Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimised the classes field of EGraph to avoid extra hashing #284

Closed
wants to merge 2 commits into from

Conversation

dewert99
Copy link
Contributor

I noticed that the classes field of an EGraph is currently represented using a HashMap<Id, EClass> (or possibly an IndexMap). Since the set of possible Ids form a continuous range and the union_find field already contains an element for each Id, I thought I could try to use it like an "index" for an IndexMap, this works especially well since it currently maps non-canonical Ids to there parent, and we would need it to map canonical Ids to the index into the class list, so each element of the union_find would only need to have an index into another list and a tag of whether or not it's canonical. I decided to use one bit of a u32 as the tag so that the union_find would stay the same size.

Hopefully, this will improve performance (do you have a setup for reliably benchmarking two commits and comparing them?)

The main downside to this change is that EGraphs can now only support 2^31 Ids instead of 2^32 and that we now rely more on the EClasss id field so if a user changes it (this is possible since it is public) it would likely break the EGraph (if it didn't already).

@mwillsey
Copy link
Member

Interesting idea, seems like it's probably faster but I'd want to make sure before merging something like this with non-trivial complexity.

You can run the built-in test suite and control various knobs as seen in the test.rs file. Something like

EGG_BENCH_CSV=out.csv cargo test --release --test=math -- --test-threads=1

You need test threads to be 1 to avoid concurrent writes to the CSV. You can do the same for the lambda benchmark suite.

Or you could try the e-matching benchmark suites:

EGG_ITER_LIMIT=5 cargo test --release -- --nocapture bench

Try tuning that limit up to make a more "significant" benchmark run.

@dewert99
Copy link
Contributor Author

Ok I tried to do some benchmarking. Since the performance was quite variable between runs I made a script to run the benchmark 20 times and summarize the results.

use polars::prelude::*;
use std::fs::File;
use std::io::{BufWriter, Write};
use tempdir::TempDir;

fn quantile(e: Expr, n: f32) -> Expr {
    e.quantile(lit(n), QuantileInterpolOptions::Nearest)
}

fn main() {
    let samples = std::env::args().nth(1).unwrap().parse().unwrap();
    let args = [
        "test",
        "--release",
        "--test=math",
        "--test=lambda",
        "--",
        "--test-threads=1",
    ];
    let tmp_dir = TempDir::new("temp").unwrap();
    let tmp_path = tmp_dir.path().join("out.csv");
    let _ = File::create(&tmp_path).unwrap();
    let mut command = std::process::Command::new("cargo");
    command.args(args).env("EGG_BENCH_CSV", &tmp_path);
    for i in 0..samples {
        command.output().unwrap();
        print!("{i} ");
        std::io::stdout().flush().unwrap();
    }
    std::env::set_var("POLARS_FMT_MAX_COLS", "-1");
    std::env::set_var("POLARS_FMT_MAX_ROWS", "-1");
    let df = LazyCsvReader::new(tmp_path)
        .has_header(false)
        .finish()
        .unwrap();
    let df = df.rename(["column_1", "column_2"], ["test", "time"]);
    let df = df.group_by([col("test")]).agg([
        col("time").mean().alias("mean"),
        col("time").median().alias("median"),
        quantile(col("time"), 0.25).alias(".25"),
        quantile(col("time"), 0.75).alias(".75"),
        col("time").count().alias("count"),
    ]);
    let df = df.sort("test", SortOptions::default());
    let out_path = std::env::args().nth(2).unwrap();
    let mut writer = BufWriter::new(File::create(out_path).unwrap());
    write!(writer, "{}", df.collect().unwrap()).unwrap()
}

I alternated running the repo before and after the changes and here are the results:

Before 1

shape: (28, 6)

┌────────────────────────┬──────────┬──────────┬──────────┬──────────┬───────┐
│ test                   ┆ mean     ┆ median   ┆ .25      ┆ .75      ┆ count │
│ ---                    ┆ ---      ┆ ---      ┆ ---      ┆ ---      ┆ ---   │
│ str                    ┆ f64      ┆ f64      ┆ f64      ┆ f64      ┆ u32   │
╞════════════════════════╪══════════╪══════════╪══════════╪══════════╪═══════╡
│ diff_power_harder      ┆ 0.011519 ┆ 0.011347 ┆ 0.011303 ┆ 0.011447 ┆ 20    │
│ diff_power_simple      ┆ 0.001099 ┆ 0.001081 ┆ 0.001076 ┆ 0.001114 ┆ 20    │
│ integ_one              ┆ 0.000057 ┆ 0.000056 ┆ 0.000054 ┆ 0.000058 ┆ 20    │
│ integ_part1            ┆ 0.001768 ┆ 0.001748 ┆ 0.001733 ┆ 0.001788 ┆ 20    │
│ integ_part2            ┆ 0.006823 ┆ 0.006753 ┆ 0.006744 ┆ 0.0068   ┆ 20    │
│ integ_part3            ┆ 0.000712 ┆ 0.000702 ┆ 0.000699 ┆ 0.000721 ┆ 20    │
│ integ_sin              ┆ 0.000052 ┆ 0.000051 ┆ 0.000051 ┆ 0.000052 ┆ 20    │
│ integ_x                ┆ 0.00007  ┆ 0.00007  ┆ 0.000069 ┆ 0.00007  ┆ 20    │
│ lambda_compose         ┆ 0.001281 ┆ 0.001253 ┆ 0.001233 ┆ 0.001266 ┆ 20    │
│ lambda_compose_many    ┆ 0.005435 ┆ 0.00519  ┆ 0.005167 ┆ 0.005326 ┆ 20    │
│ lambda_fib             ┆ 0.812159 ┆ 0.758073 ┆ 0.751036 ┆ 0.763596 ┆ 20    │
│ lambda_function_repeat ┆ 1.745412 ┆ 1.704187 ┆ 1.691335 ┆ 1.744573 ┆ 20    │
│ lambda_if              ┆ 0.000538 ┆ 0.00052  ┆ 0.000518 ┆ 0.000541 ┆ 20    │
│ lambda_if_elim         ┆ 0.000189 ┆ 0.000182 ┆ 0.000182 ┆ 0.000198 ┆ 20    │
│ lambda_if_simple       ┆ 0.000022 ┆ 0.000022 ┆ 0.000021 ┆ 0.000023 ┆ 20    │
│ lambda_let_simple      ┆ 0.000134 ┆ 0.000131 ┆ 0.00013  ┆ 0.000132 ┆ 20    │
│ lambda_under           ┆ 0.000056 ┆ 0.000054 ┆ 0.000054 ┆ 0.000056 ┆ 20    │
│ math_associate_adds    ┆ 0.011395 ┆ 0.011341 ┆ 0.011314 ┆ 0.01141  ┆ 20    │
│ math_diff_different    ┆ 0.000059 ┆ 0.000059 ┆ 0.000058 ┆ 0.00006  ┆ 20    │
│ math_diff_ln           ┆ 0.000052 ┆ 0.000052 ┆ 0.000052 ┆ 0.000053 ┆ 20    │
│ math_diff_same         ┆ 0.000053 ┆ 0.000053 ┆ 0.000052 ┆ 0.000053 ┆ 20    │
│ math_diff_simple1      ┆ 0.000696 ┆ 0.000694 ┆ 0.000693 ┆ 0.000699 ┆ 20    │
│ math_diff_simple2      ┆ 0.000534 ┆ 0.000533 ┆ 0.000531 ┆ 0.000535 ┆ 20    │
│ math_powers            ┆ 0.000076 ┆ 0.000075 ┆ 0.000074 ┆ 0.000076 ┆ 20    │
│ math_simplify_add      ┆ 0.000402 ┆ 0.000397 ┆ 0.000396 ┆ 0.000401 ┆ 20    │
│ math_simplify_const    ┆ 0.000194 ┆ 0.000191 ┆ 0.00019  ┆ 0.000192 ┆ 20    │
│ math_simplify_factor   ┆ 0.001625 ┆ 0.001613 ┆ 0.001609 ┆ 0.001631 ┆ 20    │
│ math_simplify_root     ┆ 0.002327 ┆ 0.002307 ┆ 0.002298 ┆ 0.002326 ┆ 20    │
└────────────────────────┴──────────┴──────────┴──────────┴──────────┴───────┘

Before 2

shape: (28, 6)
┌────────────────────────┬──────────┬──────────┬───────────┬───────────┬───────┐
│ test                   ┆ mean     ┆ median   ┆ .25       ┆ .75       ┆ count │
│ ---                    ┆ ---      ┆ ---      ┆ ---       ┆ ---       ┆ ---   │
│ str                    ┆ f64      ┆ f64      ┆ f64       ┆ f64       ┆ u32   │
╞════════════════════════╪══════════╪══════════╪═══════════╪═══════════╪═══════╡
│ diff_power_harder      ┆ 0.011681 ┆ 0.011343 ┆ 0.01131   ┆ 0.011516  ┆ 20    │
│ diff_power_simple      ┆ 0.001121 ┆ 0.001085 ┆ 0.001079  ┆ 0.001159  ┆ 20    │
│ integ_one              ┆ 0.000057 ┆ 0.000057 ┆ 0.000055  ┆ 0.00006   ┆ 20    │
│ integ_part1            ┆ 0.001812 ┆ 0.001784 ┆ 0.001742  ┆ 0.001861  ┆ 20    │
│ integ_part2            ┆ 0.007002 ┆ 0.00681  ┆ 0.00674   ┆ 0.007219  ┆ 20    │
│ integ_part3            ┆ 0.000713 ┆ 0.000704 ┆ 0.0007    ┆ 0.000717  ┆ 20    │
│ integ_sin              ┆ 0.000052 ┆ 0.000051 ┆ 0.000051  ┆ 0.000052  ┆ 20    │
│ integ_x                ┆ 0.00007  ┆ 0.00007  ┆ 0.000069  ┆ 0.000071  ┆ 20    │
│ lambda_compose         ┆ 0.001275 ┆ 0.001251 ┆ 0.001238  ┆ 0.00134   ┆ 20    │
│ lambda_compose_many    ┆ 0.005307 ┆ 0.005211 ┆ 0.005178  ┆ 0.005591  ┆ 20    │
│ lambda_fib             ┆ 0.781245 ┆ 0.772962 ┆ 0.754739  ┆ 0.808954  ┆ 20    │
│ lambda_function_repeat ┆ 1.740313 ┆ 1.71898  ┆ 1.699867  ┆ 1.780055  ┆ 20    │
│ lambda_if              ┆ 0.00053  ┆ 0.000523 ┆ 0.000519  ┆ 0.000539  ┆ 20    │
│ lambda_if_elim         ┆ 0.000184 ┆ 0.000182 ┆ 0.00018   ┆ 0.000186  ┆ 20    │
│ lambda_if_simple       ┆ 0.000022 ┆ 0.000022 ┆ 0.000022  ┆ 0.0000226 ┆ 20    │
│ lambda_let_simple      ┆ 0.000133 ┆ 0.00013  ┆ 0.00013   ┆ 0.000131  ┆ 20    │
│ lambda_under           ┆ 0.000055 ┆ 0.000054 ┆ 0.000054  ┆ 0.000055  ┆ 20    │
│ math_associate_adds    ┆ 0.011478 ┆ 0.011367 ┆ 0.011329  ┆ 0.011536  ┆ 20    │
│ math_diff_different    ┆ 0.00006  ┆ 0.000059 ┆ 0.000059  ┆ 0.00006   ┆ 20    │
│ math_diff_ln           ┆ 0.000052 ┆ 0.000052 ┆ 0.0000515 ┆ 0.000053  ┆ 20    │
│ math_diff_same         ┆ 0.000053 ┆ 0.000053 ┆ 0.000052  ┆ 0.000053  ┆ 20    │
│ math_diff_simple1      ┆ 0.000702 ┆ 0.000695 ┆ 0.000691  ┆ 0.000706  ┆ 20    │
│ math_diff_simple2      ┆ 0.000536 ┆ 0.000532 ┆ 0.000531  ┆ 0.000537  ┆ 20    │
│ math_powers            ┆ 0.000075 ┆ 0.000074 ┆ 0.000073  ┆ 0.000075  ┆ 20    │
│ math_simplify_add      ┆ 0.000402 ┆ 0.000397 ┆ 0.000396  ┆ 0.000411  ┆ 20    │
│ math_simplify_const    ┆ 0.000193 ┆ 0.000191 ┆ 0.000189  ┆ 0.000194  ┆ 20    │
│ math_simplify_factor   ┆ 0.001632 ┆ 0.001622 ┆ 0.00161   ┆ 0.001633  ┆ 20    │
│ math_simplify_root     ┆ 0.002337 ┆ 0.00231  ┆ 0.002302  ┆ 0.002319  ┆ 20    │
└────────────────────────┴──────────┴──────────┴───────────┴───────────┴───────┘

After 1

shape: (28, 6)
┌────────────────────────┬───────────┬──────────┬──────────┬──────────┬───────┐
│ test                   ┆ mean      ┆ median   ┆ .25      ┆ .75      ┆ count │
│ ---                    ┆ ---       ┆ ---      ┆ ---      ┆ ---      ┆ ---   │
│ str                    ┆ f64       ┆ f64      ┆ f64      ┆ f64      ┆ u32   │
╞════════════════════════╪═══════════╪══════════╪══════════╪══════════╪═══════╡
│ diff_power_harder      ┆ 0.01138   ┆ 0.011162 ┆ 0.011012 ┆ 0.011228 ┆ 20    │
│ diff_power_simple      ┆ 0.001069  ┆ 0.001054 ┆ 0.001045 ┆ 0.00109  ┆ 20    │
│ integ_one              ┆ 0.0000553 ┆ 0.000054 ┆ 0.000054 ┆ 0.000056 ┆ 20    │
│ integ_part1            ┆ 0.001719  ┆ 0.001692 ┆ 0.001686 ┆ 0.001721 ┆ 20    │
│ integ_part2            ┆ 0.006611  ┆ 0.006578 ┆ 0.006545 ┆ 0.006614 ┆ 20    │
│ integ_part3            ┆ 0.000689  ┆ 0.000685 ┆ 0.000682 ┆ 0.000696 ┆ 20    │
│ integ_sin              ┆ 0.000054  ┆ 0.000051 ┆ 0.000051 ┆ 0.000052 ┆ 20    │
│ integ_x                ┆ 0.00007   ┆ 0.00007  ┆ 0.00007  ┆ 0.000071 ┆ 20    │
│ lambda_compose         ┆ 0.001176  ┆ 0.001184 ┆ 0.001179 ┆ 0.001194 ┆ 20    │
│ lambda_compose_many    ┆ 0.004883  ┆ 0.004925 ┆ 0.004898 ┆ 0.004956 ┆ 20    │
│ lambda_fib             ┆ 0.690429  ┆ 0.680668 ┆ 0.672984 ┆ 0.696083 ┆ 20    │
│ lambda_function_repeat ┆ 1.565192  ┆ 1.560691 ┆ 1.550662 ┆ 1.576227 ┆ 20    │
│ lambda_if              ┆ 0.000506  ┆ 0.0005   ┆ 0.000499 ┆ 0.000518 ┆ 20    │
│ lambda_if_elim         ┆ 0.000181  ┆ 0.000178 ┆ 0.000176 ┆ 0.000183 ┆ 20    │
│ lambda_if_simple       ┆ 0.000021  ┆ 0.000021 ┆ 0.000021 ┆ 0.000022 ┆ 20    │
│ lambda_let_simple      ┆ 0.000129  ┆ 0.000126 ┆ 0.000126 ┆ 0.00013  ┆ 20    │
│ lambda_under           ┆ 0.000054  ┆ 0.000054 ┆ 0.000053 ┆ 0.000056 ┆ 20    │
│ math_associate_adds    ┆ 0.012375  ┆ 0.012193 ┆ 0.012174 ┆ 0.012234 ┆ 20    │
│ math_diff_different    ┆ 0.000061  ┆ 0.00006  ┆ 0.00006  ┆ 0.000061 ┆ 20    │
│ math_diff_ln           ┆ 0.000055  ┆ 0.000053 ┆ 0.000052 ┆ 0.000054 ┆ 20    │
│ math_diff_same         ┆ 0.000054  ┆ 0.000053 ┆ 0.000053 ┆ 0.000054 ┆ 20    │
│ math_diff_simple1      ┆ 0.000708  ┆ 0.000685 ┆ 0.000684 ┆ 0.000703 ┆ 20    │
│ math_diff_simple2      ┆ 0.000532  ┆ 0.000523 ┆ 0.000521 ┆ 0.000527 ┆ 20    │
│ math_powers            ┆ 0.000076  ┆ 0.000073 ┆ 0.000073 ┆ 0.000074 ┆ 20    │
│ math_simplify_add      ┆ 0.000407  ┆ 0.000389 ┆ 0.000388 ┆ 0.00039  ┆ 20    │
│ math_simplify_const    ┆ 0.000191  ┆ 0.000187 ┆ 0.000187 ┆ 0.00019  ┆ 20    │
│ math_simplify_factor   ┆ 0.001624  ┆ 0.001591 ┆ 0.001577 ┆ 0.001617 ┆ 20    │
│ math_simplify_root     ┆ 0.002287  ┆ 0.002249 ┆ 0.002245 ┆ 0.002265 ┆ 20    │
└────────────────────────┴───────────┴──────────┴──────────┴──────────┴───────┘

After2

shape: (28, 6)
┌────────────────────────┬──────────┬──────────┬──────────┬──────────┬───────┐
│ test                   ┆ mean     ┆ median   ┆ .25      ┆ .75      ┆ count │
│ ---                    ┆ ---      ┆ ---      ┆ ---      ┆ ---      ┆ ---   │
│ str                    ┆ f64      ┆ f64      ┆ f64      ┆ f64      ┆ u32   │
╞════════════════════════╪══════════╪══════════╪══════════╪══════════╪═══════╡
│ diff_power_harder      ┆ 0.011573 ┆ 0.0112   ┆ 0.011157 ┆ 0.011272 ┆ 20    │
│ diff_power_simple      ┆ 0.001084 ┆ 0.001059 ┆ 0.001046 ┆ 0.001143 ┆ 20    │
│ integ_one              ┆ 0.000056 ┆ 0.000055 ┆ 0.000054 ┆ 0.000058 ┆ 20    │
│ integ_part1            ┆ 0.001742 ┆ 0.001706 ┆ 0.0017   ┆ 0.00181  ┆ 20    │
│ integ_part2            ┆ 0.006698 ┆ 0.006569 ┆ 0.00654  ┆ 0.006827 ┆ 20    │
│ integ_part3            ┆ 0.000693 ┆ 0.000685 ┆ 0.000683 ┆ 0.000697 ┆ 20    │
│ integ_sin              ┆ 0.000051 ┆ 0.000051 ┆ 0.000051 ┆ 0.000052 ┆ 20    │
│ integ_x                ┆ 0.000072 ┆ 0.000071 ┆ 0.00007  ┆ 0.000071 ┆ 20    │
│ lambda_compose         ┆ 0.001221 ┆ 0.001188 ┆ 0.001183 ┆ 0.00129  ┆ 20    │
│ lambda_compose_many    ┆ 0.005066 ┆ 0.004928 ┆ 0.004906 ┆ 0.005319 ┆ 20    │
│ lambda_fib             ┆ 0.725474 ┆ 0.735219 ┆ 0.683442 ┆ 0.757344 ┆ 20    │
│ lambda_function_repeat ┆ 1.615716 ┆ 1.603787 ┆ 1.574573 ┆ 1.653676 ┆ 20    │
│ lambda_if              ┆ 0.0005   ┆ 0.000497 ┆ 0.000496 ┆ 0.000502 ┆ 20    │
│ lambda_if_elim         ┆ 0.000279 ┆ 0.000178 ┆ 0.000175 ┆ 0.00018  ┆ 20    │
│ lambda_if_simple       ┆ 0.000021 ┆ 0.000021 ┆ 0.000021 ┆ 0.000022 ┆ 20    │
│ lambda_let_simple      ┆ 0.000127 ┆ 0.000127 ┆ 0.000126 ┆ 0.000128 ┆ 20    │
│ lambda_under           ┆ 0.000054 ┆ 0.000053 ┆ 0.000053 ┆ 0.000054 ┆ 20    │
│ math_associate_adds    ┆ 0.012412 ┆ 0.012272 ┆ 0.012256 ┆ 0.01236  ┆ 20    │
│ math_diff_different    ┆ 0.000061 ┆ 0.00006  ┆ 0.000059 ┆ 0.000061 ┆ 20    │
│ math_diff_ln           ┆ 0.000053 ┆ 0.000052 ┆ 0.000052 ┆ 0.000053 ┆ 20    │
│ math_diff_same         ┆ 0.000054 ┆ 0.000053 ┆ 0.000053 ┆ 0.000054 ┆ 20    │
│ math_diff_simple1      ┆ 0.000698 ┆ 0.000684 ┆ 0.000682 ┆ 0.000699 ┆ 20    │
│ math_diff_simple2      ┆ 0.000538 ┆ 0.000525 ┆ 0.000523 ┆ 0.000543 ┆ 20    │
│ math_powers            ┆ 0.000073 ┆ 0.000073 ┆ 0.000072 ┆ 0.000074 ┆ 20    │
│ math_simplify_add      ┆ 0.000397 ┆ 0.00039  ┆ 0.000387 ┆ 0.000392 ┆ 20    │
│ math_simplify_const    ┆ 0.000188 ┆ 0.000187 ┆ 0.000186 ┆ 0.000189 ┆ 20    │
│ math_simplify_factor   ┆ 0.001606 ┆ 0.001584 ┆ 0.001577 ┆ 0.001623 ┆ 20    │
│ math_simplify_root     ┆ 0.002271 ┆ 0.002248 ┆ 0.002239 ┆ 0.002265 ┆ 20    │
└────────────────────────┴──────────┴──────────┴──────────┴──────────┴───────┘

I'm not sure if you have a more stable way of benchmarking? If might also be interesting to benchmark under PGO or BOLT to see if they affect one version more that another.

I also changed the implementation of UnionFind::find_mut to:

pub(super) fn find_mut_full(&mut self, mut current: Id) -> (Id, ClassId) {
    let canon = self.find(current);
    loop {
        match self.parent(current) {
            UnionFindElt::Parent(parent) => {
                self.set_parent(current, canon);
                current = parent;
            }
            UnionFindElt::Root(cid) => {
                debug_assert!(current == canon);
                return (current, cid);
            }
        }
    }
}

since one of the advantages of the path halving strategy was that the parent of a root was itself which isn't true anymore.

A more similar strategy would be to do:

pub(super) fn find_mut_full(&mut self, mut current: Id) -> (Id, ClassId) {
    loop {
        match self.parent(current) {
            UnionFindElt::Parent(parent) => {
                match self.parent(parent) {
                    UnionFindElt::Parent(grand_parent) => {
                        self.set_parent(current, grand_parent);
                        current = grand_parent
                    }
                    UnionFindElt::Root(cid) => return (parent, cid),
                }
            },
            UnionFindElt::Root(cid) => return (current, cid),
        }
    };
}

Another strategy I thought of was

pub(super) fn find_mut_full(&mut self, mut current: Id) -> (Id, ClassId) {
    let mut count = 0u32;
    let mut old_current = current;
    let (canon, cid) = loop {
        match self.parent(current) {
            UnionFindElt::Parent(parent) => {
                count += 1;
                current = parent
            },
            UnionFindElt::Root(cid) => break (current, cid),
        }
    };
    while count > 1 {
        let next = self.parents[usize::from(old_current)].as_parent();
        count-=1;
        self.set_parent(old_current, canon);
        old_current = next;
    }
    (canon, cid)
}

With the idea being that most of the time the second loop won't be needed

I don't think my benchmarking is precise enough to tell the difference, but if anyone wants to try comparing these they can.

@mwillsey
Copy link
Member

Nice work! Yes, this is sort of what I expected. I think the cost of this change isn't worth the complexity, since performance is so dominated by other factors. So I think we should close this for now unless/until a need is demonstrated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants