Skip to content

Commit

Permalink
AST scan to establish a non-conflicting expression ID for pruning (#703)
Browse files Browse the repository at this point in the history
* Scan the input AST to establish a non-conflicting expression ID for pruned expressions
* Refactors to clarify max identifier logic in prune step
  • Loading branch information
TristonianJones authored May 15, 2023
1 parent c08c0cc commit a465d93
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
50 changes: 45 additions & 5 deletions interpreter/prune.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,15 @@ type astPruner struct {
// the overloads accordingly.
func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr {
pruneState := NewEvalState()
maxID := int64(1)
for _, id := range state.IDs() {
v, _ := state.Value(id)
pruneState.SetValue(id, v)
if id > maxID {
maxID = id + 1
}
}
pruner := &astPruner{
expr: expr,
macroCalls: macroCalls,
state: pruneState,
nextExprID: maxID}
nextExprID: getMaxID(expr)}
newExpr, _ := pruner.maybePrune(expr)
return &exprpb.ParsedExpr{
Expr: newExpr,
Expand Down Expand Up @@ -463,3 +459,47 @@ func (p *astPruner) nextID() int64 {
p.nextExprID++
}
}

func getMaxID(expr *exprpb.Expr) int64 {
maxID := int64(1)
exprs := []*exprpb.Expr{expr}
for len(exprs) != 0 {
e := exprs[0]
if e.GetId() >= maxID {
maxID = e.GetId() + 1
}
exprs = exprs[1:]
switch e.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
exprs = append(exprs, e.GetSelectExpr().GetOperand())
case *exprpb.Expr_CallExpr:
call := e.GetCallExpr()
if call.GetTarget() != nil {
exprs = append(exprs, call.GetTarget())
}
exprs = append(exprs, call.GetArgs()...)
case *exprpb.Expr_ComprehensionExpr:
compre := e.GetComprehensionExpr()
exprs = append(exprs,
compre.GetIterRange(),
compre.GetAccuInit(),
compre.GetLoopCondition(),
compre.GetLoopStep(),
compre.GetResult())
case *exprpb.Expr_ListExpr:
list := e.GetListExpr()
exprs = append(exprs, list.GetElements()...)
case *exprpb.Expr_StructExpr:
for _, entry := range expr.GetStructExpr().GetEntries() {
if entry.GetMapKey() != nil {
exprs = append(exprs, entry.GetMapKey())
}
exprs = append(exprs, entry.GetValue())
if entry.GetId() >= maxID {
maxID = entry.GetId() + 1
}
}
}
}
return maxID
}
10 changes: 10 additions & 0 deletions interpreter/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@ var testCases = []testInfo{
expr: `foo == "bar" && r.attr.loc in ["GB", "US"]`,
out: `r.attr.loc in ["GB", "US"]`,
},
{
in: partialActivation(map[string]any{
"users": []map[string]string{
{"name": "alice", "role": "EMPLOYEE"},
{"name": "bob", "role": "MANAGER"},
{"name": "eve", "role": "CUSTOMER"},
}}, "r.attr.*"),
expr: `users.filter(u, u.role=="MANAGER").map(u, u.name) == r.attr.authorized["managers"]`,
out: `["bob"] == r.attr.authorized["managers"]`,
},
// TODO: the output of an expression like this relies on either
// a) doing replacements on the original macro call, or
// b) mutating the macro call tracking data rather than the core
Expand Down

0 comments on commit a465d93

Please sign in to comment.