diff --git a/common/value.cc b/common/value.cc index 8e613f4fa..ced272417 100644 --- a/common/value.cc +++ b/common/value.cc @@ -63,7 +63,7 @@ namespace cel { namespace { -static constexpr std::array kValueToKindArray = { +static constexpr std::array kValueToKindArray = { ValueKind::kError, ValueKind::kBool, ValueKind::kBytes, ValueKind::kDouble, ValueKind::kDuration, ValueKind::kError, ValueKind::kInt, ValueKind::kList, ValueKind::kList, @@ -72,7 +72,7 @@ static constexpr std::array kValueToKindArray = { ValueKind::kNull, ValueKind::kOpaque, ValueKind::kString, ValueKind::kStruct, ValueKind::kStruct, ValueKind::kStruct, ValueKind::kTimestamp, ValueKind::kType, ValueKind::kUint, - ValueKind::kUnknown}; + ValueKind::kUnknown, ValueKind::kInt}; static_assert(kValueToKindArray.size() == absl::variant_size(), @@ -750,17 +750,20 @@ namespace { Value NonNullEnumValue( absl::Nonnull value) { ABSL_DCHECK(value != nullptr); - return IntValue(value->number()); + return EnumValue(value); } Value NonNullEnumValue(absl::Nonnull type, int32_t number) { ABSL_DCHECK(type != nullptr); - if (type->is_closed()) { - if (ABSL_PREDICT_FALSE(type->FindValueByNumber(number) == nullptr)) { - return ErrorValue(absl::InvalidArgumentError(absl::StrCat( - "closed enum has no such value: ", type->full_name(), ".", number))); - } + const google::protobuf::EnumValueDescriptor* enum_value = + type->FindValueByNumber(number); + if (type->is_closed() && ABSL_PREDICT_FALSE(enum_value == nullptr)) { + return ErrorValue(absl::InvalidArgumentError(absl::StrCat( + "closed enum has no such value: ", type->full_name(), ".", number))); + } + if (enum_value != nullptr) { + return EnumValue(enum_value); } return IntValue(number); } @@ -1943,6 +1946,18 @@ absl::optional Value::AsInt() const { alternative != nullptr) { return *alternative; } + if (const auto* alternative = absl::get_if(&variant_); + alternative != nullptr) { + return IntValue(alternative->NativeValue()); + } + return absl::nullopt; +} + +absl::optional Value::AsEnum() const { + if (const auto* alternative = absl::get_if(&variant_); + alternative != nullptr) { + return *alternative; + } return absl::nullopt; } @@ -2350,11 +2365,6 @@ ErrorValue Value::GetError() && { return absl::get(std::move(variant_)); } -IntValue Value::GetInt() const { - ABSL_DCHECK(IsInt()) << *this; - return absl::get(variant_); -} - #ifdef ABSL_HAVE_EXCEPTIONS #define CEL_VALUE_THROW_BAD_VARIANT_ACCESS() throw absl::bad_variant_access() #else @@ -2362,6 +2372,24 @@ IntValue Value::GetInt() const { ABSL_LOG(FATAL) << absl::bad_variant_access().what() /* Crash OK */ #endif +IntValue Value::GetInt() const { + ABSL_DCHECK(IsInt()) << *this; + if (const auto* alternative = absl::get_if(&variant_); + alternative != nullptr) { + return *alternative; + } + if (const auto* alternative = absl::get_if(&variant_); + alternative != nullptr) { + return IntValue(alternative->NativeValue()); + } + CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); +} + +EnumValue Value::GetEnum() const { + ABSL_DCHECK(IsEnum()) << *this; + return absl::get(variant_); +} + ListValue Value::GetList() const& { ABSL_DCHECK(IsList()) << *this; if (const auto* alternative = diff --git a/common/value.h b/common/value.h index 5251f532e..f2f42f73d 100644 --- a/common/value.h +++ b/common/value.h @@ -501,7 +501,13 @@ class Value final { bool IsError() const { return absl::holds_alternative(variant_); } // Returns `true` if this value is an instance of an int value. - bool IsInt() const { return absl::holds_alternative(variant_); } + bool IsInt() const { + return absl::holds_alternative(variant_) || + absl::holds_alternative(variant_); + } + + // Returns `true` if this value is an instance of an enum value. + bool IsEnum() const { return absl::holds_alternative(variant_); } // Returns `true` if this value is an instance of a list value. bool IsList() const { @@ -671,6 +677,13 @@ class Value final { return IsInt(); } + // Convenience method for use with template metaprogramming. See + // `IsEnum()`. + template + std::enable_if_t, bool> Is() const { + return IsEnum(); + } + // Convenience method for use with template metaprogramming. See // `IsList()`. template @@ -856,6 +869,11 @@ class Value final { // int value. Otherwise an empty optional is returned. absl::optional AsInt() const; + // Performs a checked cast from a value to an enum value, + // returning a non-empty optional with either a value or reference to the + // enum value. Otherwise an empty optional is returned. + absl::optional AsEnum() const; + // Performs a checked cast from a value to a list value, // returning a non-empty optional with either a value or reference to the // list value. Otherwise an empty optional is returned. @@ -1791,6 +1809,11 @@ class Value final { // false, calling this method is undefined behavior. IntValue GetInt() const; + // Performs an unchecked cast from a value to an int value. In + // debug builds a best effort is made to crash. If `IsInt()` would return + // false, calling this method is undefined behavior. + EnumValue GetEnum() const; + // Performs an unchecked cast from a value to a list value. In // debug builds a best effort is made to crash. If `IsList()` would return // false, calling this method is undefined behavior. @@ -2104,6 +2127,25 @@ class Value final { return GetInt(); } + // Convenience method for use with template metaprogramming. See + // `GetEnum()`. + template + std::enable_if_t, EnumValue> Get() & { + return GetEnum(); + } + template + std::enable_if_t, EnumValue> Get() const& { + return GetEnum(); + } + template + std::enable_if_t, EnumValue> Get() && { + return GetEnum(); + } + template + std::enable_if_t, EnumValue> Get() const&& { + return GetEnum(); + } + // Convenience method for use with template metaprogramming. See // `GetList()`. template diff --git a/common/values/enum_value.cc b/common/values/enum_value.cc new file mode 100644 index 000000000..241ed507c --- /dev/null +++ b/common/values/enum_value.cc @@ -0,0 +1,26 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "google/protobuf/wrappers.pb.h" +#include "absl/status/statusor.h" +#include "common/value.h" + +namespace cel { + +absl::StatusOr EnumValue::Equal(ValueManager& value_manager, + const Value& other) const { + return IntValue(NativeValue()).Equal(value_manager, other); +} + +} // namespace cel diff --git a/common/values/enum_value.h b/common/values/enum_value.h index 71f437e62..62615fe3c 100644 --- a/common/values/enum_value.h +++ b/common/values/enum_value.h @@ -18,13 +18,27 @@ #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_ENUM_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_ENUM_VALUE_H_ +#include +#include +#include #include #include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" #include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/value_kind.h" +#include "common/values/int_value.h" +#include "google/protobuf/descriptor.h" #include "google/protobuf/generated_enum_util.h" +#include "google/protobuf/message.h" -namespace cel::common_internal { +namespace cel { +namespace common_internal { template > inline constexpr bool kIsWellKnownEnumType = @@ -44,6 +58,100 @@ using EnableIfGeneratedEnum = std::enable_if_t< absl::negation>>>::value, R>; -} // namespace cel::common_internal +} // namespace common_internal + +class Value; +class ValueManager; +class IntValue; +class TypeManager; +class EnumValue; + +// `EnumValue` represents protobuf enum values which behave like values of the +// primitive `int` type, except that they return the enum name in +// `DebugString()`. +class EnumValue final { + public: + static constexpr ValueKind kKind = ValueKind::kInt; + + explicit EnumValue( + absl::Nonnull value) noexcept + : value_(value->number()), name_(value->name()) {} + explicit EnumValue(absl::string_view name, int64_t value) noexcept + : value_(value), name_(name) {} + + EnumValue(const EnumValue&) = default; + EnumValue(EnumValue&&) = default; + EnumValue& operator=(const EnumValue&) = default; + EnumValue& operator=(EnumValue&&) = default; + + ValueKind kind() const { return kKind; } + + absl::string_view GetTypeName() const { return IntType::kName; } + + absl::string_view GetEnumName() const { return name_; } + + std::string DebugString() const { return std::string(GetEnumName()); } + + // See Value::SerializeTo(). + absl::Status SerializeTo( + absl::Nonnull descriptor_pool, + absl::Nonnull message_factory, + absl::Cord& value) const { + return IntValue(NativeValue()) + .SerializeTo(descriptor_pool, message_factory, value); + } + + // See Value::ConvertToJson(). + absl::Status ConvertToJson( + absl::Nonnull descriptor_pool, + absl::Nonnull message_factory, + absl::Nonnull json) const { + return IntValue(NativeValue()) + .ConvertToJson(descriptor_pool, message_factory, json); + } + + absl::Status Equal(ValueManager& value_manager, const Value& other, + Value& result) const { + return IntValue(NativeValue()).Equal(value_manager, other, result); + } + absl::StatusOr Equal(ValueManager& value_manager, + const Value& other) const; + + bool IsZeroValue() const { return NativeValue() == 0; } + + int64_t NativeValue() const { return static_cast(*this); } + + // NOLINTNEXTLINE(google-explicit-constructor) + operator int64_t() const noexcept { return value_; } + + friend void swap(EnumValue& lhs, EnumValue& rhs) noexcept { + using std::swap; + swap(lhs.value_, rhs.value_); + swap(lhs.name_, rhs.name_); + } + + private: + int64_t value_; + absl::string_view name_; +}; + +template +H AbslHashValue(H state, EnumValue value) { + return H::combine(std::move(state), value.NativeValue()); +} + +inline bool operator==(EnumValue lhs, EnumValue rhs) { + return lhs.NativeValue() == rhs.NativeValue(); +} + +inline bool operator!=(EnumValue lhs, EnumValue rhs) { + return !operator==(lhs, rhs); +} + +inline std::ostream& operator<<(std::ostream& out, EnumValue value) { + return out << value.DebugString(); +} + +} // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_ENUM_VALUE_H_ diff --git a/common/values/values.h b/common/values/values.h index 59437ed56..66bc09cf5 100644 --- a/common/values/values.h +++ b/common/values/values.h @@ -55,6 +55,7 @@ class ParsedMapFieldValue; class ParsedRepeatedFieldValue; class ParsedJsonListValue; class ParsedJsonMapValue; +class EnumValue; class CustomListValue; class CustomListValueInterface; @@ -165,18 +166,21 @@ struct IsValueAlternative std::is_same, std::is_base_of, std::is_same, IsStructValueAlternative, std::is_same, std::is_same, - std::is_same, std::is_same>> {}; + std::is_same, std::is_same, + std::is_same>> {}; template inline constexpr bool IsValueAlternativeV = IsValueAlternative::value; -using ValueVariant = absl::variant< - absl::monostate, BoolValue, BytesValue, DoubleValue, DurationValue, - ErrorValue, IntValue, LegacyListValue, CustomListValue, - ParsedRepeatedFieldValue, ParsedJsonListValue, LegacyMapValue, - CustomMapValue, ParsedMapFieldValue, ParsedJsonMapValue, NullValue, - OpaqueValue, StringValue, LegacyStructValue, CustomStructValue, - ParsedMessageValue, TimestampValue, TypeValue, UintValue, UnknownValue>; +using ValueVariant = + absl::variant; // Get the base type alternative for the given alternative or interface. The // base type alternative is the type stored in the `ValueVariant`. diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index d1d20e30a..af07cb0af 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -38,7 +38,7 @@ namespace google::api::expr::runtime { -using ::cel::IntValue; +using ::cel::EnumValue; using ::cel::TypeValue; using ::cel::Value; @@ -84,7 +84,7 @@ Resolver::Resolver( for (const auto& enumerator : enum_type.enumerators) { auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "", enumerator.name); - enum_value_map_[key] = IntValue(enumerator.number); + enum_value_map_[key] = EnumValue(enumerator.name, enumerator.number); } } } diff --git a/extensions/BUILD b/extensions/BUILD index 12d9a4cde..bd15ab94d 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -612,6 +612,12 @@ cc_test( ":formatting", "//common:allocator", "//common:value", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_value", + "//eval/public:portable_cel_function_adapter", + "//extensions/protobuf:enum_adapter", "//extensions/protobuf:runtime_adapter", "//internal:parse_text_proto", "//internal:testing", @@ -621,11 +627,14 @@ cc_test( "//parser:options", "//runtime", "//runtime:activation", + "//runtime:function_adapter", "//runtime:runtime_builder", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", diff --git a/extensions/formatting.cc b/extensions/formatting.cc index b02eafcea..b1767265d 100644 --- a/extensions/formatting.cc +++ b/extensions/formatting.cc @@ -298,7 +298,11 @@ absl::StatusOr FormatString( return "null"; case ValueKind::kInt: scratch.clear(); - absl::StrAppend(&scratch, value.GetInt().NativeValue()); + if (value.IsEnum()) { + absl::StrAppend(&scratch, value.GetEnum().GetEnumName()); + } else { + absl::StrAppend(&scratch, value.GetInt().NativeValue()); + } return scratch; case ValueKind::kUint: scratch.clear(); diff --git a/extensions/formatting_test.cc b/extensions/formatting_test.cc index 38d7174e6..17a36d8df 100644 --- a/extensions/formatting_test.cc +++ b/extensions/formatting_test.cc @@ -25,13 +25,21 @@ #include "cel/expr/syntax.pb.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/allocator.h" #include "common/value.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_value.h" +#include "eval/public/portable_cel_function_adapter.h" +#include "extensions/protobuf/enum_adapter.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/parse_text_proto.h" #include "internal/testing.h" @@ -40,6 +48,7 @@ #include "parser/options.h" #include "parser/parser.h" #include "runtime/activation.h" +#include "runtime/function_adapter.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" @@ -55,6 +64,9 @@ using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; +using ::google::api::expr::runtime::CelValue; +using ::google::api::expr::runtime::PortableUnaryFunctionAdapter; +using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::testing::HasSubstr; using ::testing::TestWithParam; using ::testing::ValuesIn; @@ -90,6 +102,34 @@ TEST_P(StringFormatTest, TestStringFormatting) { internal::GetTestingDescriptorPool(), options)); auto registration_status = RegisterStringFormattingFunctions(builder.function_registry(), options); + ASSERT_THAT( + cel::extensions::RegisterProtobufEnum( + builder.type_registry(), TestAllTypes::NestedEnum_descriptor()), + IsOk()); + ASSERT_THAT(builder.function_registry().Register( + UnaryFunctionAdapter::CreateDescriptor( + "modernFunc", /*receiver_style=*/false), + UnaryFunctionAdapter::WrapFunction( + [](ValueManager& value_manager, int64_t arg) { + return IntValue{arg}; + })), + IsOk()); + ASSERT_THAT( + builder.function_registry().Register( + UnaryFunctionAdapter, Value>::CreateDescriptor( + "ifNotZero", /*receiver_style=*/false), + UnaryFunctionAdapter, Value>::WrapFunction( + [](ValueManager& value_manager, + Value arg) -> absl::StatusOr { + if (!arg.IsInt()) { + return absl::InvalidArgumentError("not an int"); + } + if (arg.GetInt().IsZeroValue()) { + return absl::InvalidArgumentError("is zero"); + } + return arg; + })), + IsOk()); if (test_case.error.has_value() && !registration_status.ok()) { EXPECT_THAT(registration_status.message(), HasSubstr(*test_case.error)); return; @@ -352,6 +392,52 @@ INSTANTIATE_TEST_SUITE_P( }, .expected = "10,000.12 42", }, + { + .name = "FormatEnumProtoField", + .format = "%s", + .format_args = "message.standalone_enum", + .dyn_args = + { + {"message", MakeMessage(R"pb(standalone_enum: + FOO)pb")}, + }, + .expected = "FOO", + }, + { + .name = "FormatEnumConstant", + .format = "%s", + .format_args = + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum.BAR", + .expected = "BAR", + }, + { + .name = "EnumBehavesLikeInt", + .format = "%s", + .format_args = + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum.BAR + 42", + .expected = "43", + }, + { + .name = "EnumComparesToInt", + .format = "%s", + .format_args = + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum.BAR == 1", + .expected = "true", + }, + { + .name = "FormatEnumWithModernFunction", + .format = "%s", + .format_args = "modernFunc(cel.expr.conformance.proto3." + "TestAllTypes.NestedEnum.BAR)", + .expected = "1", + }, + { + .name = "FormatEnumWithCheckFunction", + .format = "%s", + .format_args = "ifNotZero(cel.expr.conformance.proto3." + "TestAllTypes.NestedEnum.BAR)", + .expected = "BAR", + }, { .name = "NoOp", .format = "no substitution", @@ -926,5 +1012,73 @@ INSTANTIATE_TEST_SUITE_P( return info.param.name; }); +// Move enum related tests to a separate file. DO NOT SUBMIT! +TEST(LegacyEnumTest, LegacyExtensionFunctionTest) { + google::protobuf::Arena arena; + auto builder = ::google::api::expr::runtime::CreateCelExpressionBuilder(); + ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry(), {}), IsOk()); + + ASSERT_THAT(RegisterStringFormattingFunctions( + builder->GetRegistry()->InternalGetRegistry(), {}), + IsOk()); + + ASSERT_THAT(builder->GetRegistry()->Register( + PortableUnaryFunctionAdapter::Create( + "legacyFunc", /*receiver_style=*/false, + [](google::protobuf::Arena* arena, int64_t arg) -> CelValue { + return CelValue::CreateInt64(arg); + })), + IsOk()); + ASSERT_THAT( + builder->GetRegistry()->Register( + PortableUnaryFunctionAdapter, CelValue>:: + Create("isNotZero", /*receiver_style=*/false, + [](google::protobuf::Arena* arena, + CelValue arg) -> absl::StatusOr { + if (!arg.IsInt64()) { + return absl::InvalidArgumentError( + "Argument must be an integer"); + } + if (arg.Int64OrDie() == 0) { + return absl::InvalidArgumentError("Argument is zero"); + } + return arg; + })), + IsOk()); + builder->GetTypeRegistry()->Register(TestAllTypes::NestedEnum_descriptor()); + + ASSERT_OK_AND_ASSIGN( + auto parsed_expr, + Parse("('%s'.format([isNotZero(cel.expr.conformance.proto3." + "TestAllTypes.NestedEnum.BAR)]) == '1' ? 1 : 0) + " + "legacyFunc(cel.expr.conformance.proto3." + "TestAllTypes.NestedEnum.BAR)")); + ASSERT_OK_AND_ASSIGN(auto expr_plan, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + ::google::api::expr::runtime::Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, expr_plan->Evaluate(activation, &arena)); + ASSERT_TRUE(value.IsInt64()); + EXPECT_EQ(value.Int64OrDie(), 2); +} + +TEST(LegacyEnumTest, LegacyReturnEnumTest) { + google::protobuf::Arena arena; + auto builder = ::google::api::expr::runtime::CreateCelExpressionBuilder(); + ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry(), {}), IsOk()); + builder->GetTypeRegistry()->Register(TestAllTypes::NestedEnum_descriptor()); + + ASSERT_OK_AND_ASSIGN( + auto parsed_expr, + Parse("cel.expr.conformance.proto3.TestAllTypes.NestedEnum.BAR")); + ASSERT_OK_AND_ASSIGN(auto expr_plan, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + ::google::api::expr::runtime::Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, expr_plan->Evaluate(activation, &arena)); + ASSERT_TRUE(value.IsInt64()); + EXPECT_EQ(value.Int64OrDie(), 1); +} + } // namespace } // namespace cel::extensions