From 5020ad629de66378ff5592a62d76fadec3cef402 Mon Sep 17 00:00:00 2001 From: Eytan Singher Date: Tue, 12 Mar 2024 14:41:46 +0200 Subject: [PATCH] Get some display on stuff --- src/dot.rs | 4 ++-- src/egraph.rs | 2 +- src/explain.rs | 12 +++++------ src/language.rs | 8 +++---- src/multipattern.rs | 14 ++++++++++++- src/pattern.rs | 6 +++--- src/rewrite.rs | 51 ++++++++++++++++++++++++++++++++++++++------- src/test.rs | 2 +- 8 files changed, 74 insertions(+), 25 deletions(-) diff --git a/src/dot.rs b/src/dot.rs index cefaf440..fa52903a 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -61,7 +61,7 @@ pub struct Dot<'a, L: Language, N: Analysis> { impl<'a, L, N> Dot<'a, L, N> where - L: Language + Display, + L: Language, N: Analysis, { /// Writes the `Dot` to a .dot file with the given filename. @@ -178,7 +178,7 @@ impl<'a, L: Language, N: Analysis> Debug for Dot<'a, L, N> { impl<'a, L, N> Display for Dot<'a, L, N> where - L: Language + Display, + L: Language, N: Analysis, { fn fmt(&self, f: &mut Formatter) -> fmt::Result { diff --git a/src/egraph.rs b/src/egraph.rs index 22956348..97e1678c 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -928,7 +928,7 @@ impl> EGraph { } } -impl> EGraph { +impl> EGraph { /// Panic if the given eclass doesn't contain the given patterns /// /// Useful for testing. diff --git a/src/explain.rs b/src/explain.rs index 79692326..0e74a6eb 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -112,14 +112,14 @@ pub struct Explanation { flat_explanation: Option>, } -impl Display for Explanation { +impl Display for Explanation { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let s = self.get_sexp().to_string(); f.write_str(&s) } } -impl Explanation { +impl Explanation { /// Get each flattened term in the explanation as an s-expression string. /// /// The s-expression format mirrors the format of each [`FlatTerm`]. @@ -591,7 +591,7 @@ pub struct FlatTerm { pub children: FlatExplanation, } -impl Display for FlatTerm { +impl Display for FlatTerm { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let s = self.get_sexp().to_string(); write!(f, "{}", s) @@ -648,7 +648,7 @@ impl Default for Explain { } } -impl FlatTerm { +impl FlatTerm { /// Convert this FlatTerm to an S-expression. /// See [`get_flat_string`](Explanation::get_flat_string) for the format of these expressions. pub fn get_string(&self) -> String { @@ -692,7 +692,7 @@ impl FlatTerm { } } -impl Display for TreeTerm { +impl Display for TreeTerm { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let mut buf = String::new(); let width = 80; @@ -701,7 +701,7 @@ impl Display for TreeTerm { } } -impl TreeTerm { +impl TreeTerm { /// Convert this TreeTerm to an S-expression. fn get_sexp(&self) -> Sexp { self.get_sexp_with_bindings(&Default::default()) diff --git a/src/language.rs b/src/language.rs index c51ae826..9704524c 100644 --- a/src/language.rs +++ b/src/language.rs @@ -28,7 +28,7 @@ use thiserror::Error; /// /// See [`SymbolLang`] for quick-and-dirty use cases. #[allow(clippy::len_without_is_empty)] -pub trait Language: Debug + Clone + Eq + Ord + Hash { +pub trait Language: Debug + Clone + Eq + Ord + Hash + Display { /// Returns true if this enode matches another enode. /// This should only consider the operator, not the children `Id`s. fn matches(&self, other: &Self) -> bool; @@ -370,7 +370,7 @@ pub struct RecExpr { } #[cfg(feature = "serde-1")] -impl serde::Serialize for RecExpr { +impl serde::Serialize for RecExpr { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, @@ -454,7 +454,7 @@ impl IndexMut for RecExpr { } } -impl Display for RecExpr { +impl Display for RecExpr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.nodes.is_empty() { Display::fmt("()", f) @@ -465,7 +465,7 @@ impl Display for RecExpr { } } -impl RecExpr { +impl RecExpr { /// Convert this RecExpr into an Sexp pub(crate) fn to_sexp(&self) -> Sexp { let last = self.nodes.len() - 1; diff --git a/src/multipattern.rs b/src/multipattern.rs index 4fe61212..e6e659de 100644 --- a/src/multipattern.rs +++ b/src/multipattern.rs @@ -100,7 +100,19 @@ impl FromStr for MultiPattern { } } -impl> Searcher for MultiPattern { +impl std::fmt::Display for MultiPattern { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (i, (v, pat)) in self.asts.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{} = {}", v, pat)?; + } + Ok(()) + } +} + +impl> Searcher for MultiPattern { fn search_eclass_with_limit( &self, egraph: &EGraph, diff --git a/src/pattern.rs b/src/pattern.rs index 475465f8..1b35c2ae 100644 --- a/src/pattern.rs +++ b/src/pattern.rs @@ -122,7 +122,7 @@ impl Pattern { } } -impl Pattern { +impl Pattern { /// Pretty print this pattern as a sexp with the given width pub fn pretty(&self, width: usize) -> String { self.ast.pretty(width) @@ -159,7 +159,7 @@ impl Language for ENodeOrVar { } } -impl Display for ENodeOrVar { +impl Display for ENodeOrVar { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { Self::ENode(node) => Display::fmt(node, f), @@ -240,7 +240,7 @@ impl TryFrom> for RecExpr { } } -impl Display for Pattern { +impl Display for Pattern { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { Display::fmt(&self.ast, f) } diff --git a/src/rewrite.rs b/src/rewrite.rs index 1687caa6..478dce21 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -26,7 +26,7 @@ pub struct Rewrite { impl Debug for Rewrite where - L: Language + Display + 'static, + L: Language + 'static, N: Analysis + 'static, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -159,7 +159,7 @@ where /// matching substitutions. /// Right now the only significant [`Searcher`] is [`Pattern`]. /// -pub trait Searcher +pub trait Searcher : Display where L: Language, N: Analysis, @@ -321,7 +321,7 @@ where /// let start = "(+ x (* y z))".parse().unwrap(); /// Runner::default().with_expr(&start).run(rules); /// ``` -pub trait Applier +pub trait Applier : Display where L: Language, N: Analysis, @@ -416,6 +416,16 @@ pub struct ConditionalApplier { pub applier: A, } +impl Display for ConditionalApplier +where + C: Display, + A: Display +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ConditionalApplier({}, {})", self.condition, self.applier) + } +} + impl Applier for ConditionalApplier where L: Language, @@ -459,7 +469,7 @@ where /// /// [`check`]: Condition::check() /// [`Fn`]: std::ops::Fn -pub trait Condition +pub trait Condition : Display where L: Language, N: Analysis, @@ -482,17 +492,35 @@ where } } -impl Condition for F +pub struct LambdaCond where + L: Language, + N: Analysis, +{ + pub f: Box, Id, &Subst) -> bool + Send + Sync>, + pub d: String +} + +impl Display for LambdaCond +where + L: Language, + N: Analysis, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "LambdaCond({})", self.d) + } +} + +impl Condition for LambdaCond where L: Language, N: Analysis, - F: Fn(&mut EGraph, Id, &Subst) -> bool, { fn check(&self, egraph: &mut EGraph, eclass: Id, subst: &Subst) -> bool { - self(egraph, eclass, subst) + (self.f)(egraph, eclass, subst) } } + /// A [`Condition`] that checks if two terms are equivalent. /// /// This condition adds its two [`Pattern`] to the egraph and passes @@ -523,6 +551,15 @@ impl ConditionEqual { } } +impl Display for ConditionEqual +where + L: Language, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "({} = {})", self.p1, self.p2) + } +} + impl Condition for ConditionEqual where L: Language, diff --git a/src/test.rs b/src/test.rs index 805a5f73..22d534f8 100644 --- a/src/test.rs +++ b/src/test.rs @@ -34,7 +34,7 @@ pub fn test_runner( check_fn: Option)>, should_check: bool, ) where - L: Language + Display + FromOp + 'static, + L: Language + FromOp + 'static, A: Analysis + Default, { let _ = env_logger::builder().is_test(true).try_init();