Skip to content

Commit

Permalink
Add versioning support for extensions (#1075)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Dec 5, 2024
1 parent 2e67731 commit 000958d
Show file tree
Hide file tree
Showing 14 changed files with 281 additions and 52 deletions.
30 changes: 24 additions & 6 deletions ext/comprehensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package ext

import (
"fmt"
"math"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
Expand Down Expand Up @@ -159,19 +160,36 @@ const (
//
// {'greeting': 'aloha', 'farewell': 'aloha'}
// .transformMapEntry(keyVar, valueVar, {valueVar: keyVar}) // error, duplicate key
func TwoVarComprehensions() cel.EnvOption {
return cel.Lib(compreV2Lib{})
func TwoVarComprehensions(options ...TwoVarComprehensionsOption) cel.EnvOption {
l := &compreV2Lib{version: math.MaxUint32}
for _, o := range options {
l = o(l)
}
return cel.Lib(l)
}

// TwoVarComprehensionsOption declares a functional operator for configuring two-variable comprehensions.
type TwoVarComprehensionsOption func(*compreV2Lib) *compreV2Lib

// TwoVarComprehensionsVersion sets the library version for two-variable comprehensions.
func TwoVarComprehensionsVersion(version uint32) TwoVarComprehensionsOption {
return func(lib *compreV2Lib) *compreV2Lib {
lib.version = version
return lib
}
}

type compreV2Lib struct{}
type compreV2Lib struct {
version uint32
}

// LibraryName implements that SingletonLibrary interface method.
func (compreV2Lib) LibraryName() string {
func (*compreV2Lib) LibraryName() string {
return "cel.lib.ext.comprev2"
}

// CompileOptions implements the cel.Library interface method.
func (compreV2Lib) CompileOptions() []cel.EnvOption {
func (*compreV2Lib) CompileOptions() []cel.EnvOption {
kType := cel.TypeParamType("K")
vType := cel.TypeParamType("V")
mapKVType := cel.MapType(kType, vType)
Expand Down Expand Up @@ -217,7 +235,7 @@ func (compreV2Lib) CompileOptions() []cel.EnvOption {
}

// ProgramOptions implements the cel.Library interface method
func (compreV2Lib) ProgramOptions() []cel.ProgramOption {
func (*compreV2Lib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}

Expand Down
7 changes: 7 additions & 0 deletions ext/comprehensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,13 @@ func TestTwoVarComprehensionsRuntimeErrors(t *testing.T) {
}
}

func TestTwoVarComprehensionsVersion(t *testing.T) {
_, err := cel.NewEnv(TwoVarComprehensions(TwoVarComprehensionsVersion(0)))
if err != nil {
t.Fatalf("TwoVarComprehensionVersion(0) failed: %v", err)
}
}

func testCompreEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env {
t.Helper()
baseOpts := []cel.EnvOption{
Expand Down
30 changes: 24 additions & 6 deletions ext/encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package ext

import (
"encoding/base64"
"math"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
Expand Down Expand Up @@ -47,17 +48,34 @@ import (
// Examples:
//
// base64.encode(b'hello') // return b'aGVsbG8='
func Encoders() cel.EnvOption {
return cel.Lib(encoderLib{})
func Encoders(options ...EncodersOption) cel.EnvOption {
l := &encoderLib{version: math.MaxUint32}
for _, o := range options {
l = o(l)
}
return cel.Lib(l)
}

// EncodersOption declares a functional operator for configuring encoder extensions.
type EncodersOption func(*encoderLib) *encoderLib

// EncodersVersion sets the library version for encoder extensions.
func EncodersVersion(version uint32) EncodersOption {
return func(lib *encoderLib) *encoderLib {
lib.version = version
return lib
}
}

type encoderLib struct{}
type encoderLib struct {
version uint32
}

func (encoderLib) LibraryName() string {
func (*encoderLib) LibraryName() string {
return "cel.lib.ext.encoders"
}

func (encoderLib) CompileOptions() []cel.EnvOption {
func (*encoderLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Function("base64.decode",
cel.Overload("base64_decode_string", []*cel.Type{cel.StringType}, cel.BytesType,
Expand All @@ -74,7 +92,7 @@ func (encoderLib) CompileOptions() []cel.EnvOption {
}
}

func (encoderLib) ProgramOptions() []cel.ProgramOption {
func (*encoderLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}

Expand Down
7 changes: 7 additions & 0 deletions ext/encoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,10 @@ func TestEncoders(t *testing.T) {
})
}
}

func TestEncodersVersion(t *testing.T) {
_, err := cel.NewEnv(Encoders(EncodersVersion(0)))
if err != nil {
t.Fatalf("EncodersVersion(0) failed: %v", err)
}
}
37 changes: 15 additions & 22 deletions ext/lists.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,10 @@ var comparableTypes = []*cel.Type{
// == ["bar", "foo", "baz"]

func Lists(options ...ListsOption) cel.EnvOption {
l := &listsLib{
version: math.MaxUint32,
}
l := &listsLib{version: math.MaxUint32}
for _, o := range options {
l = o(l)
}

return cel.Lib(l)
}

Expand Down Expand Up @@ -211,9 +208,10 @@ func (lib listsLib) CompileOptions() []cel.EnvOption {
cel.MemberOverload("list_flatten",
[]*cel.Type{listListType}, listType,
cel.UnaryBinding(func(arg ref.Val) ref.Val {
// double-check as type-guards disabled
list, ok := arg.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
return types.ValOrErr(arg, "no such overload: %v.flatten()", arg.Type())
}
flatList, err := flatten(list, 1)
if err != nil {
Expand All @@ -226,13 +224,14 @@ func (lib listsLib) CompileOptions() []cel.EnvOption {
cel.MemberOverload("list_flatten_int",
[]*cel.Type{listDyn, types.IntType}, listDyn,
cel.BinaryBinding(func(arg1, arg2 ref.Val) ref.Val {
// double-check as type-guards disabled
list, ok := arg1.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg1)
return types.ValOrErr(arg1, "no such overload: %v.flatten(%v)", arg1.Type(), arg2.Type())
}
depth, ok := arg2.(types.Int)
if !ok {
return types.MaybeNoSuchOverloadErr(arg2)
return types.ValOrErr(arg1, "no such overload: %v.flatten(%v)", arg1.Type(), arg2.Type())
}
flatList, err := flatten(list, int64(depth))
if err != nil {
Expand Down Expand Up @@ -260,10 +259,8 @@ func (lib listsLib) CompileOptions() []cel.EnvOption {
}),
cel.SingletonUnaryBinding(
func(arg ref.Val) ref.Val {
list, ok := arg.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
// validated by type-guards
list := arg.(traits.Lister)
sorted, err := sortList(list)
if err != nil {
return types.WrapErr(err)
Expand All @@ -287,15 +284,10 @@ func (lib listsLib) CompileOptions() []cel.EnvOption {
)
}),
cel.SingletonBinaryBinding(
func(arg1 ref.Val, arg2 ref.Val) ref.Val {
list, ok := arg1.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg1)
}
keys, ok := arg2.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg2)
}
func(arg1, arg2 ref.Val) ref.Val {
// validated by type-guards
list := arg1.(traits.Lister)
keys := arg2.(traits.Lister)
sorted, err := sortListByAssociatedKeys(list, keys)
if err != nil {
return types.WrapErr(err)
Expand Down Expand Up @@ -498,8 +490,9 @@ func sortByMacro(meh cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (as
if targetKind != ast.ListKind &&
targetKind != ast.SelectKind &&
targetKind != ast.IdentKind &&
targetKind != ast.ComprehensionKind && targetKind != ast.CallKind {
return nil, meh.NewError(target.ID(), fmt.Sprintf("sortBy can only be applied to a list, identifier, comprehension, call or select expression"))
targetKind != ast.ComprehensionKind &&
targetKind != ast.CallKind {
return nil, meh.NewError(target.ID(), "sortBy can only be applied to a list, identifier, comprehension, call or select expression")
}

mapCompr, err := parser.MakeMap(meh, meh.Copy(varIdent), args)
Expand Down
113 changes: 113 additions & 0 deletions ext/lists_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
proto2pb "github.com/google/cel-go/test/proto2pb"
)

Expand Down Expand Up @@ -113,6 +114,118 @@ func TestLists(t *testing.T) {
}
}

func TestListsRuntimeErrors(t *testing.T) {
env, err := cel.NewEnv(Lists(ListsVersion(1)))
if err != nil {
t.Fatalf("cel.NewEnv() failed: %v", err)
}
listsTests := []struct {
expr string
err string
}{
{
expr: "dyn({}).flatten()",
err: "no such overload",
},
{
expr: "dyn({}).flatten(0)",
err: "no such overload",
},
{
expr: "[].flatten(-1)",
err: "level must be non-negative",
},
{
expr: "[].flatten(dyn('1'))",
err: "no such overload",
},
}
for i, tst := range listsTests {
tc := tst
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%q) failed: %v", tc.expr, iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
_, _, err = prg.Eval(cel.NoVars())
if err == nil || !strings.Contains(err.Error(), tc.err) {
t.Errorf("prg.Eval() got %v, wanted %v", err, tc.err)
}
})
}
}

func TestListsVersion(t *testing.T) {
versionCases := []struct {
version uint32
supportedFunctions map[string]string
}{
{
version: 0,
supportedFunctions: map[string]string{
"slice": "[1, 2, 3, 4, 5].slice(2, 4) == [3, 4]",
},
},
{
version: 1,
supportedFunctions: map[string]string{
"flatten": "[[1, 2], [3, 4]].flatten() == [1, 2, 3, 4]",
},
},
{
version: 2,
supportedFunctions: map[string]string{
"distinct": "[1, 2, 2, 1].distinct() == [1, 2]",
"range": "lists.range(5) == [0, 1, 2, 3, 4]",
"reverse": "[1, 2, 3].reverse() == [3, 2, 1]",
"sort": "[2, 1, 3].sort() == [1, 2, 3]",
"sortBy": "[{'field': 'lo'}, {'field': 'hi'}].sortBy(m, m.field) == [{'field': 'hi'}, {'field': 'lo'}]",
},
},
}
for _, lib := range versionCases {
env, err := cel.NewEnv(Lists(ListsVersion(lib.version)))
if err != nil {
t.Fatalf("cel.NewEnv(Lists(ListsVersion(%d))) failed: %v", lib.version, err)
}
t.Run(fmt.Sprintf("version=%d", lib.version), func(t *testing.T) {
for _, tc := range versionCases {
for name, expr := range tc.supportedFunctions {
supported := lib.version >= tc.version
t.Run(fmt.Sprintf("%s-supported=%t", name, supported), func(t *testing.T) {
ast, iss := env.Compile(expr)
if supported {
if iss.Err() != nil {
t.Errorf("unexpected error: %v", iss.Err())
}
} else {
if iss.Err() == nil || !strings.Contains(iss.Err().Error(), "undeclared reference") {
t.Errorf("got error %v, wanted error %s for expr: %s, version: %d", iss.Err(), "undeclared reference", expr, tc.version)
}
return
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(cel.NoVars())
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out != types.True {
t.Errorf("prg.Eval() got %v, wanted true", out)
}
})
}
}
})
}
}

func testListsEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env {
t.Helper()
baseOpts := []cel.EnvOption{
Expand Down
Loading

0 comments on commit 000958d

Please sign in to comment.