From bd1ec924928544edf70839110c7a28d61ac011a3 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 9 Dec 2024 14:51:20 -0800 Subject: [PATCH] Fix two-variable comprehension pruning (#1083) * Fix two-variable comprehension pruning * Ensure only cel.bind() comprehensions are pruned --- ext/comprehensions_test.go | 254 ++++++++++++++++++++++++++++++ interpreter/activation.go | 22 +++ interpreter/attribute_patterns.go | 13 +- interpreter/interpretable.go | 23 +++ interpreter/prune.go | 61 +++++-- 5 files changed, 346 insertions(+), 27 deletions(-) diff --git a/ext/comprehensions_test.go b/ext/comprehensions_test.go index 1bd65fa4..b41a3154 100644 --- a/ext/comprehensions_test.go +++ b/ext/comprehensions_test.go @@ -20,6 +20,8 @@ import ( "testing" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/interpreter" ) func TestTwoVarComprehensions(t *testing.T) { @@ -359,6 +361,258 @@ func TestTwoVarComprehensionsVersion(t *testing.T) { } } +func TestTwoVarComprehensionsUnparse(t *testing.T) { + tests := []struct { + name string + expr string + unparsed string + }{ + { + name: "transform map entry", + expr: `[0, 0u].transformMapEntry(i, v, {v: i})`, + unparsed: `[0, 0u].transformMapEntry(i, v, {v: i})`, + }, + { + name: "transform map", + expr: `{'a': 'world', 'b': 'hello'}.transformMap(i, v, i == 'a' ? v.upperAscii() : v)`, + unparsed: `{"a": "world", "b": "hello"}.transformMap(i, v, (i == "a") ? v.upperAscii() : v)`, + }, + { + name: "transform list", + expr: `[1.0, 2.0, 2.0].transformList(i, v, i / 2.0 == 1.0)`, + unparsed: `[1.0, 2.0, 2.0].transformList(i, v, i / 2.0 == 1.0)`, + }, + { + name: "existsOne", + expr: `{'a': 'b', 'c': 'd'}.existsOne(k, v, k == 'b' || v == 'b')`, + unparsed: `{"a": "b", "c": "d"}.existsOne(k, v, k == "b" || v == "b")`, + }, + { + name: "exists", + expr: `{'a': 'b', 'c': 'd'}.exists(k, v, k == 'b' || v == 'b')`, + unparsed: `{"a": "b", "c": "d"}.exists(k, v, k == "b" || v == "b")`, + }, + { + name: "all", + expr: `[null, null, 'hello', string].all(i, v, i == 0 || type(v) != int)`, + unparsed: `[null, null, "hello", string].all(i, v, i == 0 || type(v) != int)`, + }, + } + env := testCompreEnv(t) + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + ast, iss := env.Parse(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Parse(%q) failed: %v", tc.expr, iss.Err()) + } + unparsed, err := cel.AstToString(ast) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if unparsed != tc.unparsed { + t.Errorf("cel.AstToString() got %q, wanted %q", unparsed, tc.unparsed) + } + }) + } +} + +func TestTwoVarComprehensionsResidualAST(t *testing.T) { + tests := []struct { + name string + in map[string]any + varOpts []cel.EnvOption + unks []*interpreter.AttributePattern + expr string + residual string + }{ + { + name: "transform map entry residual compare", + varOpts: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.DynType)), + cel.Variable("y", cel.IntType), + }, + in: map[string]any{ + "x": []any{0, uint(1)}, + }, + unks: []*interpreter.AttributePattern{cel.AttributePattern("y")}, + expr: `x.transformMapEntry(i, v, {v: i}).size() < y`, + residual: `2 < y`, + }, + { + name: "transform map entry residual transform", + varOpts: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.DynType)), + cel.Variable("y", cel.IntType), + }, + in: map[string]any{ + "x": []any{0, uint(1)}, + }, + unks: []*interpreter.AttributePattern{cel.AttributePattern("y")}, + expr: `x.transformMapEntry(i, v, i < y, {v: i})`, + residual: `[0, 1u].transformMapEntry(i, v, i < y, {v: i})`, + }, + { + name: "nested exists unknown inner range", + varOpts: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.IntType)), + cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)), + }, + in: map[string]any{ + "x": []any{1, 2, 3}, + }, + unks: []*interpreter.AttributePattern{cel.AttributePattern("y")}, + expr: `x.exists(val, y.exists(key, _, key == val))`, + residual: `[1, 2, 3].exists(val, y.exists(key, _, key == val))`, + }, + { + name: "nested exists unknown inner range", + varOpts: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.IntType)), + cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)), + }, + in: map[string]any{ + "y": map[int]string{1: "hi", 2: "hello", 3: "howdy"}, + }, + unks: []*interpreter.AttributePattern{cel.AttributePattern("x")}, + expr: `x.exists(val, y.exists(key, _, key == val))`, + residual: `x.exists(val, y.exists(key, _, key == val))`, + }, + { + name: "nested exists unknown outer range with extra predicate", + varOpts: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.IntType)), + cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)), + }, + in: map[string]any{ + "y": map[int]string{1: "hi", 2: "hello", 3: "howdy"}, + }, + unks: []*interpreter.AttributePattern{cel.AttributePattern("x")}, + expr: `x.exists(val, y.exists(key, _, key == val)) && y.all(key, val, val.startsWith('h'))`, + residual: `x.exists(val, y.exists(key, _, key == val))`, + }, + { + name: "nested exists partial unknown outer range", + varOpts: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.IntType)), + cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)), + }, + in: map[string]any{ + "x": []int{42, 0, 43}, + "y": map[int]string{1: "hi", 2: "hello", 3: "howdy"}, + }, + unks: []*interpreter.AttributePattern{cel.AttributePattern("x").QualInt(1)}, + expr: `x.exists(val, y.exists(key, _, key == val)) || x[0] == 0 || x[1] == 1 || x[2] == 2`, + residual: `x.exists(val, y.exists(key, _, key == val)) || x[1] == 1`, + }, + { + name: "nested exists partial unknown outer range with optionals", + varOpts: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.IntType)), + cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)), + }, + in: map[string]any{ + "x": []int{42, 0, 43}, + "y": map[int]string{1: "hi", 2: "hello", 3: "howdy"}, + }, + unks: []*interpreter.AttributePattern{cel.AttributePattern("x").QualInt(1)}, + expr: `x.exists(val, y.exists(key, _, key == val)) || (x[?0].hasValue() && x[?1].hasValue())`, + residual: `x.exists(val, y.exists(key, _, key == val)) || x[?1].hasValue()`, + }, + { + name: "inner value partial unknown two-var", + varOpts: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.StringType)), + cel.Variable("y", cel.MapType(cel.IntType, cel.StringType)), + }, + in: map[string]any{ + "x": []string{"howdy", "hello", "hi"}, + "y": map[int]string{0: "hi", 1: "hello", 2: "howdy"}, + }, + unks: []*interpreter.AttributePattern{cel.AttributePattern("y").QualInt(1)}, + expr: `x.exists(key, val, y[?key] == optional.of(val))`, + residual: `["howdy", "hello", "hi"].exists(key, val, y[?key] == optional.of(val))`, + }, + { + name: "inner value partial unknown one-var", + varOpts: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.StringType)), + cel.Variable("y", cel.MapType(cel.IntType, cel.StringType)), + }, + in: map[string]any{ + "x": []string{"howdy"}, + "y": map[int]string{0: "hello"}, + }, + unks: []*interpreter.AttributePattern{cel.AttributePattern("x").QualInt(0)}, + expr: `y.exists(key, y[?key] == x[?key])`, + residual: `{0: "hello"}.exists(key, y[?key] == x[?key])`, + }, + { + name: "simple bind", + varOpts: []cel.EnvOption{ + cel.Variable("y", cel.MapType(cel.IntType, cel.StringType)), + }, + in: map[string]any{ + "y": map[int]string{0: "hi", 1: "hello", 2: "howdy"}, + }, + unks: []*interpreter.AttributePattern{cel.AttributePattern("y").QualInt(1)}, + expr: `cel.bind(z, y[0], z + y[1])`, + residual: `cel.bind(z, "hi", "hi" + y[1])`, + }, + { + name: "bind with comprehension", + varOpts: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.StringType)), + cel.Variable("y", cel.MapType(cel.IntType, cel.StringType)), + }, + in: map[string]any{ + "x": []string{"hi", "hello", "howdy"}, + "y": map[int]string{0: "hi", 1: "hello", 2: "howdy"}, + }, + unks: []*interpreter.AttributePattern{cel.AttributePattern("y").QualInt(1)}, + expr: `cel.bind(z, y[0], x.all(i, val, val == z || optional.of(val) == y[?i]))`, + residual: `cel.bind(z, "hi", ["hi", "hello", "howdy"].all(i, val, val == z || optional.of(val) == y[?i]))`, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + env := testCompreEnv(t, tc.varOpts...) + ast, iss := env.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Compile(%q) failed: %v", tc.expr, iss.Err()) + } + prg, err := env.Program(ast, + cel.EvalOptions(cel.OptTrackState, cel.OptPartialEval)) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + unkVars, err := cel.PartialVars(tc.in, tc.unks...) + if err != nil { + t.Fatalf("PartialVars() failed: %v", err) + } + out, det, err := prg.Eval(unkVars) + if !types.IsUnknown(out) { + t.Fatalf("got %v, expected unknown", out) + } + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + residual, err := env.ResidualAst(ast, det) + if err != nil { + t.Fatalf("env.ResidualAst() failed: %v", err) + } + expr, err := cel.AstToString(residual) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if expr != tc.residual { + t.Errorf("got expr: %s, wanted %s", expr, tc.residual) + } + }) + } +} + func testCompreEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { t.Helper() baseOpts := []cel.EnvOption{ diff --git a/interpreter/activation.go b/interpreter/activation.go index 1577f359..c20d19de 100644 --- a/interpreter/activation.go +++ b/interpreter/activation.go @@ -156,6 +156,11 @@ type PartialActivation interface { UnknownAttributePatterns() []*AttributePattern } +// partialActivationConverter indicates whether an Activation implementation supports conversion to a PartialActivation +type partialActivationConverter interface { + asPartialActivation() (PartialActivation, bool) +} + // partActivation is the default implementations of the PartialActivation interface. type partActivation struct { Activation @@ -166,3 +171,20 @@ type partActivation struct { func (a *partActivation) UnknownAttributePatterns() []*AttributePattern { return a.unknowns } + +// asPartialActivation returns the partActivation as a PartialActivation interface. +func (a *partActivation) asPartialActivation() (PartialActivation, bool) { + return a, true +} + +func asPartialActivation(vars Activation) (PartialActivation, bool) { + // Only internal activation instances may implement this interface + if pv, ok := vars.(partialActivationConverter); ok { + return pv.asPartialActivation() + } + // Since Activations may be hierarchical, test whether a parent converts to a PartialActivation + if vars.Parent() != nil { + return asPartialActivation(vars.Parent()) + } + return nil, false +} diff --git a/interpreter/attribute_patterns.go b/interpreter/attribute_patterns.go index 8f19bde7..7e5c2db0 100644 --- a/interpreter/attribute_patterns.go +++ b/interpreter/attribute_patterns.go @@ -358,7 +358,7 @@ func (m *attributeMatcher) AddQualifier(qual Qualifier) (Attribute, error) { func (m *attributeMatcher) Resolve(vars Activation) (any, error) { id := m.NamespacedAttribute.ID() // Bug in how partial activation is resolved, should search parents as well. - partial, isPartial := toPartialActivation(vars) + partial, isPartial := asPartialActivation(vars) if isPartial { unk, err := m.fac.matchesUnknownPatterns( partial, @@ -384,14 +384,3 @@ func (m *attributeMatcher) Qualify(vars Activation, obj any) (any, error) { func (m *attributeMatcher) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { return attrQualifyIfPresent(m.fac, vars, obj, m, presenceOnly) } - -func toPartialActivation(vars Activation) (PartialActivation, bool) { - pv, ok := vars.(PartialActivation) - if ok { - return pv, true - } - if vars.Parent() != nil { - return toPartialActivation(vars.Parent()) - } - return nil, false -} diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index b7d2db00..591b7688 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -762,6 +762,9 @@ func (fold *evalFold) Eval(ctx Activation) ref.Val { defer releaseFolder(f) foldRange := fold.iterRange.Eval(ctx) + if types.IsUnknownOrError(foldRange) { + return foldRange + } if fold.iterVar2 != "" { var foldable traits.Foldable switch r := foldRange.(type) { @@ -1363,6 +1366,26 @@ func (f *folder) Parent() Activation { return f.activation } +// UnknownAttributePatterns implements the PartialActivation interface returning the unknown patterns +// if they were provided to the input activation, or an empty set if the proxied activation is not partial. +func (f *folder) UnknownAttributePatterns() []*AttributePattern { + if pv, ok := f.activation.(partialActivationConverter); ok { + if partial, isPartial := pv.asPartialActivation(); isPartial { + return partial.UnknownAttributePatterns() + } + } + return []*AttributePattern{} +} + +func (f *folder) asPartialActivation() (PartialActivation, bool) { + if pv, ok := f.activation.(partialActivationConverter); ok { + if _, isPartial := pv.asPartialActivation(); isPartial { + return f, true + } + } + return nil, false +} + // evalResult computes the final result of the fold after all entries have been folded and accumulated. func (f *folder) evalResult() ref.Val { f.computeResult = true diff --git a/interpreter/prune.go b/interpreter/prune.go index 410d80dc..d3efa7f0 100644 --- a/interpreter/prune.go +++ b/interpreter/prune.go @@ -281,13 +281,29 @@ func (p *astPruner) prune(node ast.Expr) (ast.Expr, bool) { } if macro, found := p.macroCalls[node.ID()]; found { // Ensure that intermediate values for the comprehension are cleared during pruning + pruneMacroCall := node.Kind() != ast.UnspecifiedExprKind if node.Kind() == ast.ComprehensionKind { - compre := node.AsComprehension() - visit(macro, clearIterVarVisitor(compre.IterVar(), p.state)) + // Only prune cel.bind() calls since the variables of the comprehension are all + // visible to the user, so there's no chance of an incorrect value being observed + // as a result of looking at intermediate computations within a comprehension. + pruneMacroCall = isCelBindMacro(macro) } - // prune the expression in terms of the macro call instead of the expanded form. - if newMacro, pruned := p.prune(macro); pruned { - p.macroCalls[node.ID()] = newMacro + if pruneMacroCall { + // prune the expression in terms of the macro call instead of the expanded form when + // dealing with macro call tracking references. + if newMacro, pruned := p.prune(macro); pruned { + p.macroCalls[node.ID()] = newMacro + } + } else { + // Otherwise just prune the macro target in keeping with the pruning behavior of the + // comprehensions later in the call graph. + macroCall := macro.AsCall() + if macroCall.Target() != nil { + if newTarget, pruned := p.prune(macroCall.Target()); pruned { + macro = p.NewMemberCall(macro.ID(), macroCall.FunctionName(), newTarget, macroCall.Args()...) + p.macroCalls[node.ID()] = macro + } + } } } @@ -421,6 +437,19 @@ func (p *astPruner) prune(node ast.Expr) (ast.Expr, bool) { // the last iteration of the comprehension and not each step in the evaluation which // means that the any residuals computed in between might be inaccurate. if newRange, pruned := p.maybePrune(compre.IterRange()); pruned { + if compre.HasIterVar2() { + return p.NewComprehensionTwoVar( + node.ID(), + newRange, + compre.IterVar(), + compre.IterVar2(), + compre.AccuVar(), + compre.AccuInit(), + compre.LoopCondition(), + compre.LoopStep(), + compre.Result(), + ), true + } return p.NewComprehension( node.ID(), newRange, @@ -468,16 +497,6 @@ func getMaxID(expr ast.Expr) int64 { return maxID } -func clearIterVarVisitor(varName string, state EvalState) astVisitor { - return astVisitor{ - visitExpr: func(e ast.Expr) { - if e.Kind() == ast.IdentKind && e.AsIdent() == varName { - state.SetValue(e.ID(), nil) - } - }, - } -} - func maxIDVisitor(maxID *int64) astVisitor { return astVisitor{ visitExpr: func(e ast.Expr) { @@ -541,3 +560,15 @@ func visit(expr ast.Expr, visitor astVisitor) { } } } + +func isCelBindMacro(macro ast.Expr) bool { + if macro.Kind() != ast.CallKind { + return false + } + macroCall := macro.AsCall() + target := macroCall.Target() + return macroCall.FunctionName() == "bind" && + macroCall.IsMemberFunction() && + target.Kind() == ast.IdentKind && + target.AsIdent() == "cel" +}