Skip to content

Commit

Permalink
Add runtime cost limit
Browse files Browse the repository at this point in the history
  • Loading branch information
jpbetz committed Mar 8, 2022
1 parent 91cdb04 commit a22969d
Show file tree
Hide file tree
Showing 15 changed files with 463 additions and 296 deletions.
1 change: 0 additions & 1 deletion cel/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ go_test(
srcs = [
"cel_test.go",
"io_test.go",
"runtimecost_test.go",
],
data = [
"//cel/testdata:gen_test_fds",
Expand Down
68 changes: 41 additions & 27 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1426,6 +1426,7 @@ func TestCustomInterpreterDecorator(t *testing.T) {
}
}

// TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package
type testCostEstimator struct {
hints map[string]int64
}
Expand All @@ -1445,6 +1446,34 @@ func (tc testCostEstimator) EstimateCallCost(overloadId string, target *checker.
return nil
}

type testRuntimeCostEstimator struct {
}

var timeToYearCost uint64 = 7

func (e testRuntimeCostEstimator) CallCost(overloadId string, args []ref.Val) *uint64 {
argsSize := make([]uint64, len(args))
for i, arg := range args {
reflectV := reflect.ValueOf(arg.Value())
switch reflectV.Kind() {
// Note that the CEL bytes type is implemented with Go byte slices, therefore also supported by the following
// code.
case reflect.String, reflect.Array, reflect.Slice, reflect.Map:
argsSize[i] = uint64(reflectV.Len())
default:
argsSize[i] = 1
}
}

switch overloadId {
case overloads.TimestampToYear:
return &timeToYearCost
default:
return nil
}
}

// TestEstimateCostAndRuntimeCost sanity checks that the cost systems are usable from the program API.
func TestEstimateCostAndRuntimeCost(t *testing.T) {
intList := decls.NewListType(decls.Int)
zeroCost := checker.CostEstimate{}
Expand All @@ -1460,6 +1489,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
name: "const",
expr: `"Hello World!"`,
want: zeroCost,
in: map[string]interface{}{},
},
{
name: "identity",
Expand All @@ -1479,21 +1509,6 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
want: checker.CostEstimate{Min: 2, Max: 6},
in: map[string]interface{}{"str1": "val1111111", "str2": "val2222222"},
},
{
name: "ternary with var",
expr: `true > false ? [1, 2, 3].all(x, true) : false`,
want: checker.CostEstimate{Min: 1, Max: 21},
},
{
name: "short circuited, 0 cost lhs",
expr: `true || 4 > 3 || 3 > 2 || 2 > 1 || 1 > 0`,
want: checker.CostEstimate{Min: 0, Max: 4},
},
{
name: "short circuited, 1 cost lhs",
expr: `3 > 4 || 4 > 3 || 3 > 2 || 2 > 1 || 1 > 0`,
want: checker.CostEstimate{Min: 1, Max: 5},
},
}

for _, tc := range cases {
Expand All @@ -1505,42 +1520,41 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
Declarations(tc.decls...),
Types(&proto3pb.TestAllTypes{}))
if err != nil {
t.Fatalf("environment creation error: %s\n", err)
t.Fatalf("NewEnv(opts ...EnvOption) failed to create an environment: %s\n", err)
}
ast, iss := e.Compile(tc.expr)
if iss.Err() != nil {
t.Fatal(iss.Err())
}
est, err := e.EstimateCost(ast, testCostEstimator{hints: tc.hints})
if err != nil {
t.Fatalf("estimate cost error: %s\n", err)
t.Fatalf("Env.EstimateCost(ast *Ast, estimator checker.CostEstimator) failed to estimate cost: %s\n", err)
}
if est.Min != tc.want.Min || est.Max != tc.want.Max {
t.Fatalf("Got cost interval [%v, %v], wanted [%v, %v]",
t.Fatalf("Env.EstimateCost(ast *Ast, estimator checker.CostEstimator) failed to return the right cost interval. Got [%v, %v], wanted [%v, %v]",
est.Min, est.Max, tc.want.Min, tc.want.Max)
}

ctx := constructActivation(t, tc.in)
checked_ast, iss := e.Check(ast)
checkedAst, iss := e.Check(ast)
if iss.Err() != nil {
t.Fatalf(`Failed to check expression with error: %v`, iss.Err())
t.Fatalf(`Env.Check(ast *Ast) failed to check expression: %v`, iss.Err())
}
// Evaluate expression.
program, err := e.Program(checked_ast, ActualCostTracking(testRuntimeCostEstimator{}))
program, err := e.Program(checkedAst, CostTracking(testRuntimeCostEstimator{}))
if err != nil {
t.Fatalf(`Failed to construct Program with error: %v`, err)
t.Fatalf(`Env.Program(ast *Ast, opts ...ProgramOption) failed to construct program: %v`, err)
}
_, details, err := program.Eval(ctx)
_, details, err := program.Eval(tc.in)
if err != nil {
t.Fatalf(`Failed to evaluate expression with error: %v`, err)
t.Fatalf(`Program.Eval(vars interface{}) failed to evaluate expression: %v`, err)
}
actualCost := details.ActualCost()
if actualCost == nil {
t.Fatalf(`Null pointer returned for the cost of expression "%s"`, tc.expr)
t.Errorf(`EvalDetails.ActualCost() got nil for "%s" cost, wanted %d`, tc.expr, actualCost)
}

if est.Min > *actualCost || est.Max < *actualCost {
t.Fatalf("runtime cost %d is out of the range of estimate cost [%d, %d]", *actualCost,
t.Errorf("EvalDetails.ActualCost() failed to return a runtime cost %d is the range of estimate cost [%d, %d]", *actualCost,
est.Min, est.Max)
}
})
Expand Down
17 changes: 14 additions & 3 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,16 +400,27 @@ func InterruptCheckFrequency(checkFrequency uint) ProgramOption {
}
}

// ActualCostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls.
// This enables runtime costs to be assigned to calls function library extensions.
func ActualCostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption {
// CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls.
func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption {
return func(p *prog) (*prog, error) {
p.callCostEstimator = costEstimator
p.evalOpts |= OptTrackCost
return p, nil
}
}

// CostLimit enables cost tracking and sets configures program evaluation to exit early with a
// "runtime cost limit exceeded" error if the runtime cost exceeds the costLimit.
// The CostLimit is a metric that corresponds to the number and estimated expense of operations
// performed while evaluating an expression. It is indicative of CPU usage, not memory usage.
func CostLimit(costLimit uint64) ProgramOption {
return func(p *prog) (*prog, error) {
p.costLimit = &costLimit
p.evalOpts |= OptTrackCost
return p, nil
}
}

func fieldToCELType(field protoreflect.FieldDescriptor) (*exprpb.Type, error) {
if field.Kind() == protoreflect.MessageKind {
msgName := (string)(field.Message().FullName())
Expand Down
13 changes: 10 additions & 3 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ type prog struct {
// Interpretable configured from an Ast and aggregate decorator set based on program options.
interpretable interpreter.Interpretable
callCostEstimator interpreter.ActualCostEstimator
costLimit *uint64
}

func (p *prog) clone() *prog {
Expand Down Expand Up @@ -196,7 +197,8 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) {
costTracker.CallCostEstimator = p.callCostEstimator
costTracker.Estimator = p.callCostEstimator
costTracker.Limit = p.costLimit
decs := decorators
var observers []interpreter.EvalObserver

Expand All @@ -210,7 +212,7 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {

// Enable exhaustive eval over a basic observer since it offers a superset of features.
if p.evalOpts&OptExhaustiveEval == OptExhaustiveEval {
decs = append(decs, interpreter.ExhaustiveEvalWrapper(interpreter.Observe(observers...)))
decs = append(decs, interpreter.ExhaustiveEval(), interpreter.Observe(observers...))
} else if len(observers) > 0 {
decs = append(decs, interpreter.Observe(observers...))
}
Expand Down Expand Up @@ -255,7 +257,12 @@ func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error)
// function.
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("internal error: %v", r)
switch t := r.(type) {
case interpreter.EvalCancelledError:
err = t
default:
err = fmt.Errorf("internal error: %v", r)
}
}
}()
// Build a hierarchical activation if there are default vars set.
Expand Down
67 changes: 47 additions & 20 deletions checker/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ package checker
import (
"math"

"github.com/google/cel-go/common"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/parser"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

// WARNING: Any changes to cost calculations in this file require a corresponding change in interpreter/runtimecost.go

// CostEstimator estimates the sizes of variable length input data and the costs of functions.
type CostEstimator interface {
// EstimateSize returns a SizeEstimate for the given AstNode, or nil if
Expand Down Expand Up @@ -56,7 +59,8 @@ type AstNode interface {
Expr() *exprpb.Expr
// ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression.
// For constants and inline list and map declarations, the exact size is returned. For concatenated list, strings
// and bytes, the size is derived from the size estimates of the operands.
// and bytes, the size is derived from the size estimates of the operands. nil is returned if there is no
// computed size available.
ComputedSize() *SizeEstimate
}

Expand Down Expand Up @@ -91,6 +95,10 @@ func (e astNode) ComputedSize() *SizeEstimate {
v = uint64(len(ck.StringValue))
case *exprpb.Constant_BytesValue:
v = uint64(len(ck.BytesValue))
case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue,
*exprpb.Constant_Int64Value, *exprpb.Constant_TimestampValue, *exprpb.Constant_Uint64Value,
*exprpb.Constant_NullValue:
v = uint64(1)
default:
return nil
}
Expand Down Expand Up @@ -233,13 +241,12 @@ func multiplyByCostFactor(x uint64, y float64) uint64 {
}

var (
identCost = CostEstimate{Min: 1, Max: 1}
selectCost = CostEstimate{Min: 1, Max: 1}
constCost = CostEstimate{Min: 0, Max: 0}
selectAndIdentCost = CostEstimate{Min: common.SelectAndIdentCost, Max: common.SelectAndIdentCost}
constCost = CostEstimate{Min: common.ConstCost, Max: common.ConstCost}

createListBaseCost = CostEstimate{Min: 10, Max: 10}
createMapBaseCost = CostEstimate{Min: 30, Max: 30}
createMessageBaseCost = CostEstimate{Min: 40, Max: 40}
createListBaseCost = CostEstimate{Min: common.ListCreateBaseCost, Max: common.ListCreateBaseCost}
createMapBaseCost = CostEstimate{Min: common.MapCreateBaseCost, Max: common.MapCreateBaseCost}
createMessageBaseCost = CostEstimate{Min: common.StructCreateBaseCost, Max: common.StructCreateBaseCost}
)

type coster struct {
Expand Down Expand Up @@ -326,7 +333,7 @@ func (c *coster) costIdent(e *exprpb.Expr) CostEstimate {
c.addPath(e, []string{identExpr.GetName()})
}

return identCost
return selectAndIdentCost
}

func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
Expand All @@ -339,7 +346,7 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
targetType := c.getType(sel.GetOperand())
switch kindOf(targetType) {
case kindMap, kindObject, kindTypeParam:
sum = sum.Add(selectCost)
sum = sum.Add(selectAndIdentCost)
}

// build and track the field path
Expand Down Expand Up @@ -456,8 +463,6 @@ func (c *coster) costComprehension(e *exprpb.Expr) CostEstimate {
stepCost := c.cost(comp.GetLoopStep())
c.iterRanges.pop(comp.GetIterVar())
sum = sum.Add(c.cost(comp.Result))
// TODO: comprehensions short circuit, so even if the min list size > 0, the minimum number of elements evaluated
// will be 1.
rangeCnt := c.sizeEstimate(c.newAstNode(comp.GetIterRange()))
rangeCost := rangeCnt.MultiplyByCost(stepCost.Add(loopCost))
sum = sum.Add(rangeCost)
Expand All @@ -477,9 +482,6 @@ func (c *coster) sizeEstimate(t AstNode) SizeEstimate {

func (c *coster) functionCost(overloadId string, target *AstNode, args []AstNode, argCosts []CostEstimate) CallEstimate {
argCostSum := func() CostEstimate {
// TODO: handle ternary
// TODO: && || operators short circuit, so min cost should only include 1st arg eval
// unless exhaustive evaluation is enabled
var sum CostEstimate
for _, a := range argCosts {
sum = sum.Add(a)
Expand All @@ -495,7 +497,7 @@ func (c *coster) functionCost(overloadId string, target *AstNode, args []AstNode
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString:
if len(args) == 1 {
return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(0.1).Add(argCostSum())}
return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}
case overloads.InList:
// If a list is composed entirely of constant values this is O(1), but we don't account for that here.
Expand All @@ -507,19 +509,21 @@ func (c *coster) functionCost(overloadId string, target *AstNode, args []AstNode
case overloads.MatchesString:
// https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL
if target != nil && len(args) == 1 {
strCost := c.sizeEstimate(*target).MultiplyByCostFactor(0.1)
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
// in case where string is empty but regex is still expensive.
strCost := c.sizeEstimate(*target).Add(SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor)
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(0.25)
regexCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
return CallEstimate{CostEstimate: strCost.Multiply(regexCost).Add(argCostSum())}
}
case overloads.ContainsString:
if target != nil && len(args) == 1 {
strCost := c.sizeEstimate(*target).MultiplyByCostFactor(0.1)
substrCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(0.1)
strCost := c.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor)
substrCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor)
return CallEstimate{CostEstimate: strCost.Multiply(substrCost).Add(argCostSum())}
}
case overloads.LogicalOr, overloads.LogicalAnd:
Expand All @@ -540,10 +544,33 @@ func (c *coster) functionCost(overloadId string, target *AstNode, args []AstNode
lhsSize := c.sizeEstimate(args[0])
rhsSize := c.sizeEstimate(args[1])
resultSize := lhsSize.Add(rhsSize)
return CallEstimate{CostEstimate: resultSize.MultiplyByCostFactor(0.1).Add(argCostSum()), ResultSize: &resultSize}
switch overloadId {
case overloads.AddList:
// list concatenation is O(1), but we handle it here to track size
return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum()), ResultSize: &resultSize}
default:
return CallEstimate{CostEstimate: resultSize.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &resultSize}
}
}
case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString,
overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes,
overloads.Equals, overloads.NotEquals:
lhsCost := c.sizeEstimate(args[0])
rhsCost := c.sizeEstimate(args[1])
min := uint64(0)
smallestMax := lhsCost.Max
if rhsCost.Max < smallestMax {
smallestMax = rhsCost.Max
}
if smallestMax > 0 {
min = 1
}
// equality of 2 scalar values results in a cost of 1
return CallEstimate{CostEstimate: CostEstimate{Min: min, Max: smallestMax}.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}
// O(1) functions
// See CostTracker.costCall for more details about O(1) cost calculations

// Benchmarks suggest that most of the other operations take +/- 50% of a base cost unit
// which on an Intel xeon 2.20GHz CPU is 50ns.
return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum())}
Expand Down
6 changes: 3 additions & 3 deletions checker/cost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func TestCost(t *testing.T) {
decls: []*exprpb.Decl{decls.NewVar("input", decls.String)},
hints: map[string]int64{"input": 500},
expr: `input.matches('[0-9]')`,
wanted: CostEstimate{Min: 1, Max: 101},
wanted: CostEstimate{Min: 3, Max: 103},
},
{
name: "variable cost function with constant",
Expand Down Expand Up @@ -240,7 +240,7 @@ func TestCost(t *testing.T) {
decls.NewVar("input", decls.String),
},
hints: map[string]int64{"input": 500},
wanted: CostEstimate{Min: 1, Max: 101},
wanted: CostEstimate{Min: 3, Max: 103},
},
{
name: "startsWith",
Expand Down Expand Up @@ -342,7 +342,7 @@ func TestCost(t *testing.T) {
decls.NewVar("list2", decls.NewListType(decls.Int)),
},
hints: map[string]int64{"list1": 10, "list2": 10},
wanted: CostEstimate{Min: 3, Max: 65},
wanted: CostEstimate{Min: 4, Max: 64},
},
{
name: "str concat",
Expand Down
Loading

0 comments on commit a22969d

Please sign in to comment.