summaryrefslogtreecommitdiff
path: root/Common
diff options
context:
space:
mode:
authorAndreas Grois <andi@grois.info>2024-07-20 00:10:13 +0200
committerAndreas Grois <andi@grois.info>2024-07-20 00:10:13 +0200
commit8659efb6bf0f3e21f0ab8d78e657739bc2238142 (patch)
tree8643472ed6fc99af1b326fdd3e3d213952cab671 /Common
parentb22f0f04c6fdf378a4c586cab657975f7d49f992 (diff)
Remove CompleteTree.indexOfLast, add CompleteTree.heapRemoveLastWithIndex
indexOfLast was going through the tree the same way as heapRemoveLast did, so heapRemoveLast now can optionally compute the index.
Diffstat (limited to 'Common')
-rw-r--r--Common/BinaryHeap.lean137
1 files changed, 77 insertions, 60 deletions
diff --git a/Common/BinaryHeap.lean b/Common/BinaryHeap.lean
index 4cbc802..88afeda 100644
--- a/Common/BinaryHeap.lean
+++ b/Common/BinaryHeap.lean
@@ -400,38 +400,7 @@ def CompleteTree.get {α : Type u} {n : Nat} (index : Fin (n+1)) (heap : Complet
match p with
| (pp + 1) => get ⟨j - o, h₆⟩ r
---TODO: Make this use numbers instead of traversing the actual tree.
-def CompleteTree.indexOfLast {α : Type u} (heap : CompleteTree α (o+1)) : Fin (o+1) :=
- match o, heap with
- | (n+m), .branch _ l r _ _ _ =>
- if p : 0 = (n+m) then
- 0
- else
- let rightIsFull : Bool := (m+1).nextPowerOfTwo = m+1
- have m_gt_0_or_rightIsFull : m > 0 ∨ rightIsFull := by cases m <;> simp_arith (config := { ground:=true })[rightIsFull]
- if h₁ : m < n ∧ rightIsFull then
- match n with
- | .zero => absurd (Nat.zero_lt_of_lt h₁.left) (Nat.lt_irrefl 0)
- | .succ nn => (l.indexOfLast.succ.castAdd m).cast (by simp_arith)
- else
- have : m > 0 := by
- cases m_gt_0_or_rightIsFull
- case inl => assumption
- case inr h =>
- simp_arith [h] at h₁
- cases n
- case zero =>
- simp[Nat.zero_lt_of_ne_zero] at p
- exact Nat.zero_lt_of_ne_zero (Ne.symm p)
- case succ q _ _ _ =>
- cases m
- . exact False.elim $ Nat.not_succ_le_zero q h₁
- . simp_arith
- match m with
- | .succ mm =>
- ⟨r.indexOfLast.val + 1 + n, by omega⟩
-
-/-- Helper for heapRemoveLast -/
+/-- Helper for heapRemoveLastAux -/
private theorem CompleteTree.removeRightRightNotEmpty {n m : Nat} (m_gt_0_or_rightIsFull : m > 0 ∨ ((m+1).nextPowerOfTwo = m+1 : Bool)) (h₁ : 0 ≠ n + m) (h₂ : ¬(m < n ∧ ((m+1).nextPowerOfTwo = m+1 : Bool))) : m > 0 :=
match m_gt_0_or_rightIsFull with
| Or.inl h => h
@@ -444,7 +413,7 @@ private theorem CompleteTree.removeRightRightNotEmpty {n m : Nat} (m_gt_0_or_rig
. exact absurd h₂ $ Nat.not_succ_le_zero q
. exact Nat.succ_pos _
-/-- Helper for heapRemoveLast -/
+/-- Helper for heapRemoveLastAux -/
private theorem CompleteTree.removeRightLeftIsFull {n m : Nat} (r : ¬(m < n ∧ ((m+1).nextPowerOfTwo = m+1 : Bool))) (m_le_n : m ≤ n) (subtree_complete : (n + 1).isPowerOfTwo ∨ (m + 1).isPowerOfTwo) : (n+1).isPowerOfTwo := by
rewrite[Decidable.not_and_iff_or_not] at r
cases r
@@ -457,7 +426,7 @@ private theorem CompleteTree.removeRightLeftIsFull {n m : Nat} (r : ¬(m < n ∧
simp[h₁] at subtree_complete
assumption
-/-- Helper for heapRemoveLast-/
+/-- Helper for heapRemoveLastAux -/
private theorem CompleteTree.stillInRange {n m : Nat} (r : ¬(m < n ∧ ((m+1).nextPowerOfTwo = m+1 : Bool))) (m_le_n : m ≤ n) (m_gt_0 : m > 0) (leftIsFull : (n+1).isPowerOfTwo) (max_height_difference: n < 2 * (m + 1)) : n < 2*m := by
rewrite[Decidable.not_and_iff_or_not] at r
cases r with
@@ -467,11 +436,20 @@ private theorem CompleteTree.stillInRange {n m : Nat} (r : ¬(m < n ∧ ((m+1).n
| inr h₁ => simp (config := { zetaDelta := true }) only [← Nat.power_of_two_iff_next_power_eq, decide_eq_true_eq] at h₁
apply power_of_two_mul_two_le <;> assumption
-private def CompleteTree.heapRemoveLast {α : Type u} (heap : CompleteTree α (o+1)) : (CompleteTree α o × α) :=
+private def CompleteTree.heapRemoveLastAux
+{α : Type u}
+{β : Nat → Type u}
+{o : Nat}
+(heap : CompleteTree α (o+1))
+(aux0 : α → (β 1))
+(auxl : {prev_size curr_size : Nat} → β prev_size → (h₁ : prev_size < curr_size) → β curr_size)
+(auxr : {prev_size curr_size : Nat} → β prev_size → (left_size : Nat) → (h₁ : prev_size + left_size < curr_size) → β curr_size)
+: (CompleteTree α o × (β (o+1)))
+:=
match o, heap with
| (n+m), .branch a (left : CompleteTree α n) (right : CompleteTree α m) m_le_n max_height_difference subtree_complete =>
if p : 0 = (n+m) then
- (p▸CompleteTree.leaf, a)
+ (p▸CompleteTree.leaf, p▸aux0 a)
else
let rightIsFull : Bool := (m+1).nextPowerOfTwo = m+1
have m_gt_0_or_rightIsFull : m > 0 ∨ rightIsFull := by cases m <;> simp (config := { ground:=true })[rightIsFull]
@@ -479,22 +457,33 @@ private def CompleteTree.heapRemoveLast {α : Type u} (heap : CompleteTree α (o
--remove left
match n, left with
| (l+1), left =>
- let ((newLeft : CompleteTree α l), res) := left.heapRemoveLast
+ let ((newLeft : CompleteTree α l), res) := left.heapRemoveLastAux aux0 auxl auxr
have q : l + m + 1 = l + 1 + m := Nat.add_right_comm l m 1
have s : m ≤ l := Nat.le_of_lt_succ r.left
have rightIsFull : (m+1).isPowerOfTwo := (Nat.power_of_two_iff_next_power_eq (m+1)).mpr $ decide_eq_true_eq.mp r.right
have l_lt_2_m_succ : l < 2 * (m+1) := Nat.lt_of_succ_lt max_height_difference
+ let res := auxl res (by simp_arith)
(q▸CompleteTree.branch a newLeft right s l_lt_2_m_succ (Or.inr rightIsFull), res)
else
--remove right
have m_gt_0 : m > 0 := removeRightRightNotEmpty m_gt_0_or_rightIsFull p r
let l := m.pred
have h₂ : l.succ = m := (Nat.succ_pred $ Nat.not_eq_zero_of_lt (Nat.gt_of_not_le $ Nat.not_le_of_gt m_gt_0))
- let ((newRight : CompleteTree α l), res) := (h₂.symm▸right).heapRemoveLast
+ let ((newRight : CompleteTree α l), res) := (h₂.symm▸right).heapRemoveLastAux aux0 auxl auxr
have leftIsFull : (n+1).isPowerOfTwo := removeRightLeftIsFull r m_le_n subtree_complete
have still_in_range : n < 2 * (l+1) := h₂.substr (p := λx ↦ n < 2 * x) $ stillInRange r m_le_n m_gt_0 leftIsFull max_height_difference
+ let res := auxr res n (by omega)
(h₂▸CompleteTree.branch a left newRight (Nat.le_of_succ_le (h₂▸m_le_n)) still_in_range (Or.inl leftIsFull), res)
+private def CompleteTree.heapRemoveLast {α : Type u} {o : Nat} (heap : CompleteTree α (o+1)) : (CompleteTree α o × α) :=
+ heap.heapRemoveLastAux id (λ(a : α) _ ↦ a) (λa _ _ ↦ a)
+
+private def CompleteTree.heapRemoveLastWithIndex {α : Type u} {o : Nat} (heap : CompleteTree α (o+1)) : (CompleteTree α o × α × Fin (o+1)) :=
+ heap.heapRemoveLastAux (β := λn ↦ α × Fin n)
+ (λ(a : α) ↦ (a, Fin.mk 0 (Nat.succ_pos 0)))
+ (λ(a, prev_idx) h₁ ↦ (a, prev_idx.succ.castLE $ Nat.succ_le_of_lt h₁) )
+ (λ(a, prev_idx) left_size h₁ ↦ (a, (prev_idx.addNat left_size).succ.castLE $ Nat.succ_le_of_lt h₁))
+
private theorem CompleteTree.castZeroHeap (n m : Nat) (heap : CompleteTree α 0) (h₁ : 0=n+m) {le : α → α → Bool} : HeapPredicate (h₁ ▸ heap) le := by
have h₂ : heap = (CompleteTree.empty : CompleteTree α 0) := by
simp[empty]
@@ -530,14 +519,29 @@ private theorem HeapPredicate.seesThroughCast2
assumption
-- If there is only one element left, the result is a leaf.
-private theorem CompleteTree.heapRemoveLastLeaf (heap : CompleteTree α 1) : heap.heapRemoveLast.fst = CompleteTree.leaf := by
- let l := heap.heapRemoveLast.fst
+private theorem CompleteTree.heapRemoveLastAuxLeaf
+{α : Type u}
+{β : Nat → Type u}
+(heap : CompleteTree α 1)
+(aux0 : α → (β 1))
+(auxl : {prev_size curr_size : Nat} → β prev_size → (h₁ : prev_size < curr_size) → β curr_size)
+(auxr : {prev_size curr_size : Nat} → β prev_size → (left_size : Nat) → (h₁ : prev_size + left_size < curr_size) → β curr_size)
+: (heap.heapRemoveLastAux aux0 auxl auxr).fst = CompleteTree.leaf := by
+ let l := (heap.heapRemoveLastAux aux0 auxl auxr).fst
have h₁ : l = CompleteTree.leaf := match l with
| .leaf => rfl
exact h₁
-private theorem CompleteTree.heapRemoveLastLeavesRoot (heap : CompleteTree α (n+1)) (h₁ : n > 0) : heap.root (Nat.zero_lt_of_ne_zero $ Nat.succ_ne_zero n) = heap.heapRemoveLast.fst.root (h₁) := by
- unfold heapRemoveLast
+private theorem CompleteTree.heapRemoveLastAuxLeavesRoot
+{α : Type u}
+{β : Nat → Type u}
+(heap : CompleteTree α (n+1))
+(aux0 : α → (β 1))
+(auxl : {prev_size curr_size : Nat} → β prev_size → (h₁ : prev_size < curr_size) → β curr_size)
+(auxr : {prev_size curr_size : Nat} → β prev_size → (left_size : Nat) → (h₁ : prev_size + left_size < curr_size) → β curr_size)
+(h₁ : n > 0)
+: heap.root (Nat.zero_lt_of_ne_zero $ Nat.succ_ne_zero n) = (heap.heapRemoveLastAux aux0 auxl auxr).fst.root (h₁) := by
+ unfold heapRemoveLastAux
split
rename_i o p v l r _ _ _
have h₃ : (0 ≠ o + p) := Ne.symm $ Nat.not_eq_zero_of_lt h₁
@@ -561,8 +565,16 @@ private theorem CompleteTree.heapRemoveLastLeavesRoot (heap : CompleteTree α (n
simp_arith
apply root_unfold
-private theorem CompleteTree.heapRemoveLastIsHeap {heap : CompleteTree α (o+1)} {le : α → α → Bool} (h₁ : HeapPredicate heap le) (h₂ : transitive_le le) (h₃ : total_le le) : HeapPredicate (heap.heapRemoveLast.fst) le := by
- unfold heapRemoveLast
+private theorem CompleteTree.heapRemoveLastAuxIsHeap
+{α : Type u}
+{β : Nat → Type u}
+{heap : CompleteTree α (o+1)}
+{le : α → α → Bool}
+(aux0 : α → (β 1))
+(auxl : {prev_size curr_size : Nat} → β prev_size → (h₁ : prev_size < curr_size) → β curr_size)
+(auxr : {prev_size curr_size : Nat} → β prev_size → (left_size : Nat) → (h₁ : prev_size + left_size < curr_size) → β curr_size)
+(h₁ : HeapPredicate heap le) (h₂ : transitive_le le) (h₃ : total_le le) : HeapPredicate ((heap.heapRemoveLastAux aux0 auxl auxr).fst) le := by
+ unfold heapRemoveLastAux
split
rename_i n m v l r _ _ _
exact
@@ -581,7 +593,7 @@ private theorem CompleteTree.heapRemoveLastIsHeap {heap : CompleteTree α (o+1)}
apply HeapPredicate.seesThroughCast2 <;> try simp_arith
cases ll
case a.zero => -- if ll is zero, then (heapRemoveLast l).snd is a leaf.
- have h₆ : l.heapRemoveLast.fst = .leaf := heapRemoveLastLeaf l
+ have h₆ := heapRemoveLastAuxLeaf l aux0 auxl auxr
rw[h₆]
unfold HeapPredicate at *
have h₇ : HeapPredicate .leaf le := by trivial
@@ -589,10 +601,10 @@ private theorem CompleteTree.heapRemoveLastIsHeap {heap : CompleteTree α (o+1)}
exact ⟨h₇,h₁.right.left, h₈, h₁.right.right.right⟩
case a.succ nn => -- if ll is not zero, then the root element before and after heapRemoveLast is the same.
unfold HeapPredicate at *
- simp[h₁.right.left, h₁.right.right.right, heapRemoveLastIsHeap h₁.left h₂ h₃]
+ simp[h₁.right.left, h₁.right.right.right, heapRemoveLastAuxIsHeap aux0 auxl auxr h₁.left h₂ h₃]
unfold HeapPredicate.leOrLeaf
simp
- rw[←heapRemoveLastLeavesRoot]
+ rw[←heapRemoveLastAuxLeavesRoot]
exact h₁.right.right.left
else by
simp[h₅]
@@ -603,16 +615,22 @@ private theorem CompleteTree.heapRemoveLastIsHeap {heap : CompleteTree α (o+1)}
case succ mm h₆ h₇ h₈ =>
simp
unfold HeapPredicate at *
- simp[h₁, heapRemoveLastIsHeap h₁.right.left h₂ h₃]
+ simp[h₁, heapRemoveLastAuxIsHeap aux0 auxl auxr h₁.right.left h₂ h₃]
unfold HeapPredicate.leOrLeaf
exact match mm with
| 0 => rfl
| o+1 =>
- have h₉ : le v ((r.heapRemoveLast).fst.root (Nat.zero_lt_succ o)) := by
- rw[←heapRemoveLastLeavesRoot]
+ have h₉ : le v ((r.heapRemoveLastAux _ _ _).fst.root (Nat.zero_lt_succ o)) := by
+ rw[←heapRemoveLastAuxLeavesRoot]
exact h₁.right.right.right
h₉
+private theorem CompleteTree.heapRemoveLastIsHeap {α : Type u} {heap : CompleteTree α (o+1)} {le : α → α → Bool} (h₁ : HeapPredicate heap le) (h₂ : transitive_le le) (h₃ : total_le le) : HeapPredicate (heap.heapRemoveLast.fst) le :=
+ heapRemoveLastAuxIsHeap _ _ _ h₁ h₂ h₃
+
+private theorem CompleteTree.heapRemoveLastWithIndexIsHeap {α : Type u} {heap : CompleteTree α (o+1)} {le : α → α → Bool} (h₁ : HeapPredicate heap le) (h₂ : transitive_le le) (h₃ : total_le le) : HeapPredicate (heap.heapRemoveLastWithIndex.fst) le :=
+ heapRemoveLastAuxIsHeap _ _ _ h₁ h₂ h₃
+
private def BinaryHeap.heapRemoveLast {α : Type u} {le : α → α → Bool} {n : Nat} : (BinaryHeap α le (n+1)) → BinaryHeap α le n × α
| {tree, valid, wellDefinedLe} =>
let result := tree.heapRemoveLast
@@ -1063,30 +1081,29 @@ def CompleteTree.heapRemoveAt {α : Type u} {n : Nat} (le : α → α → Bool)
if index_ne_zero : index = 0 then
heapPop le heap
else
- let lastIndex := heap.indexOfLast
- let l := heap.heapRemoveLast
- if p : index = lastIndex then
- l
+ let (remaining_tree, removed_element, removed_index) := heap.heapRemoveLastWithIndex
+ if p : index = removed_index then
+ (remaining_tree, removed_element)
else
have n_gt_zero : n > 0 := by
cases n
case succ nn => exact Nat.zero_lt_succ nn
case zero => omega
- if index_lt_lastIndex : index ≥ lastIndex then
+ if index_lt_lastIndex : index ≥ removed_index then
let index := index.pred index_ne_zero
- heapUpdateAt le index l.snd l.fst n_gt_zero
+ heapUpdateAt le index removed_element remaining_tree n_gt_zero
else
let h₁ : index < n := by omega
let index : Fin n := ⟨index, h₁⟩
- heapUpdateAt le index l.snd l.fst n_gt_zero
+ heapUpdateAt le index removed_element remaining_tree n_gt_zero
theorem CompleteTree.heapRemoveAtIsHeap {α : Type u} {n : Nat} (le : α → α → Bool) (index : Fin (n+1)) (heap : CompleteTree α (n+1)) (h₁ : HeapPredicate heap le) (wellDefinedLe : transitive_le le ∧ total_le le) : HeapPredicate (heap.heapRemoveAt le index).fst le := by
- have h₂ : HeapPredicate heap.heapRemoveLast.fst le := heapRemoveLastIsHeap h₁ wellDefinedLe.left wellDefinedLe.right
+ have h₂ : HeapPredicate heap.heapRemoveLastWithIndex.fst le := heapRemoveLastWithIndexIsHeap h₁ wellDefinedLe.left wellDefinedLe.right
unfold heapRemoveAt
split
case isTrue => exact heapPopIsHeap le heap h₁ wellDefinedLe
case isFalse h₃ =>
- cases h: (index = heap.indexOfLast : Bool)
+ cases h: (index = heap.heapRemoveLastWithIndex.snd.snd : Bool)
<;> simp_all
split
<;> apply heapUpdateAtIsHeap <;> simp_all
@@ -1120,7 +1137,7 @@ private def TestHeap :=
|> ins 3
#eval TestHeap
-#eval TestHeap.heapRemoveLast
+#eval TestHeap.heapRemoveLastWithIndex
#eval TestHeap.indexOf (13 = ·)
#eval TestHeap.heapRemoveAt (.≤.) 7