Skip to content

Commit

Permalink
Fix two-variable comprehension pruning (#1083)
Browse files Browse the repository at this point in the history
* Fix two-variable comprehension pruning
* Ensure only cel.bind() comprehensions are pruned
  • Loading branch information
TristonianJones authored Dec 9, 2024
1 parent 6202a67 commit bd1ec92
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 27 deletions.
254 changes: 254 additions & 0 deletions ext/comprehensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"testing"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/interpreter"
)

func TestTwoVarComprehensions(t *testing.T) {
Expand Down Expand Up @@ -359,6 +361,258 @@ func TestTwoVarComprehensionsVersion(t *testing.T) {
}
}

func TestTwoVarComprehensionsUnparse(t *testing.T) {
tests := []struct {
name string
expr string
unparsed string
}{
{
name: "transform map entry",
expr: `[0, 0u].transformMapEntry(i, v, {v: i})`,
unparsed: `[0, 0u].transformMapEntry(i, v, {v: i})`,
},
{
name: "transform map",
expr: `{'a': 'world', 'b': 'hello'}.transformMap(i, v, i == 'a' ? v.upperAscii() : v)`,
unparsed: `{"a": "world", "b": "hello"}.transformMap(i, v, (i == "a") ? v.upperAscii() : v)`,
},
{
name: "transform list",
expr: `[1.0, 2.0, 2.0].transformList(i, v, i / 2.0 == 1.0)`,
unparsed: `[1.0, 2.0, 2.0].transformList(i, v, i / 2.0 == 1.0)`,
},
{
name: "existsOne",
expr: `{'a': 'b', 'c': 'd'}.existsOne(k, v, k == 'b' || v == 'b')`,
unparsed: `{"a": "b", "c": "d"}.existsOne(k, v, k == "b" || v == "b")`,
},
{
name: "exists",
expr: `{'a': 'b', 'c': 'd'}.exists(k, v, k == 'b' || v == 'b')`,
unparsed: `{"a": "b", "c": "d"}.exists(k, v, k == "b" || v == "b")`,
},
{
name: "all",
expr: `[null, null, 'hello', string].all(i, v, i == 0 || type(v) != int)`,
unparsed: `[null, null, "hello", string].all(i, v, i == 0 || type(v) != int)`,
},
}
env := testCompreEnv(t)
for _, tst := range tests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
ast, iss := env.Parse(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Parse(%q) failed: %v", tc.expr, iss.Err())
}
unparsed, err := cel.AstToString(ast)
if err != nil {
t.Fatalf("cel.AstToString() failed: %v", err)
}
if unparsed != tc.unparsed {
t.Errorf("cel.AstToString() got %q, wanted %q", unparsed, tc.unparsed)
}
})
}
}

