Skip to content

Commit

Permalink
Merge pull request #436 from Alex-Fischman/symbol-gen
Browse files Browse the repository at this point in the history
SymbolGen cleanup
  • Loading branch information
Alex-Fischman authored Oct 7, 2024
2 parents 015d25d + 3dee1f4 commit 75e1466
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 210 deletions.
54 changes: 26 additions & 28 deletions src/ast/desugar.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
use super::{Rewrite, Rule};
use crate::*;

pub struct Desugar {
pub(crate) fresh_gen: SymbolGen,
}

impl Default for Desugar {
fn default() -> Self {
Self {
// the default reserved string in egglog is "_"
fresh_gen: SymbolGen::new("_".repeat(2)),
}
}
}
#[derive(Default)]
pub struct Desugar {}

fn desugar_datatype(span: Span, name: Symbol, variants: Vec<Variant>) -> Vec<NCommand> {
vec![NCommand::Sort(span.clone(), name, None)]
Expand Down Expand Up @@ -104,7 +94,11 @@ fn desugar_birewrite(ruleset: Symbol, name: Symbol, rewrite: &Rewrite) -> Vec<NC
}

// TODO(yz): we can delete this code once we enforce that all rule bodies cannot read the database (except EqSort).
fn add_semi_naive_rule(desugar: &mut Desugar, rule: Rule) -> Option<Rule> {
fn add_semi_naive_rule(
desugar: &mut Desugar,
symbol_gen: &mut SymbolGen,
rule: Rule,
) -> Option<Rule> {
let mut new_rule = rule;
// Whenever an Let(_, expr@Call(...)) or Set(_, expr@Call(...)) is present in action,
// an additional seminaive rule should be created.
Expand All @@ -120,7 +114,7 @@ fn add_semi_naive_rule(desugar: &mut Desugar, rule: Rule) -> Option<Rule> {
if let Expr::Call(..) = expr {
add_new_rule = true;

let fresh_symbol = desugar.get_fresh();
let fresh_symbol = desugar.get_fresh(symbol_gen);
let fresh_var = Expr::Var(span.clone(), fresh_symbol);
let expr = std::mem::replace(expr, fresh_var.clone());
new_head_atoms.push(Fact::Eq(span.clone(), vec![fresh_var, expr]));
Expand Down Expand Up @@ -157,9 +151,10 @@ fn desugar_simplify(
expr: &Expr,
schedule: &Schedule,
span: Span,
symbol_gen: &mut SymbolGen,
) -> Vec<NCommand> {
let mut res = vec![NCommand::Push(1)];
let lhs = desugar.get_fresh();
let lhs = desugar.get_fresh(symbol_gen);
res.push(NCommand::CoreAction(Action::Let(
span.clone(),
lhs,
Expand All @@ -174,6 +169,7 @@ fn desugar_simplify(
expr: Expr::Var(span.clone(), lhs),
},
desugar,
symbol_gen,
false,
)
.unwrap(),
Expand All @@ -193,6 +189,7 @@ pub(crate) fn rewrite_name(rewrite: &Rewrite) -> String {
pub(crate) fn desugar_command(
command: Command,
desugar: &mut Desugar,
symbol_gen: &mut SymbolGen,
seminaive_transform: bool,
) -> Result<Vec<NCommand>, Error> {
let res = match command {
Expand Down Expand Up @@ -226,6 +223,7 @@ pub(crate) fn desugar_command(
return desugar_commands(
parse_program(Some(file), &s)?,
desugar,
symbol_gen,
seminaive_transform,
);
}
Expand All @@ -245,7 +243,7 @@ pub(crate) fn desugar_command(
}];

if seminaive_transform {
if let Some(new_rule) = add_semi_naive_rule(desugar, rule) {
if let Some(new_rule) = add_semi_naive_rule(desugar, symbol_gen, rule) {
result.push(NCommand::NormRule {
ruleset,
name,
Expand All @@ -266,7 +264,7 @@ pub(crate) fn desugar_command(
span,
expr,
schedule,
} => desugar_simplify(desugar, &expr, &schedule, span),
} => desugar_simplify(desugar, &expr, &schedule, span, symbol_gen),
Command::RunSchedule(sched) => {
vec![NCommand::RunSchedule(sched.clone())]
}
Expand All @@ -293,9 +291,9 @@ pub(crate) fn desugar_command(
// ((extract {fresh} {variants}))
// :ruleset {fresh_ruleset})
// (run {fresh_ruleset} 1)
let fresh = desugar.get_fresh();
let fresh_ruleset = desugar.get_fresh();
let fresh_rulename = desugar.get_fresh();
let fresh = desugar.get_fresh(symbol_gen);
let fresh_ruleset = desugar.get_fresh(symbol_gen);
let fresh_rulename = desugar.get_fresh(symbol_gen);
let rule = Rule {
span: span.clone(),
body: vec![Fact::Eq(
Expand Down Expand Up @@ -339,7 +337,7 @@ pub(crate) fn desugar_command(
vec![NCommand::Pop(span, num)]
}
Command::Fail(span, cmd) => {
let mut desugared = desugar_command(*cmd, desugar, seminaive_transform)?;
let mut desugared = desugar_command(*cmd, desugar, symbol_gen, seminaive_transform)?;

let last = desugared.pop().unwrap();
desugared.push(NCommand::Fail(span, Box::new(last)));
Expand All @@ -356,35 +354,35 @@ pub(crate) fn desugar_command(
pub(crate) fn desugar_commands(
program: Vec<Command>,
desugar: &mut Desugar,
symbol_gen: &mut SymbolGen,
seminaive_transform: bool,
) -> Result<Vec<NCommand>, Error> {
let mut res = vec![];
for command in program {
let desugared = desugar_command(command, desugar, seminaive_transform)?;
let desugared = desugar_command(command, desugar, symbol_gen, seminaive_transform)?;
res.extend(desugared);
}
Ok(res)
}

impl Clone for Desugar {
fn clone(&self) -> Self {
Self {
fresh_gen: self.fresh_gen.clone(),
}
Self {}
}
}

impl Desugar {
pub fn get_fresh(&mut self) -> Symbol {
self.fresh_gen.fresh(&"v".into())
pub fn get_fresh(&mut self, symbol_gen: &mut SymbolGen) -> Symbol {
symbol_gen.fresh(&"v".into())
}

pub(crate) fn desugar_program(
&mut self,
program: Vec<Command>,
symbol_gen: &mut SymbolGen,
seminaive_transform: bool,
) -> Result<Vec<NCommand>, Error> {
let res = desugar_commands(program, self, seminaive_transform)?;
let res = desugar_commands(program, self, symbol_gen, seminaive_transform)?;
Ok(res)
}
}
7 changes: 4 additions & 3 deletions src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,11 @@ impl Problem<AtomTerm, ArcSort> {
&mut self,
actions: &GenericCoreActions<Symbol, Symbol>,
typeinfo: &TypeInfo,
symbol_gen: &mut SymbolGen,
) -> Result<(), TypeError> {
let mut symbol_gen = SymbolGen::new("$".to_string());
for action in actions.0.iter() {
self.constraints
.extend(action.get_constraints(typeinfo, &mut symbol_gen)?);
.extend(action.get_constraints(typeinfo, symbol_gen)?);

// bound vars are added to range
match action {
Expand All @@ -495,14 +495,15 @@ impl Problem<AtomTerm, ArcSort> {
&mut self,
rule: &CoreRule,
typeinfo: &TypeInfo,
symbol_gen: &mut SymbolGen,
) -> Result<(), TypeError> {
let CoreRule {
span: _,
head,
body,
} = rule;
self.add_query(body, typeinfo)?;
self.add_actions(head, typeinfo)?;
self.add_actions(head, typeinfo, symbol_gen)?;
Ok(())
}

Expand Down
28 changes: 12 additions & 16 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ where
pub(crate) fn to_core_rule(
&self,
typeinfo: &TypeInfo,
mut fresh_gen: impl FreshGen<Head, Leaf>,
fresh_gen: &mut impl FreshGen<Head, Leaf>,
) -> Result<GenericCoreRule<HeadOrEq<Head>, Head, Leaf>, TypeError>
where
Leaf: SymbolLike,
Expand All @@ -807,10 +807,9 @@ where
body,
} = self;

let (body, _correspondence) = Facts(body.clone()).to_query(typeinfo, &mut fresh_gen);
let (body, _correspondence) = Facts(body.clone()).to_query(typeinfo, fresh_gen);
let mut binding = body.get_vars();
let (head, _correspondence) =
head.to_core_actions(typeinfo, &mut binding, &mut fresh_gen)?;
let (head, _correspondence) = head.to_core_actions(typeinfo, &mut binding, fresh_gen)?;
Ok(GenericCoreRule {
span: self.span.clone(),
body,
Expand All @@ -821,7 +820,7 @@ where
fn to_canonicalized_core_rule_impl(
&self,
typeinfo: &TypeInfo,
fresh_gen: impl FreshGen<Head, Leaf>,
fresh_gen: &mut impl FreshGen<Head, Leaf>,
value_eq: impl Fn(&GenericAtomTerm<Leaf>, &GenericAtomTerm<Leaf>) -> Head,
) -> Result<GenericCoreRule<Head, Head, Leaf>, TypeError>
where
Expand All @@ -836,19 +835,16 @@ impl ResolvedRule {
pub(crate) fn to_canonicalized_core_rule(
&self,
typeinfo: &TypeInfo,
fresh_gen: &mut SymbolGen,
) -> Result<ResolvedCoreRule, TypeError> {
let value_eq = &typeinfo.primitives.get(&Symbol::from("value-eq")).unwrap()[0];
let unit = typeinfo.get_sort_nofail::<UnitSort>();
self.to_canonicalized_core_rule_impl(
typeinfo,
ResolvedGen::new("$".to_string()),
|at1, at2| {
ResolvedCall::Primitive(SpecializedPrimitive {
primitive: value_eq.clone(),
input: vec![at1.output(typeinfo), at2.output(typeinfo)],
output: unit.clone(),
})
},
)
self.to_canonicalized_core_rule_impl(typeinfo, fresh_gen, |at1, at2| {
ResolvedCall::Primitive(SpecializedPrimitive {
primitive: value_eq.clone(),
input: vec![at1.output(typeinfo), at2.output(typeinfo)],
output: unit.clone(),
})
})
}
}
6 changes: 3 additions & 3 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl Debug for Function {
pub(crate) type DeferredMerge = (ValueVec, Value, Value);

impl Function {
pub(crate) fn new(egraph: &EGraph, decl: &ResolvedFunctionDecl) -> Result<Self, Error> {
pub(crate) fn new(egraph: &mut EGraph, decl: &ResolvedFunctionDecl) -> Result<Self, Error> {
let mut input = Vec::with_capacity(decl.schema.input.len());
for s in &decl.schema.input {
input.push(match egraph.type_info.sorts.get(s) {
Expand Down Expand Up @@ -125,7 +125,7 @@ impl Function {
let (actions, mapped_expr) = merge_expr.to_core_actions(
&egraph.type_info,
&mut binding.clone(),
&mut ResolvedGen::new("$".to_string()),
&mut egraph.symbol_gen,
)?;
let target = mapped_expr.get_corresponding_var_or_lit(&egraph.type_info);
let program = egraph
Expand All @@ -144,7 +144,7 @@ impl Function {
let (merge_action, _) = decl.merge_action.to_core_actions(
&egraph.type_info,
&mut binding.clone(),
&mut ResolvedGen::new("$".to_string()),
&mut egraph.symbol_gen,
)?;
let program = egraph
.compile_actions(&binding, &merge_action)
Expand Down
32 changes: 19 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ impl FromStr for RunMode {

#[derive(Clone)]
pub struct EGraph {
symbol_gen: SymbolGen,
egraphs: Vec<Self>,
unionfind: UnionFind,
pub(crate) desugar: Desugar,
Expand All @@ -447,6 +448,7 @@ pub struct EGraph {
impl Default for EGraph {
fn default() -> Self {
let mut egraph = Self {
symbol_gen: SymbolGen::new("$".to_string()),
egraphs: vec![],
unionfind: Default::default(),
functions: Default::default(),
Expand Down Expand Up @@ -1042,7 +1044,7 @@ impl EGraph {
ruleset: Symbol,
) -> Result<Symbol, Error> {
let name = Symbol::from(name);
let core_rule = rule.to_canonicalized_core_rule(&self.type_info)?;
let core_rule = rule.to_canonicalized_core_rule(&self.type_info, &mut self.symbol_gen)?;
let (query, actions) = (core_rule.body, core_rule.head);

let vars = query.get_vars();
Expand Down Expand Up @@ -1083,7 +1085,7 @@ impl EGraph {
let (actions, _) = actions.to_core_actions(
&self.type_info,
&mut Default::default(),
&mut ResolvedGen::new("$".to_string()),
&mut self.symbol_gen,
)?;
let program = self
.compile_actions(&Default::default(), &actions)
Expand All @@ -1094,7 +1096,7 @@ impl EGraph {
}

pub fn eval_expr(&mut self, expr: &Expr) -> Result<(ArcSort, Value), Error> {
let fresh_name = self.desugar.get_fresh();
let fresh_name = self.desugar.get_fresh(&mut self.symbol_gen);
let command = Command::Action(Action::Let(DUMMY_SPAN.clone(), fresh_name, expr.clone()));
self.run_program(vec![command])?;
// find the table with the same name as the fresh name
Expand All @@ -1110,7 +1112,7 @@ impl EGraph {
let (actions, mapped_expr) = expr.to_core_actions(
&self.type_info,
&mut Default::default(),
&mut ResolvedGen::new("$".to_string()),
&mut self.symbol_gen,
)?;
let target = mapped_expr.get_corresponding_var_or_lit(&self.type_info);
let program = self
Expand Down Expand Up @@ -1154,7 +1156,7 @@ impl EGraph {
head: ResolvedActions::default(),
body: facts.to_vec(),
};
let core_rule = rule.to_canonicalized_core_rule(&self.type_info)?;
let core_rule = rule.to_canonicalized_core_rule(&self.type_info, &mut self.symbol_gen)?;
let query = core_rule.body;
let ordering = &query.get_vars();
let query = self.compile_gj_query(query, ordering);
Expand Down Expand Up @@ -1382,7 +1384,9 @@ impl EGraph {
.into_iter()
.map(NCommand::CoreAction)
.collect::<Vec<_>>();
let commands: Vec<_> = self.type_info.typecheck_program(&commands)?;
let commands: Vec<_> = self
.type_info
.typecheck_program(&mut self.symbol_gen, &commands)?;
for command in commands {
self.run_command(command)?;
}
Expand All @@ -1398,20 +1402,22 @@ impl EGraph {

pub fn set_reserved_symbol(&mut self, sym: Symbol) {
assert!(
!self.desugar.fresh_gen.has_been_used(),
!self.symbol_gen.has_been_used(),
"Reserved symbol must be set before any symbols are generated"
);
self.desugar.fresh_gen = SymbolGen::new(sym.to_string());
self.symbol_gen = SymbolGen::new(sym.to_string());
}

fn process_command(&mut self, command: Command) -> Result<Vec<ResolvedNCommand>, Error> {
let program = self
.desugar
.desugar_program(vec![command], self.seminaive)?;
let program =
self.desugar
.desugar_program(vec![command], &mut self.symbol_gen, self.seminaive)?;

let program = self.type_info.typecheck_program(&program)?;
let program = self
.type_info
.typecheck_program(&mut self.symbol_gen, &program)?;

let program = remove_globals(&self.type_info, program, &mut self.desugar.fresh_gen);
let program = remove_globals(&self.type_info, program, &mut self.symbol_gen);

Ok(program)
}
Expand Down
2 changes: 1 addition & 1 deletion src/sort/fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ fn call_fn(egraph: &mut EGraph, name: &Symbol, types: Vec<ArcSort>, args: Vec<Va
.to_core_actions(
&egraph.type_info,
&mut binding.clone(),
&mut ResolvedGen::new("$".to_string()),
&mut egraph.symbol_gen,
)
.unwrap();
let target = mapped_expr.get_corresponding_var_or_lit(&egraph.type_info);
Expand Down
Loading

0 comments on commit 75e1466

Please sign in to comment.