Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type-safe and static higher-order functions #452

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,8 @@ impl<'a> ActionCompiler<'a> {
}

fn do_prim(&mut self, prim: &SpecializedPrimitive) {
self.instructions.push(Instruction::CallPrimitive(
prim.primitive.clone(),
prim.input.len(),
));
self.instructions
.push(Instruction::CallPrimitive(prim.clone(), prim.input.len()));
}
}

Expand All @@ -126,7 +124,7 @@ enum Instruction {
CallFunction(Symbol, bool),
/// Pop primitive arguments off the stack, calls the primitive,
/// and push the result onto the stack.
CallPrimitive(Primitive, usize),
CallPrimitive(SpecializedPrimitive, usize),
/// Pop function arguments off the stack and either deletes or subsumes the corresponding row
/// in the function.
Change(Change, Symbol),
Expand Down Expand Up @@ -321,11 +319,13 @@ impl EGraph {
Instruction::CallPrimitive(p, arity) => {
let new_len = stack.len() - arity;
let values = &stack[new_len..];
if let Some(value) = p.apply(values, Some(self)) {
if let Some(value) =
p.primitive.apply(values, (&p.input, &p.output), Some(self))
{
stack.truncate(new_len);
stack.push(value);
} else {
return Err(Error::PrimitiveError(p.clone(), values.to_vec()));
return Err(Error::PrimitiveError(p.primitive.clone(), values.to_vec()));
}
}
Instruction::Set(f) => {
Expand Down
2 changes: 1 addition & 1 deletion src/ast/parse.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ F64: OrderedFloat<f64> = {
"inf" => OrderedFloat::<f64>(f64::INFINITY),
"-inf" => OrderedFloat::<f64>(f64::NEG_INFINITY),
}
Ident: Symbol = <s:r"(([[:alpha:]][\w-]*)|([-+*/?!=<>&|^/%_]))+"> => s.parse().unwrap();
Ident: Symbol = <s:r"(([[:alpha:]][\w-]*)|([-+*/?!=<>&|^/%_#]))+"> => s.parse().unwrap();
SymString: Symbol = <String> => Symbol::from(<>);

String: String = <r#"("[^"]*")+"#> => {
Expand Down
141 changes: 95 additions & 46 deletions src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,40 @@ pub enum ImpossibleConstraint {
actual_output: ArcSort,
actual_input: Vec<ArcSort>,
},
CompileTimeConstantExpected {
span: Span,
sort: ArcSort,
},
UnboundedFunction {
head: Symbol,
span: Span,
},
}

#[derive(Debug)]
pub enum Constraint<Var, Value> {
pub enum Constraint<'a, Var, Value> {
Eq(Var, Var),
Assign(Var, Value),
And(Vec<Constraint<Var, Value>>),
And(Vec<Constraint<'a, Var, Value>>),
// Exactly one of the constraints holds
// and all others are false
Xor(Vec<Constraint<Var, Value>>),
Xor(Vec<Constraint<'a, Var, Value>>),
LazyConstraint(Var, Box<dyn Fn(&Value) -> Self + 'a>),
Impossible(ImpossibleConstraint),
}

impl<'a, Var: Debug, Value: Debug> Debug for Constraint<'a, Var, Value> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Constraint::Eq(x, y) => write!(f, "{:?} = {:?}", x, y),
Constraint::Assign(x, v) => write!(f, "{:?} = {:?}", x, v),
Constraint::And(cs) => write!(f, "And({:?})", cs),
Constraint::Xor(cs) => write!(f, "Xor({:?})", cs),
Constraint::LazyConstraint(x, _) => write!(f, "LazyConstraint({:?}, trigger=...)", x),
Constraint::Impossible(c) => write!(f, "Impossible({:?})", c),
}
}
}

pub enum ConstraintError<Var, Value> {
InconsistentConstraint(Var, Value, Value),
UnconstrainedVar(Var),
Expand Down Expand Up @@ -84,11 +105,17 @@ impl ConstraintError<AtomTerm, ArcSort> {
actual_output.clone(),
actual_input.clone(),
),
ConstraintError::ImpossibleCaseIdentified(
ImpossibleConstraint::CompileTimeConstantExpected { span, sort },
) => TypeError::CompileTimeConstantExpected(sort.clone(), span.clone()),
ConstraintError::ImpossibleCaseIdentified(
ImpossibleConstraint::UnboundedFunction { head, span },
) => TypeError::UnboundFunction(*head, span.clone()),
}
}
}

impl<Var, Value> Constraint<Var, Value>
impl<'a, Var, Value> Constraint<'a, Var, Value>
where
Var: Eq + PartialEq + Hash + Clone + Debug,
Value: Clone + Debug,
Expand All @@ -97,11 +124,12 @@ where
/// If there's a conflict, returns the conflicting variable, the assigned conflicting types.
/// Otherwise, return whether the assignment is updated.
fn update<K: Eq>(
&self,
&mut self,
assignment: &mut Assignment<Var, Value>,
key: impl Fn(&Value) -> K + Copy,
) -> Result<bool, ConstraintError<Var, Value>> {
match self {
let mut new_self = None;
let result = match self {
Constraint::Eq(x, y) => match (assignment.0.get(x), assignment.0.get(y)) {
(Some(value), None) => {
assignment.insert(y.clone(), value.clone());
Expand Down Expand Up @@ -198,17 +226,33 @@ where
}
Ok(updated)
}
Constraint::LazyConstraint(var, trigger) => {
if assignment.0.contains_key(var) {
//
let value = assignment.0.get(var).unwrap();
let mut constraint = trigger(value);
constraint.update(assignment, key)?;
new_self = Some(constraint);
Ok(true)
} else {
Ok(false)
}
}
};
if let Some(new_self) = new_self {
*self = new_self;
}
result
}
}

#[derive(Debug)]
pub struct Problem<Var, Value> {
pub constraints: Vec<Constraint<Var, Value>>,
pub struct Problem<'a, Var, Value> {
pub constraints: Vec<Constraint<'a, Var, Value>>,
pub range: HashSet<Var>,
}

impl Default for Problem<AtomTerm, ArcSort> {
impl<'a> Default for Problem<'a, AtomTerm, ArcSort> {
fn default() -> Self {
Self {
constraints: vec![],
Expand Down Expand Up @@ -422,20 +466,20 @@ impl Assignment<AtomTerm, ArcSort> {
}
}

impl<Var, Value> Problem<Var, Value>
impl<'a, Var, Value> Problem<'a, Var, Value>
where
Var: Eq + PartialEq + Hash + Clone + Debug,
Value: Clone + Debug,
{
pub(crate) fn solve<K: Eq + Debug>(
&self,
mut self,
key: impl Fn(&Value) -> K + Copy,
) -> Result<Assignment<Var, Value>, ConstraintError<Var, Value>> {
let mut assignment = Assignment(HashMap::default());
let mut changed = true;
while changed {
changed = false;
for constraint in self.constraints.iter() {
for constraint in self.constraints.iter_mut() {
changed |= constraint.update(&mut assignment, key)?;
}
}
Expand All @@ -453,11 +497,11 @@ where
}
}

impl Problem<AtomTerm, ArcSort> {
impl<'a> Problem<'a, AtomTerm, ArcSort> {
pub(crate) fn add_query(
&mut self,
query: &Query<SymbolOrEq, Symbol>,
typeinfo: &TypeInfo,
typeinfo: &'a TypeInfo,
) -> Result<(), TypeError> {
self.constraints.extend(query.get_constraints(typeinfo)?);
self.range.extend(query.atom_terms());
Expand All @@ -467,7 +511,7 @@ impl Problem<AtomTerm, ArcSort> {
pub fn add_actions(
&mut self,
actions: &GenericCoreActions<Symbol, Symbol>,
typeinfo: &TypeInfo,
typeinfo: &'a TypeInfo,
symbol_gen: &mut SymbolGen,
) -> Result<(), TypeError> {
for action in actions.0.iter() {
Expand All @@ -491,7 +535,7 @@ impl Problem<AtomTerm, ArcSort> {
pub(crate) fn add_rule(
&mut self,
rule: &CoreRule,
typeinfo: &TypeInfo,
typeinfo: &'a TypeInfo,
symbol_gen: &mut SymbolGen,
) -> Result<(), TypeError> {
let CoreRule {
Expand All @@ -517,18 +561,18 @@ impl Problem<AtomTerm, ArcSort> {
}

impl CoreAction {
pub(crate) fn get_constraints(
pub(crate) fn get_constraints<'a>(
&self,
typeinfo: &TypeInfo,
typeinfo: &'a TypeInfo,
symbol_gen: &mut SymbolGen,
) -> Result<Vec<Constraint<AtomTerm, ArcSort>>, TypeError> {
) -> Result<Vec<Constraint<'a, AtomTerm, ArcSort>>, TypeError> {
match self {
CoreAction::Let(span, symbol, f, args) => {
let mut args = args.clone();
args.push(AtomTerm::Var(span.clone(), *symbol));

Ok(get_literal_and_global_constraints(&args, typeinfo)
.chain(get_atom_application_constraints(f, &args, span, typeinfo)?)
.chain(get_atom_application_constraints(f, &args, span, typeinfo))
.collect())
}
CoreAction::Set(span, head, args, rhs) => {
Expand All @@ -538,7 +582,7 @@ impl CoreAction {
Ok(get_literal_and_global_constraints(&args, typeinfo)
.chain(get_atom_application_constraints(
head, &args, span, typeinfo,
)?)
))
.collect())
}
CoreAction::Change(span, _change, head, args) => {
Expand All @@ -550,7 +594,7 @@ impl CoreAction {
Ok(get_literal_and_global_constraints(&args, typeinfo)
.chain(get_atom_application_constraints(
head, &args, span, typeinfo,
)?)
))
.collect())
}
CoreAction::Union(_ann, lhs, rhs) => Ok(get_literal_and_global_constraints(
Expand Down Expand Up @@ -584,10 +628,10 @@ impl CoreAction {
}

impl Atom<SymbolOrEq> {
pub fn get_constraints(
pub fn get_constraints<'a>(
&self,
type_info: &TypeInfo,
) -> Result<Vec<Constraint<AtomTerm, ArcSort>>, TypeError> {
type_info: &'a TypeInfo,
) -> Result<Vec<Constraint<'a, AtomTerm, ArcSort>>, TypeError> {
let literal_constraints = get_literal_and_global_constraints(&self.args, type_info);
match &self.head {
SymbolOrEq::Eq => {
Expand All @@ -603,25 +647,25 @@ impl Atom<SymbolOrEq> {
SymbolOrEq::Symbol(head) => Ok(literal_constraints
.chain(get_atom_application_constraints(
head, &self.args, &self.span, type_info,
)?)
))
.collect()),
}
}
}

fn get_atom_application_constraints(
pub(crate) fn get_atom_application_constraints<'a>(
head: &Symbol,
args: &[AtomTerm],
span: &Span,
type_info: &TypeInfo,
) -> Result<Vec<Constraint<AtomTerm, ArcSort>>, TypeError> {
type_info: &'a TypeInfo,
) -> Vec<Constraint<'a, AtomTerm, ArcSort>> {
// An atom can have potentially different semantics due to polymorphism
// e.g. (set-empty) can mean any empty set with some element type.
// To handle this, we collect each possible instantiations of an atom
// (where each instantiation is a vec of constraints, thus vec of vec)
// into `xor_constraints`.
// `Constraint::Xor` means one and only one of the instantiation can hold.
let mut xor_constraints: Vec<Vec<Constraint<AtomTerm, ArcSort>>> = vec![];
let mut xor_constraints: Vec<Vec<Constraint<'a, AtomTerm, ArcSort>>> = vec![];

// function atom constraints
if let Some(typ) = type_info.func_types.get(head) {
Expand Down Expand Up @@ -664,18 +708,23 @@ fn get_atom_application_constraints(
// do literal and global variable constraints first
// as they are the most "informative"
match xor_constraints.len() {
0 => Err(TypeError::UnboundFunction(*head, span.clone())),
1 => Ok(xor_constraints.pop().unwrap()),
_ => Ok(vec![Constraint::Xor(
0 => vec![Constraint::Impossible(
ImpossibleConstraint::UnboundedFunction {
head: *head,
span: span.clone(),
},
)],
1 => xor_constraints.pop().unwrap(),
_ => vec![Constraint::Xor(
xor_constraints.into_iter().map(Constraint::And).collect(),
)]),
)],
}
}

fn get_literal_and_global_constraints<'a>(
fn get_literal_and_global_constraints<'a, 'b>(
args: &'a [AtomTerm],
type_info: &'a TypeInfo,
) -> impl Iterator<Item = Constraint<AtomTerm, ArcSort>> + 'a {
) -> impl Iterator<Item = Constraint<'b, AtomTerm, ArcSort>> + 'a {
args.iter().filter_map(|arg| {
match arg {
AtomTerm::Var(_, _) => None,
Expand All @@ -696,11 +745,11 @@ fn get_literal_and_global_constraints<'a>(
}

pub trait TypeConstraint {
fn get(
fn get<'a>(
&self,
arguments: &[AtomTerm],
typeinfo: &TypeInfo,
) -> Vec<Constraint<AtomTerm, ArcSort>>;
typeinfo: &'a TypeInfo,
) -> Vec<Constraint<'a, AtomTerm, ArcSort>>;
}

/// Construct a set of `Assign` constraints that fully constrain the type of arguments
Expand All @@ -721,11 +770,11 @@ impl SimpleTypeConstraint {
}

impl TypeConstraint for SimpleTypeConstraint {
fn get(
fn get<'a>(
&self,
arguments: &[AtomTerm],
_typeinfo: &TypeInfo,
) -> Vec<Constraint<AtomTerm, ArcSort>> {
_typeinfo: &'a TypeInfo,
) -> Vec<Constraint<'a, AtomTerm, ArcSort>> {
if arguments.len() != self.sorts.len() {
vec![Constraint::Impossible(
ImpossibleConstraint::ArityMismatch {
Expand Down Expand Up @@ -796,11 +845,11 @@ impl AllEqualTypeConstraint {
}

impl TypeConstraint for AllEqualTypeConstraint {
fn get(
fn get<'a>(
&self,
mut arguments: &[AtomTerm],
_typeinfo: &TypeInfo,
) -> Vec<Constraint<AtomTerm, ArcSort>> {
_typeinfo: &'a TypeInfo,
) -> Vec<Constraint<'a, AtomTerm, ArcSort>> {
if arguments.is_empty() {
panic!("all arguments should have length > 0")
}
Expand Down
Loading
Loading