Skip to content

Commit

Permalink
Support local forward rules (+ forward tracing)
Browse files Browse the repository at this point in the history
Local forward rules are now re-elaborated each time we match a
hypothesis against them (and again when we reconstruct a queue entry).
Previously, they were elaborated once and for all for the first goal,
but this means that the resulting expressions could become invalid if
the internal names of hypotheses in the goal changed.
  • Loading branch information
JLimperg committed Sep 30, 2024
1 parent 3858f66 commit ca2a27a
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 84 deletions.
3 changes: 2 additions & 1 deletion Aesop/Builder/Forward.lean
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def forwardCore₂ (t : ElabRuleTerm) (immediate? : Option (Array Name))
return some {
toForwardRuleInfo := info
name := { phase := phase.phase, name, scope := t.scope, builder := .forward }
expr, prio
term := t.toRuleTerm
prio
}

def forwardCore (t : ElabRuleTerm) (immediate? : Option (Array Name))
Expand Down
17 changes: 7 additions & 10 deletions Aesop/Forward/RuleInfo.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ structure Slot where
/-- Index of the slot. Slots are always part of a list of slots, and `index`
is the 0-based index of this slot in that list. -/
index : SlotIndex
/-- 0-based index of the premise represented by this slot in the rule type.
Note that the slots array may use a different ordering than the original
order of premises, so we *don't* always have `index ≤ premiseIndex`. -/
premiseIndex : PremiseIndex
/-- The previous premises that the premise of this slot depends on. -/
deps : Std.HashSet PremiseIndex
/-- Common variables shared between this slot and the previous slots. -/
common : Std.HashSet PremiseIndex
/-- 0-based index of the premise represented by this slot in the rule type.
Note that the slots array may use a different ordering than the original
order of premises, so we *don't* always have `slotIndex ≤ premiseIndex`. -/
premiseIndex : PremiseIndex
deriving Inhabited

local instance : BEq Slot :=
Expand All @@ -45,10 +45,8 @@ local instance : Hashable Slot :=

/-- Information about the decomposed type of a forward rule. -/
structure ForwardRuleInfo where
/-- Metavariable context in which `premises` and `slotClusters` are valid. -/
mctx : MetavarContext
/-- Metavariables representing the premises of the forward rule. -/
premises : Array MVarId
/-- The rule's number of premises. -/
numPremises : Nat
/-- Slots representing the maximal premises of the forward rule, partitioned
into metavariable clusters. -/
slotClusters : Array (Array Slot)
Expand All @@ -60,7 +58,6 @@ namespace ForwardRuleInfo
def ofExpr (thm : Expr) : MetaM ForwardRuleInfo := withNewMCtxDepth do
let e ← inferType thm
let (premises, _, _) ← forallMetaTelescope e
let mctx ← getMCtx
let premises := premises.map (·.mvarId!)
let mut premiseToIdx : Std.HashMap MVarId PremiseIndex := ∅
for h : i in [:premises.size] do
Expand Down Expand Up @@ -93,7 +90,7 @@ def ofExpr (thm : Expr) : MetaM ForwardRuleInfo := withNewMCtxDepth do
-- slot has some variables in common with the previous slots.
assert! ! slotClusters.any λ cluster => cluster.any λ slot =>
slot.index.toNat > 0 && slot.common.isEmpty
return { premises, slotClusters, mctx }
return { slotClusters, numPremises := premises.size }
where
/-- Sort slots such that each slot has at least one variable in common with
the previous slots. -/
Expand Down
172 changes: 108 additions & 64 deletions Aesop/Forward/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,24 @@ Authors: Xavier Généreux, Jannis Limperg
-/

import Aesop.Rule.Forward
import Aesop.RuleTac.ElabRuleTerm
import Batteries.Lean.Meta.SavedState
import Batteries.Data.BinomialHeap.Basic

open Lean Lean.Meta
open Batteries (BinomialHeap)
open ExceptToEmoji (toEmoji)

