From 8659efb6bf0f3e21f0ab8d78e657739bc2238142 Mon Sep 17 00:00:00 2001 From: Andreas Grois Date: Sat, 20 Jul 2024 00:10:13 +0200 Subject: Remove CompleteTree.indexOfLast, add CompleteTree.heapRemoveLastWithIndex indexOfLast was going through the tree the same way as heapRemoveLast did, so heapRemoveLast now can optionally compute the index. --- Common/BinaryHeap.lean | 137 +++++++++++++++++++++++++++---------------------- 1 file changed, 77 insertions(+), 60 deletions(-) diff --git a/Common/BinaryHeap.lean b/Common/BinaryHeap.lean index 4cbc802..88afeda 100644 --- a/Common/BinaryHeap.lean +++ b/Common/BinaryHeap.lean @@ -400,38 +400,7 @@ def CompleteTree.get {α : Type u} {n : Nat} (index : Fin (n+1)) (heap : Complet match p with | (pp + 1) => get ⟨j - o, h₆⟩ r ---TODO: Make this use numbers instead of traversing the actual tree. -def CompleteTree.indexOfLast {α : Type u} (heap : CompleteTree α (o+1)) : Fin (o+1) := - match o, heap with - | (n+m), .branch _ l r _ _ _ => - if p : 0 = (n+m) then - 0 - else - let rightIsFull : Bool := (m+1).nextPowerOfTwo = m+1 - have m_gt_0_or_rightIsFull : m > 0 ∨ rightIsFull := by cases m <;> simp_arith (config := { ground:=true })[rightIsFull] - if h₁ : m < n ∧ rightIsFull then - match n with - | .zero => absurd (Nat.zero_lt_of_lt h₁.left) (Nat.lt_irrefl 0) - | .succ nn => (l.indexOfLast.succ.castAdd m).cast (by simp_arith) - else - have : m > 0 := by - cases m_gt_0_or_rightIsFull - case inl => assumption - case inr h => - simp_arith [h] at h₁ - cases n - case zero => - simp[Nat.zero_lt_of_ne_zero] at p - exact Nat.zero_lt_of_ne_zero (Ne.symm p) - case succ q _ _ _ => - cases m - . exact False.elim $ Nat.not_succ_le_zero q h₁ - . simp_arith - match m with - | .succ mm => - ⟨r.indexOfLast.val + 1 + n, by omega⟩ - -/-- Helper for heapRemoveLast -/ +/-- Helper for heapRemoveLastAux -/ private theorem CompleteTree.removeRightRightNotEmpty {n m : Nat} (m_gt_0_or_rightIsFull : m > 0 ∨ ((m+1).nextPowerOfTwo = m+1 : Bool)) (h₁ : 0 ≠ n + m) (h₂ : ¬(m < n ∧ ((m+1).nextPowerOfTwo = m+1 : Bool))) : m > 0 := match m_gt_0_or_rightIsFull with | Or.inl h => h @@ -444,7 +413,7 @@ private theorem CompleteTree.removeRightRightNotEmpty {n m : Nat} (m_gt_0_or_rig . exact absurd h₂ $ Nat.not_succ_le_zero q . exact Nat.succ_pos _ -/-- Helper for heapRemoveLast -/ +/-- Helper for heapRemoveLastAux -/ private theorem CompleteTree.removeRightLeftIsFull {n m : Nat} (r : ¬(m < n ∧ ((m+1).nextPowerOfTwo = m+1 : Bool))) (m_le_n : m ≤ n) (subtree_complete : (n + 1).isPowerOfTwo ∨ (m + 1).isPowerOfTwo) : (n+1).isPowerOfTwo := by rewrite[Decidable.not_and_iff_or_not] at r cases r @@ -457,7 +426,7 @@ private theorem CompleteTree.removeRightLeftIsFull {n m : Nat} (r : ¬(m < n ∧ simp[h₁] at subtree_complete assumption -/-- Helper for heapRemoveLast-/ +/-- Helper for heapRemoveLastAux -/ private theorem CompleteTree.stillInRange {n m : Nat} (r : ¬(m < n ∧ ((m+1).nextPowerOfTwo = m+1 : Bool))) (m_le_n : m ≤ n) (m_gt_0 : m > 0) (leftIsFull : (n+1).isPowerOfTwo) (max_height_difference: n < 2 * (m + 1)) : n < 2*m := by rewrite[Decidable.not_and_iff_or_not] at r cases r with @@ -467,11 +436,20 @@ private theorem CompleteTree.stillInRange {n m : Nat} (r : ¬(m < n ∧ ((m+1).n | inr h₁ => simp (config := { zetaDelta := true }) only [← Nat.power_of_two_iff_next_power_eq, decide_eq_true_eq] at h₁ apply power_of_two_mul_two_le <;> assumption -private def CompleteTree.heapRemoveLast {α : Type u} (heap : CompleteTree α (o+1)) : (CompleteTree α o × α) := +private def CompleteTree.heapRemoveLastAux +{α : Type u} +{β : Nat → Type u} +{o : Nat} +(heap : CompleteTree α (o+1)) +(aux0 : α → (β 1)) +(auxl : {prev_size curr_size : Nat} → β prev_size → (h₁ : prev_size < curr_size) → β curr_size) +(auxr : {prev_size curr_size : Nat} → β prev_size → (left_size : Nat) → (h₁ : prev_size + left_size < curr_size) → β curr_size) +: (CompleteTree α o × (β (o+1))) +:= match o, heap with | (n+m), .branch a (left : CompleteTree α n) (right : CompleteTree α m) m_le_n max_height_difference subtree_complete => if p : 0 = (n+m) then - (p▸CompleteTree.leaf, a) + (p▸CompleteTree.leaf, p▸aux0 a) else let rightIsFull : Bool := (m+1).nextPowerOfTwo = m+1 have m_gt_0_or_rightIsFull : m > 0 ∨ rightIsFull := by cases m <;> simp (config := { ground:=true })[rightIsFull] @@ -479,22 +457,33 @@ private def CompleteTree.heapRemoveLast {α : Type u} (heap : CompleteTree α (o --remove left match n, left with | (l+1), left => - let ((newLeft : CompleteTree α l), res) := left.heapRemoveLast + let ((newLeft : CompleteTree α l), res) := left.heapRemoveLastAux aux0 auxl auxr have q : l + m + 1 = l + 1 + m := Nat.add_right_comm l m 1 have s : m ≤ l := Nat.le_of_lt_succ r.left have rightIsFull : (m+1).isPowerOfTwo := (Nat.power_of_two_iff_next_power_eq (m+1)).mpr $ decide_eq_true_eq.mp r.right have l_lt_2_m_succ : l < 2 * (m+1) := Nat.lt_of_succ_lt max_height_difference + let res := auxl res (by simp_arith) (q▸CompleteTree.branch a newLeft right s l_lt_2_m_succ (Or.inr rightIsFull), res) else --remove right have m_gt_0 : m > 0 := removeRightRightNotEmpty m_gt_0_or_rightIsFull p r let l := m.pred have h₂ : l.succ = m := (Nat.succ_pred $ Nat.not_eq_zero_of_lt (Nat.gt_of_not_le $ Nat.not_le_of_gt m_gt_0)) - let ((newRight : CompleteTree α l), res) := (h₂.symm▸right).heapRemoveLast + let ((newRight : CompleteTree α l), res) := (h₂.symm▸right).heapRemoveLastAux aux0 auxl auxr have leftIsFull : (n+1).isPowerOfTwo := removeRightLeftIsFull r m_le_n subtree_complete have still_in_range : n < 2 * (l+1) := h₂.substr (p := λx ↦ n < 2 * x) $ stillInRange r m_le_n m_gt_0 leftIsFull max_height_difference + let res := auxr res n (by omega) (h₂▸CompleteTree.branch a left newRight (Nat.le_of_succ_le (h₂▸m_le_n)) still_in_range (Or.inl leftIsFull), res) +private def CompleteTree.heapRemoveLast {α : Type u} {o : Nat} (heap : CompleteTree α (o+1)) : (CompleteTree α o × α) := + heap.heapRemoveLastAux id (λ(a : α) _ ↦ a) (λa _ _ ↦ a) + +private def CompleteTree.heapRemoveLastWithIndex {α : Type u} {o : Nat} (heap : CompleteTree α (o+1)) : (CompleteTree α o × α × Fin (o+1)) := + heap.heapRemoveLastAux (β := λn ↦ α × Fin n) + (λ(a : α) ↦ (a, Fin.mk 0 (Nat.succ_pos 0))) + (λ(a, prev_idx) h₁ ↦ (a, prev_idx.succ.castLE $ Nat.succ_le_of_lt h₁) ) + (λ(a, prev_idx) left_size h₁ ↦ (a, (prev_idx.addNat left_size).succ.castLE $ Nat.succ_le_of_lt h₁)) + private theorem CompleteTree.castZeroHeap (n m : Nat) (heap : CompleteTree α 0) (h₁ : 0=n+m) {le : α → α → Bool} : HeapPredicate (h₁ ▸ heap) le := by have h₂ : heap = (CompleteTree.empty : CompleteTree α 0) := by simp[empty] @@ -530,14 +519,29 @@ private theorem HeapPredicate.seesThroughCast2 assumption -- If there is only one element left, the result is a leaf. -private theorem CompleteTree.heapRemoveLastLeaf (heap : CompleteTree α 1) : heap.heapRemoveLast.fst = CompleteTree.leaf := by - let l := heap.heapRemoveLast.fst +private theorem CompleteTree.heapRemoveLastAuxLeaf +{α : Type u} +{β : Nat → Type u} +(heap : CompleteTree α 1) +(aux0 : α → (β 1)) +(auxl : {prev_size curr_size : Nat} → β prev_size → (h₁ : prev_size < curr_size) → β curr_size) +(auxr : {prev_size curr_size : Nat} → β prev_size → (left_size : Nat) → (h₁ : prev_size + left_size < curr_size) → β curr_size) +: (heap.heapRemoveLastAux aux0 auxl auxr).fst = CompleteTree.leaf := by + let l := (heap.heapRemoveLastAux aux0 auxl auxr).fst have h₁ : l = CompleteTree.leaf := match l with | .leaf => rfl exact h₁ -private theorem CompleteTree.heapRemoveLastLeavesRoot (heap : CompleteTree α (n+1)) (h₁ : n > 0) : heap.root (Nat.zero_lt_of_ne_zero $ Nat.succ_ne_zero n) = heap.heapRemoveLast.fst.root (h₁) := by - unfold heapRemoveLast +private theorem CompleteTree.heapRemoveLastAuxLeavesRoot +{α : Type u} +{β : Nat → Type u} +(heap : CompleteTree α (n+1)) +(aux0 : α → (β 1)) +(auxl : {prev_size curr_size : Nat} → β prev_size → (h₁ : prev_size < curr_size) → β curr_size) +(auxr : {prev_size curr_size : Nat} → β prev_size → (left_size : Nat) → (h₁ : prev_size + left_size < curr_size) → β curr_size) +(h₁ : n > 0) +: heap.root (Nat.zero_lt_of_ne_zero $ Nat.succ_ne_zero n) = (heap.heapRemoveLastAux aux0 auxl auxr).fst.root (h₁) := by + unfold heapRemoveLastAux split rename_i o p v l r _ _ _ have h₃ : (0 ≠ o + p) := Ne.symm $ Nat.not_eq_zero_of_lt h₁ @@ -561,8 +565,16 @@ private theorem CompleteTree.heapRemoveLastLeavesRoot (heap : CompleteTree α (n simp_arith apply root_unfold -private theorem CompleteTree.heapRemoveLastIsHeap {heap : CompleteTree α (o+1)} {le : α → α → Bool} (h₁ : HeapPredicate heap le) (h₂ : transitive_le le) (h₃ : total_le le) : HeapPredicate (heap.heapRemoveLast.fst) le := by - unfold heapRemoveLast +private theorem CompleteTree.heapRemoveLastAuxIsHeap +{α : Type u} +{β : Nat → Type u} +{heap : CompleteTree α (o+1)} +{le : α → α → Bool} +(aux0 : α → (β 1)) +(auxl : {prev_size curr_size : Nat} → β prev_size → (h₁ : prev_size < curr_size) → β curr_size) +(auxr : {prev_size curr_size : Nat} → β prev_size → (left_size : Nat) → (h₁ : prev_size + left_size < curr_size) → β curr_size) +(h₁ : HeapPredicate heap le) (h₂ : transitive_le le) (h₃ : total_le le) : HeapPredicate ((heap.heapRemoveLastAux aux0 auxl auxr).fst) le := by + unfold heapRemoveLastAux split rename_i n m v l r _ _ _ exact @@ -581,7 +593,7 @@ private theorem CompleteTree.heapRemoveLastIsHeap {heap : CompleteTree α (o+1)} apply HeapPredicate.seesThroughCast2 <;> try simp_arith cases ll case a.zero => -- if ll is zero, then (heapRemoveLast l).snd is a leaf. - have h₆ : l.heapRemoveLast.fst = .leaf := heapRemoveLastLeaf l + have h₆ := heapRemoveLastAuxLeaf l aux0 auxl auxr rw[h₆] unfold HeapPredicate at * have h₇ : HeapPredicate .leaf le := by trivial @@ -589,10 +601,10 @@ private theorem CompleteTree.heapRemoveLastIsHeap {heap : CompleteTree α (o+1)} exact ⟨h₇,h₁.right.left, h₈, h₁.right.right.right⟩ case a.succ nn => -- if ll is not zero, then the root element before and after heapRemoveLast is the same. unfold HeapPredicate at * - simp[h₁.right.left, h₁.right.right.right, heapRemoveLastIsHeap h₁.left h₂ h₃] + simp[h₁.right.left, h₁.right.right.right, heapRemoveLastAuxIsHeap aux0 auxl auxr h₁.left h₂ h₃] unfold HeapPredicate.leOrLeaf simp - rw[←heapRemoveLastLeavesRoot] + rw[←heapRemoveLastAuxLeavesRoot] exact h₁.right.right.left else by simp[h₅] @@ -603,16 +615,22 @@ private theorem CompleteTree.heapRemoveLastIsHeap {heap : CompleteTree α (o+1)} case succ mm h₆ h₇ h₈ => simp unfold HeapPredicate at * - simp[h₁, heapRemoveLastIsHeap h₁.right.left h₂ h₃] + simp[h₁, heapRemoveLastAuxIsHeap aux0 auxl auxr h₁.right.left h₂ h₃] unfold HeapPredicate.leOrLeaf exact match mm with | 0 => rfl | o+1 => - have h₉ : le v ((r.heapRemoveLast).fst.root (Nat.zero_lt_succ o)) := by - rw[←heapRemoveLastLeavesRoot] + have h₉ : le v ((r.heapRemoveLastAux _ _ _).fst.root (Nat.zero_lt_succ o)) := by + rw[←heapRemoveLastAuxLeavesRoot] exact h₁.right.right.right h₉ +private theorem CompleteTree.heapRemoveLastIsHeap {α : Type u} {heap : CompleteTree α (o+1)} {le : α → α → Bool} (h₁ : HeapPredicate heap le) (h₂ : transitive_le le) (h₃ : total_le le) : HeapPredicate (heap.heapRemoveLast.fst) le := + heapRemoveLastAuxIsHeap _ _ _ h₁ h₂ h₃ + +private theorem CompleteTree.heapRemoveLastWithIndexIsHeap {α : Type u} {heap : CompleteTree α (o+1)} {le : α → α → Bool} (h₁ : HeapPredicate heap le) (h₂ : transitive_le le) (h₃ : total_le le) : HeapPredicate (heap.heapRemoveLastWithIndex.fst) le := + heapRemoveLastAuxIsHeap _ _ _ h₁ h₂ h₃ + private def BinaryHeap.heapRemoveLast {α : Type u} {le : α → α → Bool} {n : Nat} : (BinaryHeap α le (n+1)) → BinaryHeap α le n × α | {tree, valid, wellDefinedLe} => let result := tree.heapRemoveLast @@ -1063,30 +1081,29 @@ def CompleteTree.heapRemoveAt {α : Type u} {n : Nat} (le : α → α → Bool) if index_ne_zero : index = 0 then heapPop le heap else - let lastIndex := heap.indexOfLast - let l := heap.heapRemoveLast - if p : index = lastIndex then - l + let (remaining_tree, removed_element, removed_index) := heap.heapRemoveLastWithIndex + if p : index = removed_index then + (remaining_tree, removed_element) else have n_gt_zero : n > 0 := by cases n case succ nn => exact Nat.zero_lt_succ nn case zero => omega - if index_lt_lastIndex : index ≥ lastIndex then + if index_lt_lastIndex : index ≥ removed_index then let index := index.pred index_ne_zero - heapUpdateAt le index l.snd l.fst n_gt_zero + heapUpdateAt le index removed_element remaining_tree n_gt_zero else let h₁ : index < n := by omega let index : Fin n := ⟨index, h₁⟩ - heapUpdateAt le index l.snd l.fst n_gt_zero + heapUpdateAt le index removed_element remaining_tree n_gt_zero theorem CompleteTree.heapRemoveAtIsHeap {α : Type u} {n : Nat} (le : α → α → Bool) (index : Fin (n+1)) (heap : CompleteTree α (n+1)) (h₁ : HeapPredicate heap le) (wellDefinedLe : transitive_le le ∧ total_le le) : HeapPredicate (heap.heapRemoveAt le index).fst le := by - have h₂ : HeapPredicate heap.heapRemoveLast.fst le := heapRemoveLastIsHeap h₁ wellDefinedLe.left wellDefinedLe.right + have h₂ : HeapPredicate heap.heapRemoveLastWithIndex.fst le := heapRemoveLastWithIndexIsHeap h₁ wellDefinedLe.left wellDefinedLe.right unfold heapRemoveAt split case isTrue => exact heapPopIsHeap le heap h₁ wellDefinedLe case isFalse h₃ => - cases h: (index = heap.indexOfLast : Bool) + cases h: (index = heap.heapRemoveLastWithIndex.snd.snd : Bool) <;> simp_all split <;> apply heapUpdateAtIsHeap <;> simp_all @@ -1120,7 +1137,7 @@ private def TestHeap := |> ins 3 #eval TestHeap -#eval TestHeap.heapRemoveLast +#eval TestHeap.heapRemoveLastWithIndex #eval TestHeap.indexOf (13 = ·) #eval TestHeap.heapRemoveAt (.≤.) 7 -- cgit v1.2.3