From 31429502dd15de4e25ef63e70775b8c57a77bbde Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 7 Oct 2024 22:11:31 -0700 Subject: [PATCH] Two-variable comprehension support (#1034) Two-variable comprehensions with support for transformMapEntry --- checker/checker.go | 29 ++- ext/README.md | 139 ++++++++++++- ext/comprehensions.go | 406 +++++++++++++++++++++++++++++++++++++ ext/comprehensions_test.go | 346 +++++++++++++++++++++++++++++++ ext/guards.go | 16 +- ext/lists.go | 1 + parser/helper.go | 34 +++- parser/macro.go | 38 +++- 8 files changed, 983 insertions(+), 26 deletions(-) create mode 100644 ext/comprehensions.go create mode 100644 ext/comprehensions_test.go diff --git a/checker/checker.go b/checker/checker.go index 57fb3ce5..0603cfa3 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -496,16 +496,32 @@ func (c *checker) checkComprehension(e ast.Expr) { comp := e.AsComprehension() c.check(comp.IterRange()) c.check(comp.AccuInit()) - accuType := c.getType(comp.AccuInit()) rangeType := substitute(c.mappings, c.getType(comp.IterRange()), false) - var varType *types.Type + // Create a scope for the comprehension since it has a local accumulation variable. + // This scope will contain the accumulation variable used to compute the result. + accuType := c.getType(comp.AccuInit()) + c.env = c.env.enterScope() + c.env.AddIdents(decls.NewVariable(comp.AccuVar(), accuType)) + + var varType, var2Type *types.Type switch rangeType.Kind() { case types.ListKind: + // varType represents the list element type for one-variable comprehensions. varType = rangeType.Parameters()[0] + if comp.HasIterVar2() { + // varType represents the list index (int) for two-variable comprehensions, + // and var2Type represents the list element type. + var2Type = varType + varType = types.IntType + } case types.MapKind: - // Ranges over the keys. + // varType represents the map entry key for all comprehension types. varType = rangeType.Parameters()[0] + if comp.HasIterVar2() { + // var2Type represents the map entry value for two-variable comprehensions. + var2Type = rangeType.Parameters()[1] + } case types.DynKind, types.ErrorKind, types.TypeParamKind: // Set the range type to DYN to prevent assignment to a potentially incorrect type // at a later point in type-checking. The isAssignable call will update the type @@ -518,13 +534,12 @@ func (c *checker) checkComprehension(e ast.Expr) { varType = types.ErrorType } - // Create a scope for the comprehension since it has a local accumulation variable. - // This scope will contain the accumulation variable used to compute the result. - c.env = c.env.enterScope() - c.env.AddIdents(decls.NewVariable(comp.AccuVar(), accuType)) // Create a block scope for the loop. c.env = c.env.enterScope() c.env.AddIdents(decls.NewVariable(comp.IterVar(), varType)) + if comp.HasIterVar2() { + c.env.AddIdents(decls.NewVariable(comp.IterVar2(), var2Type)) + } // Check the variable references in the condition and step. c.check(comp.LoopCondition()) c.assertType(comp.LoopCondition(), types.BoolType) diff --git a/ext/README.md b/ext/README.md index b6f88a1d..abd70d59 100644 --- a/ext/README.md +++ b/ext/README.md @@ -3,12 +3,12 @@ CEL extensions are a related set of constants, functions, macros, or other features which may not be covered by the core CEL spec. -## Bindings +## Bindings Returns a cel.EnvOption to configure support for local variable bindings in expressions. -# Cel.Bind +### Cel.Bind Binds a simple identifier to an initialization expression which may be used in a subsequenct result expression. Bindings may also be nested within each @@ -19,11 +19,11 @@ other. Examples: cel.bind(a, 'hello', - cel.bind(b, 'world', a + b + b + a)) // "helloworldworldhello" + cel.bind(b, 'world', a + b + b + a)) // "helloworldworldhello" // Avoid a list allocation within the exists comprehension. cel.bind(valid_values, [a, b, c], - [d, e, f].exists(elem, elem in valid_values)) + [d, e, f].exists(elem, elem in valid_values)) Local bindings are not guaranteed to be evaluated before use. @@ -684,3 +684,134 @@ Examples: 'gums'.reverse() // returns 'smug' 'John Smith'.reverse() // returns 'htimS nhoJ' + +## TwoVarComprehensions + +TwoVarComprehensions introduces support for two-variable comprehensions. + +The two-variable form of comprehensions looks similar to the one-variable +counterparts. Where possible, the same macro names were used and additional +macro signatures added. The notable distinction for two-variable comprehensions +is the introduction of `transformList`, `transformMap`, and `transformMapEntry` +support for list and map types rather than the more traditional `map` and +`filter` macros. + +### All + +Comprehension which tests whether all elements in the list or map satisfy a +given predicate. The `all` macro evaluates in a manner consistent with logical +AND and will short-circuit when encountering a `false` value. + + .all(indexVar, valueVar, ) -> bool + .all(keyVar, valueVar, ) -> bool + +Examples: + + [1, 2, 3].all(i, j, i < j) // returns true + {'hello': 'world', 'taco': 'taco'}.all(k, v, k != v) // returns false + + // Combines two-variable comprehension with single variable + {'h': ['hello', 'hi'], 'j': ['joke', 'jog']} + .all(k, vals, vals.all(v, v.startsWith(k))) // returns true + +### Exists + +Comprehension which tests whether any element in a list or map exists which +satisfies a given predicate. The `exists` macro evaluates in a manner consistent +with logical OR and will short-circuit when encountering a `true` value. + + .exists(indexVar, valueVar, ) -> bool + .exists(keyVar, valueVar, ) -> bool + +Examples: + + {'greeting': 'hello', 'farewell': 'goodbye'} + .exists(k, v, k.startsWith('good') || v.endsWith('bye')) // returns true + [1, 2, 4, 8, 16].exists(i, v, v == 1024 && i == 10) // returns false + +### ExistsOne + +Comprehension which tests whether exactly one element in a list or map exists +which satisfies a given predicate expression. This comprehension does not +short-circuit in keeping with the one-variable exists one macro semantics. + + .existsOne(indexVar, valueVar, ) + .existsOne(keyVar, valueVar, ) + +This macro may also be used with the `exists_one` function name, for +compatibility with the one-variable macro of the same name. + +Examples: + + [1, 2, 1, 3, 1, 4].existsOne(i, v, i == 1 || v == 1) // returns false + [1, 1, 2, 2, 3, 3].existsOne(i, v, i == 2 && v == 2) // returns true + {'i': 0, 'j': 1, 'k': 2}.existsOne(i, v, i == 'l' || v == 1) // returns true + +### TransformList + +Comprehension which converts a map or a list into a list value. The output +expression of the comprehension determines the contents of the output list. +Elements in the list may optionally be filtered according to a predicate +expression, where elements that satisfy the predicate are transformed. + + .transformList(indexVar, valueVar, ) + .transformList(indexVar, valueVar, , ) + .transformList(keyVar, valueVar, ) + .transformList(keyVar, valueVar, , ) + +Examples: + + [1, 2, 3].transformList(indexVar, valueVar, + (indexVar * valueVar) + valueVar) // returns [1, 4, 9] + [1, 2, 3].transformList(indexVar, valueVar, indexVar % 2 == 0 + (indexVar * valueVar) + valueVar) // returns [1, 9] + {'greeting': 'hello', 'farewell': 'goodbye'} + .transformList(k, _, k) // returns ['greeting', 'farewell'] + {'greeting': 'hello', 'farewell': 'goodbye'} + .transformList(_, v, v) // returns ['hello', 'goodbye'] + +### TransformMap + +Comprehension which converts a map or a list into a map value. The output +expression of the comprehension determines the value of the output map entry; +however, the key remains fixed. Elements in the map may optionally be filtered +according to a predicate expression, where elements that satisfy the predicate +are transformed. + + .transformMap(indexVar, valueVar, ) + .transformMap(indexVar, valueVar, , ) + .transformMap(keyVar, valueVar, ) + .transformMap(keyVar, valueVar, , ) + +Examples: + + [1, 2, 3].transformMap(indexVar, valueVar, + (indexVar * valueVar) + valueVar) // returns {0: 1, 1: 4, 2: 9} + [1, 2, 3].transformMap(indexVar, valueVar, indexVar % 2 == 0 + (indexVar * valueVar) + valueVar) // returns {0: 1, 2: 9} + {'greeting': 'hello'}.transformMap(k, v, v + '!') // returns {'greeting': 'hello!'} + +### TransformMapEntry + +Comprehension which converts a map or a list into a map value; however, this +transform expects the entry expression be a map literal. If the transform +produces an entry which duplicates a key in the target map, the comprehension +will error. + +Elements in the map may optionally be filtered according to a predicate +expression, where elements that satisfy the predicate are transformed. + + .transformMap(indexVar, valueVar, ) + .transformMap(indexVar, valueVar, , ) + .transformMap(keyVar, valueVar, ) + .transformMap(keyVar, valueVar, , ) + +Examples: + + // returns {'hello': 'greeting'} + {'greeting': 'hello'}.transformMapEntry(keyVar, valueVar, {valueVar: keyVar}) + // reverse lookup, require all values in list be unique + [1, 2, 3].transformMapEntry(indexVar, valueVar, {valueVar: indexVar}) + + {'greeting': 'aloha', 'farewell': 'aloha'} + .transformMapEntry(keyVar, valueVar, {valueVar: keyVar}) // error, duplicate key diff --git a/ext/comprehensions.go b/ext/comprehensions.go new file mode 100644 index 00000000..db9316e1 --- /dev/null +++ b/ext/comprehensions.go @@ -0,0 +1,406 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" + "github.com/google/cel-go/parser" +) + +const ( + mapInsert = "cel.@mapInsert" + mapInsertOverloadMap = "@mapInsert_map_map" + mapInsertOverloadKeyValue = "@mapInsert_map_key_value" +) + +// TwoVarComprehensions introduces support for two-variable comprehensions. +// +// The two-variable form of comprehensions looks similar to the one-variable counterparts. +// Where possible, the same macro names were used and additional macro signatures added. +// The notable distinction for two-variable comprehensions is the introduction of +// `transformList`, `transformMap`, and `transformMapEntry` support for list and map types +// rather than the more traditional `map` and `filter` macros. +// +// # All +// +// Comprehension which tests whether all elements in the list or map satisfy a given +// predicate. The `all` macro evaluates in a manner consistent with logical AND and will +// short-circuit when encountering a `false` value. +// +// .all(indexVar, valueVar, ) -> bool +// .all(keyVar, valueVar, ) -> bool +// +// Examples: +// +// [1, 2, 3].all(i, j, i < j) // returns true +// {'hello': 'world', 'taco': 'taco'}.all(k, v, k != v) // returns false +// +// // Combines two-variable comprehension with single variable +// {'h': ['hello', 'hi'], 'j': ['joke', 'jog']} +// .all(k, vals, vals.all(v, v.startsWith(k))) // returns true +// +// # Exists +// +// Comprehension which tests whether any element in a list or map exists which satisfies +// a given predicate. The `exists` macro evaluates in a manner consistent with logical OR +// and will short-circuit when encountering a `true` value. +// +// .exists(indexVar, valueVar, ) -> bool +// .exists(keyVar, valueVar, ) -> bool +// +// Examples: +// +// {'greeting': 'hello', 'farewell': 'goodbye'} +// .exists(k, v, k.startsWith('good') || v.endsWith('bye')) // returns true +// [1, 2, 4, 8, 16].exists(i, v, v == 1024 && i == 10) // returns false +// +// # ExistsOne +// +// Comprehension which tests whether exactly one element in a list or map exists which +// satisfies a given predicate expression. This comprehension does not short-circuit in +// keeping with the one-variable exists one macro semantics. +// +// .existsOne(indexVar, valueVar, ) +// .existsOne(keyVar, valueVar, ) +// +// This macro may also be used with the `exists_one` function name, for compatibility +// with the one-variable macro of the same name. +// +// Examples: +// +// [1, 2, 1, 3, 1, 4].existsOne(i, v, i == 1 || v == 1) // returns false +// [1, 1, 2, 2, 3, 3].existsOne(i, v, i == 2 && v == 2) // returns true +// {'i': 0, 'j': 1, 'k': 2}.existsOne(i, v, i == 'l' || v == 1) // returns true +// +// # TransformList +// +// Comprehension which converts a map or a list into a list value. The output expression +// of the comprehension determines the contents of the output list. Elements in the list +// may optionally be filtered according to a predicate expression, where elements that +// satisfy the predicate are transformed. +// +// .transformList(indexVar, valueVar, ) +// .transformList(indexVar, valueVar, , ) +// .transformList(keyVar, valueVar, ) +// .transformList(keyVar, valueVar, , ) +// +// Examples: +// +// [1, 2, 3].transformList(indexVar, valueVar, +// (indexVar * valueVar) + valueVar) // returns [1, 4, 9] +// [1, 2, 3].transformList(indexVar, valueVar, indexVar % 2 == 0 +// (indexVar * valueVar) + valueVar) // returns [1, 9] +// {'greeting': 'hello', 'farewell': 'goodbye'} +// .transformList(k, _, k) // returns ['greeting', 'farewell'] +// {'greeting': 'hello', 'farewell': 'goodbye'} +// .transformList(_, v, v) // returns ['hello', 'goodbye'] +// +// # TransformMap +// +// Comprehension which converts a map or a list into a map value. The output expression +// of the comprehension determines the value of the output map entry; however, the key +// remains fixed. Elements in the map may optionally be filtered according to a predicate +// expression, where elements that satisfy the predicate are transformed. +// +// .transformMap(indexVar, valueVar, ) +// .transformMap(indexVar, valueVar, , ) +// .transformMap(keyVar, valueVar, ) +// .transformMap(keyVar, valueVar, , ) +// +// Examples: +// +// [1, 2, 3].transformMap(indexVar, valueVar, +// (indexVar * valueVar) + valueVar) // returns {0: 1, 1: 4, 2: 9} +// [1, 2, 3].transformMap(indexVar, valueVar, indexVar % 2 == 0 +// (indexVar * valueVar) + valueVar) // returns {0: 1, 2: 9} +// {'greeting': 'hello'}.transformMap(k, v, v + '!') // returns {'greeting': 'hello!'} +// +// # TransformMapEntry +// +// Comprehension which converts a map or a list into a map value; however, this transform +// expects the entry expression be a map literal. If the tranform produces an entry which +// duplicates a key in the target map, the comprehension will error. +// +// Elements in the map may optionally be filtered according to a predicate expression, where +// elements that satisfy the predicate are transformed. +// +// .transformMap(indexVar, valueVar, ) +// .transformMap(indexVar, valueVar, , ) +// .transformMap(keyVar, valueVar, ) +// .transformMap(keyVar, valueVar, , ) +// +// Examples: +// +// // returns {'hello': 'greeting'} +// {'greeting': 'hello'}.transformMapEntry(keyVar, valueVar, {valueVar: keyVar}) +// // reverse lookup, require all values in list be unique +// [1, 2, 3].transformMapEntry(indexVar, valueVar, {valueVar: indexVar}) +// +// {'greeting': 'aloha', 'farewell': 'aloha'} +// .transformMapEntry(keyVar, valueVar, {valueVar: keyVar}) // error, duplicate key +func TwoVarComprehensions() cel.EnvOption { + return cel.Lib(compreV2Lib{}) +} + +type compreV2Lib struct{} + +// LibraryName implements that SingletonLibrary interface method. +func (compreV2Lib) LibraryName() string { + return "cel.lib.ext.comprev2" +} + +// CompileOptions implements the cel.Library interface method. +func (compreV2Lib) CompileOptions() []cel.EnvOption { + kType := cel.TypeParamType("K") + vType := cel.TypeParamType("V") + mapKVType := cel.MapType(kType, vType) + opts := []cel.EnvOption{ + cel.Macros( + cel.ReceiverMacro("all", 3, quantifierAll), + cel.ReceiverMacro("exists", 3, quantifierExists), + cel.ReceiverMacro("existsOne", 3, quantifierExistsOne), + cel.ReceiverMacro("exists_one", 3, quantifierExistsOne), + cel.ReceiverMacro("transformList", 3, transformList), + cel.ReceiverMacro("transformList", 4, transformList), + cel.ReceiverMacro("transformMap", 3, transformMap), + cel.ReceiverMacro("transformMap", 4, transformMap), + cel.ReceiverMacro("transformMapEntry", 3, transformMapEntry), + cel.ReceiverMacro("transformMapEntry", 4, transformMapEntry), + ), + cel.Function(mapInsert, + cel.Overload(mapInsertOverloadKeyValue, []*cel.Type{mapKVType, kType, vType}, mapKVType, + cel.FunctionBinding(func(args ...ref.Val) ref.Val { + m := args[0].(traits.Mapper) + k := args[1] + v := args[2] + return types.InsertMapKeyValue(m, k, v) + })), + cel.Overload(mapInsertOverloadMap, []*cel.Type{mapKVType, mapKVType}, mapKVType, + cel.BinaryBinding(func(targetMap, updateMap ref.Val) ref.Val { + tm := targetMap.(traits.Mapper) + um := updateMap.(traits.Mapper) + umIt := um.Iterator() + for umIt.HasNext() == types.True { + k := umIt.Next() + updateOrErr := types.InsertMapKeyValue(tm, k, um.Get(k)) + if types.IsError(updateOrErr) { + return updateOrErr + } + tm = updateOrErr.(traits.Mapper) + } + return tm + })), + ), + } + return opts +} + +// ProgramOptions implements the cel.Library interface method +func (compreV2Lib) ProgramOptions() []cel.ProgramOption { + return []cel.ProgramOption{} +} + +func quantifierAll(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { + iterVar1, err := extractIterVar(mef, args[0]) + if err != nil { + return nil, err + } + iterVar2, err := extractIterVar(mef, args[1]) + if err != nil { + return nil, err + } + return mef.NewComprehensionTwoVar( + target, + iterVar1, + iterVar2, + parser.AccumulatorName, + /*accuInit=*/ mef.NewLiteral(types.True), + /*condition=*/ mef.NewCall(operators.NotStrictlyFalse, mef.NewAccuIdent()), + /*step=*/ mef.NewCall(operators.LogicalAnd, mef.NewAccuIdent(), args[2]), + /*result=*/ mef.NewAccuIdent(), + ), nil +} + +func quantifierExists(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { + iterVar1, err := extractIterVar(mef, args[0]) + if err != nil { + return nil, err + } + iterVar2, err := extractIterVar(mef, args[1]) + if err != nil { + return nil, err + } + return mef.NewComprehensionTwoVar( + target, + iterVar1, + iterVar2, + parser.AccumulatorName, + /*accuInit=*/ mef.NewLiteral(types.False), + /*condition=*/ mef.NewCall(operators.NotStrictlyFalse, mef.NewCall(operators.LogicalNot, mef.NewAccuIdent())), + /*step=*/ mef.NewCall(operators.LogicalOr, mef.NewAccuIdent(), args[2]), + /*result=*/ mef.NewAccuIdent(), + ), nil +} + +func quantifierExistsOne(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { + iterVar1, err := extractIterVar(mef, args[0]) + if err != nil { + return nil, err + } + iterVar2, err := extractIterVar(mef, args[1]) + if err != nil { + return nil, err + } + return mef.NewComprehensionTwoVar( + target, + iterVar1, + iterVar2, + parser.AccumulatorName, + /*accuInit=*/ mef.NewLiteral(types.Int(0)), + /*condition=*/ mef.NewLiteral(types.True), + /*step=*/ mef.NewCall(operators.Conditional, args[2], + mef.NewCall(operators.Add, mef.NewAccuIdent(), mef.NewLiteral(types.Int(1))), + mef.NewAccuIdent()), + /*result=*/ mef.NewCall(operators.Equals, mef.NewAccuIdent(), mef.NewLiteral(types.Int(1))), + ), nil +} + +func transformList(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { + iterVar1, err := extractIterVar(mef, args[0]) + if err != nil { + return nil, err + } + iterVar2, err := extractIterVar(mef, args[1]) + if err != nil { + return nil, err + } + + var transform ast.Expr + var filter ast.Expr + if len(args) == 4 { + filter = args[2] + transform = args[3] + } else { + filter = nil + transform = args[2] + } + + // __result__ = __result__ + [transform] + step := mef.NewCall(operators.Add, mef.NewAccuIdent(), mef.NewList(transform)) + if filter != nil { + // __result__ = (filter) ? __result__ + [transform] : __result__ + step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent()) + } + + return mef.NewComprehensionTwoVar( + target, + iterVar1, + iterVar2, + parser.AccumulatorName, + /*accuInit=*/ mef.NewList(), + /*condition=*/ mef.NewLiteral(types.True), + step, + /*result=*/ mef.NewAccuIdent(), + ), nil +} + +func transformMap(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { + iterVar1, err := extractIterVar(mef, args[0]) + if err != nil { + return nil, err + } + iterVar2, err := extractIterVar(mef, args[1]) + if err != nil { + return nil, err + } + + var transform ast.Expr + var filter ast.Expr + if len(args) == 4 { + filter = args[2] + transform = args[3] + } else { + filter = nil + transform = args[2] + } + + // __result__ = cel.@mapInsert(__result__, iterVar1, transform) + step := mef.NewCall(mapInsert, mef.NewAccuIdent(), mef.NewIdent(iterVar1), transform) + if filter != nil { + // __result__ = (filter) ? cel.@mapInsert(__result__, iterVar1, transform) : __result__ + step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent()) + } + return mef.NewComprehensionTwoVar( + target, + iterVar1, + iterVar2, + parser.AccumulatorName, + /*accuInit=*/ mef.NewMap(), + /*condition=*/ mef.NewLiteral(types.True), + step, + /*result=*/ mef.NewAccuIdent(), + ), nil +} + +func transformMapEntry(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { + iterVar1, err := extractIterVar(mef, args[0]) + if err != nil { + return nil, err + } + iterVar2, err := extractIterVar(mef, args[1]) + if err != nil { + return nil, err + } + + var transform ast.Expr + var filter ast.Expr + if len(args) == 4 { + filter = args[2] + transform = args[3] + } else { + filter = nil + transform = args[2] + } + + // __result__ = cel.@mapInsert(__result__, transform) + step := mef.NewCall(mapInsert, mef.NewAccuIdent(), transform) + if filter != nil { + // __result__ = (filter) ? cel.@mapInsert(__result__, transform) : __result__ + step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent()) + } + return mef.NewComprehensionTwoVar( + target, + iterVar1, + iterVar2, + parser.AccumulatorName, + /*accuInit=*/ mef.NewMap(), + /*condition=*/ mef.NewLiteral(types.True), + step, + /*result=*/ mef.NewAccuIdent(), + ), nil +} + +func extractIterVar(meh cel.MacroExprFactory, target ast.Expr) (string, *cel.Error) { + iterVar, found := extractIdent(target) + if !found { + return "", meh.NewError(target.ID(), "argument must be a simple name") + } + return iterVar, nil +} diff --git a/ext/comprehensions_test.go b/ext/comprehensions_test.go new file mode 100644 index 00000000..7f257859 --- /dev/null +++ b/ext/comprehensions_test.go @@ -0,0 +1,346 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "strings" + "testing" + + "github.com/google/cel-go/cel" +) + +func TestTwoVarComprehensions(t *testing.T) { + compreTests := []struct { + expr string + }{ + // list.all() + {expr: "[1, 2, 3, 4].all(i, v, i < 5 && v > 0)"}, + {expr: "[1, 2, 3, 4].all(i, v, i < v)"}, + {expr: "[1, 2, 3, 4].all(i, v, i > v) == false"}, + {expr: ` + cel.bind(listA, [1, 2, 3, 4], + cel.bind(listB, [1, 2, 3, 4, 5], + listA.all(i, v, listB[?i].hasValue() && listB[i] == v) + )) + `}, + {expr: ` + cel.bind(listA, [1, 2, 3, 4, 5, 6], + cel.bind(listB, [1, 2, 3, 4, 5], + listA.all(i, v, listB[?i].hasValue() && listB[i] == v) + )) == false + `}, + // list.exists() + {expr: ` + cel.bind(l, ['hello', 'world', 'hello!', 'worlds'], + l.exists(i, v, + v.startsWith('hello') && l[?(i+1)].optMap(next, next.endsWith('world')).orValue(false) + ) + ) + `}, + // list.existsOne() + {expr: ` + cel.bind(l, ['hello', 'world', 'hello!', 'worlds'], + l.existsOne(i, v, + v.startsWith('hello') && l[?(i+1)].optMap(next, next.endsWith('world')).orValue(false) + ) + ) + `}, + {expr: ` + cel.bind(l, ['hello', 'goodbye', 'hello!', 'goodbye'], + l.exists_one(i, v, + v.startsWith('hello') && l[?(i+1)].optMap(next, next == "goodbye").orValue(false) + ) + ) == false + `}, + // list.transformList() + {expr: ` + ['Hello', 'world'].transformList(i, v, "[%d]%s".format([i, v.lowerAscii()])) == ["[0]hello", "[1]world"] + `}, + {expr: ` + ['hello', 'world'].transformList(i, v, v.startsWith('greeting'), "[%d]%s".format([i, v])) == [] + `}, + {expr: ` + [1, 2, 3].transformList(indexVar, valueVar, (indexVar * valueVar) + valueVar) == [1, 4, 9] + `}, + {expr: ` + [1, 2, 3].transformList(indexVar, valueVar, indexVar % 2 == 0, (indexVar * valueVar) + valueVar) == [1, 9] + `}, + // list.transformMap() + {expr: ` + ['Hello', 'world'].transformMap(i, v, [v.lowerAscii()]) == {0: ['hello'], 1: ['world']} + `}, + {expr: ` + // round-tripping example + ['world', 'Hello'].transformMap(i, v, [v.lowerAscii()]) + .transformList(k, v, v) // extract the list back form the map + .flatten() + .sort() == ['hello', 'world'] + `}, + {expr: ` + [1, 2, 3].transformMap(indexVar, valueVar, + (indexVar * valueVar) + valueVar) == {0: 1, 1: 4, 2: 9} + `}, + {expr: ` + [1, 2, 3].transformMap(indexVar, valueVar, indexVar % 2 == 0, + (indexVar * valueVar) + valueVar) == {0: 1, 2: 9} + `}, + // list.transformMapEntry() + {expr: ` + "key1:value1 key2:value2 key3:value3".split(" ") + .transformMapEntry(i, v, + cel.bind(entry, v.split(":"), + entry.size() == 2 ? {entry[0]: entry[1]} : {} + ) + ) == {'key1': 'value1', 'key2': 'value2', 'key3': 'value3'} + `}, + {expr: ` + "key1:value1:extra key2:value2 key3".split(" ") + .transformMapEntry(i, v, + cel.bind(entry, v.split(":"), {?entry[0]: entry[?1]}) + ) == {'key1': 'value1', 'key2': 'value2'} + `}, + // map.all() + {expr: ` + {'hello': 'world', 'hello!': 'world'}.all(k, v, k.startsWith('hello') && v == 'world') + `}, + {expr: ` + {'hello': 'world', 'hello!': 'worlds'}.all(k, v, k.startsWith('hello') && v.endsWith('world')) == false + `}, + // map.exists() + {expr: ` + {'hello': 'world', 'hello!': 'worlds'}.exists(k, v, k.startsWith('hello') && v.endsWith('world')) + `}, + // map.existsOne() + {expr: ` + {'hello': 'world', 'hello!': 'worlds'}.existsOne(k, v, k.startsWith('hello') && v.endsWith('world')) + `}, + // map.exists_one() + {expr: ` + {'hello': 'world', 'hello!': 'worlds'}.exists_one(k, v, k.startsWith('hello') && v.endsWith('world')) + `}, + {expr: ` + {'hello': 'world', 'hello!': 'wow, world'}.exists_one(k, v, k.startsWith('hello') && v.endsWith('world')) == false + `}, + // map.transformList() + {expr: ` + {'Hello': 'world'}.transformList(k, v, "%s=%s".format([k.lowerAscii(), v])) == ["hello=world"] + `}, + {expr: ` + {'hello': 'world'}.transformList(k, v, k.startsWith('greeting'), "%s=%s".format([k, v])) == [] + `}, + {expr: ` + {'greeting': 'hello', 'farewell': 'goodbye'} + .transformList(k, _, k).sort() == ['farewell', 'greeting'] + `}, + {expr: ` + {'greeting': 'hello', 'farewell': 'goodbye'} + .transformList(_, v, v).sort() == ['goodbye', 'hello'] + `}, + // map.transformMap() + {expr: ` + {'hello': 'world', 'goodbye': 'cruel world'}.transformMap(k, v, "%s, %s!".format([k, v])) + == {'hello': 'hello, world!', 'goodbye': 'goodbye, cruel world!'} + `}, + {expr: ` + {'hello': 'world', 'goodbye': 'cruel world'}.transformMap(k, v, v.startsWith('world'), "%s, %s!".format([k, v])) + == {'hello': 'hello, world!'} + `}, + // map.transformMapEntry() + {expr: ` + {'hello': 'world', 'greetings': 'tacocat'}.transformMapEntry(k, v, {k.reverse(): v.reverse()}) + == {'olleh': 'dlrow', 'sgniteerg': 'tacocat'} + `}, + {expr: ` + {'hello': 'world', 'greetings': 'tacocat'}.transformMapEntry(k, v, v.reverse() == v, {k.reverse(): v.reverse()}) + == {'sgniteerg': 'tacocat'} + `}, + {expr: ` + {'hello': 'world', 'greetings': 'tacocat'}.transformMapEntry(k, v, {}) == {} + `}, + } + + env := testCompreEnv(t) + for i, tst := range compreTests { + tc := tst + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + var asts []*cel.Ast + pAst, iss := env.Parse(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, pAst) + cAst, iss := env.Check(pAst) + if iss.Err() != nil { + t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, cAst) + + for _, ast := range asts { + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + out, _, err := prg.Eval(cel.NoVars()) + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + if out.Value() != true { + t.Errorf("prg.Eval() got %v, wanted true for expr: %s", out.Value(), tc.expr) + } + } + }) + } +} + +func TestTwoVarComprehensionsStaticErrors(t *testing.T) { + tests := []struct { + expr string + err string + }{ + { + expr: "[].all(i.j, k, i.j < k)", + err: "argument must be a simple name", + }, + { + expr: "[].all(j, i.k, j < i.k)", + err: "argument must be a simple name", + }, + { + expr: "1.all(j, k, j < k)", + err: "cannot be range", + }, + { + expr: "[].exists(i.j, k, i.j < k)", + err: "argument must be a simple name", + }, + { + expr: "[].exists(j, i.k, j < i.k)", + err: "argument must be a simple name", + }, + { + expr: "''.exists(j, k, j < k)", + err: "cannot be range", + }, + { + expr: "[].exists_one(i.j, k, i.j < k)", + err: "argument must be a simple name", + }, + { + expr: "[].existsOne(j, i.k, j < i.k)", + err: "argument must be a simple name", + }, + { + expr: "[].exists_one(i.j, k, i.j < k)", + err: "argument must be a simple name", + }, + { + expr: "''.existsOne(j, k, j < k)", + err: "cannot be range", + }, + { + expr: "[].transformList(i.j, k, i.j + k)", + err: "argument must be a simple name", + }, + { + expr: "[].transformList(j, i.k, j + i.k)", + err: "argument must be a simple name", + }, + { + expr: "{}.transformMap(i.j, k, i.j + k)", + err: "argument must be a simple name", + }, + { + expr: "{}.transformMap(j, i.k, j + i.k)", + err: "argument must be a simple name", + }, + { + expr: "{}.transformMapEntry(j, i.k, {j: i.k})", + err: "argument must be a simple name", + }, + { + expr: "{}.transformMapEntry(i.j, k, {k: i.j})", + err: "argument must be a simple name", + }, + { + expr: "{}.transformMapEntry(j, k, 'bad filter', {k: j})", + err: "no matching overload", + }, + { + expr: "[1, 2].transformList(i, v, v % 2 == 0 ? [v] : v)", + err: "no matching overload", + }, + { + expr: `{'hello': 'world', 'greetings': 'tacocat'}.transformMapEntry(k, v, []) == {}`, + err: "no matching overload"}, + } + env := testCompreEnv(t) + for i, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + _, iss := env.Compile(tc.expr) + if iss.Err() == nil || !strings.Contains(iss.Err().Error(), tc.err) { + t.Errorf("env.Compile(%q) got %v, wanted error %v", tc.expr, iss.Err(), tc.err) + } + }) + } +} + +func TestTwoVarComprehensionsRuntimeErrors(t *testing.T) { + tests := []struct { + expr string + err string + }{ + { + expr: "[1, 1].transformMapEntry(i, v, {v: i})", + err: "insert failed: key 1 already exists", + }, + } + env := testCompreEnv(t) + for i, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + ast, iss := env.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Compile(%q) failed with error %v", tc.expr, iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program(ast) failed: %v", err) + } + in := cel.NoVars() + _, _, err = prg.Eval(in) + if err == nil || !strings.Contains(err.Error(), tc.err) { + t.Errorf("prg.Eval() got %v, wanted %v", err, tc.err) + } + }) + } +} + +func testCompreEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { + t.Helper() + baseOpts := []cel.EnvOption{ + TwoVarComprehensions(), + Bindings(), + Lists(), + Strings(), + cel.OptionalTypes(), + cel.EnableMacroCallTracking()} + env, err := cel.NewEnv(append(baseOpts, opts...)...) + if err != nil { + t.Fatalf("cel.NewEnv(TwoVarComprehensions()) failed: %v", err) + } + return env +} diff --git a/ext/guards.go b/ext/guards.go index 2c00bfe3..ccede289 100644 --- a/ext/guards.go +++ b/ext/guards.go @@ -50,14 +50,18 @@ func listStringOrError(strs []string, err error) ref.Val { return types.DefaultTypeAdapter.NativeToValue(strs) } -func macroTargetMatchesNamespace(ns string, target ast.Expr) bool { +func extractIdent(target ast.Expr) (string, bool) { switch target.Kind() { case ast.IdentKind: - if target.AsIdent() != ns { - return false - } - return true + return target.AsIdent(), true default: - return false + return "", false + } +} + +func macroTargetMatchesNamespace(ns string, target ast.Expr) bool { + if id, found := extractIdent(target); found { + return id == ns } + return false } diff --git a/ext/lists.go b/ext/lists.go index aa964ac4..ff4546de 100644 --- a/ext/lists.go +++ b/ext/lists.go @@ -39,6 +39,7 @@ var comparableTypes = []*cel.Type{ // Lists returns a cel.EnvOption to configure extended functions for list manipulation. // As a general note, all indices are zero-based. +// // # Slice // // Returns a new sub-list using the indexes provided. diff --git a/parser/helper.go b/parser/helper.go index 96748358..9f09ead0 100644 --- a/parser/helper.go +++ b/parser/helper.go @@ -115,7 +115,7 @@ func (p *parserHelper) newObjectField(fieldID int64, field string, value ast.Exp func (p *parserHelper) newComprehension(ctx any, iterRange ast.Expr, - iterVar string, + iterVar, accuVar string, accuInit ast.Expr, condition ast.Expr, @@ -125,6 +125,18 @@ func (p *parserHelper) newComprehension(ctx any, p.newID(ctx), iterRange, iterVar, accuVar, accuInit, condition, step, result) } +func (p *parserHelper) newComprehensionTwoVar(ctx any, + iterRange ast.Expr, + iterVar, iterVar2, + accuVar string, + accuInit ast.Expr, + condition ast.Expr, + step ast.Expr, + result ast.Expr) ast.Expr { + return p.exprFactory.NewComprehensionTwoVar( + p.newID(ctx), iterRange, iterVar, iterVar2, accuVar, accuInit, condition, step, result) +} + func (p *parserHelper) newID(ctx any) int64 { if id, isID := ctx.(int64); isID { return id @@ -383,8 +395,10 @@ func (e *exprHelper) Copy(expr ast.Expr) ast.Expr { cond := e.Copy(compre.LoopCondition()) step := e.Copy(compre.LoopStep()) result := e.Copy(compre.Result()) - return e.exprFactory.NewComprehension(copyID, - iterRange, compre.IterVar(), compre.AccuVar(), accuInit, cond, step, result) + // All comprehensions can be represented by the two-variable comprehension since the + // differentiation between one and two-variable is whether the iterVar2 value is non-empty. + return e.exprFactory.NewComprehensionTwoVar(copyID, + iterRange, compre.IterVar(), compre.IterVar2(), compre.AccuVar(), accuInit, cond, step, result) } return e.exprFactory.NewUnspecifiedExpr(copyID) } @@ -432,6 +446,20 @@ func (e *exprHelper) NewComprehension( e.nextMacroID(), iterRange, iterVar, accuVar, accuInit, condition, step, result) } +// NewComprehensionTwoVar implements the ExprHelper interface method. +func (e *exprHelper) NewComprehensionTwoVar( + iterRange ast.Expr, + iterVar, + iterVar2, + accuVar string, + accuInit, + condition, + step, + result ast.Expr) ast.Expr { + return e.exprFactory.NewComprehensionTwoVar( + e.nextMacroID(), iterRange, iterVar, iterVar2, accuVar, accuInit, condition, step, result) +} + // NewIdent implements the ExprHelper interface method. func (e *exprHelper) NewIdent(name string) ast.Expr { return e.exprFactory.NewIdent(e.nextMacroID(), name) diff --git a/parser/macro.go b/parser/macro.go index c1936b69..dc47b420 100644 --- a/parser/macro.go +++ b/parser/macro.go @@ -170,11 +170,12 @@ type ExprHelper interface { // NewStructField creates a new struct field initializer from the field name and value. NewStructField(field string, init ast.Expr, optional bool) ast.EntryExpr - // NewComprehension creates a new comprehension instruction. + // NewComprehension creates a new one-variable comprehension instruction. // // - iterRange represents the expression that resolves to a list or map where the elements or // keys (respectively) will be iterated over. - // - iterVar is the iteration variable name. + // - iterVar is the variable name for the list element value, or the map key, depending on the + // range type. // - accuVar is the accumulation variable name, typically parser.AccumulatorName. // - accuInit is the initial expression whose value will be set for the accuVar prior to // folding. @@ -186,11 +187,36 @@ type ExprHelper interface { // environment in the step and condition expressions. Presently, the name __result__ is commonly // used by built-in macros but this may change in the future. NewComprehension(iterRange ast.Expr, - iterVar string, + iterVar, accuVar string, - accuInit ast.Expr, - condition ast.Expr, - step ast.Expr, + accuInit, + condition, + step, + result ast.Expr) ast.Expr + + // NewComprehensionTwoVar creates a new two-variable comprehension instruction. + // + // - iterRange represents the expression that resolves to a list or map where the elements or + // keys (respectively) will be iterated over. + // - iterVar is the iteration variable assigned to the list index or the map key. + // - iterVar2 is the iteration variable assigned to the list element value or the map key value. + // - accuVar is the accumulation variable name, typically parser.AccumulatorName. + // - accuInit is the initial expression whose value will be set for the accuVar prior to + // folding. + // - condition is the expression to test to determine whether to continue folding. + // - step is the expression to evaluation at the conclusion of a single fold iteration. + // - result is the computation to evaluate at the conclusion of the fold. + // + // The accuVar should not shadow variable names that you would like to reference within the + // environment in the step and condition expressions. Presently, the name __result__ is commonly + // used by built-in macros but this may change in the future. + NewComprehensionTwoVar(iterRange ast.Expr, + iterVar, + iterVar2, + accuVar string, + accuInit, + condition, + step, result ast.Expr) ast.Expr // NewIdent creates an identifier Expr value.