From ef44deae09dfe7ec606ca8bca9d6c3368349ea9b Mon Sep 17 00:00:00 2001 From: Alex Fischman Date: Thu, 24 Oct 2024 21:05:07 -0700 Subject: [PATCH] Improve EGraph::find signature --- src/extract.rs | 12 ++++++------ src/lib.rs | 8 ++++---- src/sort/map.rs | 4 ++-- src/sort/set.rs | 2 +- src/sort/vec.rs | 2 +- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/extract.rs b/src/extract.rs index dabc22a7..97977892 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -59,9 +59,9 @@ impl EGraph { inputs .iter() .zip(&func.schema.input) - .map(|(input, sort)| extractor.costs.get( - &extractor.egraph.find(sort.is_eq_sort(), *input).bits - )) + .map(|(input, sort)| extractor + .costs + .get(&extractor.egraph.find(sort, *input).bits)) .collect::>() ); } @@ -80,7 +80,7 @@ impl EGraph { termdag: &mut TermDag, ) -> Vec { let output_sort = sort.name(); - let output_value = self.find(sort.is_eq_sort(), value); + let output_value = self.find(sort, value); let ext = &Extractor::new(self, termdag); ext.ctors .iter() @@ -152,7 +152,7 @@ impl<'a> Extractor<'a> { sort: &ArcSort, ) -> Option<(Cost, Term)> { if sort.is_eq_sort() { - let id = self.egraph.find(true, value).bits; + let id = self.egraph.find(sort, value).bits; let (cost, node) = self.costs.get(&id)?.clone(); Some((cost, node)) } else { @@ -192,7 +192,7 @@ impl<'a> Extractor<'a> { { let make_new_pair = || (new_cost, termdag.app(sym, term_inputs)); - let id = self.egraph.find(true, output.value).bits; + let id = self.egraph.find(&func.schema.output, output.value).bits; match self.costs.entry(id) { Entry::Vacant(e) => { did_something = true; diff --git a/src/lib.rs b/src/lib.rs index fc22c64a..66fe982e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -543,14 +543,14 @@ impl EGraph { for (input, sort) in inputs.iter().zip(&function.schema.input) { assert_eq!( input, - &self.find(sort.is_eq_sort(), *input), + &self.find(sort, *input), "[{i}] {name}({inputs:?}) = {output:?}\n{:?}", function.schema, ) } assert_eq!( output.value, - self.find(function.schema.output.is_eq_sort(), output.value), + self.find(&function.schema.output, output.value), "[{i}] {name}({inputs:?}) = {output:?}\n{:?}", function.schema, ) @@ -591,8 +591,8 @@ impl EGraph { } /// find the leader value for a particular eclass - pub fn find(&self, is_eq_sort: bool, value: Value) -> Value { - if is_eq_sort { + pub fn find(&self, sort: &ArcSort, value: Value) -> Value { + if sort.is_eq_sort() { Value { #[cfg(debug_assertions)] tag: value.tag, diff --git a/src/sort/map.rs b/src/sort/map.rs index 2e35b75c..fd6b3698 100644 --- a/src/sort/map.rs +++ b/src/sort/map.rs @@ -251,8 +251,8 @@ impl PrimitiveLike for MapRebuild { .iter() .map(|(k, v)| { ( - egraph.find(self.map.key.is_eq_sort(), *k), - egraph.find(self.map.value.is_eq_sort(), *v), + egraph.find(&self.map.key, *k), + egraph.find(&self.map.value, *v), ) }) .collect(); diff --git a/src/sort/set.rs b/src/sort/set.rs index a2de6828..84cd32db 100644 --- a/src/sort/set.rs +++ b/src/sort/set.rs @@ -304,7 +304,7 @@ impl PrimitiveLike for SetRebuild { let set = ValueSet::load(&self.set, &values[0]); let new_set: ValueSet = set .iter() - .map(|e| egraph.find(self.set.element.is_eq_sort(), *e)) + .map(|e| egraph.find(&self.set.element, *e)) .collect(); // drop set to make sure we lose lock drop(set); diff --git a/src/sort/vec.rs b/src/sort/vec.rs index aeed7667..40c3e862 100644 --- a/src/sort/vec.rs +++ b/src/sort/vec.rs @@ -258,7 +258,7 @@ impl PrimitiveLike for VecRebuild { let vec = ValueVec::load(&self.vec, &values[0]); let new_vec: ValueVec = vec .iter() - .map(|e| egraph.find(self.vec.element.is_eq_sort(), *e)) + .map(|e| egraph.find(&self.vec.element, *e)) .collect(); drop(vec); Some(new_vec.store(&self.vec).unwrap())