func TestTwoVarComprehensionsResidualAST(t *testing.T) {
tests := []struct {
name string
in map[string]any
varOpts []cel.EnvOption
unks []*interpreter.AttributePattern
expr string
residual string
}{
{
name: "transform map entry residual compare",
varOpts: []cel.EnvOption{
cel.Variable("x", cel.ListType(cel.DynType)),
cel.Variable("y", cel.IntType),
},
in: map[string]any{
"x": []any{0, uint(1)},
},
unks: []*interpreter.AttributePattern{cel.AttributePattern("y")},
expr: `x.transformMapEntry(i, v, {v: i}).size() < y`,
residual: `2 < y`,
},
{
name: "transform map entry residual transform",
varOpts: []cel.EnvOption{
cel.Variable("x", cel.ListType(cel.DynType)),
cel.Variable("y", cel.IntType),
},
in: map[string]any{
"x": []any{0, uint(1)},
},
unks: []*interpreter.AttributePattern{cel.AttributePattern("y")},
expr: `x.transformMapEntry(i, v, i < y, {v: i})`,
residual: `[0, 1u].transformMapEntry(i, v, i < y, {v: i})`,
},
{
name: "nested exists unknown inner range",
varOpts: []cel.EnvOption{
cel.Variable("x", cel.ListType(cel.IntType)),
cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)),
},
in: map[string]any{
"x": []any{1, 2, 3},
},
unks: []*interpreter.AttributePattern{cel.AttributePattern("y")},
expr: `x.exists(val, y.exists(key, _, key == val))`,
residual: `[1, 2, 3].exists(val, y.exists(key, _, key == val))`,
},
{
name: "nested exists unknown inner range",
varOpts: []cel.EnvOption{
cel.Variable("x", cel.ListType(cel.IntType)),
cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)),
},
in: map[string]any{
"y": map[int]string{1: "hi", 2: "hello", 3: "howdy"},
},
unks: []*interpreter.AttributePattern{cel.AttributePattern("x")},
expr: `x.exists(val, y.exists(key, _, key == val))`,
residual: `x.exists(val, y.exists(key, _, key == val))`,
},
{
name: "nested exists unknown outer range with extra predicate",
varOpts: []cel.EnvOption{
cel.Variable("x", cel.ListType(cel.IntType)),
cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)),
},
in: map[string]any{
"y": map[int]string{1: "hi", 2: "hello", 3: "howdy"},
},
unks: []*interpreter.AttributePattern{cel.AttributePattern("x")},
expr: `x.exists(val, y.exists(key, _, key == val)) && y.all(key, val, val.startsWith('h'))`,
residual: `x.exists(val, y.exists(key, _, key == val))`,
},
{
name: "nested exists partial unknown outer range",
varOpts: []cel.EnvOption{
cel.Variable("x", cel.ListType(cel.IntType)),
cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)),
},
in: map[string]any{
"x": []int{42, 0, 43},
"y": map[int]string{1: "hi", 2: "hello", 3: "howdy"},
},
unks: []*interpreter.AttributePattern{cel.AttributePattern("x").QualInt(1)},
expr: `x.exists(val, y.exists(key, _, key == val)) || x[0] == 0 || x[1] == 1 || x[2] == 2`,
residual: `x.exists(val, y.exists(key, _, key == val)) || x[1] == 1`,
},
{
name: "nested exists partial unknown outer range with optionals",
varOpts: []cel.EnvOption{
cel.Variable("x", cel.ListType(cel.IntType)),
cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)),
},
in: map[string]any{
"x": []int{42, 0, 43},
"y": map[int]string{1: "hi", 2: "hello", 3: "howdy"},
},
unks: []*interpreter.AttributePattern{cel.AttributePattern("x").QualInt(1)},
expr: `x.exists(val, y.exists(key, _, key == val)) || (x[?0].hasValue() && x[?1].hasValue())`,
residual: `x.exists(val, y.exists(key, _, key == val)) || x[?1].hasValue()`,
},
{
name: "inner value partial unknown two-var",
varOpts: []cel.EnvOption{
cel.Variable("x", cel.ListType(cel.StringType)),
cel.Variable("y", cel.MapType(cel.IntType, cel.StringType)),
},
in: map[string]any{
"x": []string{"howdy", "hello", "hi"},
"y": map[int]string{0: "hi", 1: "hello", 2: "howdy"},
},
unks: []*interpreter.AttributePattern{cel.AttributePattern("y").QualInt(1)},
expr: `x.exists(key, val, y[?key] == optional.of(val))`,
residual: `["howdy", "hello", "hi"].exists(key, val, y[?key] == optional.of(val))`,
},
{
name: "inner value partial unknown one-var",
varOpts: []cel.EnvOption{
cel.Variable("x", cel.ListType(cel.StringType)),
cel.Variable("y", cel.MapType(cel.IntType, cel.StringType)),
},
in: map[string]any{
"x": []string{"howdy"},
"y": map[int]string{0: "hello"},
},
unks: []*interpreter.AttributePattern{cel.AttributePattern("x").QualInt(0)},
expr: `y.exists(key, y[?key] == x[?key])`,
residual: `{0: "hello"}.exists(key, y[?key] == x[?key])`,
},
{
name: "simple bind",
varOpts: []cel.EnvOption{
cel.Variable("y", cel.MapType(cel.IntType, cel.StringType)),
},
in: map[string]any{
"y": map[int]string{0: "hi", 1: "hello", 2: "howdy"},
},
unks: []*interpreter.AttributePattern{cel.AttributePattern("y").QualInt(1)},
expr: `cel.bind(z, y[0], z + y[1])`,
residual: `cel.bind(z, "hi", "hi" + y[1])`,
},
{
name: "bind with comprehension",
varOpts: []cel.EnvOption{
cel.Variable("x", cel.ListType(cel.StringType)),
cel.Variable("y", cel.MapType(cel.IntType, cel.StringType)),
},
in: map[string]any{
"x": []string{"hi", "hello", "howdy"},
"y": map[int]string{0: "hi", 1: "hello", 2: "howdy"},
},
unks: []*interpreter.AttributePattern{cel.AttributePattern("y").QualInt(1)},
expr: `cel.bind(z, y[0], x.all(i, val, val == z || optional.of(val) == y[?i]))`,
residual: `cel.bind(z, "hi", ["hi", "hello", "howdy"].all(i, val, val == z || optional.of(val) == y[?i]))`,
},
}
for _, tst := range tests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
env := testCompreEnv(t, tc.varOpts...)
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%q) failed: %v", tc.expr, iss.Err())
}
prg, err := env.Program(ast,
cel.EvalOptions(cel.OptTrackState, cel.OptPartialEval))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
unkVars, err := cel.PartialVars(tc.in, tc.unks...)
if err != nil {
t.Fatalf("PartialVars() failed: %v", err)
}
out, det, err := prg.Eval(unkVars)
if !types.IsUnknown(out) {
t.Fatalf("got %v, expected unknown", out)
}
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
residual, err := env.ResidualAst(ast, det)
if err != nil {
t.Fatalf("env.ResidualAst() failed: %v", err)
}
expr, err := cel.AstToString(residual)
if err != nil {
t.Fatalf("cel.AstToString() failed: %v", err)
}
if expr != tc.residual {
t.Errorf("got expr: %s, wanted %s", expr, tc.residual)
}
})
}
}

func testCompreEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env {
t.Helper()
baseOpts := []cel.EnvOption{
Expand Down
22 changes: 22 additions & 0 deletions interpreter/activation.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ type PartialActivation interface {
UnknownAttributePatterns() []*AttributePattern
}

// partialActivationConverter indicates whether an Activation implementation supports conversion to a PartialActivation
type partialActivationConverter interface {
asPartialActivation() (PartialActivation, bool)
}

// partActivation is the default implementations of the PartialActivation interface.
type partActivation struct {
Activation
Expand All @@ -166,3 +171,20 @@ type partActivation struct {
func (a *partActivation) UnknownAttributePatterns() []*AttributePattern {
return a.unknowns
}

// asPartialActivation returns the partActivation as a PartialActivation interface.
func (a *partActivation) asPartialActivation() (PartialActivation, bool) {
return a, true
}

func asPartialActivation(vars Activation) (PartialActivation, bool) {
// Only internal activation instances may implement this interface
if pv, ok := vars.(partialActivationConverter); ok {
return pv.asPartialActivation()
}
// Since Activations may be hierarchical, test whether a parent converts to a PartialActivation
if vars.Parent() != nil {
return asPartialActivation(vars.Parent())
}
return nil, false
}
13 changes: 1 addition & 12 deletions interpreter/attribute_patterns.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func (m *attributeMatcher) AddQualifier(qual Qualifier) (Attribute, error) {
func (m *attributeMatcher) Resolve(vars Activation) (any, error) {
id := m.NamespacedAttribute.ID()
// Bug in how partial activation is resolved, should search parents as well.
partial, isPartial := toPartialActivation(vars)
partial, isPartial := asPartialActivation(vars)
if isPartial {
unk, err := m.fac.matchesUnknownPatterns(
partial,
Expand All @@ -384,14 +384,3 @@ func (m *attributeMatcher) Qualify(vars Activation, obj any) (any, error) {
func (m *attributeMatcher) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) {
return attrQualifyIfPresent(m.fac, vars, obj, m, presenceOnly)
}

func toPartialActivation(vars Activation) (PartialActivation, bool) {
pv, ok := vars.(PartialActivation)
if ok {
return pv, true
}
if vars.Parent() != nil {
return toPartialActivation(vars.Parent())
}
return nil, false
}
23 changes: 23 additions & 0 deletions interpreter/interpretable.go
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,9 @@ func (fold *evalFold) Eval(ctx Activation) ref.Val {
defer releaseFolder(f)

foldRange := fold.iterRange.Eval(ctx)
if types.IsUnknownOrError(foldRange) {
return foldRange
}
if fold.iterVar2 != "" {
var foldable traits.Foldable
switch r := foldRange.(type) {
Expand Down Expand Up @@ -1363,6 +1366,26 @@ func (f *folder) Parent() Activation {
return f.activation
}

// UnknownAttributePatterns implements the PartialActivation interface returning the unknown patterns
// if they were provided to the input activation, or an empty set if the proxied activation is not partial.
func (f *folder) UnknownAttributePatterns() []*AttributePattern {
if pv, ok := f.activation.(partialActivationConverter); ok {
if partial, isPartial := pv.asPartialActivation(); isPartial {
return partial.UnknownAttributePatterns()
}
}
return []*AttributePattern{}
}

func (f *folder) asPartialActivation() (PartialActivation, bool) {
if pv, ok := f.activation.(partialActivationConverter); ok {
if _, isPartial := pv.asPartialActivation(); isPartial {
return f, true
}
}
return nil, false
}

// evalResult computes the final result of the fold after all entries have been folded and accumulated.
func (f *folder) evalResult() ref.Val {
f.computeResult = true
Expand Down
Loading

0 comments on commit bd1ec92

Please sign in to comment.