Skip to content

Commit

Permalink
Fix partial evaluation with the new folder objects used with comprehe…
Browse files Browse the repository at this point in the history
…nsions (#1084)
  • Loading branch information
TristonianJones authored Dec 5, 2024
1 parent 5910569 commit 2e67731
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 39 deletions.
97 changes: 65 additions & 32 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1159,39 +1159,72 @@ func TestResidualAstComplex(t *testing.T) {
}

func TestResidualAstMacros(t *testing.T) {
env := testEnv(t,
Variable("x", ListType(IntType)),
Variable("y", IntType),
EnableMacroCallTracking(),
)
unkVars, _ := PartialVars(map[string]any{"y": 11}, AttributePattern("x"))
ast, iss := env.Compile(`x.exists(i, i < 10) && [11, 12, 13].all(i, i in [y, 12, 13])`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast,
EvalOptions(OptTrackState, OptPartialEval),
)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, det, err := prg.Eval(unkVars)
if !types.IsUnknown(out) {
t.Fatalf("got %v, expected unknown", out)
}
if err != nil {
t.Fatal(err)
}
residual, err := env.ResidualAst(ast, det)
if err != nil {
t.Fatal(err)
}
expr, err := AstToString(residual)
if err != nil {
t.Fatal(err)
tests := []struct {
env *Env
in map[string]any
unks []*interpreter.AttributePattern
expr string
residual string
}{
{
env: testEnv(t,
Variable("x", ListType(IntType)),
Variable("y", IntType),
EnableMacroCallTracking()),
in: map[string]any{"y": 11},
unks: []*interpreter.AttributePattern{AttributePattern("x")},
expr: `x.exists(i, i < 10) && [11, 12, 13].all(i, i in [y, 12, 13])`,
residual: `x.exists(i, i < 10)`,
},
{
env: testEnv(t,
Variable("bar", MapType(StringType, DynType)),
Variable("foo", MapType(StringType, DynType)),
EnableMacroCallTracking()),
in: map[string]any{"foo": map[string]any{"a": "b"}},
unks: []*interpreter.AttributePattern{
AttributePattern("bar").QualString("baz").Wildcard(),
},
expr: `foo.exists(t, t == bar.baz.x)`,
residual: `{"a": "b"}.exists(t, t == bar.baz.x)`,
},
}
if expr != "x.exists(i, i < 10)" {
t.Errorf("got expr: %s, wanted x.exists(i, i < 10)", expr)

for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
env := tc.env
unkVars, err := PartialVars(tc.in, tc.unks...)
if err != nil {
t.Fatalf("PartialVars() failed: %v", err)
}
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptTrackState, OptPartialEval))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, det, err := prg.Eval(unkVars)
if !types.IsUnknown(out) {
t.Fatalf("got %v, expected unknown", out)
}
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
residual, err := env.ResidualAst(ast, det)
if err != nil {
t.Fatalf("env.ResidualAst() failed: %v", err)
}
expr, err := AstToString(residual)
if err != nil {
t.Fatalf("AstToString() failed: %v", err)
}
if expr != tc.residual {
t.Errorf("got expr: %s, wanted %s", expr, tc.residual)
}
})
}
}

Expand Down
19 changes: 12 additions & 7 deletions interpreter/interpretable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,7 @@ func invalidOptionalElementInit(value ref.Val) ref.Val {
func newFolder(eval *evalFold, ctx Activation) *folder {
f := folderPool.Get().(*folder)
f.evalFold = eval
f.Activation = ctx
f.activation = ctx
return f
}

Expand All @@ -1262,7 +1262,7 @@ func releaseFolder(f *folder) {
// cel.bind or cel.@block.
type folder struct {
*evalFold
Activation
activation Activation

// fold state objects.
accuVal ref.Val
Expand Down Expand Up @@ -1290,7 +1290,7 @@ func (f *folder) foldIterable(iterable traits.Iterable) ref.Val {
// Update the accumulation value and check for eval interuption.
f.accuVal = f.step.Eval(f)
f.initialized = true
if f.interruptable && checkInterrupt(f.Activation) {
if f.interruptable && checkInterrupt(f.activation) {
f.interrupted = true
return f.evalResult()
}
Expand All @@ -1316,7 +1316,7 @@ func (f *folder) FoldEntry(key, val any) bool {
// Update the accumulation value and check for eval interuption.
f.accuVal = f.step.Eval(f)
f.initialized = true
if f.interruptable && checkInterrupt(f.Activation) {
if f.interruptable && checkInterrupt(f.activation) {
f.interrupted = true
return false
}
Expand All @@ -1330,7 +1330,7 @@ func (f *folder) ResolveName(name string) (any, bool) {
if name == f.accuVar {
if !f.initialized {
f.initialized = true
initVal := f.accu.Eval(f.Activation)
initVal := f.accu.Eval(f.activation)
if !f.exhaustive {
if l, isList := initVal.(traits.Lister); isList && l.Size() == types.IntZero {
initVal = types.NewMutableList(f.adapter)
Expand All @@ -1355,7 +1355,12 @@ func (f *folder) ResolveName(name string) (any, bool) {
return f.iterVar2Val, true
}
}
return f.Activation.ResolveName(name)
return f.activation.ResolveName(name)
}

// Parent returns the activation embedded into the folder.
func (f *folder) Parent() Activation {
return f.activation
}

// evalResult computes the final result of the fold after all entries have been folded and accumulated.
Expand All @@ -1381,7 +1386,7 @@ func (f *folder) evalResult() ref.Val {
// reset clears any state associated with folder evaluation.
func (f *folder) reset() {
f.evalFold = nil
f.Activation = nil
f.activation = nil
f.accuVal = nil
f.iterVar1Val = nil
f.iterVar2Val = nil
Expand Down

0 comments on commit 2e67731

Please sign in to comment.