aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndreas Grois <andi@grois.info>2025-01-08 21:35:06 +0100
committerAndreas Grois <andi@grois.info>2025-01-08 21:35:06 +0100
commitbf12dcd38698d9e7fd0396722d7dd75ea1cefc88 (patch)
tree75d7cc60c55224736e36177fd0c79597d9afe61e
parent4cc7ae8a6653402e091d246ed746a0ed45b099dc (diff)
Change AStarNode type class to use out param for Costs.
Fixes instance search for cost addition.
-rw-r--r--LeanAStar/Basic.lean52
1 files changed, 24 insertions, 28 deletions
diff --git a/LeanAStar/Basic.lean b/LeanAStar/Basic.lean
index 15f8c76..ba1cc33 100644
--- a/LeanAStar/Basic.lean
+++ b/LeanAStar/Basic.lean
@@ -5,29 +5,28 @@ import LeanAStar.HashSet
namespace LeanAStar
/-- The type-class any node type needs to implement in order to be usable in AStar -/
-class AStarNode.{u, v} (α : Type u) extends Finite α, Hashable α, BEq α where
- Costs : Type v
+class AStarNode.{u, v} (α : Type u) (Costs : outParam (Type v)) extends Finite α, Hashable α, BEq α where
costsLe : Costs → Costs → Bool
costs_order : BinaryHeap.TotalAndTransitiveLe costsLe
getNeighbours : α → List (α × Costs)
isGoal : α → Bool
remaining_costs_heuristic : α → Costs
-protected structure OpenSetEntry (α : Type u) [AStarNode α] where
+protected structure OpenSetEntry (α : Type u) [AStarNode α Costs] where
node : α
- accumulated_costs : AStarNode.Costs α
- estimated_total_costs : AStarNode.Costs α
+ accumulated_costs : Costs
+ estimated_total_costs : Costs
-protected def OpenSetEntry.le {α : Type u} [AStarNode α] (a b : LeanAStar.OpenSetEntry α) : Bool :=
- AStarNode.costsLe a.estimated_total_costs b.estimated_total_costs
+protected def OpenSetEntry.le {α : Type u} [AStarNode α Costs] (a b : LeanAStar.OpenSetEntry α) : Bool :=
+ AStarNode.costsLe α a.estimated_total_costs b.estimated_total_costs
-protected def OpenSetEntry.le_total {α : Type u} [AStarNode α] : BinaryHeap.total_le (α := LeanAStar.OpenSetEntry α) OpenSetEntry.le :=
+protected def OpenSetEntry.le_total {α : Type u} [AStarNode α Costs] : BinaryHeap.total_le (α := LeanAStar.OpenSetEntry α) OpenSetEntry.le :=
λa b ↦ AStarNode.costs_order.right a.estimated_total_costs b.estimated_total_costs
-protected def OpenSetEntry.le_trans {α : Type u} [AStarNode α] : BinaryHeap.transitive_le (α := LeanAStar.OpenSetEntry α) OpenSetEntry.le :=
+protected def OpenSetEntry.le_trans {α : Type u} [AStarNode α Costs] : BinaryHeap.transitive_le (α := LeanAStar.OpenSetEntry α) OpenSetEntry.le :=
λa b c ↦ AStarNode.costs_order.left a.estimated_total_costs b.estimated_total_costs c.estimated_total_costs
-protected def findFirstNotInClosedSet {α : Type u} [AStarNode α] {n : Nat} (openSet : BinaryHeap (LeanAStar.OpenSetEntry α) OpenSetEntry.le n) (closedSet : Std.HashSet α) : Option ((r : Nat) × LeanAStar.OpenSetEntry α × BinaryHeap (LeanAStar.OpenSetEntry α) OpenSetEntry.le r) :=
+protected def findFirstNotInClosedSet {α : Type u} [AStarNode α Costs] {n : Nat} (openSet : BinaryHeap (LeanAStar.OpenSetEntry α) OpenSetEntry.le n) (closedSet : Std.HashSet α) : Option ((r : Nat) × LeanAStar.OpenSetEntry α × BinaryHeap (LeanAStar.OpenSetEntry α) OpenSetEntry.le r) :=
match n, openSet with
| 0, _ => none
| m+1, openSet =>
@@ -37,7 +36,7 @@ protected def findFirstNotInClosedSet {α : Type u} [AStarNode α] {n : Nat} (op
else
some ⟨m, openSetEntry, openSet⟩
-protected theorem findFirstNotInClosedSet_not_in_closed_set {α : Type u} [AStarNode α] {n : Nat} (openSet : BinaryHeap (LeanAStar.OpenSetEntry α) OpenSetEntry.le n) (closedSet : Std.HashSet α) {result : (r : Nat) × LeanAStar.OpenSetEntry α × BinaryHeap (LeanAStar.OpenSetEntry α) OpenSetEntry.le r} (h₁ : LeanAStar.findFirstNotInClosedSet openSet closedSet = some result) : ¬closedSet.contains result.snd.fst.node := by
+protected theorem findFirstNotInClosedSet_not_in_closed_set {α : Type u} [AStarNode α Costs] {n : Nat} (openSet : BinaryHeap (LeanAStar.OpenSetEntry α) OpenSetEntry.le n) (closedSet : Std.HashSet α) {result : (r : Nat) × LeanAStar.OpenSetEntry α × BinaryHeap (LeanAStar.OpenSetEntry α) OpenSetEntry.le r} (h₁ : LeanAStar.findFirstNotInClosedSet openSet closedSet = some result) : ¬closedSet.contains result.snd.fst.node := by
simp
unfold LeanAStar.findFirstNotInClosedSet at h₁
split at h₁; contradiction
@@ -52,7 +51,7 @@ protected theorem findFirstNotInClosedSet_not_in_closed_set {α : Type u} [AStar
subst result
assumption
-protected def findPath_Aux {α : Type u} [AStarNode α] [Add (AStarNode.Costs α)] [LawfulBEq α] {n : Nat} (openSet : BinaryHeap (LeanAStar.OpenSetEntry α) OpenSetEntry.le n) (closedSet : Std.HashSet α) : Option (AStarNode.Costs α) :=
+protected def findPath_Aux {α : Type u} [AStarNode α Costs] [Add Costs] [LawfulBEq α] {n : Nat} (openSet : BinaryHeap (LeanAStar.OpenSetEntry α) OpenSetEntry.le n) (closedSet : Std.HashSet α) : Option Costs :=
match _h₁ : LeanAStar.findFirstNotInClosedSet openSet closedSet with
| none => none
| some ⟨_,({node, accumulated_costs,..}, openSet)⟩ =>
@@ -81,21 +80,21 @@ decreasing_by
have : closedSet.size < (Finite.cardinality α) := Std.HashSet.size_lt_finite_cardinality_of_not_mem closedSet ⟨_,h₂⟩
omega
-structure StartPoint (α : Type u) [AStarNode α] where
+structure StartPoint (α : Type u) [AStarNode α Costs] where
start : α
- initial_costs : AStarNode.Costs α
+ initial_costs : Costs
-protected structure PathFindCosts (α : Type u) [AStarNode α] where
+protected structure PathFindCosts (α : Type u) [AStarNode α Costs] where
previousNodes : List α
- actualCosts : AStarNode.Costs α
+ actualCosts : Costs
-protected structure PathFindCostsAdapter (α : Type u) [AStarNode α] where
+protected structure PathFindCostsAdapter (α : Type u) [AStarNode α Costs] where
node : α
private theorem Function.comp_assoc {α : Type u} {β : Type v} {γ : Type w} {δ : Type x} (f : γ → δ) (g : β → γ) (h : α → β) : f ∘ g ∘ h = (f ∘ g) ∘ h := rfl
private theorem Function.comp_id_neutral_left {α : Type u} {β : Type v} (f : α → β) : id ∘ f = f := rfl
-protected instance {α : Type u} [AStarNode α] : AStarNode (LeanAStar.PathFindCostsAdapter α) where
+protected instance {α : Type u} [AStarNode α Costs] : AStarNode (LeanAStar.PathFindCostsAdapter α) (LeanAStar.PathFindCosts α) where
cardinality := Finite.cardinality α
enumerate := Finite.enumerate ∘ PathFindCostsAdapter.node
nth := PathFindCostsAdapter.mk ∘ Finite.nth
@@ -112,8 +111,7 @@ protected instance {α : Type u} [AStarNode α] : AStarNode (LeanAStar.PathFindC
exact Finite.enumerate_inverse_nth
hash := Hashable.hash ∘ PathFindCostsAdapter.node
beq := λa b ↦ a.node == b.node
- Costs := LeanAStar.PathFindCosts α
- costsLe := λa b ↦ AStarNode.costsLe a.actualCosts b.actualCosts
+ costsLe := λa b ↦ AStarNode.costsLe α a.actualCosts b.actualCosts
costs_order := ⟨λa b c ↦ AStarNode.costs_order.left a.actualCosts b.actualCosts c.actualCosts, λa b ↦ AStarNode.costs_order.right a.actualCosts b.actualCosts⟩
getNeighbours := λx ↦
(AStarNode.getNeighbours x.node).map λ(node, actualCosts) ↦ ({node},{previousNodes := [x.node], actualCosts})
@@ -121,34 +119,34 @@ protected instance {α : Type u} [AStarNode α] : AStarNode (LeanAStar.PathFindC
remaining_costs_heuristic := λx ↦
{previousNodes := [x.node], actualCosts := AStarNode.remaining_costs_heuristic x.node}
-protected instance {α : Type u} [AStarNode α] [Add (AStarNode.Costs α)] : Add (LeanAStar.PathFindCosts α) where
+protected instance {α : Type u} [AStarNode α Costs] [Add Costs] : Add (LeanAStar.PathFindCosts α) where
add := λa b ↦
{
previousNodes := b.previousNodes ++ a.previousNodes
actualCosts := a.actualCosts + b.actualCosts
}
-protected instance {α : Type u} [AStarNode α] [LawfulBEq α] : LawfulBEq (LeanAStar.PathFindCostsAdapter α) where
+protected instance {α : Type u} [AStarNode α Costs] [LawfulBEq α] : LawfulBEq (LeanAStar.PathFindCostsAdapter α) where
rfl := by
intro a
cases a
- unfold BEq.beq AStarNode.toBEq LeanAStar.instAStarNodePathFindCostsAdapter
+ unfold BEq.beq AStarNode.toBEq LeanAStar.instAStarNodePathFindCostsAdapterPathFindCosts
simp only [beq_self_eq_true]
eq_of_beq := by
intros a b
cases a
cases b
- unfold BEq.beq AStarNode.toBEq LeanAStar.instAStarNodePathFindCostsAdapter
+ unfold BEq.beq AStarNode.toBEq LeanAStar.instAStarNodePathFindCostsAdapterPathFindCosts
simp only [beq_iff_eq, PathFindCostsAdapter.mk.injEq, imp_self]
/-- Returns the lowest-costs from any start to the nearest goal. -/
-def findLowestCosts {α : Type u} [AStarNode α] [Add (AStarNode.Costs α)] [LawfulBEq α] (starts : List (StartPoint (α := α))) : Option (AStarNode.Costs α) :=
+def findLowestCosts {α : Type u} [AStarNode α Costs] [Add Costs] [LawfulBEq α] (starts : List (StartPoint (α := α))) : Option Costs :=
let openSet := BinaryHeap.ofList ⟨OpenSetEntry.le_trans, OpenSetEntry.le_total⟩ $ starts.map λ{start, initial_costs}↦
{node := start, accumulated_costs := initial_costs, estimated_total_costs:= AStarNode.remaining_costs_heuristic start : LeanAStar.OpenSetEntry α}
LeanAStar.findPath_Aux openSet Std.HashSet.empty
/-- Helper to not only get the lowest costs, but also the shortest path. Could be implemented more efficient. -/
-def findShortestPath {α : Type u} [AStarNode α] [Add (AStarNode.Costs α)] [LawfulBEq α] (starts : List (StartPoint (α := α))) : Option (AStarNode.Costs α × List α) :=
+def findShortestPath {α : Type u} [AStarNode α Costs] [Add Costs] [LawfulBEq α] (starts : List (StartPoint (α := α))) : Option (Costs × List α) :=
let starts : List (StartPoint (α := LeanAStar.PathFindCostsAdapter α)) := starts.map λ{start, initial_costs} ↦
{
start := {node := start}
@@ -157,8 +155,6 @@ def findShortestPath {α : Type u} [AStarNode α] [Add (AStarNode.Costs α)] [La
actualCosts := initial_costs
}
}
- -- no idea why this is needed, but without this in the local context, the call to findLowerstCosts fails
- have : Add (LeanAStar.PathFindCosts α) := inferInstance
let i := findLowestCosts starts
i.map λ{previousNodes, actualCosts} ↦
(actualCosts, previousNodes)