diff --git a/src/ast/desugar.rs b/src/ast/desugar.rs index 2d1cbb49..3eb8ebca 100644 --- a/src/ast/desugar.rs +++ b/src/ast/desugar.rs @@ -44,17 +44,55 @@ pub(crate) fn desugar_command( name, variants, } => desugar_datatype(span, name, variants), - Command::Rewrite(ruleset, rewrite, subsume) => desugar_rewrite( - ruleset, - rewrite_name(&rewrite), - &rewrite, - subsume, - ), - Command::BiRewrite(ruleset, rewrite) => desugar_birewrite( - ruleset, - rewrite_name(&rewrite), - &rewrite, - ), + Command::Datatypes { span: _, datatypes } => { + let mut res = vec![]; + for datatype in datatypes.iter() { + let span = datatype.0.clone(); + let name = datatype.1; + if datatype.2.is_ok() { + res.push(NCommand::Sort(span, name, None)); + } + } + let (variants_vec, sorts): (Vec<_>, Vec<_>) = datatypes + .into_iter() + .partition(|datatype| datatype.2.is_ok()); + + for sort in sorts { + let span = sort.0.clone(); + let name = sort.1; + let constructor = sort.2.unwrap_err(); + res.push(NCommand::Sort(span, name, Some(constructor))); + } + + for variants in variants_vec { + let datatype = variants.1; + let variants = variants.2.unwrap(); + for variant in variants { + res.push(NCommand::Function(FunctionDecl { + name: variant.name, + schema: Schema { + input: variant.types, + output: datatype, + }, + merge: None, + merge_action: Actions::default(), + default: None, + cost: variant.cost, + unextractable: false, + ignore_viz: false, + span: variant.span, + })); + } + } + + res + } + Command::Rewrite(ruleset, rewrite, subsume) => { + desugar_rewrite(ruleset, rewrite_name(&rewrite).into(), &rewrite, subsume) + } + Command::BiRewrite(ruleset, rewrite) => { + desugar_birewrite(ruleset, rewrite_name(&rewrite).into(), &rewrite) + } Command::Include(span, file) => { let s = std::fs::read_to_string(&file) .unwrap_or_else(|_| panic!("{} Failed to read file {file}", span.get_quote()));