summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndreas Grois <andi@grois.info>2024-07-20 00:10:13 +0200
committerAndreas Grois <andi@grois.info>2024-07-20 00:10:13 +0200
commit8659efb6bf0f3e21f0ab8d78e657739bc2238142 (patch)
tree8643472ed6fc99af1b326fdd3e3d213952cab671
parentb22f0f04c6fdf378a4c586cab657975f7d49f992 (diff)
Remove CompleteTree.indexOfLast, add CompleteTree.heapRemoveLastWithIndex
indexOfLast was going through the tree the same way as heapRemoveLast did, so heapRemoveLast now can optionally compute the index.
-rw-r--r--Common/BinaryHeap.lean137
1 files changed, 77 insertions, 60 deletions
diff --git a/Common/BinaryHeap.lean b/Common/BinaryHeap.lean
index 4cbc802..88afeda 100644
--- a/Common/BinaryHeap.lean
+++ b/Common/BinaryHeap.lean
@@ -400,38 +400,7 @@ def CompleteTree.get {α : Type u} {n : Nat} (index : Fin (n+1)) (heap : Complet
match p with
| (pp + 1) => get ⟨j - o, h₆⟩ r
---TODO: Make this use numbers instead of traversing the actual tree.
-def CompleteTree.indexOfLast {α : Type u} (heap : CompleteTree α (o+1)) : Fin (o+1) :=
- match o, heap with
- | (n+m), .branch _ l r _ _ _ =>
- if p : 0 = (n+m) then
- 0
- else
- let rightIsFull : Bool := (m+1).nextPowerOfTwo = m+1
- have m_gt_0_or_rightIsFull : m > 0 ∨ rightIsFull := by cases m <;> simp_arith (config := { ground:=true })[rightIsFull]
- if h₁ : m < n ∧ rightIsFull then
- match n with
- | .zero => absurd (Nat.zero_lt_of_lt h₁.left) (Nat.lt_irrefl 0)
- | .succ nn => (l.indexOfLast.succ.castAdd m).cast (by simp_arith)
- else
- have : m > 0 := by
- cases m_gt_0_or_rightIsFull
- case inl => assumption
- case inr h =>
- simp_arith [h] at h₁
- cases n
- case zero =>
- simp[Nat.zero_lt_of_ne_zero] at p
- exact Nat.zero_lt_of_ne_zero (Ne.symm p)
- case succ q _ _ _ =>
- cases m
- . exact False.elim $ Nat.not_succ_le_zero q h₁
- . simp_arith
- match m with
- | .succ mm =>
- ⟨r.indexOfLast.val + 1 + n, by omega⟩
-
-/-- Helper for heapRemoveLast -/
+/-- Helper for heapRemoveLastAux -/
private theorem CompleteTree.removeRightRightNotEmpty {n m : Nat} (m_gt_0_or_rightIsFull : m > 0 ∨ ((m+1).nextPowerOfTwo = m+1 : Bool)) (h₁ : 0 ≠ n + m) (h₂ : ¬(m < n ∧ ((m+1).nextPowerOfTwo = m+1 : Bool))) : m > 0 :=
match m_gt_0_or_rightIsFull with
| Or.inl h => h
@@ -444,7 +413,7 @@ private theorem CompleteTree.removeRightRightNotEmpty {n m : Nat} (m_gt_0_or_rig
. exact absurd h₂ $ Nat.not_succ_le_zero q
. exact Nat.succ_pos _
-/-- Helper for heapRemoveLast -/
+/-- Helper for heapRemoveLastAux -/
private theorem CompleteTree.removeRightLeftIsFull {n m : Nat} (r : ¬(m < n ∧ ((m+1).nextPowerOfTwo = m+1 : Bool))) (m_le_n : m ≤ n) (subtree_complete : (n + 1).isPowerOfTwo ∨ (m + 1).isPowerOfTwo) : (n+1).isPowerOfTwo := by
rewrite[Decidable.not_and_iff_or_not] at r
cases r
@@ -457,7 +426,7 @@ private theorem CompleteTree.removeRightLeftIsFull {n m : Nat} (r : ¬(m < n ∧
simp[h₁] at subtree_complete
assumption
-/-- Helper for heapRemoveLast-/
+/-- Helper for heapRemoveLastAux -/
private theorem CompleteTree.stillInRange {n m : Nat} (r : ¬(m < n ∧ ((m+1).nextPowerOfTwo = m+1 : Bool))) (m_le_n : m ≤ n) (m_gt_0 : m > 0) (leftIsFull : (n+1).isPowerOfTwo) (max_height_difference: n < 2 * (m + 1)) : n < 2*m := by
rewrite[Decidable.not_and_iff_or_not] at r
cases r with
@@ -467,11 +436,20 @@ private theorem CompleteTree.stillInRange {n m : Nat} (r : ¬(m < n ∧ ((m+1).n
| inr h₁ => simp (config := { zetaDelta := true }) only [← Nat.power_of_two_iff_next_power_eq, decide_eq_true_eq] at h₁
apply power_of_two_mul_two_le <;> assumption
-private def CompleteTree.heapRemoveLast {α : Type u} (heap : CompleteTree α (o+1)) : (CompleteTree α o × α) :=
+private def CompleteTree.heapRemoveLastAux
+{α : Type u}
+{β : Nat → Type u}
+{o : Nat}
+(heap : CompleteTree α (o+1))
+(aux0 : α → (β 1))
+(auxl : {prev_size curr_size : Nat} → β prev_size → (h₁ : prev_size < curr_size) → β curr_size)
+(auxr : {prev_size curr_size : Nat} → β prev_size → (left_size : Nat) → (h₁ : prev_size + left_size < curr_size) → β curr_size)
+: (CompleteTree α o × (β (o+1)))
+:=
match o, heap with
| (n+m), .branch a (left : CompleteTree α n) (right : CompleteTree α m) m_le_n max_height_difference subtree_complete =>
if p : 0 = (n+m) then
- (p▸CompleteTree.leaf, a)
+ (p▸CompleteTree.leaf, p▸aux0 a)
else
let rightIsFull : Bool := (m+1).nextPowerOfTwo = m+1
have m_gt_0_or_rightIsFull : m > 0 ∨ rightIsFull := by cases m <;> simp (config := { ground:=true })[rightIsFull]
@@ -479,22 +457,33 @@ private def CompleteTree.heapRemoveLast {α : Type u} (heap : CompleteTree α (o
--remove left
match n, left with
| (l+1), left =>
- let ((newLeft : CompleteTree α l), res) := left.heapRemoveLast
+ let ((newLeft : CompleteTree α l), res) := left.heapRemoveLastAux aux0 auxl auxr
have q : l + m + 1 = l + 1 + m := Nat.add_right_comm l m 1
have s : m ≤ l := Nat.le_of_lt_succ r.left
have rightIsFull : (m+1).isPowerOfTwo := (Nat.power_of_two_iff_next_power_eq (m+1)).mpr $ decide_eq_true_eq.mp r.right
have l_lt_2_m_succ : l < 2 * (m+1) := Nat.lt_of_succ_lt max_height_difference
+ let res := auxl res (by simp_arith)
(q▸CompleteTree.branch a newLeft right s l_lt_2_m_succ (Or.inr rightIsFull), res)
else
--remove right
have m_gt_0 : m > 0 := removeRightRightNotEmpty m_gt_0_or_rightIsFull p r
let l := m.pred
have h₂ : l.succ = m := (Nat.succ_pred $ Nat.not_eq_zero_of_lt (Nat.gt_of_not_le $ Nat.not_le_of_gt m_gt_0))
- let ((newRight : CompleteTree α l), res) := (h₂.symm▸right).heapRemoveLast
+ let ((newRight : CompleteTree α l), res) := (h₂.symm▸right).heapRemoveLastAux aux0 auxl auxr
have leftIsFull : (n+1).isPowerOfTwo := removeRightLeftIsFull r m_le_n subtree_complete
have still_in_range : n < 2 * (l+1) := h₂.substr (p := λx ↦ n < 2 * x) $ stillInRange r m_le_n m_gt_0 leftIsFull max_height_difference
+ let res := auxr res n (by omega)
(h₂▸CompleteTree.branch a left newRight (Nat.le_of_succ_le (h₂▸m_le_n)) still_in_range (Or.inl leftIsFull), res)
+private def CompleteTree.heapRemoveLast {α : Type u} {o : Nat} (heap : CompleteTree α (o+1)) : (CompleteTree α o × α) :=
+ heap.heapRemoveLastAux id (λ(a : α) _ ↦ a) (λa _ _ ↦ a)
+
+private def CompleteTree.heapRemoveLastWithIndex {α : Type u} {o : Nat} (heap : CompleteTree α (o+1)) : (CompleteTree α o × α × Fin (o+1)) :=
+ heap.heapRemoveLastAux (β := λn ↦ α × Fin n)
+ (λ(a : α) ↦ (a, Fin.mk 0 (Nat.succ_pos 0)))
+ (λ(a, prev_idx) h₁ ↦ (a, prev_idx.succ.castLE $ Nat.succ_le_of_lt h₁) )
+ (λ(a, prev_idx) left_size h₁ ↦ (a, (prev_idx.addNat left_size).succ.castLE $ Nat.succ_le_of_lt h₁))
+
private theorem CompleteTree.castZeroHeap (n m : Nat) (heap : CompleteTree α 0) (h₁ : 0=n+m) {le : α → α → Bool} : HeapPredicate (h₁ ▸ heap) le := by
have h₂ : heap = (CompleteTree.empty : CompleteTree α 0) := by
simp[empty]
@@ -530,14 +519,29 @@ private theorem HeapPredicate.seesThroughCast2
assumption
-- If there is only one element left, the result is a leaf.
-private theorem CompleteTree.heapRemoveLastLeaf (heap : CompleteTree α 1) : heap.heapRemoveLast.fst = CompleteTree.leaf := by
- let l := heap.heapRemoveLast.fst
+private theorem CompleteTree.heapRemoveLastAuxLeaf
+{α : Type u}
+{β : Nat → Type u}
+(heap : CompleteTree α 1)
+(aux0 : α → (β 1))
+(auxl : {prev_size curr_size : Nat} → β prev_size → (h₁ : prev_size < curr_size) → β curr_size)
+(auxr : {prev_size curr_size : Nat} → β prev_size → (left_size : Nat) → (h₁ : prev_size + left_size < curr_size) → β curr_size)
+: (heap.heapRemoveLastAux aux0 auxl auxr).fst = CompleteTree.leaf := by
+ let l := (heap.heapRemoveLastAux aux0 auxl auxr).fst
have h₁ : l = CompleteTree.leaf := match l with
| .leaf => rfl
exact h₁
-private theorem CompleteTree.heapRemoveLastLeavesRoot (heap : CompleteTree α (n+1)) (h₁ : n > 0) : heap.root (Nat.zero_lt_of_ne_zero $ Nat.succ_ne_zero n) = heap.heapRemoveLast.fst.root (h₁) := by
- unfold heapRemoveLast
+private theorem CompleteTree.heapRemoveLastAuxLeavesRoot
+{α : Type u}
+{β : Nat → Type u}
+(heap : CompleteTree α (n+1))
+(aux0 : α → (β 1))
+(auxl : {prev_size curr_size : Nat} → β prev_size → (h₁ : prev_size < curr_size) → β curr_size)
+(auxr : {prev_size curr_size : Nat} → β prev_size → (left_size : Nat) → (h₁ : prev_size + left_size < curr_size) → β curr_size)
+(h₁ : n > 0)
+: heap.root (Nat.zero_lt_of_ne_zero $ Nat.succ_ne_zero n) = (heap.heapRemoveLastAux aux0 auxl auxr).fst.root (h₁) := by
+ unfold heapRemoveLastAux
split
rename_i o p v l r _ _ _
have h₃ : (0 ≠ o + p) := Ne.symm $ Nat.not_eq_zero_of_lt h₁
@@ -561,8 +565,16 @@ private theorem CompleteTree.heapRemoveLastLeavesRoot (heap : CompleteTree α (n
simp_arith
apply root_unfold
-private theorem CompleteTree.heapRemoveLastIsHeap {heap : CompleteTree α (o+1)} {le : α → α → Bool} (h₁ : HeapPredicate heap le) (h₂ : transitive_le le) (h₃ : total_le le) : HeapPredicate (heap.heapRemoveLast.fst) le := by
- unfold heapRemoveLast
+private theorem CompleteTree.heapRemoveLastAuxIsHeap
+{α : Type u}
+{β : Nat → Type u}
+{heap : CompleteTree α (o+1)}
+{le : α → α → Bool}
+(aux0 : α → (β 1))
+(auxl : {prev_size curr_size : Nat} → β prev_size → (h₁ : prev_size < curr_size) → β curr_size)
+(auxr : {prev_size curr_size : Nat} → β prev_size → (left_size : Nat) → (h₁ : prev_size + left_size < curr_size) → β curr_size)
+(h₁ : HeapPredicate heap le) (h₂ : transitive_le le) (h₃ : total_le le) : HeapPredicate ((heap.heapRemoveLastAux aux0 auxl auxr).fst) le := by
+ unfold heapRemoveLastAux
split
rename_i n m v l r _ _ _
exact
@@ -581,7 +593,7 @@ private theorem CompleteTree.heapRemoveLastIsHeap {heap : CompleteTree α (o+1)}
apply HeapPredicate.seesThroughCast2 <;> try simp_arith
cases ll
case a.zero => -- if ll is zero, then (heapRemoveLast l).snd is a leaf.
- have h₆ : l.heapRemoveLast.fst = .leaf := heapRemoveLastLeaf l
+ have h₆ := heapRemoveLastAuxLeaf l aux0 auxl auxr
rw[h₆]
unfold HeapPredicate at *
have h₇ : HeapPredicate .leaf le := by trivial
@@ -589,10 +601,10 @@ private theorem CompleteTree.heapRemoveLastIsHeap {heap : CompleteTree α (o+1)}
exact ⟨h₇,h₁.right.left, h₈, h₁.right.right.right⟩
case a.succ nn => -- if ll is not zero, then the root element before and after heapRemoveLast is the same.
unfold HeapPredicate at *
- simp[h₁.right.left, h₁.right.right.right, heapRemoveLastIsHeap h₁.left h₂ h₃]
+ simp[h₁.right.left, h₁.right.right.right, heapRemoveLastAuxIsHeap aux0 auxl auxr h₁.left h₂ h₃]
unfold HeapPredicate.leOrLeaf
simp
- rw[←heapRemoveLastLeavesRoot]
+ rw[←heapRemoveLastAuxLeavesRoot]
exact h₁.right.right.left
else by
simp[h₅]
@@ -603,16 +615,22 @@ private theorem CompleteTree.heapRemoveLastIsHeap {heap : CompleteTree α (o+1)}
case succ mm h₆ h₇ h₈ =>
simp
unfold HeapPredicate at *
- simp[h₁, heapRemoveLastIsHeap h₁.right.left h₂ h₃]
+ simp[h₁, heapRemoveLastAuxIsHeap aux0 auxl auxr h₁.right.left h₂ h₃]
unfold HeapPredicate.leOrLeaf
exact match mm with
| 0 => rfl
| o+1 =>
- have h₉ : le v ((r.heapRemoveLast).fst.root (Nat.zero_lt_succ o)) := by
- rw[←heapRemoveLastLeavesRoot]
+ have h₉ : le v ((r.heapRemoveLastAux _ _ _).fst.root (Nat.zero_lt_succ o)) := by
+ rw[←heapRemoveLastAuxLeavesRoot]
exact h₁.right.right.right
h₉
+private theorem CompleteTree.heapRemoveLastIsHeap {α : Type u} {heap : CompleteTree α (o+1)} {le : α → α → Bool} (h₁ : HeapPredicate heap le) (h₂ : transitive_le le) (h₃ : total_le le) : HeapPredicate (heap.heapRemoveLast.fst) le :=
+ heapRemoveLastAuxIsHeap _ _ _ h₁ h₂ h₃
+
+private theorem CompleteTree.heapRemoveLastWithIndexIsHeap {α : Type u} {heap : CompleteTree α (o+1)} {le : α → α → Bool} (h₁ : HeapPredicate heap le) (h₂ : transitive_le le) (h₃ : total_le le) : HeapPredicate (heap.heapRemoveLastWithIndex.fst) le :=
+ heapRemoveLastAuxIsHeap _ _ _ h₁ h₂ h₃
+
private def BinaryHeap.heapRemoveLast {α : Type u} {le : α → α → Bool} {n : Nat} : (BinaryHeap α le (n+1)) → BinaryHeap α le n × α
| {tree, valid, wellDefinedLe} =>
let result := tree.heapRemoveLast
@@ -1063,30 +1081,29 @@ def CompleteTree.heapRemoveAt {α : Type u} {n : Nat} (le : α → α → Bool)
if index_ne_zero : index = 0 then
heapPop le heap
else
- let lastIndex := heap.indexOfLast
- let l := heap.heapRemoveLast
- if p : index = lastIndex then
- l
+ let (remaining_tree, removed_element, removed_index) := heap.heapRemoveLastWithIndex
+ if p : index = removed_index then
+ (remaining_tree, removed_element)
else
have n_gt_zero : n > 0 := by
cases n
case succ nn => exact Nat.zero_lt_succ nn
case zero => omega
- if index_lt_lastIndex : index ≥ lastIndex then
+ if index_lt_lastIndex : index ≥ removed_index then
let index := index.pred index_ne_zero
- heapUpdateAt le index l.snd l.fst n_gt_zero
+ heapUpdateAt le index removed_element remaining_tree n_gt_zero
else
let h₁ : index < n := by omega
let index : Fin n := ⟨index, h₁⟩
- heapUpdateAt le index l.snd l.fst n_gt_zero
+ heapUpdateAt le index removed_element remaining_tree n_gt_zero
theorem CompleteTree.heapRemoveAtIsHeap {α : Type u} {n : Nat} (le : α → α → Bool) (index : Fin (n+1)) (heap : CompleteTree α (n+1)) (h₁ : HeapPredicate heap le) (wellDefinedLe : transitive_le le ∧ total_le le) : HeapPredicate (heap.heapRemoveAt le index).fst le := by
- have h₂ : HeapPredicate heap.heapRemoveLast.fst le := heapRemoveLastIsHeap h₁ wellDefinedLe.left wellDefinedLe.right
+ have h₂ : HeapPredicate heap.heapRemoveLastWithIndex.fst le := heapRemoveLastWithIndexIsHeap h₁ wellDefinedLe.left wellDefinedLe.right
unfold heapRemoveAt
split
case isTrue => exact heapPopIsHeap le heap h₁ wellDefinedLe
case isFalse h₃ =>
- cases h: (index = heap.indexOfLast : Bool)
+ cases h: (index = heap.heapRemoveLastWithIndex.snd.snd : Bool)
<;> simp_all
split
<;> apply heapUpdateAtIsHeap <;> simp_all
@@ -1120,7 +1137,7 @@ private def TestHeap :=
|> ins 3
#eval TestHeap
-#eval TestHeap.heapRemoveLast
+#eval TestHeap.heapRemoveLastWithIndex
#eval TestHeap.indexOf (13 = ·)
#eval TestHeap.heapRemoveAt (.≤.) 7