From 2e6773191ac434ccfa822c2581d45491f062e239 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 4 Dec 2024 18:52:34 -0800 Subject: [PATCH] Fix partial evaluation with the new folder objects used with comprehensions (#1084) --- cel/cel_test.go | 97 ++++++++++++++++++++++++------------ interpreter/interpretable.go | 19 ++++--- 2 files changed, 77 insertions(+), 39 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index 304007db..2de42461 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -1159,39 +1159,72 @@ func TestResidualAstComplex(t *testing.T) { } func TestResidualAstMacros(t *testing.T) { - env := testEnv(t, - Variable("x", ListType(IntType)), - Variable("y", IntType), - EnableMacroCallTracking(), - ) - unkVars, _ := PartialVars(map[string]any{"y": 11}, AttributePattern("x")) - ast, iss := env.Compile(`x.exists(i, i < 10) && [11, 12, 13].all(i, i in [y, 12, 13])`) - if iss.Err() != nil { - t.Fatalf("env.Compile() failed: %v", iss.Err()) - } - prg, err := env.Program(ast, - EvalOptions(OptTrackState, OptPartialEval), - ) - if err != nil { - t.Fatalf("env.Program() failed: %v", err) - } - out, det, err := prg.Eval(unkVars) - if !types.IsUnknown(out) { - t.Fatalf("got %v, expected unknown", out) - } - if err != nil { - t.Fatal(err) - } - residual, err := env.ResidualAst(ast, det) - if err != nil { - t.Fatal(err) - } - expr, err := AstToString(residual) - if err != nil { - t.Fatal(err) + tests := []struct { + env *Env + in map[string]any + unks []*interpreter.AttributePattern + expr string + residual string + }{ + { + env: testEnv(t, + Variable("x", ListType(IntType)), + Variable("y", IntType), + EnableMacroCallTracking()), + in: map[string]any{"y": 11}, + unks: []*interpreter.AttributePattern{AttributePattern("x")}, + expr: `x.exists(i, i < 10) && [11, 12, 13].all(i, i in [y, 12, 13])`, + residual: `x.exists(i, i < 10)`, + }, + { + env: testEnv(t, + Variable("bar", MapType(StringType, DynType)), + Variable("foo", MapType(StringType, DynType)), + EnableMacroCallTracking()), + in: map[string]any{"foo": map[string]any{"a": "b"}}, + unks: []*interpreter.AttributePattern{ + AttributePattern("bar").QualString("baz").Wildcard(), + }, + expr: `foo.exists(t, t == bar.baz.x)`, + residual: `{"a": "b"}.exists(t, t == bar.baz.x)`, + }, } - if expr != "x.exists(i, i < 10)" { - t.Errorf("got expr: %s, wanted x.exists(i, i < 10)", expr) + + for i, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + env := tc.env + unkVars, err := PartialVars(tc.in, tc.unks...) + if err != nil { + t.Fatalf("PartialVars() failed: %v", err) + } + ast, iss := env.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Compile() failed: %v", iss.Err()) + } + prg, err := env.Program(ast, EvalOptions(OptTrackState, OptPartialEval)) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + out, det, err := prg.Eval(unkVars) + if !types.IsUnknown(out) { + t.Fatalf("got %v, expected unknown", out) + } + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + residual, err := env.ResidualAst(ast, det) + if err != nil { + t.Fatalf("env.ResidualAst() failed: %v", err) + } + expr, err := AstToString(residual) + if err != nil { + t.Fatalf("AstToString() failed: %v", err) + } + if expr != tc.residual { + t.Errorf("got expr: %s, wanted %s", expr, tc.residual) + } + }) } } diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index ebc432e9..b7d2db00 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -1241,7 +1241,7 @@ func invalidOptionalElementInit(value ref.Val) ref.Val { func newFolder(eval *evalFold, ctx Activation) *folder { f := folderPool.Get().(*folder) f.evalFold = eval - f.Activation = ctx + f.activation = ctx return f } @@ -1262,7 +1262,7 @@ func releaseFolder(f *folder) { // cel.bind or cel.@block. type folder struct { *evalFold - Activation + activation Activation // fold state objects. accuVal ref.Val @@ -1290,7 +1290,7 @@ func (f *folder) foldIterable(iterable traits.Iterable) ref.Val { // Update the accumulation value and check for eval interuption. f.accuVal = f.step.Eval(f) f.initialized = true - if f.interruptable && checkInterrupt(f.Activation) { + if f.interruptable && checkInterrupt(f.activation) { f.interrupted = true return f.evalResult() } @@ -1316,7 +1316,7 @@ func (f *folder) FoldEntry(key, val any) bool { // Update the accumulation value and check for eval interuption. f.accuVal = f.step.Eval(f) f.initialized = true - if f.interruptable && checkInterrupt(f.Activation) { + if f.interruptable && checkInterrupt(f.activation) { f.interrupted = true return false } @@ -1330,7 +1330,7 @@ func (f *folder) ResolveName(name string) (any, bool) { if name == f.accuVar { if !f.initialized { f.initialized = true - initVal := f.accu.Eval(f.Activation) + initVal := f.accu.Eval(f.activation) if !f.exhaustive { if l, isList := initVal.(traits.Lister); isList && l.Size() == types.IntZero { initVal = types.NewMutableList(f.adapter) @@ -1355,7 +1355,12 @@ func (f *folder) ResolveName(name string) (any, bool) { return f.iterVar2Val, true } } - return f.Activation.ResolveName(name) + return f.activation.ResolveName(name) +} + +// Parent returns the activation embedded into the folder. +func (f *folder) Parent() Activation { + return f.activation } // evalResult computes the final result of the fold after all entries have been folded and accumulated. @@ -1381,7 +1386,7 @@ func (f *folder) evalResult() ref.Val { // reset clears any state associated with folder evaluation. func (f *folder) reset() { f.evalFold = nil - f.Activation = nil + f.activation = nil f.accuVal = nil f.iterVar1Val = nil f.iterVar2Val = nil