diff --git a/ext/comprehensions.go b/ext/comprehensions.go index 1428558d..58a1dbc2 100644 --- a/ext/comprehensions.go +++ b/ext/comprehensions.go @@ -16,6 +16,7 @@ package ext import ( "fmt" + "math" "github.com/google/cel-go/cel" "github.com/google/cel-go/common/ast" @@ -159,19 +160,36 @@ const ( // // {'greeting': 'aloha', 'farewell': 'aloha'} // .transformMapEntry(keyVar, valueVar, {valueVar: keyVar}) // error, duplicate key -func TwoVarComprehensions() cel.EnvOption { - return cel.Lib(compreV2Lib{}) +func TwoVarComprehensions(options ...TwoVarComprehensionsOption) cel.EnvOption { + l := &compreV2Lib{version: math.MaxUint32} + for _, o := range options { + l = o(l) + } + return cel.Lib(l) +} + +// TwoVarComprehensionsOption declares a functional operator for configuring two-variable comprehensions. +type TwoVarComprehensionsOption func(*compreV2Lib) *compreV2Lib + +// TwoVarComprehensionsVersion sets the library version for two-variable comprehensions. +func TwoVarComprehensionsVersion(version uint32) TwoVarComprehensionsOption { + return func(lib *compreV2Lib) *compreV2Lib { + lib.version = version + return lib + } } -type compreV2Lib struct{} +type compreV2Lib struct { + version uint32 +} // LibraryName implements that SingletonLibrary interface method. -func (compreV2Lib) LibraryName() string { +func (*compreV2Lib) LibraryName() string { return "cel.lib.ext.comprev2" } // CompileOptions implements the cel.Library interface method. -func (compreV2Lib) CompileOptions() []cel.EnvOption { +func (*compreV2Lib) CompileOptions() []cel.EnvOption { kType := cel.TypeParamType("K") vType := cel.TypeParamType("V") mapKVType := cel.MapType(kType, vType) @@ -217,7 +235,7 @@ func (compreV2Lib) CompileOptions() []cel.EnvOption { } // ProgramOptions implements the cel.Library interface method -func (compreV2Lib) ProgramOptions() []cel.ProgramOption { +func (*compreV2Lib) ProgramOptions() []cel.ProgramOption { return []cel.ProgramOption{} } diff --git a/ext/comprehensions_test.go b/ext/comprehensions_test.go index 6416e0e7..1bd65fa4 100644 --- a/ext/comprehensions_test.go +++ b/ext/comprehensions_test.go @@ -352,6 +352,13 @@ func TestTwoVarComprehensionsRuntimeErrors(t *testing.T) { } } +func TestTwoVarComprehensionsVersion(t *testing.T) { + _, err := cel.NewEnv(TwoVarComprehensions(TwoVarComprehensionsVersion(0))) + if err != nil { + t.Fatalf("TwoVarComprehensionVersion(0) failed: %v", err) + } +} + func testCompreEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { t.Helper() baseOpts := []cel.EnvOption{ diff --git a/ext/encoders.go b/ext/encoders.go index ac04b1a7..731c3d09 100644 --- a/ext/encoders.go +++ b/ext/encoders.go @@ -16,6 +16,7 @@ package ext import ( "encoding/base64" + "math" "github.com/google/cel-go/cel" "github.com/google/cel-go/common/types" @@ -47,17 +48,34 @@ import ( // Examples: // // base64.encode(b'hello') // return b'aGVsbG8=' -func Encoders() cel.EnvOption { - return cel.Lib(encoderLib{}) +func Encoders(options ...EncodersOption) cel.EnvOption { + l := &encoderLib{version: math.MaxUint32} + for _, o := range options { + l = o(l) + } + return cel.Lib(l) +} + +// EncodersOption declares a functional operator for configuring encoder extensions. +type EncodersOption func(*encoderLib) *encoderLib + +// EncodersVersion sets the library version for encoder extensions. +func EncodersVersion(version uint32) EncodersOption { + return func(lib *encoderLib) *encoderLib { + lib.version = version + return lib + } } -type encoderLib struct{} +type encoderLib struct { + version uint32 +} -func (encoderLib) LibraryName() string { +func (*encoderLib) LibraryName() string { return "cel.lib.ext.encoders" } -func (encoderLib) CompileOptions() []cel.EnvOption { +func (*encoderLib) CompileOptions() []cel.EnvOption { return []cel.EnvOption{ cel.Function("base64.decode", cel.Overload("base64_decode_string", []*cel.Type{cel.StringType}, cel.BytesType, @@ -74,7 +92,7 @@ func (encoderLib) CompileOptions() []cel.EnvOption { } } -func (encoderLib) ProgramOptions() []cel.ProgramOption { +func (*encoderLib) ProgramOptions() []cel.ProgramOption { return []cel.ProgramOption{} } diff --git a/ext/encoders_test.go b/ext/encoders_test.go index 959cb837..be6f764e 100644 --- a/ext/encoders_test.go +++ b/ext/encoders_test.go @@ -86,3 +86,10 @@ func TestEncoders(t *testing.T) { }) } } + +func TestEncodersVersion(t *testing.T) { + _, err := cel.NewEnv(Encoders(EncodersVersion(0))) + if err != nil { + t.Fatalf("EncodersVersion(0) failed: %v", err) + } +} diff --git a/ext/lists.go b/ext/lists.go index d0b90ea9..675ea867 100644 --- a/ext/lists.go +++ b/ext/lists.go @@ -145,13 +145,10 @@ var comparableTypes = []*cel.Type{ // == ["bar", "foo", "baz"] func Lists(options ...ListsOption) cel.EnvOption { - l := &listsLib{ - version: math.MaxUint32, - } + l := &listsLib{version: math.MaxUint32} for _, o := range options { l = o(l) } - return cel.Lib(l) } @@ -211,9 +208,10 @@ func (lib listsLib) CompileOptions() []cel.EnvOption { cel.MemberOverload("list_flatten", []*cel.Type{listListType}, listType, cel.UnaryBinding(func(arg ref.Val) ref.Val { + // double-check as type-guards disabled list, ok := arg.(traits.Lister) if !ok { - return types.MaybeNoSuchOverloadErr(arg) + return types.ValOrErr(arg, "no such overload: %v.flatten()", arg.Type()) } flatList, err := flatten(list, 1) if err != nil { @@ -226,13 +224,14 @@ func (lib listsLib) CompileOptions() []cel.EnvOption { cel.MemberOverload("list_flatten_int", []*cel.Type{listDyn, types.IntType}, listDyn, cel.BinaryBinding(func(arg1, arg2 ref.Val) ref.Val { + // double-check as type-guards disabled list, ok := arg1.(traits.Lister) if !ok { - return types.MaybeNoSuchOverloadErr(arg1) + return types.ValOrErr(arg1, "no such overload: %v.flatten(%v)", arg1.Type(), arg2.Type()) } depth, ok := arg2.(types.Int) if !ok { - return types.MaybeNoSuchOverloadErr(arg2) + return types.ValOrErr(arg1, "no such overload: %v.flatten(%v)", arg1.Type(), arg2.Type()) } flatList, err := flatten(list, int64(depth)) if err != nil { @@ -260,10 +259,8 @@ func (lib listsLib) CompileOptions() []cel.EnvOption { }), cel.SingletonUnaryBinding( func(arg ref.Val) ref.Val { - list, ok := arg.(traits.Lister) - if !ok { - return types.MaybeNoSuchOverloadErr(arg) - } + // validated by type-guards + list := arg.(traits.Lister) sorted, err := sortList(list) if err != nil { return types.WrapErr(err) @@ -287,15 +284,10 @@ func (lib listsLib) CompileOptions() []cel.EnvOption { ) }), cel.SingletonBinaryBinding( - func(arg1 ref.Val, arg2 ref.Val) ref.Val { - list, ok := arg1.(traits.Lister) - if !ok { - return types.MaybeNoSuchOverloadErr(arg1) - } - keys, ok := arg2.(traits.Lister) - if !ok { - return types.MaybeNoSuchOverloadErr(arg2) - } + func(arg1, arg2 ref.Val) ref.Val { + // validated by type-guards + list := arg1.(traits.Lister) + keys := arg2.(traits.Lister) sorted, err := sortListByAssociatedKeys(list, keys) if err != nil { return types.WrapErr(err) @@ -498,8 +490,9 @@ func sortByMacro(meh cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (as if targetKind != ast.ListKind && targetKind != ast.SelectKind && targetKind != ast.IdentKind && - targetKind != ast.ComprehensionKind && targetKind != ast.CallKind { - return nil, meh.NewError(target.ID(), fmt.Sprintf("sortBy can only be applied to a list, identifier, comprehension, call or select expression")) + targetKind != ast.ComprehensionKind && + targetKind != ast.CallKind { + return nil, meh.NewError(target.ID(), "sortBy can only be applied to a list, identifier, comprehension, call or select expression") } mapCompr, err := parser.MakeMap(meh, meh.Copy(varIdent), args) diff --git a/ext/lists_test.go b/ext/lists_test.go index 2baff0e6..c885e8cb 100644 --- a/ext/lists_test.go +++ b/ext/lists_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" proto2pb "github.com/google/cel-go/test/proto2pb" ) @@ -113,6 +114,118 @@ func TestLists(t *testing.T) { } } +func TestListsRuntimeErrors(t *testing.T) { + env, err := cel.NewEnv(Lists(ListsVersion(1))) + if err != nil { + t.Fatalf("cel.NewEnv() failed: %v", err) + } + listsTests := []struct { + expr string + err string + }{ + { + expr: "dyn({}).flatten()", + err: "no such overload", + }, + { + expr: "dyn({}).flatten(0)", + err: "no such overload", + }, + { + expr: "[].flatten(-1)", + err: "level must be non-negative", + }, + { + expr: "[].flatten(dyn('1'))", + err: "no such overload", + }, + } + for i, tst := range listsTests { + tc := tst + t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + ast, iss := env.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Compile(%q) failed: %v", tc.expr, iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + _, _, err = prg.Eval(cel.NoVars()) + if err == nil || !strings.Contains(err.Error(), tc.err) { + t.Errorf("prg.Eval() got %v, wanted %v", err, tc.err) + } + }) + } +} + +func TestListsVersion(t *testing.T) { + versionCases := []struct { + version uint32 + supportedFunctions map[string]string + }{ + { + version: 0, + supportedFunctions: map[string]string{ + "slice": "[1, 2, 3, 4, 5].slice(2, 4) == [3, 4]", + }, + }, + { + version: 1, + supportedFunctions: map[string]string{ + "flatten": "[[1, 2], [3, 4]].flatten() == [1, 2, 3, 4]", + }, + }, + { + version: 2, + supportedFunctions: map[string]string{ + "distinct": "[1, 2, 2, 1].distinct() == [1, 2]", + "range": "lists.range(5) == [0, 1, 2, 3, 4]", + "reverse": "[1, 2, 3].reverse() == [3, 2, 1]", + "sort": "[2, 1, 3].sort() == [1, 2, 3]", + "sortBy": "[{'field': 'lo'}, {'field': 'hi'}].sortBy(m, m.field) == [{'field': 'hi'}, {'field': 'lo'}]", + }, + }, + } + for _, lib := range versionCases { + env, err := cel.NewEnv(Lists(ListsVersion(lib.version))) + if err != nil { + t.Fatalf("cel.NewEnv(Lists(ListsVersion(%d))) failed: %v", lib.version, err) + } + t.Run(fmt.Sprintf("version=%d", lib.version), func(t *testing.T) { + for _, tc := range versionCases { + for name, expr := range tc.supportedFunctions { + supported := lib.version >= tc.version + t.Run(fmt.Sprintf("%s-supported=%t", name, supported), func(t *testing.T) { + ast, iss := env.Compile(expr) + if supported { + if iss.Err() != nil { + t.Errorf("unexpected error: %v", iss.Err()) + } + } else { + if iss.Err() == nil || !strings.Contains(iss.Err().Error(), "undeclared reference") { + t.Errorf("got error %v, wanted error %s for expr: %s, version: %d", iss.Err(), "undeclared reference", expr, tc.version) + } + return + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + out, _, err := prg.Eval(cel.NoVars()) + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + if out != types.True { + t.Errorf("prg.Eval() got %v, wanted true", out) + } + }) + } + } + }) + } +} + func testListsEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { t.Helper() baseOpts := []cel.EnvOption{ diff --git a/ext/native.go b/ext/native.go index 36ab4a7a..83f36589 100644 --- a/ext/native.go +++ b/ext/native.go @@ -17,6 +17,7 @@ package ext import ( "errors" "fmt" + "math" "reflect" "strings" "time" @@ -98,7 +99,9 @@ var ( func NativeTypes(args ...any) cel.EnvOption { return func(env *cel.Env) (*cel.Env, error) { nativeTypes := make([]any, 0, len(args)) - tpOptions := nativeTypeOptions{} + tpOptions := nativeTypeOptions{ + version: math.MaxUint32, + } for _, v := range args { switch v := v.(type) { @@ -128,6 +131,14 @@ func NativeTypes(args ...any) cel.EnvOption { // NativeTypesOption is a functional interface for configuring handling of native types. type NativeTypesOption func(*nativeTypeOptions) error +// NativeTypesVersion sets the native types version support for native extensions functions. +func NativeTypesVersion(version uint32) NativeTypesOption { + return func(opts *nativeTypeOptions) error { + opts.version = version + return nil + } +} + // NativeTypesFieldNameHandler is a handler for mapping a reflect.StructField to a CEL field name. // This can be used to override the default Go struct field to CEL field name mapping. type NativeTypesFieldNameHandler = func(field reflect.StructField) string @@ -158,6 +169,9 @@ type nativeTypeOptions struct { // This is most commonly used for switching to parsing based off the struct field tag, // such as "cel" or "json". fieldNameHandler NativeTypesFieldNameHandler + + // version is the native types library version. + version uint32 } // ParseStructTags configures if native types field names should be overridable by CEL struct tags. diff --git a/ext/native_test.go b/ext/native_test.go index 4a62ec04..2c0ce3d3 100644 --- a/ext/native_test.go +++ b/ext/native_test.go @@ -941,6 +941,13 @@ func TestNativeStructEmbedded(t *testing.T) { } } +func TestNativeTypesVersion(t *testing.T) { + _, err := cel.NewEnv(NativeTypes(NativeTypesVersion(0))) + if err != nil { + t.Fatalf("NewEnv(NativeTypes(NativeTypesVersion(0))) failed: %v", err) + } +} + // testEnv initializes the test environment common to all tests. func testNativeEnv(t *testing.T, opts ...any) *cel.Env { t.Helper() diff --git a/ext/protos.go b/ext/protos.go index 68796f60..b09db25b 100644 --- a/ext/protos.go +++ b/ext/protos.go @@ -15,6 +15,8 @@ package ext import ( + "math" + "github.com/google/cel-go/cel" "github.com/google/cel-go/common/ast" ) @@ -49,8 +51,23 @@ import ( // Examples: // // proto.hasExt(msg, google.expr.proto2.test.int32_ext) // returns true || false -func Protos() cel.EnvOption { - return cel.Lib(protoLib{}) +func Protos(options ...ProtosOption) cel.EnvOption { + l := &protoLib{version: math.MaxUint32} + for _, o := range options { + l = o(l) + } + return cel.Lib(l) +} + +// ProtosOption declares a functional operator for configuring protobuf utilities. +type ProtosOption func(*protoLib) *protoLib + +// ProtosVersion sets the library version for extensions for protobuf utilities. +func ProtosVersion(version uint32) ProtosOption { + return func(lib *protoLib) *protoLib { + lib.version = version + return lib + } } var ( @@ -59,7 +76,9 @@ var ( getExtension = "getExt" ) -type protoLib struct{} +type protoLib struct { + version uint32 +} // LibraryName implements the SingletonLibrary interface method. func (protoLib) LibraryName() string { diff --git a/ext/protos_test.go b/ext/protos_test.go index 14646c3c..02e74658 100644 --- a/ext/protos_test.go +++ b/ext/protos_test.go @@ -219,6 +219,13 @@ func TestProtosWithExtension(t *testing.T) { } } +func TestProtosVersion(t *testing.T) { + _, err := cel.NewEnv(Protos(ProtosVersion(0))) + if err != nil { + t.Fatalf("ProtosVersion(0) failed: %v", err) + } +} + // msgWithExtensions generates a new example message with all possible extensions set. func msgWithExtensions() *proto2pb.ExampleType { msg := &proto2pb.ExampleType{ diff --git a/ext/sets.go b/ext/sets.go index 7e941665..9a9ef6ee 100644 --- a/ext/sets.go +++ b/ext/sets.go @@ -77,11 +77,28 @@ import ( // sets.intersects([1], []) // false // sets.intersects([1], [1, 2]) // true // sets.intersects([[1], [2, 3]], [[1, 2], [2, 3.0]]) // true -func Sets() cel.EnvOption { - return cel.Lib(setsLib{}) +func Sets(options ...SetsOption) cel.EnvOption { + l := &setsLib{} + for _, o := range options { + l = o(l) + } + return cel.Lib(l) +} + +// SetsOption declares a functional operator for configuring set extensions. +type SetsOption func(*setsLib) *setsLib + +// SetsVersion sets the library version for set extensions. +func SetsVersion(version uint32) SetsOption { + return func(lib *setsLib) *setsLib { + lib.version = version + return lib + } } -type setsLib struct{} +type setsLib struct { + version uint32 +} // LibraryName implements the SingletonLibrary interface method. func (setsLib) LibraryName() string { diff --git a/ext/sets_test.go b/ext/sets_test.go index 42dcad9f..70ca2393 100644 --- a/ext/sets_test.go +++ b/ext/sets_test.go @@ -497,6 +497,13 @@ func TestSetsMembershipRewriter(t *testing.T) { } } +func TestSetsVersion(t *testing.T) { + _, err := cel.NewEnv(Sets(SetsVersion(0))) + if err != nil { + t.Fatalf("SetsVersion(0) failed: %v", err) + } +} + func testSetsEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { t.Helper() baseOpts := []cel.EnvOption{cel.EnableMacroCallTracking(), Sets()} diff --git a/ext/strings_test.go b/ext/strings_test.go index 1b89f981..37bbc352 100644 --- a/ext/strings_test.go +++ b/ext/strings_test.go @@ -387,6 +387,12 @@ func TestStringsVersions(t *testing.T) { "quote": `strings.quote('\a \b "double quotes"') == '"\\a \\b \\"double quotes\\""'`, }, }, + { + version: 3, + supportedFunctions: map[string]string{ + "reverse": "'taco'.reverse() == 'ocat'", + }, + }, } for _, lib := range versionCases { env, err := cel.NewEnv(Strings(StringsVersion(lib.version))) @@ -427,10 +433,6 @@ func TestStringsVersions(t *testing.T) { } } -func version(v uint32) *uint32 { - return &v -} - func TestStringsWithExtension(t *testing.T) { env, err := cel.NewEnv(Strings()) if err != nil { diff --git a/policy/config.go b/policy/config.go index f6cd3fdc..cab8da1c 100644 --- a/policy/config.go +++ b/policy/config.go @@ -264,30 +264,30 @@ func (od *OverloadDecl) AsFunctionOption(baseEnv *cel.Env) (cel.FunctionOpt, err var extFactories = map[string]ExtensionFactory{ "bindings": func(version uint32) cel.EnvOption { - return ext.Bindings() + return ext.Bindings(ext.BindingsVersion(version)) }, "encoders": func(version uint32) cel.EnvOption { - return ext.Encoders() + return ext.Encoders(ext.EncodersVersion(version)) }, "lists": func(version uint32) cel.EnvOption { - return ext.Lists() + return ext.Lists(ext.ListsVersion(version)) }, "math": func(version uint32) cel.EnvOption { - return ext.Math() + return ext.Math(ext.MathVersion(version)) }, "optional": func(version uint32) cel.EnvOption { return cel.OptionalTypes(cel.OptionalTypesVersion(version)) }, "protos": func(version uint32) cel.EnvOption { - return ext.Protos() + return ext.Protos(ext.ProtosVersion(version)) }, "sets": func(version uint32) cel.EnvOption { - return ext.Sets() + return ext.Sets(ext.SetsVersion(version)) }, "strings": func(version uint32) cel.EnvOption { return ext.Strings(ext.StringsVersion(version)) }, "two-var-comprehensions": func(version uint32) cel.EnvOption { - return ext.TwoVarComprehensions() + return ext.TwoVarComprehensions(ext.TwoVarComprehensionsVersion(version)) }, }