Skip to content

Commit

Permalink
Add CEL enum formatting support
Browse files Browse the repository at this point in the history
Introduces the `EnumValue`, which behaves like an `IntValue`, but keeps track of the enum name, so that it can be printed using `string.format()`.

PiperOrigin-RevId: 717454914
  • Loading branch information
zeitgeist87 authored and copybara-github committed Jan 31, 2025
1 parent 4fa3096 commit 1ac4b1e
Show file tree
Hide file tree
Showing 9 changed files with 402 additions and 27 deletions.
54 changes: 41 additions & 13 deletions common/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
namespace cel {
namespace {

static constexpr std::array<ValueKind, 25> kValueToKindArray = {
static constexpr std::array kValueToKindArray = {
ValueKind::kError, ValueKind::kBool, ValueKind::kBytes,
ValueKind::kDouble, ValueKind::kDuration, ValueKind::kError,
ValueKind::kInt, ValueKind::kList, ValueKind::kList,
Expand All @@ -72,7 +72,7 @@ static constexpr std::array<ValueKind, 25> kValueToKindArray = {
ValueKind::kNull, ValueKind::kOpaque, ValueKind::kString,
ValueKind::kStruct, ValueKind::kStruct, ValueKind::kStruct,
ValueKind::kTimestamp, ValueKind::kType, ValueKind::kUint,
ValueKind::kUnknown};
ValueKind::kUnknown, ValueKind::kInt};

static_assert(kValueToKindArray.size() ==
absl::variant_size<common_internal::ValueVariant>(),
Expand Down Expand Up @@ -750,17 +750,20 @@ namespace {
Value NonNullEnumValue(
absl::Nonnull<const google::protobuf::EnumValueDescriptor*> value) {
ABSL_DCHECK(value != nullptr);
return IntValue(value->number());
return EnumValue(value);
}

Value NonNullEnumValue(absl::Nonnull<const google::protobuf::EnumDescriptor*> type,
int32_t number) {
ABSL_DCHECK(type != nullptr);
if (type->is_closed()) {
if (ABSL_PREDICT_FALSE(type->FindValueByNumber(number) == nullptr)) {
return ErrorValue(absl::InvalidArgumentError(absl::StrCat(
"closed enum has no such value: ", type->full_name(), ".", number)));
}
const google::protobuf::EnumValueDescriptor* enum_value =
type->FindValueByNumber(number);
if (type->is_closed() && ABSL_PREDICT_FALSE(enum_value == nullptr)) {
return ErrorValue(absl::InvalidArgumentError(absl::StrCat(
"closed enum has no such value: ", type->full_name(), ".", number)));
}
if (enum_value != nullptr) {
return EnumValue(enum_value);
}
return IntValue(number);
}
Expand Down Expand Up @@ -1943,6 +1946,18 @@ absl::optional<IntValue> Value::AsInt() const {
alternative != nullptr) {
return *alternative;
}
if (const auto* alternative = absl::get_if<EnumValue>(&variant_);
alternative != nullptr) {
return IntValue(alternative->NativeValue());
}
return absl::nullopt;
}

absl::optional<EnumValue> Value::AsEnum() const {
if (const auto* alternative = absl::get_if<EnumValue>(&variant_);
alternative != nullptr) {
return *alternative;
}
return absl::nullopt;
}

Expand Down Expand Up @@ -2350,18 +2365,31 @@ ErrorValue Value::GetError() && {
return absl::get<ErrorValue>(std::move(variant_));
}

IntValue Value::GetInt() const {
ABSL_DCHECK(IsInt()) << *this;
return absl::get<IntValue>(variant_);
}

#ifdef ABSL_HAVE_EXCEPTIONS
#define CEL_VALUE_THROW_BAD_VARIANT_ACCESS() throw absl::bad_variant_access()
#else
#define CEL_VALUE_THROW_BAD_VARIANT_ACCESS() \
ABSL_LOG(FATAL) << absl::bad_variant_access().what() /* Crash OK */
#endif

IntValue Value::GetInt() const {
ABSL_DCHECK(IsInt()) << *this;
if (const auto* alternative = absl::get_if<IntValue>(&variant_);
alternative != nullptr) {
return *alternative;
}
if (const auto* alternative = absl::get_if<EnumValue>(&variant_);
alternative != nullptr) {
return IntValue(alternative->NativeValue());
}
CEL_VALUE_THROW_BAD_VARIANT_ACCESS();
}

EnumValue Value::GetEnum() const {
ABSL_DCHECK(IsEnum()) << *this;
return absl::get<EnumValue>(variant_);
}

ListValue Value::GetList() const& {
ABSL_DCHECK(IsList()) << *this;
if (const auto* alternative =
Expand Down
44 changes: 43 additions & 1 deletion common/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,13 @@ class Value final {
bool IsError() const { return absl::holds_alternative<ErrorValue>(variant_); }

// Returns `true` if this value is an instance of an int value.
bool IsInt() const { return absl::holds_alternative<IntValue>(variant_); }
bool IsInt() const {
return absl::holds_alternative<IntValue>(variant_) ||
absl::holds_alternative<EnumValue>(variant_);
}

// Returns `true` if this value is an instance of an enum value.
bool IsEnum() const { return absl::holds_alternative<EnumValue>(variant_); }

// Returns `true` if this value is an instance of a list value.
bool IsList() const {
Expand Down Expand Up @@ -671,6 +677,13 @@ class Value final {
return IsInt();
}

// Convenience method for use with template metaprogramming. See
// `IsEnum()`.
template <typename T>
std::enable_if_t<std::is_same_v<EnumValue, T>, bool> Is() const {
return IsEnum();
}

// Convenience method for use with template metaprogramming. See
// `IsList()`.
template <typename T>
Expand Down Expand Up @@ -856,6 +869,11 @@ class Value final {
// int value. Otherwise an empty optional is returned.
absl::optional<IntValue> AsInt() const;

// Performs a checked cast from a value to an enum value,
// returning a non-empty optional with either a value or reference to the
// enum value. Otherwise an empty optional is returned.
absl::optional<EnumValue> AsEnum() const;

// Performs a checked cast from a value to a list value,
// returning a non-empty optional with either a value or reference to the
// list value. Otherwise an empty optional is returned.
Expand Down Expand Up @@ -1791,6 +1809,11 @@ class Value final {
// false, calling this method is undefined behavior.
IntValue GetInt() const;

// Performs an unchecked cast from a value to an int value. In
// debug builds a best effort is made to crash. If `IsInt()` would return
// false, calling this method is undefined behavior.
EnumValue GetEnum() const;

// Performs an unchecked cast from a value to a list value. In
// debug builds a best effort is made to crash. If `IsList()` would return
// false, calling this method is undefined behavior.
Expand Down Expand Up @@ -2104,6 +2127,25 @@ class Value final {
return GetInt();
}

// Convenience method for use with template metaprogramming. See
// `GetEnum()`.
template <typename T>
std::enable_if_t<std::is_same_v<EnumValue, T>, EnumValue> Get() & {
return GetEnum();
}
template <typename T>
std::enable_if_t<std::is_same_v<EnumValue, T>, EnumValue> Get() const& {
return GetEnum();
}
template <typename T>
std::enable_if_t<std::is_same_v<EnumValue, T>, EnumValue> Get() && {
return GetEnum();
}
template <typename T>
std::enable_if_t<std::is_same_v<EnumValue, T>, EnumValue> Get() const&& {
return GetEnum();
}

// Convenience method for use with template metaprogramming. See
// `GetList()`.
template <typename T>
Expand Down
26 changes: 26 additions & 0 deletions common/values/enum_value.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "google/protobuf/wrappers.pb.h"
#include "absl/status/statusor.h"
#include "common/value.h"

namespace cel {

absl::StatusOr<Value> EnumValue::Equal(ValueManager& value_manager,
const Value& other) const {
return IntValue(NativeValue()).Equal(value_manager, other);
}

} // namespace cel
112 changes: 110 additions & 2 deletions common/values/enum_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,27 @@
#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_ENUM_VALUE_H_
#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_ENUM_VALUE_H_

#include <cstdint>
#include <ostream>
#include <string>
#include <type_traits>

#include "google/protobuf/struct.pb.h"
#include "absl/base/nullability.h"
#include "absl/meta/type_traits.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/string_view.h"
#include "common/type.h"
#include "common/value_kind.h"
#include "common/values/int_value.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/generated_enum_util.h"
#include "google/protobuf/message.h"

namespace cel::common_internal {
namespace cel {
namespace common_internal {

template <typename T, typename U = absl::remove_cv_t<T>>
inline constexpr bool kIsWellKnownEnumType =
Expand All @@ -44,6 +58,100 @@ using EnableIfGeneratedEnum = std::enable_if_t<
absl::negation<std::bool_constant<kIsWellKnownEnumType<T>>>>::value,
R>;

} // namespace cel::common_internal
} // namespace common_internal

class Value;
class ValueManager;
class IntValue;
class TypeManager;
class EnumValue;

// `EnumValue` represents protobuf enum values which behave like values of the
// primitive `int` type, except that they return the enum name in
// `DebugString()`.
class EnumValue final {
public:
static constexpr ValueKind kKind = ValueKind::kInt;

explicit EnumValue(
absl::Nonnull<const google::protobuf::EnumValueDescriptor*> value) noexcept
: value_(value->number()), name_(value->name()) {}
explicit EnumValue(absl::string_view name, int64_t value) noexcept
: value_(value), name_(name) {}

EnumValue(const EnumValue&) = default;
EnumValue(EnumValue&&) = default;
EnumValue& operator=(const EnumValue&) = default;
EnumValue& operator=(EnumValue&&) = default;

ValueKind kind() const { return kKind; }

absl::string_view GetTypeName() const { return IntType::kName; }

absl::string_view GetEnumName() const { return name_; }

std::string DebugString() const { return std::string(GetEnumName()); }

// See Value::SerializeTo().
absl::Status SerializeTo(
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
absl::Cord& value) const {
return IntValue(NativeValue())
.SerializeTo(descriptor_pool, message_factory, value);
}

// See Value::ConvertToJson().
absl::Status ConvertToJson(
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
absl::Nonnull<google::protobuf::Message*> json) const {
return IntValue(NativeValue())
.ConvertToJson(descriptor_pool, message_factory, json);
}

absl::Status Equal(ValueManager& value_manager, const Value& other,
Value& result) const {
return IntValue(NativeValue()).Equal(value_manager, other, result);
}
absl::StatusOr<Value> Equal(ValueManager& value_manager,
const Value& other) const;

bool IsZeroValue() const { return NativeValue() == 0; }

int64_t NativeValue() const { return static_cast<int64_t>(*this); }

// NOLINTNEXTLINE(google-explicit-constructor)
operator int64_t() const noexcept { return value_; }

friend void swap(EnumValue& lhs, EnumValue& rhs) noexcept {
using std::swap;
swap(lhs.value_, rhs.value_);
swap(lhs.name_, rhs.name_);
}

private:
int64_t value_;
absl::string_view name_;
};

template <typename H>
H AbslHashValue(H state, EnumValue value) {
return H::combine(std::move(state), value.NativeValue());
}

inline bool operator==(EnumValue lhs, EnumValue rhs) {
return lhs.NativeValue() == rhs.NativeValue();
}

inline bool operator!=(EnumValue lhs, EnumValue rhs) {
return !operator==(lhs, rhs);
}

inline std::ostream& operator<<(std::ostream& out, EnumValue value) {
return out << value.DebugString();
}

} // namespace cel

#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_ENUM_VALUE_H_
20 changes: 12 additions & 8 deletions common/values/values.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ParsedMapFieldValue;
class ParsedRepeatedFieldValue;
class ParsedJsonListValue;
class ParsedJsonMapValue;
class EnumValue;

class CustomListValue;
class CustomListValueInterface;
Expand Down Expand Up @@ -165,18 +166,21 @@ struct IsValueAlternative
std::is_same<NullValue, T>, std::is_base_of<OpaqueValue, T>,
std::is_same<StringValue, T>, IsStructValueAlternative<T>,
std::is_same<TimestampValue, T>, std::is_same<TypeValue, T>,
std::is_same<UintValue, T>, std::is_same<UnknownValue, T>>> {};
std::is_same<UintValue, T>, std::is_same<UnknownValue, T>,
std::is_same<EnumValue, T>>> {};

template <typename T>
inline constexpr bool IsValueAlternativeV = IsValueAlternative<T>::value;

using ValueVariant = absl::variant<
absl::monostate, BoolValue, BytesValue, DoubleValue, DurationValue,
ErrorValue, IntValue, LegacyListValue, CustomListValue,
ParsedRepeatedFieldValue, ParsedJsonListValue, LegacyMapValue,
CustomMapValue, ParsedMapFieldValue, ParsedJsonMapValue, NullValue,
OpaqueValue, StringValue, LegacyStructValue, CustomStructValue,
ParsedMessageValue, TimestampValue, TypeValue, UintValue, UnknownValue>;
using ValueVariant =
absl::variant<absl::monostate, BoolValue, BytesValue, DoubleValue,
DurationValue, ErrorValue, IntValue, LegacyListValue,
CustomListValue, ParsedRepeatedFieldValue,
ParsedJsonListValue, LegacyMapValue, CustomMapValue,
ParsedMapFieldValue, ParsedJsonMapValue, NullValue,
OpaqueValue, StringValue, LegacyStructValue,
CustomStructValue, ParsedMessageValue, TimestampValue,
TypeValue, UintValue, UnknownValue, EnumValue>;

// Get the base type alternative for the given alternative or interface. The
// base type alternative is the type stored in the `ValueVariant`.
Expand Down
4 changes: 2 additions & 2 deletions eval/compiler/resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

namespace google::api::expr::runtime {

using ::cel::IntValue;
using ::cel::EnumValue;
using ::cel::TypeValue;
using ::cel::Value;

Expand Down Expand Up @@ -84,7 +84,7 @@ Resolver::Resolver(
for (const auto& enumerator : enum_type.enumerators) {
auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "",
enumerator.name);
enum_value_map_[key] = IntValue(enumerator.number);
enum_value_map_[key] = EnumValue(enumerator.name, enumerator.number);
}
}
}
Expand Down
Loading

0 comments on commit 1ac4b1e

Please sign in to comment.