Skip to content

Commit

Permalink
Added push/pop
Browse files Browse the repository at this point in the history
  • Loading branch information
dewert99 committed Jan 2, 2024
1 parent 812c76c commit 1f5a954
Show file tree
Hide file tree
Showing 11 changed files with 486 additions and 24 deletions.
40 changes: 33 additions & 7 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{
#[cfg(feature = "serde-1")]
use serde::{Deserialize, Serialize};

use crate::semi_persistent::UndoLogT;
use log::*;

/** A data structure to keep track of equalities between expressions.
Expand Down Expand Up @@ -56,16 +57,24 @@ pub struct EGraph<L: Language, N: Analysis<L>> {
pub analysis: N,
/// The `Explain` used to explain equivalences in this `EGraph`.
pub(crate) explain: Option<Explain<L>>,
unionfind: UnionFind,
#[cfg_attr(
feature = "serde-1",
serde(bound(
serialize = "N::UndoLog: Serialize",
deserialize = "N::UndoLog: for<'a> Deserialize<'a>",
))
)]
pub(crate) undo_log: N::UndoLog,
pub(crate) unionfind: UnionFind,
/// Stores each enode's `Id`, not the `Id` of the eclass.
/// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new
/// unions can cause them to become out of date.
#[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
memo: HashMap<L, Id>,
pub(crate) memo: HashMap<L, Id>,
/// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode,
/// not the canonical id of the eclass.
pending: Vec<(L, Id)>,
analysis_pending: UniqueQueue<(L, Id)>,
pub(crate) pending: Vec<(L, Id)>,
pub(crate) analysis_pending: UniqueQueue<(L, Id)>,
#[cfg_attr(
feature = "serde-1",
serde(bound(
Expand Down Expand Up @@ -103,6 +112,8 @@ impl<L: Language, N: Analysis<L>> Debug for EGraph<L, N> {
f.debug_struct("EGraph")
.field("memo", &self.memo)
.field("classes", &self.classes)
.field("undo_log", &self.undo_log)
.field("explain", &self.explain)
.finish()
}
}
Expand All @@ -120,6 +131,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
memo: Default::default(),
analysis_pending: Default::default(),
classes_by_op: Default::default(),
undo_log: Default::default(),
}
}

Expand Down Expand Up @@ -769,9 +781,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
*existing_explain
} else {
let new_id = self.unionfind.make_set();
self.undo_log.add_node(&original, new_id);
explain.add(original, new_id, new_id);
self.unionfind.union(id, new_id);
self.undo_log.union(id, new_id);
explain.union(existing_id, new_id, Justification::Congruence, true);
self.undo_log.union_explain(existing_id, new_id);
new_id
}
} else {
Expand All @@ -780,6 +795,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
} else {
let id = self.make_new_eclass(enode);
if let Some(explain) = self.explain.as_mut() {
self.undo_log.add_node(&original, id);
explain.add(original, id, id);
}

Expand Down Expand Up @@ -811,7 +827,8 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
self.pending.push((enode.clone(), id));

self.classes.insert(id, class);
assert!(self.memo.insert(enode, id).is_none());
let old = self.undo_log.modify_memo(&mut self.memo, enode, Some(id));
assert!(old.is_none());

id
}
Expand Down Expand Up @@ -919,7 +936,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
if id1 == id2 {
if let Some(Justification::Rule(_)) = rule {
if let Some(explain) = &mut self.explain {
explain.alternate_rewrite(enode_id1, enode_id2, rule.unwrap());
explain.alternate_rewrite(
enode_id1,
enode_id2,
rule.unwrap(),
&mut self.undo_log,
);
}
}
return false;
Expand All @@ -933,10 +955,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {

if let Some(explain) = &mut self.explain {
explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs);
self.undo_log.union_explain(enode_id1, enode_id2);
}

// make id1 the new root
self.unionfind.union(id1, id2);
self.undo_log.union(id1, id2);

assert_ne!(id1, id2);
let class2 = self.classes.remove(&id2).unwrap();
Expand Down Expand Up @@ -1105,7 +1129,9 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
while !self.pending.is_empty() || !self.analysis_pending.is_empty() {
while let Some((mut node, class)) = self.pending.pop() {
node.update_children(|id| self.find_mut(id));
if let Some(memo_class) = self.memo.insert(node, class) {
if let Some(memo_class) =
self.undo_log.modify_memo(&mut self.memo, node, Some(class))
{
let did_something = self.perform_union(
memo_class,
class,
Expand Down
46 changes: 30 additions & 16 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::collections::{BinaryHeap, VecDeque};
use std::fmt::{self, Debug, Display, Formatter};
use std::rc::Rc;

use crate::semi_persistent::UndoLogT;
use symbolic_expressions::Sexp;

type ProofCost = Saturating<usize>;
Expand All @@ -29,32 +30,43 @@ pub enum Justification {

#[derive(Debug, Clone, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
struct Connection {
next: Id,
pub(crate) struct Connection {
pub(crate) next: Id,
current: Id,
justification: Justification,
is_rewrite_forward: bool,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
struct ExplainNode<L: Language> {
node: L,
pub(crate) struct ExplainNode<L: Language> {
pub(crate) node: L,
// neighbors includes parent connections
neighbors: Vec<Connection>,
parent_connection: Connection,
pub(crate) neighbors: Vec<Connection>,
pub(crate) parent_connection: Connection,
// it was inserted because of:
// 1) it's parent is inserted (points to parent enode)
// 2) a rewrite instantiated it (points to adjacent enode)
// 3) it was inserted directly (points to itself)
// if 1 is true but it's also adjacent (2) then either works and it picks 2
existance_node: Id,
pub(crate) existance_node: Id,
}

impl Connection {
pub(crate) fn dummy(set: Id) -> Self {
Connection {
justification: Justification::Congruence,
is_rewrite_forward: false,
next: set,
current: set,
}
}
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
pub struct Explain<L: Language> {
explainfind: Vec<ExplainNode<L>>,
pub(crate) explainfind: Vec<ExplainNode<L>>,
#[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
pub uncanon_memo: HashMap<L, Id>,
/// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted.
Expand All @@ -66,7 +78,7 @@ pub struct Explain<L: Language> {
// Invariant: The distance is always <= the unoptimized distance
// That is, less than or equal to the result of `distance_between`
#[cfg_attr(feature = "serde-1", serde(skip))]
shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>,
pub(crate) shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>,
}

#[derive(Default)]
Expand Down Expand Up @@ -1048,12 +1060,7 @@ impl<L: Language> Explain<L> {
self.explainfind.push(ExplainNode {
node,
neighbors: vec![],
parent_connection: Connection {
justification: Justification::Congruence,
is_rewrite_forward: false,
next: set,
current: set,
},
parent_connection: Connection::dummy(set),
existance_node,
});
set
Expand All @@ -1075,7 +1082,13 @@ impl<L: Language> Explain<L> {
}
}

pub(crate) fn alternate_rewrite(&mut self, node1: Id, node2: Id, justification: Justification) {
pub(crate) fn alternate_rewrite(
&mut self,
node1: Id,
node2: Id,
justification: Justification,
undo: &mut impl UndoLogT<L>,
) {
if node1 == node2 {
return;
}
Expand All @@ -1084,6 +1097,7 @@ impl<L: Language> Explain<L> {
return;
}
}
undo.union_explain(node1, node2);

let lconnection = Connection {
justification: justification.clone(),
Expand Down
24 changes: 24 additions & 0 deletions src/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::{hash::Hash, str::FromStr};

use crate::*;

use crate::semi_persistent::{UndoLog, UndoLogT};
use fmt::Formatter;
use symbolic_expressions::{Sexp, SexpError};
use thiserror::Error;
Expand Down Expand Up @@ -655,6 +656,7 @@ define_language! {
struct ConstantFolding;
impl Analysis<SimpleMath> for ConstantFolding {
type Data = Option<i32>;
type UndoLog = ();
fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge {
egg::merge_max(to, from)
Expand Down Expand Up @@ -700,6 +702,12 @@ pub trait Analysis<L: Language>: Sized {
/// The per-[`EClass`] data for this analysis.
type Data: Debug;

/// Determines whether the [`EGraph`] supports [`push`](EGraph::push) and [`pop`](EGraph::pop)
/// Setting this to `()` disables [`push`](EGraph::push) and [`pop`](EGraph::pop)
/// Setting this to [`UndoLog`](UndoLog) enables [`push`](EGraph::push) and [`pop`](EGraph::pop)
/// Doing this requires that the [`EGraph`] has explanations enabled
type UndoLog: UndoLogT<L>;

/// Makes a new [`Analysis`] data for a given e-node.
///
/// Note the mutable `egraph` parameter: this is needed for some
Expand Down Expand Up @@ -765,6 +773,22 @@ pub trait Analysis<L: Language>: Sized {

impl<L: Language> Analysis<L> for () {
type Data = ();

type UndoLog = ();
fn make(_egraph: &mut EGraph<L, Self>, _enode: &L) -> Self::Data {}
fn merge(&mut self, _: &mut Self::Data, _: Self::Data) -> DidMerge {
DidMerge(false, false)
}
}

/// Simple [`Analysis`], similar to `()` but enables [`push`](EGraph::push) and [`pop`](EGraph::pop)
/// Doing this requires that the [`EGraph`] has explanations enabled
pub struct WithUndo;

impl<L: Language> Analysis<L> for WithUndo {
type Data = ();

type UndoLog = UndoLog<L>;
fn make(_egraph: &mut EGraph<L, Self>, _enode: &L) -> Self::Data {}
fn merge(&mut self, _: &mut Self::Data, _: Self::Data) -> DidMerge {
DidMerge(false, false)
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mod multipattern;
mod pattern;
mod rewrite;
mod run;
mod semi_persistent;
mod subst;
mod unionfind;
mod util;
Expand Down Expand Up @@ -101,6 +102,7 @@ pub use {
pattern::{ENodeOrVar, Pattern, PatternAst, SearchMatches},
rewrite::{Applier, Condition, ConditionEqual, ConditionalApplier, Rewrite, Searcher},
run::*,
semi_persistent::UndoLog,
subst::{Subst, Var},
util::*,
};
Expand Down
1 change: 1 addition & 0 deletions src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ where
/// struct MinSize;
/// impl Analysis<Math> for MinSize {
/// type Data = usize;
/// type UndoLog = ();
/// fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge {
/// merge_min(to, from)
/// }
Expand Down
Loading

0 comments on commit 1f5a954

Please sign in to comment.