diff --git a/cel/cel_test.go b/cel/cel_test.go index 9259d2a6..b05ec1b0 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -1438,7 +1438,7 @@ func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeE return nil } -func (tc testCostEstimator) EstimateCallCost(overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { +func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { switch overloadID { case overloads.TimestampToYear: return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}} @@ -1451,7 +1451,7 @@ type testRuntimeCostEstimator struct { var timeToYearCost uint64 = 7 -func (e testRuntimeCostEstimator) CallCost(overloadID string, args []ref.Val) *uint64 { +func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 { argsSize := make([]uint64, len(args)) for i, arg := range args { reflectV := reflect.ValueOf(arg.Value()) diff --git a/checker/cost.go b/checker/cost.go index b1e83020..4e1ea5aa 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -37,7 +37,7 @@ type CostEstimator interface { EstimateSize(element AstNode) *SizeEstimate // EstimateCallCost returns the estimated cost of an invocation, or nil if // the estimator has no estimate to provide. - EstimateCallCost(overloadID string, target *AstNode, args []AstNode) *CallEstimate + EstimateCallCost(function, overloadID string, target *AstNode, args []AstNode) *CallEstimate } // CallEstimate includes a CostEstimate for the call, and an optional estimate of the result object size. @@ -384,7 +384,7 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate { fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0} var resultSize *SizeEstimate for _, overload := range ref.GetOverloadId() { - overloadCost := c.functionCost(overload, &targetType, argTypes, argCosts) + overloadCost := c.functionCost(call.GetFunction(), overload, &targetType, argTypes, argCosts) fnCost = fnCost.Union(overloadCost.CostEstimate) if overloadCost.ResultSize != nil { if resultSize == nil { @@ -479,7 +479,7 @@ func (c *coster) sizeEstimate(t AstNode) SizeEstimate { return SizeEstimate{Min: 0, Max: math.MaxUint64} } -func (c *coster) functionCost(overloadID string, target *AstNode, args []AstNode, argCosts []CostEstimate) CallEstimate { +func (c *coster) functionCost(function, overloadID string, target *AstNode, args []AstNode, argCosts []CostEstimate) CallEstimate { argCostSum := func() CostEstimate { var sum CostEstimate for _, a := range argCosts { @@ -488,7 +488,7 @@ func (c *coster) functionCost(overloadID string, target *AstNode, args []AstNode return sum } - if est := c.estimator.EstimateCallCost(overloadID, target, args); est != nil { + if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil { callEst := *est return CallEstimate{CostEstimate: callEst.Add(argCostSum())} } diff --git a/checker/cost_test.go b/checker/cost_test.go index 08f8f5ff..e2118f87 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -411,7 +411,7 @@ func (tc testCostEstimator) EstimateSize(element AstNode) *SizeEstimate { return nil } -func (tc testCostEstimator) EstimateCallCost(overloadID string, target *AstNode, args []AstNode) *CallEstimate { +func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *AstNode, args []AstNode) *CallEstimate { switch overloadID { case overloads.TimestampToYear: return &CallEstimate{CostEstimate: CostEstimate{Min: 7, Max: 7}} diff --git a/interpreter/runtimecost.go b/interpreter/runtimecost.go index acbe5696..d8233804 100644 --- a/interpreter/runtimecost.go +++ b/interpreter/runtimecost.go @@ -31,7 +31,7 @@ import ( // estimate to provide. CEL attempts to provide reasonable estimates for its standard function library, so CallCost // should typically not need to provide an estimate for CELs standard function. type ActualCostEstimator interface { - CallCost(overloadID string, args []ref.Val) *uint64 + CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 } // CostObserver provides an observer that tracks runtime cost. @@ -58,7 +58,7 @@ func CostObserver(tracker *CostTracker) EvalObserver { tracker.cost++ case InterpretableCall: if argVals, ok := tracker.stack.pop(len(t.Args())); ok { - tracker.cost += tracker.costCall(t, argVals) + tracker.cost += tracker.costCall(t, argVals, val) } case InterpretableConstructor: switch t.Type() { @@ -93,10 +93,10 @@ func (c CostTracker) ActualCost() uint64 { return c.cost } -func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val) uint64 { +func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val, result ref.Val) uint64 { var cost uint64 if c.Estimator != nil { - callCost := c.Estimator.CallCost(call.OverloadID(), argValues) + callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), argValues, result) if callCost != nil { cost += *callCost return cost diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index 2e49b9fd..633057d5 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -183,7 +183,7 @@ type testRuntimeCostEstimator struct { var timeToYearCost uint64 = 7 -func (e testRuntimeCostEstimator) CallCost(overloadID string, args []ref.Val) *uint64 { +func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 { argsSize := make([]uint64, len(args)) for i, arg := range args { reflectV := reflect.ValueOf(arg.Value()) @@ -216,7 +216,7 @@ func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeE return nil } -func (tc testCostEstimator) EstimateCallCost(overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { +func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { switch overloadID { case overloads.TimestampToYear: return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}}