From aaf917ec4937cf8422952bc147536135d1241397 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Mon, 14 Apr 2025 15:00:03 -0700 Subject: [PATCH] Adding mapInsert internal runtime function PiperOrigin-RevId: 747580855 --- .../main/java/dev/cel/extensions/BUILD.bazel | 20 ++ .../dev/cel/extensions/CelComprehensions.java | 211 ++++++++++++++++++ .../dev/cel/extensions/CelExtensions.java | 8 +- .../cel/extensions/CelComprehensionsTest.java | 138 ++++++++++++ .../dev/cel/extensions/CelExtensionsTest.java | 3 +- 5 files changed, 378 insertions(+), 2 deletions(-) create mode 100644 extensions/src/main/java/dev/cel/extensions/CelComprehensions.java create mode 100644 extensions/src/test/java/dev/cel/extensions/CelComprehensionsTest.java diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index 82b9fea95..ab6a6c890 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -17,6 +17,7 @@ java_library( ], deps = [ ":bindings", + ":comprehensions", ":encoders", ":lists", ":math", @@ -171,3 +172,22 @@ java_library( "@maven//:com_google_guava_guava", ], ) + +java_library( + name = "comprehensions", + srcs = ["CelComprehensions.java"], + tags = [ + ], + deps = [ + "//checker:checker_builder", + "//common:compiler_common", + "//common/ast", + "//common/types", + "//compiler:compiler_builder", + "//parser:macro", + "//parser:parser_builder", + "//runtime", + "//runtime:function_binding", + "@maven//:com_google_guava_guava", + ], +) diff --git a/extensions/src/main/java/dev/cel/extensions/CelComprehensions.java b/extensions/src/main/java/dev/cel/extensions/CelComprehensions.java new file mode 100644 index 000000000..e7a0ad8f2 --- /dev/null +++ b/extensions/src/main/java/dev/cel/extensions/CelComprehensions.java @@ -0,0 +1,211 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.extensions; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; +import dev.cel.checker.CelCheckerBuilder; +import dev.cel.common.CelFunctionDecl; +import dev.cel.common.CelIssue; +import dev.cel.common.CelOverloadDecl; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.ExprKind.Kind; +import dev.cel.common.types.MapType; +import dev.cel.common.types.TypeParamType; +import dev.cel.compiler.CelCompilerLibrary; +import dev.cel.parser.CelMacro; +import dev.cel.parser.CelMacroExprFactory; +import dev.cel.parser.CelParserBuilder; +import dev.cel.runtime.CelFunctionBinding; +import dev.cel.runtime.CelRuntimeBuilder; +import dev.cel.runtime.CelRuntimeLibrary; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +/** */ +final class CelComprehensions implements CelCompilerLibrary, CelRuntimeLibrary { + + private static final TypeParamType TYPE_PARAM_A = TypeParamType.create("A"); + private static final TypeParamType TYPE_PARAM_B = TypeParamType.create("B"); + private static final MapType MAP_OF_AB = MapType.create(TYPE_PARAM_A, TYPE_PARAM_B); + private static final String CEL_NAMESPACE = "cel"; + private static final String MAP_INSERT_FUNCTION = "cel.@mapInsert"; + private static final String MAP_INSERT_OVERLOAD_MAP_MAP = "@mapInsert_map_map"; + private static final String MAP_INSERT_OVERLOAD_KEY_VALUE = "@mapInsert_map_key_value"; + + public enum Function { + MAP_INSERT( + CelFunctionDecl.newFunctionDeclaration( + MAP_INSERT_FUNCTION, + CelOverloadDecl.newGlobalOverload( + MAP_INSERT_OVERLOAD_MAP_MAP, "map insertion", MAP_OF_AB, MAP_OF_AB, MAP_OF_AB), + CelOverloadDecl.newGlobalOverload( + MAP_INSERT_OVERLOAD_KEY_VALUE, + "map insertion", + MAP_OF_AB, + MAP_OF_AB, + TYPE_PARAM_A, + TYPE_PARAM_B)), + CelFunctionBinding.from( + MAP_INSERT_OVERLOAD_MAP_MAP, Map.class, Map.class, CelComprehensions::mapInsert), + CelFunctionBinding.from( + MAP_INSERT_OVERLOAD_KEY_VALUE, + ImmutableList.of(Map.class, Object.class, Object.class), + CelComprehensions::mapInsert)); + + private final CelFunctionDecl functionDecl; + private final ImmutableSet functionBindings; + + String getFunction() { + return functionDecl.name(); + } + + Function(CelFunctionDecl functionDecl, CelFunctionBinding... functionBindings) { + this.functionDecl = functionDecl; + this.functionBindings = ImmutableSet.copyOf(functionBindings); + } + } + + private final ImmutableSet functions; + + CelComprehensions() { + this.functions = ImmutableSet.copyOf(Function.values()); + } + + CelComprehensions(Set functions) { + this.functions = ImmutableSet.copyOf(functions); + } + + @Override + public void setParserOptions(CelParserBuilder parserBuilder) { + parserBuilder.addMacros( + CelMacro.newReceiverVarArgMacro("mapInsert", CelComprehensions::expandMapInsertMacro)); + } + + @Override + public void setCheckerOptions(CelCheckerBuilder checkerBuilder) { + functions.forEach(function -> checkerBuilder.addFunctionDeclarations(function.functionDecl)); + } + + @Override + public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) { + functions.forEach(function -> runtimeBuilder.addFunctionBindings(function.functionBindings)); + } + + private static Map mapInsert(Map first, Map second) { + // TODO: return a mutable map instead of an actual copy. + Map result = Maps.newHashMapWithExpectedSize(first.size() + second.size()); + result.putAll(first); + for (Map.Entry entry : second.entrySet()) { + if (result.containsKey(entry.getKey())) { + throw new IllegalArgumentException( + String.format("insert failed: key '%s' already exists", entry.getKey())); + } + result.put(entry.getKey(), entry.getValue()); + } + return result; + } + + private static Map mapInsert(Object[] args) { + Map map = (Map) args[0]; + Object key = args[1]; + Object value = args[2]; + // TODO: return a mutable map instead of an actual copy. + if (map.containsKey(key)) { + throw new IllegalArgumentException( + String.format("insert failed: key '%s' already exists", key)); + } + Map result = Maps.newHashMapWithExpectedSize(map.size() + 1); + result.putAll(map); + result.put(key, value); + return result; + } + + private static Optional expandMapInsertMacro( + CelMacroExprFactory exprFactory, CelExpr target, ImmutableList arguments) { + if (!isTargetInNamespace(target)) { + // Return empty to indicate that we're not interested in expanding this macro, and + // that the parser should default to a function call on the receiver. + return Optional.empty(); + } + + switch (arguments.size()) { + case 2: + Optional invalidArg = + checkInvalidArgument(exprFactory, MAP_INSERT_OVERLOAD_MAP_MAP, arguments); + if (invalidArg.isPresent()) { + return invalidArg; + } + + return Optional.of(exprFactory.newGlobalCall(MAP_INSERT_FUNCTION, arguments)); + case 3: + invalidArg = checkInvalidArgument(exprFactory, MAP_INSERT_OVERLOAD_KEY_VALUE, arguments); + if (invalidArg.isPresent()) { + return invalidArg; + } + + return Optional.of(exprFactory.newGlobalCall(MAP_INSERT_FUNCTION, arguments)); + default: + return newError( + exprFactory, + "cel.mapInsert() arguments must be either two maps or a map and a key-value pair", + target); + } + } + + private static boolean isTargetInNamespace(CelExpr target) { + return target.exprKind().getKind().equals(Kind.IDENT) + && target.ident().name().equals(CEL_NAMESPACE); + } + + private static Optional checkInvalidArgument( + CelMacroExprFactory exprFactory, String functionName, List arguments) { + + if (functionName.equals(MAP_INSERT_OVERLOAD_MAP_MAP)) { + for (CelExpr arg : arguments) { + if (arg.exprKind().getKind() != Kind.MAP) { + return newError( + exprFactory, String.format("Invalid argument '%s': must be a map", arg), arg); + } + } + } + if (functionName.equals(MAP_INSERT_OVERLOAD_KEY_VALUE)) { + if (arguments.get(0).exprKind().getKind() != Kind.MAP) { + return newError( + exprFactory, + String.format("Invalid argument '%s': must be a map", arguments.get(0)), + arguments.get(0)); + } + if (arguments.get(1).exprKind().getKind() != Kind.CONSTANT) { + return newError( + exprFactory, + String.format("'%s' is an invalid Key", arguments.get(1)), + arguments.get(1)); + } + } + + return Optional.empty(); + } + + private static Optional newError( + CelMacroExprFactory exprFactory, String errorMessage, CelExpr argument) { + return Optional.of( + exprFactory.reportError( + CelIssue.formatError(exprFactory.getSourceLocation(argument), errorMessage))); + } +} diff --git a/extensions/src/main/java/dev/cel/extensions/CelExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelExtensions.java index eb795341e..7c9091e65 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelExtensions.java @@ -258,6 +258,10 @@ public static CelListsExtensions lists(Set function return new CelListsExtensions(functions); } + public static CelComprehensions comprehensions() { + return new CelComprehensions(); + } + /** * Retrieves all function names used by every extension libraries. * @@ -276,7 +280,9 @@ public static ImmutableSet getAllFunctionNames() { stream(CelEncoderExtensions.Function.values()) .map(CelEncoderExtensions.Function::getFunction), stream(CelListsExtensions.Function.values()) - .map(CelListsExtensions.Function::getFunction)) + .map(CelListsExtensions.Function::getFunction), + stream(CelComprehensions.Function.values()) + .map(CelComprehensions.Function::getFunction)) .collect(toImmutableSet()); } diff --git a/extensions/src/test/java/dev/cel/extensions/CelComprehensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelComprehensionsTest.java new file mode 100644 index 000000000..ec6f32d62 --- /dev/null +++ b/extensions/src/test/java/dev/cel/extensions/CelComprehensionsTest.java @@ -0,0 +1,138 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dev.cel.extensions; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableMap; +import com.google.testing.junit.testparameterinjector.TestParameter; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelOptions; +import dev.cel.common.CelValidationException; +import dev.cel.compiler.CelCompiler; +import dev.cel.compiler.CelCompilerFactory; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelRuntime; +import dev.cel.runtime.CelRuntimeFactory; +import java.util.Map; +import java.util.Objects; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class CelComprehensionsTest { + private static final CelOptions CEL_OPTIONS = CelOptions.current().build(); + private static final CelCompiler CEL_COMPILER = + CelCompilerFactory.standardCelCompilerBuilder() + .setOptions(CEL_OPTIONS) + .addLibraries(CelExtensions.comprehensions()) + .build(); + private static final CelRuntime CEL_RUNTIME = + CelRuntimeFactory.standardCelRuntimeBuilder() + .setOptions(CEL_OPTIONS) + .addLibraries(CelExtensions.comprehensions()) + .build(); + + @SuppressWarnings("ImmutableEnumChecker") // Test only + private enum MapInsertTestCase { + EMPTY_MAP("cel.mapInsert({}, {})", ImmutableMap.of()), + EMPTY_FULL_MAP("cel.mapInsert({}, {2.0: 3.0})", ImmutableMap.of(2.0, 3.0)), + DOUBLE_MAP("cel.mapInsert({1.0: 5.0}, {2.0: 3.0})", ImmutableMap.of(1.0, 5.0, 2.0, 3.0)), + INT_MAP("cel.mapInsert({1: 5}, {2: 3})", ImmutableMap.of(1L, 5L, 2L, 3L)), + LONG_MAP("cel.mapInsert({1: 5}, {2: 3})", ImmutableMap.of(1L, 5L, 2L, 3L)), + MIXED_INPUT_MAP("cel.mapInsert({5.0: 30}, {2.0: 30})", ImmutableMap.of(5.0, 30L, 2.0, 30L)), + SIMPLE_MAP_INSERT("cel.mapInsert({'a': 7}, 'b', 3)", ImmutableMap.of("a", 7L, "b", 3L)), + NESTED_MAP_INSERT( + "cel.mapInsert({'a': {1: 5}}, 'b', {1: 3})", + ImmutableMap.of("a", ImmutableMap.of(1L, 5L), "b", ImmutableMap.of(1L, 3L))); + + private final String expr; + private final Map expectedResult; + + MapInsertTestCase(String expr, Map expectedResult) { + this.expr = expr; + this.expectedResult = expectedResult; + } + } + + @Test + public void mapInsert_success(@TestParameter MapInsertTestCase testCase) throws Exception { + CelAbstractSyntaxTree ast = CEL_COMPILER.compile(testCase.expr).getAst(); + + Object result = CEL_RUNTIME.createProgram(ast).eval(); + + assertThat(Objects.equals(result, testCase.expectedResult)).isTrue(); + } + + @Test + @TestParameters("{expr: 'cel.mapInsert()'}") + @TestParameters("{expr: 'cel.mapInsert({})'}") + @TestParameters("{expr: 'cel.mapInsert({1: 5}, 1, 3, 13, 72})'}") + public void mapInsert_invalidSizeArgs_throwsCompilationException(String expr) { + CelValidationException e = + assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + + assertThat(e) + .hasMessageThat() + .contains( + "cel.mapInsert() arguments must be either two maps or a map and a key-value pair"); + } + + @Test + @TestParameters("{expr: 'cel.mapInsert({1: 5}, {1: 3}, {1: 3})'}") + @TestParameters("{expr: 'cel.mapInsert({1: 21}, [1], 3)'}") + public void mapInsertMapKeyValue_invalidKeyArgs_throwsCompilationException(String expr) { + CelValidationException e = + assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + + assertThat(e).hasMessageThat().contains("is an invalid Key"); + } + + @Test + @TestParameters("{expr: 'cel.mapInsert(1, 1, 3)'}") + @TestParameters("{expr: 'cel.mapInsert([1], 1, 3)'}") + public void mapInsertMapKeyValue_invalidMapArgs_throwsCompilationException(String expr) { + CelValidationException e = + assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + + assertThat(e).hasMessageThat().contains("must be a map"); + } + + @Test + @TestParameters("{expr: 'cel.mapInsert(1, {1: 2})'}") + @TestParameters("{expr: 'cel.mapInsert({1:[2]}, 3)'}") + public void mapInsertMapMap_invalidMapArgs_throwsCompilationException(String expr) { + CelValidationException e = + assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + + assertThat(e).hasMessageThat().contains("must be a map"); + } + + @Test + @TestParameters("{expr: 'cel.mapInsert({1: 5},{1: 3})'}") + @TestParameters("{expr: 'cel.mapInsert({1: 5}, 1, 3)'}") + public void mapInsert_sameKey_throwsRuntimeException(String expr) throws Exception { + CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + + CelEvaluationException e = + assertThrows(CelEvaluationException.class, () -> CEL_RUNTIME.createProgram(ast).eval()); + + assertThat(e).hasMessageThat().contains("evaluation error"); + assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); + assertThat(e).hasCauseThat().hasMessageThat().contains("insert failed: key '1' already exists"); + } +} diff --git a/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java index ffff33c7e..5d2000db9 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java @@ -177,6 +177,7 @@ public void getAllFunctionNames() { "sets.intersects", "base64.decode", "base64.encode", - "flatten"); + "flatten", + "cel.@mapInsert"); } }