Skip to content

Commit

Permalink
A better colored union find for parallel runs
Browse files Browse the repository at this point in the history
  • Loading branch information
eytans committed Sep 15, 2024
1 parent afabcfc commit 35da4f4
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 64 deletions.
69 changes: 45 additions & 24 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 7 additions & 6 deletions src/colors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use itertools::Itertools;
use std::fmt::Formatter;
use indexmap::{IndexMap, IndexSet};
use crate::colored_union_find::ColoredUnionFind;

Check warning on line 9 in src/colors.rs

View workflow job for this annotation

GitHub Actions / clippy

unused import: `crate::colored_union_find::ColoredUnionFind`

warning: unused import: `crate::colored_union_find::ColoredUnionFind` --> src/colors.rs:9:5 | 9 | use crate::colored_union_find::ColoredUnionFind; | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | = note: `#[warn(unused_imports)]` on by default
use crate::unionfind::{SimpleUnionFind, UnionFindWrapper};

Check warning on line 10 in src/colors.rs

View workflow job for this annotation

GitHub Actions / clippy

unused import: `SimpleUnionFind`

warning: unused import: `SimpleUnionFind` --> src/colors.rs:10:24 | 10 | use crate::unionfind::{SimpleUnionFind, UnionFindWrapper}; | ^^^^^^^^^^^^^^^

global_counter!(COLOR_IDS, usize, usize::default());

Expand All @@ -19,7 +20,7 @@ pub struct Color<L: Language, N: Analysis<L>> {
pub(crate) equality_classes: IndexMap<Id, IndexSet<Id>>,
/// Used to implement a union find. Opposite function of `equality_classes`.
/// Supports removal of elements when they are not needed.
union_find: ColoredUnionFind,
union_find: UnionFindWrapper<(), Id>,
/// Used to determine for each a colored equality class what is the black colored class.
/// Relevant when a colored edge was added.
pub(crate) black_colored_classes: IndexMap<Id, Id>,
Expand Down Expand Up @@ -74,7 +75,7 @@ impl<L: Language, N: Analysis<L>> Color<L, N> {
pub(crate) fn verify_uf_minimal(&self, egraph: &EGraph<L, N>) {
let mut parents: IndexMap<Id, usize> = IndexMap::default();
for (k, _v) in self.union_find.iter() {
let v = self.find(egraph, k);
let v = self.find(egraph, *k);
*parents.entry(v).or_default() += 1;
}
for (k, v) in parents {
Expand Down Expand Up @@ -153,8 +154,8 @@ impl<L: Language, N: Analysis<L>> Color<L, N> {
// This part only needs to happen if one of the two is in the union find.
let orig_to = orig_to.unwrap_or(base_to);
let orig_from = orig_from.unwrap_or(base_from);
self.union_find.insert(orig_to);
self.union_find.insert(orig_from);
self.union_find.insert(orig_to, ());
self.union_find.insert(orig_from, ());
self.union_find.union(&orig_to, &orig_from).unwrap()
} else {
(base_to, base_from)
Expand Down Expand Up @@ -195,8 +196,8 @@ impl<L: Language, N: Analysis<L>> Color<L, N> {
// Assumed id1 and id2 are parent canonized
pub(crate) fn inner_colored_union(&mut self, id1: Id, id2: Id) -> (Id, Id, bool, Vec<(Id, Id)>) {
// Parent classes will be updated in black union to come.
self.union_find.insert(id1);
self.union_find.insert(id2);
self.union_find.insert(id1, ());
self.union_find.insert(id2, ());
let (to, from) = self.union_find.union(&id1, &id2).unwrap();
let changed = to != from;
let g_todo = self.update_black_classes(to, from).into_iter().collect_vec();
Expand Down
7 changes: 4 additions & 3 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ use crate::Pattern;
use crate::RecExpr;
use crate::Searcher;
use crate::Subst;
use crate::UnionFind;
use crate::SimpleUnionFind;
use crate::{OpId, SymbolLang};

pub use crate::colors::{Color, ColorId};
use itertools::Itertools;
use multimap::MultiMap;
use serde::{Deserialize, Serialize};
use crate::unionfind::UnionFind;
use crate::util::UniqueQueue;

/** A data structure to keep track of equalities between expressions.
Expand Down Expand Up @@ -195,7 +196,7 @@ pub struct EGraph<L: Language, N: Analysis<L>> {
/// The `Analysis` given when creating this `EGraph`.
pub analysis: N,
pub(crate) memo: IndexMap<L, Id>,
unionfind: UnionFind,
unionfind: SimpleUnionFind,
classes: SparseVec<EClass<L, N::Data>>,
/// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode,
/// not the canonical id of the eclass.
Expand Down Expand Up @@ -248,7 +249,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {

impl<L: Language, N: Analysis<L>> EGraph<L, N> {
pub(crate) fn inner_new(
uf: UnionFind,
uf: SimpleUnionFind,
classes: Vec<Option<Box<EClass<SymbolLang, ()>>>>,
memo: IndexMap<SymbolLang, Id>,
) -> EGraph<SymbolLang, ()> {
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl std::fmt::Display for ColorId {
}
}

pub(crate) use unionfind::UnionFind;
pub(crate) use unionfind::SimpleUnionFind;

pub use {
dot::Dot,
Expand Down
4 changes: 2 additions & 2 deletions src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::io::{Read, Write, BufReader, Result, BufRead};
use indexmap::IndexMap;
use itertools::Itertools;
use crate::{EGraph, Analysis, Language, Id, EClass, SymbolLang, ColorId};
use crate::unionfind::UnionFind;
use crate::unionfind::SimpleUnionFind;

/// A trait for EGraphs that can be serialized.
pub trait Serialization {
Expand Down Expand Up @@ -149,7 +149,7 @@ impl Deserialization for EGraph<SymbolLang, ()> {

#[derive(Debug, Clone, Default)]
struct EGraphBuilder {
unionfind: UnionFind,
unionfind: SimpleUnionFind,
classes: Vec<Option<Box<EClass<SymbolLang, ()>>>>,
memo: IndexMap<SymbolLang, Id>,
_palette: ColorPalette,
Expand Down
Loading

0 comments on commit 35da4f4

Please sign in to comment.