diff --git a/cel/cel_test.go b/cel/cel_test.go index e247c558..dd5043a0 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -913,6 +913,58 @@ func TestContextEval(t *testing.T) { } } +func TestContextEvalPropagation(t *testing.T) { + env, err := NewEnv( + Declarations( + decls.NewFunction("sleep", decls.NewOverload( + "sleep", []*exprpb.Type{decls.Int}, decls.Null, + )), + ), + ) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + ast, iss := env.Compile("sleep(20)") + if iss.Err() != nil { + t.Fatalf("env.Compile(expr) failed: %v", iss.Err()) + } + prg, err := env.Program(ast, EvalOptions(OptOptimize|OptTrackState), Functions(&functions.ContextOverload{ + Operator: "sleep", + Unary: func(ctx context.Context, value ref.Val) ref.Val { + t := time.NewTimer(time.Duration(value.Value().(int64)) * time.Microsecond) + select { + case <-t.C: + return types.NullValue + case <-ctx.Done(): + return types.NewErr("ctx done") + } + }, + })) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + + ctx := context.TODO() + out, _, err := prg.ContextEval(ctx, map[string]interface{}{}) + if err != nil { + t.Fatalf("prg.ContextEval() failed: %v", err) + } + if out != types.NullValue { + t.Errorf("prg.ContextEval() got %v, wanted true", out) + } + + evalCtx, cancel := context.WithTimeout(ctx, time.Microsecond) + defer cancel() + + out, _, err = prg.ContextEval(evalCtx, map[string]interface{}{}) + if err == nil { + t.Errorf("Got result %v, wanted timeout error", out) + } + if err != nil && err.Error() != "ctx done" { + t.Errorf("Got %v, wanted operation interrupted error", err) + } +} + func BenchmarkContextEval(b *testing.B) { env, err := NewEnv( Declarations( diff --git a/cel/options.go b/cel/options.go index 65f9b8c2..4c1921ab 100644 --- a/cel/options.go +++ b/cel/options.go @@ -332,7 +332,7 @@ func CustomDecorator(dec interpreter.InterpretableDecorator) ProgramOption { } // Functions adds function overloads that extend or override the set of CEL built-ins. -func Functions(funcs ...*functions.Overload) ProgramOption { +func Functions(funcs ...functions.Overloader) ProgramOption { return func(p *prog) (*prog, error) { if err := p.dispatcher.Add(funcs...); err != nil { return nil, err diff --git a/cel/program.go b/cel/program.go index 8fbd350a..74c26893 100644 --- a/cel/program.go +++ b/cel/program.go @@ -19,6 +19,7 @@ import ( "fmt" "math" "sync" + "time" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -301,12 +302,12 @@ func (p *prog) ContextEval(ctx context.Context, input interface{}) (ref.Val, *Ev var vars interpreter.Activation switch v := input.(type) { case interpreter.Activation: - vars = ctxActivationPool.Setup(v, ctx.Done(), p.interruptCheckFrequency) + vars = ctxActivationPool.Setup(ctx, v, p.interruptCheckFrequency) defer ctxActivationPool.Put(vars) case map[string]interface{}: rawVars := activationPool.Setup(v) defer activationPool.Put(rawVars) - vars = ctxActivationPool.Setup(rawVars, ctx.Done(), p.interruptCheckFrequency) + vars = ctxActivationPool.Setup(ctx, rawVars, p.interruptCheckFrequency) defer ctxActivationPool.Put(vars) default: return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]interface{}, got: (%T)%v", input, input) @@ -415,12 +416,63 @@ func estimateCost(i interface{}) (min, max int64) { } type ctxEvalActivation struct { + ctx context.Context parent interpreter.Activation interrupt <-chan struct{} interruptCheckCount uint interruptCheckFrequency uint } +func (a *ctxEvalActivation) Deadline() (deadline time.Time, ok bool) { + if a.parent != nil { + if d1, ok := a.parent.Deadline(); ok { + if d2, ok := a.ctx.Deadline(); ok { + if d1.Before(d2) { + return d1, true + } else { + return d2, true + } + } + return d1, ok + } + } + return a.ctx.Deadline() +} + +func (a *ctxEvalActivation) Done() <-chan struct{} { + if a.parent != nil { + if a.parent.Done() != nil { + c := make(chan struct{}) + go func() { + select { + case c <- <-a.parent.Done(): + case c <- <-a.ctx.Done(): + } + }() + return c + } + } + return a.ctx.Done() +} + +func (a *ctxEvalActivation) Err() error { + if a.parent != nil { + if err := a.parent.Err(); err != nil { + return err + } + } + return a.ctx.Err() +} + +func (a *ctxEvalActivation) Value(key interface{}) interface{} { + if a.parent != nil { + if v := a.parent.Value(key); v != nil { + return v + } + } + return a.ctx.Value(key) +} + // ResolveName implements the Activation interface method, but adds a special #interrupted variable // which is capable of testing whether a 'done' signal is provided from a context.Context channel. func (a *ctxEvalActivation) ResolveName(name string) (interface{}, bool) { @@ -447,7 +499,7 @@ func newCtxEvalActivationPool() *ctxEvalActivationPool { return &ctxEvalActivationPool{ Pool: sync.Pool{ New: func() interface{} { - return &ctxEvalActivation{} + return &ctxEvalActivation{ctx: context.Background()} }, }, } @@ -458,20 +510,27 @@ type ctxEvalActivationPool struct { } // Setup initializes a pooled Activation with the ability check for context.Context cancellation -func (p *ctxEvalActivationPool) Setup(vars interpreter.Activation, done <-chan struct{}, interruptCheckRate uint) *ctxEvalActivation { +func (p *ctxEvalActivationPool) Setup(ctx context.Context, vars interpreter.Activation, interruptCheckRate uint) *ctxEvalActivation { a := p.Pool.Get().(*ctxEvalActivation) + a.ctx = ctx a.parent = vars - a.interrupt = done + a.interrupt = ctx.Done() a.interruptCheckCount = 0 a.interruptCheckFrequency = interruptCheckRate return a } type evalActivation struct { + ctx context.Context vars map[string]interface{} lazyVars map[string]interface{} } +func (a *evalActivation) Deadline() (deadline time.Time, ok bool) { return a.ctx.Deadline() } +func (a *evalActivation) Done() <-chan struct{} { return a.ctx.Done() } +func (a *evalActivation) Err() error { return a.ctx.Err() } +func (a *evalActivation) Value(key interface{}) interface{} { return a.ctx.Value } + // ResolveName looks up the value of the input variable name, if found. // // Lazy bindings may be supplied within the map-based input in either of the following forms: @@ -516,7 +575,7 @@ func newEvalActivationPool() *evalActivationPool { return &evalActivationPool{ Pool: sync.Pool{ New: func() interface{} { - return &evalActivation{lazyVars: make(map[string]interface{})} + return &evalActivation{ctx: context.Background(), lazyVars: make(map[string]interface{})} }, }, } diff --git a/interpreter/activation.go b/interpreter/activation.go index 8686d4f0..511c78ff 100644 --- a/interpreter/activation.go +++ b/interpreter/activation.go @@ -15,9 +15,11 @@ package interpreter import ( + "context" "errors" "fmt" "sync" + "time" "github.com/google/cel-go/common/types/ref" ) @@ -26,6 +28,7 @@ import ( // // An Activation is the primary mechanism by which a caller supplies input into a CEL program. type Activation interface { + context.Context // ResolveName returns a value from the activation by qualified name, or false if the name // could not be found. ResolveName(name string) (interface{}, bool) @@ -37,14 +40,20 @@ type Activation interface { // EmptyActivation returns a variable-free activation. func EmptyActivation() Activation { - return emptyActivation{} + return &emptyActivation{ctx: context.Background()} } // emptyActivation is a variable-free activation. -type emptyActivation struct{} +type emptyActivation struct { + ctx context.Context +} -func (emptyActivation) ResolveName(string) (interface{}, bool) { return nil, false } -func (emptyActivation) Parent() Activation { return nil } +func (a *emptyActivation) Deadline() (deadline time.Time, ok bool) { return a.ctx.Deadline() } +func (a *emptyActivation) Done() <-chan struct{} { return a.ctx.Done() } +func (a *emptyActivation) Err() error { return a.ctx.Err() } +func (a *emptyActivation) Value(key interface{}) interface{} { return a.ctx.Value } +func (a *emptyActivation) ResolveName(string) (interface{}, bool) { return nil, false } +func (a *emptyActivation) Parent() Activation { return nil } // NewActivation returns an activation based on a map-based binding where the map keys are // expected to be qualified names used with ResolveName calls. @@ -73,7 +82,7 @@ func NewActivation(bindings interface{}) (Activation, error) { "activation input must be an activation or map[string]interface: got %T", bindings) } - return &mapActivation{bindings: m}, nil + return &mapActivation{ctx: context.Background(), bindings: m}, nil } // mapActivation which implements Activation and maps of named values. @@ -81,9 +90,15 @@ func NewActivation(bindings interface{}) (Activation, error) { // Named bindings may lazily supply values by providing a function which accepts no arguments and // produces an interface value. type mapActivation struct { + ctx context.Context bindings map[string]interface{} } +func (a *mapActivation) Deadline() (deadline time.Time, ok bool) { return a.ctx.Deadline() } +func (a *mapActivation) Done() <-chan struct{} { return a.ctx.Done() } +func (a *mapActivation) Err() error { return a.ctx.Err() } +func (a *mapActivation) Value(key interface{}) interface{} { return a.ctx.Value } + // Parent implements the Activation interface method. func (a *mapActivation) Parent() Activation { return nil @@ -115,6 +130,54 @@ type hierarchicalActivation struct { child Activation } +func (a *hierarchicalActivation) Deadline() (deadline time.Time, ok bool) { + if d1, ok := a.child.Deadline(); ok { + if d2, ok := a.parent.Deadline(); ok { + if d1.Before(d2) { + return d1, true + } else { + return d2, true + } + } + return d1, ok + } + return a.parent.Deadline() +} + +func (a *hierarchicalActivation) Done() <-chan struct{} { + if a.parent.Done() != nil { + if a.child.Done() != nil { + c := make(chan struct{}) + go func() { + select { + case c <- <-a.parent.Done(): + case c <- <-a.child.Done(): + } + }() + return c + } else { + return a.parent.Done() + } + } + return a.child.Done() +} + +func (a *hierarchicalActivation) Err() error { + if err := a.child.Err(); err != nil { + return err + } else if err = a.parent.Err(); err != nil { + return err + } + return nil +} + +func (a *hierarchicalActivation) Value(key interface{}) interface{} { + if v := a.child.Value(key); v != nil { + return v + } + return a.parent.Value(key) +} + // Parent implements the Activation interface method. func (a *hierarchicalActivation) Parent() Activation { return a.parent @@ -178,6 +241,34 @@ type varActivation struct { val ref.Val } +func (a *varActivation) Deadline() (deadline time.Time, ok bool) { + if a.parent != nil { + return a.parent.Deadline() + } + return time.Time{}, ok +} + +func (a *varActivation) Done() <-chan struct{} { + if a.parent != nil { + return a.parent.Done() + } + return nil +} + +func (a *varActivation) Err() error { + if a.parent != nil { + return a.parent.Err() + } + return nil +} + +func (a *varActivation) Value(key interface{}) interface{} { + if a.parent != nil { + return a.parent.Value(key) + } + return nil +} + // Parent implements the Activation interface method. func (v *varActivation) Parent() Activation { return v.parent diff --git a/interpreter/dispatcher.go b/interpreter/dispatcher.go index febf9d8a..aaaa3e02 100644 --- a/interpreter/dispatcher.go +++ b/interpreter/dispatcher.go @@ -23,10 +23,10 @@ import ( // Dispatcher resolves function calls to their appropriate overload. type Dispatcher interface { // Add one or more overloads, returning an error if any Overload has the same Overload#Name. - Add(overloads ...*functions.Overload) error + Add(overloads ...functions.Overloader) error // FindOverload returns an Overload definition matching the provided name. - FindOverload(overload string) (*functions.Overload, bool) + FindOverload(overload string) (functions.Overloader, bool) // OverloadIds returns the set of all overload identifiers configured for dispatch. OverloadIds() []string @@ -35,7 +35,7 @@ type Dispatcher interface { // NewDispatcher returns an empty Dispatcher instance. func NewDispatcher() Dispatcher { return &defaultDispatcher{ - overloads: make(map[string]*functions.Overload)} + overloads: make(map[string]functions.Overloader)} } // ExtendDispatcher returns a Dispatcher which inherits the overloads of its parent, and @@ -44,11 +44,11 @@ func NewDispatcher() Dispatcher { func ExtendDispatcher(parent Dispatcher) Dispatcher { return &defaultDispatcher{ parent: parent, - overloads: make(map[string]*functions.Overload)} + overloads: make(map[string]functions.Overloader)} } // overloadMap helper type for indexing overloads by function name. -type overloadMap map[string]*functions.Overload +type overloadMap map[string]functions.Overloader // defaultDispatcher struct which contains an overload map. type defaultDispatcher struct { @@ -57,20 +57,20 @@ type defaultDispatcher struct { } // Add implements the Dispatcher.Add interface method. -func (d *defaultDispatcher) Add(overloads ...*functions.Overload) error { +func (d *defaultDispatcher) Add(overloads ...functions.Overloader) error { for _, o := range overloads { // add the overload unless an overload of the same name has already been provided. - if _, found := d.overloads[o.Operator]; found { - return fmt.Errorf("overload already exists '%s'", o.Operator) + if _, found := d.overloads[o.GetOperator()]; found { + return fmt.Errorf("overload already exists '%s'", o.GetOperator()) } // index the overload by function name. - d.overloads[o.Operator] = o + d.overloads[o.GetOperator()] = o } return nil } // FindOverload implements the Dispatcher.FindOverload interface method. -func (d *defaultDispatcher) FindOverload(overload string) (*functions.Overload, bool) { +func (d *defaultDispatcher) FindOverload(overload string) (functions.Overloader, bool) { o, found := d.overloads[overload] // Attempt to dispatch to an overload defined in the parent. if !found && d.parent != nil { diff --git a/interpreter/functions/functions.go b/interpreter/functions/functions.go index dd1e9ddd..460daa09 100644 --- a/interpreter/functions/functions.go +++ b/interpreter/functions/functions.go @@ -16,7 +16,20 @@ // interpreter and as declared within the checker#StandardDeclarations. package functions -import "github.com/google/cel-go/common/types/ref" +import ( + "context" + + "github.com/google/cel-go/common/types/ref" +) + +type Overloader interface { + GetOperator() string + GetOperandTrait() int + GetUnary() ContextUnaryOp + GetBinary() ContextBinaryOp + GetFunction() ContextFunctionOp + IsNonStrict() bool +} // Overload defines a named overload of a function, indicating an operand trait // which must be present on the first argument to the overload as well as one @@ -51,12 +64,69 @@ type Overload struct { NonStrict bool } +func (o *Overload) GetOperator() string { return o.Operator } +func (o *Overload) GetOperandTrait() int { return o.OperandTrait } +func (o *Overload) GetUnary() ContextUnaryOp { + if o.Unary != nil { + return func(ctx context.Context, value ref.Val) ref.Val { return o.Unary(value) } + } + return nil +} +func (o *Overload) GetBinary() ContextBinaryOp { + if o.Binary != nil { + return func(ctx context.Context, lhs, rhs ref.Val) ref.Val { return o.Binary(lhs, rhs) } + } + return nil +} +func (o *Overload) GetFunction() ContextFunctionOp { + if o.Function != nil { + return func(ctx context.Context, values ...ref.Val) ref.Val { return o.Function(values...) } + } + return nil +} +func (o *Overload) IsNonStrict() bool { return o.NonStrict } + +type ContextOverload struct { + // Operator name as written in an expression or defined within + // operators.go. + Operator string + + // Operand trait used to dispatch the call. The zero-value indicates a + // global function overload or that one of the Unary / Binary / Function + // definitions should be used to execute the call. + OperandTrait int + + // Unary defines the overload with a UnaryOp implementation. May be nil. + Unary ContextUnaryOp + + // Binary defines the overload with a BinaryOp implementation. May be nil. + Binary ContextBinaryOp + + // Function defines the overload with a FunctionOp implementation. May be + // nil. + Function ContextFunctionOp + + // NonStrict specifies whether the Overload will tolerate arguments that + // are types.Err or types.Unknown. + NonStrict bool +} + +func (o *ContextOverload) GetOperator() string { return o.Operator } +func (o *ContextOverload) GetOperandTrait() int { return o.OperandTrait } +func (o *ContextOverload) GetUnary() ContextUnaryOp { return o.Unary } +func (o *ContextOverload) GetBinary() ContextBinaryOp { return o.Binary } +func (o *ContextOverload) GetFunction() ContextFunctionOp { return o.Function } +func (o *ContextOverload) IsNonStrict() bool { return o.NonStrict } + // UnaryOp is a function that takes a single value and produces an output. type UnaryOp func(value ref.Val) ref.Val +type ContextUnaryOp func(ctx context.Context, value ref.Val) ref.Val // BinaryOp is a function that takes two values and produces an output. type BinaryOp func(lhs ref.Val, rhs ref.Val) ref.Val +type ContextBinaryOp func(ctx context.Context, lhs ref.Val, rhs ref.Val) ref.Val // FunctionOp is a function with accepts zero or more arguments and produces // an value (as interface{}) or error as a result. type FunctionOp func(values ...ref.Val) ref.Val +type ContextFunctionOp func(ctx context.Context, values ...ref.Val) ref.Val diff --git a/interpreter/functions/standard.go b/interpreter/functions/standard.go index 73e93611..346a7bb5 100644 --- a/interpreter/functions/standard.go +++ b/interpreter/functions/standard.go @@ -23,10 +23,10 @@ import ( ) // StandardOverloads returns the definitions of the built-in overloads. -func StandardOverloads() []*Overload { - return []*Overload{ +func StandardOverloads() []Overloader { + return []Overloader{ // Logical not (!a) - { + &Overload{ Operator: operators.LogicalNot, OperandTrait: traits.NegatorType, Unary: func(value ref.Val) ref.Val { @@ -36,16 +36,16 @@ func StandardOverloads() []*Overload { return value.(traits.Negater).Negate() }}, // Not strictly false: IsBool(a) ? a : true - { + &Overload{ Operator: operators.NotStrictlyFalse, Unary: notStrictlyFalse}, // Deprecated: not strictly false, may be overridden in the environment. - { + &Overload{ Operator: operators.OldNotStrictlyFalse, Unary: notStrictlyFalse}, // Less than operator - {Operator: operators.Less, + &Overload{Operator: operators.Less, OperandTrait: traits.ComparerType, Binary: func(lhs ref.Val, rhs ref.Val) ref.Val { cmp := lhs.(traits.Comparer).Compare(rhs) @@ -59,7 +59,7 @@ func StandardOverloads() []*Overload { }}, // Less than or equal operator - {Operator: operators.LessEquals, + &Overload{Operator: operators.LessEquals, OperandTrait: traits.ComparerType, Binary: func(lhs ref.Val, rhs ref.Val) ref.Val { cmp := lhs.(traits.Comparer).Compare(rhs) @@ -73,7 +73,7 @@ func StandardOverloads() []*Overload { }}, // Greater than operator - {Operator: operators.Greater, + &Overload{Operator: operators.Greater, OperandTrait: traits.ComparerType, Binary: func(lhs ref.Val, rhs ref.Val) ref.Val { cmp := lhs.(traits.Comparer).Compare(rhs) @@ -87,7 +87,7 @@ func StandardOverloads() []*Overload { }}, // Greater than equal operators - {Operator: operators.GreaterEquals, + &Overload{Operator: operators.GreaterEquals, OperandTrait: traits.ComparerType, Binary: func(lhs ref.Val, rhs ref.Val) ref.Val { cmp := lhs.(traits.Comparer).Compare(rhs) @@ -101,42 +101,42 @@ func StandardOverloads() []*Overload { }}, // Add operator - {Operator: operators.Add, + &Overload{Operator: operators.Add, OperandTrait: traits.AdderType, Binary: func(lhs ref.Val, rhs ref.Val) ref.Val { return lhs.(traits.Adder).Add(rhs) }}, // Subtract operators - {Operator: operators.Subtract, + &Overload{Operator: operators.Subtract, OperandTrait: traits.SubtractorType, Binary: func(lhs ref.Val, rhs ref.Val) ref.Val { return lhs.(traits.Subtractor).Subtract(rhs) }}, // Multiply operator - {Operator: operators.Multiply, + &Overload{Operator: operators.Multiply, OperandTrait: traits.MultiplierType, Binary: func(lhs ref.Val, rhs ref.Val) ref.Val { return lhs.(traits.Multiplier).Multiply(rhs) }}, // Divide operator - {Operator: operators.Divide, + &Overload{Operator: operators.Divide, OperandTrait: traits.DividerType, Binary: func(lhs ref.Val, rhs ref.Val) ref.Val { return lhs.(traits.Divider).Divide(rhs) }}, // Modulo operator - {Operator: operators.Modulo, + &Overload{Operator: operators.Modulo, OperandTrait: traits.ModderType, Binary: func(lhs ref.Val, rhs ref.Val) ref.Val { return lhs.(traits.Modder).Modulo(rhs) }}, // Negate operator - {Operator: operators.Negate, + &Overload{Operator: operators.Negate, OperandTrait: traits.NegatorType, Unary: func(value ref.Val) ref.Val { if types.IsBool(value) { @@ -146,26 +146,26 @@ func StandardOverloads() []*Overload { }}, // Index operator - {Operator: operators.Index, + &Overload{Operator: operators.Index, OperandTrait: traits.IndexerType, Binary: func(lhs ref.Val, rhs ref.Val) ref.Val { return lhs.(traits.Indexer).Get(rhs) }}, // Size function - {Operator: overloads.Size, + &Overload{Operator: overloads.Size, OperandTrait: traits.SizerType, Unary: func(value ref.Val) ref.Val { return value.(traits.Sizer).Size() }}, // In operator - {Operator: operators.In, Binary: inAggregate}, + &Overload{Operator: operators.In, Binary: inAggregate}, // Deprecated: in operator, may be overridden in the environment. - {Operator: operators.OldIn, Binary: inAggregate}, + &Overload{Operator: operators.OldIn, Binary: inAggregate}, // Matches function - {Operator: overloads.Matches, + &Overload{Operator: overloads.Matches, OperandTrait: traits.MatcherType, Binary: func(lhs ref.Val, rhs ref.Val) ref.Val { return lhs.(traits.Matcher).Match(rhs) @@ -175,78 +175,78 @@ func StandardOverloads() []*Overload { // TODO: verify type conversion safety of numeric values. // Int conversions. - {Operator: overloads.TypeConvertInt, + &Overload{Operator: overloads.TypeConvertInt, Unary: func(value ref.Val) ref.Val { return value.ConvertToType(types.IntType) }}, // Uint conversions. - {Operator: overloads.TypeConvertUint, + &Overload{Operator: overloads.TypeConvertUint, Unary: func(value ref.Val) ref.Val { return value.ConvertToType(types.UintType) }}, // Double conversions. - {Operator: overloads.TypeConvertDouble, + &Overload{Operator: overloads.TypeConvertDouble, Unary: func(value ref.Val) ref.Val { return value.ConvertToType(types.DoubleType) }}, // Bool conversions. - {Operator: overloads.TypeConvertBool, + &Overload{Operator: overloads.TypeConvertBool, Unary: func(value ref.Val) ref.Val { return value.ConvertToType(types.BoolType) }}, // Bytes conversions. - {Operator: overloads.TypeConvertBytes, + &Overload{Operator: overloads.TypeConvertBytes, Unary: func(value ref.Val) ref.Val { return value.ConvertToType(types.BytesType) }}, // String conversions. - {Operator: overloads.TypeConvertString, + &Overload{Operator: overloads.TypeConvertString, Unary: func(value ref.Val) ref.Val { return value.ConvertToType(types.StringType) }}, // Timestamp conversions. - {Operator: overloads.TypeConvertTimestamp, + &Overload{Operator: overloads.TypeConvertTimestamp, Unary: func(value ref.Val) ref.Val { return value.ConvertToType(types.TimestampType) }}, // Duration conversions. - {Operator: overloads.TypeConvertDuration, + &Overload{Operator: overloads.TypeConvertDuration, Unary: func(value ref.Val) ref.Val { return value.ConvertToType(types.DurationType) }}, // Type operations. - {Operator: overloads.TypeConvertType, + &Overload{Operator: overloads.TypeConvertType, Unary: func(value ref.Val) ref.Val { return value.ConvertToType(types.TypeType) }}, // Dyn conversion (identity function). - {Operator: overloads.TypeConvertDyn, + &Overload{Operator: overloads.TypeConvertDyn, Unary: func(value ref.Val) ref.Val { return value }}, - {Operator: overloads.Iterator, + &Overload{Operator: overloads.Iterator, OperandTrait: traits.IterableType, Unary: func(value ref.Val) ref.Val { return value.(traits.Iterable).Iterator() }}, - {Operator: overloads.HasNext, + &Overload{Operator: overloads.HasNext, OperandTrait: traits.IteratorType, Unary: func(value ref.Val) ref.Val { return value.(traits.Iterator).HasNext() }}, - {Operator: overloads.Next, + &Overload{Operator: overloads.Next, OperandTrait: traits.IteratorType, Unary: func(value ref.Val) ref.Val { return value.(traits.Iterator).Next() diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index f957c760..096f0a41 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -387,7 +387,7 @@ type evalZeroArity struct { id int64 function string overload string - impl functions.FunctionOp + impl functions.ContextFunctionOp } // ID implements the Interpretable interface method. @@ -397,7 +397,7 @@ func (zero *evalZeroArity) ID() int64 { // Eval implements the Interpretable interface method. func (zero *evalZeroArity) Eval(ctx Activation) ref.Val { - return zero.impl() + return zero.impl(ctx) } // Cost returns 1 representing the heuristic cost of the function. @@ -426,7 +426,7 @@ type evalUnary struct { overload string arg Interpretable trait int - impl functions.UnaryOp + impl functions.ContextUnaryOp nonStrict bool } @@ -446,7 +446,7 @@ func (un *evalUnary) Eval(ctx Activation) ref.Val { // If the implementation is bound and the argument value has the right traits required to // invoke it, then call the implementation. if un.impl != nil && (un.trait == 0 || argVal.Type().HasTrait(un.trait)) { - return un.impl(argVal) + return un.impl(ctx, argVal) } // Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the // operand (arg0). @@ -486,7 +486,7 @@ type evalBinary struct { lhs Interpretable rhs Interpretable trait int - impl functions.BinaryOp + impl functions.ContextBinaryOp nonStrict bool } @@ -512,7 +512,7 @@ func (bin *evalBinary) Eval(ctx Activation) ref.Val { // If the implementation is bound and the argument value has the right traits required to // invoke it, then call the implementation. if bin.impl != nil && (bin.trait == 0 || lVal.Type().HasTrait(bin.trait)) { - return bin.impl(lVal, rVal) + return bin.impl(ctx, lVal, rVal) } // Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the // operand (arg0). @@ -548,12 +548,12 @@ type evalVarArgs struct { overload string args []Interpretable trait int - impl functions.FunctionOp + impl functions.ContextFunctionOp nonStrict bool } // NewCall creates a new call Interpretable. -func NewCall(id int64, function, overload string, args []Interpretable, impl functions.FunctionOp) InterpretableCall { +func NewCall(id int64, function, overload string, args []Interpretable, impl functions.ContextFunctionOp) InterpretableCall { return &evalVarArgs{ id: id, function: function, @@ -583,7 +583,7 @@ func (fn *evalVarArgs) Eval(ctx Activation) ref.Val { // invoke it, then call the implementation. arg0 := argVals[0] if fn.impl != nil && (fn.trait == 0 || arg0.Type().HasTrait(fn.trait)) { - return fn.impl(argVals...) + return fn.impl(ctx, argVals...) } // Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the // operand (arg0). diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 6d50f346..d1f3f22c 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -56,7 +56,7 @@ type testCase struct { abbrevs []string env []*exprpb.Decl types []proto.Message - funcs []*functions.Overload + funcs []functions.Overloader attrs AttributeFactory unchecked bool extraOpts []InterpretableDecorator @@ -121,8 +121,8 @@ var ( expr: `zero()`, cost: []int64{1, 1}, unchecked: true, - funcs: []*functions.Overload{ - { + funcs: []functions.Overloader{ + &functions.Overload{ Operator: "zero", Function: func(args ...ref.Val) ref.Val { return types.IntZero @@ -136,8 +136,8 @@ var ( expr: `neg(1)`, cost: []int64{1, 1}, unchecked: true, - funcs: []*functions.Overload{ - { + funcs: []functions.Overloader{ + &functions.Overload{ Operator: "neg", OperandTrait: traits.NegatorType, Unary: func(arg ref.Val) ref.Val { @@ -152,8 +152,8 @@ var ( expr: `b'abc'.concat(b'def')`, cost: []int64{1, 1}, unchecked: true, - funcs: []*functions.Overload{ - { + funcs: []functions.Overloader{ + &functions.Overload{ Operator: "concat", OperandTrait: traits.AdderType, Binary: func(lhs, rhs ref.Val) ref.Val { @@ -168,8 +168,8 @@ var ( expr: `addall(a, b, c, d) == 10`, cost: []int64{6, 6}, unchecked: true, - funcs: []*functions.Overload{ - { + funcs: []functions.Overloader{ + &functions.Overload{ Operator: "addall", OperandTrait: traits.AdderType, Function: func(args ...ref.Val) ref.Val { @@ -196,12 +196,12 @@ var ( decls.String), ), }, - funcs: []*functions.Overload{ - { + funcs: []functions.Overloader{ + &functions.Overload{ Operator: "base64.encode", Unary: base64Encode, }, - { + &functions.Overload{ Operator: "base64_encode_string", Unary: base64Encode, }, @@ -213,8 +213,8 @@ var ( expr: `base64.encode('hello')`, cost: []int64{1, 1}, unchecked: true, - funcs: []*functions.Overload{ - { + funcs: []functions.Overloader{ + &functions.Overload{ Operator: "base64.encode", Unary: base64Encode, }, @@ -233,12 +233,12 @@ var ( decls.String), ), }, - funcs: []*functions.Overload{ - { + funcs: []functions.Overloader{ + &functions.Overload{ Operator: "base64.encode", Unary: base64Encode, }, - { + &functions.Overload{ Operator: "base64_encode_string", Unary: base64Encode, }, @@ -251,8 +251,8 @@ var ( cost: []int64{1, 1}, container: `base64`, unchecked: true, - funcs: []*functions.Overload{ - { + funcs: []functions.Overloader{ + &functions.Overload{ Operator: "base64.encode", Unary: base64Encode, }, @@ -1296,8 +1296,8 @@ var ( decls.NewOverload("string_to_json", []*exprpb.Type{decls.String}, decls.Dyn)), }, - funcs: []*functions.Overload{ - { + funcs: []functions.Overloader{ + &functions.Overload{ Operator: "json", Unary: func(val ref.Val) ref.Val { str, ok := val.(types.String) @@ -1342,8 +1342,8 @@ var ( name: "call_with_error_unary", expr: `try(0/0)`, unchecked: true, - funcs: []*functions.Overload{ - { + funcs: []functions.Overloader{ + &functions.Overload{ Operator: "try", Unary: func(arg ref.Val) ref.Val { if types.IsError(arg) { @@ -1360,8 +1360,8 @@ var ( name: "call_with_error_binary", expr: `try(0/0, 0)`, unchecked: true, - funcs: []*functions.Overload{ - { + funcs: []functions.Overloader{ + &functions.Overload{ Operator: "try", Binary: func(arg0, arg1 ref.Val) ref.Val { if types.IsError(arg0) { @@ -1378,8 +1378,8 @@ var ( name: "call_with_error_function", expr: `try(0/0, 0, 0)`, unchecked: true, - funcs: []*functions.Overload{ - { + funcs: []functions.Overloader{ + &functions.Overload{ Operator: "try", Function: func(args ...ref.Val) ref.Val { if types.IsError(args[0]) { diff --git a/interpreter/optimizations.go b/interpreter/optimizations.go index 2fc87e69..f40d17d3 100644 --- a/interpreter/optimizations.go +++ b/interpreter/optimizations.go @@ -15,6 +15,7 @@ package interpreter import ( + "context" "regexp" "github.com/google/cel-go/common/types" @@ -32,7 +33,7 @@ var MatchesRegexOptimization = &RegexOptimization{ if err != nil { return nil, err } - return NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(values ...ref.Val) ref.Val { + return NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(ctx context.Context, values ...ref.Val) ref.Val { if len(values) != 2 { return types.NoSuchOverloadErr() } diff --git a/interpreter/planner.go b/interpreter/planner.go index ae8abbb0..c9250680 100644 --- a/interpreter/planner.go +++ b/interpreter/planner.go @@ -290,7 +290,7 @@ func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) { // Otherwise, generate Interpretable calls specialized by argument count. // Try to find the specific function by overload id. - var fnDef *functions.Overload + var fnDef functions.Overloader if oName != "" { fnDef, _ = p.disp.FindOverload(oName) } @@ -314,15 +314,15 @@ func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) { func (p *planner) planCallZero(expr *exprpb.Expr, function string, overload string, - impl *functions.Overload) (Interpretable, error) { - if impl == nil || impl.Function == nil { + impl functions.Overloader) (Interpretable, error) { + if impl == nil || impl.GetFunction() == nil { return nil, fmt.Errorf("no such overload: %s()", function) } return &evalZeroArity{ id: expr.Id, function: function, overload: overload, - impl: impl.Function, + impl: impl.GetFunction(), }, nil } @@ -330,18 +330,18 @@ func (p *planner) planCallZero(expr *exprpb.Expr, func (p *planner) planCallUnary(expr *exprpb.Expr, function string, overload string, - impl *functions.Overload, + impl functions.Overloader, args []Interpretable) (Interpretable, error) { - var fn functions.UnaryOp + var fn functions.ContextUnaryOp var trait int var nonStrict bool if impl != nil { - if impl.Unary == nil { + if impl.GetUnary() == nil { return nil, fmt.Errorf("no such overload: %s(arg)", function) } - fn = impl.Unary - trait = impl.OperandTrait - nonStrict = impl.NonStrict + fn = impl.GetUnary() + trait = impl.GetOperandTrait() + nonStrict = impl.IsNonStrict() } return &evalUnary{ id: expr.Id, @@ -358,18 +358,18 @@ func (p *planner) planCallUnary(expr *exprpb.Expr, func (p *planner) planCallBinary(expr *exprpb.Expr, function string, overload string, - impl *functions.Overload, + impl functions.Overloader, args []Interpretable) (Interpretable, error) { - var fn functions.BinaryOp + var fn functions.ContextBinaryOp var trait int var nonStrict bool if impl != nil { - if impl.Binary == nil { + if impl.GetBinary() == nil { return nil, fmt.Errorf("no such overload: %s(lhs, rhs)", function) } - fn = impl.Binary - trait = impl.OperandTrait - nonStrict = impl.NonStrict + fn = impl.GetBinary() + trait = impl.GetOperandTrait() + nonStrict = impl.IsNonStrict() } return &evalBinary{ id: expr.Id, @@ -387,18 +387,18 @@ func (p *planner) planCallBinary(expr *exprpb.Expr, func (p *planner) planCallVarArgs(expr *exprpb.Expr, function string, overload string, - impl *functions.Overload, + impl functions.Overloader, args []Interpretable) (Interpretable, error) { - var fn functions.FunctionOp + var fn functions.ContextFunctionOp var trait int var nonStrict bool if impl != nil { - if impl.Function == nil { + if impl.GetFunction() == nil { return nil, fmt.Errorf("no such overload: %s(...)", function) } - fn = impl.Function - trait = impl.OperandTrait - nonStrict = impl.NonStrict + fn = impl.GetFunction() + trait = impl.GetOperandTrait() + nonStrict = impl.IsNonStrict() } return &evalVarArgs{ id: expr.Id,