Skip to content

Commit

Permalink
fix up subst as it was before, rename
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Feb 15, 2024
1 parent 3a699e7 commit e26a499
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 82 deletions.
32 changes: 9 additions & 23 deletions src/ast/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ impl<Head: Clone + Display, Leaf: Hash + Clone + Display + Eq, Ann: Clone>

/// Applys `f` to all sub-expressions (including `self`)
/// bottom-up, collecting the results.
pub fn map(self, f: &mut impl FnMut(Self) -> Self) -> Self {
pub fn visit_exprs(self, f: &mut impl FnMut(Self) -> Self) -> Self {
match self {
GenericExpr::Lit(..) => f(self),
GenericExpr::Var(..) => f(self),
GenericExpr::Call(ann, op, children) => {
let children = children.into_iter().map(|c| c.map(f)).collect();
let children = children.into_iter().map(|c| c.visit_exprs(f)).collect();
f(GenericExpr::Call(ann.clone(), op.clone(), children))
}
}
Expand All @@ -202,35 +202,21 @@ impl<Head: Clone + Display, Leaf: Hash + Clone + Display + Eq, Ann: Clone>
}
}

pub fn map_heads<Head2>(
self,
f: &mut impl FnMut(Head) -> Head2,
) -> GenericExpr<Head2, Leaf, Ann> {
match self {
GenericExpr::Var(ann, v) => GenericExpr::Var(ann, v),
GenericExpr::Lit(ann, lit) => GenericExpr::Lit(ann, lit),
GenericExpr::Call(ann, op, children) => {
let children = children.into_iter().map(|c| c.map_heads(f)).collect();
GenericExpr::Call(ann, f(op), children)
}
}
}

// TODO: Currently, subst_leaf takes a leaf but not an annotation over the leaf,
// so it has to "make up" annotations for the returned GenericExpr. A better
// approach is for subst_leaf to also take the annotation, which we should
// implement after we use real non-() annotations
pub fn subst<Head2, Leaf2>(
self,
subst_leaf: &mut impl FnMut(Leaf) -> GenericExpr<Head2, Leaf2, Ann>,
subst_head: &mut impl FnMut(Head) -> Head2,
&self,
subst_leaf: &mut impl FnMut(&Leaf) -> GenericExpr<Head2, Leaf2, Ann>,
subst_head: &mut impl FnMut(&Head) -> Head2,
) -> GenericExpr<Head2, Leaf2, Ann> {
match self {
GenericExpr::Lit(ann, lit) => GenericExpr::Lit(ann.clone(), lit.clone()),
GenericExpr::Var(_ann, v) => subst_leaf(v),
GenericExpr::Call(ann, op, children) => {
let children = children
.into_iter()
.iter()
.map(|c| c.subst(subst_leaf, subst_head))
.collect();
GenericExpr::Call(ann.clone(), subst_head(op), children)
Expand All @@ -239,10 +225,10 @@ impl<Head: Clone + Display, Leaf: Hash + Clone + Display + Eq, Ann: Clone>
}

pub fn subst_leaf<Leaf2>(
self,
subst: &mut impl FnMut(Leaf) -> GenericExpr<Head, Leaf2, Ann>,
&self,
subst_leaf: &mut impl FnMut(&Leaf) -> GenericExpr<Head, Leaf2, Ann>,
) -> GenericExpr<Head, Leaf2, Ann> {
self.subst(subst, &mut |op| op.clone())
self.subst(subst_leaf, &mut |x| x.clone())
}

pub fn vars(&self) -> impl Iterator<Item = Leaf> + '_ {
Expand Down
141 changes: 85 additions & 56 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ where
}
}

