diff --git a/.gitignore b/.gitignore index 53547711..47d80105 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,8 @@ default.nix TODO.md flamegraph.svg perf.data* +profile.json *.bench .vscode/settings.json out.csv +tmp* diff --git a/Cargo.toml b/Cargo.toml index b56dc602..610fe76f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,8 +35,20 @@ vectorize = {version = "0.2.0", optional = true} saturating = "0.1.0" serde_json = {version = "1.0.81", optional = true} +# for parallelisation support +# rayon = {version = "1.10.0", optional = true} +rayon = "1.10.0" + [dev-dependencies] ordered-float = "3.0.0" +criterion = "0.3" + +[[bench]] +name = "math" +harness = false + +[profile.bench] +debug = true [features] # forces the use of indexmaps over hashmaps @@ -51,6 +63,7 @@ serde-1 = [ "vectorize", ] wasm-bindgen = [] +# parallel = ["rayon"] # private features for testing test-explanations = [] diff --git a/Makefile b/Makefile index 229977bf..a8f51e61 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ test: cargo test --release --features=lp # don't run examples in proof-production mode cargo test --release --features "test-explanations" - + .PHONY: nits nits: @@ -23,4 +23,29 @@ nits: .PHONY: docs docs: - RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features --open \ No newline at end of file + RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features --open + + + +math.csv: + EGG_BENCH_CSV=math.csv cargo test --test math --release -- --nocapture --test --test-threads=1 + +lambda.csv: + EGG_BENCH_CSV=lambda.csv cargo test --test lambda --release -- --nocapture --test --test-threads=1 + +.PHONY: existing-bench +existing-bench: math.csv lambda.csv + +.PHONY: clean-bench +clean-bench: + rm math.csv lambda.csv profile.json + +.PHONY: bench +bench: + cargo build --profile test && cargo bench + +profile.json: + cargo build --profile test && samply record cargo bench + +.PHONY: profile +profile: profile.json diff --git a/benches/definitions.rs b/benches/definitions.rs new file mode 100644 index 00000000..93de9bec --- /dev/null +++ b/benches/definitions.rs @@ -0,0 +1,242 @@ +pub mod simple { + use egg::*; + + define_language! { + pub enum SimpleLanguage { + Num(i32), + "+" = Add([Id; 2]), + "*" = Mul([Id; 2]), + Symbol(Symbol), + } + } + + pub fn make_rules() -> Vec> { + vec![ + rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), + rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"), + rewrite!("add-0"; "(+ ?a 0)" => "?a"), + rewrite!("mul-0"; "(* ?a 0)" => "0"), + rewrite!("mul-1"; "(* ?a 1)" => "?a"), + ] + } + + pub const EXAMPLE_INPUTS: &'static [&'static str] = &[ + "(* 0 42)", + "(+ 0 (* 1 foo))" + ]; +} + + +pub mod math { + use egg::{rewrite as rw, *}; + use ordered_float::NotNan; + + pub type EGraph = egg::EGraph; + pub type Rewrite = egg::Rewrite; + + pub type Constant = NotNan; + + define_language! { + pub enum Math { + "d" = Diff([Id; 2]), + "i" = Integral([Id; 2]), + + "+" = Add([Id; 2]), + "-" = Sub([Id; 2]), + "*" = Mul([Id; 2]), + "/" = Div([Id; 2]), + "pow" = Pow([Id; 2]), + "ln" = Ln(Id), + "sqrt" = Sqrt(Id), + + "sin" = Sin(Id), + "cos" = Cos(Id), + + Constant(Constant), + Symbol(Symbol), + } + } + + // You could use egg::AstSize, but this is useful for debugging, since + // it will really try to get rid of the Diff operator + pub struct MathCostFn; + impl egg::CostFunction for MathCostFn { + type Cost = usize; + fn cost(&mut self, enode: &Math, mut costs: C) -> Self::Cost + where + C: FnMut(Id) -> Self::Cost, + { + let op_cost = match enode { + Math::Diff(..) => 100, + Math::Integral(..) => 100, + _ => 1, + }; + enode.fold(op_cost, |sum, i| sum + costs(i)) + } + } + + #[derive(Default)] + pub struct ConstantFold; + impl Analysis for ConstantFold { + type Data = Option<(Constant, PatternAst)>; + + fn make(egraph: &mut EGraph, enode: &Math) -> Self::Data { + let x = |i: &Id| egraph[*i].data.as_ref().map(|d| d.0); + Some(match enode { + Math::Constant(c) => (*c, format!("{}", c).parse().unwrap()), + Math::Add([a, b]) => ( + x(a)? + x(b)?, + format!("(+ {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + Math::Sub([a, b]) => ( + x(a)? - x(b)?, + format!("(- {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + Math::Mul([a, b]) => ( + x(a)? * x(b)?, + format!("(* {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + Math::Div([a, b]) if x(b) != Some(NotNan::new(0.0).unwrap()) => ( + x(a)? / x(b)?, + format!("(/ {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + _ => return None, + }) + } + + fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { + merge_option(to, from, |a, b| { + assert_eq!(a.0, b.0, "Merged non-equal constants"); + DidMerge(false, false) + }) + } + + fn modify(egraph: &mut EGraph, id: Id) { + let data = egraph[id].data.clone(); + if let Some((c, pat)) = data { + if egraph.are_explanations_enabled() { + egraph.union_instantiations( + &pat, + &format!("{}", c).parse().unwrap(), + &Default::default(), + "constant_fold".to_string(), + ); + } else { + let added = egraph.add(Math::Constant(c)); + egraph.union(id, added); + } + // to not prune, comment this out + egraph[id].nodes.retain(|n| n.is_leaf()); + + #[cfg(debug_assertions)] + egraph[id].assert_unique_leaves(); + } + } + } + + fn is_const_or_distinct_var(v: &str, w: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let v = v.parse().unwrap(); + let w = w.parse().unwrap(); + move |egraph, _, subst| { + egraph.find(subst[v]) != egraph.find(subst[w]) + && (egraph[subst[v]].data.is_some() + || egraph[subst[v]] + .nodes + .iter() + .any(|n| matches!(n, Math::Symbol(..)))) + } + } + + fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let var = var.parse().unwrap(); + move |egraph, _, subst| egraph[subst[var]].data.is_some() + } + + fn is_sym(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let var = var.parse().unwrap(); + move |egraph, _, subst| { + egraph[subst[var]] + .nodes + .iter() + .any(|n| matches!(n, Math::Symbol(..))) + } + } + + fn is_not_zero(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let var = var.parse().unwrap(); + move |egraph, _, subst| { + if let Some(n) = &egraph[subst[var]].data { + *(n.0) != 0.0 + } else { + true + } + } + } + + #[rustfmt::skip] + pub fn rules() -> Vec { vec![ + rw!("comm-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), + rw!("comm-mul"; "(* ?a ?b)" => "(* ?b ?a)"), + rw!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"), + rw!("assoc-mul"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"), + + rw!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"), + rw!("div-canon"; "(/ ?a ?b)" => "(* ?a (pow ?b -1))" if is_not_zero("?b")), + // rw!("canon-sub"; "(+ ?a (* -1 ?b))" => "(- ?a ?b)"), + // rw!("canon-div"; "(* ?a (pow ?b -1))" => "(/ ?a ?b)" if is_not_zero("?b")), + + rw!("zero-add"; "(+ ?a 0)" => "?a"), + rw!("zero-mul"; "(* ?a 0)" => "0"), + rw!("one-mul"; "(* ?a 1)" => "?a"), + + rw!("add-zero"; "?a" => "(+ ?a 0)"), + rw!("mul-one"; "?a" => "(* ?a 1)"), + + rw!("cancel-sub"; "(- ?a ?a)" => "0"), + rw!("cancel-div"; "(/ ?a ?a)" => "1" if is_not_zero("?a")), + + rw!("distribute"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"), + rw!("factor" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"), + + rw!("pow-mul"; "(* (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (+ ?b ?c))"), + rw!("pow0"; "(pow ?x 0)" => "1" + if is_not_zero("?x")), + rw!("pow1"; "(pow ?x 1)" => "?x"), + rw!("pow2"; "(pow ?x 2)" => "(* ?x ?x)"), + rw!("pow-recip"; "(pow ?x -1)" => "(/ 1 ?x)" + if is_not_zero("?x")), + rw!("recip-mul-div"; "(* ?x (/ 1 ?x))" => "1" if is_not_zero("?x")), + + rw!("d-variable"; "(d ?x ?x)" => "1" if is_sym("?x")), + rw!("d-constant"; "(d ?x ?c)" => "0" if is_sym("?x") if is_const_or_distinct_var("?c", "?x")), + + rw!("d-add"; "(d ?x (+ ?a ?b))" => "(+ (d ?x ?a) (d ?x ?b))"), + rw!("d-mul"; "(d ?x (* ?a ?b))" => "(+ (* ?a (d ?x ?b)) (* ?b (d ?x ?a)))"), + + rw!("d-sin"; "(d ?x (sin ?x))" => "(cos ?x)"), + rw!("d-cos"; "(d ?x (cos ?x))" => "(* -1 (sin ?x))"), + + rw!("d-ln"; "(d ?x (ln ?x))" => "(/ 1 ?x)" if is_not_zero("?x")), + + rw!("d-power"; + "(d ?x (pow ?f ?g))" => + "(* (pow ?f ?g) + (+ (* (d ?x ?f) + (/ ?g ?f)) + (* (d ?x ?g) + (ln ?f))))" + if is_not_zero("?f") + if is_not_zero("?g") + ), + + rw!("i-one"; "(i 1 ?x)" => "?x"), + rw!("i-power-const"; "(i (pow ?x ?c) ?x)" => + "(/ (pow ?x (+ ?c 1)) (+ ?c 1))" if is_const("?c")), + rw!("i-cos"; "(i (cos ?x) ?x)" => "(sin ?x)"), + rw!("i-sin"; "(i (sin ?x) ?x)" => "(* -1 (cos ?x))"), + rw!("i-sum"; "(i (+ ?f ?g) ?x)" => "(+ (i ?f ?x) (i ?g ?x))"), + rw!("i-dif"; "(i (- ?f ?g) ?x)" => "(- (i ?f ?x) (i ?g ?x))"), + rw!("i-parts"; "(i (* ?a ?b) ?x)" => + "(- (* ?a (i ?b ?x)) (i (* (d ?x ?a) (i ?b ?x)) ?x))"), + ]} +} diff --git a/benches/lambda.rs b/benches/lambda.rs new file mode 100644 index 00000000..e69de29b diff --git a/benches/math.rs b/benches/math.rs new file mode 100644 index 00000000..094a4c95 --- /dev/null +++ b/benches/math.rs @@ -0,0 +1,419 @@ +use egg::{rewrite as rw, *}; + +mod definitions; +use definitions::math; + +use criterion::{criterion_group, criterion_main, Criterion}; + +// fn math_() { +// egg::test::test_runner( +// "", +// None, +// &math::rules(), +// "".parse().unwrap(), +// &["".parse().unwrap()], +// None, +// true +// ) +// } + +fn math_associate_adds() { + egg::test::test_runner( + "math_associate_adds", + Some(Runner::default() + .with_iter_limit(7) + .with_scheduler(SimpleScheduler)), + &[ + rw!("comm-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), + rw!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"), + ], + "(+ 1 (+ 2 (+ 3 (+ 4 (+ 5 (+ 6 7))))))".parse().unwrap(), + &["(+ 7 (+ 6 (+ 5 (+ 4 (+ 3 (+ 2 1))))))".parse().unwrap()], + Some(|r: Runner| assert_eq!(r.egraph.number_of_classes(), 127)), + true + ) +} + +// NOTE: Less suitable for benchmarking +// fn math_fail() { +// let _ = std::panic::catch_unwind( +// || egg::test::test_runner( +// "math_fail", +// None, +// &math::rules(), +// "(+ x y)".parse().unwrap(), +// &["(/ x y)".parse().unwrap()], +// None, +// true +// ) +// ); +// } + +fn math_simplify_add() { + egg::test::test_runner( + "math_simplify_add", + None, + &math::rules(), + "(+ x (+ x (+ x x)))".parse().unwrap(), + &["(* 4 x)".parse().unwrap()], + None, + true + ) +} + +fn math_powers() { + egg::test::test_runner( + "math_powers", + None, + &math::rules(), + "(* (pow 2 x) (pow 2 y))".parse().unwrap(), + &["(pow 2 (+ x y))".parse().unwrap()], + None, + true + ) +} + +fn math_simplify_const() { + egg::test::test_runner( + "math_simplify_const", + None, + &math::rules(), + "(+ 1 (- a (* (- 2 1) a)))".parse().unwrap(), + &["1".parse().unwrap()], + None, + true + ) +} + +fn math_simplify_root() { + egg::test::test_runner( + "math_simplify_root", + Some(Runner::default().with_node_limit(75_000)), + &math::rules(), + r#" + (/ 1 + (- (/ (+ 1 (sqrt five)) + 2) + (/ (- 1 (sqrt five)) + 2)))"#.parse().unwrap(), + &["(/ 1 (sqrt five))".parse().unwrap()], + None, + true + ) +} + +fn math_simplify_factor() { + egg::test::test_runner( + "math_simplify_factor", + None, + &math::rules(), + "(* (+ x 3) (+ x 1))".parse().unwrap(), + &["(+ (+ (* x x) (* 4 x)) 3)".parse().unwrap()], + None, + true + ) +} + +fn math_diff_same() { + egg::test::test_runner( + "math_diff_same", + None, + &math::rules(), + "(d x x)".parse().unwrap(), + &["1".parse().unwrap()], + None, + true + ) +} + +fn math_diff_different() { + egg::test::test_runner( + "math_diff_different", + None, + &math::rules(), + "(d x y)".parse().unwrap(), + &["0".parse().unwrap()], + None, + true + ) +} + +fn math_diff_simple1() { + egg::test::test_runner( + "math_diff_simple1", + None, + &math::rules(), + "(d x (+ 1 (* 2 x)))".parse().unwrap(), + &["2".parse().unwrap()], + None, + true + ) +} + +fn math_diff_simple2() { + egg::test::test_runner( + "math_diff_simple2", + None, + &math::rules(), + "(d x (+ 1 (* y x)))".parse().unwrap(), + &["y".parse().unwrap()], + None, + true + ) +} + +fn math_diff_ln() { + egg::test::test_runner( + "math_diff_ln", + None, + &math::rules(), + "(d x (ln x))".parse().unwrap(), + &["(/ 1 x)".parse().unwrap()], + None, + true + ) +} + +fn diff_power_simple() { + egg::test::test_runner( + "diff_power_simple", + None, + &math::rules(), + "(d x (pow x 3))".parse().unwrap(), + &["(* 3 (pow x 2))".parse().unwrap()], + None, + true + ) +} + +fn diff_power_harder() { + egg::test::test_runner( + "diff_power_harder", + Some(Runner::default() + .with_time_limit(std::time::Duration::from_secs(10)) + .with_iter_limit(60) + .with_node_limit(100_000) + .with_explanations_enabled() + // HACK this needs to "see" the end expression + .with_expr(&"(* x (- (* 3 x) 14))".parse().unwrap())), + &math::rules(), + "(d x (- (pow x 3) (* 7 (pow x 2))))".parse().unwrap(), + &["(* x (- (* 3 x) 14))".parse().unwrap()], + None, + true + ) +} + +fn integ_one() { + egg::test::test_runner( + "integ_one", + None, + &math::rules(), + "(i 1 x)".parse().unwrap(), + &["x".parse().unwrap()], + None, + true + ) +} + +fn integ_sin() { + egg::test::test_runner( + "integ_sin", + None, + &math::rules(), + "(i (cos x) x)".parse().unwrap(), + &["(sin x)".parse().unwrap()], + None, + true + ) +} + +fn integ_x() { + egg::test::test_runner( + "integ_x", + None, + &math::rules(), + "(i (pow x 1) x)".parse().unwrap(), + &["(/ (pow x 2) 2)".parse().unwrap()], + None, + true + ) +} + +fn integ_part1() { + egg::test::test_runner( + "integ_part1", + None, + &math::rules(), + "(i (* x (cos x)) x)".parse().unwrap(), + &["(+ (* x (sin x)) (cos x))".parse().unwrap()], + None, + true + ) +} + +fn integ_part2() { + egg::test::test_runner( + "integ_part2", + None, + &math::rules(), + "(i (* (cos x) x) x)".parse().unwrap(), + &["(+ (* x (sin x)) (cos x))".parse().unwrap()], + None, + true + ) +} + +fn integ_part3() { + egg::test::test_runner( + "integ_part3", + None, + &math::rules(), + "(i (ln x) x)".parse().unwrap(), + &["(- (* x (ln x)) x)".parse().unwrap()], + None, + true + ) +} + +pub fn ematching_benches(c: &mut Criterion) { + let exprs = &[ + "(i (ln x) x)", + "(i (+ x (cos x)) x)", + "(i (* (cos x) x) x)", + "(d x (+ 1 (* 2 x)))", + "(d x (- (pow x 3) (* 7 (pow x 2))))", + "(+ (* y (+ x y)) (- (+ x 2) (+ x x)))", + "(/ 1 (- (/ (+ 1 (sqrt five)) 2) (/ (- 1 (sqrt five)) 2)))", + ]; + + let extra_patterns = &[ + "(+ ?a (+ ?b ?c))", + "(+ (+ ?a ?b) ?c)", + "(* ?a (* ?b ?c))", + "(* (* ?a ?b) ?c)", + "(+ ?a (* -1 ?b))", + "(* ?a (pow ?b -1))", + "(* ?a (+ ?b ?c))", + "(pow ?a (+ ?b ?c))", + "(+ (* ?a ?b) (* ?a ?c))", + "(* (pow ?a ?b) (pow ?a ?c))", + "(* ?x (/ 1 ?x))", + "(d ?x (+ ?a ?b))", + "(+ (d ?x ?a) (d ?x ?b))", + "(d ?x (* ?a ?b))", + "(+ (* ?a (d ?x ?b)) (* ?b (d ?x ?a)))", + "(d ?x (sin ?x))", + "(d ?x (cos ?x))", + "(* -1 (sin ?x))", + "(* -1 (cos ?x))", + "(i (cos ?x) ?x)", + "(i (sin ?x) ?x)", + "(d ?x (ln ?x))", + "(d ?x (pow ?f ?g))", + "(* (pow ?f ?g) (+ (* (d ?x ?f) (/ ?g ?f)) (* (d ?x ?g) (ln ?f))))", + "(i (pow ?x ?c) ?x)", + "(/ (pow ?x (+ ?c 1)) (+ ?c 1))", + "(i (+ ?f ?g) ?x)", + "(i (- ?f ?g) ?x)", + "(+ (i ?f ?x) (i ?g ?x))", + "(- (i ?f ?x) (i ?g ?x))", + "(i (* ?a ?b) ?x)", + "(- (* ?a (i ?b ?x)) (i (* (d ?x ?a) (i ?b ?x)) ?x))", + ]; + + c.bench_function( + "ematching_benches", + |b| b.iter( + || egg::test::bench_egraph("math", math::rules(), exprs, extra_patterns) + ) + ); +} + +pub fn math_tests(c: &mut Criterion) { + let mut group = c.benchmark_group("math_tests"); + group.bench_function( + "math_associate_adds", + |b| b.iter(math_associate_adds) + ); + // group.bench_function( + // "math_fail", + // |b| b.iter(math_fail) + // ); + group.bench_function( + "math_simplify_add", + |b| b.iter(math_simplify_add) + ); + group.bench_function( + "math_powers", + |b| b.iter(math_powers) + ); + group.bench_function( + "math_simplify_const", + |b| b.iter(math_simplify_const) + ); + group.bench_function( + "math_simplify_root", + |b| b.iter(math_simplify_root) + ); + group.bench_function( + "math_simplify_factor", + |b| b.iter(math_simplify_factor) + ); + group.bench_function( + "math_diff_same", + |b| b.iter(math_diff_same) + ); + group.bench_function( + "math_diff_different", + |b| b.iter(math_diff_different) + ); + group.bench_function( + "math_diff_simple1", + |b| b.iter(math_diff_simple1) + ); + group.bench_function( + "math_diff_simple2", + |b| b.iter(math_diff_simple2) + ); + group.bench_function( + "math_diff_ln", + |b| b.iter(math_diff_ln) + ); + group.bench_function( + "diff_power_simple", + |b| b.iter(diff_power_simple) + ); + group.bench_function( + "diff_power_harder", + |b| b.iter(diff_power_harder) + ); + group.bench_function( + "integ_one", + |b| b.iter(integ_one) + ); + group.bench_function( + "integ_sin", + |b| b.iter(integ_sin) + ); + group.bench_function( + "integ_x", + |b| b.iter(integ_x) + ); + group.bench_function( + "integ_part1", + |b| b.iter(integ_part1) + ); + group.bench_function( + "integ_part2", + |b| b.iter(integ_part2) + ); + group.bench_function( + "integ_part3", + |b| b.iter(integ_part3) + ); + group.finish(); +} + +criterion_group!(benches, ematching_benches); +criterion_main!(benches); diff --git a/benches/schedulers.rs b/benches/schedulers.rs new file mode 100644 index 00000000..42a24df7 --- /dev/null +++ b/benches/schedulers.rs @@ -0,0 +1,92 @@ +pub mod schedulers { + use egg::*; + use rayon::prelude::*; + + pub struct SerialRewriteScheduler; + impl> RewriteScheduler for SerialRewriteScheduler { + fn search_rewrites<'a>( + &mut self, + iteration: usize, + egraph: &EGraph, + rewrites: &[&'a Rewrite], + limits: &RunnerLimits, + ) -> RunnerResult>>> { + rewrites + .iter() + .map(|rw| { + let ms = rw.search(egraph); + limits.check_limits(iteration, egraph)?; + Ok(ms) + }) + .collect() + } + } + + pub struct ParallelRewriteScheduler; + impl RewriteScheduler for ParallelRewriteScheduler + where + L: Language + Sync + Send, + L::Discriminant: Sync + Send, + N: Analysis + Sync + Send, + N::Data: Sync + Send + { + // impl RewriteScheduler for ParallelRewriteScheduler { + fn search_rewrites<'a>( + &mut self, + iteration: usize, + egraph: &EGraph, + rewrites: &[&'a Rewrite], + limits: &RunnerLimits, + ) -> RunnerResult>>> { + // This implementation just ignores the limits + // fake `par_map` to enforce Send + Sync, in real life use rayon + // fn par_map(slice: &[T], f: F) -> Vec + // where + // T: Send + Sync, + // F: Fn(&T) -> T2 + Send + Sync, + // T2: Send + Sync, + // { + // slice.iter().map(f).collect() + // } + // Ok(par_map(rewrites, |rw| rw.search(egraph))) + + rewrites + .par_iter() + .map(|rw| { + let ms = rw.search(egraph); + limits.check_limits(iteration, egraph)?; + Ok(ms) + }) + .collect() // ::>>>>() + + // TODO: Note that `Sync + Send` traits were added to both language and + // discriminant. Could this impact correctness? + } + } + + + pub struct RestrictedParallelRewriteScheduler; + impl RewriteScheduler for RestrictedParallelRewriteScheduler + where + L: Language + Sync + Send, + L::Discriminant: Sync + Send, + { + // impl RewriteScheduler for ParallelRewriteScheduler { + fn search_rewrites<'a>( + &mut self, + iteration: usize, + egraph: &EGraph, + rewrites: &[&'a Rewrite], + limits: &RunnerLimits, + ) -> RunnerResult>>> { + rewrites + .par_iter() + .map(|rw| { + let ms = rw.search(egraph); + limits.check_limits(iteration, egraph)?; + Ok(ms) + }) + .collect() + } + } +} diff --git a/benches/simple.rs b/benches/simple.rs new file mode 100644 index 00000000..f156b6d5 --- /dev/null +++ b/benches/simple.rs @@ -0,0 +1,121 @@ +use egg::*; +use rayon::prelude::*; + +mod definitions; +use definitions::simple; + +mod schedulers; +use schedulers::schedulers::*; + +use criterion::{criterion_group, criterion_main, Criterion, BenchmarkId}; + + +fn serial_simplify(s: &str) -> String { + let expr: RecExpr = s.parse().unwrap(); + let runner = Runner::default() + .with_scheduler(SerialRewriteScheduler) + .with_expr(&expr) + .run(&simple::make_rules()); + let root = runner.roots[0]; + let extractor = Extractor::new(&runner.egraph, AstSize); + let (_best_cost, best) = extractor.find_best(root); + best.to_string() +} + +fn parallel_simplify(s: &str) -> String { + let expr: RecExpr = s.parse().unwrap(); + let runner = Runner::default() + .with_scheduler(ParallelRewriteScheduler) + .with_expr(&expr) + .run(&simple::make_rules()); + let root = runner.roots[0]; + let extractor = Extractor::new(&runner.egraph, AstSize); + let (_best_cost, best) = extractor.find_best(root); + best.to_string() +} + +fn restricted_parallel_simplify(s: &str) -> String { + let expr: RecExpr = s.parse().unwrap(); + let runner = Runner::default() + .with_scheduler(RestrictedParallelRewriteScheduler) + .with_expr(&expr) + .run(&simple::make_rules()); + let root = runner.roots[0]; + let extractor = Extractor::new(&runner.egraph, AstSize); + let (_best_cost, best) = extractor.find_best(root); + best.to_string() +} + + +pub fn serial_simple_bench(c: &mut Criterion) { + c.bench_function( + "serial_simplify", + |b| b.iter( + || serial_simplify("(+ 0 (* 1 foo))") + ) + ); +} + +pub fn parallel_simple_bench(c: &mut Criterion) { + c.bench_function( + "parallel_simplify", + |b| b.iter( + || parallel_simplify("(+ 0 (* 1 foo))") + ) + ); +} + +pub fn restricted_parallel_simple_bench(c: &mut Criterion) { + c.bench_function( + "restricted_parallel_simplify", + |b| b.iter( + || restricted_parallel_simplify("(+ 0 (* 1 foo))") + ) + ); +} + +pub fn comparison_simple_bench(c: &mut Criterion) { + let mut group = c.benchmark_group("simplify"); + for i in simple::EXAMPLE_INPUTS.iter() { + group.bench_with_input(BenchmarkId::new("serial_simplify", i), i, + |b, i| b.iter(|| serial_simplify(*i))); + group.bench_with_input(BenchmarkId::new("parallel_simplify", i), i, + |b, i| b.iter(|| parallel_simplify(*i))); + group.bench_with_input(BenchmarkId::new("restricted_parallel_simplify", i), i, + |b, i| b.iter(|| restricted_parallel_simplify(*i))); + } + group.finish(); +} + + +// fn math_serial_simplify_root() { +// egg::test::test_runner( +// "math_simplify_root", +// Some(Runner::default().with_node_limit(75_000)), +// &math::rules(), +// r#" +// (/ 1 +// (- (/ (+ 1 (sqrt five)) +// 2) +// (/ (- 1 (sqrt five)) +// 2)))"#.parse().unwrap(), +// &["(/ 1 (sqrt five))".parse().unwrap()], +// None, +// true +// ) +// } + +// pub fn math_bench(c: &mut Criterion) { +// c.bench_function( +// "math_simplify_root", +// |b| b.iter(math_serial_simplify_root) +// ); +// //c.bench_function( +// // "math_simplify_factor", +// // |b| b.iter(math_simplify_factor) +// //); +// } + + +criterion_group!(benches, comparison_simple_bench); +criterion_main!(benches); diff --git a/src/egraph.rs b/src/egraph.rs index d3213bed..9f271d5c 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1220,10 +1220,10 @@ impl> EGraph { /// Useful for testing. pub fn check_goals(&self, id: Id, goals: &[Pattern]) { let (cost, best) = Extractor::new(self, AstSize).find_best(id); - println!("End ({}): {}", cost, best.pretty(80)); + info!("End ({}): {}", cost, best.pretty(80)); for (i, goal) in goals.iter().enumerate() { - println!("Trying to prove goal {}: {}", i, goal.pretty(40)); + info!("Trying to prove goal {}: {}", i, goal.pretty(40)); let matches = goal.search_eclass(self, id); if matches.is_none() { let best = Extractor::new(self, AstSize).find_best(id).1; diff --git a/src/run.rs b/src/run.rs index fbd37b30..56ab62ac 100644 --- a/src/run.rs +++ b/src/run.rs @@ -566,10 +566,50 @@ where let search_time = start_time.elapsed().as_secs_f64(); info!("Search time: {}", search_time); - let apply_time = Instant::now(); + // // =================================================== // + // // ===== Novel algorithm with separated rebuilds ===== // + // // =================================================== // + // let apply_time = Instant::now(); + + // result = result.and_then(|_| { + // rules.iter().zip(matches).try_for_each(|(rw, ms)| { + // let total_matches: usize = ms.iter().map(|m| m.substs.len()).sum(); + // debug!("Applying {} {} times", rw.name, total_matches); + + // let actually_matched = self.scheduler.apply_rewrite(i, &mut self.egraph, rw, ms); + // if actually_matched > 0 { + // if let Some(count) = applied.get_mut(&rw.name) { + // *count += actually_matched; + // } else { + // applied.insert(rw.name.to_owned(), actually_matched); + // } + // debug!("Applied {} {} times", rw.name, actually_matched); + // } + // self.check_limits() + // }) + // }); + + // let apply_time = apply_time.elapsed().as_secs_f64(); + + // let rebuild_time = Instant::now(); + // let n_rebuilds = self.egraph.rebuild(); + // if self.egraph.are_explanations_enabled() { + // debug_assert!(self.egraph.check_each_explain(rules)); + // } + // let rebuild_time = rebuild_time.elapsed().as_secs_f64(); + // // =================================================== // + + + // // =================================================== // + // // ===== Old algorithm with interleaved rebuilds ===== // + // // =================================================== // + let mut apply_time = 0.0_f64; + let mut rebuild_time = 0.0_f64; + let mut n_rebuilds = 0; result = result.and_then(|_| { rules.iter().zip(matches).try_for_each(|(rw, ms)| { + let single_apply_time = Instant::now(); let total_matches: usize = ms.iter().map(|m| m.substs.len()).sum(); debug!("Applying {} {} times", rw.name, total_matches); @@ -582,20 +622,23 @@ where } debug!("Applied {} {} times", rw.name, actually_matched); } - self.check_limits() + let limits = self.check_limits(); + apply_time += single_apply_time.elapsed().as_secs_f64(); + + + let single_rebuild_time = Instant::now(); + n_rebuilds += self.egraph.rebuild(); + if self.egraph.are_explanations_enabled() { + debug_assert!(self.egraph.check_each_explain(rules)); + } + rebuild_time += single_rebuild_time.elapsed().as_secs_f64(); + + limits }) }); + // // =================================================== // - let apply_time = apply_time.elapsed().as_secs_f64(); info!("Apply time: {}", apply_time); - - let rebuild_time = Instant::now(); - let n_rebuilds = self.egraph.rebuild(); - if self.egraph.are_explanations_enabled() { - debug_assert!(self.egraph.check_each_explain(rules)); - } - - let rebuild_time = rebuild_time.elapsed().as_secs_f64(); info!("Rebuild time: {}", rebuild_time); info!( "Size: n={}, e={}", @@ -603,6 +646,7 @@ where self.egraph.number_of_classes() ); + let can_be_saturated = applied.is_empty() && self.scheduler.can_stop(i) // now make sure the hooks didn't do anything diff --git a/src/test.rs b/src/test.rs index 8565f79a..d7568874 100644 --- a/src/test.rs +++ b/src/test.rs @@ -6,6 +6,8 @@ These are not considered part of the public api. use num_traits::identities::Zero; use std::{fmt::Display, fs::File, io::Write, path::PathBuf}; +use log::*; + use crate::*; pub fn env_var(s: &str) -> Option @@ -80,7 +82,7 @@ pub fn test_runner( if should_check { let report = runner.report(); - println!("{report}"); + info!("{report}"); runner.egraph.check_goals(id, goals); if let Some(filename) = env_var::("EGG_BENCH_CSV") { @@ -89,7 +91,7 @@ pub fn test_runner( .append(true) .open(&filename) .unwrap_or_else(|_| panic!("Couldn't open {:?}", filename)); - writeln!(file, "{},{}", name, runner.report().total_time).unwrap(); + // writeln!(file, "{},{}", name, runner.report().total_time).unwrap(); } if runner.egraph.are_explanations_enabled() { @@ -153,7 +155,7 @@ where patterns.push(p.ast.alpha_rename().into()); } - eprintln!("{} patterns", patterns.len()); + info!("{} patterns", patterns.len()); patterns.retain(|p| p.ast.len() > 1); patterns.sort_by_key(|p| p.to_string()); @@ -164,8 +166,8 @@ where let node_limit = env_var("EGG_NODE_LIMIT").unwrap_or(1_000_000); let time_limit = env_var("EGG_TIME_LIMIT").unwrap_or(1000); let n_samples = env_var("EGG_SAMPLES").unwrap_or(100); - eprintln!("Benching {} samples", n_samples); - eprintln!( + info!("Benching {} samples", n_samples); + info!( "Limits: {} iters, {} nodes, {} seconds", iter_limit, node_limit, time_limit ); @@ -174,7 +176,7 @@ where .with_scheduler(SimpleScheduler) .with_hook(move |runner| { let n_nodes = runner.egraph.total_number_of_nodes(); - eprintln!("Iter {}, {} nodes", runner.iterations.len(), n_nodes); + info!("Iter {}, {} nodes", runner.iterations.len(), n_nodes); if n_nodes > node_limit { Err("Bench stopped".into()) } else { @@ -190,7 +192,7 @@ where } let runner = runner.run(&rules); - eprintln!("{}", runner.report()); + info!("{}", runner.report()); let egraph = runner.egraph; let get_len = |pat: &Pattern| pat.to_string().len(); @@ -207,7 +209,7 @@ where .collect(); times.sort_unstable(); - println!( + info!( "test {name:10} ns/iter (+/- {iqr})", name = pat.to_string().replace(' ', "_"), width = max_width, diff --git a/tests/prop.rs b/tests/prop.rs index ed1c7469..7aa5548d 100644 --- a/tests/prop.rs +++ b/tests/prop.rs @@ -1,5 +1,7 @@ use egg::*; +use log::*; + define_language! { enum Prop { Bool(bool), @@ -118,7 +120,7 @@ fn prove_something(name: &str, start: &str, rewrites: &[Rewrite], goals: &[&str] let egraph = runner.run(rewrites).egraph; for (i, (goal_expr, goal_str)) in goal_exprs.iter().zip(goals).enumerate() { - println!("Trying to prove goal {}: {}", i, goal_str); + info!("Trying to prove goal {}: {}", i, goal_str); let equivs = egraph.equivs(&start_expr, goal_expr); if equivs.is_empty() { panic!("Couldn't prove goal {}: {}", i, goal_str);