From a22969df5e0afe445b8bfddeeb2e02bab9d62c6f Mon Sep 17 00:00:00 2001 From: Joe Betz Date: Thu, 3 Mar 2022 13:47:59 -0500 Subject: [PATCH] Add runtime cost limit --- cel/BUILD.bazel | 1 - cel/cel_test.go | 68 +++-- cel/options.go | 17 +- cel/program.go | 13 +- checker/cost.go | 67 +++-- checker/cost_test.go | 6 +- common/BUILD.bazel | 1 + common/cost.go | 40 +++ interpreter/attributes_test.go | 2 +- interpreter/interpretable.go | 51 ---- interpreter/interpreter.go | 55 ++-- interpreter/interpreter_test.go | 14 +- interpreter/prune_test.go | 2 +- interpreter/runtimecost.go | 106 ++++---- {cel => interpreter}/runtimecost_test.go | 316 +++++++++++++++-------- 15 files changed, 463 insertions(+), 296 deletions(-) create mode 100644 common/cost.go rename {cel => interpreter}/runtimecost_test.go (61%) diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index 6c8cb477..b43beedc 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -46,7 +46,6 @@ go_test( srcs = [ "cel_test.go", "io_test.go", - "runtimecost_test.go", ], data = [ "//cel/testdata:gen_test_fds", diff --git a/cel/cel_test.go b/cel/cel_test.go index cb163a9c..31883f09 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -1426,6 +1426,7 @@ func TestCustomInterpreterDecorator(t *testing.T) { } } +// TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package type testCostEstimator struct { hints map[string]int64 } @@ -1445,6 +1446,34 @@ func (tc testCostEstimator) EstimateCallCost(overloadId string, target *checker. return nil } +type testRuntimeCostEstimator struct { +} + +var timeToYearCost uint64 = 7 + +func (e testRuntimeCostEstimator) CallCost(overloadId string, args []ref.Val) *uint64 { + argsSize := make([]uint64, len(args)) + for i, arg := range args { + reflectV := reflect.ValueOf(arg.Value()) + switch reflectV.Kind() { + // Note that the CEL bytes type is implemented with Go byte slices, therefore also supported by the following + // code. + case reflect.String, reflect.Array, reflect.Slice, reflect.Map: + argsSize[i] = uint64(reflectV.Len()) + default: + argsSize[i] = 1 + } + } + + switch overloadId { + case overloads.TimestampToYear: + return &timeToYearCost + default: + return nil + } +} + +// TestEstimateCostAndRuntimeCost sanity checks that the cost systems are usable from the program API. func TestEstimateCostAndRuntimeCost(t *testing.T) { intList := decls.NewListType(decls.Int) zeroCost := checker.CostEstimate{} @@ -1460,6 +1489,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { name: "const", expr: `"Hello World!"`, want: zeroCost, + in: map[string]interface{}{}, }, { name: "identity", @@ -1479,21 +1509,6 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { want: checker.CostEstimate{Min: 2, Max: 6}, in: map[string]interface{}{"str1": "val1111111", "str2": "val2222222"}, }, - { - name: "ternary with var", - expr: `true > false ? [1, 2, 3].all(x, true) : false`, - want: checker.CostEstimate{Min: 1, Max: 21}, - }, - { - name: "short circuited, 0 cost lhs", - expr: `true || 4 > 3 || 3 > 2 || 2 > 1 || 1 > 0`, - want: checker.CostEstimate{Min: 0, Max: 4}, - }, - { - name: "short circuited, 1 cost lhs", - expr: `3 > 4 || 4 > 3 || 3 > 2 || 2 > 1 || 1 > 0`, - want: checker.CostEstimate{Min: 1, Max: 5}, - }, } for _, tc := range cases { @@ -1505,7 +1520,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { Declarations(tc.decls...), Types(&proto3pb.TestAllTypes{})) if err != nil { - t.Fatalf("environment creation error: %s\n", err) + t.Fatalf("NewEnv(opts ...EnvOption) failed to create an environment: %s\n", err) } ast, iss := e.Compile(tc.expr) if iss.Err() != nil { @@ -1513,34 +1528,33 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { } est, err := e.EstimateCost(ast, testCostEstimator{hints: tc.hints}) if err != nil { - t.Fatalf("estimate cost error: %s\n", err) + t.Fatalf("Env.EstimateCost(ast *Ast, estimator checker.CostEstimator) failed to estimate cost: %s\n", err) } if est.Min != tc.want.Min || est.Max != tc.want.Max { - t.Fatalf("Got cost interval [%v, %v], wanted [%v, %v]", + t.Fatalf("Env.EstimateCost(ast *Ast, estimator checker.CostEstimator) failed to return the right cost interval. Got [%v, %v], wanted [%v, %v]", est.Min, est.Max, tc.want.Min, tc.want.Max) } - ctx := constructActivation(t, tc.in) - checked_ast, iss := e.Check(ast) + checkedAst, iss := e.Check(ast) if iss.Err() != nil { - t.Fatalf(`Failed to check expression with error: %v`, iss.Err()) + t.Fatalf(`Env.Check(ast *Ast) failed to check expression: %v`, iss.Err()) } // Evaluate expression. - program, err := e.Program(checked_ast, ActualCostTracking(testRuntimeCostEstimator{})) + program, err := e.Program(checkedAst, CostTracking(testRuntimeCostEstimator{})) if err != nil { - t.Fatalf(`Failed to construct Program with error: %v`, err) + t.Fatalf(`Env.Program(ast *Ast, opts ...ProgramOption) failed to construct program: %v`, err) } - _, details, err := program.Eval(ctx) + _, details, err := program.Eval(tc.in) if err != nil { - t.Fatalf(`Failed to evaluate expression with error: %v`, err) + t.Fatalf(`Program.Eval(vars interface{}) failed to evaluate expression: %v`, err) } actualCost := details.ActualCost() if actualCost == nil { - t.Fatalf(`Null pointer returned for the cost of expression "%s"`, tc.expr) + t.Errorf(`EvalDetails.ActualCost() got nil for "%s" cost, wanted %d`, tc.expr, actualCost) } if est.Min > *actualCost || est.Max < *actualCost { - t.Fatalf("runtime cost %d is out of the range of estimate cost [%d, %d]", *actualCost, + t.Errorf("EvalDetails.ActualCost() failed to return a runtime cost %d is the range of estimate cost [%d, %d]", *actualCost, est.Min, est.Max) } }) diff --git a/cel/options.go b/cel/options.go index 5750464f..9a5df568 100644 --- a/cel/options.go +++ b/cel/options.go @@ -400,9 +400,8 @@ func InterruptCheckFrequency(checkFrequency uint) ProgramOption { } } -// ActualCostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls. -// This enables runtime costs to be assigned to calls function library extensions. -func ActualCostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption { +// CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls. +func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption { return func(p *prog) (*prog, error) { p.callCostEstimator = costEstimator p.evalOpts |= OptTrackCost @@ -410,6 +409,18 @@ func ActualCostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOp } } +// CostLimit enables cost tracking and sets configures program evaluation to exit early with a +// "runtime cost limit exceeded" error if the runtime cost exceeds the costLimit. +// The CostLimit is a metric that corresponds to the number and estimated expense of operations +// performed while evaluating an expression. It is indicative of CPU usage, not memory usage. +func CostLimit(costLimit uint64) ProgramOption { + return func(p *prog) (*prog, error) { + p.costLimit = &costLimit + p.evalOpts |= OptTrackCost + return p, nil + } +} + func fieldToCELType(field protoreflect.FieldDescriptor) (*exprpb.Type, error) { if field.Kind() == protoreflect.MessageKind { msgName := (string)(field.Message().FullName()) diff --git a/cel/program.go b/cel/program.go index 259a7ebf..467f4d6b 100644 --- a/cel/program.go +++ b/cel/program.go @@ -127,6 +127,7 @@ type prog struct { // Interpretable configured from an Ast and aggregate decorator set based on program options. interpretable interpreter.Interpretable callCostEstimator interpreter.ActualCostEstimator + costLimit *uint64 } func (p *prog) clone() *prog { @@ -196,7 +197,8 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) { // Enable exhaustive eval, state tracking and cost tracking last since they require a factory. if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 { factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) { - costTracker.CallCostEstimator = p.callCostEstimator + costTracker.Estimator = p.callCostEstimator + costTracker.Limit = p.costLimit decs := decorators var observers []interpreter.EvalObserver @@ -210,7 +212,7 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) { // Enable exhaustive eval over a basic observer since it offers a superset of features. if p.evalOpts&OptExhaustiveEval == OptExhaustiveEval { - decs = append(decs, interpreter.ExhaustiveEvalWrapper(interpreter.Observe(observers...))) + decs = append(decs, interpreter.ExhaustiveEval(), interpreter.Observe(observers...)) } else if len(observers) > 0 { decs = append(decs, interpreter.Observe(observers...)) } @@ -255,7 +257,12 @@ func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error) // function. defer func() { if r := recover(); r != nil { - err = fmt.Errorf("internal error: %v", r) + switch t := r.(type) { + case interpreter.EvalCancelledError: + err = t + default: + err = fmt.Errorf("internal error: %v", r) + } } }() // Build a hierarchical activation if there are default vars set. diff --git a/checker/cost.go b/checker/cost.go index 6c512d2f..68eaf8be 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -17,12 +17,15 @@ package checker import ( "math" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/parser" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) +// WARNING: Any changes to cost calculations in this file require a corresponding change in interpreter/runtimecost.go + // CostEstimator estimates the sizes of variable length input data and the costs of functions. type CostEstimator interface { // EstimateSize returns a SizeEstimate for the given AstNode, or nil if @@ -56,7 +59,8 @@ type AstNode interface { Expr() *exprpb.Expr // ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression. // For constants and inline list and map declarations, the exact size is returned. For concatenated list, strings - // and bytes, the size is derived from the size estimates of the operands. + // and bytes, the size is derived from the size estimates of the operands. nil is returned if there is no + // computed size available. ComputedSize() *SizeEstimate } @@ -91,6 +95,10 @@ func (e astNode) ComputedSize() *SizeEstimate { v = uint64(len(ck.StringValue)) case *exprpb.Constant_BytesValue: v = uint64(len(ck.BytesValue)) + case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue, + *exprpb.Constant_Int64Value, *exprpb.Constant_TimestampValue, *exprpb.Constant_Uint64Value, + *exprpb.Constant_NullValue: + v = uint64(1) default: return nil } @@ -233,13 +241,12 @@ func multiplyByCostFactor(x uint64, y float64) uint64 { } var ( - identCost = CostEstimate{Min: 1, Max: 1} - selectCost = CostEstimate{Min: 1, Max: 1} - constCost = CostEstimate{Min: 0, Max: 0} + selectAndIdentCost = CostEstimate{Min: common.SelectAndIdentCost, Max: common.SelectAndIdentCost} + constCost = CostEstimate{Min: common.ConstCost, Max: common.ConstCost} - createListBaseCost = CostEstimate{Min: 10, Max: 10} - createMapBaseCost = CostEstimate{Min: 30, Max: 30} - createMessageBaseCost = CostEstimate{Min: 40, Max: 40} + createListBaseCost = CostEstimate{Min: common.ListCreateBaseCost, Max: common.ListCreateBaseCost} + createMapBaseCost = CostEstimate{Min: common.MapCreateBaseCost, Max: common.MapCreateBaseCost} + createMessageBaseCost = CostEstimate{Min: common.StructCreateBaseCost, Max: common.StructCreateBaseCost} ) type coster struct { @@ -326,7 +333,7 @@ func (c *coster) costIdent(e *exprpb.Expr) CostEstimate { c.addPath(e, []string{identExpr.GetName()}) } - return identCost + return selectAndIdentCost } func (c *coster) costSelect(e *exprpb.Expr) CostEstimate { @@ -339,7 +346,7 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate { targetType := c.getType(sel.GetOperand()) switch kindOf(targetType) { case kindMap, kindObject, kindTypeParam: - sum = sum.Add(selectCost) + sum = sum.Add(selectAndIdentCost) } // build and track the field path @@ -456,8 +463,6 @@ func (c *coster) costComprehension(e *exprpb.Expr) CostEstimate { stepCost := c.cost(comp.GetLoopStep()) c.iterRanges.pop(comp.GetIterVar()) sum = sum.Add(c.cost(comp.Result)) - // TODO: comprehensions short circuit, so even if the min list size > 0, the minimum number of elements evaluated - // will be 1. rangeCnt := c.sizeEstimate(c.newAstNode(comp.GetIterRange())) rangeCost := rangeCnt.MultiplyByCost(stepCost.Add(loopCost)) sum = sum.Add(rangeCost) @@ -477,9 +482,6 @@ func (c *coster) sizeEstimate(t AstNode) SizeEstimate { func (c *coster) functionCost(overloadId string, target *AstNode, args []AstNode, argCosts []CostEstimate) CallEstimate { argCostSum := func() CostEstimate { - // TODO: handle ternary - // TODO: && || operators short circuit, so min cost should only include 1st arg eval - // unless exhaustive evaluation is enabled var sum CostEstimate for _, a := range argCosts { sum = sum.Add(a) @@ -495,7 +497,7 @@ func (c *coster) functionCost(overloadId string, target *AstNode, args []AstNode // O(n) functions case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString: if len(args) == 1 { - return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(0.1).Add(argCostSum())} + return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())} } case overloads.InList: // If a list is composed entirely of constant values this is O(1), but we don't account for that here. @@ -507,19 +509,21 @@ func (c *coster) functionCost(overloadId string, target *AstNode, args []AstNode case overloads.MatchesString: // https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL if target != nil && len(args) == 1 { - strCost := c.sizeEstimate(*target).MultiplyByCostFactor(0.1) + // Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0 + // in case where string is empty but regex is still expensive. + strCost := c.sizeEstimate(*target).Add(SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor) // We don't know how many expressions are in the regex, just the string length (a huge // improvement here would be to somehow get a count the number of expressions in the regex or // how many states are in the regex state machine and use that to measure regex cost). // For now, we're making a guess that each expression in a regex is typically at least 4 chars // in length. - regexCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(0.25) + regexCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor) return CallEstimate{CostEstimate: strCost.Multiply(regexCost).Add(argCostSum())} } case overloads.ContainsString: if target != nil && len(args) == 1 { - strCost := c.sizeEstimate(*target).MultiplyByCostFactor(0.1) - substrCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(0.1) + strCost := c.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor) + substrCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor) return CallEstimate{CostEstimate: strCost.Multiply(substrCost).Add(argCostSum())} } case overloads.LogicalOr, overloads.LogicalAnd: @@ -540,10 +544,33 @@ func (c *coster) functionCost(overloadId string, target *AstNode, args []AstNode lhsSize := c.sizeEstimate(args[0]) rhsSize := c.sizeEstimate(args[1]) resultSize := lhsSize.Add(rhsSize) - return CallEstimate{CostEstimate: resultSize.MultiplyByCostFactor(0.1).Add(argCostSum()), ResultSize: &resultSize} + switch overloadId { + case overloads.AddList: + // list concatenation is O(1), but we handle it here to track size + return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum()), ResultSize: &resultSize} + default: + return CallEstimate{CostEstimate: resultSize.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &resultSize} + } + } + case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString, + overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes, + overloads.Equals, overloads.NotEquals: + lhsCost := c.sizeEstimate(args[0]) + rhsCost := c.sizeEstimate(args[1]) + min := uint64(0) + smallestMax := lhsCost.Max + if rhsCost.Max < smallestMax { + smallestMax = rhsCost.Max } + if smallestMax > 0 { + min = 1 + } + // equality of 2 scalar values results in a cost of 1 + return CallEstimate{CostEstimate: CostEstimate{Min: min, Max: smallestMax}.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())} } // O(1) functions + // See CostTracker.costCall for more details about O(1) cost calculations + // Benchmarks suggest that most of the other operations take +/- 50% of a base cost unit // which on an Intel xeon 2.20GHz CPU is 50ns. return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum())} diff --git a/checker/cost_test.go b/checker/cost_test.go index 91d0fc24..1b362a52 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -122,7 +122,7 @@ func TestCost(t *testing.T) { decls: []*exprpb.Decl{decls.NewVar("input", decls.String)}, hints: map[string]int64{"input": 500}, expr: `input.matches('[0-9]')`, - wanted: CostEstimate{Min: 1, Max: 101}, + wanted: CostEstimate{Min: 3, Max: 103}, }, { name: "variable cost function with constant", @@ -240,7 +240,7 @@ func TestCost(t *testing.T) { decls.NewVar("input", decls.String), }, hints: map[string]int64{"input": 500}, - wanted: CostEstimate{Min: 1, Max: 101}, + wanted: CostEstimate{Min: 3, Max: 103}, }, { name: "startsWith", @@ -342,7 +342,7 @@ func TestCost(t *testing.T) { decls.NewVar("list2", decls.NewListType(decls.Int)), }, hints: map[string]int64{"list1": 10, "list2": 10}, - wanted: CostEstimate{Min: 3, Max: 65}, + wanted: CostEstimate{Min: 4, Max: 64}, }, { name: "str concat", diff --git a/common/BUILD.bazel b/common/BUILD.bazel index 9e4ad65e..a0058aeb 100644 --- a/common/BUILD.bazel +++ b/common/BUILD.bazel @@ -8,6 +8,7 @@ package( go_library( name = "go_default_library", srcs = [ + "cost.go", "error.go", "errors.go", "location.go", diff --git a/common/cost.go b/common/cost.go new file mode 100644 index 00000000..5e24bd0f --- /dev/null +++ b/common/cost.go @@ -0,0 +1,40 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +const ( + // SelectAndIdentCost is the cost of an operation that accesses an identifier or performs a select. + SelectAndIdentCost = 1 + + // ConstCost is the cost of an operation that accesses a constant. + ConstCost = 0 + + // ListCreateBaseCost is the base cost of any operation that creates a new list. + ListCreateBaseCost = 10 + + // MapCreateBaseCost is the base cost of any operation that creates a new map. + MapCreateBaseCost = 30 + + // StructCreateBaseCost is the base cost of any operation that creates a new struct. + StructCreateBaseCost = 40 + + // StringTraversalCostFactor is multiplied to a length of a string when computing the cost of traversing the entire + // string once. + StringTraversalCostFactor = 0.1 + + // RegexStringLengthCostFactor is multiplied ot the length of a regex string pattern when computing the cost of + // applying the regex to a string of unit cost. + RegexStringLengthCostFactor = 0.25 +) diff --git a/interpreter/attributes_test.go b/interpreter/attributes_test.go index 1dfefdfc..9ee5007f 100644 --- a/interpreter/attributes_test.go +++ b/interpreter/attributes_test.go @@ -690,7 +690,7 @@ func TestAttributeStateTracking(t *testing.T) { interp := NewStandardInterpreter(cont, reg, reg, attrs) // Show that program planning will now produce an error. st := NewEvalState() - i, err := interp.NewInterpretable(checked, Optimize(), TrackState(st)) + i, err := interp.NewInterpretable(checked, Optimize(), Observe(EvalStateObserver(st))) if err != nil { t.Fatal(err) } diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index 6d161a1b..c87b351a 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -100,17 +100,6 @@ type InterpretableConstructor interface { Type() ref.Type } -// InterpretableBooleanBinaryOp interface for inspecting "&&" and "||" booelan operations. -type InterpretableBooleanBinaryOp interface { - Interpretable - - // LHS returns the left-hand size. - LHS() Interpretable - - // RHS returns the right-hand size. - RHS() Interpretable -} - // Core Interpretable implementations used during the program planning phase. type evalTestOnly struct { @@ -211,16 +200,6 @@ func (or *evalOr) ID() int64 { return or.id } -// LHS implements the InterpretableBooleanBinaryOp interface method. -func (or *evalOr) LHS() Interpretable { - return or.lhs -} - -// RHS implements the InterpretableBooleanBinaryOp interface method. -func (or *evalOr) RHS() Interpretable { - return or.rhs -} - // Eval implements the Interpretable interface method. func (or *evalOr) Eval(ctx Activation) ref.Val { // short-circuit lhs. @@ -271,16 +250,6 @@ func (and *evalAnd) ID() int64 { return and.id } -// LHS implements the InterpretableBooleanBinaryOp interface method. -func (and *evalAnd) LHS() Interpretable { - return and.lhs -} - -// RHS implements the InterpretableBooleanBinaryOp interface method. -func (and *evalAnd) RHS() Interpretable { - return and.rhs -} - // Eval implements the Interpretable interface method. func (and *evalAnd) Eval(ctx Activation) ref.Val { // short-circuit lhs. @@ -1061,16 +1030,6 @@ func (or *evalExhaustiveOr) ID() int64 { return or.id } -// LHS implements the InterpretableBooleanBinaryOp interface method. -func (or *evalExhaustiveOr) LHS() Interpretable { - return or.lhs -} - -// RHS implements the InterpretableBooleanBinaryOp interface method. -func (or *evalExhaustiveOr) RHS() Interpretable { - return or.rhs -} - // Eval implements the Interpretable interface method. func (or *evalExhaustiveOr) Eval(ctx Activation) ref.Val { lVal := or.lhs.Eval(ctx) @@ -1117,16 +1076,6 @@ func (and *evalExhaustiveAnd) ID() int64 { return and.id } -// LHS implements the InterpretableBooleanBinaryOp interface method. -func (and *evalExhaustiveAnd) LHS() Interpretable { - return and.lhs -} - -// RHS implements the InterpretableBooleanBinaryOp interface method. -func (and *evalExhaustiveAnd) RHS() Interpretable { - return and.rhs -} - // Eval implements the Interpretable interface method. func (and *evalExhaustiveAnd) Eval(ctx Activation) ref.Val { lVal := and.lhs.Eval(ctx) diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 629e9727..c46feea8 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -39,8 +39,12 @@ type Interpreter interface { } // EvalObserver is a functional interface that accepts an expression id and an observed value. +// The id identifies the expression that was evaluated, the programStep is the Interpretable or Qualifier that +// was evaluated and value is the result of the evaluation. type EvalObserver func(id int64, programStep interface{}, value ref.Val) +// Observe constructs a decorator that calls all the provided observers in order after evaluating each Interpretable +// or Qualifier during program evaluation. func Observe(observers ...EvalObserver) InterpretableDecorator { if len(observers) == 1 { return decObserveEval(observers[0]) @@ -53,6 +57,29 @@ func Observe(observers ...EvalObserver) InterpretableDecorator { return decObserveEval(observeFn) } +// EvalCancelledError represents a cancelled program evaluation operation. +type EvalCancelledError struct { + Message string + // Type identifies the cause of the cancellation. + Cause CancellationCause +} + +func (e EvalCancelledError) Error() string { + return e.Message +} + +// CancellationCause enumerates the ways a program evaluation operation can be cancelled. +type CancellationCause int + +const ( + // ContextCancelled indicates that the operation was cancelled in response to a Golang context cancellation. + ContextCancelled CancellationCause = iota + + // CostLimitExceeded indicates that the operation was cancelled in response to the actual cost limit being + // exceeded. + CostLimitExceeded +) + // TODO: Replace all usages of TrackState with EvalStateObserver // TrackState decorates each expression node with an observer which records the value @@ -74,23 +101,6 @@ func EvalStateObserver(state EvalState) EvalObserver { } } -// ExhaustiveEvalWrapper replaces operations that short-circuit with versions that evaluate -// expressions and couples this behavior with the TrackState observer to provide -// insight into the evaluation state of the entire expression. The EvalStateObserver must be -// included in the underlying decorator. This decorator is not thread-safe, and the EvalState -// must be reset between Eval() calls. -func ExhaustiveEvalWrapper(underlying InterpretableDecorator) InterpretableDecorator { - ex := decDisableShortcircuits() - return func(i Interpretable) (Interpretable, error) { - var err error - i, err = ex(i) - if err != nil { - return nil, err - } - return underlying(i) - } -} - // TODO: Replace all usages of ExhaustiveEval with ExhaustiveEvalWrapper // ExhaustiveEval replaces operations that short-circuit with versions that evaluate @@ -98,17 +108,10 @@ func ExhaustiveEvalWrapper(underlying InterpretableDecorator) InterpretableDecor // insight into the evaluation state of the entire expression. EvalState must be // provided to the decorator. This decorator is not thread-safe, and the EvalState // must be reset between Eval() calls. -// DEPRECATED: Please use ExhaustiveEvalWrapper instead. It composes gracefully with additional observers. -func ExhaustiveEval(state EvalState) InterpretableDecorator { +func ExhaustiveEval() InterpretableDecorator { ex := decDisableShortcircuits() - obs := TrackState(state) return func(i Interpretable) (Interpretable, error) { - var err error - i, err = ex(i) - if err != nil { - return nil, err - } - return obs(i) + return ex(i) } } diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 0e297c4a..ccd28a4f 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -1402,13 +1402,13 @@ func TestInterpreter(t *testing.T) { } } state := NewEvalState() - opts := map[string]InterpretableDecorator{ - "optimize": Optimize(), - "exhaustive": ExhaustiveEval(state), - "track": TrackState(state), + opts := map[string][]InterpretableDecorator{ + "optimize": {Optimize()}, + "exhaustive": {ExhaustiveEval(), Observe(EvalStateObserver(state))}, + "track": {Observe(EvalStateObserver(state))}, } for mode, opt := range opts { - opts := []InterpretableDecorator{opt} + opts := opt if tc.extraOpts != nil { opts = append(opts, tc.extraOpts...) } @@ -1542,7 +1542,7 @@ func TestInterpreter_ExhaustiveConditionalExpr(t *testing.T) { intr := NewStandardInterpreter(cont, reg, reg, attrs) interpretable, _ := intr.NewUncheckedInterpretable( parsed.GetExpr(), - ExhaustiveEval(state)) + ExhaustiveEval(), Observe(EvalStateObserver(state))) vars, _ := NewActivation(map[string]interface{}{ "a": types.True, "b": types.Double(0.999), @@ -1631,7 +1631,7 @@ func TestInterpreter_ExhaustiveLogicalOrEquals(t *testing.T) { interp := NewStandardInterpreter(cont, reg, reg, attrs) i, _ := interp.NewUncheckedInterpretable( parsed.GetExpr(), - ExhaustiveEval(state)) + ExhaustiveEval(), Observe(EvalStateObserver(state))) vars, _ := NewActivation(map[string]interface{}{ "a": true, "b": "b", diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index 34d91df6..507bdc7f 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -143,7 +143,7 @@ func TestPrune(t *testing.T) { interpretable, _ := interp.NewUncheckedInterpretable( ast.Expr, - ExhaustiveEval(state)) + ExhaustiveEval(), Observe(EvalStateObserver(state))) interpretable.Eval(testActivation(t, tst.in)) newExpr := PruneAst(ast.Expr, state) actual, err := parser.Unparse(newExpr, nil) diff --git a/interpreter/runtimecost.go b/interpreter/runtimecost.go index 5445016e..9e9287b6 100644 --- a/interpreter/runtimecost.go +++ b/interpreter/runtimecost.go @@ -21,13 +21,15 @@ package interpreter import ( "math" - "github.com/google/cel-go/checker" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" ) +// WARNING: Any changes to cost calculations in this file require a corresponding change in checker/cost.go + // ActualCostEstimator provides function call cost estimations at runtime // CallCost returns an estimated cost for the function overload invocation with the given args, or nil if it has no // estimate to provide. CEL attempts to provide reasonable estimates for its standard function library, so CallCost @@ -40,52 +42,54 @@ type ActualCostEstimator interface { func CostObserver(tracker *CostTracker) EvalObserver { observer := func(id int64, programStep interface{}, val ref.Val) { switch t := programStep.(type) { - case InterpretableConst: - // zero cost case ConstantQualifier: + // TODO: Push identifiers on to the stack before observing constant qualifiers that apply to them + // and enable the below pop. Once enabled this can case can be collapsed into the Qualifier case. + //tracker.stack.pop(1) tracker.cost += 1 + case InterpretableConst: + // zero cost case InterpretableAttribute: - tracker.cost += 1 + // Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions. + _, isConditional := t.Attr().(*conditionalAttribute) + if !isConditional { + tracker.cost += common.SelectAndIdentCost + } + case *evalExhaustiveConditional, *evalOr, *evalAnd, *evalExhaustiveOr, *evalExhaustiveAnd: + // Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions. case Qualifier: - tracker.stack.pop() + tracker.stack.pop(1) tracker.cost += 1 case InterpretableCall: - argVals := make([]ref.Val, len(t.Args())) - argsFound := true - for i := len(t.Args()) - 1; i >= 0; i-- { - if v, ok := tracker.stack.pop(); ok { - argVals[i] = v - } else { - // should never happen - argsFound = false - } - } - if argsFound { + if argVals, ok := tracker.stack.pop(len(t.Args())); ok { tracker.cost += tracker.costCall(t, argVals) } case InterpretableConstructor: switch t.Type() { case types.ListType: - tracker.cost += 10 + tracker.cost += common.ListCreateBaseCost case types.MapType: - tracker.cost += 30 + tracker.cost += common.MapCreateBaseCost default: - tracker.cost += 40 + tracker.cost += common.StructCreateBaseCost } - case InterpretableBooleanBinaryOp: - tracker.cost += 1 } tracker.stack.push(val) + + if tracker.Limit != nil && tracker.cost > *tracker.Limit { + panic(EvalCancelledError{Cause: CostLimitExceeded, Message: "operation cancelled: actual cost limit exceeded"}) + } } return observer } // CostTracker represents the information needed for tacking runtime cost type CostTracker struct { - estimator *checker.CostEstimator - CallCostEstimator ActualCostEstimator - cost uint64 - stack refValStack + Estimator ActualCostEstimator + Limit *uint64 + + cost uint64 + stack refValStack } // ActualCost returns the runtime cost @@ -95,8 +99,8 @@ func (c CostTracker) ActualCost() uint64 { func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val) uint64 { var cost uint64 - if c.CallCostEstimator != nil { - callCost := c.CallCostEstimator.CallCost(call.OverloadID(), argValues) + if c.Estimator != nil { + callCost := c.Estimator.CallCost(call.OverloadID(), argValues) if callCost != nil { cost += *callCost return cost @@ -107,7 +111,7 @@ func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val) uint6 switch call.OverloadID() { // O(n) functions case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString: - cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])) * 0.1)) + cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor)) case overloads.InList: // If a list is composed entirely of constant values this is O(1), but we don't account for that here. // We just assume all list containment checks are O(n). @@ -119,37 +123,43 @@ func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val) uint6 // When we check the equality of 2 scalar values (e.g. 2 integers, 2 floating-point numbers, 2 booleans etc.), // the CostTracker.actualSize() function by definition returns 1 for each operand, resulting in an overall cost // of 1. - lhsCost := c.actualSize(argValues[0]) - rhsCost := c.actualSize(argValues[1]) - if lhsCost > rhsCost { - cost += rhsCost - } else { - cost += lhsCost + lhsSize := c.actualSize(argValues[0]) + rhsSize := c.actualSize(argValues[1]) + minSize := lhsSize + if rhsSize < minSize { + minSize = rhsSize } + cost += uint64(math.Ceil(float64(minSize) * common.StringTraversalCostFactor)) // O(m+n) functions - case overloads.AddString, overloads.AddBytes, overloads.AddList: + case overloads.AddString, overloads.AddBytes: // In the worst case scenario, we would need to reallocate a new backing store and copy both operands over. - cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])+c.actualSize(argValues[1])) * 0.1)) + cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])+c.actualSize(argValues[1])) * common.StringTraversalCostFactor)) // O(nm) functions case overloads.MatchesString: // https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL - strCost := uint64(math.Ceil(float64(c.actualSize(argValues[0])) * 0.1)) + // Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0 + // in case where string is empty but regex is still expensive. + strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(argValues[0]))) * common.StringTraversalCostFactor)) // We don't know how many expressions are in the regex, just the string length (a huge // improvement here would be to somehow get a count the number of expressions in the regex or // how many states are in the regex state machine and use that to measure regex cost). // For now, we're making a guess that each expression in a regex is typically at least 4 chars // in length. - regexCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * 0.25)) + regexCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.RegexStringLengthCostFactor)) cost += strCost * regexCost case overloads.ContainsString: - strCost := uint64(math.Ceil(float64(c.actualSize(argValues[0])) * 0.1)) - substrCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * 0.1)) + strCost := uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor)) + substrCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.StringTraversalCostFactor)) cost += strCost * substrCost + default: // The following operations are assumed to have O(1) complexity. - // 1. Concatenation of 2 lists: see the implementation of the concatList type. - // 2. Computing the size of strings, byte sequences, lists and maps: presumably, the length of each of these - // data structures are cached and can be retrieved in constant time. + // - AddList due to the implementation. Index lookup can be O(c) the + // number of concatenated lists, but we don't track that is cost calculations. + // - Conversions, since none perform a traversal of a type of unbound length. + // - Computing the size of strings, byte sequences, lists and maps. + // - Logical operations and all operators on fixed width scalars (comparisons, equality) + // - Any functions that don't have a declared cost either here or in provided ActualCostEstimator. cost += 1 } @@ -159,7 +169,7 @@ func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val) uint6 // actualSize returns the size of value func (c CostTracker) actualSize(value ref.Val) uint64 { if sz, ok := value.(traits.Sizer); ok { - return uint64(sz.Size().Value().(int64)) + return uint64(sz.Size().(types.Int)) } return 1 } @@ -171,12 +181,12 @@ func (s *refValStack) push(value ref.Val) { *s = append(*s, value) } -func (s *refValStack) pop() (ref.Val, bool) { - if len(*s) == 0 { +func (s *refValStack) pop(count int) ([]ref.Val, bool) { + if len(*s) < count { return nil, false } - idx := len(*s) - 1 - el := (*s)[idx] + idx := len(*s) - count + el := (*s)[idx:] *s = (*s)[:idx] return el, true } diff --git a/cel/runtimecost_test.go b/interpreter/runtimecost_test.go similarity index 61% rename from cel/runtimecost_test.go rename to interpreter/runtimecost_test.go index 7a694946..4545854e 100644 --- a/cel/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -12,51 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -package cel +package interpreter import ( + "fmt" "math/rand" "reflect" + "strings" "testing" "time" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/checker" "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/overloads" - "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" - "github.com/google/cel-go/interpreter" + "github.com/google/cel-go/parser" "github.com/google/cel-go/test/proto3pb" ) -type testInfo struct { - env *Env - in interface{} - lhsExpr string - rhsExpr string -} - -func computeCosts(t *testing.T, info *testInfo) (lhsCost, rhsCost uint64) { - t.Helper() - - env := info.env - if env == nil { - emptyEnv, err := NewEnv() - if err != nil { - t.Fatalf("Failed to create empty environment, error: %v", err) - } - env = emptyEnv - } - ctx := constructActivation(t, info.in) - lhsCost = computeCost(t, env, info.lhsExpr, &ctx) - rhsCost = computeCost(t, env, info.rhsExpr, &ctx) - - return lhsCost, rhsCost -} - func TestTrackCostAdvanced(t *testing.T) { - var equalCases = []testInfo{ + var equalCases = []struct { + in interface{} + lhsExpr string + rhsExpr string + }{ { lhsExpr: `1`, rhsExpr: `2`, @@ -70,14 +53,29 @@ func TestTrackCostAdvanced(t *testing.T) { rhsExpr: `2 in [15, 17, 16]`, }, } - for i, testCase := range equalCases { - lhsCost, rhsCost := computeCosts(t, &testCase) - if lhsCost != rhsCost { - t.Errorf(`Expected equal cost case #%d, expressions "%s" vs. "%s", respective cost %d vs. %d`, i, - testCase.lhsExpr, testCase.rhsExpr, lhsCost, rhsCost) - } + for _, tc := range equalCases { + t.Run(tc.lhsExpr+" vs "+tc.rhsExpr, func(t *testing.T) { + ctx := constructActivation(t, tc.in) + lhsCost, _, err := computeCost(t, tc.lhsExpr, nil, ctx, nil) + if err != nil { + t.Fatalf("Interpreter.Eval(activation Activation) failed to eval expression due: %v", err) + } + rhsCost, _, err := computeCost(t, tc.rhsExpr, nil, ctx, nil) + if err != nil { + t.Fatalf("Interpreter.Eval(activation Activation) failed to eval expression due: %v", err) + } + if lhsCost != rhsCost { + t.Errorf(`Interpreter.Eval(activation Activation) failed return a cost for %s of %d equal to a cost for %s of %d`, + tc.lhsExpr, lhsCost, tc.rhsExpr, rhsCost) + } + }) + } - var smallerCases = []testInfo{ + var smallerCases = []struct { + in interface{} + lhsExpr string + rhsExpr string + }{ { lhsExpr: `1`, rhsExpr: `1 + 2`, @@ -91,50 +89,79 @@ func TestTrackCostAdvanced(t *testing.T) { rhsExpr: `1 in [4, 5, 6, 7, 8, 9]`, }, } - for i, testCase := range smallerCases { - lhsCost, rhsCost := computeCosts(t, &testCase) - if lhsCost >= rhsCost { - t.Errorf(`Expected smaller cost case #%d, expect the cost of expression "%s" to be strictly smaller than "%s", but got %d vs. %d respectively`, - i, testCase.lhsExpr, testCase.rhsExpr, lhsCost, rhsCost) - } + for _, tc := range smallerCases { + t.Run(tc.lhsExpr+" vs "+tc.rhsExpr, func(t *testing.T) { + ctx := constructActivation(t, tc.in) + lhsCost, _, err := computeCost(t, tc.lhsExpr, nil, ctx, nil) + if err != nil { + t.Fatalf("Interpreter.Eval(activation Activation) failed to eval expression due: %v", err) + } + rhsCost, _, err := computeCost(t, tc.rhsExpr, nil, ctx, nil) + if err != nil { + t.Fatalf("Interpreter.Eval(activation Activation) failed to eval expression due: %v", err) + } + if lhsCost >= rhsCost { + t.Errorf(`Interpreter.Eval(activation Activation) failed return a cost for %s of %d less than the cost for %s of %d`, + tc.lhsExpr, lhsCost, tc.rhsExpr, rhsCost) + } + }) } } -func computeCost(t *testing.T, env *Env, expr string, ctx *interpreter.Activation) uint64 { +func computeCost(t *testing.T, expr string, decls []*exprpb.Decl, ctx Activation, limit *uint64) (cost uint64, est checker.CostEstimate, err error) { t.Helper() - // Compile and check expression. - ast, iss := env.Compile(expr) - if iss.Err() != nil { - t.Fatalf(`Failed to compile expression "%s", error: %v`, expr, iss.Err()) + s := common.NewTextSource(expr) + p, err := parser.NewParser(parser.Macros(parser.AllMacros...)) + if err != nil { + t.Fatalf("Failed to initialize parser: %v", err) } - checked_ast, iss := env.Check(ast) - if iss.Err() != nil { - t.Fatalf(`Failed to check expression "%s", error: %v`, expr, iss.Err()) + parsed, errs := p.Parse(s) + if len(errs.GetErrors()) != 0 { + t.Fatalf(`Failed to Parse expression "%s", error: %v`, expr, errs.GetErrors()) } - // Evaluate expression. - program, err := env.Program(checked_ast, EvalOptions(OptTrackCost)) + cont := containers.DefaultContainer + reg := newTestRegistry(t, &proto3pb.TestAllTypes{}) + attrs := NewAttributeFactory(cont, reg, reg) + env := newTestEnv(t, cont, reg) + err = env.Add(decls...) if err != nil { - t.Fatalf(`Failed to construct Program from expression "%s", error: %v`, expr, err) + t.Fatalf("Failed to initialize env: %v", err) } - _, details, err := program.Eval(*ctx) - if err != nil { - t.Fatalf(`Failed to evaluate expression "%s", error: %v`, expr, err) + + checked, errs := checker.Check(parsed, s, env) + if len(errs.GetErrors()) != 0 { + t.Fatalf(`Failed to check expression "%s", error: %v`, expr, errs.GetErrors()) } - costPtr := details.ActualCost() - if costPtr == nil { - t.Fatalf(`Null pointer returned for the cost of expression "%s"`, expr) + est = checker.Cost(checked, testCostEstimator{}) + interp := NewStandardInterpreter(cont, reg, reg, attrs) + costTracker := &CostTracker{Estimator: &testRuntimeCostEstimator{}, Limit: limit} + prg, err := interp.NewInterpretable(checked, Observe(CostObserver(costTracker))) + if err != nil { + t.Fatalf(`Failed to check expression "%s", error: %v`, expr, errs.GetErrors()) } - return *costPtr + + defer func() { + if r := recover(); r != nil { + switch t := r.(type) { + case EvalCancelledError: + err = t + default: + err = fmt.Errorf("internal error: %v", r) + } + } + }() + prg.Eval(ctx) + return costTracker.cost, est, err } -func constructActivation(t *testing.T, in interface{}) interpreter.Activation { +func constructActivation(t *testing.T, in interface{}) Activation { t.Helper() if in == nil { - return interpreter.EmptyActivation() + return EmptyActivation() } - a, err := interpreter.NewActivation(in) + a, err := NewActivation(in) if err != nil { t.Fatalf("NewActivation(%v) failed: %v", in, err) } @@ -178,6 +205,25 @@ func (e testRuntimeCostEstimator) CallCost(overloadId string, args []ref.Val) *u } } +type testCostEstimator struct { + hints map[string]int64 +} + +func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { + if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok { + return &checker.SizeEstimate{Min: 0, Max: uint64(l)} + } + return nil +} + +func (tc testCostEstimator) EstimateCallCost(overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + switch overloadId { + case overloads.TimestampToYear: + return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}} + } + return nil +} + func TestRuntimeCost(t *testing.T) { allTypes := decls.NewObjectType("google.expr.proto3.test.TestAllTypes") allList := decls.NewListType(allTypes) @@ -193,6 +239,9 @@ func TestRuntimeCost(t *testing.T) { want uint64 in interface{} testFuncCost bool + limit uint64 + + expectExceedsLimit bool }{ { name: "const", @@ -213,6 +262,13 @@ func TestRuntimeCost(t *testing.T) { want: 2, in: map[string]interface{}{"input": map[string]string{"key": "v"}}, }, + { + name: "select: array index", + expr: `input[1]`, + decls: []*exprpb.Decl{decls.NewVar("input", decls.NewListType(decls.String))}, + want: 2, + in: map[string]interface{}{"input": []string{"v"}}, + }, { name: "select: field", expr: `input.single_int32`, @@ -228,6 +284,20 @@ func TestRuntimeCost(t *testing.T) { }, }, }, + { + name: "expr select: map", + expr: `input['ke' + 'y']`, + decls: []*exprpb.Decl{decls.NewVar("input", decls.NewMapType(decls.String, decls.String))}, + want: 3, + in: map[string]interface{}{"input": map[string]string{"key": "v"}}, + }, + { + name: "expr select: array index", + expr: `input[3-2]`, + decls: []*exprpb.Decl{decls.NewVar("input", decls.NewListType(decls.String))}, + want: 3, + in: map[string]interface{}{"input": []string{"v"}}, + }, { name: "select: field test only", expr: `has(input.single_int32)`, @@ -287,13 +357,13 @@ func TestRuntimeCost(t *testing.T) { { name: "all comprehension on literal", expr: `[1, 2, 3].all(x, true)`, - want: 23, + want: 20, }, { name: "variable cost function", decls: []*exprpb.Decl{decls.NewVar("input", decls.String)}, expr: `input.matches('[0-9]')`, - want: 101, + want: 103, in: map[string]interface{}{"input": string(randSeq(500))}, }, { @@ -304,12 +374,12 @@ func TestRuntimeCost(t *testing.T) { { name: "or", expr: `true || false`, - want: 1, + want: 0, }, { name: "and", expr: `true && false`, - want: 1, + want: 0, }, { name: "lt", @@ -369,13 +439,18 @@ func TestRuntimeCost(t *testing.T) { { name: "ternary", expr: `true ? 1 : 2`, - want: 1, + want: 0, }, { name: "string size", expr: `size("123")`, want: 1, }, + { + name: "str eq str", + expr: `'12345678901234567890' == '123456789012345678901234567890'`, + want: 2, + }, { name: "bytes to string conversion", decls: []*exprpb.Decl{decls.NewVar("input", decls.Bytes)}, @@ -411,7 +486,7 @@ func TestRuntimeCost(t *testing.T) { decls: []*exprpb.Decl{ decls.NewVar("input", decls.String), }, - want: 101, + want: 103, in: map[string]interface{}{"input": string(randSeq(500)), "arg1": string(randSeq(500))}, }, { @@ -460,16 +535,28 @@ func TestRuntimeCost(t *testing.T) { decls.NewVar("input1", allList), decls.NewVar("input2", allList), }, - want: 8, + want: 6, in: map[string]interface{}{"input1": []proto3pb.TestAllTypes{proto3pb.TestAllTypes{}}, "input2": []proto3pb.TestAllTypes{proto3pb.TestAllTypes{}}, "x": 1}, }, + { + name: "ternary eval trivial, true", + expr: `true ? false : 1 > 3`, + want: 0, + in: map[string]interface{}{}, + }, + { + name: "ternary eval trivial, false", + expr: `false ? false : 1 > 3`, + want: 1, + in: map[string]interface{}{}, + }, { name: "comprehension over map", expr: `input.all(k, input[k].single_int32 > 3)`, decls: []*exprpb.Decl{ decls.NewVar("input", allMap), }, - want: 10, + want: 9, in: map[string]interface{}{"input": map[string]interface{}{"val": &proto3pb.TestAllTypes{}}}, }, { @@ -515,7 +602,7 @@ func TestRuntimeCost(t *testing.T) { decls.NewVar("list1", decls.NewListType(decls.Int)), decls.NewVar("list2", decls.NewListType(decls.Int)), }, - want: 3, + want: 4, in: map[string]interface{}{"list1": []int{}, "list2": []int{}}, }, { @@ -528,50 +615,69 @@ func TestRuntimeCost(t *testing.T) { want: 6, in: map[string]interface{}{"str1": "val1", "str2": "val2222222"}, }, + { + name: "at limit", + expr: `"abcdefg".contains(str1 + str2)`, + decls: []*exprpb.Decl{ + decls.NewVar("str1", decls.String), + decls.NewVar("str2", decls.String), + }, + in: map[string]interface{}{"str1": "val1", "str2": "val2222222"}, + limit: 6, + want: 6, + }, + { + name: "above limit", + expr: `"abcdefg".contains(str1 + str2)`, + decls: []*exprpb.Decl{ + decls.NewVar("str1", decls.String), + decls.NewVar("str2", decls.String), + }, + in: map[string]interface{}{"str1": "val1", "str2": "val2222222"}, + limit: 5, + expectExceedsLimit: true, + }, + { + name: "ternary as operand", + expr: `(1 > 2 ? 5 : 3) > 1`, + decls: []*exprpb.Decl{}, + in: map[string]interface{}{}, + want: 2, + }, + { + name: "ternary as operand", + expr: `(1 > 2 || 2 > 1) == true`, + decls: []*exprpb.Decl{}, + in: map[string]interface{}{}, + want: 3, + }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - e, err := NewEnv( - Declarations(tc.decls...), - Types(&proto3pb.TestAllTypes{}), - CustomTypeAdapter(types.DefaultTypeAdapter)) - if err != nil { - t.Fatalf("environment creation error: %s\n", err) - } - ast, iss := e.Compile(tc.expr) - if iss.Err() != nil { - t.Fatal(iss.Err()) - } - ctx := constructActivation(t, tc.in) - checked_ast, iss := e.Check(ast) - if iss.Err() != nil { - t.Fatalf(`Failed to check expression with error: %v`, iss.Err()) - } - // Evaluate expression. - var program Program - if tc.testFuncCost { - program, err = e.Program(checked_ast, ActualCostTracking(testRuntimeCostEstimator{})) - } else { - program, err = e.Program(checked_ast, EvalOptions(OptTrackCost)) - } - if err != nil { - t.Fatalf(`Failed to construct Program with error: %v`, err) + var costLimit *uint64 + if tc.limit > 0 { + costLimit = &tc.limit } - _, details, err := program.Eval(ctx) + actualCost, est, err := computeCost(t, tc.expr, tc.decls, ctx, costLimit) if err != nil { - t.Fatalf(`Failed to evaluate expression with error: %v`, err) + if tc.expectExceedsLimit { + return + } + t.Fatalf("Interpreter.Eval(activation Activation) failed due to: %v", err) } - actualCost := details.ActualCost() - if actualCost == nil { - t.Fatalf(`Null pointer returned for the cost of expression "%s"`, tc.expr) + if tc.expectExceedsLimit { + t.Fatalf("Interpreter.Eval(activation Activation) failed to return a cost exceeded error for limit %d, got cost %d", tc.limit, actualCost) } - if *actualCost != tc.want { - t.Fatalf("runtime cost %d does not match expected runtime cost %d", *actualCost, tc.want) + if actualCost != tc.want { + t.Fatalf("Interpreter.Eval(activation Activation) failed to return expected runtime cost %d, got %d", tc.want, actualCost) + } + if est.Min > actualCost || est.Max < actualCost { + t.Fatalf("Interpreter.Eval(activation Activation) failed to return cost in range of estimate cost [%d, %d], got %d", + est.Min, est.Max, actualCost) } - }) } }