From 21976cf0b6a475ecb1329593843351deb4efc3d6 Mon Sep 17 00:00:00 2001 From: Joe Betz Date: Mon, 8 Jul 2024 20:03:32 -0400 Subject: [PATCH] Add function accessor to Env (#978) * Add function accessor to env * Add test to strings verifying that existing versions have a known list of functions --- cel/env.go | 11 ++++++++++ cel/env_test.go | 16 +++++++++++++++ ext/strings_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+) diff --git a/cel/env.go b/cel/env.go index 504e5d3d..17c52180 100644 --- a/cel/env.go +++ b/cel/env.go @@ -412,6 +412,17 @@ func (e *Env) Libraries() []string { return libraries } +// HasFunction returns whether a specific function has been configured in the environment +func (e *Env) HasFunction(functionName string) bool { + _, ok := e.functions[functionName] + return ok +} + +// Functions returns map of Functions, keyed by function name, that have been configured in the environment. +func (e *Env) Functions() map[string]*decls.FunctionDecl { + return e.functions +} + // HasValidator returns whether a specific ASTValidator has been configured in the environment. func (e *Env) HasValidator(name string) bool { for _, v := range e.validators { diff --git a/cel/env_test.go b/cel/env_test.go index d089bb90..3c3235c5 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -271,6 +271,22 @@ func TestLibraries(t *testing.T) { } } +func TestFunctions(t *testing.T) { + e, err := NewEnv(OptionalTypes()) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + for _, expected := range []string{"optional.of", "or"} { + if !e.HasFunction(expected) { + t.Errorf("Expected HasFunction() to return true for '%s'", expected) + } + + if _, ok := e.Functions()[expected]; !ok { + t.Errorf("Expected Functions() to include '%s'", expected) + } + } +} + func BenchmarkNewCustomEnvLazy(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/ext/strings_test.go b/ext/strings_test.go index 564fd10e..8f7e416f 100644 --- a/ext/strings_test.go +++ b/ext/strings_test.go @@ -1746,6 +1746,56 @@ func evalWithCEL(input string, expectedRuntimeCost uint64, expectedEstimatedCost return out.Value().(string) } +func TestFunctionsForVersions(t *testing.T) { + tests := []struct { + version uint32 + introducedFunctions []string + }{ + { + version: 0, + introducedFunctions: []string{"lastIndexOf", "lowerAscii", "split", "trim", "join", "charAt", "indexOf", "replace", "substring", "upperAscii"}, + }, + { + version: 1, + introducedFunctions: []string{"strings.quote", "format"}, + }, + { + version: 2, + introducedFunctions: []string{}, // join changed, no functions added + }, + { + version: 3, + introducedFunctions: []string{"reverse"}, + }, + } + var functions []string + for _, tt := range tests { + functions = append(functions, tt.introducedFunctions...) + t.Run(fmt.Sprintf("version %d", tt.version), func(t *testing.T) { + e, err := cel.NewCustomEnv(Strings(StringsVersion(tt.version))) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + if len(functions) != len(e.Functions()) { + var functionNames []string + for name := range e.Functions() { + functionNames = append(functionNames, name) + } + t.Fatalf("Expected functions: %#v, got %#v", functions, functionNames) + } + for _, expected := range functions { + if !e.HasFunction(expected) { + t.Errorf("Expected HasFunction() to return true for '%s'", expected) + } + + if _, ok := e.Functions()[expected]; !ok { + t.Errorf("Expected Functions() to include '%s'", expected) + } + } + }) + } +} + func FuzzQuote(f *testing.F) { tests := []string{ "this is a test",