diff --git a/checker/cost.go b/checker/cost.go index 6cf8c4fe..f1102686 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -347,6 +347,7 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate { // this is equivalent to how evalTestOnly increments the runtime cost counter // but does not add any additional cost for the qualifier, except here we do // the reverse (ident adds cost) + sum = sum.Add(selectAndIdentCost) sum = sum.Add(c.cost(sel.GetOperand())) return sum } diff --git a/checker/cost_test.go b/checker/cost_test.go index 20feff26..30bb21e3 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -75,7 +75,13 @@ func TestCost(t *testing.T) { name: "select: field test only", expr: `has(input.single_int32)`, decls: []*exprpb.Decl{decls.NewVar("input", decls.NewObjectType("google.expr.proto3.test.TestAllTypes"))}, - wanted: CostEstimate{Min: 1, Max: 1}, + wanted: CostEstimate{Min: 2, Max: 2}, + }, + { + name: "select: non-proto field test", + expr: `has(input.testAttr.nestedAttr)`, + decls: []*exprpb.Decl{decls.NewVar("input", nestedMap)}, + wanted: CostEstimate{Min: 3, Max: 3}, }, { name: "estimated function call", diff --git a/interpreter/attributes.go b/interpreter/attributes.go index d2205d90..1b19dc2b 100644 --- a/interpreter/attributes.go +++ b/interpreter/attributes.go @@ -294,7 +294,11 @@ func (a *absoluteAttribute) Resolve(vars Activation) (any, error) { return nil, err } if isOpt { - return types.OptionalOf(a.adapter.NativeToValue(obj)), nil + val := a.adapter.NativeToValue(obj) + if types.IsUnknown(val) { + return val, nil + } + return types.OptionalOf(val), nil } return obj, nil } @@ -558,7 +562,11 @@ func (a *relativeAttribute) Resolve(vars Activation) (any, error) { return nil, err } if isOpt { - return types.OptionalOf(a.adapter.NativeToValue(obj)), nil + val := a.adapter.NativeToValue(obj) + if types.IsUnknown(val) { + return val, nil + } + return types.OptionalOf(val), nil } return obj, nil } @@ -1171,6 +1179,9 @@ func applyQualifiers(vars Activation, obj any, qualifiers []Qualifier) (any, boo return nil, false, err } if !present { + // We return optional none here with a presence of 'false' as the layers + // above will attempt to call types.OptionalOf() on a present value if any + // of the qualifiers is optional. return types.OptionalNone, false, nil } } else { @@ -1223,6 +1234,8 @@ func refQualify(adapter ref.TypeAdapter, obj any, idx ref.Val, presenceTest, pre return nil, false, v case traits.Mapper: val, found := v.Find(idx) + // If the index is of the wrong type for the map, then it is possible + // for the Find call to produce an error. if types.IsError(val) { return nil, false, val.(*types.Err) } @@ -1234,6 +1247,8 @@ func refQualify(adapter ref.TypeAdapter, obj any, idx ref.Val, presenceTest, pre } return nil, false, missingKey(idx) case traits.Lister: + // If the index argument is not a valid numeric type, then it is possible + // for the index operation to produce an error. i, err := types.IndexOrError(idx) if err != nil { return nil, false, err @@ -1254,6 +1269,8 @@ func refQualify(adapter ref.TypeAdapter, obj any, idx ref.Val, presenceTest, pre if types.IsError(presence) { return nil, false, presence.(*types.Err) } + // If not found or presence only test, then return. + // Otherwise, if found, obtain the value later on. if presenceOnly || presence == types.False { return nil, presence == types.True, nil } @@ -1320,17 +1337,3 @@ func (e *resolutionError) Error() string { func (e *resolutionError) Is(err error) bool { return err.Error() == e.Error() } - -func findMin(x, y int64) int64 { - if x < y { - return x - } - return y -} - -func findMax(x, y int64) int64 { - if x > y { - return x - } - return y -} diff --git a/interpreter/attributes_test.go b/interpreter/attributes_test.go index 2fa8f64f..a5c99350 100644 --- a/interpreter/attributes_test.go +++ b/interpreter/attributes_test.go @@ -24,6 +24,7 @@ import ( "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common" "github.com/google/cel-go/common/containers" + "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/parser" @@ -68,7 +69,7 @@ func TestAttributesAbsoluteAttr(t *testing.T) { } } -func TestAttributesAbsoluteAttr_Type(t *testing.T) { +func TestAttributesAbsoluteAttrType(t *testing.T) { reg := newTestRegistry(t) attrs := NewAttributeFactory(containers.DefaultContainer, reg, reg) @@ -861,6 +862,7 @@ func TestAttributeStateTracking(t *testing.T) { env []*exprpb.Decl in map[string]any out ref.Val + attrs []*AttributePattern state map[int64]any }{ { @@ -1051,12 +1053,54 @@ func TestAttributeStateTracking(t *testing.T) { 3: types.String("world"), }, }, + { + expr: `m[has(a.b)]`, + env: []*exprpb.Decl{ + decls.NewVar("a", decls.NewMapType(decls.String, decls.String)), + decls.NewVar("m", decls.NewMapType(decls.Bool, decls.String)), + }, + in: map[string]any{ + "a": map[string]string{"b": ""}, + "m": map[bool]string{true: "world"}, + }, + out: types.String("world"), + }, + { + expr: `m[?has(a.b)]`, + env: []*exprpb.Decl{ + decls.NewVar("a", decls.NewMapType(decls.String, decls.String)), + decls.NewVar("m", decls.NewMapType(decls.Bool, decls.String)), + }, + in: map[string]any{ + "a": map[string]string{"b": ""}, + "m": map[bool]string{true: "world"}, + }, + out: types.OptionalOf(types.String("world")), + }, + { + expr: `m[?has(a.b.c)]`, + env: []*exprpb.Decl{ + decls.NewVar("a", decls.NewMapType(decls.String, decls.Dyn)), + decls.NewVar("m", decls.NewMapType(decls.Bool, decls.String)), + }, + in: map[string]any{ + "a": map[string]any{}, + "m": map[bool]string{true: "world"}, + }, + out: types.Unknown{5}, + attrs: []*AttributePattern{ + NewAttributePattern("a").QualString("b"), + }, + }, } for _, test := range tests { tc := test t.Run(tc.expr, func(t *testing.T) { src := common.NewTextSource(tc.expr) - p, err := parser.NewParser(parser.EnableOptionalSyntax(true)) + p, err := parser.NewParser( + parser.EnableOptionalSyntax(true), + parser.Macros(parser.AllMacros...), + ) if err != nil { t.Fatalf("parser.NewParser() failed: %v", err) } @@ -1071,6 +1115,7 @@ func TestAttributeStateTracking(t *testing.T) { t.Fatalf("checker.NewEnv() failed: %v", err) } env.Add(checker.StandardDeclarations()...) + env.Add(optionalSignatures()...) if tc.env != nil { env.Add(tc.env...) } @@ -1079,6 +1124,9 @@ func TestAttributeStateTracking(t *testing.T) { t.Fatalf(errors.ToDisplayString()) } attrs := NewAttributeFactory(cont, reg, reg) + if tc.attrs != nil { + attrs = NewPartialAttributeFactory(cont, reg, reg) + } interp := NewStandardInterpreter(cont, reg, reg, attrs) // Show that program planning will now produce an error. st := NewEvalState() @@ -1089,10 +1137,14 @@ func TestAttributeStateTracking(t *testing.T) { if err != nil { t.Fatal(err) } - in, _ := NewActivation(tc.in) + in, _ := NewPartialActivation(tc.in, tc.attrs...) out := i.Eval(in) - if tc.out.Equal(out) != types.True { - t.Errorf("got %v, wanted %v", out.Value(), tc.out) + if types.IsUnknown(tc.out) && types.IsUnknown(out) { + if !reflect.DeepEqual(tc.out, out) { + t.Errorf("got %v, wanted %v", out, tc.out) + } + } else if tc.out.Equal(out) != types.True { + t.Errorf("got %v, wanted %v", out, tc.out) } for id, val := range tc.state { stVal, found := st.Value(id) @@ -1209,3 +1261,15 @@ func findField(t testing.TB, reg ref.TypeRegistry, typeName, field string) *ref. } return ft } + +func optionalSignatures() []*exprpb.Decl { + return []*exprpb.Decl{ + decls.NewFunction(operators.OptIndex, + decls.NewParameterizedOverload("map_optindex_optional_value", []*exprpb.Type{ + decls.NewMapType(decls.NewTypeParamType("K"), decls.NewTypeParamType("V")), + decls.NewTypeParamType("K")}, + decls.NewOptionalType(decls.NewTypeParamType("V")), + []string{"K", "V"}, + )), + } +} diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index ac9a63f3..7d67a92b 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -15,6 +15,8 @@ package interpreter import ( + "fmt" + "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" @@ -69,6 +71,7 @@ type InterpretableAttribute interface { // to whether the qualifier is present. QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) + // IsOptional indicates whether the resulting value is an optional type. IsOptional() bool // Resolve returns the value of the Attribute given the current Activation. @@ -108,10 +111,8 @@ type InterpretableConstructor interface { // Core Interpretable implementations used during the program planning phase. type evalTestOnly struct { - id int64 - attr InterpretableAttribute - qual Qualifier - field types.String + id int64 + InterpretableAttribute } // ID implements the Interpretable interface method. @@ -121,28 +122,58 @@ func (test *evalTestOnly) ID() int64 { // Eval implements the Interpretable interface method. func (test *evalTestOnly) Eval(ctx Activation) ref.Val { - val, err := test.attr.Resolve(ctx) + val, err := test.Resolve(ctx) + // Return an error if the resolve step fails if err != nil { - return types.NewErr(err.Error()) + return types.WrapErr(err) } - optVal, isOpt := val.(*types.Optional) - if isOpt { - if !optVal.HasValue() { - return types.False - } - val = optVal.GetValue() + if optVal, isOpt := val.(*types.Optional); isOpt { + return types.Bool(optVal.HasValue()) + } + return test.Adapter().NativeToValue(val) +} + +// AddQualifier appends a qualifier that will always and only perform a presence test. +func (test *evalTestOnly) AddQualifier(q Qualifier) (Attribute, error) { + cq, ok := q.(ConstantQualifier) + if !ok { + return nil, fmt.Errorf("test only expressions must have constant qualifiers: %v", q) } - out, found, err := test.qual.QualifyIfPresent(ctx, val, true) + return test.InterpretableAttribute.AddQualifier(&testOnlyQualifier{ConstantQualifier: cq}) +} + +type testOnlyQualifier struct { + ConstantQualifier +} + +// Qualify determines whether the test-only qualifier is present on the input object. +func (q *testOnlyQualifier) Qualify(vars Activation, obj any) (any, error) { + out, present, err := q.ConstantQualifier.QualifyIfPresent(vars, obj, true) if err != nil { - return types.NewErr(err.Error()) + return nil, err } if unk, isUnk := out.(types.Unknown); isUnk { - return unk + return unk, nil } - if found { - return types.True + if opt, isOpt := out.(types.Optional); isOpt { + return opt.HasValue(), nil } - return types.False + return present, nil +} + +// QualifyIfPresent returns whether the target field in the test-only expression is present. +// +// This method should never be called as the has() macro and optional syntax are incompatible +// when used on the same field. +func (q *testOnlyQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + // Only ever test for presence. + return q.ConstantQualifier.QualifyIfPresent(vars, obj, true) +} + +// QualifierValueEquals determines whether the test-only constant qualifier equals the input value. +func (q *testOnlyQualifier) QualifierValueEquals(value any) bool { + // The input qualifier will always be of type string + return q.ConstantQualifier.Value().Value() == value } // NewConstValue creates a new constant valued Interpretable. @@ -875,7 +906,7 @@ func (e *evalWatchConstQual) Qualify(vars Activation, obj any) (any, error) { out, err := e.ConstantQualifier.Qualify(vars, obj) var val ref.Val if err != nil { - val = types.NewErr(err.Error()) + val = types.WrapErr(err) } else { val = e.adapter.NativeToValue(out) } @@ -888,11 +919,13 @@ func (e *evalWatchConstQual) QualifyIfPresent(vars Activation, obj any, presence out, present, err := e.ConstantQualifier.QualifyIfPresent(vars, obj, presenceOnly) var val ref.Val if err != nil { - val = types.NewErr(err.Error()) - } else if present { + val = types.WrapErr(err) + } else if out != nil { val = e.adapter.NativeToValue(out) + } else if out == nil && presenceOnly { + val = types.Bool(present) } - if present { + if present || presenceOnly { e.observer(e.ID(), e.ConstantQualifier, val) } return out, present, err @@ -916,7 +949,7 @@ func (e *evalWatchQual) Qualify(vars Activation, obj any) (any, error) { out, err := e.Qualifier.Qualify(vars, obj) var val ref.Val if err != nil { - val = types.NewErr(err.Error()) + val = types.WrapErr(err) } else { val = e.adapter.NativeToValue(out) } @@ -929,11 +962,13 @@ func (e *evalWatchQual) QualifyIfPresent(vars Activation, obj any, presenceOnly out, present, err := e.Qualifier.QualifyIfPresent(vars, obj, presenceOnly) var val ref.Val if err != nil { - val = types.NewErr(err.Error()) - } else if present { + val = types.WrapErr(err) + } else if out != nil { val = e.adapter.NativeToValue(out) + } else if out == nil && presenceOnly { + val = types.Bool(present) } - if present { + if present || presenceOnly { e.observer(e.ID(), e.Qualifier, val) } return out, present, err @@ -1058,12 +1093,12 @@ func (cond *evalExhaustiveConditional) Eval(ctx Activation) ref.Val { } if cBool { if tErr != nil { - return types.NewErr(tErr.Error()) + return types.WrapErr(tErr) } return cond.adapter.NativeToValue(tVal) } if fErr != nil { - return types.NewErr(fErr.Error()) + return types.WrapErr(fErr) } return cond.adapter.NativeToValue(fVal) } @@ -1075,6 +1110,8 @@ type evalAttr struct { optional bool } +var _ InterpretableAttribute = &evalAttr{} + // ID of the attribute instruction. func (a *evalAttr) ID() int64 { return a.attr.ID() @@ -1101,7 +1138,7 @@ func (a *evalAttr) Adapter() ref.TypeAdapter { func (a *evalAttr) Eval(ctx Activation) ref.Val { v, err := a.attr.Resolve(ctx) if err != nil { - return types.NewErr(err.Error()) + return types.WrapErr(err) } return a.adapter.NativeToValue(v) } diff --git a/interpreter/planner.go b/interpreter/planner.go index 9cf8e4e5..0b65d0fa 100644 --- a/interpreter/planner.go +++ b/interpreter/planner.go @@ -20,7 +20,6 @@ import ( "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/operators" - "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/interpreter/functions" @@ -217,18 +216,14 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) { if err != nil { return nil, err } - - // Return the test only eval expression. + // Modify the attribute to be test-only. if sel.GetTestOnly() { - return &evalTestOnly{ - id: expr.GetId(), - field: types.String(sel.GetField()), - attr: attr, - qual: qual, - }, nil + attr = &evalTestOnly{ + id: expr.GetId(), + InterpretableAttribute: attr, + } } - - // Otherwise, append the qualifier on the attribute. + // Append the qualifier on the attribute. _, err = attr.AddQualifier(qual) return attr, err } diff --git a/interpreter/prune.go b/interpreter/prune.go index b7f3a4d2..d1b5d6bd 100644 --- a/interpreter/prune.go +++ b/interpreter/prune.go @@ -227,12 +227,13 @@ func (p *astPruner) maybePruneOptional(elem *exprpb.Expr) (*exprpb.Expr, bool) { } func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) { + // elem in list call := node.GetCallExpr() - v, exists := p.value(call.GetArgs()[1].GetId()) - if !exists || types.IsUnknownOrError(v) { + val, exists := p.maybeValue(call.GetArgs()[1].GetId()) + if !exists { return nil, false } - if sz, ok := v.(traits.Sizer); ok && sz.Size() == types.IntZero { + if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero { return p.maybeCreateLiteral(node.GetId(), types.False) } return nil, false @@ -241,24 +242,49 @@ func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) { func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) { call := node.GetCallExpr() arg := call.GetArgs()[0] - v, exists := p.value(arg.GetId()) - if !exists || types.IsUnknownOrError(v) { + val, exists := p.maybeValue(arg.GetId()) + if !exists { return nil, false } - if b, ok := v.(types.Bool); ok { + if b, ok := val.(types.Bool); ok { return p.maybeCreateLiteral(node.GetId(), !b) } return nil, false } -func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) { +func (p *astPruner) maybePruneOr(node *exprpb.Expr) (*exprpb.Expr, bool) { call := node.GetCallExpr() // We know result is unknown, so we have at least one unknown arg // and if one side is a known value, we know we can ignore it. - if p.existsWithKnownValue(call.GetArgs()[0].GetId()) { + if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists { + if v == types.True { + return p.maybeCreateLiteral(node.GetId(), types.True) + } + return call.GetArgs()[1], true + } + if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists { + if v == types.True { + return p.maybeCreateLiteral(node.GetId(), types.True) + } + return call.GetArgs()[0], true + } + return nil, false +} + +func (p *astPruner) maybePruneAnd(node *exprpb.Expr) (*exprpb.Expr, bool) { + call := node.GetCallExpr() + // We know result is unknown, so we have at least one unknown arg + // and if one side is a known value, we know we can ignore it. + if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists { + if v == types.False { + return p.maybeCreateLiteral(node.GetId(), types.False) + } return call.GetArgs()[1], true } - if p.existsWithKnownValue(call.GetArgs()[1].GetId()) { + if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists { + if v == types.False { + return p.maybeCreateLiteral(node.GetId(), types.False) + } return call.GetArgs()[0], true } return nil, false @@ -266,8 +292,8 @@ func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) { func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) { call := node.GetCallExpr() - cond, exists := p.value(call.GetArgs()[0].GetId()) - if !exists || types.IsUnknownOrError(cond) { + cond, exists := p.maybeValue(call.GetArgs()[0].GetId()) + if !exists { return nil, false } if cond.Value().(bool) { @@ -277,9 +303,15 @@ func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool } func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) { + if _, exists := p.value(node.GetId()); !exists { + return nil, false + } call := node.GetCallExpr() - if call.Function == operators.LogicalOr || call.Function == operators.LogicalAnd { - return p.maybePruneAndOr(node) + if call.Function == operators.LogicalOr { + return p.maybePruneOr(node) + } + if call.Function == operators.LogicalAnd { + return p.maybePruneAnd(node) } if call.Function == operators.Conditional { return p.maybePruneConditional(node) @@ -301,12 +333,10 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { if node == nil { return node, false } - v, exists := p.value(node.GetId()) - if exists && !types.IsUnknownOrError(v) { - if newNode, ok := p.maybeCreateLiteral(node.GetId(), v); ok { - // if the macro completely evaluated, then delete the reference to it, if one exists. + val, valueExists := p.maybeValue(node.GetId()) + if valueExists { + if newNode, ok := p.maybeCreateLiteral(node.GetId(), val); ok { delete(p.macroCalls, node.GetId()) - // return the literal value. return newNode, true } } @@ -320,7 +350,6 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { // We have either an unknown/error value, or something we don't want to // transform, or expression was not evaluated. If possible, drill down // more. - switch node.GetExprKind().(type) { case *exprpb.Expr_SelectExpr: if operand, pruned := p.maybePrune(node.GetSelectExpr().GetOperand()); pruned { @@ -386,11 +415,11 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { if isOpt { newElem, pruned := p.maybePruneOptional(elem) if pruned { + prunedList = true if newElem != nil { newElems = append(newElems, newElem) prunedIdx++ } - prunedList = true continue } newOptIndexMap[int32(prunedIdx)] = true @@ -468,21 +497,18 @@ func (p *astPruner) value(id int64) (ref.Val, bool) { return val, (found && val != nil) } -func (p *astPruner) existsWithKnownValue(id int64) bool { - val, valueExists := p.value(id) - return valueExists && !types.IsUnknownOrError(val) +func (p *astPruner) maybeValue(id int64) (ref.Val, bool) { + val, found := p.value(id) + if !found || types.IsUnknownOrError(val) { + return nil, false + } + return val, true } func (p *astPruner) nextID() int64 { - for { - _, found := p.state.Value(p.nextExprID) - if !found { - next := p.nextExprID - p.nextExprID++ - return next - } - p.nextExprID++ - } + next := p.nextExprID + p.nextExprID++ + return next } type astVisitor struct { diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index fe60df78..f8f5f176 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -24,6 +24,7 @@ import ( "github.com/google/cel-go/interpreter/functions" "github.com/google/cel-go/parser" "github.com/google/cel-go/test" + "github.com/google/cel-go/test/proto3pb" ) type testInfo struct { @@ -68,6 +69,90 @@ var testCases = []testInfo{ expr: `this in []`, out: `false`, }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{"b": "exists"}, + }, NewAttributePattern("this")), + expr: `has(this.a) || !has(this.b)`, + out: `has(this.a) || !has(this.b)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{"b": "exists"}, + }, NewAttributePattern("this").QualString("a")), + expr: `has(this.a) || !has(this.b)`, + out: `has(this.a)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{"b": "exists"}, + }, NewAttributePattern("this").QualString("a")), + expr: `!has(this.b) || has(this.a)`, + out: `has(this.a)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{}, + }, NewAttributePattern("this")), + expr: `(!(this.a in []) || has(this.a)) || !has(this.b)`, + out: `true`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{}, + }, NewAttributePattern("this")), + expr: `has(this.a) || !has(this.b)`, + out: `has(this.a) || !has(this.b)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{}, + }, NewAttributePattern("this")), + expr: `(has(this.a) || !(this.a in [])) || !has(this.b)`, + out: `true`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{"a": "exists"}, + }, NewAttributePattern("this").QualString("b")), + expr: `has(this.a) && !has(this.b)`, + out: `!has(this.b)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{}, + }, NewAttributePattern("this")), + expr: `(has(this.a) && this.a in []) || !has(this.b)`, + out: `!has(this.b)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]string{}, + }, NewAttributePattern("this")), + expr: `(this.a in [] && has(this.a)) || !has(this.b)`, + out: `!has(this.b)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]any{"a": map[string]string{}}, + }, NewAttributePattern("this").QualString("a")), + expr: `has(this.a.b)`, + out: `has(this.a.b)`, + }, + { + in: partialActivation(map[string]any{ + "this": map[string]any{"a": map[string]string{}}, + }, NewAttributePattern("this").QualString("a")), + expr: `has(this["a"].b)`, + out: `has(this["a"].b)`, + }, + { + in: partialActivation(map[string]any{ + "this": &proto3pb.TestAllTypes{SingleInt32: 0, SingleInt64: 1}, + }, NewAttributePattern("this").QualString("single_int64")), + expr: `has(this.single_int32) && !has(this.single_int64)`, + out: `false`, + }, { in: unknownActivation("this"), expr: `this in {}`, @@ -158,6 +243,11 @@ var testCases = []testInfo{ expr: `[?optional.of(10), ?a, 2, 3]`, out: `[10, ?a, 2, 3]`, }, + { + in: unknownActivation("a"), + expr: `[?optional.of(10), a, 2, 3]`, + out: `[10, a, 2, 3]`, + }, { in: partialActivation(map[string]any{"a": "hi"}, "b"), expr: `{?a: b.?c}`, @@ -314,7 +404,7 @@ func TestPrune(t *testing.T) { t.Fatalf(iss.ToDisplayString()) } state := NewEvalState() - reg := newTestRegistry(t) + reg := newTestRegistry(t, &proto3pb.TestAllTypes{}) attrs := NewPartialAttributeFactory(containers.DefaultContainer, reg, reg) dispatcher := NewDispatcher() dispatcher.Add(functions.StandardOverloads()...) @@ -331,6 +421,10 @@ func TestPrune(t *testing.T) { t.Error(err) } if !test.Compare(actual, tst.out) { + for _, id := range state.IDs() { + v, _ := state.Value(id) + t.Logf("state[%d] %v\n", id, v) + } t.Errorf("prune[%d], diff: %s", i, test.DiffMessage("structure", actual, tst.out)) } } @@ -345,10 +439,17 @@ func unknownActivation(vars ...string) PartialActivation { return a } -func partialActivation(in map[string]any, vars ...string) PartialActivation { +func partialActivation(in map[string]any, vars ...any) PartialActivation { pats := make([]*AttributePattern, len(vars)) for i, v := range vars { - pats[i] = NewAttributePattern(v) + if pat, ok := v.(*AttributePattern); ok { + pats[i] = pat + continue + } + if str, ok := v.(string); ok { + pats[i] = NewAttributePattern(str) + continue + } } a, _ := NewPartialActivation(in, pats...) return a diff --git a/interpreter/runtimecost.go b/interpreter/runtimecost.go index e7daf011..a47ed59b 100644 --- a/interpreter/runtimecost.go +++ b/interpreter/runtimecost.go @@ -69,8 +69,6 @@ func CostObserver(tracker *CostTracker) EvalObserver { tracker.stack.drop(t.rhs.ID(), t.lhs.ID()) case *evalFold: tracker.stack.drop(t.iterRange.ID()) - case *evalTestOnly: - tracker.cost += common.SelectAndIdentCost case Qualifier: tracker.cost++ case InterpretableCall: diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index b8d0fc5e..3e099dbe 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -306,7 +306,7 @@ func TestRuntimeCost(t *testing.T) { name: "select: field test only", expr: `has(input.single_int32)`, decls: []*exprpb.Decl{decls.NewVar("input", decls.NewObjectType("google.expr.proto3.test.TestAllTypes"))}, - want: 1, + want: 2, in: map[string]any{ "input": &proto3pb.TestAllTypes{ RepeatedBool: []bool{false}, @@ -321,7 +321,7 @@ func TestRuntimeCost(t *testing.T) { name: "select: non-proto field test", expr: `has(input.testAttr.nestedAttr)`, decls: []*exprpb.Decl{decls.NewVar("input", nestedMap)}, - want: 2, + want: 3, in: map[string]any{ "input": map[string]any{ "testAttr": map[string]any{ @@ -725,7 +725,6 @@ func TestRuntimeCost(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { ctx := constructActivation(t, tc.in) - var costLimit *uint64 if tc.limit > 0 { costLimit = &tc.limit