Skip to content

Commit

Permalink
Add function name and result to callcost functions to ease estimation…
Browse files Browse the repository at this point in the history
… calculations (#506)
  • Loading branch information
jpbetz authored Mar 9, 2022
1 parent ca0dc96 commit b9108d5
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeE
return nil
}

func (tc testCostEstimator) EstimateCallCost(overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
switch overloadID {
case overloads.TimestampToYear:
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}}
Expand All @@ -1451,7 +1451,7 @@ type testRuntimeCostEstimator struct {

var timeToYearCost uint64 = 7

func (e testRuntimeCostEstimator) CallCost(overloadID string, args []ref.Val) *uint64 {
func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 {
argsSize := make([]uint64, len(args))
for i, arg := range args {
reflectV := reflect.ValueOf(arg.Value())
Expand Down
8 changes: 4 additions & 4 deletions checker/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type CostEstimator interface {
EstimateSize(element AstNode) *SizeEstimate
// EstimateCallCost returns the estimated cost of an invocation, or nil if
// the estimator has no estimate to provide.
EstimateCallCost(overloadID string, target *AstNode, args []AstNode) *CallEstimate
EstimateCallCost(function, overloadID string, target *AstNode, args []AstNode) *CallEstimate
}

// CallEstimate includes a CostEstimate for the call, and an optional estimate of the result object size.
Expand Down Expand Up @@ -384,7 +384,7 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0}
var resultSize *SizeEstimate
for _, overload := range ref.GetOverloadId() {
overloadCost := c.functionCost(overload, &targetType, argTypes, argCosts)
overloadCost := c.functionCost(call.GetFunction(), overload, &targetType, argTypes, argCosts)
fnCost = fnCost.Union(overloadCost.CostEstimate)
if overloadCost.ResultSize != nil {
if resultSize == nil {
Expand Down Expand Up @@ -479,7 +479,7 @@ func (c *coster) sizeEstimate(t AstNode) SizeEstimate {
return SizeEstimate{Min: 0, Max: math.MaxUint64}
}

func (c *coster) functionCost(overloadID string, target *AstNode, args []AstNode, argCosts []CostEstimate) CallEstimate {
func (c *coster) functionCost(function, overloadID string, target *AstNode, args []AstNode, argCosts []CostEstimate) CallEstimate {
argCostSum := func() CostEstimate {
var sum CostEstimate
for _, a := range argCosts {
Expand All @@ -488,7 +488,7 @@ func (c *coster) functionCost(overloadID string, target *AstNode, args []AstNode
return sum
}

if est := c.estimator.EstimateCallCost(overloadID, target, args); est != nil {
if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum())}
}
Expand Down
2 changes: 1 addition & 1 deletion checker/cost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ func (tc testCostEstimator) EstimateSize(element AstNode) *SizeEstimate {
return nil
}

func (tc testCostEstimator) EstimateCallCost(overloadID string, target *AstNode, args []AstNode) *CallEstimate {
func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *AstNode, args []AstNode) *CallEstimate {
switch overloadID {
case overloads.TimestampToYear:
return &CallEstimate{CostEstimate: CostEstimate{Min: 7, Max: 7}}
Expand Down
8 changes: 4 additions & 4 deletions interpreter/runtimecost.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
// estimate to provide. CEL attempts to provide reasonable estimates for its standard function library, so CallCost
// should typically not need to provide an estimate for CELs standard function.
type ActualCostEstimator interface {
CallCost(overloadID string, args []ref.Val) *uint64
CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64
}

// CostObserver provides an observer that tracks runtime cost.
Expand All @@ -58,7 +58,7 @@ func CostObserver(tracker *CostTracker) EvalObserver {
tracker.cost++
case InterpretableCall:
if argVals, ok := tracker.stack.pop(len(t.Args())); ok {
tracker.cost += tracker.costCall(t, argVals)
tracker.cost += tracker.costCall(t, argVals, val)
}
case InterpretableConstructor:
switch t.Type() {
Expand Down Expand Up @@ -93,10 +93,10 @@ func (c CostTracker) ActualCost() uint64 {
return c.cost
}

func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val) uint64 {
func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val, result ref.Val) uint64 {
var cost uint64
if c.Estimator != nil {
callCost := c.Estimator.CallCost(call.OverloadID(), argValues)
callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), argValues, result)
if callCost != nil {
cost += *callCost
return cost
Expand Down
4 changes: 2 additions & 2 deletions interpreter/runtimecost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ type testRuntimeCostEstimator struct {

var timeToYearCost uint64 = 7

func (e testRuntimeCostEstimator) CallCost(overloadID string, args []ref.Val) *uint64 {
func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 {
argsSize := make([]uint64, len(args))
for i, arg := range args {
reflectV := reflect.ValueOf(arg.Value())
Expand Down Expand Up @@ -216,7 +216,7 @@ func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeE
return nil
}

func (tc testCostEstimator) EstimateCallCost(overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
switch overloadID {
case overloads.TimestampToYear:
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}}
Expand Down

0 comments on commit b9108d5

Please sign in to comment.