pub fn map_exprs(
pub fn visit_exprs(
self,
f: &mut impl FnMut(GenericExpr<Head, Leaf, ()>) -> GenericExpr<Head, Leaf, ()>,
) -> Self {
Expand All @@ -157,7 +157,7 @@ where
value: f(value.clone()),
},
GenericNCommand::Sort(name, params) => GenericNCommand::Sort(name, params),
GenericNCommand::Function(func) => GenericNCommand::Function(func.map_exprs(f)),
GenericNCommand::Function(func) => GenericNCommand::Function(func.visit_exprs(f)),
GenericNCommand::AddRuleset(name) => GenericNCommand::AddRuleset(name),
GenericNCommand::NormRule {
name,
Expand All @@ -166,15 +166,17 @@ where
} => GenericNCommand::NormRule {
name,
ruleset,
rule: rule.map_exprs(f),
rule: rule.visit_exprs(f),
},
GenericNCommand::RunSchedule(schedule) => {
GenericNCommand::RunSchedule(schedule.map_exprs(f))
GenericNCommand::RunSchedule(schedule.visit_exprs(f))
}
GenericNCommand::PrintOverallStatistics => GenericNCommand::PrintOverallStatistics,
GenericNCommand::CoreAction(action) => GenericNCommand::CoreAction(action.map_exprs(f)),
GenericNCommand::CoreAction(action) => {
GenericNCommand::CoreAction(action.visit_exprs(f))
}
GenericNCommand::Check(facts) => {
GenericNCommand::Check(facts.into_iter().map(|fact| fact.map_exprs(f)).collect())
GenericNCommand::Check(facts.into_iter().map(|fact| fact.visit_exprs(f)).collect())
}
GenericNCommand::CheckProof => GenericNCommand::CheckProof,
GenericNCommand::PrintTable(name, n) => GenericNCommand::PrintTable(name, n),
Expand All @@ -185,7 +187,7 @@ where
},
GenericNCommand::Push(n) => GenericNCommand::Push(n),
GenericNCommand::Pop(n) => GenericNCommand::Pop(n),
GenericNCommand::Fail(cmd) => GenericNCommand::Fail(Box::new(cmd.map_exprs(f))),
GenericNCommand::Fail(cmd) => GenericNCommand::Fail(Box::new(cmd.visit_exprs(f))),
GenericNCommand::Input { name, file } => GenericNCommand::Input { name, file },
}
}
Expand Down Expand Up @@ -254,20 +256,20 @@ where
Head: Clone + Display,
Leaf: Clone + PartialEq + Eq + Display + Hash,
{
fn map_exprs(
fn visit_exprs(
self,
f: &mut impl FnMut(GenericExpr<Head, Leaf, Ann>) -> GenericExpr<Head, Leaf, Ann>,
) -> Self {
match self {
GenericSchedule::Saturate(sched) => {
GenericSchedule::Saturate(Box::new(sched.map_exprs(f)))
GenericSchedule::Saturate(Box::new(sched.visit_exprs(f)))
}
GenericSchedule::Repeat(size, sched) => {
GenericSchedule::Repeat(size, Box::new(sched.map_exprs(f)))
GenericSchedule::Repeat(size, Box::new(sched.visit_exprs(f)))
}
GenericSchedule::Run(config) => GenericSchedule::Run(config.map_exprs(f)),
GenericSchedule::Run(config) => GenericSchedule::Run(config.visit_exprs(f)),
GenericSchedule::Sequence(scheds) => {
GenericSchedule::Sequence(scheds.into_iter().map(|s| s.map_exprs(f)).collect())
GenericSchedule::Sequence(scheds.into_iter().map(|s| s.visit_exprs(f)).collect())
}
}
}
Expand Down Expand Up @@ -756,15 +758,15 @@ where
Head: Clone + Display,
Leaf: Clone + PartialEq + Eq + Display + Hash,
{
pub fn map_exprs(
pub fn visit_exprs(
self,
f: &mut impl FnMut(GenericExpr<Head, Leaf, Ann>) -> GenericExpr<Head, Leaf, Ann>,
) -> Self {
Self {
ruleset: self.ruleset,
until: self
.until
.map(|until| until.into_iter().map(|fact| fact.map_exprs(f)).collect()),
.map(|until| until.into_iter().map(|fact| fact.visit_exprs(f)).collect()),
}
}
}
Expand Down Expand Up @@ -871,16 +873,16 @@ where
Head: Clone + Display,
Leaf: Clone + PartialEq + Eq + Display + Hash,
{
pub fn map_exprs(
pub fn visit_exprs(
self,
f: &mut impl FnMut(GenericExpr<Head, Leaf, Ann>) -> GenericExpr<Head, Leaf, Ann>,
) -> GenericFunctionDecl<Head, Leaf, Ann> {
GenericFunctionDecl {
name: self.name,
schema: self.schema,
default: self.default.map(|expr| expr.map(f)),
merge: self.merge.map(|expr| expr.map(f)),
merge_action: self.merge_action.map_exprs(f),
default: self.default.map(|expr| expr.visit_exprs(f)),
merge: self.merge.map(|expr| expr.visit_exprs(f)),
merge_action: self.merge_action.visit_exprs(f),
cost: self.cost,
unextractable: self.unextractable,
}
Expand Down Expand Up @@ -1034,40 +1036,34 @@ where
Head: Clone + Display,
Leaf: Clone + PartialEq + Eq + Display + Hash,
{
pub(crate) fn map_exprs(
pub(crate) fn visit_exprs(
self,
f: &mut impl FnMut(GenericExpr<Head, Leaf, Ann>) -> GenericExpr<Head, Leaf, Ann>,
) -> GenericFact<Head, Leaf, Ann> {
match self {
GenericFact::Eq(exprs) => {
GenericFact::Eq(exprs.into_iter().map(|expr| expr.map(f)).collect())
GenericFact::Eq(exprs.into_iter().map(|expr| expr.visit_exprs(f)).collect())
}
GenericFact::Fact(expr) => GenericFact::Fact(expr.map(f)),
GenericFact::Fact(expr) => GenericFact::Fact(expr.visit_exprs(f)),
}
}

pub(crate) fn map_leafs<Leaf2>(
self,
f: &mut impl FnMut(Leaf) -> Leaf2,
) -> GenericFact<Head, Leaf2, Ann> {
pub(crate) fn map_exprs<Head2, Leaf2>(
&self,
f: &mut impl FnMut(&GenericExpr<Head, Leaf, Ann>) -> GenericExpr<Head2, Leaf2, Ann>,
) -> GenericFact<Head2, Leaf2, Ann> {
match self {
GenericFact::Eq(exprs) => {
GenericFact::Eq(exprs.into_iter().map(|e| e.map_leafs(f)).collect())
}
GenericFact::Fact(expr) => GenericFact::Fact(expr.map_leafs(f)),
GenericFact::Eq(exprs) => GenericFact::Eq(exprs.iter().map(f).collect()),
GenericFact::Fact(expr) => GenericFact::Fact(f(expr)),
}
}

pub(crate) fn map_heads<Head2>(
self,
f: &mut impl FnMut(Head) -> Head2,
) -> GenericFact<Head2, Leaf, Ann> {
match self {
GenericFact::Eq(exprs) => {
GenericFact::Eq(exprs.into_iter().map(|e| e.map_heads(f)).collect())
}
GenericFact::Fact(expr) => GenericFact::Fact(expr.map_heads(f)),
}
pub(crate) fn subst<Leaf2, Head2>(
&self,
subst_leaf: &mut impl FnMut(&Leaf) -> GenericExpr<Head2, Leaf2, Ann>,
subst_head: &mut impl FnMut(&Head) -> Head2,
) -> GenericFact<Head2, Leaf2, Ann> {
self.map_exprs(&mut |e| e.subst(subst_leaf, subst_head))
}
}

Expand All @@ -1081,8 +1077,9 @@ where
Leaf: SymbolLike,
Head: SymbolLike,
{
self.map_leafs(&mut |v| v.to_symbol())
.map_heads(&mut |h| h.to_symbol())
self.subst(&mut |v| Expr::Var((), v.to_symbol()), &mut |h| {
h.to_symbol()
})
}
}

Expand Down Expand Up @@ -1230,11 +1227,11 @@ where
self.0.is_empty()
}

pub(crate) fn map_exprs(
pub(crate) fn visit_exprs(
self,
f: &mut impl FnMut(GenericExpr<Head, Leaf, Ann>) -> GenericExpr<Head, Leaf, Ann>,
) -> Self {
Self(self.0.into_iter().map(|a| a.map_exprs(f)).collect())
Self(self.0.into_iter().map(|a| a.visit_exprs(f)).collect())
}
}

Expand Down Expand Up @@ -1263,45 +1260,77 @@ where
Leaf: Clone + Eq + Display + Hash,
Ann: Clone + Default,
{
// Applys `f` to all expressions in the action.
pub fn map_exprs(
&self,
f: &mut impl FnMut(&GenericExpr<Head, Leaf, Ann>) -> GenericExpr<Head, Leaf, Ann>,
) -> Self {
match self {
GenericAction::Let(ann, lhs, rhs) => {
GenericAction::Let(ann.clone(), lhs.clone(), f(rhs))
}
GenericAction::Set(ann, lhs, args, rhs) => {
let right = f(rhs);
GenericAction::Set(
ann.clone(),
lhs.clone(),
args.iter().map(f).collect(),
right,
)
}
GenericAction::Delete(ann, lhs, args) => {
GenericAction::Delete(ann.clone(), lhs.clone(), args.iter().map(f).collect())
}
GenericAction::Union(ann, lhs, rhs) => {
GenericAction::Union(ann.clone(), f(lhs), f(rhs))
}
GenericAction::Extract(ann, expr, variants) => {
GenericAction::Extract(ann.clone(), f(expr), f(variants))
}
GenericAction::Panic(ann, msg) => GenericAction::Panic(ann.clone(), msg.clone()),
GenericAction::Expr(ann, e) => GenericAction::Expr(ann.clone(), f(e)),
}
}

/// Applys `f` to all sub-expressions (including `self`)
/// bottom-up, collecting the results.
pub fn map_exprs(
pub fn visit_exprs(
self,
f: &mut impl FnMut(GenericExpr<Head, Leaf, Ann>) -> GenericExpr<Head, Leaf, Ann>,
) -> Self {
match self {
GenericAction::Let(ann, lhs, rhs) => {
GenericAction::Let(ann.clone(), lhs.clone(), rhs.map(f))
GenericAction::Let(ann.clone(), lhs.clone(), rhs.visit_exprs(f))
}
// TODO should we refactor `Set` so that we can map over Expr::Call(lhs, args)?
// This seems more natural to oflatt
GenericAction::Set(ann, lhs, args, rhs) => {
let args = args.into_iter().map(|e| e.map(f)).collect();
GenericAction::Set(ann.clone(), lhs.clone(), args, rhs.map(f))
let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
GenericAction::Set(ann.clone(), lhs.clone(), args, rhs.visit_exprs(f))
}
GenericAction::Delete(ann, lhs, args) => {
let args = args.into_iter().map(|e| e.map(f)).collect();
let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
GenericAction::Delete(ann.clone(), lhs.clone(), args)
}
GenericAction::Union(ann, lhs, rhs) => {
GenericAction::Union(ann.clone(), lhs.map(f), rhs.map(f))
GenericAction::Union(ann.clone(), lhs.visit_exprs(f), rhs.visit_exprs(f))
}
GenericAction::Extract(ann, expr, variants) => {
GenericAction::Extract(ann.clone(), expr.map(f), variants.map(f))
GenericAction::Extract(ann.clone(), expr.visit_exprs(f), variants.visit_exprs(f))
}
GenericAction::Panic(ann, msg) => GenericAction::Panic(ann.clone(), msg.clone()),
GenericAction::Expr(ann, e) => GenericAction::Expr(ann.clone(), e.map(f)),
GenericAction::Expr(ann, e) => GenericAction::Expr(ann.clone(), e.visit_exprs(f)),
}
}

pub fn subst(self, subst: &mut impl FnMut(Leaf) -> GenericExpr<Head, Leaf, Ann>) -> Self {
pub fn subst(&self, subst: &mut impl FnMut(&Leaf) -> GenericExpr<Head, Leaf, Ann>) -> Self {
self.map_exprs(&mut |e| e.subst_leaf(subst))
}

pub fn map_def_use(self, fvar: &mut impl FnMut(Leaf, bool) -> Leaf) -> Self {
macro_rules! fvar_expr {
() => {
|s: _| GenericExpr::Var(Ann::default(), fvar(s, false))
|s: _| GenericExpr::Var(Ann::default(), fvar(s.clone(), false))
};
}
match self {
Expand Down Expand Up @@ -1374,16 +1403,16 @@ where
Head: Clone + Display,
Leaf: Clone + PartialEq + Eq + Display + Hash,
{
pub(crate) fn map_exprs(
pub(crate) fn visit_exprs(
self,
f: &mut impl FnMut(GenericExpr<Head, Leaf, Ann>) -> GenericExpr<Head, Leaf, Ann>,
) -> Self {
Self {
head: self.head.map_exprs(f),
head: self.head.visit_exprs(f),
body: self
.body
.into_iter()
.map(|bexpr| bexpr.map_exprs(f))
.map(|bexpr| bexpr.visit_exprs(f))
.collect(),
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/ast/remove_globals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ fn replace_global_var(expr: ResolvedExpr) -> ResolvedExpr {
}

fn remove_globals_expr(expr: ResolvedExpr) -> ResolvedExpr {
expr.map(&mut replace_global_var)
expr.visit_exprs(&mut replace_global_var)
}

fn remove_globals_action(action: ResolvedAction) -> ResolvedAction {
action.map_exprs(&mut replace_global_var)
action.visit_exprs(&mut replace_global_var)
}

fn remove_globals_cmd(type_info: &TypeInfo, cmd: ResolvedNCommand) -> Vec<ResolvedNCommand> {
Expand Down Expand Up @@ -101,6 +101,6 @@ fn remove_globals_cmd(type_info: &TypeInfo, cmd: ResolvedNCommand) -> Vec<Resolv
}
_ => vec![GenericNCommand::CoreAction(remove_globals_action(action))],
},
_ => vec![cmd.map_exprs(&mut remove_globals_expr)],
_ => vec![cmd.visit_exprs(&mut remove_globals_expr)],
}
}

0 comments on commit e26a499

Please sign in to comment.