diff --git a/interpreter/prune.go b/interpreter/prune.go index 85b3b065..fc2c8135 100644 --- a/interpreter/prune.go +++ b/interpreter/prune.go @@ -68,19 +68,15 @@ type astPruner struct { // the overloads accordingly. func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr { pruneState := NewEvalState() - maxID := int64(1) for _, id := range state.IDs() { v, _ := state.Value(id) pruneState.SetValue(id, v) - if id > maxID { - maxID = id + 1 - } } pruner := &astPruner{ expr: expr, macroCalls: macroCalls, state: pruneState, - nextExprID: maxID} + nextExprID: getMaxID(expr)} newExpr, _ := pruner.maybePrune(expr) return &exprpb.ParsedExpr{ Expr: newExpr, @@ -463,3 +459,47 @@ func (p *astPruner) nextID() int64 { p.nextExprID++ } } + +func getMaxID(expr *exprpb.Expr) int64 { + maxID := int64(1) + exprs := []*exprpb.Expr{expr} + for len(exprs) != 0 { + e := exprs[0] + if e.GetId() >= maxID { + maxID = e.GetId() + 1 + } + exprs = exprs[1:] + switch e.GetExprKind().(type) { + case *exprpb.Expr_SelectExpr: + exprs = append(exprs, e.GetSelectExpr().GetOperand()) + case *exprpb.Expr_CallExpr: + call := e.GetCallExpr() + if call.GetTarget() != nil { + exprs = append(exprs, call.GetTarget()) + } + exprs = append(exprs, call.GetArgs()...) + case *exprpb.Expr_ComprehensionExpr: + compre := e.GetComprehensionExpr() + exprs = append(exprs, + compre.GetIterRange(), + compre.GetAccuInit(), + compre.GetLoopCondition(), + compre.GetLoopStep(), + compre.GetResult()) + case *exprpb.Expr_ListExpr: + list := e.GetListExpr() + exprs = append(exprs, list.GetElements()...) + case *exprpb.Expr_StructExpr: + for _, entry := range expr.GetStructExpr().GetEntries() { + if entry.GetMapKey() != nil { + exprs = append(exprs, entry.GetMapKey()) + } + exprs = append(exprs, entry.GetValue()) + if entry.GetId() >= maxID { + maxID = entry.GetId() + 1 + } + } + } + } + return maxID +} diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index fb7855d5..302d0bf5 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -157,6 +157,16 @@ var testCases = []testInfo{ expr: `foo == "bar" && r.attr.loc in ["GB", "US"]`, out: `r.attr.loc in ["GB", "US"]`, }, + { + in: partialActivation(map[string]any{ + "users": []map[string]string{ + {"name": "alice", "role": "EMPLOYEE"}, + {"name": "bob", "role": "MANAGER"}, + {"name": "eve", "role": "CUSTOMER"}, + }}, "r.attr.*"), + expr: `users.filter(u, u.role=="MANAGER").map(u, u.name) == r.attr.authorized["managers"]`, + out: `["bob"] == r.attr.authorized["managers"]`, + }, // TODO: the output of an expression like this relies on either // a) doing replacements on the original macro call, or // b) mutating the macro call tracking data rather than the core