set_option linter.missingDocs true

namespace Aesop

/-- Elaborate the term of a forward rule in the current goal. -/
def elabForwardRuleTerm? (goal : MVarId) : RuleTerm → MetaM (Option Expr)
| .const n => mkConstWithFreshMVarLevels n
| .term stx => observing? do
(withFullElaboration $ elabRuleTermForApplyLikeMetaM goal stx) |>.run'

/-- A substitution maps premise indices to assignments. -/
abbrev Substitution := AssocList PremiseIndex Expr

Expand Down Expand Up @@ -236,8 +244,6 @@ end VariableMap

/-- Structure representing the state of a slot cluster. -/
structure ClusterState where
/-- The metavariable context in which `slots` and `variableMaps` are valid. -/
mctx : MetavarContext
/-- The cluster's slots. -/
slots : Array Slot
/-- The variable map for this cluster. -/
Expand All @@ -257,17 +263,17 @@ def slot! (cs : ClusterState) (slot : SlotIndex) : Slot :=
def findSlot? (cs : ClusterState) (i : PremiseIndex) : Option Slot :=
cs.slots.find? (·.premiseIndex == i)

/-- Match hypothesis `hyp` against the slot with index `slot` in `rs` (which
/-- Match hypothesis `hyp` against the slot with index `slot` in `cs` (which
must be a valid index). -/
def matchPremise? (premises : Array MVarId) (cs : ClusterState)
(slot : SlotIndex) (hyp : FVarId) : MetaM (Option Substitution) := do
let some slot := cs.slots[slot.toNat]?
| throwError "aesop: internal error: matchPremise?: no slot with index {slot}"
withMCtx cs.mctx do
let some slotPremise := premises[slot.premiseIndex.toNat]?
| throwError "aesop: internal error: matchPremise?: slot with premise index {slot.premiseIndex}, but only {premises.size} premises"
let inputHypTypeslotPremise.getType
let hypType ← hyp.getType
let some slotPremise := premises[slot.premiseIndex.toNat]?
| throwError "aesop: internal error: matchPremise?: slot with premise index {slot.premiseIndex}, but only {premises.size} premises"
let inputHypType ← slotPremise.getType
let hypTypehyp.getType
withAesopTraceNode .debug (λ r => return m!"{toEmoji r} match against premise {slot.premiseIndex}: {hypType} ≟ {inputHypType}") do
if ← isDefEq inputHypType hypType then
/- Note: This was over `slot.common` and not `slot.deps`. We need `slot.deps`
because, among other issues, `slot.common` is empty in the first slot. Even though
Expand Down Expand Up @@ -361,9 +367,10 @@ structure CompleteMatch where

namespace CompleteMatch

/-- Given a complete match `m` for `r`, produce an application of the theorem
`r.expr` to the hypotheses from `m`. -/
def reconstruct (r : ForwardRule) (m : CompleteMatch) : Expr := Id.run do
/-- Given a complete match `m` for `r`, get arguments to `r` contained in the
match's slots and substitution. -/
def reconstructArgs (r : ForwardRule) (m : CompleteMatch) :
Array Expr := Id.run do
let mut slotHyps : Std.HashMap PremiseIndex FVarId := ∅
for h : i in [:r.slotClusters.size] do
let cluster := r.slotClusters[i]
Expand All @@ -380,16 +387,25 @@ def reconstruct (r : ForwardRule) (m : CompleteMatch) : Expr := Id.run do
for m in m.clusterMatches do
subst := subst.mergeCompatible m.subst

let mut args := Array.mkEmpty r.premises.size
for i in [:r.premises.size] do
let mut args := Array.mkEmpty r.numPremises
for i in [:r.numPremises] do
if let some hyp := slotHyps.get? ⟨i⟩ then
args := args.push (.fvar hyp)
else if let some inst := subst.find? ⟨i⟩ then
args := args.push inst
else
panic! s!"match for rule {r.name} is not complete: no hyp or instantiation for premise {i}"

return mkAppN r.expr args
return args

/-- Given a complete match `m` for `r`, produce an application of the theorem
`r.expr` to the hypotheses from `m`. Returns `none` if the term of `e` does not
elaborate in the current goal. -/
def reconstruct? (goal : MVarId) (r : ForwardRule) (m : CompleteMatch) :
MetaM (Option Expr) := do
let some e ← elabForwardRuleTerm? goal r.term
| return none
return mkAppN e (m.reconstructArgs r)

end CompleteMatch

Expand All @@ -411,9 +427,12 @@ protected def le (q₁ q₂ : CompleteMatchQueue.Entry) : Bool :=
| .unsafe x, .unsafe y => x ≥ y
| _, _ => panic! "comparing QueueEntries with different priority types"

/-- Build the proof corresponding to the complete match contained in `entry`. -/
def toProof (entry : CompleteMatchQueue.Entry) : Expr :=
entry.match.reconstruct entry.rule
/-- Build the proof corresponding to the complete match contained in `entry`.
May fail if the forward rule is local and can't be elaborated in the current
goal. -/
def toProof? (goal : MVarId) (entry : CompleteMatchQueue.Entry) :
MetaM (Option Expr) :=
entry.match.reconstruct? goal entry.rule

end CompleteMatchQueue.Entry

Expand All @@ -432,26 +451,31 @@ structure RuleState where
/-- The initial (empty) rule state for a given forward rule. -/
def ForwardRule.initialRuleState (r : ForwardRule) : RuleState :=
let clusterStates := r.slotClusters.map λ slots =>
{ mctx := r.mctx, slots, variableMap := ∅, completeMatches := {} }
{ slots, variableMap := ∅, completeMatches := {} }
{ rule := r, clusterStates }

namespace RuleState

/-- Add a hypothesis to the rule state. Returns the new rule state and any newly
completed matches. If `h` does not match premise `pi`, nothing happens. -/
def addHyp (h : FVarId) (pi : PremiseIndex) (rs : RuleState) :
def addHyp (goal : MVarId) (h : FVarId) (pi : PremiseIndex) (rs : RuleState) :
MetaM (RuleState × Array CompleteMatch) := do
let mut rs := rs
let mut clusterStates := rs.clusterStates
let mut completeMatches := #[]
for i in [:clusterStates.size] do
let cs := clusterStates[i]!
let (cs, newCompleteMatches) ← cs.addHyp rs.rule.premises pi h
clusterStates := clusterStates.set! i cs
completeMatches :=
completeMatches ++
getCompleteMatches clusterStates i newCompleteMatches
return ({ rs with clusterStates }, completeMatches)
let some ruleExpr ← elabForwardRuleTerm? goal rs.rule.term
| return (rs, #[])
withNewMCtxDepth do
let (premises, _, _) ← forallMetaTelescope (← inferType ruleExpr)
let premises := premises.map (·.mvarId!)
let mut rs := rs
let mut clusterStates := rs.clusterStates
let mut completeMatches := #[]
for i in [:clusterStates.size] do
let cs := clusterStates[i]!
let (cs, newCompleteMatches) ← cs.addHyp premises pi h
clusterStates := clusterStates.set! i cs
completeMatches :=
completeMatches ++
getCompleteMatches clusterStates i newCompleteMatches
return ({ rs with clusterStates }, completeMatches)
where
getCompleteMatches (clusterStates : Array ClusterState) (clusterIdx : Nat)
(newCompleteMatches : Array Match) :
Expand All @@ -474,7 +498,8 @@ where
if completeMatches.isEmpty then
return clusterMatches.map ({ clusterMatches := #[·] })
else
let mut newCompleteMatches := Array.mkEmpty (completeMatches.size * clusterMatches.size)
let mut newCompleteMatches :=
Array.mkEmpty (completeMatches.size * clusterMatches.size)
for completeMatch in completeMatches do
for clusterMatch in clusterMatches do
newCompleteMatches := newCompleteMatches.push
Expand Down Expand Up @@ -531,17 +556,19 @@ def addCompleteMatchQueueEntry (entry : CompleteMatchQueue.Entry)
/-- Add a hypothesis to the forward state. If `fs` represents a local context
`lctx`, then `fs.addHyp h ms` represents `lctx` with `h` added. `ms` must
overapproximate the rules for which `h` may unify with a maximal premise. -/
def addHyp (h : FVarId) (ms : Array (ForwardRule × PremiseIndex))
def addHyp (goal : MVarId) (h : FVarId) (ms : Array (ForwardRule × PremiseIndex))
(fs : ForwardState) : MetaM ForwardState := do
let mut fs := fs
for (r, i) in ms do
let rs := fs.ruleStates.find? r.name |>.getD r.initialRuleState
let (rs, completeMatches) ← rs.addHyp h i
fs := { fs with ruleStates := fs.ruleStates.insert r.name rs }
for m in completeMatches do
let entry := { rule := r, «match» := m }
fs := fs.addCompleteMatchQueueEntry entry r.name.phase
return fs
goal.withContext do
withConstAesopTraceNode .debug (return m!"add hyp {Expr.fvar h}") do
ms.foldlM (init := fs) λ fs (r, i) => do
withConstAesopTraceNode .debug (return m!"rule {r.name}, premise {i}") do
let rs := fs.ruleStates.find? r.name |>.getD r.initialRuleState
let (rs, completeMatches) ← rs.addHyp goal h i
let fs := { fs with ruleStates := fs.ruleStates.insert r.name rs }
completeMatches.foldlM (init := fs) λ fs m => do
aesop_trace[forward] "new complete match with args {m.reconstructArgs r}"
let entry := { rule := r, «match» := m }
return fs.addCompleteMatchQueueEntry entry r.name.phase

/-- Remove a hypothesis from the forward state. If `fs` represents a local
context `lctx`, then `fs.eraseHyp h ms` represents `lctx` with `h` removed. `ms`
Expand All @@ -557,39 +584,56 @@ def eraseHyp (h : FVarId) (ms : Array (ForwardRule × PremiseIndex))
return { fs with ruleStates, erasedHyps := fs.erasedHyps.insert h }

@[inline]
private partial def popFirstMatch?' (fs : ForwardState)
(queue : CompleteMatchQueue) : Option (Expr × CompleteMatchQueue) :=
private partial def popFirstMatch?' (goal : MVarId) (fs : ForwardState)
(queue : CompleteMatchQueue) :
MetaM (Option (Expr × CompleteMatchQueue)) := do
match queue.deleteMin with
| none => none
| none => return none
| some (entry, queue) =>
let entryHasErasedHyp :=
entry.match.clusterMatches.any λ m =>
m.revHyps.any (fs.erasedHyps.contains ·)
if entryHasErasedHyp then
popFirstMatch?' fs queue
else
(entry.toProof, queue)
let result? ←
withAesopTraceNode .debug (λ r => return m!"{toEmoji r} reconstruct queue entry for rule {entry.rule.name} with args {entry.match.reconstructArgs entry.rule}") do
let entryHasErasedHyp :=
entry.match.clusterMatches.any λ m =>
m.revHyps.any (fs.erasedHyps.contains ·)
if entryHasErasedHyp then
aesop_trace[forward] "args contain erased hyp"
pure none
else if let some prf ← entry.toProof? goal then
pure (prf, queue)
else
aesop_trace[forward] "rule does not elaborate"
pure none
match result? with
| none => popFirstMatch?' goal fs queue
| some result => return result

/-- Get a proof for the first complete match of a norm rule. -/
def popFirstNormMatch? (fs : ForwardState) : Option (Expr × ForwardState) :=
fs.popFirstMatch?' fs.normQueue
|>.map λ (e, q) => (e, { fs with normQueue := q })
def popFirstNormMatch? (goal : MVarId) (fs : ForwardState) :
MetaM (Option (Expr × ForwardState)) :=
return (← fs.popFirstMatch?' goal fs.normQueue).map λ (e, q) =>
(e, { fs with normQueue := q })

/-- Get a proof for the first complete match of a safe rule. -/
def popFirstSafeMatch? (fs : ForwardState) : Option (Expr × ForwardState) :=
fs.popFirstMatch?' fs.safeQueue
|>.map λ (e, q) => (e, { fs with safeQueue := q })
def popFirstSafeMatch? (goal : MVarId) (fs : ForwardState) :
MetaM (Option (Expr × ForwardState)) :=
return (← fs.popFirstMatch?' goal fs.safeQueue).map λ (e, q) =>
(e, { fs with safeQueue := q })

/-- Get a proof for the first complete match of an unsafe rule. -/
def popFirstUnsafeMatch? (fs : ForwardState) : Option (Expr × ForwardState) :=
fs.popFirstMatch?' fs.unsafeQueue
|>.map λ (e, q) => (e, { fs with unsafeQueue := q })
def popFirstUnsafeMatch? (goal : MVarId) (fs : ForwardState) :
MetaM (Option (Expr × ForwardState)) :=
return (← fs.popFirstMatch?' goal fs.unsafeQueue).map λ (e, q) =>
(e, { fs with unsafeQueue := q })

/-- Get a proof for the first complete match. Norm rules are prioritised over
safe rules, and safe over unsafe rules. -/
def popFirstMatch? (fs : ForwardState) : Option (Expr × ForwardState) :=
fs.popFirstNormMatch? <|>
fs.popFirstSafeMatch? <|>
fs.popFirstUnsafeMatch?
def popFirstMatch? (goal : MVarId) (fs : ForwardState) :
MetaM (Option (Expr × ForwardState)) := do
if let some result ← fs.popFirstNormMatch? goal then
return result
else if let some result ← fs.popFirstSafeMatch? goal then
return result
else
fs.popFirstUnsafeMatch? goal

end Aesop.ForwardState
7 changes: 3 additions & 4 deletions Aesop/Rule/Forward.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Xavier Généreux, Jannis Limperg
-/

import Aesop.Forward.RuleInfo
import Aesop.Percent
import Aesop.Rule.Name
import Aesop.Forward.RuleInfo
import Aesop.RuleTac.Basic

set_option linter.missingDocs true

Expand All @@ -28,8 +29,7 @@ structure ForwardRule extends ForwardRuleInfo where
/-- The rule's name. Should be unique among all rules in a rule set. -/
name : RuleName
/-- The theorem from which this rule is derived. -/
-- FIXME What happens if this expr becomes invalid due to fvar renamings etc.?
expr : Expr
term : RuleTerm
/-- The rule's priority. -/
prio : ForwardRulePriority
deriving Inhabited
Expand All @@ -45,5 +45,4 @@ instance : Hashable ForwardRule :=
instance : Ord ForwardRule :=
⟨λ r₁ r₂ => compare r₁.name r₂.name⟩


end Aesop.ForwardRule
3 changes: 2 additions & 1 deletion Aesop/RuleTac/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ open Lean.Meta

namespace Aesop


/-! # Rule Tactic Types -/

-- TODO put docs on the structure fields instead of the structures
Expand Down Expand Up @@ -170,10 +169,12 @@ inductive CasesTarget
inductive RuleTerm
| const (decl : Name)
| term (term : Term)
deriving Inhabited

inductive ElabRuleTerm
| const (decl : Name)
| term (term : Term) (expr : Expr)
deriving Inhabited

namespace ElabRuleTerm

Expand Down
Loading

0 comments on commit ca2a27a

Please sign in to comment.