diff --git a/cel/cel_test.go b/cel/cel_test.go index ab9c1b0e..d4ad99ad 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -91,7 +91,12 @@ func Test_ExampleWithBuiltins(t *testing.T) { } func TestEval(t *testing.T) { - env, err := NewEnv(Variable("input", ListType(IntType))) + env, err := NewEnv( + Variable("input", ListType(IntType)), + CostEstimatorOptions( + checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear), + ), + ) if err != nil { t.Fatalf("NewEnv() failed: %v", err) } @@ -114,6 +119,9 @@ func TestEval(t *testing.T) { ctx := context.Background() prgOpts := []ProgramOption{ CostTracking(testRuntimeCostEstimator{}), + CostTrackerOptions( + interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear), + ), EvalOptions(OptOptimize, OptTrackCost), InterruptCheckFrequency(100), } @@ -1338,7 +1346,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { name string expr string decls []EnvOption - hints map[string]int64 + hints map[string]uint64 want checker.CostEstimate in any }{ @@ -1362,7 +1370,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { Variable("str1", StringType), Variable("str2", StringType), }, - hints: map[string]int64{"str1": 10, "str2": 10}, + hints: map[string]uint64{"str1": 10, "str2": 10}, want: checker.CostEstimate{Min: 2, Max: 6}, in: map[string]any{"str1": "val1111111", "str2": "val2222222"}, }, @@ -1373,9 +1381,15 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() if tc.hints == nil { - tc.hints = map[string]int64{} + tc.hints = map[string]uint64{} } - env := testEnv(t, tc.decls...) + envOpts := []EnvOption{ + CostEstimatorOptions( + checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear), + ), + } + envOpts = append(envOpts, tc.decls...) + env := testEnv(t, envOpts...) ast, iss := env.Compile(tc.expr) if iss.Err() != nil { t.Fatalf("env.Compile(%v) failed: %v", tc.expr, iss.Err()) @@ -1394,7 +1408,12 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { t.Fatalf(`Env.Check(ast *Ast) failed to check expression: %v`, iss.Err()) } // Evaluate expression. - program, err := env.Program(checkedAst, CostTracking(testRuntimeCostEstimator{})) + program, err := env.Program(checkedAst, + CostTracking(testRuntimeCostEstimator{}), + CostTrackerOptions( + interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear), + ), + ) if err != nil { t.Fatalf(`Env.Program(ast *Ast, opts ...ProgramOption) failed to construct program: %v`, err) } @@ -2631,27 +2650,26 @@ func BenchmarkDynamicDispatch(b *testing.B) { // TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package type testCostEstimator struct { - hints map[string]int64 + hints map[string]uint64 } func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok { - return &checker.SizeEstimate{Min: 0, Max: uint64(l)} + return &checker.SizeEstimate{Min: 0, Max: l} } return nil } func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { - switch overloadID { - case overloads.TimestampToYear: - return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}} - } return nil } -type testRuntimeCostEstimator struct { +func estimateTimestampToYear(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}} } +type testRuntimeCostEstimator struct{} + var timeToYearCost uint64 = 7 func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 { @@ -2667,13 +2685,11 @@ func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []r argsSize[i] = 1 } } + return nil +} - switch overloadID { - case overloads.TimestampToYear: - return &timeToYearCost - default: - return nil - } +func trackTimestampToYear(args []ref.Val, result ref.Val) *uint64 { + return &timeToYearCost } func testEnv(t testing.TB, opts ...EnvOption) *Env { diff --git a/cel/env.go b/cel/env.go index 69ba34db..b5c3b4cc 100644 --- a/cel/env.go +++ b/cel/env.go @@ -119,6 +119,7 @@ type Env struct { appliedFeatures map[int]bool libraries map[string]bool validators []ASTValidator + costOptions []checker.CostOption // Internal parser representation prsr *parser.Parser @@ -181,6 +182,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) { libraries: map[string]bool{}, validators: []ASTValidator{}, progOpts: []ProgramOption{}, + costOptions: []checker.CostOption{}, }).configure(opts) } @@ -356,6 +358,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) { } validatorsCopy := make([]ASTValidator, len(e.validators)) copy(validatorsCopy, e.validators) + costOptsCopy := make([]checker.CostOption, len(e.costOptions)) + copy(costOptsCopy, e.costOptions) ext := &Env{ Container: e.Container, @@ -371,6 +375,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) { provider: provider, chkOpts: chkOptsCopy, prsrOpts: prsrOptsCopy, + costOptions: costOptsCopy, } return ext.configure(opts) } @@ -557,7 +562,10 @@ func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...ch TypeMap: ast.typeMap, ReferenceMap: ast.refMap, } - return checker.Cost(checked, estimator, opts...) + extendedOpts := make([]checker.CostOption, 0, len(e.costOptions)) + extendedOpts = append(extendedOpts, opts...) + extendedOpts = append(extendedOpts, e.costOptions...) + return checker.Cost(checked, estimator, extendedOpts...) } // configure applies a series of EnvOptions to the current environment. diff --git a/cel/options.go b/cel/options.go index d47f55d8..05867730 100644 --- a/cel/options.go +++ b/cel/options.go @@ -23,6 +23,7 @@ import ( "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/dynamicpb" + "github.com/google/cel-go/checker" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/types" @@ -469,6 +470,24 @@ func InterruptCheckFrequency(checkFrequency uint) ProgramOption { } } +// CostEstimatorOptions configure type-check time options for estimating expression cost. +func CostEstimatorOptions(costOpts ...checker.CostOption) EnvOption { + return func(e *Env) (*Env, error) { + e.costOptions = append(e.costOptions, costOpts...) + return e, nil + } +} + +// CostTrackerOptions configures a set of options for cost-tracking. +// +// Note, CostTrackerOptions is a no-op unless CostTracking is also enabled. +func CostTrackerOptions(costOpts ...interpreter.CostTrackerOption) ProgramOption { + return func(p *prog) (*prog, error) { + p.costOptions = append(p.costOptions, costOpts...) + return p, nil + } +} + // CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls. func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption { return func(p *prog) (*prog, error) { diff --git a/cel/program.go b/cel/program.go index 11c5c447..2dd72f75 100644 --- a/cel/program.go +++ b/cel/program.go @@ -106,7 +106,7 @@ func (ed *EvalDetails) State() interpreter.EvalState { // ActualCost returns the tracked cost through the course of execution when `CostTracking` is enabled. // Otherwise, returns nil if the cost was not enabled. func (ed *EvalDetails) ActualCost() *uint64 { - if ed.costTracker == nil { + if ed == nil || ed.costTracker == nil { return nil } cost := ed.costTracker.ActualCost() @@ -130,10 +130,14 @@ type prog struct { // Interpretable configured from an Ast and aggregate decorator set based on program options. interpretable interpreter.Interpretable callCostEstimator interpreter.ActualCostEstimator + costOptions []interpreter.CostTrackerOption costLimit *uint64 } func (p *prog) clone() *prog { + costOptsCopy := make([]interpreter.CostTrackerOption, len(p.costOptions)) + copy(costOptsCopy, p.costOptions) + return &prog{ Env: p.Env, evalOpts: p.evalOpts, @@ -155,9 +159,10 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) { // Ensure the default attribute factory is set after the adapter and provider are // configured. p := &prog{ - Env: e, - decorators: []interpreter.InterpretableDecorator{}, - dispatcher: disp, + Env: e, + decorators: []interpreter.InterpretableDecorator{}, + dispatcher: disp, + costOptions: []interpreter.CostTrackerOption{}, } // Configure the program via the ProgramOption values. @@ -242,6 +247,12 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) { factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) { costTracker.Estimator = p.callCostEstimator costTracker.Limit = p.costLimit + for _, costOpt := range p.costOptions { + err := costOpt(costTracker) + if err != nil { + return nil, err + } + } // Limit capacity to guarantee a reallocation when calling 'append(decs, ...)' below. This // prevents the underlying memory from being shared between factory function calls causing // undesired mutations. @@ -371,7 +382,11 @@ type progGen struct { // the test is successful. func newProgGen(factory progFactory) (Program, error) { // Test the factory to make sure that configuration errors are spotted at config - _, err := factory(interpreter.NewEvalState(), &interpreter.CostTracker{}) + tracker, err := interpreter.NewCostTracker(nil) + if err != nil { + return nil, err + } + _, err = factory(interpreter.NewEvalState(), tracker) if err != nil { return nil, err } @@ -384,7 +399,10 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) { // new EvalState instance for each call to ensure that unique evaluations yield unique stateful // results. state := interpreter.NewEvalState() - costTracker := &interpreter.CostTracker{} + costTracker, err := interpreter.NewCostTracker(nil) + if err != nil { + return nil, nil, err + } det := &EvalDetails{state: state, costTracker: costTracker} // Generate a new instance of the interpretable using the factory configured during the call to @@ -412,7 +430,10 @@ func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalD // new EvalState instance for each call to ensure that unique evaluations yield unique stateful // results. state := interpreter.NewEvalState() - costTracker := &interpreter.CostTracker{} + costTracker, err := interpreter.NewCostTracker(nil) + if err != nil { + return nil, nil, err + } det := &EvalDetails{state: state, costTracker: costTracker} // Generate a new instance of the interpretable using the factory configured during the call to diff --git a/checker/cost.go b/checker/cost.go index d02f2628..f232f30d 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -230,7 +230,7 @@ func addUint64NoOverflow(x, y uint64) uint64 { // multiplyUint64NoOverflow multiplies non-negative ints. If the result is exceeds math.MaxUint64, math.MaxUint64 // is returned. func multiplyUint64NoOverflow(x, y uint64) uint64 { - if x > 0 && y > 0 && x > math.MaxUint64/y { + if y != 0 && x > math.MaxUint64/y { return math.MaxUint64 } return x * y @@ -242,7 +242,11 @@ func multiplyByCostFactor(x uint64, y float64) uint64 { if xFloat > 0 && y > 0 && xFloat > math.MaxUint64/y { return math.MaxUint64 } - return uint64(math.Ceil(xFloat * y)) + ceil := math.Ceil(xFloat * y) + if ceil >= doubleTwoTo64 { + return math.MaxUint64 + } + return uint64(ceil) } var ( @@ -260,9 +264,10 @@ type coster struct { // iterRanges tracks the iterRange of each iterVar. iterRanges iterRangeScopes // computedSizes tracks the computed sizes of call results. - computedSizes map[int64]SizeEstimate - checkedAST *ast.CheckedAST - estimator CostEstimator + computedSizes map[int64]SizeEstimate + checkedAST *ast.CheckedAST + estimator CostEstimator + overloadEstimators map[string]FunctionEstimator // presenceTestCost will either be a zero or one based on whether has() macros count against cost computations. presenceTestCost CostEstimate } @@ -291,6 +296,7 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) { type CostOption func(*coster) error // PresenceTestHasCost determines whether presence testing has a cost of one or zero. +// // Defaults to presence test has a cost of one. func PresenceTestHasCost(hasCost bool) CostOption { return func(c *coster) error { @@ -303,15 +309,30 @@ func PresenceTestHasCost(hasCost bool) CostOption { } } +// FunctionEstimator provides a CallEstimate given the target and arguments for a specific function, overload pair. +type FunctionEstimator func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate + +// OverloadCostEstimate binds a FunctionCoster to a specific function overload ID. +// +// When a OverloadCostEstimate is provided, it will override the cost calculation of the CostEstimator provided to +// the Cost() call. +func OverloadCostEstimate(overloadID string, functionCoster FunctionEstimator) CostOption { + return func(c *coster) error { + c.overloadEstimators[overloadID] = functionCoster + return nil + } +} + // Cost estimates the cost of the parsed and type checked CEL expression. func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) { c := &coster{ - checkedAST: checker, - estimator: estimator, - exprPath: map[int64][]string{}, - iterRanges: map[string][]int64{}, - computedSizes: map[int64]SizeEstimate{}, - presenceTestCost: CostEstimate{Min: 1, Max: 1}, + checkedAST: checker, + estimator: estimator, + overloadEstimators: map[string]FunctionEstimator{}, + exprPath: map[int64][]string{}, + iterRanges: map[string][]int64{}, + computedSizes: map[int64]SizeEstimate{}, + presenceTestCost: CostEstimate{Min: 1, Max: 1}, } for _, opt := range opts { err := opt(c) @@ -532,7 +553,14 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args } return sum } - + if len(c.overloadEstimators) != 0 { + if estimator, found := c.overloadEstimators[overloadID]; found { + if est := estimator(c.estimator, target, args); est != nil { + callEst := *est + return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize} + } + } + } if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil { callEst := *est return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize} @@ -682,3 +710,7 @@ func isScalar(t *types.Type) bool { } return false } + +var ( + doubleTwoTo64 = math.Ldexp(1.0, 64) +) diff --git a/checker/cost_test.go b/checker/cost_test.go index a035fb04..c94c1c2b 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -15,6 +15,7 @@ package checker import ( + "math" "strings" "testing" @@ -43,7 +44,7 @@ func TestCost(t *testing.T) { name string expr string vars []*decls.VariableDecl - hints map[string]int64 + hints map[string]uint64 options []CostOption wanted CostEstimate }{ @@ -128,14 +129,14 @@ func TestCost(t *testing.T) { { name: "all comprehension", vars: []*decls.VariableDecl{decls.NewVariable("input", allList)}, - hints: map[string]int64{"input": 100}, + hints: map[string]uint64{"input": 100}, expr: `input.all(x, true)`, wanted: CostEstimate{Min: 2, Max: 302}, }, { name: "nested all comprehension", vars: []*decls.VariableDecl{decls.NewVariable("input", nestedList)}, - hints: map[string]int64{"input": 50, "input.@items": 10}, + hints: map[string]uint64{"input": 50, "input.@items": 10}, expr: `input.all(x, x.all(y, true))`, wanted: CostEstimate{Min: 2, Max: 1752}, }, @@ -147,7 +148,7 @@ func TestCost(t *testing.T) { { name: "variable cost function", vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)}, - hints: map[string]int64{"input": 500}, + hints: map[string]uint64{"input": 500}, expr: `input.matches('[0-9]')`, wanted: CostEstimate{Min: 3, Max: 103}, }, @@ -256,14 +257,14 @@ func TestCost(t *testing.T) { { name: "bytes to string conversion", vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)}, - hints: map[string]int64{"input": 500}, + hints: map[string]uint64{"input": 500}, expr: `string(input)`, wanted: CostEstimate{Min: 1, Max: 51}, }, { name: "bytes to string conversion equality", vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)}, - hints: map[string]int64{"input": 500}, + hints: map[string]uint64{"input": 500}, // equality check ensures that the resultSize calculation is included in cost expr: `string(input) == string(input)`, wanted: CostEstimate{Min: 3, Max: 152}, @@ -271,14 +272,14 @@ func TestCost(t *testing.T) { { name: "string to bytes conversion", vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)}, - hints: map[string]int64{"input": 500}, + hints: map[string]uint64{"input": 500}, expr: `bytes(input)`, wanted: CostEstimate{Min: 1, Max: 51}, }, { name: "string to bytes conversion equality", vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)}, - hints: map[string]int64{"input": 500}, + hints: map[string]uint64{"input": 500}, // equality check ensures that the resultSize calculation is included in cost expr: `bytes(input) == bytes(input)`, wanted: CostEstimate{Min: 3, Max: 302}, @@ -295,7 +296,7 @@ func TestCost(t *testing.T) { decls.NewVariable("input", types.StringType), decls.NewVariable("arg1", types.StringType), }, - hints: map[string]int64{"input": 500, "arg1": 500}, + hints: map[string]uint64{"input": 500, "arg1": 500}, wanted: CostEstimate{Min: 2, Max: 2502}, }, { @@ -304,7 +305,7 @@ func TestCost(t *testing.T) { vars: []*decls.VariableDecl{ decls.NewVariable("input", types.StringType), }, - hints: map[string]int64{"input": 500}, + hints: map[string]uint64{"input": 500}, wanted: CostEstimate{Min: 3, Max: 103}, }, { @@ -314,7 +315,7 @@ func TestCost(t *testing.T) { decls.NewVariable("input", types.StringType), decls.NewVariable("arg1", types.StringType), }, - hints: map[string]int64{"arg1": 500}, + hints: map[string]uint64{"arg1": 500}, wanted: CostEstimate{Min: 2, Max: 52}, }, { @@ -324,7 +325,7 @@ func TestCost(t *testing.T) { decls.NewVariable("input", types.StringType), decls.NewVariable("arg1", types.StringType), }, - hints: map[string]int64{"arg1": 500}, + hints: map[string]uint64{"arg1": 500}, wanted: CostEstimate{Min: 2, Max: 52}, }, { @@ -351,7 +352,7 @@ func TestCost(t *testing.T) { decls.NewVariable("input1", allList), decls.NewVariable("input2", allList), }, - hints: map[string]int64{"input1": 1, "input2": 1}, + hints: map[string]uint64{"input1": 1, "input2": 1}, wanted: CostEstimate{Min: 4, Max: 7}, }, { @@ -360,7 +361,7 @@ func TestCost(t *testing.T) { vars: []*decls.VariableDecl{ decls.NewVariable("input", allMap), }, - hints: map[string]int64{"input": 10}, + hints: map[string]uint64{"input": 10}, wanted: CostEstimate{Min: 2, Max: 82}, }, { @@ -369,7 +370,7 @@ func TestCost(t *testing.T) { vars: []*decls.VariableDecl{ decls.NewVariable("input", nestedMap), }, - hints: map[string]int64{"input": 5, "input.@values": 10}, + hints: map[string]uint64{"input": 5, "input.@values": 10}, wanted: CostEstimate{Min: 2, Max: 187}, }, { @@ -378,7 +379,7 @@ func TestCost(t *testing.T) { vars: []*decls.VariableDecl{ decls.NewVariable("input", nestedMap), }, - hints: map[string]int64{"input": 5, "input.@keys": 10}, + hints: map[string]uint64{"input": 5, "input.@keys": 10}, wanted: CostEstimate{Min: 2, Max: 32}, }, { @@ -387,7 +388,7 @@ func TestCost(t *testing.T) { vars: []*decls.VariableDecl{ decls.NewVariable("input", nestedMap), }, - hints: map[string]int64{"input": 2, "input.@values": 2, "input.@keys": 5}, + hints: map[string]uint64{"input": 2, "input.@values": 2, "input.@keys": 5}, wanted: CostEstimate{Min: 2, Max: 34}, }, { @@ -396,7 +397,7 @@ func TestCost(t *testing.T) { vars: []*decls.VariableDecl{ decls.NewVariable("input", nestedMap), }, - hints: map[string]int64{"input": 2, "input.@values": 2, "input.@keys": 5}, + hints: map[string]uint64{"input": 2, "input.@values": 2, "input.@keys": 5}, wanted: CostEstimate{Min: 2, Max: 34}, }, { @@ -406,7 +407,7 @@ func TestCost(t *testing.T) { decls.NewVariable("list1", types.NewListType(types.IntType)), decls.NewVariable("list2", types.NewListType(types.IntType)), }, - hints: map[string]int64{"list1": 10, "list2": 10}, + hints: map[string]uint64{"list1": 10, "list2": 10}, wanted: CostEstimate{Min: 4, Max: 64}, }, { @@ -416,9 +417,30 @@ func TestCost(t *testing.T) { decls.NewVariable("str1", types.StringType), decls.NewVariable("str2", types.StringType), }, - hints: map[string]int64{"str1": 10, "str2": 10}, + hints: map[string]uint64{"str1": 10, "str2": 10}, wanted: CostEstimate{Min: 2, Max: 6}, }, + { + name: "str concat custom cost estimate", + expr: `"abcdefg".contains(str1 + str2)`, + vars: []*decls.VariableDecl{ + decls.NewVariable("str1", types.StringType), + decls.NewVariable("str2", types.StringType), + }, + hints: map[string]uint64{"str1": 10, "str2": 10}, + options: []CostOption{ + OverloadCostEstimate(overloads.ContainsString, + func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate { + if target != nil && len(args) == 1 { + strSize := estimateSize(estimator, *target).MultiplyByCostFactor(0.2) + subSize := estimateSize(estimator, args[0]).MultiplyByCostFactor(0.2) + return &CallEstimate{CostEstimate: strSize.Multiply(subSize)} + } + return nil + }), + }, + wanted: CostEstimate{Min: 2, Max: 12}, + }, { name: "list size comparison", expr: `list1.size() == list2.size()`, @@ -485,7 +507,7 @@ func TestCost(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { if tc.hints == nil { - tc.hints = map[string]int64{} + tc.hints = map[string]uint64{} } p, err := parser.NewParser(parser.Macros(parser.AllMacros...)) if err != nil { @@ -530,12 +552,12 @@ func TestCost(t *testing.T) { } type testCostEstimator struct { - hints map[string]int64 + hints map[string]uint64 } func (tc testCostEstimator) EstimateSize(element AstNode) *SizeEstimate { if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok { - return &SizeEstimate{Min: 0, Max: uint64(l)} + return &SizeEstimate{Min: 0, Max: l} } return nil } @@ -547,3 +569,13 @@ func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target } return nil } + +func estimateSize(estimator CostEstimator, node AstNode) SizeEstimate { + if l := node.ComputedSize(); l != nil { + return *l + } + if l := estimator.EstimateSize(node); l != nil { + return *l + } + return SizeEstimate{Min: 0, Max: math.MaxUint64} +} diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index ebdf7d01..6fdcc60c 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -20,6 +20,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//cel:go_default_library", + "//checker:go_default_library", "//checker/decls:go_default_library", "//common/overloads:go_default_library", "//common/types:go_default_library", diff --git a/ext/sets.go b/ext/sets.go index 4820d619..833c15f6 100644 --- a/ext/sets.go +++ b/ext/sets.go @@ -15,10 +15,14 @@ package ext import ( + "math" + "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" + "github.com/google/cel-go/interpreter" ) // Sets returns a cel.EnvOption to configure namespaced set relationship @@ -95,12 +99,24 @@ func (setsLib) CompileOptions() []cel.EnvOption { cel.Function("sets.intersects", cel.Overload("list_sets_intersects_list", []*cel.Type{listType, listType}, cel.BoolType, cel.BinaryBinding(setsIntersects))), + cel.CostEstimatorOptions( + checker.OverloadCostEstimate("list_sets_contains_list", estimateSetsCost(1)), + checker.OverloadCostEstimate("list_sets_intersects_list", estimateSetsCost(1)), + // equivalence requires potentially two m*n comparisons to ensure each list is contained by the other + checker.OverloadCostEstimate("list_sets_equivalent_list", estimateSetsCost(2)), + ), } } // ProgramOptions implements the Library interface method. func (setsLib) ProgramOptions() []cel.ProgramOption { - return []cel.ProgramOption{} + return []cel.ProgramOption{ + cel.CostTrackerOptions( + interpreter.OverloadCostTracker("list_sets_contains_list", trackSetsCost(1)), + interpreter.OverloadCostTracker("list_sets_intersects_list", trackSetsCost(1)), + interpreter.OverloadCostTracker("list_sets_equivalent_list", trackSetsCost(2)), + ), + } } func setsIntersects(listA, listB ref.Val) ref.Val { @@ -136,3 +152,46 @@ func setsEquivalent(listA, listB ref.Val) ref.Val { } return setsContains(listB, listA) } + +func estimateSetsCost(costFactor float64) checker.FunctionEstimator { + return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if len(args) == 2 { + arg0Size := estimateSize(estimator, args[0]) + arg1Size := estimateSize(estimator, args[1]) + costEstimate := arg0Size.Multiply(arg1Size).MultiplyByCostFactor(costFactor).Add(callCostEstimate) + return &checker.CallEstimate{CostEstimate: costEstimate} + } + return nil + } +} + +func estimateSize(estimator checker.CostEstimator, node checker.AstNode) checker.SizeEstimate { + if l := node.ComputedSize(); l != nil { + return *l + } + if l := estimator.EstimateSize(node); l != nil { + return *l + } + return checker.SizeEstimate{Min: 0, Max: math.MaxUint64} +} + +func trackSetsCost(costFactor float64) interpreter.FunctionTracker { + return func(args []ref.Val, _ ref.Val) *uint64 { + lhsSize := actualSize(args[0]) + rhsSize := actualSize(args[1]) + cost := callCost + uint64(float64(lhsSize*rhsSize)*costFactor) + return &cost + } +} + +func actualSize(value ref.Val) uint64 { + if sz, ok := value.(traits.Sizer); ok { + return uint64(sz.Size().(types.Int)) + } + return 1 +} + +var ( + callCostEstimate = checker.CostEstimate{Min: 1, Max: 1} + callCost = uint64(1) +) diff --git a/ext/sets_test.go b/ext/sets_test.go index 5fbe3b9a..ddaf96d8 100644 --- a/ext/sets_test.go +++ b/ext/sets_test.go @@ -15,67 +15,267 @@ package ext import ( - "fmt" + "math" + "reflect" + "strings" "testing" "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" ) func TestSets(t *testing.T) { setsTests := []struct { - expr string + expr string + vars []cel.EnvOption + in map[string]any + hints map[string]uint64 + estimatedCost checker.CostEstimate + actualCost uint64 }{ // set containment - {expr: `sets.contains([], [])`}, - {expr: `sets.contains([1], [])`}, - {expr: `sets.contains([1], [1])`}, - {expr: `sets.contains([1], [1, 1])`}, - {expr: `sets.contains([1, 1], [1])`}, - {expr: `sets.contains([2, 1], [1])`}, - {expr: `sets.contains([1, 2, 3, 4], [2, 3])`}, - {expr: `sets.contains([1], [1.0, 1])`}, - {expr: `sets.contains([1, 2], [2u, 2.0])`}, - {expr: `sets.contains([1, 2u], [2, 2.0])`}, - {expr: `sets.contains([1, 2.0, 3u], [1.0, 2u, 3])`}, - {expr: `sets.contains([[1], [2, 3]], [[2, 3.0]])`}, - {expr: `!sets.contains([1], [2])`}, - {expr: `!sets.contains([1], [1, 2])`}, - {expr: `!sets.contains([1], ["1", 1])`}, - {expr: `!sets.contains([1], [1.1, 1u])`}, - // set equivalence - {expr: `sets.equivalent([], [])`}, - {expr: `sets.equivalent([1], [1])`}, - {expr: `sets.equivalent([1], [1, 1])`}, - {expr: `sets.equivalent([1, 1], [1])`}, - {expr: `sets.equivalent([1], [1u, 1.0])`}, - {expr: `sets.equivalent([1], [1u, 1.0])`}, - {expr: `sets.equivalent([1, 2, 3], [3u, 2.0, 1])`}, - {expr: `sets.equivalent([[1.0], [2, 3]], [[1], [2, 3.0]])`}, - {expr: `!sets.equivalent([2, 1], [1])`}, - {expr: `!sets.equivalent([1], [1, 2])`}, - {expr: `!sets.equivalent([1, 2], [2u, 2, 2.0])`}, - {expr: `!sets.equivalent([1, 2], [1u, 2, 2.3])`}, + { + expr: `sets.contains(x, [1, 2, 3])`, + vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.IntType))}, + in: map[string]any{"x": []int64{5, 4, 3, 2, 1}}, + hints: map[string]uint64{"x": 10}, + // min cost is input 'x' length 0, 10 for list creation, 2 for arg costs + // max cost is input 'x' lenght 10, 10 for list creation, 2 for arg costs + estimatedCost: checker.CostEstimate{Min: 12, Max: 42}, + // actual cost is 'x' length 5 * list literal length 3, 10 for list creation, 2 for arg cost + actualCost: 27, + }, + { + expr: `sets.contains(x, [1, 1, 1, 1, 1])`, + vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.IntType))}, + in: map[string]any{"x": []int64{5, 4, 3, 2, 1}}, + // min cost is input 'x' length 0, 10 for list creation, 2 for arg costs + // max cost is effectively infinite due to missing size hint for 'x' + estimatedCost: checker.CostEstimate{Min: 12, Max: math.MaxUint64}, + // actual cost is 'x' length 5 * list literal length 5, 10 for list creation, 2 for arg cost + actualCost: 37, + }, + { + expr: `sets.contains([], [])`, + estimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + actualCost: 21, + }, + { + expr: `sets.contains([1], [])`, + estimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + actualCost: 21, + }, + { + expr: `sets.contains([1], [1])`, + estimatedCost: checker.CostEstimate{Min: 22, Max: 22}, + actualCost: 22, + }, + { + expr: `sets.contains([1], [1, 1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.contains([1, 1], [1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.contains([2, 1], [1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.contains([1, 2, 3, 4], [2, 3])`, + estimatedCost: checker.CostEstimate{Min: 29, Max: 29}, + actualCost: 29, + }, + { + expr: `sets.contains([1], [1.0, 1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.contains([1, 2], [2u, 2.0])`, + estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + actualCost: 25, + }, + { + expr: `sets.contains([1, 2u], [2, 2.0])`, + estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + actualCost: 25, + }, + { + expr: `sets.contains([1, 2.0, 3u], [1.0, 2u, 3])`, + estimatedCost: checker.CostEstimate{Min: 30, Max: 30}, + actualCost: 30, + }, + { + expr: `sets.contains([[1], [2, 3]], [[2, 3.0]])`, + // 10 for each list creation, top-level list sizes are 2, 1 + estimatedCost: checker.CostEstimate{Min: 53, Max: 53}, + actualCost: 53, + }, + { + expr: `!sets.contains([1], [2])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `!sets.contains([1], [1, 2])`, + estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + actualCost: 24, + }, + { + expr: `!sets.contains([1], ["1", 1])`, + estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + actualCost: 24, + }, + { + expr: `!sets.contains([1], [1.1, 1u])`, + estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + actualCost: 24, + }, + + // set equivalence (note the cost factor is higher as it's basically two contains checks) + { + expr: `sets.equivalent([], [])`, + estimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + actualCost: 21, + }, + { + expr: `sets.equivalent([1], [1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.equivalent([1], [1, 1])`, + estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + actualCost: 25, + }, + { + expr: `sets.equivalent([1, 1], [1])`, + estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + actualCost: 25, + }, + { + expr: `sets.equivalent([1], [1u, 1.0])`, + estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + actualCost: 25, + }, + { + expr: `sets.equivalent([1], [1u, 1.0])`, + estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + actualCost: 25, + }, + { + expr: `sets.equivalent([1, 2, 3], [3u, 2.0, 1])`, + estimatedCost: checker.CostEstimate{Min: 39, Max: 39}, + actualCost: 39, + }, + { + expr: `sets.equivalent([[1.0], [2, 3]], [[1], [2, 3.0]])`, + estimatedCost: checker.CostEstimate{Min: 69, Max: 69}, + actualCost: 69, + }, + { + expr: `!sets.equivalent([2, 1], [1])`, + estimatedCost: checker.CostEstimate{Min: 26, Max: 26}, + actualCost: 26, + }, + { + expr: `!sets.equivalent([1], [1, 2])`, + estimatedCost: checker.CostEstimate{Min: 26, Max: 26}, + actualCost: 26, + }, + { + expr: `!sets.equivalent([1, 2], [2u, 2, 2.0])`, + estimatedCost: checker.CostEstimate{Min: 34, Max: 34}, + actualCost: 34, + }, + { + expr: `!sets.equivalent([1, 2], [1u, 2, 2.3])`, + estimatedCost: checker.CostEstimate{Min: 34, Max: 34}, + actualCost: 34, + }, + // set intersection - {expr: `sets.intersects([1], [1])`}, - {expr: `sets.intersects([1], [1, 1])`}, - {expr: `sets.intersects([1, 1], [1])`}, - {expr: `sets.intersects([2, 1], [1])`}, - {expr: `sets.intersects([1], [1, 2])`}, - {expr: `sets.intersects([1], [1.0, 2])`}, - {expr: `sets.intersects([1, 2], [2u, 2, 2.0])`}, - {expr: `sets.intersects([1, 2], [1u, 2, 2.3])`}, - {expr: `sets.intersects([[1], [2, 3]], [[1, 2], [2, 3.0]])`}, - {expr: `!sets.intersects([], [])`}, - {expr: `!sets.intersects([1], [])`}, - {expr: `!sets.intersects([1], [2])`}, - {expr: `!sets.intersects([1], ["1", 2])`}, - {expr: `!sets.intersects([1], [1.1, 2u])`}, + { + expr: `sets.intersects([1], [1])`, + estimatedCost: checker.CostEstimate{Min: 22, Max: 22}, + actualCost: 22, + }, + { + expr: `sets.intersects([1], [1, 1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.intersects([1, 1], [1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.intersects([2, 1], [1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.intersects([1], [1, 2])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.intersects([1], [1.0, 2])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.intersects([1, 2], [2u, 2, 2.0])`, + estimatedCost: checker.CostEstimate{Min: 27, Max: 27}, + actualCost: 27, + }, + { + expr: `sets.intersects([1, 2], [1u, 2, 2.3])`, + estimatedCost: checker.CostEstimate{Min: 27, Max: 27}, + actualCost: 27, + }, + { + expr: `sets.intersects([[1], [2, 3]], [[1, 2], [2, 3.0]])`, + estimatedCost: checker.CostEstimate{Min: 65, Max: 65}, + actualCost: 65, + }, + { + expr: `!sets.intersects([], [])`, + estimatedCost: checker.CostEstimate{Min: 22, Max: 22}, + actualCost: 22, + }, + { + expr: `!sets.intersects([1], [])`, + estimatedCost: checker.CostEstimate{Min: 22, Max: 22}, + actualCost: 22, + }, + { + expr: `!sets.intersects([1], [2])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `!sets.intersects([1], ["1", 2])`, + estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + actualCost: 24, + }, + { + expr: `!sets.intersects([1], [1.1, 2u])`, + estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + actualCost: 24, + }, } - env := testSetsEnv(t) - for i, tst := range setsTests { + for _, tst := range setsTests { tc := tst - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + t.Run(tc.expr, func(t *testing.T) { + env := testSetsEnv(t, tc.vars...) var asts []*cel.Ast pAst, iss := env.Parse(tc.expr) if iss.Err() != nil { @@ -86,20 +286,43 @@ func TestSets(t *testing.T) { if iss.Err() != nil { t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err()) } + + hints := map[string]uint64{} + if len(tc.hints) != 0 { + hints = tc.hints + } + est, err := env.EstimateCost(cAst, testSetsCostEstimator{hints: hints}) + if err != nil { + t.Fatalf("env.EstimateCost() failed: %v", err) + } + if !reflect.DeepEqual(est, tc.estimatedCost) { + t.Errorf("env.EstimateCost() got %v, wanted %v", est, tc.estimatedCost) + } asts = append(asts, cAst) for _, ast := range asts { - prg, err := env.Program(ast) + prgOpts := []cel.ProgramOption{} + if ast.IsChecked() { + prgOpts = append(prgOpts, cel.CostTracking(nil)) + } + prg, err := env.Program(ast, prgOpts...) if err != nil { t.Fatalf("env.Program() failed: %v", err) } - out, _, err := prg.Eval(cel.NoVars()) + in := tc.in + if in == nil { + in = map[string]any{} + } + out, det, err := prg.Eval(in) if err != nil { t.Fatalf("prg.Eval() failed: %v", err) } if out.Value() != true { t.Errorf("prg.Eval() got %v, wanted true for expr: %s", out.Value(), tc.expr) } + if det.ActualCost() != nil && *det.ActualCost() != tc.actualCost { + t.Errorf("prg.Eval() had cost %v, wanted %v", *det.ActualCost(), tc.actualCost) + } } }) } @@ -114,3 +337,18 @@ func testSetsEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { } return env } + +type testSetsCostEstimator struct { + hints map[string]uint64 +} + +func (tc testSetsCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { + if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok { + return &checker.SizeEstimate{Min: 0, Max: l} + } + return nil +} + +func (testSetsCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + return nil +} diff --git a/interpreter/runtimecost.go b/interpreter/runtimecost.go index 96faed2e..b9b307c1 100644 --- a/interpreter/runtimecost.go +++ b/interpreter/runtimecost.go @@ -133,6 +133,7 @@ func PresenceTestHasCost(hasCost bool) CostTrackerOption { func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (*CostTracker, error) { tracker := &CostTracker{ Estimator: estimator, + overloadTrackers: map[string]FunctionTracker{}, presenceTestHasCost: true, } for _, opt := range opts { @@ -144,9 +145,24 @@ func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (* return tracker, nil } +// OverloadCostTracker binds an overload ID to a runtime FunctionTracker implementation. +// +// OverloadCostTracker instances augment or override ActualCostEstimator decisions, allowing for versioned and/or +// optional cost tracking changes. +func OverloadCostTracker(overloadID string, fnTracker FunctionTracker) CostTrackerOption { + return func(tracker *CostTracker) error { + tracker.overloadTrackers[overloadID] = fnTracker + return nil + } +} + +// FunctionTracker computes the actual cost of evaluating the functions with the given arguments and result. +type FunctionTracker func(args []ref.Val, result ref.Val) *uint64 + // CostTracker represents the information needed for tracking runtime cost. type CostTracker struct { Estimator ActualCostEstimator + overloadTrackers map[string]FunctionTracker Limit *uint64 presenceTestHasCost bool @@ -159,10 +175,19 @@ func (c *CostTracker) ActualCost() uint64 { return c.cost } -func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, result ref.Val) uint64 { +func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result ref.Val) uint64 { var cost uint64 + if len(c.overloadTrackers) != 0 { + if tracker, found := c.overloadTrackers[call.OverloadID()]; found { + callCost := tracker(args, result) + if callCost != nil { + cost += *callCost + return cost + } + } + } if c.Estimator != nil { - callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), argValues, result) + callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), args, result) if callCost != nil { cost += *callCost return cost @@ -173,11 +198,11 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu switch call.OverloadID() { // O(n) functions case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString: - cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor)) + cost += uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor)) case overloads.InList: // If a list is composed entirely of constant values this is O(1), but we don't account for that here. // We just assume all list containment checks are O(n). - cost += c.actualSize(argValues[1]) + cost += c.actualSize(args[1]) // O(min(m, n)) functions case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString, overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes, @@ -185,8 +210,8 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu // When we check the equality of 2 scalar values (e.g. 2 integers, 2 floating-point numbers, 2 booleans etc.), // the CostTracker.actualSize() function by definition returns 1 for each operand, resulting in an overall cost // of 1. - lhsSize := c.actualSize(argValues[0]) - rhsSize := c.actualSize(argValues[1]) + lhsSize := c.actualSize(args[0]) + rhsSize := c.actualSize(args[1]) minSize := lhsSize if rhsSize < minSize { minSize = rhsSize @@ -195,23 +220,23 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu // O(m+n) functions case overloads.AddString, overloads.AddBytes: // In the worst case scenario, we would need to reallocate a new backing store and copy both operands over. - cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])+c.actualSize(argValues[1])) * common.StringTraversalCostFactor)) + cost += uint64(math.Ceil(float64(c.actualSize(args[0])+c.actualSize(args[1])) * common.StringTraversalCostFactor)) // O(nm) functions case overloads.MatchesString: // https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL // Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0 // in case where string is empty but regex is still expensive. - strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(argValues[0]))) * common.StringTraversalCostFactor)) + strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(args[0]))) * common.StringTraversalCostFactor)) // We don't know how many expressions are in the regex, just the string length (a huge // improvement here would be to somehow get a count the number of expressions in the regex or // how many states are in the regex state machine and use that to measure regex cost). // For now, we're making a guess that each expression in a regex is typically at least 4 chars // in length. - regexCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.RegexStringLengthCostFactor)) + regexCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.RegexStringLengthCostFactor)) cost += strCost * regexCost case overloads.ContainsString: - strCost := uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor)) - substrCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.StringTraversalCostFactor)) + strCost := uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor)) + substrCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.StringTraversalCostFactor)) cost += strCost * substrCost default: diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index 9a700a4b..1c6ac124 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -16,6 +16,7 @@ package interpreter import ( "fmt" + "math" "math/rand" "reflect" "strings" @@ -29,6 +30,7 @@ import ( "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/parser" proto3pb "github.com/google/cel-go/test/proto3pb" @@ -727,6 +729,25 @@ func TestRuntimeCost(t *testing.T) { want: 6, in: map[string]any{"str1": "val1", "str2": "val2222222"}, }, + { + name: "str concat custom cost tracker", + expr: `"abcdefg".contains(str1 + str2)`, + vars: []*decls.VariableDecl{ + decls.NewVariable("str1", types.StringType), + decls.NewVariable("str2", types.StringType), + }, + options: []CostTrackerOption{ + OverloadCostTracker(overloads.ContainsString, + func(args []ref.Val, result ref.Val) *uint64 { + strCost := uint64(math.Ceil(float64(actualSize(args[0])) * 0.2)) + substrCost := uint64(math.Ceil(float64(actualSize(args[1])) * 0.2)) + cost := strCost * substrCost + return &cost + }), + }, + want: 10, + in: map[string]any{"str1": "val1", "str2": "val2222222"}, + }, { name: "at limit", expr: `"abcdefg".contains(str1 + str2)`, @@ -803,3 +824,10 @@ func TestRuntimeCost(t *testing.T) { }) } } + +func actualSize(val ref.Val) uint64 { + if sz, ok := val.(traits.Sizer); ok { + return uint64(sz.Size().(types.Int)) + } + return 1 +}