Skip to content

Commit

Permalink
Get some display on stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
eytans committed Mar 12, 2024
1 parent c590048 commit 5020ad6
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 25 deletions.
4 changes: 2 additions & 2 deletions src/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub struct Dot<'a, L: Language, N: Analysis<L>> {

impl<'a, L, N> Dot<'a, L, N>
where
L: Language + Display,
L: Language,
N: Analysis<L>,
{
/// Writes the `Dot` to a .dot file with the given filename.
Expand Down Expand Up @@ -178,7 +178,7 @@ impl<'a, L: Language, N: Analysis<L>> Debug for Dot<'a, L, N> {

impl<'a, L, N> Display for Dot<'a, L, N>
where
L: Language + Display,
L: Language,
N: Analysis<L>,
{
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
Expand Down
2 changes: 1 addition & 1 deletion src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
}
}

impl<L: Language + Display, N: Analysis<L>> EGraph<L, N> {
impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// Panic if the given eclass doesn't contain the given patterns
///
/// Useful for testing.
Expand Down
12 changes: 6 additions & 6 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,14 @@ pub struct Explanation<L: Language> {
flat_explanation: Option<FlatExplanation<L>>,
}

impl<L: Language + Display + FromOp> Display for Explanation<L> {
impl<L: Language + FromOp> Display for Explanation<L> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let s = self.get_sexp().to_string();
f.write_str(&s)
}
}

impl<L: Language + Display + FromOp> Explanation<L> {
impl<L: Language + FromOp> Explanation<L> {
/// Get each flattened term in the explanation as an s-expression string.
///
/// The s-expression format mirrors the format of each [`FlatTerm`].
Expand Down Expand Up @@ -591,7 +591,7 @@ pub struct FlatTerm<L: Language> {
pub children: FlatExplanation<L>,
}

impl<L: Language + Display + FromOp> Display for FlatTerm<L> {
impl<L: Language + FromOp> Display for FlatTerm<L> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let s = self.get_sexp().to_string();
write!(f, "{}", s)
Expand Down Expand Up @@ -648,7 +648,7 @@ impl<L: Language> Default for Explain<L> {
}
}

impl<L: Language + Display + FromOp> FlatTerm<L> {
impl<L: Language + FromOp> FlatTerm<L> {
/// Convert this FlatTerm to an S-expression.
/// See [`get_flat_string`](Explanation::get_flat_string) for the format of these expressions.
pub fn get_string(&self) -> String {
Expand Down Expand Up @@ -692,7 +692,7 @@ impl<L: Language + Display + FromOp> FlatTerm<L> {
}
}

impl<L: Language + Display + FromOp> Display for TreeTerm<L> {
impl<L: Language + FromOp> Display for TreeTerm<L> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut buf = String::new();
let width = 80;
Expand All @@ -701,7 +701,7 @@ impl<L: Language + Display + FromOp> Display for TreeTerm<L> {
}
}

impl<L: Language + Display + FromOp> TreeTerm<L> {
impl<L: Language + FromOp> TreeTerm<L> {
/// Convert this TreeTerm to an S-expression.
fn get_sexp(&self) -> Sexp {
self.get_sexp_with_bindings(&Default::default())
Expand Down
8 changes: 4 additions & 4 deletions src/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use thiserror::Error;
///
/// See [`SymbolLang`] for quick-and-dirty use cases.
#[allow(clippy::len_without_is_empty)]
pub trait Language: Debug + Clone + Eq + Ord + Hash {
pub trait Language: Debug + Clone + Eq + Ord + Hash + Display {
/// Returns true if this enode matches another enode.
/// This should only consider the operator, not the children `Id`s.
fn matches(&self, other: &Self) -> bool;
Expand Down Expand Up @@ -370,7 +370,7 @@ pub struct RecExpr<L> {
}

#[cfg(feature = "serde-1")]
impl<L: Language + Display> serde::Serialize for RecExpr<L> {
impl<L: Language> serde::Serialize for RecExpr<L> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
Expand Down Expand Up @@ -454,7 +454,7 @@ impl<L: Language> IndexMut<Id> for RecExpr<L> {
}
}

impl<L: Language + Display> Display for RecExpr<L> {
impl<L: Language> Display for RecExpr<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.nodes.is_empty() {
Display::fmt("()", f)
Expand All @@ -465,7 +465,7 @@ impl<L: Language + Display> Display for RecExpr<L> {
}
}

impl<L: Language + Display> RecExpr<L> {
impl<L: Language> RecExpr<L> {
/// Convert this RecExpr into an Sexp
pub(crate) fn to_sexp(&self) -> Sexp {
let last = self.nodes.len() - 1;
Expand Down
14 changes: 13 additions & 1 deletion src/multipattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,19 @@ impl<L: Language + FromOp> FromStr for MultiPattern<L> {
}
}

impl<L: Language, A: Analysis<L>> Searcher<L, A> for MultiPattern<L> {
impl<L: Language + std::fmt::Display> std::fmt::Display for MultiPattern<L> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for (i, (v, pat)) in self.asts.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{} = {}", v, pat)?;
}
Ok(())
}
}

impl<L: Language + std::fmt::Display, A: Analysis<L>> Searcher<L, A> for MultiPattern<L> {
fn search_eclass_with_limit(
&self,
egraph: &EGraph<L, A>,
Expand Down
6 changes: 3 additions & 3 deletions src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl<L: Language> Pattern<L> {
}
}

impl<L: Language + Display> Pattern<L> {
impl<L: Language> Pattern<L> {
/// Pretty print this pattern as a sexp with the given width
pub fn pretty(&self, width: usize) -> String {
self.ast.pretty(width)
Expand Down Expand Up @@ -159,7 +159,7 @@ impl<L: Language> Language for ENodeOrVar<L> {
}
}

impl<L: Language + Display> Display for ENodeOrVar<L> {
impl<L: Language> Display for ENodeOrVar<L> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::ENode(node) => Display::fmt(node, f),
Expand Down Expand Up @@ -240,7 +240,7 @@ impl<L: Language> TryFrom<Pattern<L>> for RecExpr<L> {
}
}

impl<L: Language + Display> Display for Pattern<L> {
impl<L: Language> Display for Pattern<L> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.ast, f)
}
Expand Down
51 changes: 44 additions & 7 deletions src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub struct Rewrite<L, N> {

impl<L, N> Debug for Rewrite<L, N>
where
L: Language + Display + 'static,
L: Language + 'static,
N: Analysis<L> + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -159,7 +159,7 @@ where
/// matching substitutions.
/// Right now the only significant [`Searcher`] is [`Pattern`].
///
pub trait Searcher<L, N>
pub trait Searcher<L, N> : Display
where
L: Language,
N: Analysis<L>,
Expand Down Expand Up @@ -321,7 +321,7 @@ where
/// let start = "(+ x (* y z))".parse().unwrap();
/// Runner::default().with_expr(&start).run(rules);
/// ```
pub trait Applier<L, N>
pub trait Applier<L, N> : Display
where
L: Language,
N: Analysis<L>,
Expand Down Expand Up @@ -416,6 +416,16 @@ pub struct ConditionalApplier<C, A> {
pub applier: A,
}

impl<C, A> Display for ConditionalApplier<C, A>
where
C: Display,
A: Display
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ConditionalApplier({}, {})", self.condition, self.applier)
}
}

impl<C, A, N, L> Applier<L, N> for ConditionalApplier<C, A>
where
L: Language,
Expand Down Expand Up @@ -459,7 +469,7 @@ where
///
/// [`check`]: Condition::check()
/// [`Fn`]: std::ops::Fn
pub trait Condition<L, N>
pub trait Condition<L, N> : Display
where
L: Language,
N: Analysis<L>,
Expand All @@ -482,17 +492,35 @@ where
}
}

impl<L, F, N> Condition<L, N> for F
pub struct LambdaCond<L, N> where
L: Language,
N: Analysis<L>,
{
pub f: Box<dyn Fn(&mut EGraph<L, N>, Id, &Subst) -> bool + Send + Sync>,
pub d: String
}

impl<L, N> Display for LambdaCond<L, N>
where
L: Language,
N: Analysis<L>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "LambdaCond({})", self.d)
}
}

impl<L, N> Condition<L, N> for LambdaCond<L, N>
where
L: Language,
N: Analysis<L>,
F: Fn(&mut EGraph<L, N>, Id, &Subst) -> bool,
{
fn check(&self, egraph: &mut EGraph<L, N>, eclass: Id, subst: &Subst) -> bool {
self(egraph, eclass, subst)
(self.f)(egraph, eclass, subst)
}
}


/// A [`Condition`] that checks if two terms are equivalent.
///
/// This condition adds its two [`Pattern`] to the egraph and passes
Expand Down Expand Up @@ -523,6 +551,15 @@ impl<L: FromOp> ConditionEqual<L> {
}
}

impl<L> Display for ConditionEqual<L>
where
L: Language,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "({} = {})", self.p1, self.p2)
}
}

impl<L, N> Condition<L, N> for ConditionEqual<L>
where
L: Language,
Expand Down
2 changes: 1 addition & 1 deletion src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub fn test_runner<L, A>(
check_fn: Option<fn(Runner<L, A, ()>)>,
should_check: bool,
) where
L: Language + Display + FromOp + 'static,
L: Language + FromOp + 'static,
A: Analysis<L> + Default,
{
let _ = env_logger::builder().is_test(true).try_init();
Expand Down

0 comments on commit 5020ad6

Please sign in to comment.