From 0b54df40f9fa99d992b90b0a1ab34ca1cae628ef Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Tue, 23 May 2023 13:33:55 -0700 Subject: [PATCH] Update presence test state tracking and pruning (#711) The has() macro calls were not being adequately captured during state tracking which led to issues in cost estimation and ast pruning function. The updated behavior ensures that the has() macro calls have the same cost as a traditional select expression and that the state for the presence test is correctly recorded in a way that works with partial evaluation. --- checker/cost.go | 1 + checker/cost_test.go | 8 ++- interpreter/attributes.go | 35 ++++++----- interpreter/attributes_test.go | 74 ++++++++++++++++++++-- interpreter/interpretable.go | 95 +++++++++++++++++++--------- interpreter/planner.go | 17 ++--- interpreter/prune.go | 90 +++++++++++++++++---------- interpreter/prune_test.go | 107 +++++++++++++++++++++++++++++++- interpreter/runtimecost.go | 2 - interpreter/runtimecost_test.go | 5 +- 10 files changed, 332 insertions(+), 102 deletions(-) diff --git a/checker/cost.go b/checker/cost.go index 6cf8c4fe..f1102686 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -347,6 +347,7 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate { // this is equivalent to how evalTestOnly increments the runtime cost counter // but does not add any additional cost for the qualifier, except here we do // the reverse (ident adds cost) + sum = sum.Add(selectAndIdentCost) sum = sum.Add(c.cost(sel.GetOperand())) return sum } diff --git a/checker/cost_test.go b/checker/cost_test.go index 20feff26..30bb21e3 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -75,7 +75,13 @@ func TestCost(t *testing.T) { name: "select: field test only", expr: `has(input.single_int32)`, decls: []*exprpb.Decl{decls.NewVar("input", decls.NewObjectType("google.expr.proto3.test.TestAllTypes"))}, - wanted: CostEstimate{Min: 1, Max: 1}, + wanted: CostEstimate{Min: 2, Max: 2}, + }, + { + name: "select: non-proto field test", + expr: `has(input.testAttr.nestedAttr)`, + decls: []*exprpb.Decl{decls.NewVar("input", nestedMap)}, + wanted: CostEstimate{Min: 3, Max: 3}, }, { name: "estimated function call", diff --git a/interpreter/attributes.go b/interpreter/attributes.go index d2205d90..1b19dc2b 100644 --- a/interpreter/attributes.go +++ b/interpreter/attributes.go @@ -294,7 +294,11 @@ func (a *absoluteAttribute) Resolve(vars Activation) (any, error) { return nil, err } if isOpt { - return types.OptionalOf(a.adapter.NativeToValue(obj)), nil + val := a.adapter.NativeToValue(obj) + if types.IsUnknown(val) { + return val, nil + } + return types.OptionalOf(val), nil } return obj, nil } @@ -558,7 +562,11 @@ func (a *relativeAttribute) Resolve(vars Activation) (any, error) { return nil, err } if isOpt { - return types.OptionalOf(a.adapter.NativeToValue(obj)), nil + val := a.adapter.NativeToValue(obj) + if types.IsUnknown(val) { + return val, nil + } + return types.OptionalOf(val), nil } return obj, nil } @@ -1171,6 +1179,9 @@ func applyQualifiers(vars Activation, obj any, qualifiers []Qualifier) (any, boo return nil, false, err } if !present { + // We return optional none here with a presence of 'false' as the layers + // above will attempt to call types.OptionalOf() on a present value if any + // of the qualifiers is optional. return types.OptionalNone, false, nil } } else { @@ -1223,6 +1234,8 @@ func refQualify(adapter ref.TypeAdapter, obj any, idx ref.Val, presenceTest, pre return nil, false, v case traits.Mapper: val, found := v.Find(idx) + // If the index is of the wrong type for the map, then it is possible + // for the Find call to produce an error. if types.IsError(val) { return nil, false, val.(*types.Err) } @@ -1234,6 +1247,8 @@ func refQualify(adapter ref.TypeAdapter, obj any, idx ref.Val, presenceTest, pre } return nil, false, missingKey(idx) case traits.Lister: + // If the index argument is not a valid numeric type, then it is possible + // for the index operation to produce an error. i, err := types.IndexOrError(idx) if err != nil { return nil, false, err @@ -1254,6 +1269,8 @@ func refQualify(adapter ref.TypeAdapter, obj any, idx ref.Val, presenceTest, pre if types.IsError(presence) { return nil, false, presence.(*types.Err) } + // If not found or presence only test, then return. + // Otherwise, if found, obtain the value later on. if presenceOnly || presence == types.False { return nil, presence == types.True, nil } @@ -1320,17 +1337,3 @@ func (e *resolutionError) Error() string { func (e *resolutionError) Is(err error) bool { return err.Error() == e.Error() } - -func findMin(x, y int64) int64 { - if x < y { - return x - } - return y -} - -func findMax(x, y int64) int64 { - if x > y { - return x - } - return y -} diff --git a/interpreter/attributes_test.go b/interpreter/attributes_test.go index 2fa8f64f..a5c99350 100644 --- a/interpreter/attributes_test.go +++ b/interpreter/attributes_test.go @@ -24,6 +24,7 @@ import ( "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common" "github.com/google/cel-go/common/containers" + "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/parser" @@ -68,7 +69,7 @@ func TestAttributesAbsoluteAttr(t *testing.T) { } } -func TestAttributesAbsoluteAttr_Type(t *testing.T) { +func TestAttributesAbsoluteAttrType(t *testing.T) { reg := newTestRegistry(t) attrs := NewAttributeFactory(containers.DefaultContainer, reg, reg) @@ -861,6 +862,7 @@ func TestAttributeStateTracking(t *testing.T) { env []*exprpb.Decl in map[string]any out ref.Val + attrs []*AttributePattern state map[int64]any }{ { @@ -1051,12 +1053,54 @@ func TestAttributeStateTracking(t *testing.T) { 3: types.String("world"), }, }, + { + expr: `m[has(a.b)]`, + env: []*exprpb.Decl{ + decls.NewVar("a", decls.NewMapType(decls.String, decls.String)), + decls.NewVar("m", decls.NewMapType(decls.Bool, decls.String)), + }, + in: map[string]any{ + "a": map[string]string{"b": ""}, + "m": map[bool]string{true: "world"}, + }, + out: types.String("world"), + }, + { + expr: `m[?has(a.b)]`, + env: []*exprpb.Decl{ + decls.NewVar("a", decls.NewMapType(decls.String, decls.String)), + decls.NewVar("m", decls.NewMapType(decls.Bool, decls.String)), + }, + in: map[string]any{ + "a": map[string]string{"b": ""}, + "m": map[bool]string{true: "world"}, + }, + out: types.OptionalOf(types.String("world")), + }, + { + expr: `m[?has(a.b.c)]`, + env: []*exprpb.Decl{ + decls.NewVar("a", decls.NewMapType(decls.String, decls.Dyn)), + decls.NewVar("m", decls.NewMapType(decls.Bool, decls.String)), + }, + in: map[string]any{ + "a": map[string]any{}, + "m": map[bool]string{true: "world"}, + }, + out: types.Unknown{5}, + attrs: []*AttributePattern{ + NewAttributePattern("a").QualString("b"), + }, + }, } for _, test := range tests { tc := test t.Run(tc.expr, func(t *testing.T) { src := common.NewTextSource(tc.expr) - p, err := parser.NewParser(parser.EnableOptionalSyntax(true)) + p, err := parser.NewParser( + parser.EnableOptionalSyntax(true), + parser.Macros(parser.AllMacros...), + ) if err != nil { t.Fatalf("parser.NewParser() failed: %v", err) } @@ -1071,6 +1115,7 @@ func TestAttributeStateTracking(t *testing.T) { t.Fatalf("checker.NewEnv() failed: %v", err) } env.Add(checker.StandardDeclarations()...) + env.Add(optionalSignatures()...) if tc.env != nil { env.Add(tc.env...) } @@ -1079,6 +1124,9 @@ func TestAttributeStateTracking(t *testing.T) { t.Fatalf(errors.ToDisplayString()) } attrs := NewAttributeFactory(cont, reg, reg) + if tc.attrs != nil { + attrs = NewPartialAttributeFactory(cont, reg, reg) + } interp := NewStandardInterpreter(cont, reg, reg, attrs) // Show that program planning will now produce an error. st := NewEvalState() @@ -1089,10 +1137,14 @@ func TestAttributeStateTracking(t *testing.T) { if err != nil { t.Fatal(err) } - in, _ := NewActivation(tc.in) + in, _ := NewPartialActivation(tc.in, tc.attrs...) out := i.Eval(in) - if tc.out.Equal(out) != types.True { - t.Errorf("got %v, wanted %v", out.Value(), tc.out) + if types.IsUnknown(tc.out) && types.IsUnknown(out) { + if !reflect.DeepEqual(tc.out, out) { + t.Errorf("got %v, wanted %v", out, tc.out) + } + } else if tc.out.Equal(out) != types.True { + t.Errorf("got %v, wanted %v", out, tc.out) } for id, val := range tc.state { stVal, found := st.Value(id) @@ -1209,3 +1261,15 @@ func findField(t testing.TB, reg ref.TypeRegistry, typeName, field string) *ref. } return ft } + +func optionalSignatures() []*exprpb.Decl { + return []*exprpb.Decl{ + decls.NewFunction(operators.OptIndex, + decls.NewParameterizedOverload("map_optindex_optional_value", []*exprpb.Type{ + decls.NewMapType(decls.NewTypeParamType("K"), decls.NewTypeParamType("V")), + decls.NewTypeParamType("K")}, + decls.NewOptionalType(decls.NewTypeParamType("V")), + []string{"K", "V"}, + )), + } +} diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index ac9a63f3..7d67a92b 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -15,6 +15,8 @@ package interpreter import ( + "fmt" + "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" @@ -69,6 +71,7 @@ type InterpretableAttribute interface { // to whether the qualifier is present. QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) + // IsOptional indicates whether the resulting value is an optional type. IsOptional() bool // Resolve returns the value of the Attribute given the current Activation. @@ -108,10 +111,8 @@ type InterpretableConstructor interface { // Core Interpretable implementations used during the program planning phase. type evalTestOnly struct { - id int64 - attr InterpretableAttribute - qual Qualifier - field types.String + id int64 + InterpretableAttribute } // ID implements the Interpretable interface method. @@ -121,28 +122,58 @@ func (test *evalTestOnly) ID() int64 { // Eval implements the Interpretable interface method. func (test *evalTestOnly) Eval(ctx Activation) ref.Val { - val, err := test.attr.Resolve(ctx) + val, err := test.Resolve(ctx) + // Return an error if the resolve step fails if err != nil { - return types.NewErr(err.Error()) + return types.WrapErr(err) } - optVal, isOpt := val.(*types.Optional) - if isOpt { - if !optVal.HasValue() { - return types.False - } - val = optVal.GetValue() + if optVal, isOpt := val.(*types.Optional); isOpt { + return types.Bool(optVal.HasValue()) + } + return test.Adapter().NativeToValue(val) +} + +// AddQualifier appends a qualifier that will always and only perform a presence test. +func (test *evalTestOnly) AddQualifier(q Qualifier) (Attribute, error) { + cq, ok := q.(ConstantQualifier) + if !ok { + return nil, fmt.Errorf("test only expressions must have constant qualifiers: %v", q) } - out, found, err := test.qual.QualifyIfPresent(ctx, val, true) + return test.InterpretableAttribute.AddQualifier(&testOnlyQualifier{ConstantQualifier: cq}) +} + +type testOnlyQualifier struct { + ConstantQualifier +} + +// Qualify determines whether the test-only qualifier is present on the input object. +func (q *testOnlyQualifier) Qualify(vars Activation, obj any) (any, error) { + out, present, err := q.ConstantQualifier.QualifyIfPresent(vars, obj, true) if err != nil { - return types.NewErr(err.Error()) + return nil, err } if unk, isUnk := out.(types.Unknown); isUnk { - return unk + return unk, nil } - if found { - return types.True + if opt, isOpt := out.(types.Optional); isOpt { + return opt.HasValue(), nil } - return types.False + return present, nil +} + +// QualifyIfPresent returns whether the target field in the test-only expression is present. +// +// This method should never be called as the has() macro and optional syntax are incompatible +// when used on the same field. +func (q *testOnlyQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + // Only ever test for presence. + return q.ConstantQualifier.QualifyIfPresent(vars, obj, true) +} + +// QualifierValueEquals determines whether the test-only constant qualifier equals the input value. +func (q *testOnlyQualifier) QualifierValueEquals(value any) bool { + // The input qualifier will always be of type string + return q.ConstantQualifier.Value().Value() == value } // NewConstValue creates a new constant valued Interpretable. @@ -875,7 +906,7 @@ func (e *evalWatchConstQual) Qualify(vars Activation, obj any) (any, error) { out, err := e.ConstantQualifier.Qualify(vars, obj) var val ref.Val if err != nil { - val = types.NewErr(err.Error()) + val = types.WrapErr(err) } else { val = e.adapter.NativeToValue(out) } @@ -888,11 +919,13 @@ func (e *evalWatchConstQual) QualifyIfPresent(vars Activation, obj any, presence out, present, err := e.ConstantQualifier.QualifyIfPresent(vars, obj, presenceOnly) var val ref.Val if err != nil { - val = types.NewErr(err.Error()) - } else if present { + val = types.WrapErr(err) + } else if out != nil { val = e.adapter.NativeToValue(out) + } else if out == nil && presenceOnly { + val = types.Bool(present) } - if present { + if present || presenceOnly { e.observer(e.ID(), e.ConstantQualifier, val) } return out, present, err @@ -916,7 +949,7 @@ func (e *evalWatchQual) Qualify(vars Activation, obj any) (any, error) { out, err := e.Qualifier.Qualify(vars, obj) var val ref.Val if err != nil { - val = types.NewErr(err.Error()) + val = types.WrapErr(err) } else { val = e.adapter.NativeToValue(out) } @@ -929,11 +962,13 @@ func (e *evalWatchQual) QualifyIfPresent(vars Activation, obj any, presenceOnly out, present, err := e.Qualifier.QualifyIfPresent(vars, obj, presenceOnly) var val ref.Val if err != nil { - val = types.NewErr(err.Error()) - } else if present { + val = types.WrapErr(err) + } else if out != nil { val = e.adapter.NativeToValue(out) + } else if out == nil && presenceOnly { + val = types.Bool(present) } - if present { + if present || presenceOnly { e.observer(e.ID(), e.Qualifier, val) } return out, present, err @@ -1058,12 +1093,12 @@ func (cond *evalExhaustiveConditional) Eval(ctx Activation) ref.Val { } if cBool { if tErr != nil { - return types.NewErr(tErr.Error()) + return types.WrapErr(tErr) } return cond.adapter.NativeToValue(tVal) } if fErr != nil { - return types.NewErr(fErr.Error()) + return types.WrapErr(fErr) } return cond.adapter.NativeToValue(fVal) } @@ -1075,6 +1110,8 @@ type evalAttr struct { optional bool } +var _ InterpretableAttribute = &evalAttr{} + // ID of the attribute instruction. func (a *evalAttr) ID() int64 { return a.attr.ID() @@ -1101,7 +1138,7 @@ func (a *evalAttr) Adapter() ref.TypeAdapter { func (a *evalAttr) Eval(ctx Activation) ref.Val { v, err := a.attr.Resolve(ctx) if err != nil { - return types.NewErr(err.Error()) + return types.WrapErr(err) } return a.adapter.NativeToValue(v) } diff --git a/interpreter/planner.go b/interpreter/planner.go index 9cf8e4e5..0b65d0fa 100644 --- a/interpreter/planner.go +++ b/interpreter/planner.go @@ -20,7 +20,6 @@ import ( "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/operators" - "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/interpreter/functions" @@ -217,18 +216,14 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) { if err != nil { return nil, err } - - // Return the test only eval expression. + // Modify the attribute to be test-only. if sel.GetTestOnly() { - return &evalTestOnly{ - id: expr.GetId(), - field: types.String(sel.GetField()), - attr: attr, - qual: qual, - }, nil + attr = &evalTestOnly{ + id: expr.GetId(), + InterpretableAttribute: attr, + } } - - // Otherwise, append the qualifier on the attribute. + // Append the qualifier on the attribute. _, err = attr.AddQualifier(qual) return attr, err } diff --git a/interpreter/prune.go b/interpreter/prune.go index b7f3a4d2..d1b5d6bd 100644 --- a/interpreter/prune.go +++ b/interpreter/prune.go @@ -227,12 +227,13 @@ func (p *astPruner) maybePruneOptional(elem *exprpb.Expr) (*exprpb.Expr, bool) { } func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) { + // elem in list call := node.GetCallExpr() - v, exists := p.value(call.GetArgs()[1].GetId()) - if !exists || types.IsUnknownOrError(v) { + val, exists := p.maybeValue(call.GetArgs()[1].GetId()) + if !exists { return nil, false } - if sz, ok := v.(traits.Sizer); ok && sz.Size() == types.IntZero { + if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero { return p.maybeCreateLiteral(node.GetId(), types.False) } return nil, false @@ -241,24 +242,49 @@ func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) { func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) { call := node.GetCallExpr() arg := call.GetArgs()[0] - v, exists := p.value(arg.GetId()) - if !exists || types.IsUnknownOrError(v) { + val, exists := p.maybeValue(arg.GetId()) + if !exists { return nil, false } - if b, ok := v.(types.Bool); ok { + if b, ok := val.(types.Bool); ok { return p.maybeCreateLiteral(node.GetId(), !b) } return nil, false } -func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) { +func (p *astPruner) maybePruneOr(node *exprpb.Expr) (*exprpb.Expr, bool) { call := node.GetCallExpr() // We know result is unknown, so we have at least one unknown arg // and if one side is a known value, we know we can ignore it. - if p.existsWithKnownValue(call.GetArgs()[0].GetId()) { + if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists { + if v == types.True { + return p.maybeCreateLiteral(node.GetId(), types.True) + } + return call.GetArgs()[1], true + } + if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists { + if v == types.True { + return p.maybeCreateLiteral(node.GetId(), types.True) + } + return call.GetArgs()[0], true + } + return nil, false +} + +func (p *astPruner) maybePruneAnd(node *exprpb.Expr) (*exprpb.Expr, bool) { + call := node.GetCallExpr() + // We know result is unknown, so we have at least one unknown arg + // and if one side is a known value, we know we can ignore it. + if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists { + if v == types.False { + return p.maybeCreateLiteral(node.GetId(), types.False) + } return call.GetArgs()[1], true } - if p.existsWithKnownValue(call.GetArgs()[1].GetId()) { + if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists { + if v == types.False { + return p.maybeCreateLiteral(node.GetId(), types.False) + } return call.GetArgs()[0], true } return nil, false @@ -266,8 +292,8 @@ func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) { func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) { call := node.GetCallExpr() - cond, exists := p.value(call.GetArgs()[0].GetId()) - if !exists || types.IsUnknownOrError(cond) { + cond, exists := p.maybeValue(call.GetArgs()[0].GetId()) + if !exists { return nil, false } if cond.Value().(bool) { @@ -277,9 +303,15 @@ func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool } func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) { + if _, exists := p.value(node.GetId()); !exists { + return nil, false + } call := node.GetCallExpr() - if call.Function == operators.LogicalOr || call.Function == operators.LogicalAnd { - return p.maybePruneAndOr(node) + if call.Function == operators.LogicalOr { + return p.maybePruneOr(node) + } + if call.Function == operators.LogicalAnd { + return p.maybePruneAnd(node) } if call.Function == operators.Conditional { return p.maybePruneConditional(node) @@ -301,12 +333,10 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { if node == nil { return node, false } - v, exists := p.value(node.GetId()) - if exists && !types.IsUnknownOrError(v) { - if newNode, ok := p.maybeCreateLiteral(node.GetId(), v); ok { - // if the macro completely evaluated, then delete the reference to it, if one exists. + val, valueExists := p.maybeValue(node.GetId()) + if valueExists { + if newNode, ok := p.maybeCreateLiteral(node.GetId(), val); ok { delete(p.macroCalls, node.GetId()) - // return the literal value. return newNode, true } } @@ -320,7 +350,6 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { // We have either an unknown/error value, or something we don't want to // transform, or expression was not evaluated. If possible, drill down // more. - switch node.GetExprKind().(type) { case *exprpb.Expr_SelectExpr: if operand, pruned := p.maybePrune(node.GetSelectExpr().GetOperand()); pruned { @@ -386,11 +415,11 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { if isOpt { newElem, pruned := p.maybePruneOptional(elem) if pruned { + prunedList = true if newElem != nil { newElems = append(newElems, newElem) prunedIdx++ } - prunedList = true continue } newOptIndexMap[int32(prunedIdx)] = true @@ -468,21 +497,18 @@ func (p *astPruner) value(id int64) (ref.Val, bool) { return val, (found && val != nil) } -func (p *astPruner) existsWithKnownValue(id int64) bool { - val, valueExists := p.value(id) - return valueExists && !types.IsUnknownOrError(val) +func (p *astPruner) maybeValue(id int64) (ref.Val, bool) { + val, found := p.value(id) + if !found || types.IsUnknownOrError(val) { + return nil, false + } + return val, true } func (p *astPruner) nextID() int64 { - for { - _, found := p.state.Value(p.nextExprID) - if !found { - next := p.nextExprID - p.nextExprID++ - return next - } - p.nextExprID++ - } + next := p.nextExprID + p.nextExprID++ + return next } type astVisitor struct { diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index fe60df78..f8f5f176 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -24,6 +24,7 @@ import ( "github.com/google/cel-go/interpreter/functions" "github.com/google/cel-go/parser" "github.com/google/cel-go/test" + "github.com/google/cel-go/test/proto3pb" ) type testInfo struct { @@ -68,6 +69,90 @@ var testCases = []testInfo{ expr: `this in []`, out: `false`, }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{"b": "exists"}, + }, NewAttributePattern("this")), + expr: `has(this.a) || !has(this.b)`, + out: `has(this.a) || !has(this.b)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{"b": "exists"}, + }, NewAttributePattern("this").QualString("a")), + expr: `has(this.a) || !has(this.b)`, + out: `has(this.a)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{"b": "exists"}, + }, NewAttributePattern("this").QualString("a")), + expr: `!has(this.b) || has(this.a)`, + out: `has(this.a)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{}, + }, NewAttributePattern("this")), + expr: `(!(this.a in []) || has(this.a)) || !has(this.b)`, + out: `true`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{}, + }, NewAttributePattern("this")), + expr: `has(this.a) || !has(this.b)`, + out: `has(this.a) || !has(this.b)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{}, + }, NewAttributePattern("this")), + expr: `(has(this.a) || !(this.a in [])) || !has(this.b)`, + out: `true`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{"a": "exists"}, + }, NewAttributePattern("this").QualString("b")), + expr: `has(this.a) && !has(this.b)`, + out: `!has(this.b)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{}, + }, NewAttributePattern("this")), + expr: `(has(this.a) && this.a in []) || !has(this.b)`, + out: `!has(this.b)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{}, + }, NewAttributePattern("this")), + expr: `(this.a in [] && has(this.a)) || !has(this.b)`, + out: `!has(this.b)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]any{"a": map[string]string{}}, + }, NewAttributePattern("this").QualString("a")), + expr: `has(this.a.b)`, + out: `has(this.a.b)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]any{"a": map[string]string{}}, + }, NewAttributePattern("this").QualString("a")), + expr: `has(this["a"].b)`, + out: `has(this["a"].b)`, + }, + { + in: partialActivation(map[string]any{ + "this": &proto3pb.TestAllTypes{SingleInt32: 0, SingleInt64: 1}, + }, NewAttributePattern("this").QualString("single_int64")), + expr: `has(this.single_int32) && !has(this.single_int64)`, + out: `false`, + }, { in: unknownActivation("this"), expr: `this in {}`, @@ -158,6 +243,11 @@ var testCases = []testInfo{ expr: `[?optional.of(10), ?a, 2, 3]`, out: `[10, ?a, 2, 3]`, }, + { + in: unknownActivation("a"), + expr: `[?optional.of(10), a, 2, 3]`, + out: `[10, a, 2, 3]`, + }, { in: partialActivation(map[string]any{"a": "hi"}, "b"), expr: `{?a: b.?c}`, @@ -314,7 +404,7 @@ func TestPrune(t *testing.T) { t.Fatalf(iss.ToDisplayString()) } state := NewEvalState() - reg := newTestRegistry(t) + reg := newTestRegistry(t, &proto3pb.TestAllTypes{}) attrs := NewPartialAttributeFactory(containers.DefaultContainer, reg, reg) dispatcher := NewDispatcher() dispatcher.Add(functions.StandardOverloads()...) @@ -331,6 +421,10 @@ func TestPrune(t *testing.T) { t.Error(err) } if !test.Compare(actual, tst.out) { + for _, id := range state.IDs() { + v, _ := state.Value(id) + t.Logf("state[%d] %v\n", id, v) + } t.Errorf("prune[%d], diff: %s", i, test.DiffMessage("structure", actual, tst.out)) } } @@ -345,10 +439,17 @@ func unknownActivation(vars ...string) PartialActivation { return a } -func partialActivation(in map[string]any, vars ...string) PartialActivation { +func partialActivation(in map[string]any, vars ...any) PartialActivation { pats := make([]*AttributePattern, len(vars)) for i, v := range vars { - pats[i] = NewAttributePattern(v) + if pat, ok := v.(*AttributePattern); ok { + pats[i] = pat + continue + } + if str, ok := v.(string); ok { + pats[i] = NewAttributePattern(str) + continue + } } a, _ := NewPartialActivation(in, pats...) return a diff --git a/interpreter/runtimecost.go b/interpreter/runtimecost.go index e7daf011..a47ed59b 100644 --- a/interpreter/runtimecost.go +++ b/interpreter/runtimecost.go @@ -69,8 +69,6 @@ func CostObserver(tracker *CostTracker) EvalObserver { tracker.stack.drop(t.rhs.ID(), t.lhs.ID()) case *evalFold: tracker.stack.drop(t.iterRange.ID()) - case *evalTestOnly: - tracker.cost += common.SelectAndIdentCost case Qualifier: tracker.cost++ case InterpretableCall: diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index b8d0fc5e..3e099dbe 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -306,7 +306,7 @@ func TestRuntimeCost(t *testing.T) { name: "select: field test only", expr: `has(input.single_int32)`, decls: []*exprpb.Decl{decls.NewVar("input", decls.NewObjectType("google.expr.proto3.test.TestAllTypes"))}, - want: 1, + want: 2, in: map[string]any{ "input": &proto3pb.TestAllTypes{ RepeatedBool: []bool{false}, @@ -321,7 +321,7 @@ func TestRuntimeCost(t *testing.T) { name: "select: non-proto field test", expr: `has(input.testAttr.nestedAttr)`, decls: []*exprpb.Decl{decls.NewVar("input", nestedMap)}, - want: 2, + want: 3, in: map[string]any{ "input": map[string]any{ "testAttr": map[string]any{ @@ -725,7 +725,6 @@ func TestRuntimeCost(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { ctx := constructActivation(t, tc.in) - var costLimit *uint64 if tc.limit > 0 { costLimit = &tc.limit