diff --git a/src/ast/desugar.rs b/src/ast/desugar.rs index 568068ba..18e21c9f 100644 --- a/src/ast/desugar.rs +++ b/src/ast/desugar.rs @@ -44,17 +44,60 @@ pub(crate) fn desugar_command( name, variants, } => desugar_datatype(span, name, variants), - Command::Rewrite(ruleset, rewrite, subsume) => desugar_rewrite( - ruleset, - rewrite.to_string().replace('\"', "'").into(), - &rewrite, - subsume, - ), - Command::BiRewrite(ruleset, rewrite) => desugar_birewrite( - ruleset, - rewrite.to_string().replace('\"', "'").into(), - &rewrite, - ), + Command::Datatypes { span: _, datatypes } => { + // first declare all the datatypes as sorts, then add all explicit sorts which could refer to the datatypes, and finally add all the variants as functions + let mut res = vec![]; + for datatype in datatypes.iter() { + let span = datatype.0.clone(); + let name = datatype.1; + if let Subdatatypes::Variants(..) = datatype.2 { + res.push(NCommand::Sort(span, name, None)); + } + } + let (variants_vec, sorts): (Vec<_>, Vec<_>) = datatypes + .into_iter() + .partition(|datatype| matches!(datatype.2, Subdatatypes::Variants(..))); + + for sort in sorts { + let span = sort.0.clone(); + let name = sort.1; + let Subdatatypes::NewSort(sort, args) = sort.2 else { + unreachable!() + }; + res.push(NCommand::Sort(span, name, Some((sort, args)))); + } + + for variants in variants_vec { + let datatype = variants.1; + let Subdatatypes::Variants(variants) = variants.2 else { + unreachable!(); + }; + for variant in variants { + res.push(NCommand::Function(FunctionDecl { + name: variant.name, + schema: Schema { + input: variant.types, + output: datatype, + }, + merge: None, + merge_action: Actions::default(), + default: None, + cost: variant.cost, + unextractable: false, + ignore_viz: false, + span: variant.span, + })); + } + } + + res + } + Command::Rewrite(ruleset, rewrite, subsume) => { + desugar_rewrite(ruleset, rewrite_name(&rewrite).into(), &rewrite, subsume) + } + Command::BiRewrite(ruleset, rewrite) => { + desugar_birewrite(ruleset, rewrite_name(&rewrite).into(), &rewrite) + } Command::Include(span, file) => { let s = std::fs::read_to_string(&file) .unwrap_or_else(|_| panic!("{} Failed to read file {file}", span.get_quote())); @@ -356,3 +399,7 @@ fn desugar_simplify( res.push(NCommand::Pop(span, 1)); res } + +pub(crate) fn rewrite_name(rewrite: &Rewrite) -> String { + rewrite.to_string().replace('\"', "'") +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 948f2d2e..a4556fab 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -450,6 +450,12 @@ pub type Command = GenericCommand; pub type Subsume = bool; +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Subdatatypes { + Variants(Vec), + NewSort(Symbol, Vec), +} + /// A [`Command`] is the top-level construct in egglog. /// It includes defining rules, declaring functions, /// adding to tables, and running rules (via a [`Schedule`]). @@ -500,6 +506,10 @@ where name: Symbol, variants: Vec, }, + Datatypes { + span: Span, + datatypes: Vec<(Span, Symbol, Subdatatypes)>, + }, /// Create a new user-defined sort, which can then /// be used in new [`Command::Function`] declarations. /// The [`Command::Datatype`] command desugars directly to this command, with one [`Command::Function`] @@ -515,11 +525,7 @@ where /// ``` /// /// Now `MathVec` can be used as an input or output sort. - Sort( - Span, - Symbol, - Option<(Symbol, Vec>)>, - ), + Sort(Span, Symbol, Option<(Symbol, Vec)>), /// Declare an egglog function, which is a database table with a /// a functional dependency (also called a primary key) on its inputs to one output. /// @@ -887,6 +893,18 @@ where expr, schedule, } => list!("simplify", schedule, expr), + GenericCommand::Datatypes { span: _, datatypes } => { + let datatypes: Vec<_> = datatypes + .iter() + .map(|(_, name, variants)| match variants { + Subdatatypes::Variants(variants) => list!(name, ++ variants), + Subdatatypes::NewSort(head, args) => { + list!("sort", name, list!(head, ++ args)) + } + }) + .collect(); + list!("datatype*", ++ datatypes) + } } } } diff --git a/src/ast/parse.lalrpop b/src/ast/parse.lalrpop index ab0479db..59ed4d35 100644 --- a/src/ast/parse.lalrpop +++ b/src/ast/parse.lalrpop @@ -44,11 +44,19 @@ Comma: Vec = { } }; +RecDatatype: (Span, Symbol, Subdatatypes) = { + => (Span(srcfile.clone(), lo, hi), name, Subdatatypes::Variants(variants)), + "sort" LParen RParen => (Span(srcfile.clone(), lo, hi), name, Subdatatypes::NewSort(head, exprs)), +} + Command: Command = { LParen "set-option" RParen => Command::SetOption { name, value }, "datatype" => Command::Datatype { span: Span(srcfile.clone(), lo, hi), name, variants }, "sort" LParen RParen => Command::Sort (Span(srcfile.clone(), lo, hi), name, Some((head, tail))), "sort" => Command::Sort (Span(srcfile.clone(), lo, hi), name, None), + "datatype*" + + => Command::Datatypes { span: Span(srcfile.clone(), lo, hi), datatypes }, "function" >)?> diff --git a/tests/datatypes.egg b/tests/datatypes.egg new file mode 100644 index 00000000..9f5cfa86 --- /dev/null +++ b/tests/datatypes.egg @@ -0,0 +1,11 @@ +(datatype* + (Math + (Add Math Math) + (Sum MathVec) + (B Bool)) + (sort MathVec (Vec Math)) + (Bool + (True) + (False))) + +(let expr (Add (Sum (vec-of (B (True)) (B (False)))) (B (True))))