Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rebuilding and extraction bugs for EqSort containers #191

Merged
merged 21 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 45 additions & 46 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ use crate::termdag::{Term, TermDag};
use crate::util::HashMap;
use crate::{ArcSort, EGraph, Function, Id, Value};

type Cost = usize;
pub type Cost = usize;

#[derive(Debug)]
pub(crate) struct Node<'a> {
sym: Symbol,
inputs: &'a [Value],
}

pub(crate) struct Extractor<'a> {
costs: HashMap<Id, (Cost, Term)>,
pub struct Extractor<'a> {
pub costs: HashMap<Id, (Cost, Term)>,
ctors: Vec<Symbol>,
egraph: &'a EGraph,
}
Expand All @@ -31,7 +31,29 @@ impl EGraph {
}

pub fn extract(&self, value: Value, termdag: &mut TermDag, arcsort: &ArcSort) -> (Cost, Term) {
Extractor::new(self, termdag).find_best(value, termdag, arcsort)
let extractor = Extractor::new(self, termdag);
extractor
.find_best(value, termdag, arcsort)
.unwrap_or_else(|| {
log::error!("No cost for {:?}", value);
for func in self.functions.values() {
for (inputs, output) in func.nodes.iter() {
if output.value == value {
log::error!("Found unextractable function: {:?}", func.decl.name);
log::error!("Inputs: {:?}", inputs);
log::error!(
"{:?}",
inputs
.iter()
.map(|input| extractor.costs.get(&extractor.find(input)))
.collect::<Vec<_>>()
);
}
}
}

panic!("No cost for {:?}", value)
})
}

