From efc893b4f30b40ecadad1b96ee0a988428cbf234 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 14 Jul 2022 10:46:04 -0700 Subject: [PATCH] Ensure overloads are searched in the order they are declared (#566) * Ensure overloads are searched in the order they are declared during dynamic dispatch * Improved support for dynamic dispatch --- cel/cel_test.go | 205 ++++++++++++++++++++++++++++++++++++++-------- cel/decls.go | 112 +++++++++++++++++-------- cel/decls_test.go | 22 ++++- 3 files changed, 269 insertions(+), 70 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index 36dbdcc4..5b8e0b69 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -674,7 +674,10 @@ func TestGlobalVars(t *testing.T) { t.Run("attrs_alt", func(t *testing.T) { vars := map[string]interface{}{ "attrs": map[string]interface{}{"second": "yep"}} - out, _, _ := prg.Eval(vars) + out, _, err := prg.Eval(vars) + if err != nil { + t.Fatalf("prg.Eval(vars) failed: %v", err) + } if out.Equal(types.String("yep")) != types.True { t.Errorf("got '%v', expected 'yep'.", out.Value()) } @@ -1657,7 +1660,7 @@ func TestDefaultUTCTimeZone(t *testing.T) { if err != nil { t.Fatalf("NewEnv() failed: %v", err) } - ast, iss := env.Compile(` + out, err := interpret(t, env, ` x.getFullYear() == 1970 && x.getMonth() == 0 && x.getDayOfYear() == 0 @@ -1687,16 +1690,10 @@ func TestDefaultUTCTimeZone(t *testing.T) { && x.getHours('23:15') == 1 && x.getMinutes('23:15') == 20 && x.getSeconds('23:15') == 6 - && x.getMilliseconds('23:15') == 1 - `) - if iss.Err() != nil { - t.Fatalf("env.Compile() failed: %v", iss.Err()) - } - prg, err := env.Program(ast) - if err != nil { - t.Fatalf("env.Program() failed: %v", err) - } - out, _, err := prg.Eval(map[string]interface{}{"x": time.Unix(7506, 1000000).Local()}) + && x.getMilliseconds('23:15') == 1`, + map[string]interface{}{ + "x": time.Unix(7506, 1000000).Local(), + }) if err != nil { t.Fatalf("prg.Eval() failed: %v", err) } @@ -1718,20 +1715,12 @@ func TestDefaultUTCTimeZoneExtension(t *testing.T) { if err != nil { t.Fatalf("env.Extend() failed: %v", err) } - ast, iss := env.Compile(` + out, err := interpret(t, env, ` x.getFullYear() == 1970 && y.getHours() == 2 && y.getMinutes() == 120 && y.getSeconds() == 7235 - && y.getMilliseconds() == 7235000`) - if iss.Err() != nil { - t.Fatalf("env.Compile() failed: %v", iss.Err()) - } - prg, err := env.Program(ast) - if err != nil { - t.Fatalf("env.Program() failed: %v", err) - } - out, _, err := prg.Eval( + && y.getMilliseconds() == 7235000`, map[string]interface{}{ "x": time.Unix(7506, 1000000).Local(), "y": time.Duration(7235) * time.Second, @@ -1750,7 +1739,7 @@ func TestDefaultUTCTimeZoneError(t *testing.T) { if err != nil { t.Fatalf("NewEnv() failed: %v", err) } - ast, iss := env.Compile(` + out, err := interpret(t, env, ` x.getFullYear(':xx') == 1969 || x.getDayOfYear('xx:') == 364 || x.getMonth('Am/Ph') == 11 @@ -1761,30 +1750,180 @@ func TestDefaultUTCTimeZoneError(t *testing.T) { || x.getMinutes('Am/Ph') == 5 || x.getSeconds('Am/Ph') == 6 || x.getMilliseconds('Am/Ph') == 1 - `) - if iss.Err() != nil { - t.Fatalf("env.Compile() failed: %v", iss.Err()) + `, map[string]interface{}{ + "x": time.Unix(7506, 1000000).Local(), + }, + ) + if err == nil { + t.Fatalf("prg.Eval() got %v wanted error", out) } - prg, err := env.Program(ast) +} + +func TestDynamicDispatch(t *testing.T) { + env, err := NewEnv( + HomogeneousAggregateLiterals(), + Function("first", + MemberOverload("first_list_int", []*Type{ListType(IntType)}, IntType, + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.IntZero + } + return l.Get(types.IntZero) + }), + ), + MemberOverload("first_list_double", []*Type{ListType(DoubleType)}, DoubleType, + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.Double(0.0) + } + return l.Get(types.IntZero) + }), + ), + MemberOverload("first_list_string", []*Type{ListType(StringType)}, StringType, + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.String("") + } + return l.Get(types.IntZero) + }), + ), + MemberOverload("first_list_list_string", []*Type{ListType(ListType(StringType))}, ListType(StringType), + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.DefaultTypeAdapter.NativeToValue([]string{}) + } + return l.Get(types.IntZero) + }), + ), + ), + ) if err != nil { - t.Fatalf("env.Program() failed: %v", err) + t.Fatalf("NewEnv() failed: %v", err) } - out, _, err := prg.Eval(map[string]interface{}{"x": time.Unix(7506, 1000000).Local()}) - if err == nil { - t.Fatalf("prg.Eval() got %v wanted error", out) + out, err := interpret(t, env, ` + [].first() == 0 + && [1, 2].first() == 1 + && [1.0, 2.0].first() == 1.0 + && ["hello", "world"].first() == "hello" + && [["hello"], ["world", "!"]].first().first() == "hello" + && [[], ["empty"]].first().first() == "" + && dyn([1, 2]).first() == 1 + && dyn([1.0, 2.0]).first() == 1.0 + && dyn(["hello", "world"]).first() == "hello" + && dyn([["hello"], ["world", "!"]]).first().first() == "hello" + `, map[string]interface{}{}, + ) + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + if out != types.True { + t.Fatalf("prg.Eval() got %v wanted true", out) + } +} + +func BenchmarkDynamicDispatch(b *testing.B) { + env, err := NewEnv( + HomogeneousAggregateLiterals(), + Function("first", + MemberOverload("first_list_int", []*Type{ListType(IntType)}, IntType, + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.IntZero + } + return l.Get(types.IntZero) + }), + ), + MemberOverload("first_list_double", []*Type{ListType(DoubleType)}, DoubleType, + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.Double(0.0) + } + return l.Get(types.IntZero) + }), + ), + MemberOverload("first_list_string", []*Type{ListType(StringType)}, StringType, + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.String("") + } + return l.Get(types.IntZero) + }), + ), + MemberOverload("first_list_list_string", []*Type{ListType(ListType(StringType))}, ListType(StringType), + UnaryBinding(func(list ref.Val) ref.Val { + l := list.(traits.Lister) + if l.Size() == types.IntZero { + return types.DefaultTypeAdapter.NativeToValue([]string{}) + } + return l.Get(types.IntZero) + }), + ), + ), + ) + if err != nil { + b.Fatalf("NewEnv() failed: %v", err) } + prg := compile(b, env, ` + [].first() == 0 + && [1, 2].first() == 1 + && [1.0, 2.0].first() == 1.0 + && ["hello", "world"].first() == "hello" + && [["hello"], ["world", "!"]].first().first() == "hello"`) + prgDyn := compile(b, env, ` + dyn([]).first() == 0 + && dyn([1, 2]).first() == 1 + && dyn([1.0, 2.0]).first() == 1.0 + && dyn(["hello", "world"]).first() == "hello" + && dyn([["hello"], ["world", "!"]]).first().first() == "hello"`) + b.ResetTimer() + b.Run("DirectDispatch", func(b *testing.B) { + for i := 0; i < b.N; i++ { + prg.Eval(NoVars()) + } + }) + b.ResetTimer() + b.Run("DynamicDispatch", func(b *testing.B) { + for i := 0; i < b.N; i++ { + prgDyn.Eval(NoVars()) + } + }) } -func interpret(t *testing.T, env *Env, expr string, vars interface{}) (ref.Val, error) { +func compile(t testing.TB, env *Env, expr string) Program { + t.Helper() + prg, err := compileOrError(t, env, expr) + if err != nil { + t.Fatal(err) + } + return prg +} + +func compileOrError(t testing.TB, env *Env, expr string) (Program, error) { t.Helper() ast, iss := env.Compile(expr) if iss.Err() != nil { return nil, fmt.Errorf("env.Compile(%s) failed: %v", expr, iss.Err()) } - prg, err := env.Program(ast) + prg, err := env.Program(ast, EvalOptions(OptOptimize)) if err != nil { return nil, fmt.Errorf("env.Program() failed: %v", err) } + return prg, nil +} + +func interpret(t testing.TB, env *Env, expr string, vars interface{}) (ref.Val, error) { + t.Helper() + prg, err := compileOrError(t, env, expr) + if err != nil { + return nil, err + } out, _, err := prg.Eval(vars) if err != nil { return nil, fmt.Errorf("prg.Eval(%v) failed: %v", vars, err) diff --git a/cel/decls.go b/cel/decls.go index 55532788..f2df721d 100644 --- a/cel/decls.go +++ b/cel/decls.go @@ -21,6 +21,7 @@ import ( "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/interpreter/functions" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -162,7 +163,7 @@ type Type struct { // isAssignableRuntimeType function determines whether the runtime type (with erasure) is assignable to this type. // A nil value for the isAssignableRuntimeType function falls back to the equality of the type or type name. - isAssignableRuntimeType func(other ref.Type) bool + isAssignableRuntimeType func(other ref.Val) bool } // IsAssignableType determines whether the current type is type-check assignable from the input fromType. @@ -177,11 +178,11 @@ func (t *Type) IsAssignableType(fromType *Type) bool { // // At runtime, parameterized types are erased and so a function which type-checks to support a map(string, string) // will have a runtime assignable type of a map. -func (t *Type) IsAssignableRuntimeType(runtimeType ref.Type) bool { +func (t *Type) IsAssignableRuntimeType(val ref.Val) bool { if t.isAssignableRuntimeType != nil { - return t.isAssignableRuntimeType(runtimeType) + return t.isAssignableRuntimeType(val) } - return t.defaultIsAssignableRuntimeType(runtimeType) + return t.defaultIsAssignableRuntimeType(val) } // String returns a human-readable definition of the type name. @@ -221,7 +222,7 @@ func (t *Type) equals(other *Type) bool { // - The from types are the same instance // - The target type is dynamic // - The fromType has the same kind and type name as the target type, and all parameters of the target type -// are IsAssignableType() from the parameters of the fromType. +// are IsAssignableType() from the parameters of the fromType. func (t *Type) defaultIsAssignableType(fromType *Type) bool { if t == fromType || t.isDyn() { return true @@ -240,8 +241,40 @@ func (t *Type) defaultIsAssignableType(fromType *Type) bool { return true } -func (t *Type) defaultIsAssignableRuntimeType(runtimeType ref.Type) bool { - return t.runtimeType == runtimeType || t.isDyn() || t.runtimeType.TypeName() == runtimeType.TypeName() +// defaultIsAssignableRuntimeType inspects the type and in the case of list and map elements, the key and element types +// to determine whether a ref.Val is assignable to the declared type for a function signature. +func (t *Type) defaultIsAssignableRuntimeType(val ref.Val) bool { + valType := val.Type() + if !(t.runtimeType == valType || t.isDyn() || t.runtimeType.TypeName() == valType.TypeName()) { + return false + } + switch t.runtimeType { + case types.ListType: + elemType := t.parameters[0] + l := val.(traits.Lister) + if l.Size() == types.IntZero { + return true + } + it := l.Iterator() + for it.HasNext() == types.True { + elemVal := it.Next() + return elemType.IsAssignableRuntimeType(elemVal) + } + case types.MapType: + keyType := t.parameters[0] + elemType := t.parameters[1] + m := val.(traits.Mapper) + if m.Size() == types.IntZero { + return true + } + it := m.Iterator() + for it.HasNext() == types.True { + keyVal := it.Next() + elemVal := m.Get(keyVal) + return keyType.IsAssignableRuntimeType(keyVal) && elemType.IsAssignableRuntimeType(elemVal) + } + } + return true } // ListType creates an instances of a list type value with the provided element type. @@ -273,7 +306,7 @@ func NullableType(wrapped *Type) *Type { isAssignableType: func(other *Type) bool { return NullType.IsAssignableType(other) || wrapped.IsAssignableType(other) }, - isAssignableRuntimeType: func(other ref.Type) bool { + isAssignableRuntimeType: func(other ref.Val) bool { return NullType.IsAssignableRuntimeType(other) || wrapped.IsAssignableRuntimeType(other) }, } @@ -328,12 +361,26 @@ func Variable(name string, t *Type) EnvOption { // One key difference with using Function() is that each FunctionDecl provided will handle dynamic // dispatch based on the type-signatures of the overloads provided which means overload resolution at // runtime is handled out of the box rather than via a custom binding for overload resolution via -// Functions(). +// Functions(): +// +// - Overloads are searched in the order they are declared +// - Dynamic dispatch for lists and maps is limited by inspection of the list and map contents +// at runtime. Empty lists and maps will result in a 'default dispatch' +// - In the event that a default dispatch occurs, the first overload provided is the one invoked +// +// If you intend to use overloads which differentiate based on the key or element type of a list or +// map, consider using a generic function instead: e.g. func(list(T)) or func(map(K, V)) as this +// will allow your implementation to determine how best to handle dispatch and the default behavior +// for empty lists and maps whose contents cannot be inspected. +// +// For functions which use parameterized opaque types (abstract types), consider using a singleton +// function which is capable of inspecting the contents of the type and resolving the appropriate +// overload as CEL can only make inferences by type-name regarding such types. func Function(name string, opts ...FunctionOpt) EnvOption { return func(e *Env) (*Env, error) { fn := &functionDecl{ name: name, - overloads: map[string]*overloadDecl{}, + overloads: []*overloadDecl{}, options: opts, } err := fn.init() @@ -510,7 +557,7 @@ func OverloadOperandTrait(trait int) OverloadOpt { type functionDecl struct { name string - overloads map[string]*overloadDecl + overloads []*overloadDecl options []FunctionOpt singleton *functions.Overload initialized bool @@ -591,22 +638,22 @@ func (f *functionDecl) bindings() ([]*functions.Overload, error) { // performs dynamic dispatch to the proper overload based on the argument types. bindings := append([]*functions.Overload{}, overloads...) funcDispatch := func(args ...ref.Val) ref.Val { - for _, overloadDecl := range f.overloads { - if !overloadDecl.matchesRuntimeSignature(args...) { + for _, o := range f.overloads { + if !o.matchesRuntimeSignature(args...) { continue } switch len(args) { case 1: - if overloadDecl.unaryOp != nil { - return overloadDecl.unaryOp(args[0]) + if o.unaryOp != nil { + return o.unaryOp(args[0]) } case 2: - if overloadDecl.binaryOp != nil { - return overloadDecl.binaryOp(args[0], args[1]) + if o.binaryOp != nil { + return o.binaryOp(args[0], args[1]) } } - if overloadDecl.functionOp != nil { - return overloadDecl.functionOp(args...) + if o.functionOp != nil { + return o.functionOp(args...) } // eventually this will fall through to the noSuchOverload below. } @@ -639,14 +686,12 @@ func (f *functionDecl) merge(other *functionDecl) (*functionDecl, error) { } merged := &functionDecl{ name: f.name, - overloads: map[string]*overloadDecl{}, + overloads: make([]*overloadDecl, len(f.overloads)), options: []FunctionOpt{}, initialized: true, singleton: f.singleton, } - for id, o := range f.overloads { - merged.overloads[id] = o - } + copy(merged.overloads, f.overloads) for _, o := range other.overloads { err := merged.addOverload(o) if err != nil { @@ -666,20 +711,21 @@ func (f *functionDecl) merge(other *functionDecl) (*functionDecl, error) { // however, if the function signatures are identical, the implementation may be rewritten as its // difficult to compare functions by object identity. func (f *functionDecl) addOverload(overload *overloadDecl) error { - for id, o := range f.overloads { - if id != overload.id && o.signatureOverlaps(overload) { + for index, o := range f.overloads { + if o.id != overload.id && o.signatureOverlaps(overload) { return fmt.Errorf("overload signature collision in function %s: %s collides with %s", f.name, o.id, overload.id) } - if id == overload.id { + if o.id == overload.id { if o.signatureEquals(overload) && o.nonStrict == overload.nonStrict { // Allow redefinition of an overload implementation so long as the signatures match. - f.overloads[id] = overload + f.overloads[index] = overload + return nil } else { return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.name, o.id) } } } - f.overloads[overload.id] = overload + f.overloads = append(f.overloads, overload) return nil } @@ -757,19 +803,19 @@ func (o *overloadDecl) matchesRuntimeUnarySignature(arg ref.Val) bool { if o.nonStrict && types.IsUnknownOrError(arg) { return true } - return o.argTypes[0].IsAssignableRuntimeType(arg.Type()) && (o.operandTrait == 0 || arg.Type().HasTrait(o.operandTrait)) + return o.argTypes[0].IsAssignableRuntimeType(arg) && (o.operandTrait == 0 || arg.Type().HasTrait(o.operandTrait)) } // matchesRuntimeBinarySignature indicates whether the argument types are runtime assiganble to the overload's expected arguments. func (o *overloadDecl) matchesRuntimeBinarySignature(arg1, arg2 ref.Val) bool { if o.nonStrict { if types.IsUnknownOrError(arg1) { - return types.IsUnknownOrError(arg2) || o.argTypes[1].IsAssignableRuntimeType(arg2.Type()) + return types.IsUnknownOrError(arg2) || o.argTypes[1].IsAssignableRuntimeType(arg2) } - } else if !o.argTypes[1].IsAssignableRuntimeType(arg2.Type()) { + } else if !o.argTypes[1].IsAssignableRuntimeType(arg2) { return false } - return o.argTypes[0].IsAssignableRuntimeType(arg1.Type()) && (o.operandTrait == 0 || arg1.Type().HasTrait(o.operandTrait)) + return o.argTypes[0].IsAssignableRuntimeType(arg1) && (o.operandTrait == 0 || arg1.Type().HasTrait(o.operandTrait)) } // matchesRuntimeSignature indicates whether the argument types are runtime assiganble to the overload's expected arguments. @@ -785,7 +831,7 @@ func (o *overloadDecl) matchesRuntimeSignature(args ...ref.Val) bool { if o.nonStrict && types.IsUnknownOrError(arg) { continue } - allArgsMatch = allArgsMatch && o.argTypes[i].IsAssignableRuntimeType(arg.Type()) + allArgsMatch = allArgsMatch && o.argTypes[i].IsAssignableRuntimeType(arg) } arg := args[0] diff --git a/cel/decls_test.go b/cel/decls_test.go index 655b0ed9..d1cadd72 100644 --- a/cel/decls_test.go +++ b/cel/decls_test.go @@ -20,6 +20,7 @@ import ( "reflect" "strings" "testing" + "time" "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/operators" @@ -659,14 +660,27 @@ func TestIsAssignableType(t *testing.T) { } func TestIsAssignableRuntimeType(t *testing.T) { - if !NullableType(DoubleType).IsAssignableRuntimeType(types.NullType) { + if !NullableType(DoubleType).IsAssignableRuntimeType(types.NullValue) { t.Error("nullable double cannot be assigned from null") } - if !NullableType(DoubleType).IsAssignableRuntimeType(types.DoubleType) { + if !NullableType(DoubleType).IsAssignableRuntimeType(types.Double(0.0)) { t.Error("nullable double cannot be assigned from double") } - if !MapType(StringType, DurationType).IsAssignableRuntimeType(types.MapType) { - t.Error("map(string, duration) not assibale to map at runtime") + if !MapType(StringType, DurationType).IsAssignableRuntimeType( + types.DefaultTypeAdapter.NativeToValue(map[string]time.Duration{})) { + t.Error("map(string, duration) not assignable to map at runtime") + } + if !MapType(StringType, DurationType).IsAssignableRuntimeType( + types.DefaultTypeAdapter.NativeToValue(map[string]time.Duration{"one": time.Duration(1)})) { + t.Error("map(string, duration) not assignable to map at runtime") + } + if !MapType(StringType, DynType).IsAssignableRuntimeType( + types.DefaultTypeAdapter.NativeToValue(map[string]time.Duration{"one": time.Duration(1)})) { + t.Error("map(string, dyn) not assignable to map at runtime") + } + if MapType(StringType, DynType).IsAssignableRuntimeType( + types.DefaultTypeAdapter.NativeToValue(map[int64]time.Duration{1: time.Duration(1)})) { + t.Error("map(string, dyn) must not be assignable to map(int, duration) at runtime") } }