Skip to content

Commit

Permalink
Support recursive datatypes in egglog (#432)
Browse files Browse the repository at this point in the history
* datatype*

* Fix build due to bad merge

* fix datatype sexp

* use Subdatatypes than Result

* fix

* Update src/ast/desugar.rs

Co-authored-by: Saul Shanabrook <s.shanabrook@gmail.com>

---------

Co-authored-by: Saul Shanabrook <s.shanabrook@gmail.com>
  • Loading branch information
yihozhang and saulshanabrook authored Oct 11, 2024
1 parent 8bacebf commit bb97e1e
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 16 deletions.
69 changes: 58 additions & 11 deletions src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,60 @@ pub(crate) fn desugar_command(
name,
variants,
} => desugar_datatype(span, name, variants),
Command::Rewrite(ruleset, rewrite, subsume) => desugar_rewrite(
ruleset,
rewrite.to_string().replace('\"', "'").into(),
&rewrite,
subsume,
),
Command::BiRewrite(ruleset, rewrite) => desugar_birewrite(
ruleset,
rewrite.to_string().replace('\"', "'").into(),
&rewrite,
),
Command::Datatypes { span: _, datatypes } => {
// first declare all the datatypes as sorts, then add all explicit sorts which could refer to the datatypes, and finally add all the variants as functions
let mut res = vec![];
for datatype in datatypes.iter() {
let span = datatype.0.clone();
let name = datatype.1;
if let Subdatatypes::Variants(..) = datatype.2 {
res.push(NCommand::Sort(span, name, None));
}
}
let (variants_vec, sorts): (Vec<_>, Vec<_>) = datatypes
.into_iter()
.partition(|datatype| matches!(datatype.2, Subdatatypes::Variants(..)));

for sort in sorts {
let span = sort.0.clone();
let name = sort.1;
let Subdatatypes::NewSort(sort, args) = sort.2 else {
unreachable!()
};
res.push(NCommand::Sort(span, name, Some((sort, args))));
}

for variants in variants_vec {
let datatype = variants.1;
let Subdatatypes::Variants(variants) = variants.2 else {
unreachable!();
};
for variant in variants {
res.push(NCommand::Function(FunctionDecl {
name: variant.name,
schema: Schema {
input: variant.types,
output: datatype,
},
merge: None,
merge_action: Actions::default(),
default: None,
cost: variant.cost,
unextractable: false,
ignore_viz: false,
span: variant.span,
}));
}
}

res
}
Command::Rewrite(ruleset, rewrite, subsume) => {
desugar_rewrite(ruleset, rewrite_name(&rewrite).into(), &rewrite, subsume)
}
Command::BiRewrite(ruleset, rewrite) => {
desugar_birewrite(ruleset, rewrite_name(&rewrite).into(), &rewrite)
}
Command::Include(span, file) => {
let s = std::fs::read_to_string(&file)
.unwrap_or_else(|_| panic!("{} Failed to read file {file}", span.get_quote()));
Expand Down Expand Up @@ -356,3 +399,7 @@ fn desugar_simplify(
res.push(NCommand::Pop(span, 1));
res
}

pub(crate) fn rewrite_name(rewrite: &Rewrite) -> String {
rewrite.to_string().replace('\"', "'")
}
28 changes: 23 additions & 5 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,12 @@ pub type Command = GenericCommand<Symbol, Symbol>;

pub type Subsume = bool;

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Subdatatypes {
Variants(Vec<Variant>),
NewSort(Symbol, Vec<Expr>),
}

/// A [`Command`] is the top-level construct in egglog.
/// It includes defining rules, declaring functions,
/// adding to tables, and running rules (via a [`Schedule`]).
Expand Down Expand Up @@ -500,6 +506,10 @@ where
name: Symbol,
variants: Vec<Variant>,
},
Datatypes {
span: Span,
datatypes: Vec<(Span, Symbol, Subdatatypes)>,
},
/// Create a new user-defined sort, which can then
/// be used in new [`Command::Function`] declarations.
/// The [`Command::Datatype`] command desugars directly to this command, with one [`Command::Function`]
Expand All @@ -515,11 +525,7 @@ where
/// ```
///
/// Now `MathVec` can be used as an input or output sort.
Sort(
Span,
Symbol,
Option<(Symbol, Vec<GenericExpr<Symbol, Symbol>>)>,
),
Sort(Span, Symbol, Option<(Symbol, Vec<Expr>)>),
/// Declare an egglog function, which is a database table with a
/// a functional dependency (also called a primary key) on its inputs to one output.
///
Expand Down Expand Up @@ -887,6 +893,18 @@ where
expr,
schedule,
} => list!("simplify", schedule, expr),
GenericCommand::Datatypes { span: _, datatypes } => {
let datatypes: Vec<_> = datatypes
.iter()
.map(|(_, name, variants)| match variants {
Subdatatypes::Variants(variants) => list!(name, ++ variants),
Subdatatypes::NewSort(head, args) => {
list!("sort", name, list!(head, ++ args))
}
})
.collect();
list!("datatype*", ++ datatypes)
}
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/ast/parse.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,19 @@ Comma<T>: Vec<T> = {
}
};

RecDatatype: (Span, Symbol, Subdatatypes) = {
<lo:LParen> <name:Ident> <variants:(Variant)*> <hi:RParen> => (Span(srcfile.clone(), lo, hi), name, Subdatatypes::Variants(variants)),
<lo:LParen> "sort" <name:Ident> LParen <head:Ident> <exprs:(Expr)*> RParen <hi:RParen> => (Span(srcfile.clone(), lo, hi), name, Subdatatypes::NewSort(head, exprs)),
}

Command: Command = {
LParen "set-option" <name:Ident> <value:Expr> RParen => Command::SetOption { name, value },
<lo:LParen> "datatype" <name:Ident> <variants:(Variant)*> <hi:RParen> => Command::Datatype { span: Span(srcfile.clone(), lo, hi), name, variants },
<lo:LParen> "sort" <name:Ident> LParen <head:Ident> <tail:(Expr)*> RParen <hi:RParen> => Command::Sort (Span(srcfile.clone(), lo, hi), name, Some((head, tail))),
<lo:LParen> "sort" <name:Ident> <hi:RParen> => Command::Sort (Span(srcfile.clone(), lo, hi), name, None),
<lo:LParen> "datatype*"
<datatypes:RecDatatype*>
<hi:RParen> => Command::Datatypes { span: Span(srcfile.clone(), lo, hi), datatypes },
<lo:LParen> "function" <name:Ident> <schema:Schema> <cost:Cost>
<unextractable:(":unextractable")?>
<merge_action:(":on_merge" <List<Action>>)?>
Expand Down
11 changes: 11 additions & 0 deletions tests/datatypes.egg
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
(datatype*
(Math
(Add Math Math)
(Sum MathVec)
(B Bool))
(sort MathVec (Vec Math))
(Bool
(True)
(False)))

(let expr (Add (Sum (vec-of (B (True)) (B (False)))) (B (True))))

0 comments on commit bb97e1e

Please sign in to comment.