pub fn extract_variants(
Expand All @@ -57,7 +79,9 @@ impl EGraph {
.filter_map(|(inputs, output)| {
(&output.value == output_value).then(|| {
let node = Node { sym, inputs };
ext.expr_from_node(&node, termdag)
ext.expr_from_node(&node, termdag).expect(
"extract_variants should be called after extractor initialization",
)
})
})
.collect()
Expand Down Expand Up @@ -89,46 +113,29 @@ impl<'a> Extractor<'a> {
extractor
}

fn expr_from_node(&self, node: &Node, termdag: &mut TermDag) -> Term {
fn expr_from_node(&self, node: &Node, termdag: &mut TermDag) -> Option<Term> {
let mut children = vec![];
for value in node.inputs {
let arcsort = self.egraph.get_sort(value).unwrap();
children.push(self.find_best(*value, termdag, arcsort).1)
children.push(self.find_best(*value, termdag, arcsort)?.1)
}

termdag.make(node.sym, children)
Some(termdag.make(node.sym, children))
}

pub fn find_best(&self, value: Value, termdag: &mut TermDag, sort: &ArcSort) -> (Cost, Term) {
pub fn find_best(
&self,
value: Value,
termdag: &mut TermDag,
sort: &ArcSort,
) -> Option<(Cost, Term)> {
if sort.is_eq_sort() {
let id = self.find(&value);
let (cost, node) = self
.costs
.get(&id)
.unwrap_or_else(|| {
log::error!("No cost for {:?}", value);
for func in self.egraph.functions.values() {
for (inputs, output) in func.nodes.iter() {
if output.value == value {
log::error!("Found unextractable function: {:?}", func.decl.name);
log::error!("Inputs: {:?}", inputs);
log::error!(
"{:?}",
inputs
.iter()
.map(|input| self.costs.get(&self.find(input)))
.collect::<Vec<_>>()
);
}
}
}

panic!("No cost for {:?}", value)
})
.clone();
(cost, node)
let (cost, node) = self.costs.get(&id)?.clone();
Some((cost, node))
} else {
(0, termdag.expr_to_term(&sort.make_expr(self.egraph, value)))
let (cost, node) = sort.extract_expr(self.egraph, value, self, termdag)?;
Some((cost, termdag.expr_to_term(&node)))
}
}

Expand All @@ -142,17 +149,9 @@ impl<'a> Extractor<'a> {
let types = &function.schema.input;
let mut terms: Vec<Term> = vec![];
for (ty, value) in types.iter().zip(children) {
cost = cost.saturating_add(if ty.is_eq_sort() {
let id = self.egraph.find(Id::from(value.bits as usize));
// TODO costs should probably map values?
let (cost, term) = self.costs.get(&id)?;
terms.push(term.clone());
*cost
} else {
let term = termdag.expr_to_term(&ty.make_expr(self.egraph, *value));
terms.push(term);
1
});
let (term_cost, term) = self.find_best(*value, termdag, ty)?;
terms.push(term.clone());
cost = cost.saturating_add(term_cost);
}
Some((terms, cost))
}
Expand Down
9 changes: 8 additions & 1 deletion src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,14 @@ impl Function {
) -> Result<(usize, Vec<DeferredMerge>), Error> {
// Make sure indexes are up to date.
self.update_indexes(self.nodes.num_offsets());
if self.schema.input.iter().all(|s| !s.is_eq_sort()) && !self.schema.output.is_eq_sort() {
if self
.schema
.input
.iter()
.all(|s| !s.is_eq_sort() && !s.is_eq_container_sort())
&& !self.schema.output.is_eq_sort()
&& !self.schema.output.is_eq_container_sort()
{
return Ok((std::mem::take(&mut self.updates), Default::default()));
}
let mut deferred_merges = Vec::new();
Expand Down
16 changes: 9 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,11 @@ impl EGraph {
}

// now update global bindings
let mut new_global_bindings = self.global_bindings.clone();
for (_sym, (_sort, value, ts)) in new_global_bindings.iter_mut() {
*value = self.bad_find_value(*value);
*ts = self.timestamp;
let mut new_global_bindings = std::mem::take(&mut self.global_bindings);
for (_sym, (sort, value, ts)) in new_global_bindings.iter_mut() {
if sort.canonicalize(value, &self.unionfind) {
*ts = self.timestamp;
}
}
self.global_bindings = new_global_bindings;

Expand Down Expand Up @@ -494,18 +495,19 @@ impl EGraph {
let mut children = Vec::new();
for (a, a_type) in ins.iter().copied().zip(&schema.input) {
if a_type.is_eq_sort() {
children.push(extractor.find_best(a, &mut termdag, a_type).1);
children.push(extractor.find_best(a, &mut termdag, a_type).unwrap().1);
} else {
children.push(termdag.expr_to_term(&a_type.make_expr(self, a)));
children.push(termdag.expr_to_term(&a_type.make_expr(self, a).1));
};
}

let out = if schema.output.is_eq_sort() {
extractor
.find_best(out.value, &mut termdag, &schema.output)
.unwrap()
.1
} else {
termdag.expr_to_term(&schema.output.make_expr(self, out.value))
termdag.expr_to_term(&schema.output.make_expr(self, out.value).1)
};
terms.push((termdag.make(sym, children), out));
}
Expand Down
2 changes: 1 addition & 1 deletion src/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ impl EGraph {
println!("{} is a container sort", sort.name());
sort.name().to_string()
} else {
sort.make_expr(self, *value).to_string()
sort.make_expr(self, *value).1.to_string()
};
egraph.nodes.insert(
node_id.clone(),
Expand Down
7 changes: 5 additions & 2 deletions src/sort/f64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ impl Sort for F64Sort {

}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> Expr {
fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
assert!(value.tag == self.name());
Expr::Lit(Literal::F64(OrderedFloat(f64::from_bits(value.bits))))
(
1,
Expr::Lit(Literal::F64(OrderedFloat(f64::from_bits(value.bits)))),
)
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/sort/i64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ impl Sort for I64Sort {

}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> Expr {
fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
assert!(value.tag == self.name());
Expr::Lit(Literal::Int(value.bits as _))
(1, Expr::Lit(Literal::Int(value.bits as _)))
}
}

Expand Down
26 changes: 20 additions & 6 deletions src/sort/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,33 @@ impl Sort for MapSort {
});
}

fn make_expr(&self, egraph: &EGraph, value: Value) -> Expr {
fn make_expr(&self, egraph: &EGraph, value: Value) -> (Cost, Expr) {
let mut termdag = TermDag::default();
let extractor = Extractor::new(egraph, &mut termdag);
self.extract_expr(egraph, value, &extractor, &mut termdag)
.expect("Extraction should be successful since extractor has been fully initialized")
}

fn extract_expr(
&self,
_egraph: &EGraph,
value: Value,
extractor: &Extractor,
termdag: &mut TermDag,
) -> Option<(Cost, Expr)> {
let map = ValueMap::load(self, &value);
let mut expr = Expr::call("map-empty", []);
let mut termdag = TermDag::default();
let mut cost = 0usize;
for (k, v) in map.iter().rev() {
let k = egraph.extract(*k, &mut termdag, &self.key).1;
let v = egraph.extract(*v, &mut termdag, &self.value).1;
let k = extractor.find_best(*k, termdag, &self.key)?;
let v = extractor.find_best(*v, termdag, &self.value)?;
cost = cost.saturating_add(k.0).saturating_add(v.0);
expr = Expr::call(
"map-insert",
[expr, termdag.term_to_expr(&k), termdag.term_to_expr(&v)],
[expr, termdag.term_to_expr(&k.1), termdag.term_to_expr(&v.1)],
)
}
expr
Some((cost, expr))
}
}

Expand Down
21 changes: 19 additions & 2 deletions src/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub use set::*;
mod vec;
pub use vec::*;

use crate::extract::{Cost, Extractor};
use crate::*;

pub trait Sort: Any + Send + Sync + Debug {
Expand Down Expand Up @@ -69,7 +70,23 @@ pub trait Sort: Any + Send + Sync + Debug {
let _ = info;
}

fn make_expr(&self, egraph: &EGraph, value: Value) -> Expr;
/// Extracting an expression (with smallest cost) out of a primitive value
fn make_expr(&self, egraph: &EGraph, value: Value) -> (Cost, Expr);

/// For values like EqSort containers, to make/extract an expression from it
/// requires an extractor. Moreover, the extraction may be unsuccessful if
/// the extractor is not fully initialized.
///
/// The default behavior is to call make_expr
fn extract_expr(
&self,
egraph: &EGraph,
value: Value,
_extractor: &Extractor,
_termdag: &mut TermDag,
) -> Option<(Cost, Expr)> {
Some(self.make_expr(egraph, value))
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -101,7 +118,7 @@ impl Sort for EqSort {
}
}

fn make_expr(&self, _egraph: &EGraph, _value: Value) -> Expr {
fn make_expr(&self, _egraph: &EGraph, _value: Value) -> (Cost, Expr) {
unimplemented!("No make_expr for EqSort {}", self.name)
}
}
Expand Down
17 changes: 10 additions & 7 deletions src/sort/rational.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,20 @@ impl Sort for RationalSort {
add_primitives!(eg, "<=" = |a: R, b: R| -> Opt { if a <= b {Some(())} else {None} });
add_primitives!(eg, ">=" = |a: R, b: R| -> Opt { if a >= b {Some(())} else {None} });
}
fn make_expr(&self, _egraph: &EGraph, value: Value) -> Expr {
fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
assert!(value.tag == self.name());
let rat = R::load(self, &value);
let numer = *rat.numer();
let denom = *rat.denom();
Expr::call(
"rational",
vec![
Expr::Lit(Literal::Int(numer)),
Expr::Lit(Literal::Int(denom)),
],
(
1,
Expr::call(
"rational",
vec![
Expr::Lit(Literal::Int(numer)),
Expr::Lit(Literal::Int(denom)),
],
),
)
}
}
Expand Down
24 changes: 19 additions & 5 deletions src/sort/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,29 @@ impl Sort for SetSort {
});
}

fn make_expr(&self, egraph: &EGraph, value: Value) -> Expr {
fn make_expr(&self, egraph: &EGraph, value: Value) -> (Cost, Expr) {
let mut termdag = TermDag::default();
let extractor = Extractor::new(egraph, &mut termdag);
self.extract_expr(egraph, value, &extractor, &mut termdag)
.expect("Extraction should be successful since extractor has been fully initialized")
}

fn extract_expr(
&self,
_egraph: &EGraph,
value: Value,
extractor: &Extractor,
termdag: &mut TermDag,
) -> Option<(Cost, Expr)> {
let set = ValueSet::load(self, &value);
let mut expr = Expr::call("set-empty", []);
let mut termdag = TermDag::default();
let mut cost = 0usize;
for e in set.iter().rev() {
let e = egraph.extract(*e, &mut termdag, &self.element).1;
expr = Expr::call("set-insert", [expr, termdag.term_to_expr(&e)])
let e = extractor.find_best(*e, termdag, &self.element)?;
cost = cost.saturating_add(e.0);
expr = Expr::call("set-insert", [expr, termdag.term_to_expr(&e.1)])
}
expr
Some((cost, expr))
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/sort/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ impl Sort for StringSort {
self
}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> Expr {
fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
assert!(value.tag == self.name);
let sym = Symbol::from(NonZeroU32::new(value.bits as _).unwrap());
Expr::Lit(Literal::String(sym))
(1, Expr::Lit(Literal::String(sym)))
}

fn register_primitives(self: Arc<Self>, typeinfo: &mut TypeInfo) {
Expand Down
Loading