Skip to content

Commit

Permalink
Update presence test state tracking and pruning (#711)
Browse files Browse the repository at this point in the history
The has() macro calls were not being adequately captured during
state tracking which led to issues in cost estimation and ast
pruning function. The updated behavior ensures that the has()
macro calls have the same cost as a traditional select expression
and that the state for the presence test is correctly recorded
in a way that works with partial evaluation.
  • Loading branch information
TristonianJones authored May 23, 2023
1 parent aebe3fa commit 0b54df4
Show file tree
Hide file tree
Showing 10 changed files with 332 additions and 102 deletions.
1 change: 1 addition & 0 deletions checker/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
// this is equivalent to how evalTestOnly increments the runtime cost counter
// but does not add any additional cost for the qualifier, except here we do
// the reverse (ident adds cost)
sum = sum.Add(selectAndIdentCost)
sum = sum.Add(c.cost(sel.GetOperand()))
return sum
}
Expand Down
8 changes: 7 additions & 1 deletion checker/cost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,13 @@ func TestCost(t *testing.T) {
name: "select: field test only",
expr: `has(input.single_int32)`,
decls: []*exprpb.Decl{decls.NewVar("input", decls.NewObjectType("google.expr.proto3.test.TestAllTypes"))},
wanted: CostEstimate{Min: 1, Max: 1},
wanted: CostEstimate{Min: 2, Max: 2},
},
{
name: "select: non-proto field test",
expr: `has(input.testAttr.nestedAttr)`,
decls: []*exprpb.Decl{decls.NewVar("input", nestedMap)},
wanted: CostEstimate{Min: 3, Max: 3},
},
{
name: "estimated function call",
Expand Down
35 changes: 19 additions & 16 deletions interpreter/attributes.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ func (a *absoluteAttribute) Resolve(vars Activation) (any, error) {
return nil, err
}
if isOpt {
return types.OptionalOf(a.adapter.NativeToValue(obj)), nil
val := a.adapter.NativeToValue(obj)
if types.IsUnknown(val) {
return val, nil
}
return types.OptionalOf(val), nil
}
return obj, nil
}
Expand Down Expand Up @@ -558,7 +562,11 @@ func (a *relativeAttribute) Resolve(vars Activation) (any, error) {
return nil, err
}
if isOpt {
return types.OptionalOf(a.adapter.NativeToValue(obj)), nil
val := a.adapter.NativeToValue(obj)
if types.IsUnknown(val) {
return val, nil
}
return types.OptionalOf(val), nil
}
return obj, nil
}
Expand Down Expand Up @@ -1171,6 +1179,9 @@ func applyQualifiers(vars Activation, obj any, qualifiers []Qualifier) (any, boo
return nil, false, err
}
if !present {
// We return optional none here with a presence of 'false' as the layers
// above will attempt to call types.OptionalOf() on a present value if any
// of the qualifiers is optional.
return types.OptionalNone, false, nil
}
} else {
Expand Down Expand Up @@ -1223,6 +1234,8 @@ func refQualify(adapter ref.TypeAdapter, obj any, idx ref.Val, presenceTest, pre
return nil, false, v
case traits.Mapper:
val, found := v.Find(idx)
// If the index is of the wrong type for the map, then it is possible
// for the Find call to produce an error.
if types.IsError(val) {
return nil, false, val.(*types.Err)
}
Expand All @@ -1234,6 +1247,8 @@ func refQualify(adapter ref.TypeAdapter, obj any, idx ref.Val, presenceTest, pre
}
return nil, false, missingKey(idx)
case traits.Lister:
// If the index argument is not a valid numeric type, then it is possible
// for the index operation to produce an error.
i, err := types.IndexOrError(idx)
if err != nil {
return nil, false, err
Expand All @@ -1254,6 +1269,8 @@ func refQualify(adapter ref.TypeAdapter, obj any, idx ref.Val, presenceTest, pre
if types.IsError(presence) {
return nil, false, presence.(*types.Err)
}
// If not found or presence only test, then return.
// Otherwise, if found, obtain the value later on.
if presenceOnly || presence == types.False {
return nil, presence == types.True, nil
}
Expand Down Expand Up @@ -1320,17 +1337,3 @@ func (e *resolutionError) Error() string {
func (e *resolutionError) Is(err error) bool {
return err.Error() == e.Error()
}

func findMin(x, y int64) int64 {
if x < y {
return x
}
return y
}

func findMax(x, y int64) int64 {
if x > y {
return x
}
return y
}
74 changes: 69 additions & 5 deletions interpreter/attributes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/parser"
Expand Down Expand Up @@ -68,7 +69,7 @@ func TestAttributesAbsoluteAttr(t *testing.T) {
}
}

func TestAttributesAbsoluteAttr_Type(t *testing.T) {
func TestAttributesAbsoluteAttrType(t *testing.T) {
reg := newTestRegistry(t)
attrs := NewAttributeFactory(containers.DefaultContainer, reg, reg)

Expand Down Expand Up @@ -861,6 +862,7 @@ func TestAttributeStateTracking(t *testing.T) {
env []*exprpb.Decl
in map[string]any
out ref.Val
attrs []*AttributePattern
state map[int64]any
}{
{
Expand Down Expand Up @@ -1051,12 +1053,54 @@ func TestAttributeStateTracking(t *testing.T) {
3: types.String("world"),
},
},
{
expr: `m[has(a.b)]`,
env: []*exprpb.Decl{
decls.NewVar("a", decls.NewMapType(decls.String, decls.String)),
decls.NewVar("m", decls.NewMapType(decls.Bool, decls.String)),
},
in: map[string]any{
"a": map[string]string{"b": ""},
"m": map[bool]string{true: "world"},
},
out: types.String("world"),
},
{
expr: `m[?has(a.b)]`,
env: []*exprpb.Decl{
decls.NewVar("a", decls.NewMapType(decls.String, decls.String)),
decls.NewVar("m", decls.NewMapType(decls.Bool, decls.String)),
},
in: map[string]any{
"a": map[string]string{"b": ""},
"m": map[bool]string{true: "world"},
},
out: types.OptionalOf(types.String("world")),
},
{
expr: `m[?has(a.b.c)]`,
env: []*exprpb.Decl{
decls.NewVar("a", decls.NewMapType(decls.String, decls.Dyn)),
decls.NewVar("m", decls.NewMapType(decls.Bool, decls.String)),
},
in: map[string]any{
"a": map[string]any{},
"m": map[bool]string{true: "world"},
},
out: types.Unknown{5},
attrs: []*AttributePattern{
NewAttributePattern("a").QualString("b"),
},
},
}
for _, test := range tests {
tc := test
t.Run(tc.expr, func(t *testing.T) {
src := common.NewTextSource(tc.expr)
p, err := parser.NewParser(parser.EnableOptionalSyntax(true))
p, err := parser.NewParser(
parser.EnableOptionalSyntax(true),
parser.Macros(parser.AllMacros...),
)
if err != nil {
t.Fatalf("parser.NewParser() failed: %v", err)
}
Expand All @@ -1071,6 +1115,7 @@ func TestAttributeStateTracking(t *testing.T) {
t.Fatalf("checker.NewEnv() failed: %v", err)
}
env.Add(checker.StandardDeclarations()...)
env.Add(optionalSignatures()...)
if tc.env != nil {
env.Add(tc.env...)
}
Expand All @@ -1079,6 +1124,9 @@ func TestAttributeStateTracking(t *testing.T) {
t.Fatalf(errors.ToDisplayString())
}
attrs := NewAttributeFactory(cont, reg, reg)
if tc.attrs != nil {
attrs = NewPartialAttributeFactory(cont, reg, reg)
}
interp := NewStandardInterpreter(cont, reg, reg, attrs)
// Show that program planning will now produce an error.
st := NewEvalState()
Expand All @@ -1089,10 +1137,14 @@ func TestAttributeStateTracking(t *testing.T) {
if err != nil {
t.Fatal(err)
}
in, _ := NewActivation(tc.in)
in, _ := NewPartialActivation(tc.in, tc.attrs...)
out := i.Eval(in)
if tc.out.Equal(out) != types.True {
t.Errorf("got %v, wanted %v", out.Value(), tc.out)
if types.IsUnknown(tc.out) && types.IsUnknown(out) {
if !reflect.DeepEqual(tc.out, out) {
t.Errorf("got %v, wanted %v", out, tc.out)
}
} else if tc.out.Equal(out) != types.True {
t.Errorf("got %v, wanted %v", out, tc.out)
}
for id, val := range tc.state {
stVal, found := st.Value(id)
Expand Down Expand Up @@ -1209,3 +1261,15 @@ func findField(t testing.TB, reg ref.TypeRegistry, typeName, field string) *ref.
}
return ft
}

func optionalSignatures() []*exprpb.Decl {
return []*exprpb.Decl{
decls.NewFunction(operators.OptIndex,
decls.NewParameterizedOverload("map_optindex_optional_value", []*exprpb.Type{
decls.NewMapType(decls.NewTypeParamType("K"), decls.NewTypeParamType("V")),
decls.NewTypeParamType("K")},
decls.NewOptionalType(decls.NewTypeParamType("V")),
[]string{"K", "V"},
)),
}
}
Loading

0 comments on commit 0b54df4

Please sign in to comment.