From ca2a27a4c416a73e17230e36ddbbbf042fbd5b1e Mon Sep 17 00:00:00 2001 From: Jannis Limperg Date: Mon, 30 Sep 2024 18:44:30 +0200 Subject: [PATCH] Support local forward rules (+ forward tracing) Local forward rules are now re-elaborated each time we match a hypothesis against them (and again when we reconstruct a queue entry). Previously, they were elaborated once and for all for the first goal, but this means that the resulting expressions could become invalid if the internal names of hypotheses in the goal changed. --- Aesop/Builder/Forward.lean | 3 +- Aesop/Forward/RuleInfo.lean | 17 ++-- Aesop/Forward/State.lean | 172 ++++++++++++++++++++++-------------- Aesop/Rule/Forward.lean | 7 +- Aesop/RuleTac/Basic.lean | 3 +- Aesop/Saturate.lean | 8 +- Aesop/Tracing.lean | 4 + 7 files changed, 130 insertions(+), 84 deletions(-) diff --git a/Aesop/Builder/Forward.lean b/Aesop/Builder/Forward.lean index 9a2a684..71ff940 100644 --- a/Aesop/Builder/Forward.lean +++ b/Aesop/Builder/Forward.lean @@ -113,7 +113,8 @@ def forwardCore₂ (t : ElabRuleTerm) (immediate? : Option (Array Name)) return some { toForwardRuleInfo := info name := { phase := phase.phase, name, scope := t.scope, builder := .forward } - expr, prio + term := t.toRuleTerm + prio } def forwardCore (t : ElabRuleTerm) (immediate? : Option (Array Name)) diff --git a/Aesop/Forward/RuleInfo.lean b/Aesop/Forward/RuleInfo.lean index 053fadd..dc08d9f 100644 --- a/Aesop/Forward/RuleInfo.lean +++ b/Aesop/Forward/RuleInfo.lean @@ -27,14 +27,14 @@ structure Slot where /-- Index of the slot. Slots are always part of a list of slots, and `index` is the 0-based index of this slot in that list. -/ index : SlotIndex + /-- 0-based index of the premise represented by this slot in the rule type. + Note that the slots array may use a different ordering than the original + order of premises, so we *don't* always have `index ≤ premiseIndex`. -/ + premiseIndex : PremiseIndex /-- The previous premises that the premise of this slot depends on. -/ deps : Std.HashSet PremiseIndex /-- Common variables shared between this slot and the previous slots. -/ common : Std.HashSet PremiseIndex - /-- 0-based index of the premise represented by this slot in the rule type. - Note that the slots array may use a different ordering than the original - order of premises, so we *don't* always have `slotIndex ≤ premiseIndex`. -/ - premiseIndex : PremiseIndex deriving Inhabited local instance : BEq Slot := @@ -45,10 +45,8 @@ local instance : Hashable Slot := /-- Information about the decomposed type of a forward rule. -/ structure ForwardRuleInfo where - /-- Metavariable context in which `premises` and `slotClusters` are valid. -/ - mctx : MetavarContext - /-- Metavariables representing the premises of the forward rule. -/ - premises : Array MVarId + /-- The rule's number of premises. -/ + numPremises : Nat /-- Slots representing the maximal premises of the forward rule, partitioned into metavariable clusters. -/ slotClusters : Array (Array Slot) @@ -60,7 +58,6 @@ namespace ForwardRuleInfo def ofExpr (thm : Expr) : MetaM ForwardRuleInfo := withNewMCtxDepth do let e ← inferType thm let (premises, _, _) ← forallMetaTelescope e - let mctx ← getMCtx let premises := premises.map (·.mvarId!) let mut premiseToIdx : Std.HashMap MVarId PremiseIndex := ∅ for h : i in [:premises.size] do @@ -93,7 +90,7 @@ def ofExpr (thm : Expr) : MetaM ForwardRuleInfo := withNewMCtxDepth do -- slot has some variables in common with the previous slots. assert! ! slotClusters.any λ cluster => cluster.any λ slot => slot.index.toNat > 0 && slot.common.isEmpty - return { premises, slotClusters, mctx } + return { slotClusters, numPremises := premises.size } where /-- Sort slots such that each slot has at least one variable in common with the previous slots. -/ diff --git a/Aesop/Forward/State.lean b/Aesop/Forward/State.lean index 0b595f7..665fb49 100644 --- a/Aesop/Forward/State.lean +++ b/Aesop/Forward/State.lean @@ -5,16 +5,24 @@ Authors: Xavier Généreux, Jannis Limperg -/ import Aesop.Rule.Forward +import Aesop.RuleTac.ElabRuleTerm import Batteries.Lean.Meta.SavedState import Batteries.Data.BinomialHeap.Basic open Lean Lean.Meta open Batteries (BinomialHeap) +open ExceptToEmoji (toEmoji) set_option linter.missingDocs true namespace Aesop +/-- Elaborate the term of a forward rule in the current goal. -/ +def elabForwardRuleTerm? (goal : MVarId) : RuleTerm → MetaM (Option Expr) + | .const n => mkConstWithFreshMVarLevels n + | .term stx => observing? do + (withFullElaboration $ elabRuleTermForApplyLikeMetaM goal stx) |>.run' + /-- A substitution maps premise indices to assignments. -/ abbrev Substitution := AssocList PremiseIndex Expr @@ -236,8 +244,6 @@ end VariableMap /-- Structure representing the state of a slot cluster. -/ structure ClusterState where - /-- The metavariable context in which `slots` and `variableMaps` are valid. -/ - mctx : MetavarContext /-- The cluster's slots. -/ slots : Array Slot /-- The variable map for this cluster. -/ @@ -257,17 +263,17 @@ def slot! (cs : ClusterState) (slot : SlotIndex) : Slot := def findSlot? (cs : ClusterState) (i : PremiseIndex) : Option Slot := cs.slots.find? (·.premiseIndex == i) -/-- Match hypothesis `hyp` against the slot with index `slot` in `rs` (which +/-- Match hypothesis `hyp` against the slot with index `slot` in `cs` (which must be a valid index). -/ def matchPremise? (premises : Array MVarId) (cs : ClusterState) (slot : SlotIndex) (hyp : FVarId) : MetaM (Option Substitution) := do let some slot := cs.slots[slot.toNat]? | throwError "aesop: internal error: matchPremise?: no slot with index {slot}" - withMCtx cs.mctx do - let some slotPremise := premises[slot.premiseIndex.toNat]? - | throwError "aesop: internal error: matchPremise?: slot with premise index {slot.premiseIndex}, but only {premises.size} premises" - let inputHypType ← slotPremise.getType - let hypType ← hyp.getType + let some slotPremise := premises[slot.premiseIndex.toNat]? + | throwError "aesop: internal error: matchPremise?: slot with premise index {slot.premiseIndex}, but only {premises.size} premises" + let inputHypType ← slotPremise.getType + let hypType ← hyp.getType + withAesopTraceNode .debug (λ r => return m!"{toEmoji r} match against premise {slot.premiseIndex}: {hypType} ≟ {inputHypType}") do if ← isDefEq inputHypType hypType then /- Note: This was over `slot.common` and not `slot.deps`. We need `slot.deps` because, among other issues, `slot.common` is empty in the first slot. Even though @@ -361,9 +367,10 @@ structure CompleteMatch where namespace CompleteMatch -/-- Given a complete match `m` for `r`, produce an application of the theorem -`r.expr` to the hypotheses from `m`. -/ -def reconstruct (r : ForwardRule) (m : CompleteMatch) : Expr := Id.run do +/-- Given a complete match `m` for `r`, get arguments to `r` contained in the +match's slots and substitution. -/ +def reconstructArgs (r : ForwardRule) (m : CompleteMatch) : + Array Expr := Id.run do let mut slotHyps : Std.HashMap PremiseIndex FVarId := ∅ for h : i in [:r.slotClusters.size] do let cluster := r.slotClusters[i] @@ -380,8 +387,8 @@ def reconstruct (r : ForwardRule) (m : CompleteMatch) : Expr := Id.run do for m in m.clusterMatches do subst := subst.mergeCompatible m.subst - let mut args := Array.mkEmpty r.premises.size - for i in [:r.premises.size] do + let mut args := Array.mkEmpty r.numPremises + for i in [:r.numPremises] do if let some hyp := slotHyps.get? ⟨i⟩ then args := args.push (.fvar hyp) else if let some inst := subst.find? ⟨i⟩ then @@ -389,7 +396,16 @@ def reconstruct (r : ForwardRule) (m : CompleteMatch) : Expr := Id.run do else panic! s!"match for rule {r.name} is not complete: no hyp or instantiation for premise {i}" - return mkAppN r.expr args + return args + +/-- Given a complete match `m` for `r`, produce an application of the theorem +`r.expr` to the hypotheses from `m`. Returns `none` if the term of `e` does not +elaborate in the current goal. -/ +def reconstruct? (goal : MVarId) (r : ForwardRule) (m : CompleteMatch) : + MetaM (Option Expr) := do + let some e ← elabForwardRuleTerm? goal r.term + | return none + return mkAppN e (m.reconstructArgs r) end CompleteMatch @@ -411,9 +427,12 @@ protected def le (q₁ q₂ : CompleteMatchQueue.Entry) : Bool := | .unsafe x, .unsafe y => x ≥ y | _, _ => panic! "comparing QueueEntries with different priority types" -/-- Build the proof corresponding to the complete match contained in `entry`. -/ -def toProof (entry : CompleteMatchQueue.Entry) : Expr := - entry.match.reconstruct entry.rule +/-- Build the proof corresponding to the complete match contained in `entry`. +May fail if the forward rule is local and can't be elaborated in the current +goal. -/ +def toProof? (goal : MVarId) (entry : CompleteMatchQueue.Entry) : + MetaM (Option Expr) := + entry.match.reconstruct? goal entry.rule end CompleteMatchQueue.Entry @@ -432,26 +451,31 @@ structure RuleState where /-- The initial (empty) rule state for a given forward rule. -/ def ForwardRule.initialRuleState (r : ForwardRule) : RuleState := let clusterStates := r.slotClusters.map λ slots => - { mctx := r.mctx, slots, variableMap := ∅, completeMatches := {} } + { slots, variableMap := ∅, completeMatches := {} } { rule := r, clusterStates } namespace RuleState /-- Add a hypothesis to the rule state. Returns the new rule state and any newly completed matches. If `h` does not match premise `pi`, nothing happens. -/ -def addHyp (h : FVarId) (pi : PremiseIndex) (rs : RuleState) : +def addHyp (goal : MVarId) (h : FVarId) (pi : PremiseIndex) (rs : RuleState) : MetaM (RuleState × Array CompleteMatch) := do - let mut rs := rs - let mut clusterStates := rs.clusterStates - let mut completeMatches := #[] - for i in [:clusterStates.size] do - let cs := clusterStates[i]! - let (cs, newCompleteMatches) ← cs.addHyp rs.rule.premises pi h - clusterStates := clusterStates.set! i cs - completeMatches := - completeMatches ++ - getCompleteMatches clusterStates i newCompleteMatches - return ({ rs with clusterStates }, completeMatches) + let some ruleExpr ← elabForwardRuleTerm? goal rs.rule.term + | return (rs, #[]) + withNewMCtxDepth do + let (premises, _, _) ← forallMetaTelescope (← inferType ruleExpr) + let premises := premises.map (·.mvarId!) + let mut rs := rs + let mut clusterStates := rs.clusterStates + let mut completeMatches := #[] + for i in [:clusterStates.size] do + let cs := clusterStates[i]! + let (cs, newCompleteMatches) ← cs.addHyp premises pi h + clusterStates := clusterStates.set! i cs + completeMatches := + completeMatches ++ + getCompleteMatches clusterStates i newCompleteMatches + return ({ rs with clusterStates }, completeMatches) where getCompleteMatches (clusterStates : Array ClusterState) (clusterIdx : Nat) (newCompleteMatches : Array Match) : @@ -474,7 +498,8 @@ where if completeMatches.isEmpty then return clusterMatches.map ({ clusterMatches := #[·] }) else - let mut newCompleteMatches := Array.mkEmpty (completeMatches.size * clusterMatches.size) + let mut newCompleteMatches := + Array.mkEmpty (completeMatches.size * clusterMatches.size) for completeMatch in completeMatches do for clusterMatch in clusterMatches do newCompleteMatches := newCompleteMatches.push @@ -531,17 +556,19 @@ def addCompleteMatchQueueEntry (entry : CompleteMatchQueue.Entry) /-- Add a hypothesis to the forward state. If `fs` represents a local context `lctx`, then `fs.addHyp h ms` represents `lctx` with `h` added. `ms` must overapproximate the rules for which `h` may unify with a maximal premise. -/ -def addHyp (h : FVarId) (ms : Array (ForwardRule × PremiseIndex)) +def addHyp (goal : MVarId) (h : FVarId) (ms : Array (ForwardRule × PremiseIndex)) (fs : ForwardState) : MetaM ForwardState := do - let mut fs := fs - for (r, i) in ms do - let rs := fs.ruleStates.find? r.name |>.getD r.initialRuleState - let (rs, completeMatches) ← rs.addHyp h i - fs := { fs with ruleStates := fs.ruleStates.insert r.name rs } - for m in completeMatches do - let entry := { rule := r, «match» := m } - fs := fs.addCompleteMatchQueueEntry entry r.name.phase - return fs + goal.withContext do + withConstAesopTraceNode .debug (return m!"add hyp {Expr.fvar h}") do + ms.foldlM (init := fs) λ fs (r, i) => do + withConstAesopTraceNode .debug (return m!"rule {r.name}, premise {i}") do + let rs := fs.ruleStates.find? r.name |>.getD r.initialRuleState + let (rs, completeMatches) ← rs.addHyp goal h i + let fs := { fs with ruleStates := fs.ruleStates.insert r.name rs } + completeMatches.foldlM (init := fs) λ fs m => do + aesop_trace[forward] "new complete match with args {m.reconstructArgs r}" + let entry := { rule := r, «match» := m } + return fs.addCompleteMatchQueueEntry entry r.name.phase /-- Remove a hypothesis from the forward state. If `fs` represents a local context `lctx`, then `fs.eraseHyp h ms` represents `lctx` with `h` removed. `ms` @@ -557,39 +584,56 @@ def eraseHyp (h : FVarId) (ms : Array (ForwardRule × PremiseIndex)) return { fs with ruleStates, erasedHyps := fs.erasedHyps.insert h } @[inline] -private partial def popFirstMatch?' (fs : ForwardState) - (queue : CompleteMatchQueue) : Option (Expr × CompleteMatchQueue) := +private partial def popFirstMatch?' (goal : MVarId) (fs : ForwardState) + (queue : CompleteMatchQueue) : + MetaM (Option (Expr × CompleteMatchQueue)) := do match queue.deleteMin with - | none => none + | none => return none | some (entry, queue) => - let entryHasErasedHyp := - entry.match.clusterMatches.any λ m => - m.revHyps.any (fs.erasedHyps.contains ·) - if entryHasErasedHyp then - popFirstMatch?' fs queue - else - (entry.toProof, queue) + let result? ← + withAesopTraceNode .debug (λ r => return m!"{toEmoji r} reconstruct queue entry for rule {entry.rule.name} with args {entry.match.reconstructArgs entry.rule}") do + let entryHasErasedHyp := + entry.match.clusterMatches.any λ m => + m.revHyps.any (fs.erasedHyps.contains ·) + if entryHasErasedHyp then + aesop_trace[forward] "args contain erased hyp" + pure none + else if let some prf ← entry.toProof? goal then + pure (prf, queue) + else + aesop_trace[forward] "rule does not elaborate" + pure none + match result? with + | none => popFirstMatch?' goal fs queue + | some result => return result /-- Get a proof for the first complete match of a norm rule. -/ -def popFirstNormMatch? (fs : ForwardState) : Option (Expr × ForwardState) := - fs.popFirstMatch?' fs.normQueue - |>.map λ (e, q) => (e, { fs with normQueue := q }) +def popFirstNormMatch? (goal : MVarId) (fs : ForwardState) : + MetaM (Option (Expr × ForwardState)) := + return (← fs.popFirstMatch?' goal fs.normQueue).map λ (e, q) => + (e, { fs with normQueue := q }) /-- Get a proof for the first complete match of a safe rule. -/ -def popFirstSafeMatch? (fs : ForwardState) : Option (Expr × ForwardState) := - fs.popFirstMatch?' fs.safeQueue - |>.map λ (e, q) => (e, { fs with safeQueue := q }) +def popFirstSafeMatch? (goal : MVarId) (fs : ForwardState) : + MetaM (Option (Expr × ForwardState)) := + return (← fs.popFirstMatch?' goal fs.safeQueue).map λ (e, q) => + (e, { fs with safeQueue := q }) /-- Get a proof for the first complete match of an unsafe rule. -/ -def popFirstUnsafeMatch? (fs : ForwardState) : Option (Expr × ForwardState) := - fs.popFirstMatch?' fs.unsafeQueue - |>.map λ (e, q) => (e, { fs with unsafeQueue := q }) +def popFirstUnsafeMatch? (goal : MVarId) (fs : ForwardState) : + MetaM (Option (Expr × ForwardState)) := + return (← fs.popFirstMatch?' goal fs.unsafeQueue).map λ (e, q) => + (e, { fs with unsafeQueue := q }) /-- Get a proof for the first complete match. Norm rules are prioritised over safe rules, and safe over unsafe rules. -/ -def popFirstMatch? (fs : ForwardState) : Option (Expr × ForwardState) := - fs.popFirstNormMatch? <|> - fs.popFirstSafeMatch? <|> - fs.popFirstUnsafeMatch? +def popFirstMatch? (goal : MVarId) (fs : ForwardState) : + MetaM (Option (Expr × ForwardState)) := do + if let some result ← fs.popFirstNormMatch? goal then + return result + else if let some result ← fs.popFirstSafeMatch? goal then + return result + else + fs.popFirstUnsafeMatch? goal end Aesop.ForwardState diff --git a/Aesop/Rule/Forward.lean b/Aesop/Rule/Forward.lean index 614ed4c..2687736 100644 --- a/Aesop/Rule/Forward.lean +++ b/Aesop/Rule/Forward.lean @@ -4,9 +4,10 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Xavier Généreux, Jannis Limperg -/ +import Aesop.Forward.RuleInfo import Aesop.Percent import Aesop.Rule.Name -import Aesop.Forward.RuleInfo +import Aesop.RuleTac.Basic set_option linter.missingDocs true @@ -28,8 +29,7 @@ structure ForwardRule extends ForwardRuleInfo where /-- The rule's name. Should be unique among all rules in a rule set. -/ name : RuleName /-- The theorem from which this rule is derived. -/ - -- FIXME What happens if this expr becomes invalid due to fvar renamings etc.? - expr : Expr + term : RuleTerm /-- The rule's priority. -/ prio : ForwardRulePriority deriving Inhabited @@ -45,5 +45,4 @@ instance : Hashable ForwardRule := instance : Ord ForwardRule := ⟨λ r₁ r₂ => compare r₁.name r₂.name⟩ - end Aesop.ForwardRule diff --git a/Aesop/RuleTac/Basic.lean b/Aesop/RuleTac/Basic.lean index 0d08853..8d61b4d 100644 --- a/Aesop/RuleTac/Basic.lean +++ b/Aesop/RuleTac/Basic.lean @@ -19,7 +19,6 @@ open Lean.Meta namespace Aesop - /-! # Rule Tactic Types -/ -- TODO put docs on the structure fields instead of the structures @@ -170,10 +169,12 @@ inductive CasesTarget inductive RuleTerm | const (decl : Name) | term (term : Term) + deriving Inhabited inductive ElabRuleTerm | const (decl : Name) | term (term : Term) (expr : Expr) + deriving Inhabited namespace ElabRuleTerm diff --git a/Aesop/Saturate.lean b/Aesop/Saturate.lean index 6894587..b5b27ec 100644 --- a/Aesop/Saturate.lean +++ b/Aesop/Saturate.lean @@ -117,23 +117,23 @@ partial def saturateCore (rs : LocalRuleSet) (goal : MVarId) : if ldecl.isImplementationDetail then continue let rules ← index.get ldecl.type - fs ← fs.addHyp ldecl.fvarId rules + fs ← fs.addHyp goal ldecl.fvarId rules go fs goal where go (fs : ForwardState) (goal : MVarId) : ScriptM MVarId := do withIncRecDepth do goal.withContext do - if let some (prf, fs) := fs.popFirstMatch? then + if let some (prf, fs) ← fs.popFirstMatch? goal then trace[saturate] "goal:{indentD goal}" let name ← getUnusedUserName forwardHypPrefix let type ← inferType prf - trace[saturate] "add: {name} : {type} := {prf}" + trace[saturate] "add hyp {name} : {type} := {prf}" let hyp := { userName := name, value := prf, type } let (goal, #[hyp]) ← assertHypothesisS goal hyp (md := .default) | unreachable! goal.withContext do let rules ← rs.forwardRules.get type - let fs ← fs.addHyp hyp rules + let fs ← fs.addHyp goal hyp rules go fs goal else return goal diff --git a/Aesop/Tracing.lean b/Aesop/Tracing.lean index 7730a5f..9f3bb4a 100644 --- a/Aesop/Tracing.lean +++ b/Aesop/Tracing.lean @@ -66,6 +66,10 @@ initialize script : TraceOption ← registerTraceOption `script "(aesop) Print a trace of script generation." +initialize forward : TraceOption ← + registerTraceOption `forward + "(aesop) Trace forward reasoning." + end TraceOption section