From e26a499a64dd81cf10b1729e94911f8abdddd2d4 Mon Sep 17 00:00:00 2001 From: oflatt Date: Thu, 15 Feb 2024 13:19:10 -0800 Subject: [PATCH] fix up subst as it was before, rename --- src/ast/expr.rs | 32 +++------ src/ast/mod.rs | 141 +++++++++++++++++++++++--------------- src/ast/remove_globals.rs | 6 +- 3 files changed, 97 insertions(+), 82 deletions(-) diff --git a/src/ast/expr.rs b/src/ast/expr.rs index 98c9e9abc..95d6c277f 100644 --- a/src/ast/expr.rs +++ b/src/ast/expr.rs @@ -177,12 +177,12 @@ impl /// Applys `f` to all sub-expressions (including `self`) /// bottom-up, collecting the results. - pub fn map(self, f: &mut impl FnMut(Self) -> Self) -> Self { + pub fn visit_exprs(self, f: &mut impl FnMut(Self) -> Self) -> Self { match self { GenericExpr::Lit(..) => f(self), GenericExpr::Var(..) => f(self), GenericExpr::Call(ann, op, children) => { - let children = children.into_iter().map(|c| c.map(f)).collect(); + let children = children.into_iter().map(|c| c.visit_exprs(f)).collect(); f(GenericExpr::Call(ann.clone(), op.clone(), children)) } } @@ -202,35 +202,21 @@ impl } } - pub fn map_heads( - self, - f: &mut impl FnMut(Head) -> Head2, - ) -> GenericExpr { - match self { - GenericExpr::Var(ann, v) => GenericExpr::Var(ann, v), - GenericExpr::Lit(ann, lit) => GenericExpr::Lit(ann, lit), - GenericExpr::Call(ann, op, children) => { - let children = children.into_iter().map(|c| c.map_heads(f)).collect(); - GenericExpr::Call(ann, f(op), children) - } - } - } - // TODO: Currently, subst_leaf takes a leaf but not an annotation over the leaf, // so it has to "make up" annotations for the returned GenericExpr. A better // approach is for subst_leaf to also take the annotation, which we should // implement after we use real non-() annotations pub fn subst( - self, - subst_leaf: &mut impl FnMut(Leaf) -> GenericExpr, - subst_head: &mut impl FnMut(Head) -> Head2, + &self, + subst_leaf: &mut impl FnMut(&Leaf) -> GenericExpr, + subst_head: &mut impl FnMut(&Head) -> Head2, ) -> GenericExpr { match self { GenericExpr::Lit(ann, lit) => GenericExpr::Lit(ann.clone(), lit.clone()), GenericExpr::Var(_ann, v) => subst_leaf(v), GenericExpr::Call(ann, op, children) => { let children = children - .into_iter() + .iter() .map(|c| c.subst(subst_leaf, subst_head)) .collect(); GenericExpr::Call(ann.clone(), subst_head(op), children) @@ -239,10 +225,10 @@ impl } pub fn subst_leaf( - self, - subst: &mut impl FnMut(Leaf) -> GenericExpr, + &self, + subst_leaf: &mut impl FnMut(&Leaf) -> GenericExpr, ) -> GenericExpr { - self.subst(subst, &mut |op| op.clone()) + self.subst(subst_leaf, &mut |x| x.clone()) } pub fn vars(&self) -> impl Iterator + '_ { diff --git a/src/ast/mod.rs b/src/ast/mod.rs index c61a7bb11..a027990b4 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -147,7 +147,7 @@ where } } - pub fn map_exprs( + pub fn visit_exprs( self, f: &mut impl FnMut(GenericExpr) -> GenericExpr, ) -> Self { @@ -157,7 +157,7 @@ where value: f(value.clone()), }, GenericNCommand::Sort(name, params) => GenericNCommand::Sort(name, params), - GenericNCommand::Function(func) => GenericNCommand::Function(func.map_exprs(f)), + GenericNCommand::Function(func) => GenericNCommand::Function(func.visit_exprs(f)), GenericNCommand::AddRuleset(name) => GenericNCommand::AddRuleset(name), GenericNCommand::NormRule { name, @@ -166,15 +166,17 @@ where } => GenericNCommand::NormRule { name, ruleset, - rule: rule.map_exprs(f), + rule: rule.visit_exprs(f), }, GenericNCommand::RunSchedule(schedule) => { - GenericNCommand::RunSchedule(schedule.map_exprs(f)) + GenericNCommand::RunSchedule(schedule.visit_exprs(f)) } GenericNCommand::PrintOverallStatistics => GenericNCommand::PrintOverallStatistics, - GenericNCommand::CoreAction(action) => GenericNCommand::CoreAction(action.map_exprs(f)), + GenericNCommand::CoreAction(action) => { + GenericNCommand::CoreAction(action.visit_exprs(f)) + } GenericNCommand::Check(facts) => { - GenericNCommand::Check(facts.into_iter().map(|fact| fact.map_exprs(f)).collect()) + GenericNCommand::Check(facts.into_iter().map(|fact| fact.visit_exprs(f)).collect()) } GenericNCommand::CheckProof => GenericNCommand::CheckProof, GenericNCommand::PrintTable(name, n) => GenericNCommand::PrintTable(name, n), @@ -185,7 +187,7 @@ where }, GenericNCommand::Push(n) => GenericNCommand::Push(n), GenericNCommand::Pop(n) => GenericNCommand::Pop(n), - GenericNCommand::Fail(cmd) => GenericNCommand::Fail(Box::new(cmd.map_exprs(f))), + GenericNCommand::Fail(cmd) => GenericNCommand::Fail(Box::new(cmd.visit_exprs(f))), GenericNCommand::Input { name, file } => GenericNCommand::Input { name, file }, } } @@ -254,20 +256,20 @@ where Head: Clone + Display, Leaf: Clone + PartialEq + Eq + Display + Hash, { - fn map_exprs( + fn visit_exprs( self, f: &mut impl FnMut(GenericExpr) -> GenericExpr, ) -> Self { match self { GenericSchedule::Saturate(sched) => { - GenericSchedule::Saturate(Box::new(sched.map_exprs(f))) + GenericSchedule::Saturate(Box::new(sched.visit_exprs(f))) } GenericSchedule::Repeat(size, sched) => { - GenericSchedule::Repeat(size, Box::new(sched.map_exprs(f))) + GenericSchedule::Repeat(size, Box::new(sched.visit_exprs(f))) } - GenericSchedule::Run(config) => GenericSchedule::Run(config.map_exprs(f)), + GenericSchedule::Run(config) => GenericSchedule::Run(config.visit_exprs(f)), GenericSchedule::Sequence(scheds) => { - GenericSchedule::Sequence(scheds.into_iter().map(|s| s.map_exprs(f)).collect()) + GenericSchedule::Sequence(scheds.into_iter().map(|s| s.visit_exprs(f)).collect()) } } } @@ -756,7 +758,7 @@ where Head: Clone + Display, Leaf: Clone + PartialEq + Eq + Display + Hash, { - pub fn map_exprs( + pub fn visit_exprs( self, f: &mut impl FnMut(GenericExpr) -> GenericExpr, ) -> Self { @@ -764,7 +766,7 @@ where ruleset: self.ruleset, until: self .until - .map(|until| until.into_iter().map(|fact| fact.map_exprs(f)).collect()), + .map(|until| until.into_iter().map(|fact| fact.visit_exprs(f)).collect()), } } } @@ -871,16 +873,16 @@ where Head: Clone + Display, Leaf: Clone + PartialEq + Eq + Display + Hash, { - pub fn map_exprs( + pub fn visit_exprs( self, f: &mut impl FnMut(GenericExpr) -> GenericExpr, ) -> GenericFunctionDecl { GenericFunctionDecl { name: self.name, schema: self.schema, - default: self.default.map(|expr| expr.map(f)), - merge: self.merge.map(|expr| expr.map(f)), - merge_action: self.merge_action.map_exprs(f), + default: self.default.map(|expr| expr.visit_exprs(f)), + merge: self.merge.map(|expr| expr.visit_exprs(f)), + merge_action: self.merge_action.visit_exprs(f), cost: self.cost, unextractable: self.unextractable, } @@ -1034,40 +1036,34 @@ where Head: Clone + Display, Leaf: Clone + PartialEq + Eq + Display + Hash, { - pub(crate) fn map_exprs( + pub(crate) fn visit_exprs( self, f: &mut impl FnMut(GenericExpr) -> GenericExpr, ) -> GenericFact { match self { GenericFact::Eq(exprs) => { - GenericFact::Eq(exprs.into_iter().map(|expr| expr.map(f)).collect()) + GenericFact::Eq(exprs.into_iter().map(|expr| expr.visit_exprs(f)).collect()) } - GenericFact::Fact(expr) => GenericFact::Fact(expr.map(f)), + GenericFact::Fact(expr) => GenericFact::Fact(expr.visit_exprs(f)), } } - pub(crate) fn map_leafs( - self, - f: &mut impl FnMut(Leaf) -> Leaf2, - ) -> GenericFact { + pub(crate) fn map_exprs( + &self, + f: &mut impl FnMut(&GenericExpr) -> GenericExpr, + ) -> GenericFact { match self { - GenericFact::Eq(exprs) => { - GenericFact::Eq(exprs.into_iter().map(|e| e.map_leafs(f)).collect()) - } - GenericFact::Fact(expr) => GenericFact::Fact(expr.map_leafs(f)), + GenericFact::Eq(exprs) => GenericFact::Eq(exprs.iter().map(f).collect()), + GenericFact::Fact(expr) => GenericFact::Fact(f(expr)), } } - pub(crate) fn map_heads( - self, - f: &mut impl FnMut(Head) -> Head2, - ) -> GenericFact { - match self { - GenericFact::Eq(exprs) => { - GenericFact::Eq(exprs.into_iter().map(|e| e.map_heads(f)).collect()) - } - GenericFact::Fact(expr) => GenericFact::Fact(expr.map_heads(f)), - } + pub(crate) fn subst( + &self, + subst_leaf: &mut impl FnMut(&Leaf) -> GenericExpr, + subst_head: &mut impl FnMut(&Head) -> Head2, + ) -> GenericFact { + self.map_exprs(&mut |e| e.subst(subst_leaf, subst_head)) } } @@ -1081,8 +1077,9 @@ where Leaf: SymbolLike, Head: SymbolLike, { - self.map_leafs(&mut |v| v.to_symbol()) - .map_heads(&mut |h| h.to_symbol()) + self.subst(&mut |v| Expr::Var((), v.to_symbol()), &mut |h| { + h.to_symbol() + }) } } @@ -1230,11 +1227,11 @@ where self.0.is_empty() } - pub(crate) fn map_exprs( + pub(crate) fn visit_exprs( self, f: &mut impl FnMut(GenericExpr) -> GenericExpr, ) -> Self { - Self(self.0.into_iter().map(|a| a.map_exprs(f)).collect()) + Self(self.0.into_iter().map(|a| a.visit_exprs(f)).collect()) } } @@ -1263,45 +1260,77 @@ where Leaf: Clone + Eq + Display + Hash, Ann: Clone + Default, { + // Applys `f` to all expressions in the action. + pub fn map_exprs( + &self, + f: &mut impl FnMut(&GenericExpr) -> GenericExpr, + ) -> Self { + match self { + GenericAction::Let(ann, lhs, rhs) => { + GenericAction::Let(ann.clone(), lhs.clone(), f(rhs)) + } + GenericAction::Set(ann, lhs, args, rhs) => { + let right = f(rhs); + GenericAction::Set( + ann.clone(), + lhs.clone(), + args.iter().map(f).collect(), + right, + ) + } + GenericAction::Delete(ann, lhs, args) => { + GenericAction::Delete(ann.clone(), lhs.clone(), args.iter().map(f).collect()) + } + GenericAction::Union(ann, lhs, rhs) => { + GenericAction::Union(ann.clone(), f(lhs), f(rhs)) + } + GenericAction::Extract(ann, expr, variants) => { + GenericAction::Extract(ann.clone(), f(expr), f(variants)) + } + GenericAction::Panic(ann, msg) => GenericAction::Panic(ann.clone(), msg.clone()), + GenericAction::Expr(ann, e) => GenericAction::Expr(ann.clone(), f(e)), + } + } + /// Applys `f` to all sub-expressions (including `self`) /// bottom-up, collecting the results. - pub fn map_exprs( + pub fn visit_exprs( self, f: &mut impl FnMut(GenericExpr) -> GenericExpr, ) -> Self { match self { GenericAction::Let(ann, lhs, rhs) => { - GenericAction::Let(ann.clone(), lhs.clone(), rhs.map(f)) + GenericAction::Let(ann.clone(), lhs.clone(), rhs.visit_exprs(f)) } // TODO should we refactor `Set` so that we can map over Expr::Call(lhs, args)? // This seems more natural to oflatt GenericAction::Set(ann, lhs, args, rhs) => { - let args = args.into_iter().map(|e| e.map(f)).collect(); - GenericAction::Set(ann.clone(), lhs.clone(), args, rhs.map(f)) + let args = args.into_iter().map(|e| e.visit_exprs(f)).collect(); + GenericAction::Set(ann.clone(), lhs.clone(), args, rhs.visit_exprs(f)) } GenericAction::Delete(ann, lhs, args) => { - let args = args.into_iter().map(|e| e.map(f)).collect(); + let args = args.into_iter().map(|e| e.visit_exprs(f)).collect(); GenericAction::Delete(ann.clone(), lhs.clone(), args) } GenericAction::Union(ann, lhs, rhs) => { - GenericAction::Union(ann.clone(), lhs.map(f), rhs.map(f)) + GenericAction::Union(ann.clone(), lhs.visit_exprs(f), rhs.visit_exprs(f)) } GenericAction::Extract(ann, expr, variants) => { - GenericAction::Extract(ann.clone(), expr.map(f), variants.map(f)) + GenericAction::Extract(ann.clone(), expr.visit_exprs(f), variants.visit_exprs(f)) } GenericAction::Panic(ann, msg) => GenericAction::Panic(ann.clone(), msg.clone()), - GenericAction::Expr(ann, e) => GenericAction::Expr(ann.clone(), e.map(f)), + GenericAction::Expr(ann, e) => GenericAction::Expr(ann.clone(), e.visit_exprs(f)), } } - pub fn subst(self, subst: &mut impl FnMut(Leaf) -> GenericExpr) -> Self { + pub fn subst(&self, subst: &mut impl FnMut(&Leaf) -> GenericExpr) -> Self { self.map_exprs(&mut |e| e.subst_leaf(subst)) } pub fn map_def_use(self, fvar: &mut impl FnMut(Leaf, bool) -> Leaf) -> Self { macro_rules! fvar_expr { () => { - |s: _| GenericExpr::Var(Ann::default(), fvar(s, false)) + |s: _| GenericExpr::Var(Ann::default(), fvar(s.clone(), false)) }; } match self { @@ -1374,16 +1403,16 @@ where Head: Clone + Display, Leaf: Clone + PartialEq + Eq + Display + Hash, { - pub(crate) fn map_exprs( + pub(crate) fn visit_exprs( self, f: &mut impl FnMut(GenericExpr) -> GenericExpr, ) -> Self { Self { - head: self.head.map_exprs(f), + head: self.head.visit_exprs(f), body: self .body .into_iter() - .map(|bexpr| bexpr.map_exprs(f)) + .map(|bexpr| bexpr.visit_exprs(f)) .collect(), } } diff --git a/src/ast/remove_globals.rs b/src/ast/remove_globals.rs index d9e276d86..9e0e46f7c 100644 --- a/src/ast/remove_globals.rs +++ b/src/ast/remove_globals.rs @@ -59,11 +59,11 @@ fn replace_global_var(expr: ResolvedExpr) -> ResolvedExpr { } fn remove_globals_expr(expr: ResolvedExpr) -> ResolvedExpr { - expr.map(&mut replace_global_var) + expr.visit_exprs(&mut replace_global_var) } fn remove_globals_action(action: ResolvedAction) -> ResolvedAction { - action.map_exprs(&mut replace_global_var) + action.visit_exprs(&mut replace_global_var) } fn remove_globals_cmd(type_info: &TypeInfo, cmd: ResolvedNCommand) -> Vec { @@ -101,6 +101,6 @@ fn remove_globals_cmd(type_info: &TypeInfo, cmd: ResolvedNCommand) -> Vec vec![GenericNCommand::CoreAction(remove_globals_action(action))], }, - _ => vec![cmd.map_exprs(&mut remove_globals_expr)], + _ => vec![cmd.visit_exprs(&mut remove_globals_expr)], } }