diff --git a/Cargo.lock b/Cargo.lock index b101b5f0..18b5a295 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -184,6 +184,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +[[package]] +name = "bitmaps" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "031043d04099746d8db04daf1fa424b2bc8bd69d92b25962dcde24da39ab64a2" +dependencies = [ + "typenum", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -300,6 +309,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "num-traits", +] + [[package]] name = "clap" version = "4.5.4" @@ -641,14 +659,16 @@ dependencies = [ [[package]] name = "egglog" -version = "0.1.0" -source = "git+https://github.com/egraphs-good/egglog?rev=325814fd90767b5e43c72bc2eb65e14ff0b8746c#325814fd90767b5e43c72bc2eb65e14ff0b8746c" +version = "0.3.0" +source = "git+https://github.com/egraphs-good/egglog?rev=12ecb21e8aeb25297a36be2a04d846222daf5297#12ecb21e8aeb25297a36be2a04d846222daf5297" dependencies = [ + "chrono", "clap", "egraph-serialize", "env_logger 0.10.2", "generic_symbolic_expressions", "hashbrown 0.14.5", + "im-rc", "indexmap", "instant", "lalrpop", @@ -669,9 +689,9 @@ dependencies = [ [[package]] name = "egraph-serialize" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a41150f383849cfc16ae6230f592112b3c0a2c0e3ec43eb0b09db037bfcce703" +checksum = "c31c5c0d7f760f9c1c84e21d73dcd3b3ce7a4770c27689f56a0db26e0f3e79ca" dependencies = [ "graphviz-rust 0.6.6", "indexmap", @@ -813,8 +833,9 @@ dependencies = [ [[package]] name = "generic_symbolic_expressions" -version = "5.0.3" -source = "git+https://github.com/oflatt/symbolic-expressions?rev=655b6a4c06b4b3d3b2300e17779860b4abe440f0#655b6a4c06b4b3d3b2300e17779860b4abe440f0" +version = "5.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597eb584fb7cfd1935294fc3608a453fc35a58dfa9da4299c8fd3bc75a4c0b4b" [[package]] name = "getrandom" @@ -928,6 +949,20 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "im-rc" +version = "15.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1955a75fa080c677d3972822ec4bad316169ab1cfc6c257a942c2265dbe5fe" +dependencies = [ + "bitmaps", + "rand_core", + "rand_xoshiro", + "sized-chunks", + "typenum", + "version_check", +] + [[package]] name = "indexmap" version = "2.2.6" @@ -1444,6 +1479,15 @@ dependencies = [ "serde", ] +[[package]] +name = "rand_xoshiro" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" +dependencies = [ + "rand_core", +] + [[package]] name = "redox_syscall" version = "0.5.1" @@ -1662,6 +1706,16 @@ version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +[[package]] +name = "sized-chunks" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d69225bde7a69b235da73377861095455d298f2b970996eec25ddbb42b3d1e" +dependencies = [ + "bitmaps", + "typenum", +] + [[package]] name = "slice-group-by" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index fa915af6..15b0830d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,8 +9,8 @@ name = "files" [dependencies] -egglog = { git = "https://github.com/egraphs-good/egglog", rev = "325814fd90767b5e43c72bc2eb65e14ff0b8746c" } -egraph-serialize = "0.1.0" +egglog = { git = "https://github.com/egraphs-good/egglog", rev = "12ecb21e8aeb25297a36be2a04d846222daf5297" } +egraph-serialize = "0.2.0" log = "0.4.19" thiserror = "1" lalrpop-util = { version = "0.20.2", features = ["lexer"] } diff --git a/dag_in_context/Cargo.lock b/dag_in_context/Cargo.lock index 90965f51..f8f4d8a7 100644 --- a/dag_in_context/Cargo.lock +++ b/dag_in_context/Cargo.lock @@ -140,6 +140,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +[[package]] +name = "bitmaps" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "031043d04099746d8db04daf1fa424b2bc8bd69d92b25962dcde24da39ab64a2" +dependencies = [ + "typenum", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -165,6 +174,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "num-traits", +] + [[package]] name = "clap" version = "4.5.4" @@ -319,14 +337,16 @@ checksum = "675e35c02a51bb4d4618cb4885b3839ce6d1787c97b664474d9208d074742e20" [[package]] name = "egglog" -version = "0.1.0" -source = "git+https://github.com/egraphs-good/egglog?rev=325814fd90767b5e43c72bc2eb65e14ff0b8746c#325814fd90767b5e43c72bc2eb65e14ff0b8746c" +version = "0.3.0" +source = "git+https://github.com/egraphs-good/egglog?rev=12ecb21e8aeb25297a36be2a04d846222daf5297#12ecb21e8aeb25297a36be2a04d846222daf5297" dependencies = [ + "chrono", "clap", "egraph-serialize", "env_logger 0.10.2", "generic_symbolic_expressions", "hashbrown 0.14.3", + "im-rc", "indexmap", "instant", "lalrpop", @@ -347,9 +367,9 @@ dependencies = [ [[package]] name = "egraph-serialize" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a41150f383849cfc16ae6230f592112b3c0a2c0e3ec43eb0b09db037bfcce703" +checksum = "c31c5c0d7f760f9c1c84e21d73dcd3b3ce7a4770c27689f56a0db26e0f3e79ca" dependencies = [ "graphviz-rust 0.6.6", "indexmap", @@ -456,8 +476,9 @@ dependencies = [ [[package]] name = "generic_symbolic_expressions" -version = "5.0.3" -source = "git+https://github.com/oflatt/symbolic-expressions?rev=655b6a4c06b4b3d3b2300e17779860b4abe440f0#655b6a4c06b4b3d3b2300e17779860b4abe440f0" +version = "5.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597eb584fb7cfd1935294fc3608a453fc35a58dfa9da4299c8fd3bc75a4c0b4b" [[package]] name = "getrandom" @@ -545,6 +566,20 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "im-rc" +version = "15.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1955a75fa080c677d3972822ec4bad316169ab1cfc6c257a942c2265dbe5fe" +dependencies = [ + "bitmaps", + "rand_core", + "rand_xoshiro", + "sized-chunks", + "typenum", + "version_check", +] + [[package]] name = "indexmap" version = "2.2.6" @@ -938,6 +973,15 @@ dependencies = [ "serde", ] +[[package]] +name = "rand_xoshiro" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" +dependencies = [ + "rand_core", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -1097,6 +1141,16 @@ version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +[[package]] +name = "sized-chunks" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d69225bde7a69b235da73377861095455d298f2b970996eec25ddbb42b3d1e" +dependencies = [ + "bitmaps", + "typenum", +] + [[package]] name = "smallvec" version = "1.13.2" diff --git a/dag_in_context/Cargo.toml b/dag_in_context/Cargo.toml index fb1fea1a..0e56536b 100644 --- a/dag_in_context/Cargo.toml +++ b/dag_in_context/Cargo.toml @@ -6,12 +6,12 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -egglog = { git = "https://github.com/egraphs-good/egglog", rev = "325814fd90767b5e43c72bc2eb65e14ff0b8746c" } +egglog = { git = "https://github.com/egraphs-good/egglog", rev = "12ecb21e8aeb25297a36be2a04d846222daf5297" } strum = "0.25" strum_macros = "0.25" main_error = "0.1.2" thiserror = "1.0" -egraph-serialize = "0.1.0" +egraph-serialize = "0.2.0" bril-rs = { git = "https://github.com/uwplse/bril", rev = "e2be3f5" } indexmap = "2.0.0" rustc-hash = "1.1.0" diff --git a/dag_in_context/src/add_context.rs b/dag_in_context/src/add_context.rs index 941dc40c..9a8dac55 100644 --- a/dag_in_context/src/add_context.rs +++ b/dag_in_context/src/add_context.rs @@ -4,7 +4,7 @@ //! Mantains the sharing invariant (see restore_sharing_invariant) by using a cache. use egglog::Term; -use std::collections::HashMap; +use indexmap::IndexMap; use crate::{ print_with_intermediate_helper, @@ -14,8 +14,8 @@ use crate::{ }; pub struct ContextCache { - with_ctx: HashMap<(*const Expr, AssumptionRef), RcExpr>, - symbol_gen: HashMap<(*const Expr, AssumptionRef), String>, + with_ctx: IndexMap<(*const Expr, AssumptionRef), RcExpr>, + symbol_gen: IndexMap<(*const Expr, AssumptionRef), String>, /// How many placeholder contexts we've created (used to make a unique name each time) loop_context_placeholder_counter: usize, /// The unions that we need to make between assumptions @@ -48,8 +48,8 @@ impl ContextCache { pub fn new() -> ContextCache { ContextCache { - with_ctx: HashMap::new(), - symbol_gen: HashMap::new(), + with_ctx: IndexMap::new(), + symbol_gen: IndexMap::new(), loop_context_placeholder_counter: 0, loop_context_unions: Vec::new(), symbolic_ctx: false, @@ -59,8 +59,8 @@ impl ContextCache { pub fn new_symbolic_ctx() -> ContextCache { ContextCache { - with_ctx: HashMap::new(), - symbol_gen: HashMap::new(), + with_ctx: IndexMap::new(), + symbol_gen: IndexMap::new(), loop_context_placeholder_counter: 0, loop_context_unions: Vec::new(), symbolic_ctx: true, @@ -70,8 +70,8 @@ impl ContextCache { pub fn new_dummy_ctx() -> ContextCache { ContextCache { - with_ctx: HashMap::new(), - symbol_gen: HashMap::new(), + with_ctx: IndexMap::new(), + symbol_gen: IndexMap::new(), loop_context_placeholder_counter: 0, loop_context_unions: Vec::new(), symbolic_ctx: false, @@ -94,7 +94,7 @@ impl ContextCache { &self, printed: &mut String, tree_state: &mut TreeToEgglog, - term_cache: &mut HashMap, + term_cache: &mut IndexMap, ) -> String { self.loop_context_unions .iter() diff --git a/dag_in_context/src/from_egglog.rs b/dag_in_context/src/from_egglog.rs index 7b07b84a..d43ab630 100644 --- a/dag_in_context/src/from_egglog.rs +++ b/dag_in_context/src/from_egglog.rs @@ -1,9 +1,10 @@ //! Converts from an egglog AST directly to the rust representation of that AST. //! Common subexpressions (common terms) must be converted to the same RcExpr (pointer equality). -use std::{collections::HashMap, rc::Rc}; +use std::rc::Rc; use egglog::{ast::Literal, match_term_app, Term}; +use indexmap::IndexMap; use crate::schema::{ Assumption, BaseType, BinaryOp, Constant, Expr, RcExpr, TernaryOp, TreeProgram, Type, UnaryOp, @@ -11,13 +12,13 @@ use crate::schema::{ pub struct FromEgglog<'a> { pub termdag: &'a egglog::TermDag, - pub conversion_cache: HashMap, + pub conversion_cache: IndexMap, } pub fn program_from_egglog(program: Term, termdag: &egglog::TermDag) -> TreeProgram { let mut converter = FromEgglog { termdag, - conversion_cache: HashMap::new(), + conversion_cache: IndexMap::new(), }; converter.program_from_egglog(program) } @@ -28,7 +29,7 @@ pub fn program_from_egglog_preserve_ctx_nodes( ) -> TreeProgram { let mut converter = FromEgglog { termdag, - conversion_cache: HashMap::new(), + conversion_cache: IndexMap::new(), }; converter.program_from_egglog(program) } diff --git a/dag_in_context/src/greedy_dag_extractor.rs b/dag_in_context/src/greedy_dag_extractor.rs index c33d6681..f5f4bd2e 100644 --- a/dag_in_context/src/greedy_dag_extractor.rs +++ b/dag_in_context/src/greedy_dag_extractor.rs @@ -1,10 +1,10 @@ -use egglog::{ast::Literal, util::IndexMap, *}; +use egglog::{ast::Literal, *}; use egraph_serialize::{ClassId, EGraph, NodeId}; +use indexmap::{IndexMap, IndexSet}; use ordered_float::{NotNan, OrderedFloat}; use rpds::HashTrieMap; -use rustc_hash::FxHashMap; use std::{ - collections::{HashMap, HashSet, VecDeque}, + collections::{HashSet, VecDeque}, f64::INFINITY, }; use strum::IntoEnumIterator; @@ -22,24 +22,24 @@ pub(crate) struct EgraphInfo<'a> { pub(crate) egraph: EGraph, // For every (root, eclass) pair, store the parent // (root, enode) pairs that may depend on it. - pub(crate) parents: HashMap<(RootId, ClassId), Vec<(RootId, NodeId)>>, + pub(crate) parents: IndexMap<(RootId, ClassId), Vec<(RootId, NodeId)>>, pub(crate) roots: Vec<(RootId, NodeId)>, pub(crate) cm: &'a dyn CostModel, /// Optionally, a loop with (inputs, outputs) can have an estimated number of iterations. /// This is found by looking at LoopNumItersGuess in the database. - pub(crate) loop_iteration_estimates: HashMap<(RootId, RootId), i64>, + pub(crate) loop_iteration_estimates: IndexMap<(RootId, RootId), i64>, /// A set of names of functions that are unextractable - unextractables: HashSet, + unextractables: IndexSet, /// A set of (func args) of calls that have been inlined, to indicate we shouldn't /// extract the corresponding (Call func args). - inlined_calls: HashSet<(ClassId, ClassId)>, + inlined_calls: IndexSet<(ClassId, ClassId)>, } pub(crate) struct Extractor<'a> { pub(crate) termdag: &'a mut TermDag, costsets: Vec, - costsetmemo: FxHashMap<(NodeId, Vec), CostSetIndex>, - costs: FxHashMap>, + costsetmemo: IndexMap<(NodeId, Vec), CostSetIndex>, + costs: IndexMap>, // use to get the type of an expression pub(crate) typechecker: TypeChecker<'a>, @@ -49,8 +49,8 @@ pub(crate) struct Extractor<'a> { pub(crate) correspondence: IndexMap, // Get the expression corresponding to a term. // This is computed after the extraction is done. - pub(crate) term_to_expr: Option>, - pub(crate) eclass_type: Option>, + pub(crate) term_to_expr: Option>, + pub(crate) eclass_type: Option>, } impl<'a> EgraphInfo<'a> { @@ -64,18 +64,18 @@ impl<'a> EgraphInfo<'a> { .unwrap() } - fn get_loop_iteration_estimates(egraph: &EGraph) -> HashMap<(ClassId, ClassId), i64> { + fn get_loop_iteration_estimates(egraph: &EGraph) -> IndexMap<(ClassId, ClassId), i64> { // for every eclass that represents a single i64 in the egraph, // map the eclass to that integer - let mut integers = HashMap::new(); + let mut integers: IndexMap = IndexMap::default(); for (nodeid, node) in &egraph.nodes { if let Ok(integer) = node.op.parse::() { let eclass = egraph.nid_to_cid(nodeid); - integers.insert(eclass, integer); + integers.insert(eclass.clone(), integer); } } - let mut loop_iteration_estimates = HashMap::new(); + let mut loop_iteration_estimates = IndexMap::default(); // loop over all nodes, finding LoopNumItersGuess nodes for (_nodeid, node) in &egraph.nodes { @@ -99,8 +99,8 @@ impl<'a> EgraphInfo<'a> { loop_iteration_estimates } - fn get_inlined_calls(egraph: &EGraph) -> HashSet<(ClassId, ClassId)> { - let mut inlined_calls = HashSet::new(); + fn get_inlined_calls(egraph: &EGraph) -> IndexSet<(ClassId, ClassId)> { + let mut inlined_calls = IndexSet::new(); // loop over all nodes, finding InlinedCall nodes for (_nodeid, node) in &egraph.nodes { @@ -124,13 +124,13 @@ impl<'a> EgraphInfo<'a> { pub(crate) fn new( cm: &'a dyn CostModel, egraph: EGraph, - unextractables: HashSet, + unextractables: IndexSet, ) -> Self { let loop_iteration_estimates = Self::get_loop_iteration_estimates(&egraph); let inlined_calls = Self::get_inlined_calls(&egraph); // get all the roots needed - let mut region_roots = HashSet::new(); + let mut region_roots = IndexSet::new(); for (_nodeid, node) in &egraph.nodes { for root in enode_regions(&egraph, node) { region_roots.insert(root); @@ -179,7 +179,8 @@ impl<'a> EgraphInfo<'a> { roots.sort(); log::info!("Found {} roots", roots.len()); - let mut parents: HashMap<(RootId, ClassId), HashSet<(RootId, NodeId)>> = HashMap::new(); + let mut parents: IndexMap<(RootId, ClassId), IndexSet<(RootId, NodeId)>> = + IndexMap::default(); for (root, eclass) in relavent_eclasses { // iterate over every root, enode pair for enode in egraph.classes()[&eclass].nodes.iter() { @@ -217,7 +218,7 @@ impl<'a> EgraphInfo<'a> { parents.values().map(|v| v.len()).sum::() ); - let mut parents_sorted = HashMap::new(); + let mut parents_sorted = IndexMap::default(); for (key, parents) in parents { let mut parents_vec = parents.into_iter().collect::>(); parents_vec.sort(); @@ -278,7 +279,7 @@ impl<'a> Extractor<'a> { termdag: self.termdag, conversion_cache: Default::default(), }; - let mut node_to_type: HashMap = Default::default(); + let mut node_to_type: IndexMap = Default::default(); for (term, node_id) in &self.correspondence { let node = info.egraph.nodes.get(node_id).unwrap(); @@ -368,7 +369,7 @@ pub(crate) fn get_root(egraph: &egraph_serialize::EGraph) -> NodeId { res.0.clone() } -pub fn get_unextractables(egraph: &egglog::EGraph) -> HashSet { +pub fn get_unextractables(egraph: &egglog::EGraph) -> IndexSet { let unextractables = egraph .functions .iter() @@ -385,7 +386,7 @@ pub fn get_unextractables(egraph: &egglog::EGraph) -> HashSet { pub fn serialized_egraph( egglog_egraph: egglog::EGraph, -) -> (egraph_serialize::EGraph, HashSet) { +) -> (egraph_serialize::EGraph, IndexSet) { let config = SerializeConfig::default(); let egraph = egglog_egraph.serialize(config); @@ -687,7 +688,7 @@ fn node_cost_in_region( pub fn extract( original_prog: &TreeProgram, egraph: egraph_serialize::EGraph, - unextractables: HashSet, + unextractables: IndexSet, termdag: &mut TermDag, cost_model: impl CostModel, ) -> (CostSet, TreeProgram) { @@ -719,7 +720,7 @@ pub fn extract_with_paths( // If effectful paths are present, // for each region we will only consider // effectful nodes that are in effectful_path[rootid] - effectful_paths: Option<&HashMap>>, + effectful_paths: Option<&IndexMap>>, ) -> (CostSet, TreeProgram) { if effectful_paths.is_some() { log::info!("Re-extracting program after linear path is found."); @@ -1095,8 +1096,8 @@ fn region_reachable_classes( egraph: &egraph_serialize::EGraph, root: ClassId, cm: &dyn CostModel, -) -> HashSet { - let mut visited = HashSet::new(); +) -> IndexSet { + let mut visited = IndexSet::new(); let mut queue = UniqueQueue::default(); queue.insert(root); @@ -1143,7 +1144,7 @@ fn dag_extraction_test(prog: &TreeProgram, expected_cost: NotNan) { }; let mut egraph = egglog::EGraph::default(); - egraph.parse_and_run_program(&string_prog).unwrap(); + egraph.parse_and_run_program(None, &string_prog).unwrap(); let (serialized_egraph, unextractables) = serialized_egraph(egraph); let mut termdag = TermDag::default(); @@ -1170,7 +1171,7 @@ fn dag_extraction_linearity_check(prog: &TreeProgram, error_message: &str) { }; let mut egraph = egglog::EGraph::default(); - egraph.parse_and_run_program(&string_prog).unwrap(); + egraph.parse_and_run_program(None, &string_prog).unwrap(); let (serialized_egraph, unextractables) = serialized_egraph(egraph); let mut termdag = TermDag::default(); @@ -1389,7 +1390,7 @@ fn test_validity_of_extraction() { }; let mut egraph = egglog::EGraph::default(); - egraph.parse_and_run_program(&string_prog).unwrap(); + egraph.parse_and_run_program(None, &string_prog).unwrap(); let (serialized_egraph, unextractables) = serialized_egraph(egraph); let mut termdag = TermDag::default(); diff --git a/dag_in_context/src/lib.rs b/dag_in_context/src/lib.rs index a9d38168..be73c674 100644 --- a/dag_in_context/src/lib.rs +++ b/dag_in_context/src/lib.rs @@ -1,7 +1,6 @@ -use std::collections::HashMap; - use egglog::{Term, TermDag}; use greedy_dag_extractor::{extract, serialized_egraph, DefaultCostModel}; +use indexmap::IndexMap; use interpreter::Value; use schema::TreeProgram; use std::fmt::Write; @@ -79,7 +78,7 @@ pub fn prologue() -> String { fn print_with_intermediate_helper( termdag: &TermDag, term: Term, - cache: &mut HashMap, + cache: &mut IndexMap, res: &mut String, ) -> String { if let Some(var) = cache.get(&term) { @@ -108,7 +107,7 @@ fn print_with_intermediate_helper( pub fn print_with_intermediate_vars(termdag: &TermDag, term: Term) -> String { let mut printed = String::new(); - let mut cache = HashMap::::new(); + let mut cache = IndexMap::::new(); let res = print_with_intermediate_helper(termdag, term, &mut cache, &mut printed); printed.push_str(&format!("(let PROG {res})\n")); printed @@ -120,7 +119,7 @@ pub fn build_program(program: &TreeProgram, cache: &mut ContextCache, optimize: // Create a global cache for generating intermediate variables let mut tree_state = TreeToEgglog::new(); - let mut term_cache = HashMap::::new(); + let mut term_cache = IndexMap::::new(); // Generate function inlining egglog let function_inlining_unions = if !optimize { @@ -190,7 +189,7 @@ pub fn check_roundtrip_egraph(program: &TreeProgram) { let egglog_prog = build_program(program, &mut ContextCache::new(), false); log::info!("Running egglog program..."); let mut egraph = egglog::EGraph::default(); - egraph.parse_and_run_program(&egglog_prog).unwrap(); + egraph.parse_and_run_program(None, &egglog_prog).unwrap(); let (serialized, unextractables) = serialized_egraph(egraph); let (_res_cost, res) = extract( @@ -219,7 +218,7 @@ pub fn optimize( let egglog_prog = build_program(program, cache, true); log::info!("Running egglog program..."); let mut egraph = egglog::EGraph::default(); - egraph.parse_and_run_program(&egglog_prog)?; + egraph.parse_and_run_program(None, &egglog_prog)?; let (serialized, unextractables) = serialized_egraph(egraph); let mut termdag = egglog::TermDag::default(); @@ -271,7 +270,7 @@ fn check_program_gets_type(program: TreeProgram) -> Result { ); egglog::EGraph::default() - .parse_and_run_program(&s) + .parse_and_run_program(None, &s) .map(|lines| { for line in lines { println!("{}", line); @@ -349,7 +348,7 @@ fn egglog_test_internal( } let res = egglog::EGraph::default() - .parse_and_run_program(&program) + .parse_and_run_program(None, &program) .map(|lines| { for line in lines { println!("{}", line); diff --git a/dag_in_context/src/linearity.rs b/dag_in_context/src/linearity.rs index c8cf2b4d..baa7c5dd 100644 --- a/dag_in_context/src/linearity.rs +++ b/dag_in_context/src/linearity.rs @@ -2,11 +2,7 @@ //! program use memory linearly. //! In particular, it finds all the effectful e-nodes in an extracted term that are along the state edge path. -use std::{ - collections::{HashMap, HashSet}, - iter, - rc::Rc, -}; +use std::{collections::HashSet, iter, rc::Rc}; use egglog::Term; use egraph_serialize::{ClassId, NodeId}; @@ -21,8 +17,8 @@ type EffectfulNodes = IndexMap>; struct Linearity { effectful_nodes: EffectfulNodes, - expr_to_term: HashMap<*const Expr, Term>, - n2c: HashMap, + expr_to_term: IndexMap<*const Expr, Term>, + n2c: IndexMap, } impl<'a> Extractor<'a> { @@ -34,8 +30,8 @@ impl<'a> Extractor<'a> { &mut self, prog: &TreeProgram, egraph_info: &EgraphInfo, - ) -> HashMap> { - let mut expr_to_term = HashMap::new(); + ) -> IndexMap> { + let mut expr_to_term = IndexMap::new(); for (term, expr) in self.term_to_expr.as_ref().unwrap() { expr_to_term.insert(Rc::as_ptr(expr), term.clone()); } @@ -58,7 +54,7 @@ impl<'a> Extractor<'a> { self.find_effectful_nodes_in_region(function.func_body().unwrap(), &mut linearity); } - let effectful_nodes: HashMap> = linearity + let effectful_nodes: IndexMap> = linearity .effectful_nodes .into_iter() .map(|(k, v)| { @@ -72,7 +68,7 @@ impl<'a> Extractor<'a> { // assert that we only find one node per eclass (otherwise the extractor is incorrect) for nodes in effectful_nodes.values() { - let mut eclasses = HashSet::new(); + let mut eclasses = IndexSet::new(); for node in nodes { assert!(eclasses.insert(egraph_info.egraph.nid_to_cid(node))); } @@ -234,7 +230,7 @@ impl<'a> Extractor<'a> { }) .collect(); - let mut effectful_parent: HashMap<*const Expr, RcExpr> = Default::default(); + let mut effectful_parent: IndexMap<*const Expr, RcExpr> = Default::default(); for expr in exprs { let Some(expr) = get_if_effectful(self, expr) else { @@ -279,7 +275,7 @@ impl Expr { self: &RcExpr, root: &RcExpr, reachable_from: &mut IndexMap<*const Expr, IndexSet<*const Expr>>, - raw_to_rc: &mut HashMap<*const Expr, RcExpr>, + raw_to_rc: &mut IndexMap<*const Expr, RcExpr>, ) { raw_to_rc .entry(Rc::as_ptr(self)) diff --git a/dag_in_context/src/optimizations/function_inlining.rs b/dag_in_context/src/optimizations/function_inlining.rs index 7f0e4eaa..3bdee5c8 100644 --- a/dag_in_context/src/optimizations/function_inlining.rs +++ b/dag_in_context/src/optimizations/function_inlining.rs @@ -1,10 +1,7 @@ -use std::{ - collections::{HashMap, HashSet}, - rc::Rc, - vec, -}; +use std::{rc::Rc, vec}; use egglog::Term; +use indexmap::{IndexMap, IndexSet}; use crate::{ add_context::ContextCache, @@ -23,7 +20,7 @@ pub struct CallBody { fn get_calls_with_cache( expr: &RcExpr, calls: &mut Vec, - seen_exprs: &mut HashSet<*const Expr>, + seen_exprs: &mut IndexSet<*const Expr>, ) { if seen_exprs.get(&Rc::as_ptr(expr)).is_some() { return; @@ -48,7 +45,7 @@ fn get_calls_with_cache( // to look up the body fn subst_call( call: &RcExpr, - func_to_body: &HashMap, + func_to_body: &IndexMap, cache: &mut ContextCache, ) -> CallBody { if let Expr::Call(func_name, args) = call.as_ref() { @@ -83,10 +80,10 @@ pub fn function_inlining_pairs( func.func_body().expect("Func has body"), ) }) - .collect::>(); + .collect::>(); // Inline once - let mut seen_exprs: HashSet<*const Expr> = HashSet::new(); + let mut seen_exprs: IndexSet<*const Expr> = IndexSet::new(); let mut calls: Vec = Vec::new(); all_funcs .iter() @@ -127,7 +124,7 @@ pub fn print_function_inlining_pairs( function_inlining_pairs: Vec, printed: &mut String, tree_state: &mut TreeToEgglog, - term_cache: &mut HashMap, + term_cache: &mut IndexMap, ) -> String { let inlined_calls = "(relation InlinedCall (String Expr))"; // Get unions and mark each call as inlined for extraction purposes diff --git a/dag_in_context/src/pretty_print.rs b/dag_in_context/src/pretty_print.rs index 5d0b28a3..693e2ec0 100644 --- a/dag_in_context/src/pretty_print.rs +++ b/dag_in_context/src/pretty_print.rs @@ -13,16 +13,17 @@ use crate::{ schema_helpers::AssumptionRef, to_egglog::TreeToEgglog, }; -use egglog::{Term, TermDag}; +use egglog::{ast::DUMMY_SPAN, Term, TermDag}; +use indexmap::IndexMap; -use std::{collections::HashMap, hash::Hash, rc::Rc, vec}; +use std::{hash::Hash, rc::Rc, vec}; #[derive(Default)] pub struct PrettyPrinter { // Type/Assum/BaseType -> intermediate variables - symbols: HashMap, + symbols: IndexMap, // intermediate variable -> Type/Assum/BaseType lookup - table: HashMap, + table: IndexMap, fresh_count: u64, } @@ -118,15 +119,15 @@ impl PrettyPrinter { let bounded_expr = format!("(let {} {})", binding.clone(), str_expr); let prog = prologue().to_owned() + &bounded_expr; let mut egraph = egglog::EGraph::default(); - egraph.parse_and_run_program(&prog).unwrap(); + egraph.parse_and_run_program(None, &prog).unwrap(); let mut termdag = TermDag::default(); let (sort, value) = egraph - .eval_expr(&egglog::ast::Expr::Var((), binding.into())) + .eval_expr(&egglog::ast::Expr::Var(DUMMY_SPAN.clone(), binding.into())) .unwrap(); let (_, extracted) = egraph.extract(value, &mut termdag, &sort); let mut converter = FromEgglog { termdag: &termdag, - conversion_cache: HashMap::default(), + conversion_cache: IndexMap::default(), }; let expr = converter.expr_from_egglog(extracted); if to_rust { diff --git a/dag_in_context/src/to_egglog.rs b/dag_in_context/src/to_egglog.rs index 46e4d084..08fa5bd8 100644 --- a/dag_in_context/src/to_egglog.rs +++ b/dag_in_context/src/to_egglog.rs @@ -1,9 +1,10 @@ -use std::{collections::HashMap, rc::Rc, vec}; +use std::{rc::Rc, vec}; use egglog::{ ast::{Literal, Symbol}, Term, TermDag, }; +use indexmap::IndexMap; use crate::{ from_egglog::program_from_egglog_preserve_ctx_nodes, @@ -15,7 +16,7 @@ use crate::{ pub(crate) struct TreeToEgglog { pub termdag: TermDag, // Cache for shared subexpressions - converted_cache: HashMap<*const Expr, Term>, + converted_cache: IndexMap<*const Expr, Term>, } impl TreeToEgglog { @@ -23,7 +24,7 @@ impl TreeToEgglog { pub fn new() -> TreeToEgglog { TreeToEgglog { termdag: TermDag::default(), - converted_cache: HashMap::new(), + converted_cache: IndexMap::new(), } } @@ -292,7 +293,7 @@ impl TreeProgram { pub fn to_egglog_with_termdag(&self, termdag: TermDag) -> (Term, TermDag) { let mut state = TreeToEgglog { termdag, - converted_cache: HashMap::new(), + converted_cache: IndexMap::new(), }; (self.to_egglog_with(&mut state), state.termdag) } @@ -341,10 +342,22 @@ fn test_expr_parses_to(expr: RcExpr, expected: &str) { test_parses_to(term, &mut termdag, expected); } +#[cfg(test)] +pub const DEFAULT_FILENAME: &str = ""; + #[cfg(test)] fn test_parses_to(term: Term, termdag: &mut TermDag, expected: &str) { + use std::sync::Arc; + + use egglog::ast::SrcFile; + + let filename = DEFAULT_FILENAME.to_string(); + let srcfile = Arc::new(SrcFile { + name: filename, + contents: Some(expected.to_string()), + }); let parser = egglog::ast::parse::ExprParser::new(); - let parsed = parser.parse(expected).unwrap(); + let parsed = parser.parse(&srcfile, expected).unwrap(); let term2 = termdag.expr_to_term(&parsed); let pretty1 = termdag.term_to_expr(&term).to_sexp().pretty(); let pretty2 = termdag.term_to_expr(&term2).to_sexp().pretty(); diff --git a/dag_in_context/src/typechecker.rs b/dag_in_context/src/typechecker.rs index 5d933a5d..bb5931cc 100644 --- a/dag_in_context/src/typechecker.rs +++ b/dag_in_context/src/typechecker.rs @@ -1,4 +1,6 @@ -use std::{collections::HashMap, rc::Rc}; +use std::rc::Rc; + +use indexmap::IndexMap; use crate::{ ast::{base, empty, emptyt, function, program, statet}, @@ -91,11 +93,11 @@ impl Expr { /// This map is used to memoize the results of typechecking. /// It maps the old untyped expression to the new typed expression /// The type can be None when `expect_fully_typed` is true. -pub type TypedExprCache = HashMap<(*const Expr, Option), RcExpr>; +pub type TypedExprCache = IndexMap<(*const Expr, Option), RcExpr>; /// We also need to keep track of the type of the newly typed expression. /// This maps the newly instrumented expression to its type. -pub type TypeCache = HashMap<*const Expr, Type>; +pub type TypeCache = IndexMap<*const Expr, Type>; /// Type checks program fragments. /// Uses the program to look up function types. pub(crate) struct TypeChecker<'a> { @@ -112,8 +114,8 @@ impl<'a> TypeChecker<'a> { pub(crate) fn new(prog: &'a TreeProgram, expect_fully_typed: bool) -> Self { TypeChecker { program: prog, - type_cache: HashMap::new(), - type_expr_cache: HashMap::new(), + type_cache: IndexMap::new(), + type_expr_cache: IndexMap::new(), expect_fully_typed, } } diff --git a/src/canonicalize_names.rs b/src/canonicalize_names.rs index 38b3bc85..ca559cfe 100644 --- a/src/canonicalize_names.rs +++ b/src/canonicalize_names.rs @@ -3,10 +3,10 @@ //! track those down instead. use bril_rs::{Code, Function, Instruction, Program}; -use hashbrown::HashMap; +use indexmap::IndexMap; struct Renamer { - name_map: HashMap, + name_map: IndexMap, } pub(crate) fn canonicalize_bril(prog: &Program) -> Program { @@ -18,7 +18,7 @@ pub(crate) fn canonicalize_bril(prog: &Program) -> Program { fn canonicalize_func_names(func: &Function) -> Function { let mut renamer = Renamer { - name_map: HashMap::new(), + name_map: IndexMap::new(), }; for arg in &func.args { // don't touch argument names diff --git a/src/cfg/mod.rs b/src/cfg/mod.rs index 8a5567a3..d50ed02a 100644 --- a/src/cfg/mod.rs +++ b/src/cfg/mod.rs @@ -4,13 +4,14 @@ //! look for here are instructions that may break up basic blocks (`jmp`, `br`, //! `ret`), and labels. All other instructions are copied into the CFG. use core::fmt::Debug; +use std::fmt::Display; use std::fmt::Formatter; use std::str::FromStr; -use std::{collections::HashMap, fmt::Display}; use std::{fmt, mem}; use bril_rs::{Argument, Code, EffectOps, Function, Instruction, Position, Program, Type}; -use hashbrown::HashSet; +use indexmap::IndexMap; +use indexmap::IndexSet; use petgraph::dot::Dot; use petgraph::stable_graph::{EdgeReference, StableDiGraph}; @@ -356,7 +357,7 @@ pub type SwitchCfgFunction = CfgFunction; impl CfgFunction { pub(crate) fn remove_unreachable(&mut self) { - let mut reachable = HashSet::new(); + let mut reachable = IndexSet::new(); let mut dfs = DfsPostOrder::new(&self.graph, self.entry); while let Some(node) = dfs.next(&self.graph) { reachable.insert(node); @@ -611,7 +612,7 @@ pub(crate) fn function_to_cfg(func: &Function) -> SimpleCfgFunction { struct CfgBuilder { cfg: SimpleCfgFunction, - label_to_block: HashMap, + label_to_block: IndexMap, } impl CfgBuilder { @@ -629,7 +630,7 @@ impl CfgBuilder { return_ty: func.return_type.clone(), _phantom: Simple, }, - label_to_block: HashMap::new(), + label_to_block: IndexMap::new(), } } diff --git a/src/rvsdg/from_cfg.rs b/src/rvsdg/from_cfg.rs index ae59ac63..f45a14db 100644 --- a/src/rvsdg/from_cfg.rs +++ b/src/rvsdg/from_cfg.rs @@ -13,7 +13,7 @@ use std::io::Write; use std::process::Command; use bril_rs::{ConstOps, EffectOps, Instruction, Literal, Position, Type, ValueOps}; -use hashbrown::HashMap; +use indexmap::IndexMap; use petgraph::algo::dominators; use petgraph::dot::Dot; @@ -154,16 +154,16 @@ pub(crate) fn cfg_func_to_rvsdg( /// to the type of the function. /// Bril doesn't have a void type, so this /// is `None` when the function returns nothing. -pub(crate) type FunctionTypes = HashMap>; +pub(crate) type FunctionTypes = IndexMap>; pub(crate) struct RvsdgBuilder<'a> { cfg: &'a mut SwitchCfgFunction, expr: Vec, // Maps from branch node to join point. - join_point: HashMap, + join_point: IndexMap, analysis: LiveVariableAnalysis, dom: Dominators, - store: HashMap, + store: IndexMap, function_types: FunctionTypes, } @@ -519,7 +519,7 @@ impl<'a> RvsdgBuilder<'a> { fn convert_args( args: &[String], analysis: &mut LiveVariableAnalysis, - env: &mut HashMap, + env: &mut IndexMap, pos: &Option, ) -> Result> { let mut ops = Vec::with_capacity(args.len()); @@ -764,7 +764,7 @@ fn get_id(exprs: &mut Vec, body: RvsdgBody) -> Id { fn get_op( var: VarId, pos: &Option, - env: &HashMap, + env: &IndexMap, intern: &Names, ) -> Result { match env.get(&var) { diff --git a/src/rvsdg/from_dag.rs b/src/rvsdg/from_dag.rs index 626135f1..54a2ef8d 100644 --- a/src/rvsdg/from_dag.rs +++ b/src/rvsdg/from_dag.rs @@ -2,13 +2,14 @@ //! This is a strait-forward translation, since DAG programs are like RVSDGs //! but with tuple constructs such as Concat. -use std::{collections::HashMap, rc::Rc}; +use std::rc::Rc; use bril_rs::{ConstOps, EffectOps, Literal, ValueOps}; use dag_in_context::{ schema::{BaseType, BinaryOp, Expr, RcExpr, TernaryOp, TreeProgram, Type, UnaryOp}, typechecker::TypeCache, }; +use indexmap::IndexMap; use super::{BasicExpr, Operand, RvsdgBody, RvsdgFunction, RvsdgProgram, RvsdgType}; @@ -23,7 +24,7 @@ struct TreeToRvsdg<'a> { /// A cache of already converted expressions. /// Shared expressions must be converted to the same RVSDG nodes. /// For branches, this can be pre-propulated with the arguments passed to the branch. - translation_cache: HashMap<*const Expr, Operands>, + translation_cache: IndexMap<*const Expr, Operands>, nodes: &'a mut Vec, /// The current arguments to the tree program /// as RVSDG operands. @@ -111,7 +112,7 @@ fn tree_func_to_rvsdg(func: RcExpr, program: &TreeProgram) -> RvsdgFunction { let mut converter = TreeToRvsdg { program: &typechecked_program, type_cache: &type_cache, - translation_cache: HashMap::new(), + translation_cache: IndexMap::new(), nodes: &mut nodes, // initial arguments are the first n arguments current_args: (0..input_types.len()).map(Operand::Arg).collect(), @@ -213,7 +214,7 @@ impl<'a> TreeToRvsdg<'a> { program: self.program, nodes: self.nodes, type_cache: self.type_cache, - translation_cache: HashMap::new(), + translation_cache: IndexMap::new(), current_args: args, }; translator.convert_expr(expr) diff --git a/src/rvsdg/live_variables.rs b/src/rvsdg/live_variables.rs index 8900715b..5ad58794 100644 --- a/src/rvsdg/live_variables.rs +++ b/src/rvsdg/live_variables.rs @@ -9,7 +9,7 @@ use std::{collections::BTreeMap, fmt, mem}; use bril_rs::{self, EffectOps, Instruction, ValueOps}; use fixedbitset::FixedBitSet; use hashbrown::HashMap; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; use petgraph::{ stable_graph::NodeIndex, visit::{DfsPostOrder, VisitMap, Visitable}, @@ -270,7 +270,7 @@ impl<'a> fmt::Debug for StateAndNames<'a> { /// Type information recorded during live variable analysis. #[derive(Default)] pub(crate) struct VarTypes { - data: HashMap, + data: IndexMap, } impl VarTypes { diff --git a/src/rvsdg/optimize_direct_jumps.rs b/src/rvsdg/optimize_direct_jumps.rs index 016f865f..caeb493e 100644 --- a/src/rvsdg/optimize_direct_jumps.rs +++ b/src/rvsdg/optimize_direct_jumps.rs @@ -4,7 +4,7 @@ //! This is used by `to_cfg` to clean up //! the output. -use hashbrown::HashMap; +use indexmap::IndexMap; use petgraph::{ graph::EdgeIndex, stable_graph::{NodeIndex, StableDiGraph, StableGraph}, @@ -32,7 +32,7 @@ impl SimpleCfgFunction { // new graph // if a node was fused into another node, // it points to the new, fused node - let mut node_mapping: HashMap = HashMap::new(); + let mut node_mapping: IndexMap = IndexMap::new(); // we use a dfs post order // so dependencies are visited before parents diff --git a/src/rvsdg/restructure.rs b/src/rvsdg/restructure.rs index 0c871169..56fd85ff 100644 --- a/src/rvsdg/restructure.rs +++ b/src/rvsdg/restructure.rs @@ -9,7 +9,7 @@ use std::collections::VecDeque; use bril_rs::Type; -use hashbrown::{HashMap, HashSet}; +use indexmap::{IndexMap, IndexSet}; use petgraph::{ algo::{dominators, tarjan_scc}, graph::NodeIndex, @@ -132,8 +132,8 @@ impl SwitchCfgFunction { // The following follows the paper fairly literally. let scc_set = node_set(scc.iter().copied()); - let mut entry_arcs = HashSet::new(); - let mut entry_vertices = HashSet::new(); + let mut entry_arcs = IndexSet::new(); + let mut entry_vertices = IndexSet::new(); for edge_ref in scc .iter() .flat_map(|node| self.graph.edges_directed(*node, Direction::Incoming)) @@ -143,8 +143,8 @@ impl SwitchCfgFunction { entry_vertices.insert(edge_ref.target()); } - let mut exit_arcs = HashSet::new(); - let mut exit_vertices = HashSet::new(); + let mut exit_arcs = IndexSet::new(); + let mut exit_vertices = IndexSet::new(); for edge_ref in scc .iter() @@ -155,7 +155,7 @@ impl SwitchCfgFunction { exit_vertices.insert(edge_ref.target()); } - let repetition_arcs: HashSet = entry_vertices + let repetition_arcs: IndexSet = entry_vertices .iter() .flat_map(|node| self.graph.edges_directed(*node, Direction::Incoming)) .filter(|e| scc_set.is_visited(&e.source())) @@ -256,8 +256,8 @@ impl SwitchCfgFunction { node: NodeIndex, targets: impl IntoIterator, state: &mut RestructureState, - ) -> (HashMap, Identifier) { - let mut blocks = HashMap::new(); + ) -> (IndexMap, Identifier) { + let mut blocks = IndexMap::default(); for node in targets { let cur_len = u32::try_from(blocks.len()).unwrap(); blocks.entry(node).or_insert(cur_len); @@ -283,9 +283,9 @@ impl SwitchCfgFunction { } /// Compute the subgraph of the CFG dominated by the given edge. - fn dominator_graph(&self, edge: EdgeIndex) -> HashSet { - let mut nodes = HashSet::new(); - let mut edges = HashSet::new(); + fn dominator_graph(&self, edge: EdgeIndex) -> IndexSet { + let mut nodes = IndexSet::default(); + let mut edges = IndexSet::new(); edges.insert(edge); let mut frontier = VecDeque::with_capacity(1); let (_, target) = self.graph.edge_endpoints(edge).unwrap(); @@ -542,9 +542,9 @@ const JMP: Branch = Branch { struct Continuation { /// Nodes in the "tail" (`T` above) that are targetted by an edge out of the /// given branch node. - reentry_nodes: HashSet, + reentry_nodes: IndexSet, /// A mapping from branch edge, to edges back to nodes not dominated by that edge. - exit_arcs: HashMap>, + exit_arcs: IndexMap>, } struct EdgeData { diff --git a/src/rvsdg/simplify_branches.rs b/src/rvsdg/simplify_branches.rs index 6822bed7..13f98130 100644 --- a/src/rvsdg/simplify_branches.rs +++ b/src/rvsdg/simplify_branches.rs @@ -63,7 +63,6 @@ use std::{collections::VecDeque, io::Write, mem}; use crate::cfg::{BasicBlock, BlockName, Branch, BranchOp, CondVal, Identifier, SimpleCfgFunction}; use bril_rs::{Argument, Instruction, Literal, Type, ValueOps}; -use hashbrown::{HashMap, HashSet}; use indexmap::{IndexMap, IndexSet}; use petgraph::{ graph::{EdgeIndex, NodeIndex}, @@ -453,7 +452,7 @@ struct ValueState { Transform, )>, /// The set of variables written to in this basic block. - kills: HashSet, + kills: IndexSet, /// The materialized output of transforms on inherited. outputs: IndexMap, /// A variable indicating if `outputs` is stale. @@ -575,7 +574,7 @@ impl ValueState { } struct ValueAnalysis { - data: HashMap, + data: IndexMap, } impl ValueAnalysis { diff --git a/src/rvsdg/to_cfg.rs b/src/rvsdg/to_cfg.rs index 3bd1e1fb..cc9c5064 100644 --- a/src/rvsdg/to_cfg.rs +++ b/src/rvsdg/to_cfg.rs @@ -13,7 +13,7 @@ use bril_rs::{Argument, ConstOps, EffectOps, Instruction, Literal, Type, ValueOps}; -use hashbrown::HashMap; +use indexmap::IndexMap; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableDiGraph; @@ -92,8 +92,8 @@ struct RvsdgToCfg<'a> { /// the Option is the context, which is important becuase /// arguments are different in different contexts /// The context is none at the top level - operand_cache: HashMap<(RvsdgContext, Operand), Vec>, - body_cache: HashMap<(RvsdgContext, Id), Vec>, + operand_cache: IndexMap<(RvsdgContext, Operand), Vec>, + body_cache: IndexMap<(RvsdgContext, Id), Vec>, } impl RvsdgProgram { diff --git a/src/rvsdg/to_dag.rs b/src/rvsdg/to_dag.rs index 19848d42..4bb12607 100644 --- a/src/rvsdg/to_dag.rs +++ b/src/rvsdg/to_dag.rs @@ -13,6 +13,7 @@ use dag_in_context::ast::*; use dag_in_context::interpreter::Value; #[cfg(test)] use dag_in_context::schema::Constant; +use indexmap::IndexMap; use crate::rvsdg::{BasicExpr, Id, Operand, RvsdgBody, RvsdgFunction, RvsdgProgram}; use bril_rs::{EffectOps, Literal, ValueOps}; @@ -21,7 +22,6 @@ use dag_in_context::{ ast::{add, call, dowhile, function, int, less_than, program_vec, tfalse, ttrue}, schema::{RcExpr, TreeProgram, Type}, }; -use hashbrown::HashMap; use super::RvsdgType; @@ -85,7 +85,7 @@ impl StoredValue { struct DagTranslator<'a> { /// `stored_node` is a cache of already translated rvsdg nodes. - stored_node: HashMap, + stored_node: IndexMap, /// A reference to the nodes in the RVSDG. nodes: &'a [RvsdgBody], /// The next id to assign to an alloc. @@ -353,7 +353,7 @@ impl<'a> DagTranslator<'a> { impl RvsdgFunction { fn to_dag_encoding(&self) -> RcExpr { let mut translator = DagTranslator { - stored_node: HashMap::new(), + stored_node: IndexMap::new(), nodes: &self.nodes, next_alloc_id: 0, }; diff --git a/src/util.rs b/src/util.rs index a4e61899..e2769ba3 100644 --- a/src/util.rs +++ b/src/util.rs @@ -201,7 +201,6 @@ pub enum RunMode { /// Convert the original program to a RVSDG and then to a CFG, outputting one SVG per function. RvsdgToCfg, /// Converts to an executable using brilift (not using eggcc). - /// Brilift does not support phi nodes right now, so we can't run the optimized program with it. Cranelift, /// Converts to an executable using brillvm. /// `optimize_egglog` and `optimize_bril_llvm` must be set. @@ -660,7 +659,7 @@ impl Run { format!("{unfolded_program} \n {folded_program} \n (check (= PROG_PP PROG))"); //println!("{}", program); egglog::EGraph::default() - .parse_and_run_program(&program) + .parse_and_run_program(None, &program) .unwrap(); (vec![], None) } diff --git a/tests/snapshots/files__block-diamond-optimize.snap b/tests/snapshots/files__block-diamond-optimize.snap index 9263fd9b..dea10bb6 100644 --- a/tests/snapshots/files__block-diamond-optimize.snap +++ b/tests/snapshots/files__block-diamond-optimize.snap @@ -4,30 +4,32 @@ expression: visualization.result --- @main(v0: int) { .b1_: - c2_: int = const 2; - v3_: bool = lt v0 c2_; - v4_: bool = not v3_; + c2_: int = const 1; + c3_: int = const 2; + v4_: bool = lt v0 c3_; c5_: int = const 0; - c6_: int = const 1; - c7_: int = const 5; - v8_: int = id c6_; - v9_: int = id c6_; - v10_: int = id c2_; - br v3_ .b11_ .b12_; + c6_: int = const 5; + v7_: int = id c2_; + v8_: int = id c2_; + v9_: int = id c3_; + br v4_ .b10_ .b11_; +.b10_: + c12_: int = const 4; + v7_: int = id c12_; + v8_: int = id c2_; + v9_: int = id c3_; + v13_: int = id v7_; + v14_: int = id v8_; + v15_: int = add c2_ v13_; + print v15_; + ret; .b11_: - c13_: int = const 4; - v8_: int = id c13_; - v9_: int = id c6_; - v10_: int = id c2_; -.b12_: + v13_: int = id v7_; + v14_: int = id v8_; + v16_: int = add v7_ v9_; + v13_: int = id v16_; v14_: int = id v8_; - v15_: int = id v9_; - br v4_ .b16_ .b17_; -.b16_: - v18_: int = add v10_ v8_; - v14_: int = id v18_; - v15_: int = id v9_; .b17_: - v19_: int = add c6_ v14_; - print v19_; + v15_: int = add c2_ v13_; + print v15_; } diff --git a/tests/snapshots/files__collatz_redundant_computation-optimize.snap b/tests/snapshots/files__collatz_redundant_computation-optimize.snap index e0d66df6..7340d95a 100644 --- a/tests/snapshots/files__collatz_redundant_computation-optimize.snap +++ b/tests/snapshots/files__collatz_redundant_computation-optimize.snap @@ -17,8 +17,8 @@ expression: visualization.result .b12_: v13_: bool = eq v7_ v8_; v14_: int = id v6_; - v15_: int = id v7_; - v16_: int = id v7_; + v15_: int = id v8_; + v16_: int = id v8_; v17_: int = id v9_; v18_: int = id v10_; v19_: int = id v11_; diff --git a/tests/snapshots/files__fib_recursive-optimize.snap b/tests/snapshots/files__fib_recursive-optimize.snap index e674c7e5..9be0217f 100644 --- a/tests/snapshots/files__fib_recursive-optimize.snap +++ b/tests/snapshots/files__fib_recursive-optimize.snap @@ -128,18 +128,18 @@ expression: visualization.result br v95_ .b96_ .b97_; .b96_: v98_: bool = eq c2_ c2_; - v99_: int = id v71_; + v99_: int = id c4_; br v98_ .b100_ .b101_; .b101_: - v102_: bool = eq c2_ v71_; + v102_: bool = eq c2_ c4_; br v102_ .b103_ .b104_; .b103_: v105_: int = call @fac c2_; - v106_: int = id c2_; + v106_: int = id c4_; .b107_: v99_: int = id v106_; .b100_: - v108_: int = id v71_; + v108_: int = id c4_; .b109_: v92_: int = id v108_; .b93_: @@ -259,51 +259,51 @@ expression: visualization.result } @main { .b0_: - c1_: int = const 0; - c2_: int = const 2; + c1_: int = const 2; + c2_: int = const 0; v3_: bool = eq c1_ c2_; c4_: int = const 1; v5_: int = id c4_; br v3_ .b6_ .b7_; .b7_: - v8_: bool = eq c2_ c4_; + v8_: bool = eq c1_ c4_; br v8_ .b9_ .b10_; .b9_: - v11_: bool = eq c1_ c1_; - v12_: int = id c2_; + v11_: bool = eq c2_ c2_; + v12_: int = id c1_; br v11_ .b13_ .b14_; .b14_: v15_: bool = eq c1_ c2_; br v15_ .b16_ .b17_; .b16_: - v18_: int = call @fac c1_; + v18_: int = call @fac c2_; v19_: int = id c2_; .b20_: v12_: int = id v19_; .b13_: - v21_: int = id c2_; + v21_: int = id c1_; .b22_: v5_: int = id v21_; print v5_; ret; .b17_: c23_: int = const -1; - v24_: int = sub c23_ c2_; + v24_: int = sub c23_ c1_; v25_: int = call @fac c23_; v26_: int = call @fac v24_; v27_: int = add v25_ v26_; v19_: int = id v27_; jmp .b20_; .b10_: - v28_: bool = eq c1_ c1_; - v29_: bool = eq c1_ c4_; + v28_: bool = eq c2_ c2_; + v29_: bool = eq c2_ c4_; v30_: int = id c4_; br v29_ .b31_ .b32_; .b32_: v33_: bool = eq c4_ c4_; br v33_ .b34_ .b35_; .b34_: - v36_: int = call @fac c1_; + v36_: int = call @fac c2_; v37_: int = id c4_; .b38_: v30_: int = id v37_; @@ -311,11 +311,11 @@ expression: visualization.result v39_: int = id c4_; br v28_ .b40_ .b41_; .b41_: - v42_: bool = eq c1_ c4_; + v42_: bool = eq c2_ c4_; br v42_ .b43_ .b44_; .b43_: - v45_: int = call @fac c1_; - v46_: int = id c1_; + v45_: int = call @fac c2_; + v46_: int = id c4_; .b47_: v39_: int = id v46_; .b40_: @@ -332,7 +332,7 @@ expression: visualization.result jmp .b47_; .b35_: c54_: int = const -1; - v55_: int = call @fac c1_; + v55_: int = call @fac c2_; v56_: int = call @fac c54_; v57_: int = add v55_ v56_; v37_: int = id v57_; diff --git a/tests/snapshots/files__if_dead_code_nested-optimize.snap b/tests/snapshots/files__if_dead_code_nested-optimize.snap index e6ffbdca..10c37378 100644 --- a/tests/snapshots/files__if_dead_code_nested-optimize.snap +++ b/tests/snapshots/files__if_dead_code_nested-optimize.snap @@ -4,61 +4,60 @@ expression: visualization.result --- @main(v0: int) { .b1_: - c2_: int = const 1; - v3_: bool = lt v0 c2_; - c4_: int = const 0; - c5_: int = const 3; - c6_: int = const 2; - br v3_ .b7_ .b8_; -.b7_: - v9_: bool = lt v0 c4_; - c10_: bool = const true; - c11_: int = const 1; - c12_: int = const 2; - v13_: int = id c12_; - v14_: bool = id c10_; - v15_: int = id c11_; - br v9_ .b16_ .b17_; + c2_: int = const 0; + v3_: bool = le v0 c2_; + c4_: int = const 3; + c5_: int = const 2; + br v3_ .b6_ .b7_; +.b6_: + v8_: bool = lt v0 c2_; + c9_: bool = const true; + c10_: int = const 1; + c11_: int = const 2; + v12_: int = id c11_; + v13_: bool = id c9_; + v14_: int = id c10_; + br v8_ .b15_ .b16_; +.b15_: + v12_: int = id c10_; + v13_: bool = id c9_; + v14_: int = id c10_; .b16_: - v13_: int = id c11_; - v14_: bool = id c10_; - v15_: int = id c11_; -.b17_: - v18_: int = id v13_; - v19_: int = id c11_; - print v19_; - print v3_; + v17_: int = id v12_; + v18_: int = id c10_; print v18_; + print v3_; + print v17_; ret; -.b8_: - v20_: bool = lt c6_ v0; - c21_: bool = const false; - c22_: int = const 2; - v23_: int = id c22_; - v24_: bool = id c21_; - v25_: int = id c4_; - br v20_ .b26_ .b27_; -.b26_: - v28_: bool = gt v0 c5_; - c29_: int = const 4; - v30_: int = id c29_; - v31_: bool = id c21_; - v32_: int = id c4_; - br v28_ .b33_ .b34_; +.b7_: + v19_: bool = gt v0 c5_; + c20_: bool = const false; + c21_: int = const 2; + v22_: int = id c21_; + v23_: bool = id c20_; + v24_: int = id c2_; + br v19_ .b25_ .b26_; +.b25_: + v27_: bool = gt v0 c4_; + c28_: int = const 4; + v29_: int = id c28_; + v30_: bool = id c20_; + v31_: int = id c2_; + br v27_ .b32_ .b33_; +.b32_: + c34_: int = const 3; + v29_: int = id c34_; + v30_: bool = id c20_; + v31_: int = id c2_; .b33_: - c35_: int = const 3; - v30_: int = id c35_; - v31_: bool = id c21_; - v32_: int = id c4_; -.b34_: - v23_: int = id v30_; - v24_: bool = id v31_; - v25_: int = id v32_; -.b27_: - v18_: int = id v23_; - v19_: int = id c4_; -.b36_: - print v19_; - print v3_; + v22_: int = id v29_; + v23_: bool = id v30_; + v24_: int = id v31_; +.b26_: + v17_: int = id v22_; + v18_: int = id c2_; +.b35_: print v18_; + print v3_; + print v17_; } diff --git a/tests/snapshots/files__if_in_loop-optimize.snap b/tests/snapshots/files__if_in_loop-optimize.snap index 187187d1..6ff71a95 100644 --- a/tests/snapshots/files__if_in_loop-optimize.snap +++ b/tests/snapshots/files__if_in_loop-optimize.snap @@ -7,41 +7,39 @@ expression: visualization.result c2_: int = const 0; c3_: int = const 1; c4_: int = const 10; - v5_: bool = lt v0 c3_; - v6_: int = id c2_; - v7_: int = id c3_; - v8_: int = id v0; - v9_: int = id c2_; - v10_: int = id c4_; - v11_: bool = id v5_; -.b12_: - v13_: bool = lt v6_ v10_; - v14_: bool = id v13_; + v5_: int = id c2_; + v6_: int = id c3_; + v7_: int = id v0; + v8_: int = id c2_; + v9_: int = id c4_; +.b10_: + v11_: bool = lt v7_ v6_; + v12_: bool = lt v5_ v9_; + v13_: bool = id v12_; + v14_: int = id v5_; v15_: int = id v6_; - v16_: int = id v7_; - v17_: int = id v9_; + v16_: int = id v8_; + v17_: int = id v7_; v18_: int = id v8_; v19_: int = id v9_; - v20_: int = id v10_; - br v11_ .b21_ .b22_; -.b21_: - v14_: bool = id v13_; + br v11_ .b20_ .b21_; +.b20_: + v13_: bool = id v12_; + v14_: int = id v5_; v15_: int = id v6_; - v16_: int = id v7_; + v16_: int = id v6_; v17_: int = id v7_; v18_: int = id v8_; v19_: int = id v9_; - v20_: int = id v10_; -.b22_: - print v17_; +.b21_: + print v16_; print v11_; - v23_: int = add v6_ v7_; - v6_: int = id v23_; + v22_: int = add v5_ v6_; + v5_: int = id v22_; + v6_: int = id v6_; v7_: int = id v7_; v8_: int = id v8_; v9_: int = id v9_; - v10_: int = id v10_; - v11_: bool = id v11_; - br v13_ .b12_ .b24_; -.b24_: + br v12_ .b10_ .b23_; +.b23_: } diff --git a/tests/snapshots/files__nested_call-optimize.snap b/tests/snapshots/files__nested_call-optimize.snap index f45a1c0f..45ff3258 100644 --- a/tests/snapshots/files__nested_call-optimize.snap +++ b/tests/snapshots/files__nested_call-optimize.snap @@ -4,10 +4,10 @@ expression: visualization.result --- @inc(v0: int): int { .b1_: - c2_: int = const 1; - v3_: int = add c2_ v0; - c4_: int = const 2; - v5_: int = mul c4_ v3_; + c2_: int = const 2; + c3_: int = const 1; + v4_: int = add c3_ v0; + v5_: int = mul c2_ v4_; ret v5_; } @double(v0: int): int { diff --git a/tests/snapshots/files__sqrt-optimize.snap b/tests/snapshots/files__sqrt-optimize.snap index c0f3bcce..19c2e1cc 100644 --- a/tests/snapshots/files__sqrt-optimize.snap +++ b/tests/snapshots/files__sqrt-optimize.snap @@ -12,7 +12,7 @@ expression: visualization.result ret; .b5_: v6_: bool = feq v0 v0; - c7_: bool = const true; + c7_: bool = const false; v8_: float = id c2_; v9_: bool = id c7_; br v6_ .b10_ .b11_; @@ -36,8 +36,8 @@ expression: visualization.result v28_: float = fadd v21_ v27_; v29_: float = fdiv v28_ v24_; v30_: float = fdiv v29_ v21_; - v31_: bool = fle v30_ v22_; - v32_: bool = fge v30_ v23_; + v31_: bool = fge v30_ v23_; + v32_: bool = fle v30_ v22_; v33_: bool = and v31_ v32_; v34_: bool = not v33_; v20_: float = id v20_; @@ -51,14 +51,15 @@ expression: visualization.result print v21_; v13_: float = id v20_; .b14_: + v36_: bool = not v12_; v8_: float = id v13_; - v9_: bool = id v12_; - br v9_ .b11_ .b36_; -.b36_: + v9_: bool = id v36_; + br v9_ .b37_ .b11_; +.b37_: ret; .b11_: - v37_: float = fdiv v8_ v8_; - print v37_; - jmp .b36_; -.b38_: + v38_: float = fdiv v8_ v8_; + print v38_; + jmp .b37_; +.b39_: }