Skip to content

Commit

Permalink
Improve EGraph::find signature
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Fischman committed Oct 25, 2024
1 parent ba5ca66 commit ef44dea
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
12 changes: 6 additions & 6 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ impl EGraph {
inputs
.iter()
.zip(&func.schema.input)
.map(|(input, sort)| extractor.costs.get(
&extractor.egraph.find(sort.is_eq_sort(), *input).bits
))
.map(|(input, sort)| extractor
.costs
.get(&extractor.egraph.find(sort, *input).bits))
.collect::<Vec<_>>()
);
}
Expand All @@ -80,7 +80,7 @@ impl EGraph {
termdag: &mut TermDag,
) -> Vec<Term> {
let output_sort = sort.name();
let output_value = self.find(sort.is_eq_sort(), value);
let output_value = self.find(sort, value);
let ext = &Extractor::new(self, termdag);
ext.ctors
.iter()
Expand Down Expand Up @@ -152,7 +152,7 @@ impl<'a> Extractor<'a> {
sort: &ArcSort,
) -> Option<(Cost, Term)> {
if sort.is_eq_sort() {
let id = self.egraph.find(true, value).bits;
let id = self.egraph.find(sort, value).bits;
let (cost, node) = self.costs.get(&id)?.clone();
Some((cost, node))
} else {
Expand Down Expand Up @@ -192,7 +192,7 @@ impl<'a> Extractor<'a> {
{
let make_new_pair = || (new_cost, termdag.app(sym, term_inputs));

let id = self.egraph.find(true, output.value).bits;
let id = self.egraph.find(&func.schema.output, output.value).bits;
match self.costs.entry(id) {
Entry::Vacant(e) => {
did_something = true;
Expand Down
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,14 +543,14 @@ impl EGraph {
for (input, sort) in inputs.iter().zip(&function.schema.input) {
assert_eq!(
input,
&self.find(sort.is_eq_sort(), *input),
&self.find(sort, *input),
"[{i}] {name}({inputs:?}) = {output:?}\n{:?}",
function.schema,
)
}
assert_eq!(
output.value,
self.find(function.schema.output.is_eq_sort(), output.value),
self.find(&function.schema.output, output.value),
"[{i}] {name}({inputs:?}) = {output:?}\n{:?}",
function.schema,
)
Expand Down Expand Up @@ -591,8 +591,8 @@ impl EGraph {
}

/// find the leader value for a particular eclass
pub fn find(&self, is_eq_sort: bool, value: Value) -> Value {
if is_eq_sort {
pub fn find(&self, sort: &ArcSort, value: Value) -> Value {
if sort.is_eq_sort() {
Value {
#[cfg(debug_assertions)]
tag: value.tag,
Expand Down
4 changes: 2 additions & 2 deletions src/sort/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ impl PrimitiveLike for MapRebuild {
.iter()
.map(|(k, v)| {
(
egraph.find(self.map.key.is_eq_sort(), *k),
egraph.find(self.map.value.is_eq_sort(), *v),
egraph.find(&self.map.key, *k),
egraph.find(&self.map.value, *v),
)
})
.collect();
Expand Down
2 changes: 1 addition & 1 deletion src/sort/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ impl PrimitiveLike for SetRebuild {
let set = ValueSet::load(&self.set, &values[0]);
let new_set: ValueSet = set
.iter()
.map(|e| egraph.find(self.set.element.is_eq_sort(), *e))
.map(|e| egraph.find(&self.set.element, *e))
.collect();
// drop set to make sure we lose lock
drop(set);
Expand Down
2 changes: 1 addition & 1 deletion src/sort/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ impl PrimitiveLike for VecRebuild {
let vec = ValueVec::load(&self.vec, &values[0]);
let new_vec: ValueVec = vec
.iter()
.map(|e| egraph.find(self.vec.element.is_eq_sort(), *e))
.map(|e| egraph.find(&self.vec.element, *e))
.collect();
drop(vec);
Some(new_vec.store(&self.vec).unwrap())
Expand Down

0 comments on commit ef44dea

Please sign in to comment.