From 3638cd2fcf378e85a566d419a3f27043706614d6 Mon Sep 17 00:00:00 2001 From: timdn Date: Thu, 10 Feb 2022 11:17:03 +0000 Subject: [PATCH 001/111] Add CEL/C++ support for hermetic descriptor pools So far, CEL C++ always used only the generated descriptor pool (and message factory). This means only protos that are compiled into the binary using CEL can be used. In our scenario, we have a database providing certain entities that can be added in our system at run-time and which use proto to describe output data. We want to use CEL to formulate expressions over this output. Since we know only at run-time what protos these are, we need to create a hermtic descriptor pool and dynamic message factory for that transitively closed file descriptor set. This CL adds the capability to CEL/C++ to use a custom descriptor poold and message factory. From the users' perspective this is entirely optional and the generated descriptor pool and message factory will be used if not explicitly overridden. PiperOrigin-RevId: 427701926 --- base/BUILD | 131 +++ base/internal/BUILD | 54 ++ base/internal/operators.h | 53 ++ base/internal/type.h | 72 ++ base/internal/value.h | 512 ++++++++++++ base/kind.cc | 62 ++ base/kind.h | 48 ++ base/kind_test.cc | 49 ++ base/operators.cc | 170 ++++ base/operators.h | 151 ++++ base/operators_test.cc | 267 ++++++ base/type.cc | 260 ++++++ base/type.h | 159 ++++ base/type_test.cc | 307 +++++++ base/value.cc | 789 ++++++++++++++++++ base/value.h | 380 +++++++++ base/value_test.cc | 749 +++++++++++++++++ bazel/BUILD | 8 - bazel/antlr.bzl | 75 +- bazel/deps.bzl | 23 +- bazel/deps_extra.bzl | 2 + conformance/BUILD | 2 +- conformance/server.cc | 22 +- eval/compiler/BUILD | 6 + eval/compiler/flat_expr_builder.cc | 9 +- eval/compiler/flat_expr_builder.h | 16 +- eval/compiler/flat_expr_builder_test.cc | 221 +++++ eval/eval/BUILD | 9 + eval/eval/comprehension_step_test.cc | 4 +- eval/eval/const_value_step_test.cc | 5 +- eval/eval/container_access_step_test.cc | 5 +- eval/eval/create_list_step_test.cc | 15 +- eval/eval/create_struct_step.cc | 2 +- eval/eval/create_struct_step_test.cc | 15 +- eval/eval/evaluator_core.cc | 5 +- eval/eval/evaluator_core.h | 25 +- eval/eval/evaluator_core_test.cc | 16 +- eval/eval/function_step_test.cc | 57 +- eval/eval/ident_step_test.cc | 27 +- eval/eval/logic_step_test.cc | 7 +- eval/eval/select_step_test.cc | 33 +- eval/eval/shadowable_value_step_test.cc | 5 +- eval/eval/ternary_step_test.cc | 7 +- eval/public/BUILD | 16 +- eval/public/cel_expr_builder_factory.cc | 138 ++- eval/public/cel_expr_builder_factory.h | 13 + eval/public/cel_expr_builder_factory_test.cc | 164 ++++ eval/public/cel_expression.h | 5 + eval/public/cel_type_registry.cc | 11 +- eval/public/cel_type_registry.h | 9 +- eval/public/containers/field_access.cc | 6 +- eval/public/structs/cel_proto_wrapper.cc | 263 +++--- eval/public/structs/cel_proto_wrapper.h | 7 +- eval/public/structs/cel_proto_wrapper_test.cc | 14 +- eval/testutil/BUILD | 7 + eval/testutil/simple_test_message.proto | 9 + tools/BUILD | 36 - tools/cel_ast_renumber.cc | 152 ---- tools/cel_ast_renumber.h | 33 - tools/reference_inliner.cc | 202 ----- tools/reference_inliner.h | 53 -- 61 files changed, 5175 insertions(+), 797 deletions(-) create mode 100644 base/BUILD create mode 100644 base/internal/BUILD create mode 100644 base/internal/operators.h create mode 100644 base/internal/type.h create mode 100644 base/internal/value.h create mode 100644 base/kind.cc create mode 100644 base/kind.h create mode 100644 base/kind_test.cc create mode 100644 base/operators.cc create mode 100644 base/operators.h create mode 100644 base/operators_test.cc create mode 100644 base/type.cc create mode 100644 base/type.h create mode 100644 base/type_test.cc create mode 100644 base/value.cc create mode 100644 base/value.h create mode 100644 base/value_test.cc create mode 100644 eval/public/cel_expr_builder_factory_test.cc create mode 100644 eval/testutil/simple_test_message.proto delete mode 100644 tools/cel_ast_renumber.cc delete mode 100644 tools/cel_ast_renumber.h delete mode 100644 tools/reference_inliner.cc delete mode 100644 tools/reference_inliner.h diff --git a/base/BUILD b/base/BUILD new file mode 100644 index 000000000..b6f98e7fc --- /dev/null +++ b/base/BUILD @@ -0,0 +1,131 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # Under active development, not yet being released. + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +cc_library( + name = "kind", + srcs = ["kind.cc"], + hdrs = ["kind.h"], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "kind_test", + srcs = ["kind_test.cc"], + deps = [ + ":kind", + "//internal:testing", + ], +) + +cc_library( + name = "operators", + srcs = ["operators.cc"], + hdrs = ["operators.h"], + deps = [ + "//base/internal:operators", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "operators_test", + srcs = ["operators_test.cc"], + deps = [ + ":operators", + "//internal:testing", + "@com_google_absl//absl/hash:hash_testing", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "type", + srcs = ["type.cc"], + hdrs = ["type.h"], + deps = [ + ":kind", + "//base/internal:type", + "//internal:reference_counted", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "type_test", + srcs = ["type_test.cc"], + deps = [ + ":type", + "//internal:testing", + "@com_google_absl//absl/hash:hash_testing", + ], +) + +cc_library( + name = "value", + srcs = ["value.cc"], + hdrs = ["value.h"], + deps = [ + ":kind", + ":type", + "//base/internal:value", + "//internal:casts", + "//internal:reference_counted", + "//internal:status_macros", + "//internal:strings", + "//internal:time", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "value_test", + srcs = ["value_test.cc"], + deps = [ + ":type", + ":value", + "//internal:strings", + "//internal:testing", + "//internal:time", + "@com_google_absl//absl/hash:hash_testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + ], +) diff --git a/base/internal/BUILD b/base/internal/BUILD new file mode 100644 index 000000000..d4eeffe0d --- /dev/null +++ b/base/internal/BUILD @@ -0,0 +1,54 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "operators", + hdrs = ["operators.h"], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "type", + hdrs = ["type.h"], + deps = [ + "//base:kind", + "//internal:reference_counted", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "value", + hdrs = ["value.h"], + deps = [ + ":type", + "//base:kind", + "//internal:casts", + "//internal:reference_counted", + "@com_google_absl//absl/base:config", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) diff --git a/base/internal/operators.h b/base/internal/operators.h new file mode 100644 index 000000000..84159dcca --- /dev/null +++ b/base/internal/operators.h @@ -0,0 +1,53 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_OPERATORS_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_OPERATORS_H_ + +#include "absl/strings/string_view.h" + +namespace cel { + +enum class OperatorId; + +namespace base_internal { + +struct OperatorData final { + OperatorData() = delete; + + OperatorData(const OperatorData&) = delete; + + OperatorData(OperatorData&&) = delete; + + constexpr OperatorData(cel::OperatorId id, absl::string_view name, + absl::string_view display_name, int precedence, + int arity) + : id(id), + name(name), + display_name(display_name), + precedence(precedence), + arity(arity) {} + + const cel::OperatorId id; + const absl::string_view name; + const absl::string_view display_name; + const int precedence; + const int arity; +}; + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_OPERATORS_H_ diff --git a/base/internal/type.h b/base/internal/type.h new file mode 100644 index 000000000..3b2220c42 --- /dev/null +++ b/base/internal/type.h @@ -0,0 +1,72 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_H_ + +#include "absl/hash/hash.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "base/kind.h" +#include "internal/reference_counted.h" + +namespace cel { + +class Type; + +namespace base_internal { + +class SimpleType; + +class BaseType : public cel::internal::ReferenceCounted { + public: + // Returns the type kind. + virtual Kind kind() const = 0; + + // Returns the type name, i.e. map or google.protobuf.Any. + virtual absl::string_view name() const = 0; + + // Returns the type parameters of the type, i.e. key and value of map type. + virtual absl::Span parameters() const = 0; + + protected: + // Overriden by subclasses to implement more strictly equality testing. By + // default `cel::Type` ensures `kind()` and `name()` are equal, this behavior + // cannot be overriden. It is completely valid and acceptable to simply return + // `true`. + // + // This method should only ever be called by cel::Type. + virtual bool Equals(const cel::Type& value) const = 0; + + // Overriden by subclasses to implement better hashing. By default `cel::Type` + // hashes `kind()` and `name()`, this behavior cannot be overriden. It is + // completely valid and acceptable to simply do nothing. + // + // This method should only ever be called by cel::Type. + virtual void HashValue(absl::HashState state) const = 0; + + private: + friend class cel::Type; + friend class SimpleType; + + // The default constructor is private so that only sanctioned classes can + // extend it. Users should extend those classes instead of this one. + constexpr BaseType() = default; +}; + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_H_ diff --git a/base/internal/value.h b/base/internal/value.h new file mode 100644 index 000000000..81fdc0b03 --- /dev/null +++ b/base/internal/value.h @@ -0,0 +1,512 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/config.h" +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/hash/hash.h" +#include "absl/numeric/bits.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "base/internal/type.h" +#include "base/kind.h" +#include "internal/casts.h" +#include "internal/reference_counted.h" + +namespace cel { + +class Value; +class Bytes; + +namespace base_internal { + +// Abstract base class that all non-simple values are derived from. Users will +// not inherit from this directly but rather indirectly through exposed classes +// like cel::Struct. +class BaseValue : public cel::internal::ReferenceCounted { + public: + // Returns a human readable representation of this value. The representation + // is not guaranteed to be consistent across versions and should only be used + // for debugging purposes. + virtual std::string DebugString() const = 0; + + protected: + virtual bool Equals(const cel::Value& value) const = 0; + + virtual void HashValue(absl::HashState state) const = 0; + + private: + friend class cel::Value; + friend class cel::Bytes; + + BaseValue() = default; +}; + +// Type erased state capable of holding a pointer to remote storage or storing +// objects less than two pointers in size inline. +union ExternalDataReleaserState final { + void* remote; + alignas(alignof(std::max_align_t)) char local[sizeof(void*) * 2]; +}; + +// Function which deletes the object referenced by ExternalDataReleaserState. +using ExternalDataReleaserDeleter = void(ExternalDataReleaserState* state); + +template +void LocalExternalDataReleaserDeleter(ExternalDataReleaserState* state) { + reinterpret_cast(&state->local)->~Releaser(); +} + +template +void RemoteExternalDataReleaserDeleter(ExternalDataReleaserState* state) { + ::delete reinterpret_cast(state->remote); +} + +// Function which invokes the object referenced by ExternalDataReleaserState. +using ExternalDataReleaseInvoker = + void(ExternalDataReleaserState* state) noexcept; + +template +void LocalExternalDataReleaserInvoker( + ExternalDataReleaserState* state) noexcept { + (*reinterpret_cast(&state->local))(); +} + +template +void RemoteExternalDataReleaserInvoker( + ExternalDataReleaserState* state) noexcept { + (*reinterpret_cast(&state->remote))(); +} + +struct ExternalDataReleaser final { + ExternalDataReleaser() = delete; + + template + explicit ExternalDataReleaser(Releaser&& releaser) { + using DecayedReleaser = std::decay_t; + if constexpr (sizeof(DecayedReleaser) <= sizeof(void*) * 2 && + alignof(DecayedReleaser) <= alignof(std::max_align_t)) { + // Object meets size and alignment constraints, will be stored + // inline in ExternalDataReleaserState.local. + ::new (static_cast(&state.local)) + DecayedReleaser(std::forward(releaser)); + invoker = LocalExternalDataReleaserInvoker; + if constexpr (std::is_trivially_destructible_v) { + // Object is trivially destructable, no need to call destructor at all. + deleter = nullptr; + } else { + deleter = LocalExternalDataReleaserDeleter; + } + } else { + // Object does not meet size and alignment constraints, allocate on the + // heap and store pointer in ExternalDataReleaserState::remote. inline in + // ExternalDataReleaserState::local. + state.remote = ::new DecayedReleaser(std::forward(releaser)); + invoker = RemoteExternalDataReleaserInvoker; + deleter = RemoteExternalDataReleaserDeleter; + } + } + + ExternalDataReleaser(const ExternalDataReleaser&) = delete; + + ExternalDataReleaser(ExternalDataReleaser&&) = delete; + + ~ExternalDataReleaser() { + (*invoker)(&state); + if (deleter != nullptr) { + (*deleter)(&state); + } + } + + ExternalDataReleaser& operator=(const ExternalDataReleaser&) = delete; + + ExternalDataReleaser& operator=(ExternalDataReleaser&&) = delete; + + ExternalDataReleaserState state; + ExternalDataReleaserDeleter* deleter; + ExternalDataReleaseInvoker* invoker; +}; + +// Utility class encompassing a contiguous array of data which a function that +// must be called when the data is no longer needed. +struct ExternalData final { + ExternalData() = delete; + + ExternalData(const void* data, size_t size, + std::unique_ptr releaser) + : data(data), size(size), releaser(std::move(releaser)) {} + + ExternalData(const ExternalData&) = delete; + + ExternalData(ExternalData&&) noexcept = default; + + ExternalData& operator=(const ExternalData&) = delete; + + ExternalData& operator=(ExternalData&&) noexcept = default; + + const void* data; + size_t size; + std::unique_ptr releaser; +}; + +// Currently absl::Status has a size that is less than or equal to 8, however +// this could change at any time. Thus we delegate the lifetime management to +// BaseInlinedStatus which is always less than or equal to 8 bytes. +template +class BaseInlinedStatus; + +// Specialization for when the size of absl::Status is less than or equal to 8 +// bytes. +template <> +class BaseInlinedStatus final { + public: + BaseInlinedStatus() = default; + + BaseInlinedStatus(const BaseInlinedStatus&) = default; + + BaseInlinedStatus(BaseInlinedStatus&&) = default; + + explicit BaseInlinedStatus(const absl::Status& status) : status_(status) {} + + BaseInlinedStatus& operator=(const BaseInlinedStatus&) = default; + + BaseInlinedStatus& operator=(BaseInlinedStatus&&) = default; + + BaseInlinedStatus& operator=(const absl::Status& status) { + status_ = status; + return *this; + } + + const absl::Status& status() const { return status_; } + + private: + absl::Status status_; +}; + +// Specialization for when the size of absl::Status is greater than 8 bytes. As +// mentioned above, this template is never used today. It could in the future if +// the size of `absl::Status` ever changes. Without this specialization, our +// static asserts below would break and so would compiling CEL. +template <> +class BaseInlinedStatus final { + public: + BaseInlinedStatus() = default; + + BaseInlinedStatus(const BaseInlinedStatus&) = default; + + BaseInlinedStatus(BaseInlinedStatus&&) = default; + + explicit BaseInlinedStatus(const absl::Status& status) + : status_(std::make_shared(status)) {} + + BaseInlinedStatus& operator=(const BaseInlinedStatus&) = default; + + BaseInlinedStatus& operator=(BaseInlinedStatus&&) = default; + + BaseInlinedStatus& operator=(const absl::Status& status) { + if (status_) { + *status_ = status; + } else { + status_ = std::make_shared(status); + } + return *this; + } + + const absl::Status& status() const { + static const absl::Status* ok_status = new absl::Status(); + return status_ ? *status_ : *ok_status; + } + + private: + std::shared_ptr status_; +}; + +using InlinedStatus = BaseInlinedStatus<(sizeof(absl::Status) <= 8)>; + +// ValueMetadata is a specialized tagged union capable of storing either a +// pointer to a BaseType or a Kind. Only simple kinds are stored directly. +// Simple kinds can be converted into cel::Type using cel::Type::Simple. +// ValueMetadata is primarily used to interpret the contents of ValueContent. +// +// We assume that all pointers returned by `malloc()` are at minimum aligned to +// 4 bytes. In practice this assumption is pretty safe and all known +// implementations exhibit this behavior. +// +// The tagged union byte layout depends on the 0 bit. +// +// Bit 0 unset: +// +// -------------------------------- +// | 63 ... 2 | 1 | 0 | +// -------------------------------- +// | pointer | reserved | reffed | +// -------------------------------- +// +// Bit 0 set: +// +// --------------------------------------------------------------- +// | 63 ... 32 | 31 ... 16 | 15 ... 8 | 7 ... 1 | 0 | +// --------------------------------------------------------------- +// | extended_content | reserved | kind | reserved | simple | +// --------------------------------------------------------------- +// +// Q: Why not use absl::variant/std::variant? +// A: In theory, we could. However it would be repetative and inefficient. +// variant has a size equal to the largest of its memory types plus an +// additional field keeping track of the type that is active. For our purposes, +// the field that is active is kept track of by ValueMetadata and the storage in +// ValueContent. We know what is stored in ValueContent by the kind/type in +// ValueMetadata. Since we need to keep the type bundled with the Value, using +// variant would introduce two sources of truth for what is stored in +// ValueContent. If we chose the naive implementation, which would be to use +// Type instead of ValueMetadata and variant instead of ValueContent, each time +// we copy Value we would be guaranteed to incur a reference count causing a +// cache miss. This approach avoids that reference count for simple types. +// Additionally the size of Value would now be roughly 8 + 16 on 64-bit +// platforms. +// +// As with ValueContent, this class is only meant to be used by cel::Value. +class ValueMetadata final { + public: + constexpr ValueMetadata() : raw_(MakeDefault()) {} + + constexpr explicit ValueMetadata(Kind kind) : ValueMetadata(kind, 0) {} + + constexpr ValueMetadata(Kind kind, uint32_t extended_content) + : raw_(MakeSimple(kind, extended_content)) {} + + explicit ValueMetadata(const BaseType* base_type) + : ptr_(reinterpret_cast(base_type)) { + // Assert that the lower 2 bits are 0, a.k.a. at minimum 4 byte aligned. + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(base_type)) >= 2); + } + + ValueMetadata(const ValueMetadata&) = delete; + + ValueMetadata(ValueMetadata&&) = delete; + + ValueMetadata& operator=(const ValueMetadata&) = delete; + + ValueMetadata& operator=(ValueMetadata&&) = delete; + + constexpr bool simple_tag() const { + return (lower_ & kSimpleTag) == kSimpleTag; + } + + constexpr uint32_t extended_content() const { + ABSL_ASSERT(simple_tag()); + return higher_; + } + + const BaseType* base_type() const { + ABSL_ASSERT(!simple_tag()); + return reinterpret_cast(ptr_ & kPtrMask); + } + + Kind kind() const { + return simple_tag() ? static_cast(lower_ >> 8) : base_type()->kind(); + } + + void Reset() { + if (!simple_tag()) { + internal::Unref(base_type()); + } + raw_ = MakeDefault(); + } + + void CopyFrom(const ValueMetadata& other) { + if (ABSL_PREDICT_TRUE(this != std::addressof(other))) { + if (!other.simple_tag()) { + internal::Ref(other.base_type()); + } + if (!simple_tag()) { + internal::Unref(base_type()); + } + raw_ = other.raw_; + } + } + + void MoveFrom(ValueMetadata&& other) { + if (ABSL_PREDICT_TRUE(this != std::addressof(other))) { + if (!simple_tag()) { + internal::Unref(base_type()); + } + raw_ = other.raw_; + other.raw_ = MakeDefault(); + } + } + + private: + static constexpr uint64_t MakeSimple(Kind kind, uint32_t extended_content) { + return static_cast(kSimpleTag | + (static_cast(kind) << 8)) | + (static_cast(extended_content) << 32); + } + + static constexpr uint64_t MakeDefault() { + return MakeSimple(Kind::kNullType, 0); + } + + static constexpr uint32_t kNoTag = 0; + static constexpr uint32_t kSimpleTag = + 1 << 0; // Indicates the kind is simple and there is no BaseType* held. + static constexpr uint32_t kReservedTag = 1 << 1; + static constexpr uintptr_t kPtrMask = + ~static_cast(kSimpleTag | kReservedTag); + + union { + uint64_t raw_; + +#if defined(ABSL_IS_LITTLE_ENDIAN) + struct { + uint32_t lower_; + uint32_t higher_; + }; +#elif defined(ABSL_IS_BIG_ENDIAN) + struct { + uint32_t higher_; + uint32_t lower_; + }; +#else +#error "Platform is neither big endian nor little endian" +#endif + + uintptr_t ptr_; + }; +}; + +static_assert(sizeof(ValueMetadata) == 8, + "Expected sizeof(ValueMetadata) to be 8"); + +// ValueContent is an untagged union whose contents are determined by the +// accompanying ValueMetadata. +// +// As with ValueMetadata, this class is only meant to be used by cel::Value. +class ValueContent final { + public: + constexpr ValueContent() : raw_(0) {} + + constexpr explicit ValueContent(bool value) : bool_value_(value) {} + + constexpr explicit ValueContent(int64_t value) : int_value_(value) {} + + constexpr explicit ValueContent(uint64_t value) : uint_value_(value) {} + + constexpr explicit ValueContent(double value) : double_value_(value) {} + + explicit ValueContent(const absl::Status& status) { + construct_error_value(status); + } + + constexpr explicit ValueContent(BaseValue* base_value) + : base_value_(base_value) {} + + ValueContent(const ValueContent&) = delete; + + ValueContent(ValueContent&&) = delete; + + ~ValueContent() {} + + ValueContent& operator=(const ValueContent&) = delete; + + ValueContent& operator=(ValueContent&&) = delete; + + constexpr bool bool_value() const { return bool_value_; } + + constexpr int64_t int_value() const { return int_value_; } + + constexpr uint64_t uint_value() const { return uint_value_; } + + constexpr double double_value() const { return double_value_; } + + constexpr void construct_trivial_value(uint64_t value) { raw_ = value; } + + constexpr void destruct_trivial_value() { raw_ = 0; } + + constexpr uint64_t trivial_value() const { return raw_; } + + // Updates this to hold `value`, incrementing the reference count. This is + // used during copies. + void construct_reffed_value(BaseValue* value) { + base_value_ = cel::internal::Ref(value); + } + + // Updates this to hold `value` without incrementing the reference count. This + // is used during moves. + void adopt_reffed_value(BaseValue* value) { base_value_ = value; } + + // Decrement the reference count of the currently held reffed value and clear + // this. + void destruct_reffed_value() { + cel::internal::Unref(base_value_); + base_value_ = nullptr; + } + + // Return the currently held reffed value and reset this, without decrementing + // the reference count. This is used during moves. + BaseValue* release_reffed_value() { + BaseValue* reffed_value = base_value_; + base_value_ = nullptr; + return reffed_value; + } + + constexpr BaseValue* reffed_value() const { return base_value_; } + + void construct_error_value(const absl::Status& status) { + ::new (static_cast(std::addressof(error_value_))) + InlinedStatus(status); + } + + void assign_error_value(const absl::Status& status) { error_value_ = status; } + + void destruct_error_value() { + std::addressof(error_value_)->~InlinedStatus(); + } + + constexpr const absl::Status& error_value() const { + return error_value_.status(); + } + + private: + union { + uint64_t raw_; + + bool bool_value_; + int64_t int_value_; + uint64_t uint_value_; + double double_value_; + InlinedStatus error_value_; + BaseValue* base_value_; + }; +}; + +static_assert(sizeof(ValueContent) == 8, + "Expected sizeof(ValueContent) to be 8"); + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_H_ diff --git a/base/kind.cc b/base/kind.cc new file mode 100644 index 000000000..f1c207e4b --- /dev/null +++ b/base/kind.cc @@ -0,0 +1,62 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/kind.h" + +namespace cel { + +absl::string_view KindToString(Kind kind) { + switch (kind) { + case Kind::kNullType: + return "null_type"; + case Kind::kDyn: + return "dyn"; + case Kind::kAny: + return "any"; + case Kind::kType: + return "type"; + case Kind::kTypeParam: + return "type_param"; + case Kind::kBool: + return "bool"; + case Kind::kInt: + return "int"; + case Kind::kUint: + return "uint"; + case Kind::kDouble: + return "double"; + case Kind::kString: + return "string"; + case Kind::kBytes: + return "bytes"; + case Kind::kEnum: + return "enum"; + case Kind::kDuration: + return "duration"; + case Kind::kTimestamp: + return "timestamp"; + case Kind::kList: + return "list"; + case Kind::kMap: + return "map"; + case Kind::kStruct: + return "struct"; + case Kind::kOpaque: + return "opaque"; + default: + return "*error*"; + } +} + +} // namespace cel diff --git a/base/kind.h b/base/kind.h new file mode 100644 index 000000000..cb294075e --- /dev/null +++ b/base/kind.h @@ -0,0 +1,48 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_KIND_H_ +#define THIRD_PARTY_CEL_CPP_BASE_KIND_H_ + +#include "absl/strings/string_view.h" + +namespace cel { + +enum class Kind { + kNullType = 0, + kError, + kDyn, + kAny, + kType, + kTypeParam, + kBool, + kInt, + kUint, + kDouble, + kString, + kBytes, + kEnum, + kDuration, + kTimestamp, + kList, + kMap, + kStruct, + kOpaque, +}; + +absl::string_view KindToString(Kind kind); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_KIND_H_ diff --git a/base/kind_test.cc b/base/kind_test.cc new file mode 100644 index 000000000..4069f931d --- /dev/null +++ b/base/kind_test.cc @@ -0,0 +1,49 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/kind.h" + +#include + +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(Kind, ToString) { + EXPECT_EQ(KindToString(Kind::kError), "*error*"); + EXPECT_EQ(KindToString(Kind::kNullType), "null_type"); + EXPECT_EQ(KindToString(Kind::kDyn), "dyn"); + EXPECT_EQ(KindToString(Kind::kAny), "any"); + EXPECT_EQ(KindToString(Kind::kType), "type"); + EXPECT_EQ(KindToString(Kind::kTypeParam), "type_param"); + EXPECT_EQ(KindToString(Kind::kBool), "bool"); + EXPECT_EQ(KindToString(Kind::kInt), "int"); + EXPECT_EQ(KindToString(Kind::kUint), "uint"); + EXPECT_EQ(KindToString(Kind::kDouble), "double"); + EXPECT_EQ(KindToString(Kind::kString), "string"); + EXPECT_EQ(KindToString(Kind::kBytes), "bytes"); + EXPECT_EQ(KindToString(Kind::kEnum), "enum"); + EXPECT_EQ(KindToString(Kind::kDuration), "duration"); + EXPECT_EQ(KindToString(Kind::kTimestamp), "timestamp"); + EXPECT_EQ(KindToString(Kind::kList), "list"); + EXPECT_EQ(KindToString(Kind::kMap), "map"); + EXPECT_EQ(KindToString(Kind::kStruct), "struct"); + EXPECT_EQ(KindToString(Kind::kOpaque), "opaque"); + EXPECT_EQ(KindToString(static_cast(std::numeric_limits::max())), + "*error*"); +} + +} // namespace +} // namespace cel diff --git a/base/operators.cc b/base/operators.cc new file mode 100644 index 000000000..5dc6975ec --- /dev/null +++ b/base/operators.cc @@ -0,0 +1,170 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/operators.h" + +#include + +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" +#include "absl/base/macros.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" + +// Macro definining all the operators and their properties. +// (1) - The identifier. +// (2) - The display name if applicable, otherwise an empty string. +// (3) - The name. +// (4) - The precedence if applicable, otherwise 0. +// (5) - The arity. +#define CEL_OPERATORS_ENUM(XX) \ + XX(Conditional, "", "_?_:_", 8, 3) \ + XX(LogicalOr, "||", "_||_", 7, 2) \ + XX(LogicalAnd, "&&", "_&&_", 6, 2) \ + XX(Equals, "==", "_==_", 5, 2) \ + XX(NotEquals, "!=", "_!=_", 5, 2) \ + XX(Less, "<", "_<_", 5, 2) \ + XX(LessEquals, "<=", "_<=_", 5, 2) \ + XX(Greater, ">", "_>_", 5, 2) \ + XX(GreaterEquals, ">=", "_>=_", 5, 2) \ + XX(In, "in", "@in", 5, 2) \ + XX(OldIn, "in", "_in_", 5, 2) \ + XX(Add, "+", "_+_", 4, 2) \ + XX(Subtract, "-", "_-_", 4, 2) \ + XX(Multiply, "*", "_*_", 3, 2) \ + XX(Divide, "/", "_/_", 3, 2) \ + XX(Modulo, "%", "_%_", 3, 2) \ + XX(LogicalNot, "!", "!_", 2, 1) \ + XX(Negate, "-", "-_", 2, 1) \ + XX(Index, "", "_[_]", 1, 2) \ + XX(NotStrictlyFalse, "", "@not_strictly_false", 0, 1) \ + XX(OldNotStrictlyFalse, "", "__not_strictly_false__", 0, 1) + +namespace cel { + +namespace { + +ABSL_CONST_INIT absl::once_flag operators_once_flag; +ABSL_CONST_INIT const absl::flat_hash_map* + operators_by_name = nullptr; +ABSL_CONST_INIT const absl::flat_hash_map* + operators_by_display_name = nullptr; +ABSL_CONST_INIT const absl::flat_hash_map* + unary_operators = nullptr; +ABSL_CONST_INIT const absl::flat_hash_map* + binary_operators = nullptr; + +void InitializeOperators() { + ABSL_ASSERT(operators_by_name == nullptr); + ABSL_ASSERT(operators_by_display_name == nullptr); + ABSL_ASSERT(unary_operators == nullptr); + ABSL_ASSERT(binary_operators == nullptr); + auto operators_by_name_ptr = + std::make_unique>(); + auto operators_by_display_name_ptr = + std::make_unique>(); + auto unary_operators_ptr = + std::make_unique>(); + auto binary_operators_ptr = + std::make_unique>(); + +#define CEL_DEFINE_OPERATORS_BY_NAME(id, symbol, name, precedence, arity) \ + if constexpr (!absl::string_view(name).empty()) { \ + operators_by_name_ptr->insert({name, Operator::id()}); \ + } + CEL_OPERATORS_ENUM(CEL_DEFINE_OPERATORS_BY_NAME) +#undef CEL_DEFINE_OPERATORS_BY_NAME + +#define CEL_DEFINE_OPERATORS_BY_SYMBOL(id, symbol, name, precedence, arity) \ + if constexpr (!absl::string_view(symbol).empty()) { \ + operators_by_display_name_ptr->insert({symbol, Operator::id()}); \ + } + CEL_OPERATORS_ENUM(CEL_DEFINE_OPERATORS_BY_SYMBOL) +#undef CEL_DEFINE_OPERATORS_BY_SYMBOL + +#define CEL_DEFINE_UNARY_OPERATORS(id, symbol, name, precedence, arity) \ + if constexpr (!absl::string_view(symbol).empty() && arity == 1) { \ + unary_operators_ptr->insert({symbol, Operator::id()}); \ + } + CEL_OPERATORS_ENUM(CEL_DEFINE_UNARY_OPERATORS) +#undef CEL_DEFINE_UNARY_OPERATORS + +#define CEL_DEFINE_BINARY_OPERATORS(id, symbol, name, precedence, arity) \ + if constexpr (!absl::string_view(symbol).empty() && arity == 2) { \ + binary_operators_ptr->insert({symbol, Operator::id()}); \ + } + CEL_OPERATORS_ENUM(CEL_DEFINE_BINARY_OPERATORS) +#undef CEL_DEFINE_BINARY_OPERATORS + + operators_by_name = operators_by_name_ptr.release(); + operators_by_display_name = operators_by_display_name_ptr.release(); + unary_operators = unary_operators_ptr.release(); + binary_operators = binary_operators_ptr.release(); +} + +#define CEL_DEFINE_OPERATOR_DATA(id, symbol, name, precedence, arity) \ + ABSL_CONST_INIT constexpr base_internal::OperatorData k##id##Data( \ + OperatorId::k##id, name, symbol, precedence, arity); +CEL_OPERATORS_ENUM(CEL_DEFINE_OPERATOR_DATA) +#undef CEL_DEFINE_OPERATOR_DATA + +} // namespace + +#define CEL_DEFINE_OPERATOR(id, symbol, name, precedence, arity) \ + Operator Operator::id() { return Operator(std::addressof(k##id##Data)); } +CEL_OPERATORS_ENUM(CEL_DEFINE_OPERATOR) +#undef CEL_DEFINE_OPERATOR + +absl::StatusOr Operator::FindByName(absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + auto it = operators_by_name->find(input); + if (it != operators_by_name->end()) { + return it->second; + } + return absl::NotFoundError(absl::StrCat("No such operator: ", input)); +} + +absl::StatusOr Operator::FindByDisplayName(absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + auto it = operators_by_display_name->find(input); + if (it != operators_by_display_name->end()) { + return it->second; + } + return absl::NotFoundError(absl::StrCat("No such operator: ", input)); +} + +absl::StatusOr Operator::FindUnaryByDisplayName( + absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + auto it = unary_operators->find(input); + if (it != unary_operators->end()) { + return it->second; + } + return absl::NotFoundError(absl::StrCat("No such unary operator: ", input)); +} + +absl::StatusOr Operator::FindBinaryByDisplayName( + absl::string_view input) { + absl::call_once(operators_once_flag, InitializeOperators); + auto it = binary_operators->find(input); + if (it != binary_operators->end()) { + return it->second; + } + return absl::NotFoundError(absl::StrCat("No such binary operator: ", input)); +} + +} // namespace cel + +#undef CEL_OPERATORS_ENUM diff --git a/base/operators.h b/base/operators.h new file mode 100644 index 000000000..7cd40d911 --- /dev/null +++ b/base/operators.h @@ -0,0 +1,151 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_OPERATORS_H_ +#define THIRD_PARTY_CEL_CPP_BASE_OPERATORS_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/internal/operators.h" + +namespace cel { + +enum class OperatorId { + kConditional = 1, + kLogicalAnd, + kLogicalOr, + kLogicalNot, + kEquals, + kNotEquals, + kLess, + kLessEquals, + kGreater, + kGreaterEquals, + kAdd, + kSubtract, + kMultiply, + kDivide, + kModulo, + kNegate, + kIndex, + kIn, + kNotStrictlyFalse, + kOldIn, + kOldNotStrictlyFalse, +}; + +class Operator final { + public: + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Conditional(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator LogicalAnd(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator LogicalOr(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator LogicalNot(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Equals(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator NotEquals(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Less(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator LessEquals(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Greater(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator GreaterEquals(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Add(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Subtract(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Multiply(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Divide(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Modulo(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Negate(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator Index(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator In(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator NotStrictlyFalse(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator OldIn(); + ABSL_ATTRIBUTE_PURE_FUNCTION static Operator OldNotStrictlyFalse(); + + static absl::StatusOr FindByName(absl::string_view input); + + static absl::StatusOr FindByDisplayName(absl::string_view input); + + static absl::StatusOr FindUnaryByDisplayName( + absl::string_view input); + + static absl::StatusOr FindBinaryByDisplayName( + absl::string_view input); + + Operator() = delete; + + Operator(const Operator&) = default; + + Operator(Operator&&) = default; + + Operator& operator=(const Operator&) = default; + + Operator& operator=(Operator&&) = default; + + constexpr OperatorId id() const { return data_->id; } + + // Returns the name of the operator. This is the managed representation of the + // operator, for example "_&&_". + constexpr absl::string_view name() const { return data_->name; } + + // Returns the source text representation of the operator. This is the + // unmanaged text representation of the operator, for example "&&". + // + // Note that this will be empty for operators like Conditional() and Index(). + constexpr absl::string_view display_name() const { + return data_->display_name; + } + + constexpr int precedence() const { return data_->precedence; } + + constexpr int arity() const { return data_->arity; } + + private: + constexpr explicit Operator(const base_internal::OperatorData* data) + : data_(data) {} + + const base_internal::OperatorData* data_; +}; + +constexpr bool operator==(const Operator& lhs, const Operator& rhs) { + return lhs.id() == rhs.id(); +} + +constexpr bool operator==(OperatorId lhs, const Operator& rhs) { + return lhs == rhs.id(); +} + +constexpr bool operator==(const Operator& lhs, OperatorId rhs) { + return operator==(rhs, lhs); +} + +constexpr bool operator!=(const Operator& lhs, const Operator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(OperatorId lhs, const Operator& rhs) { + return !operator==(lhs, rhs); +} + +constexpr bool operator!=(const Operator& lhs, OperatorId rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, const Operator& op) { + return H::combine(std::move(state), op.id()); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_OPERATORS_H_ diff --git a/base/operators_test.cc b/base/operators_test.cc new file mode 100644 index 000000000..b86743e7e --- /dev/null +++ b/base/operators_test.cc @@ -0,0 +1,267 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/operators.h" + +#include + +#include "absl/hash/hash_testing.h" +#include "absl/status/status.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using cel::internal::StatusIs; + +TEST(Operator, TypeTraits) { + EXPECT_FALSE(std::is_default_constructible_v); + EXPECT_TRUE(std::is_copy_constructible_v); + EXPECT_TRUE(std::is_move_constructible_v); + EXPECT_TRUE(std::is_copy_assignable_v); + EXPECT_TRUE(std::is_move_assignable_v); +} + +TEST(Operator, Conditional) { + EXPECT_EQ(Operator::Conditional().id(), OperatorId::kConditional); + EXPECT_EQ(Operator::Conditional().name(), "_?_:_"); + EXPECT_EQ(Operator::Conditional().display_name(), ""); + EXPECT_EQ(Operator::Conditional().precedence(), 8); + EXPECT_EQ(Operator::Conditional().arity(), 3); +} + +TEST(Operator, LogicalAnd) { + EXPECT_EQ(Operator::LogicalAnd().id(), OperatorId::kLogicalAnd); + EXPECT_EQ(Operator::LogicalAnd().name(), "_&&_"); + EXPECT_EQ(Operator::LogicalAnd().display_name(), "&&"); + EXPECT_EQ(Operator::LogicalAnd().precedence(), 6); + EXPECT_EQ(Operator::LogicalAnd().arity(), 2); +} + +TEST(Operator, LogicalOr) { + EXPECT_EQ(Operator::LogicalOr().id(), OperatorId::kLogicalOr); + EXPECT_EQ(Operator::LogicalOr().name(), "_||_"); + EXPECT_EQ(Operator::LogicalOr().display_name(), "||"); + EXPECT_EQ(Operator::LogicalOr().precedence(), 7); + EXPECT_EQ(Operator::LogicalOr().arity(), 2); +} + +TEST(Operator, LogicalNot) { + EXPECT_EQ(Operator::LogicalNot().id(), OperatorId::kLogicalNot); + EXPECT_EQ(Operator::LogicalNot().name(), "!_"); + EXPECT_EQ(Operator::LogicalNot().display_name(), "!"); + EXPECT_EQ(Operator::LogicalNot().precedence(), 2); + EXPECT_EQ(Operator::LogicalNot().arity(), 1); +} + +TEST(Operator, Equals) { + EXPECT_EQ(Operator::Equals().id(), OperatorId::kEquals); + EXPECT_EQ(Operator::Equals().name(), "_==_"); + EXPECT_EQ(Operator::Equals().display_name(), "=="); + EXPECT_EQ(Operator::Equals().precedence(), 5); + EXPECT_EQ(Operator::Equals().arity(), 2); +} + +TEST(Operator, NotEquals) { + EXPECT_EQ(Operator::NotEquals().id(), OperatorId::kNotEquals); + EXPECT_EQ(Operator::NotEquals().name(), "_!=_"); + EXPECT_EQ(Operator::NotEquals().display_name(), "!="); + EXPECT_EQ(Operator::NotEquals().precedence(), 5); + EXPECT_EQ(Operator::NotEquals().arity(), 2); +} + +TEST(Operator, Less) { + EXPECT_EQ(Operator::Less().id(), OperatorId::kLess); + EXPECT_EQ(Operator::Less().name(), "_<_"); + EXPECT_EQ(Operator::Less().display_name(), "<"); + EXPECT_EQ(Operator::Less().precedence(), 5); + EXPECT_EQ(Operator::Less().arity(), 2); +} + +TEST(Operator, LessEquals) { + EXPECT_EQ(Operator::LessEquals().id(), OperatorId::kLessEquals); + EXPECT_EQ(Operator::LessEquals().name(), "_<=_"); + EXPECT_EQ(Operator::LessEquals().display_name(), "<="); + EXPECT_EQ(Operator::LessEquals().precedence(), 5); + EXPECT_EQ(Operator::LessEquals().arity(), 2); +} + +TEST(Operator, Greater) { + EXPECT_EQ(Operator::Greater().id(), OperatorId::kGreater); + EXPECT_EQ(Operator::Greater().name(), "_>_"); + EXPECT_EQ(Operator::Greater().display_name(), ">"); + EXPECT_EQ(Operator::Greater().precedence(), 5); + EXPECT_EQ(Operator::Greater().arity(), 2); +} + +TEST(Operator, GreaterEquals) { + EXPECT_EQ(Operator::GreaterEquals().id(), OperatorId::kGreaterEquals); + EXPECT_EQ(Operator::GreaterEquals().name(), "_>=_"); + EXPECT_EQ(Operator::GreaterEquals().display_name(), ">="); + EXPECT_EQ(Operator::GreaterEquals().precedence(), 5); + EXPECT_EQ(Operator::GreaterEquals().arity(), 2); +} + +TEST(Operator, Add) { + EXPECT_EQ(Operator::Add().id(), OperatorId::kAdd); + EXPECT_EQ(Operator::Add().name(), "_+_"); + EXPECT_EQ(Operator::Add().display_name(), "+"); + EXPECT_EQ(Operator::Add().precedence(), 4); + EXPECT_EQ(Operator::Add().arity(), 2); +} + +TEST(Operator, Subtract) { + EXPECT_EQ(Operator::Subtract().id(), OperatorId::kSubtract); + EXPECT_EQ(Operator::Subtract().name(), "_-_"); + EXPECT_EQ(Operator::Subtract().display_name(), "-"); + EXPECT_EQ(Operator::Subtract().precedence(), 4); + EXPECT_EQ(Operator::Subtract().arity(), 2); +} + +TEST(Operator, Multiply) { + EXPECT_EQ(Operator::Multiply().id(), OperatorId::kMultiply); + EXPECT_EQ(Operator::Multiply().name(), "_*_"); + EXPECT_EQ(Operator::Multiply().display_name(), "*"); + EXPECT_EQ(Operator::Multiply().precedence(), 3); + EXPECT_EQ(Operator::Multiply().arity(), 2); +} + +TEST(Operator, Divide) { + EXPECT_EQ(Operator::Divide().id(), OperatorId::kDivide); + EXPECT_EQ(Operator::Divide().name(), "_/_"); + EXPECT_EQ(Operator::Divide().display_name(), "/"); + EXPECT_EQ(Operator::Divide().precedence(), 3); + EXPECT_EQ(Operator::Divide().arity(), 2); +} + +TEST(Operator, Modulo) { + EXPECT_EQ(Operator::Modulo().id(), OperatorId::kModulo); + EXPECT_EQ(Operator::Modulo().name(), "_%_"); + EXPECT_EQ(Operator::Modulo().display_name(), "%"); + EXPECT_EQ(Operator::Modulo().precedence(), 3); + EXPECT_EQ(Operator::Modulo().arity(), 2); +} + +TEST(Operator, Negate) { + EXPECT_EQ(Operator::Negate().id(), OperatorId::kNegate); + EXPECT_EQ(Operator::Negate().name(), "-_"); + EXPECT_EQ(Operator::Negate().display_name(), "-"); + EXPECT_EQ(Operator::Negate().precedence(), 2); + EXPECT_EQ(Operator::Negate().arity(), 1); +} + +TEST(Operator, Index) { + EXPECT_EQ(Operator::Index().id(), OperatorId::kIndex); + EXPECT_EQ(Operator::Index().name(), "_[_]"); + EXPECT_EQ(Operator::Index().display_name(), ""); + EXPECT_EQ(Operator::Index().precedence(), 1); + EXPECT_EQ(Operator::Index().arity(), 2); +} + +TEST(Operator, In) { + EXPECT_EQ(Operator::In().id(), OperatorId::kIn); + EXPECT_EQ(Operator::In().name(), "@in"); + EXPECT_EQ(Operator::In().display_name(), "in"); + EXPECT_EQ(Operator::In().precedence(), 5); + EXPECT_EQ(Operator::In().arity(), 2); +} + +TEST(Operator, NotStrictlyFalse) { + EXPECT_EQ(Operator::NotStrictlyFalse().id(), OperatorId::kNotStrictlyFalse); + EXPECT_EQ(Operator::NotStrictlyFalse().name(), "@not_strictly_false"); + EXPECT_EQ(Operator::NotStrictlyFalse().display_name(), ""); + EXPECT_EQ(Operator::NotStrictlyFalse().precedence(), 0); + EXPECT_EQ(Operator::NotStrictlyFalse().arity(), 1); +} + +TEST(Operator, OldIn) { + EXPECT_EQ(Operator::OldIn().id(), OperatorId::kOldIn); + EXPECT_EQ(Operator::OldIn().name(), "_in_"); + EXPECT_EQ(Operator::OldIn().display_name(), "in"); + EXPECT_EQ(Operator::OldIn().precedence(), 5); + EXPECT_EQ(Operator::OldIn().arity(), 2); +} + +TEST(Operator, OldNotStrictlyFalse) { + EXPECT_EQ(Operator::OldNotStrictlyFalse().id(), + OperatorId::kOldNotStrictlyFalse); + EXPECT_EQ(Operator::OldNotStrictlyFalse().name(), "__not_strictly_false__"); + EXPECT_EQ(Operator::OldNotStrictlyFalse().display_name(), ""); + EXPECT_EQ(Operator::OldNotStrictlyFalse().precedence(), 0); + EXPECT_EQ(Operator::OldNotStrictlyFalse().arity(), 1); +} + +TEST(Operator, FindByName) { + auto status_or_operator = Operator::FindByName("@in"); + EXPECT_OK(status_or_operator); + EXPECT_EQ(status_or_operator.value(), Operator::In()); + status_or_operator = Operator::FindByName("_in_"); + EXPECT_OK(status_or_operator); + EXPECT_EQ(status_or_operator.value(), Operator::OldIn()); + status_or_operator = Operator::FindByName("in"); + EXPECT_THAT(status_or_operator, StatusIs(absl::StatusCode::kNotFound)); +} + +TEST(Operator, FindByDisplayName) { + auto status_or_operator = Operator::FindByDisplayName("-"); + EXPECT_OK(status_or_operator); + EXPECT_EQ(status_or_operator.value(), Operator::Subtract()); + status_or_operator = Operator::FindByDisplayName("@in"); + EXPECT_THAT(status_or_operator, StatusIs(absl::StatusCode::kNotFound)); +} + +TEST(Operator, FindUnaryByDisplayName) { + auto status_or_operator = Operator::FindUnaryByDisplayName("-"); + EXPECT_OK(status_or_operator); + EXPECT_EQ(status_or_operator.value(), Operator::Negate()); + status_or_operator = Operator::FindUnaryByDisplayName("&&"); + EXPECT_THAT(status_or_operator, StatusIs(absl::StatusCode::kNotFound)); +} + +TEST(Operator, FindBinaryByDisplayName) { + auto status_or_operator = Operator::FindBinaryByDisplayName("-"); + EXPECT_OK(status_or_operator); + EXPECT_EQ(status_or_operator.value(), Operator::Subtract()); + status_or_operator = Operator::FindBinaryByDisplayName("!"); + EXPECT_THAT(status_or_operator, StatusIs(absl::StatusCode::kNotFound)); +} + +TEST(Type, SupportsAbslHash) { + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ + Operator::Conditional(), + Operator::LogicalAnd(), + Operator::LogicalOr(), + Operator::LogicalNot(), + Operator::Equals(), + Operator::NotEquals(), + Operator::Less(), + Operator::LessEquals(), + Operator::Greater(), + Operator::GreaterEquals(), + Operator::Add(), + Operator::Subtract(), + Operator::Multiply(), + Operator::Divide(), + Operator::Modulo(), + Operator::Negate(), + Operator::Index(), + Operator::In(), + Operator::NotStrictlyFalse(), + Operator::OldIn(), + Operator::OldNotStrictlyFalse(), + })); +} + +} // namespace +} // namespace cel diff --git a/base/type.cc b/base/type.cc new file mode 100644 index 000000000..c1e0851c1 --- /dev/null +++ b/base/type.cc @@ -0,0 +1,260 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/type.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "base/internal/type.h" +#include "internal/reference_counted.h" + +namespace cel { + +namespace base_internal { + +// Implementation of BaseType for simple types. See SimpleTypes below for the +// types being implemented. +class SimpleType final : public BaseType { + public: + constexpr SimpleType(Kind kind, absl::string_view name) + : BaseType(), name_(name), kind_(kind) {} + + ~SimpleType() override { + // Simple types should live for the lifetime of the process, so destructing + // them is definetly a bug. + std::abort(); + } + + Kind kind() const override { return kind_; } + + absl::string_view name() const override { return name_; } + + absl::Span parameters() const override { return {}; } + + protected: + void HashValue(absl::HashState state) const override { + // cel::Type already adds both kind and name to the hash state, nothing else + // for us to do. + static_cast(state); + } + + bool Equals(const cel::Type& other) const override { + // cel::Type already checks that the kind and name are equivalent, so at + // this point the types are the same. + static_cast(other); + return true; + } + + private: + const absl::string_view name_; + const Kind kind_; +}; + +} // namespace base_internal + +namespace { + +struct SimpleTypes final { + constexpr SimpleTypes() = default; + + SimpleTypes(const SimpleTypes&) = delete; + + SimpleTypes(SimpleTypes&&) = delete; + + ~SimpleTypes() = default; + + SimpleTypes& operator=(const SimpleTypes&) = delete; + + SimpleTypes& operator=(SimpleTypes&&) = delete; + + Type error_type; + Type null_type; + Type dyn_type; + Type any_type; + Type bool_type; + Type int_type; + Type uint_type; + Type double_type; + Type string_type; + Type bytes_type; + Type duration_type; + Type timestamp_type; +}; + +ABSL_CONST_INIT absl::once_flag simple_types_once; +ABSL_CONST_INIT SimpleTypes* simple_types = nullptr; + +} // namespace + +void Type::Initialize() { + absl::call_once(simple_types_once, []() { + ABSL_ASSERT(simple_types == nullptr); + simple_types = new SimpleTypes(); + simple_types->error_type = + Type(new base_internal::SimpleType(Kind::kError, "*error*")); + simple_types->dyn_type = + Type(new base_internal::SimpleType(Kind::kDyn, "dyn")); + simple_types->any_type = + Type(new base_internal::SimpleType(Kind::kAny, "google.protobuf.Any")); + simple_types->bool_type = + Type(new base_internal::SimpleType(Kind::kBool, "bool")); + simple_types->int_type = + Type(new base_internal::SimpleType(Kind::kInt, "int")); + simple_types->uint_type = + Type(new base_internal::SimpleType(Kind::kUint, "uint")); + simple_types->double_type = + Type(new base_internal::SimpleType(Kind::kDouble, "double")); + simple_types->string_type = + Type(new base_internal::SimpleType(Kind::kString, "string")); + simple_types->bytes_type = + Type(new base_internal::SimpleType(Kind::kBytes, "bytes")); + simple_types->duration_type = Type(new base_internal::SimpleType( + Kind::kDuration, "google.protobuf.Duration")); + simple_types->timestamp_type = Type(new base_internal::SimpleType( + Kind::kTimestamp, "google.protobuf.Timestamp")); + }); +} + +const Type& Type::Simple(Kind kind) { + switch (kind) { + case Kind::kNullType: + return Null(); + case Kind::kError: + return Error(); + case Kind::kBool: + return Bool(); + case Kind::kInt: + return Int(); + case Kind::kUint: + return Uint(); + case Kind::kDouble: + return Double(); + case Kind::kDuration: + return Duration(); + case Kind::kTimestamp: + return Timestamp(); + case Kind::kString: + return String(); + case Kind::kBytes: + return Bytes(); + default: + // We can only get here via memory corruption in cel::Value via + // cel::base_internal::ValueMetadata, as the the kinds with simple tags + // are all covered here. + std::abort(); + } +} + +const Type& Type::Null() { + Initialize(); + return simple_types->null_type; +} + +const Type& Type::Error() { + Initialize(); + return simple_types->error_type; +} + +const Type& Type::Dyn() { + Initialize(); + return simple_types->dyn_type; +} + +const Type& Type::Any() { + Initialize(); + return simple_types->any_type; +} + +const Type& Type::Bool() { + Initialize(); + return simple_types->bool_type; +} + +const Type& Type::Int() { + Initialize(); + return simple_types->int_type; +} + +const Type& Type::Uint() { + Initialize(); + return simple_types->uint_type; +} + +const Type& Type::Double() { + Initialize(); + return simple_types->double_type; +} + +const Type& Type::String() { + Initialize(); + return simple_types->string_type; +} + +const Type& Type::Bytes() { + Initialize(); + return simple_types->bytes_type; +} + +const Type& Type::Duration() { + Initialize(); + return simple_types->duration_type; +} + +const Type& Type::Timestamp() { + Initialize(); + return simple_types->timestamp_type; +} + +Type::Type(const Type& other) : impl_(other.impl_) { internal::Ref(impl_); } + +Type::Type(Type&& other) : impl_(other.impl_) { other.impl_ = nullptr; } + +Type& Type::operator=(const Type& other) { + if (ABSL_PREDICT_TRUE(this != &other)) { + internal::Ref(other.impl_); + internal::Unref(impl_); + impl_ = other.impl_; + } + return *this; +} + +Type& Type::operator=(Type&& other) { + if (ABSL_PREDICT_TRUE(this != &other)) { + internal::Unref(impl_); + impl_ = other.impl_; + other.impl_ = nullptr; + } + return *this; +} + +bool Type::Equals(const Type& other) const { + return impl_ == other.impl_ || + (kind() == other.kind() && name() == other.name() && + // It should not be possible to reach here if impl_ is nullptr. + impl_->Equals(other)); +} + +void Type::HashValue(absl::HashState state) const { + state = absl::HashState::combine(std::move(state), kind(), name()); + if (impl_) { + impl_->HashValue(std::move(state)); + } +} + +} // namespace cel diff --git a/base/type.h b/base/type.h new file mode 100644 index 000000000..84d201536 --- /dev/null +++ b/base/type.h @@ -0,0 +1,159 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPE_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/hash/hash.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "base/internal/type.h" +#include "base/kind.h" +#include "internal/reference_counted.h" + +namespace cel { + +class Value; + +// A representation of a CEL type that enables reflection, for static analysis, +// and introspection, for program construction, of types. +class Type final { + public: + // Returns the null type. + ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Null(); + + // Returns the error type. + ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Error(); + + // Returns the dynamic type. + ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Dyn(); + + // Returns the any type. + ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Any(); + + // Returns the bool type. + ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Bool(); + + // Returns the int type. + ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Int(); + + // Returns the uint type. + ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Uint(); + + // Returns the double type. + ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Double(); + + // Returns the string type. + ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& String(); + + // Returns the bytes type. + ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Bytes(); + + // Returns the duration type. + ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Duration(); + + // Returns the timestamp type. + ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Timestamp(); + + // Equivalent to `Type::Null()`. + constexpr Type() : Type(nullptr) {} + + Type(const Type& other); + + Type(Type&& other); + + ~Type() { internal::Unref(impl_); } + + Type& operator=(const Type& other); + + Type& operator=(Type&& other); + + // Returns the type kind. + Kind kind() const { return impl_ ? impl_->kind() : Kind::kNullType; } + + // Returns the type name, i.e. "list". + absl::string_view name() const { return impl_ ? impl_->name() : "null_type"; } + + // Returns the type parameters of the type, i.e. key and value type of map. + absl::Span parameters() const { + return impl_ ? impl_->parameters() : absl::Span(); + } + + bool IsNull() const { return kind() == Kind::kNullType; } + + bool IsError() const { return kind() == Kind::kError; } + + bool IsDyn() const { return kind() == Kind::kDyn; } + + bool IsAny() const { return kind() == Kind::kAny; } + + bool IsBool() const { return kind() == Kind::kBool; } + + bool IsInt() const { return kind() == Kind::kInt; } + + bool IsUint() const { return kind() == Kind::kUint; } + + bool IsDouble() const { return kind() == Kind::kDouble; } + + bool IsString() const { return kind() == Kind::kString; } + + bool IsBytes() const { return kind() == Kind::kBytes; } + + bool IsDuration() const { return kind() == Kind::kDuration; } + + bool IsTimestamp() const { return kind() == Kind::kTimestamp; } + + template + friend H AbslHashValue(H state, const Type& type) { + type.HashValue(absl::HashState::Create(&state)); + return std::move(state); + } + + friend void swap(Type& lhs, Type& rhs) { + const base_internal::BaseType* impl = lhs.impl_; + lhs.impl_ = rhs.impl_; + rhs.impl_ = impl; + } + + friend bool operator==(const Type& lhs, const Type& rhs) { + return lhs.Equals(rhs); + } + + friend bool operator!=(const Type& lhs, const Type& rhs) { + return !operator==(lhs, rhs); + } + + private: + friend class Value; + + static void Initialize(); + + static const Type& Simple(Kind kind); + + constexpr explicit Type(const base_internal::BaseType* impl) : impl_(impl) {} + + bool Equals(const Type& other) const; + + void HashValue(absl::HashState state) const; + + const base_internal::BaseType* impl_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPE_H_ diff --git a/base/type_test.cc b/base/type_test.cc new file mode 100644 index 000000000..d8df5dae0 --- /dev/null +++ b/base/type_test.cc @@ -0,0 +1,307 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/type.h" + +#include +#include + +#include "absl/hash/hash_testing.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using testing::SizeIs; + +template +constexpr void IS_INITIALIZED(T&) {} + +TEST(Type, TypeTraits) { + EXPECT_TRUE(std::is_default_constructible_v); + EXPECT_TRUE(std::is_copy_constructible_v); + EXPECT_TRUE(std::is_move_constructible_v); + EXPECT_TRUE(std::is_copy_assignable_v); + EXPECT_TRUE(std::is_move_assignable_v); + EXPECT_TRUE(std::is_swappable_v); +} + +TEST(Type, DefaultConstructor) { + Type type; + EXPECT_EQ(type, Type::Null()); +} + +TEST(Type, CopyConstructor) { + Type type(Type::Int()); + EXPECT_EQ(type, Type::Int()); +} + +TEST(Type, MoveConstructor) { + Type from(Type::Int()); + Type to(std::move(from)); + IS_INITIALIZED(from); + EXPECT_EQ(from, Type::Null()); + EXPECT_EQ(to, Type::Int()); +} + +TEST(Type, CopyAssignment) { + Type type; + type = Type::Int(); + EXPECT_EQ(type, Type::Int()); +} + +TEST(Type, MoveAssignment) { + Type from(Type::Int()); + Type to; + to = std::move(from); + IS_INITIALIZED(from); + EXPECT_EQ(from, Type::Null()); + EXPECT_EQ(to, Type::Int()); +} + +TEST(Type, Swap) { + Type lhs = Type::Int(); + Type rhs = Type::Uint(); + std::swap(lhs, rhs); + EXPECT_EQ(lhs, Type::Uint()); + EXPECT_EQ(rhs, Type::Int()); +} + +// The below tests could be made parameterized but doing so requires the +// extension for struct member initiation by name for it to be worth it. That +// feature is not available in C++17. + +TEST(Type, Null) { + EXPECT_EQ(Type::Null().kind(), Kind::kNullType); + EXPECT_EQ(Type::Null().name(), "null_type"); + EXPECT_THAT(Type::Null().parameters(), SizeIs(0)); + EXPECT_TRUE(Type::Null().IsNull()); + EXPECT_FALSE(Type::Null().IsDyn()); + EXPECT_FALSE(Type::Null().IsAny()); + EXPECT_FALSE(Type::Null().IsBool()); + EXPECT_FALSE(Type::Null().IsInt()); + EXPECT_FALSE(Type::Null().IsUint()); + EXPECT_FALSE(Type::Null().IsDouble()); + EXPECT_FALSE(Type::Null().IsString()); + EXPECT_FALSE(Type::Null().IsBytes()); + EXPECT_FALSE(Type::Null().IsDuration()); + EXPECT_FALSE(Type::Null().IsTimestamp()); +} + +TEST(Type, Error) { + EXPECT_EQ(Type::Error().kind(), Kind::kError); + EXPECT_EQ(Type::Error().name(), "*error*"); + EXPECT_THAT(Type::Error().parameters(), SizeIs(0)); + EXPECT_FALSE(Type::Error().IsNull()); + EXPECT_FALSE(Type::Error().IsDyn()); + EXPECT_FALSE(Type::Error().IsAny()); + EXPECT_FALSE(Type::Error().IsBool()); + EXPECT_FALSE(Type::Error().IsInt()); + EXPECT_FALSE(Type::Error().IsUint()); + EXPECT_FALSE(Type::Error().IsDouble()); + EXPECT_FALSE(Type::Error().IsString()); + EXPECT_FALSE(Type::Error().IsBytes()); + EXPECT_FALSE(Type::Error().IsDuration()); + EXPECT_FALSE(Type::Error().IsTimestamp()); +} + +TEST(Type, Dyn) { + EXPECT_EQ(Type::Dyn().kind(), Kind::kDyn); + EXPECT_EQ(Type::Dyn().name(), "dyn"); + EXPECT_THAT(Type::Dyn().parameters(), SizeIs(0)); + EXPECT_FALSE(Type::Dyn().IsNull()); + EXPECT_TRUE(Type::Dyn().IsDyn()); + EXPECT_FALSE(Type::Dyn().IsAny()); + EXPECT_FALSE(Type::Dyn().IsBool()); + EXPECT_FALSE(Type::Dyn().IsInt()); + EXPECT_FALSE(Type::Dyn().IsUint()); + EXPECT_FALSE(Type::Dyn().IsDouble()); + EXPECT_FALSE(Type::Dyn().IsString()); + EXPECT_FALSE(Type::Dyn().IsBytes()); + EXPECT_FALSE(Type::Dyn().IsDuration()); + EXPECT_FALSE(Type::Dyn().IsTimestamp()); +} + +TEST(Type, Any) { + EXPECT_EQ(Type::Any().kind(), Kind::kAny); + EXPECT_EQ(Type::Any().name(), "google.protobuf.Any"); + EXPECT_THAT(Type::Any().parameters(), SizeIs(0)); + EXPECT_FALSE(Type::Any().IsNull()); + EXPECT_FALSE(Type::Any().IsDyn()); + EXPECT_TRUE(Type::Any().IsAny()); + EXPECT_FALSE(Type::Any().IsBool()); + EXPECT_FALSE(Type::Any().IsInt()); + EXPECT_FALSE(Type::Any().IsUint()); + EXPECT_FALSE(Type::Any().IsDouble()); + EXPECT_FALSE(Type::Any().IsString()); + EXPECT_FALSE(Type::Any().IsBytes()); + EXPECT_FALSE(Type::Any().IsDuration()); + EXPECT_FALSE(Type::Any().IsTimestamp()); +} + +TEST(Type, Bool) { + EXPECT_EQ(Type::Bool().kind(), Kind::kBool); + EXPECT_EQ(Type::Bool().name(), "bool"); + EXPECT_THAT(Type::Bool().parameters(), SizeIs(0)); + EXPECT_FALSE(Type::Bool().IsNull()); + EXPECT_FALSE(Type::Bool().IsDyn()); + EXPECT_FALSE(Type::Bool().IsAny()); + EXPECT_TRUE(Type::Bool().IsBool()); + EXPECT_FALSE(Type::Bool().IsInt()); + EXPECT_FALSE(Type::Bool().IsUint()); + EXPECT_FALSE(Type::Bool().IsDouble()); + EXPECT_FALSE(Type::Bool().IsString()); + EXPECT_FALSE(Type::Bool().IsBytes()); + EXPECT_FALSE(Type::Bool().IsDuration()); + EXPECT_FALSE(Type::Bool().IsTimestamp()); +} + +TEST(Type, Int) { + EXPECT_EQ(Type::Int().kind(), Kind::kInt); + EXPECT_EQ(Type::Int().name(), "int"); + EXPECT_THAT(Type::Int().parameters(), SizeIs(0)); + EXPECT_FALSE(Type::Int().IsNull()); + EXPECT_FALSE(Type::Int().IsDyn()); + EXPECT_FALSE(Type::Int().IsAny()); + EXPECT_FALSE(Type::Int().IsBool()); + EXPECT_TRUE(Type::Int().IsInt()); + EXPECT_FALSE(Type::Int().IsUint()); + EXPECT_FALSE(Type::Int().IsDouble()); + EXPECT_FALSE(Type::Int().IsString()); + EXPECT_FALSE(Type::Int().IsBytes()); + EXPECT_FALSE(Type::Int().IsDuration()); + EXPECT_FALSE(Type::Int().IsTimestamp()); +} + +TEST(Type, Uint) { + EXPECT_EQ(Type::Uint().kind(), Kind::kUint); + EXPECT_EQ(Type::Uint().name(), "uint"); + EXPECT_THAT(Type::Uint().parameters(), SizeIs(0)); + EXPECT_FALSE(Type::Uint().IsNull()); + EXPECT_FALSE(Type::Uint().IsDyn()); + EXPECT_FALSE(Type::Uint().IsAny()); + EXPECT_FALSE(Type::Uint().IsBool()); + EXPECT_FALSE(Type::Uint().IsInt()); + EXPECT_TRUE(Type::Uint().IsUint()); + EXPECT_FALSE(Type::Uint().IsDouble()); + EXPECT_FALSE(Type::Uint().IsString()); + EXPECT_FALSE(Type::Uint().IsBytes()); + EXPECT_FALSE(Type::Uint().IsDuration()); + EXPECT_FALSE(Type::Uint().IsTimestamp()); +} + +TEST(Type, Double) { + EXPECT_EQ(Type::Double().kind(), Kind::kDouble); + EXPECT_EQ(Type::Double().name(), "double"); + EXPECT_THAT(Type::Double().parameters(), SizeIs(0)); + EXPECT_FALSE(Type::Double().IsNull()); + EXPECT_FALSE(Type::Double().IsDyn()); + EXPECT_FALSE(Type::Double().IsAny()); + EXPECT_FALSE(Type::Double().IsBool()); + EXPECT_FALSE(Type::Double().IsInt()); + EXPECT_FALSE(Type::Double().IsUint()); + EXPECT_TRUE(Type::Double().IsDouble()); + EXPECT_FALSE(Type::Double().IsString()); + EXPECT_FALSE(Type::Double().IsBytes()); + EXPECT_FALSE(Type::Double().IsDuration()); + EXPECT_FALSE(Type::Double().IsTimestamp()); +} + +TEST(Type, String) { + EXPECT_EQ(Type::String().kind(), Kind::kString); + EXPECT_EQ(Type::String().name(), "string"); + EXPECT_THAT(Type::String().parameters(), SizeIs(0)); + EXPECT_FALSE(Type::String().IsNull()); + EXPECT_FALSE(Type::String().IsDyn()); + EXPECT_FALSE(Type::String().IsAny()); + EXPECT_FALSE(Type::String().IsBool()); + EXPECT_FALSE(Type::String().IsInt()); + EXPECT_FALSE(Type::String().IsUint()); + EXPECT_FALSE(Type::String().IsDouble()); + EXPECT_TRUE(Type::String().IsString()); + EXPECT_FALSE(Type::String().IsBytes()); + EXPECT_FALSE(Type::String().IsDuration()); + EXPECT_FALSE(Type::String().IsTimestamp()); +} + +TEST(Type, Bytes) { + EXPECT_EQ(Type::Bytes().kind(), Kind::kBytes); + EXPECT_EQ(Type::Bytes().name(), "bytes"); + EXPECT_THAT(Type::Bytes().parameters(), SizeIs(0)); + EXPECT_FALSE(Type::Bytes().IsNull()); + EXPECT_FALSE(Type::Bytes().IsDyn()); + EXPECT_FALSE(Type::Bytes().IsAny()); + EXPECT_FALSE(Type::Bytes().IsBool()); + EXPECT_FALSE(Type::Bytes().IsInt()); + EXPECT_FALSE(Type::Bytes().IsUint()); + EXPECT_FALSE(Type::Bytes().IsDouble()); + EXPECT_FALSE(Type::Bytes().IsString()); + EXPECT_TRUE(Type::Bytes().IsBytes()); + EXPECT_FALSE(Type::Bytes().IsDuration()); + EXPECT_FALSE(Type::Bytes().IsTimestamp()); +} + +TEST(Type, Duration) { + EXPECT_EQ(Type::Duration().kind(), Kind::kDuration); + EXPECT_EQ(Type::Duration().name(), "google.protobuf.Duration"); + EXPECT_THAT(Type::Duration().parameters(), SizeIs(0)); + EXPECT_FALSE(Type::Duration().IsNull()); + EXPECT_FALSE(Type::Duration().IsDyn()); + EXPECT_FALSE(Type::Duration().IsAny()); + EXPECT_FALSE(Type::Duration().IsBool()); + EXPECT_FALSE(Type::Duration().IsInt()); + EXPECT_FALSE(Type::Duration().IsUint()); + EXPECT_FALSE(Type::Duration().IsDouble()); + EXPECT_FALSE(Type::Duration().IsString()); + EXPECT_FALSE(Type::Duration().IsBytes()); + EXPECT_TRUE(Type::Duration().IsDuration()); + EXPECT_FALSE(Type::Duration().IsTimestamp()); +} + +TEST(Type, Timestamp) { + EXPECT_EQ(Type::Timestamp().kind(), Kind::kTimestamp); + EXPECT_EQ(Type::Timestamp().name(), "google.protobuf.Timestamp"); + EXPECT_THAT(Type::Timestamp().parameters(), SizeIs(0)); + EXPECT_FALSE(Type::Timestamp().IsNull()); + EXPECT_FALSE(Type::Timestamp().IsDyn()); + EXPECT_FALSE(Type::Timestamp().IsAny()); + EXPECT_FALSE(Type::Timestamp().IsBool()); + EXPECT_FALSE(Type::Timestamp().IsInt()); + EXPECT_FALSE(Type::Timestamp().IsUint()); + EXPECT_FALSE(Type::Timestamp().IsDouble()); + EXPECT_FALSE(Type::Timestamp().IsString()); + EXPECT_FALSE(Type::Timestamp().IsBytes()); + EXPECT_FALSE(Type::Timestamp().IsDuration()); + EXPECT_TRUE(Type::Timestamp().IsTimestamp()); +} + +TEST(Type, SupportsAbslHash) { + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ + Type::Error(), + Type::Null(), + Type::Dyn(), + Type::Any(), + Type::Bool(), + Type::Int(), + Type::Uint(), + Type::Double(), + Type::String(), + Type::Bytes(), + Type::Duration(), + Type::Timestamp(), + })); +} + +} // namespace +} // namespace cel diff --git a/base/value.cc b/base/value.cc new file mode 100644 index 000000000..e28e20c81 --- /dev/null +++ b/base/value.cc @@ -0,0 +1,789 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/value.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "base/internal/value.h" +#include "internal/reference_counted.h" +#include "internal/status_macros.h" +#include "internal/strings.h" +#include "internal/time.h" + +namespace cel { + +namespace { + +struct StatusPayload final { + std::string key; + absl::Cord value; +}; + +void StatusHashValue(absl::HashState state, const absl::Status& status) { + // absl::Status::operator== compares `raw_code()`, `message()` and the + // payloads. + state = absl::HashState::combine(std::move(state), status.raw_code(), + status.message()); + // In order to determistically hash, we need to put the payloads in sorted + // order. There is no guarantee from `absl::Status` on the order of the + // payloads returned from `absl::Status::ForEachPayload`. + // + // This should be the same inline size as + // `absl::status_internal::StatusPayloads`. + absl::InlinedVector payloads; + status.ForEachPayload([&](absl::string_view key, const absl::Cord& value) { + payloads.push_back(StatusPayload{std::string(key), value}); + }); + std::stable_sort( + payloads.begin(), payloads.end(), + [](const StatusPayload& lhs, const StatusPayload& rhs) -> bool { + return lhs.key < rhs.key; + }); + for (const auto& payload : payloads) { + state = + absl::HashState::combine(std::move(state), payload.key, payload.value); + } +} + +// SimpleValues holds common values that are frequently needed and should not be +// constructed everytime they are required, usually because they would require a +// heap allocation. An example of this is an empty byte string. +struct SimpleValues final { + public: + SimpleValues() = default; + + SimpleValues(const SimpleValues&) = delete; + + SimpleValues(SimpleValues&&) = delete; + + SimpleValues& operator=(const SimpleValues&) = delete; + + SimpleValues& operator=(SimpleValues&&) = delete; + + Value empty_bytes; +}; + +ABSL_CONST_INIT absl::once_flag simple_values_once; +ABSL_CONST_INIT SimpleValues* simple_values = nullptr; + +} // namespace + +Value Value::Error(const absl::Status& status) { + ABSL_ASSERT(!status.ok()); + if (ABSL_PREDICT_FALSE(status.ok())) { + return Value(absl::UnknownError( + "If you are seeing this message the caller attempted to construct an " + "error value from a successful status. Refusing to fail " + "successfully.")); + } + return Value(status); +} + +absl::StatusOr Value::Duration(absl::Duration value) { + CEL_RETURN_IF_ERROR(internal::ValidateDuration(value)); + int64_t seconds = absl::IDivDuration(value, absl::Seconds(1), &value); + int64_t nanoseconds = absl::IDivDuration(value, absl::Nanoseconds(1), &value); + return Value(Kind::kDuration, seconds, + absl::bit_cast(static_cast(nanoseconds))); +} + +absl::StatusOr Value::Timestamp(absl::Time value) { + CEL_RETURN_IF_ERROR(internal::ValidateTimestamp(value)); + absl::Duration duration = value - absl::UnixEpoch(); + int64_t seconds = absl::IDivDuration(duration, absl::Seconds(1), &duration); + int64_t nanoseconds = + absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); + return Value(Kind::kTimestamp, seconds, + absl::bit_cast(static_cast(nanoseconds))); +} + +Value::Value(const Value& other) { + // metadata_ is currently equal to the simple null type. + // content_ is zero initialized. + switch (other.kind()) { + case Kind::kNullType: + // `this` is already the null value, do nothing. + return; + case Kind::kBool: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kInt: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kUint: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kDouble: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kDuration: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kTimestamp: + // `other` is a simple value and simple type. We only need to trivially + // copy metadata_ and content_. + metadata_.CopyFrom(other.metadata_); + content_.construct_trivial_value(other.content_.trivial_value()); + return; + case Kind::kError: + // `other` is an error value and a simple type. We need to trivially copy + // metadata_ and copy construct the error value to content_. + metadata_.CopyFrom(other.metadata_); + content_.construct_error_value(other.content_.error_value()); + return; + case Kind::kBytes: + // `other` is a reffed value and a simple type. We need to trivially copy + // metadata_ and copy construct the reffed value to content_. + metadata_.CopyFrom(other.metadata_); + content_.construct_reffed_value(other.content_.reffed_value()); + return; + default: + // TODO(issues/5): remove after implementing other kinds + std::abort(); + } +} + +Value::Value(Value&& other) { + // metadata_ is currently equal to the simple null type. + // content_ is currently zero initialized. + switch (other.kind()) { + case Kind::kNullType: + // `this` and `other` are already the null value, do nothing. + return; + case Kind::kBool: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kInt: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kUint: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kDouble: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kDuration: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kTimestamp: + // `other` is a simple value and simple type. Trivially copy and then + // clear metadata_ and content_, making `other` equivalent to `Value()` or + // `Value::Null()`. + metadata_.MoveFrom(std::move(other.metadata_)); + content_.construct_trivial_value(other.content_.trivial_value()); + other.content_.destruct_trivial_value(); + break; + case Kind::kError: + // `other` is an error value and simple type. Trivially copy and then + // clear metadata_ and copy construct and then clear content_, making + // `other` equivalent to `Value()` or `Value::Null()`. + metadata_.MoveFrom(std::move(other.metadata_)); + content_.construct_error_value(other.content_.error_value()); + other.content_.destruct_error_value(); + break; + case Kind::kBytes: + // `other` is a reffed value and simple type. Trivially copy and then + // clear metadata_ and trivially move content_, making + // `other` equivalent to `Value()` or `Value::Null()`. + metadata_.MoveFrom(std::move(other.metadata_)); + content_.adopt_reffed_value(other.content_.release_reffed_value()); + break; + default: + // TODO(issues/5): remove after implementing other kinds + std::abort(); + } +} + +Value::~Value() { Destruct(this); } + +Value& Value::operator=(const Value& other) { + if (ABSL_PREDICT_TRUE(this != std::addressof(other))) { + switch (other.kind()) { + case Kind::kNullType: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kBool: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kInt: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kUint: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kDouble: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kDuration: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kTimestamp: + // `this` could be a simple value, an error value, or a reffed value. + // First we destruct resetting `this` to `Value()`. Then we perform the + // equivalent work of the copy constructor. + Destruct(this); + metadata_.CopyFrom(other.metadata_); + content_.construct_trivial_value(other.content_.trivial_value()); + break; + case Kind::kError: + if (kind() == Kind::kError) { + // `this` and `other` are error values. Perform a copy assignment + // which is faster than destructing and copy constructing. + content_.assign_error_value(other.content_.error_value()); + } else { + // `this` could be a simple value or a reffed value. First we destruct + // resetting `this` to `Value()`. Then we perform the equivalent work + // of the copy constructor. + Destruct(this); + content_.construct_error_value(other.content_.error_value()); + } + // Always copy metadata, for forward compatibility in case other bits + // are added. + metadata_.CopyFrom(other.metadata_); + break; + case Kind::kBytes: { + // `this` could be a simple value, an error value, or a reffed value. + // First we destruct resetting `this` to `Value()`. Then we perform the + // equivalent work of the copy constructor. + base_internal::BaseValue* reffed_value = + internal::Ref(other.content_.reffed_value()); + Destruct(this); + metadata_.CopyFrom(other.metadata_); + // Adopt is typically used for moves, but in this case we already + // increment the reference count, so it is equivalent to a move. + content_.adopt_reffed_value(reffed_value); + } break; + default: + // TODO(issues/5): remove after implementing other kinds + std::abort(); + } + } + return *this; +} + +Value& Value::operator=(Value&& other) { + if (ABSL_PREDICT_TRUE(this != std::addressof(other))) { + switch (other.kind()) { + case Kind::kNullType: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kBool: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kInt: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kUint: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kDouble: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kDuration: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kTimestamp: + // `this` could be a simple value, an error value, or a reffed value. + // First we destruct resetting `this` to `Value()`. Then we perform the + // equivalent work of the move constructor. + Destruct(this); + metadata_.MoveFrom(std::move(other.metadata_)); + content_.construct_trivial_value(other.content_.trivial_value()); + other.content_.destruct_trivial_value(); + break; + case Kind::kError: + if (kind() == Kind::kError) { + // `this` and `other` are error values. Perform a copy assignment + // which is faster than destructing and copy constructing. `other` + // will be reset below. + content_.assign_error_value(other.content_.error_value()); + } else { + // `this` could be a simple value or a reffed value. First we destruct + // resetting `this` to `Value()`. Then we perform the equivalent work + // of the copy constructor. + Destruct(this); + content_.construct_error_value(other.content_.error_value()); + } + // Always copy metadata, for forward compatibility in case other bits + // are added. + metadata_.CopyFrom(other.metadata_); + // Reset `other` to `Value()`. + Destruct(std::addressof(other)); + break; + case Kind::kBytes: + // `this` could be a simple value, an error value, or a reffed value. + // First we destruct resetting `this` to `Value()`. Then we perform the + // equivalent work of the move constructor. + Destruct(this); + metadata_.MoveFrom(std::move(other.metadata_)); + content_.adopt_reffed_value(other.content_.release_reffed_value()); + break; + default: + // TODO(issues/5): remove after implementing other kinds + std::abort(); + } + } + return *this; +} + +std::string Value::DebugString() const { + switch (kind()) { + case Kind::kNullType: + return "null"; + case Kind::kBool: + return AsBool() ? "true" : "false"; + case Kind::kInt: + return absl::StrCat(AsInt()); + case Kind::kUint: + return absl::StrCat(AsUint(), "u"); + case Kind::kDouble: { + if (std::isfinite(AsDouble())) { + if (static_cast(static_cast(AsDouble())) != + AsDouble()) { + // The double is not representable as a whole number, so use + // absl::StrCat which will add decimal places. + return absl::StrCat(AsDouble()); + } + // absl::StrCat historically would represent 0.0 as 0, and we want the + // decimal places so ZetaSQL correctly assumes the type as double + // instead of int64_t. + std::string stringified = absl::StrCat(AsDouble()); + if (!absl::StrContains(stringified, '.')) { + absl::StrAppend(&stringified, ".0"); + } else { + // absl::StrCat has a decimal now? Use it directly. + } + return stringified; + } + if (std::isnan(AsDouble())) { + return "nan"; + } + if (std::signbit(AsDouble())) { + return "-infinity"; + } + return "+infinity"; + } + case Kind::kDuration: + return internal::FormatDuration(AsDuration()).value(); + case Kind::kTimestamp: + return internal::FormatTimestamp(AsTimestamp()).value(); + case Kind::kError: + return AsError().ToString(); + case Kind::kBytes: + return content_.reffed_value()->DebugString(); + default: + // TODO(issues/5): remove after implementing other kinds + std::abort(); + } +} + +void Value::InitializeSingletons() { + absl::call_once(simple_values_once, []() { + ABSL_ASSERT(simple_values == nullptr); + simple_values = new SimpleValues(); + simple_values->empty_bytes = Value(Kind::kBytes, new cel::Bytes()); + }); +} + +void Value::Destruct(Value* dest) { + // Perform any deallocations or destructions necessary and reset the state + // of `dest` to `Value()` making it the null value. + switch (dest->kind()) { + case Kind::kNullType: + return; + case Kind::kBool: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kInt: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kUint: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kDouble: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kDuration: + ABSL_FALLTHROUGH_INTENDED; + case Kind::kTimestamp: + dest->content_.destruct_trivial_value(); + break; + case Kind::kError: + dest->content_.destruct_error_value(); + break; + case Kind::kBytes: + dest->content_.destruct_reffed_value(); + break; + default: + // TODO(issues/5): remove after implementing other kinds + std::abort(); + } + dest->metadata_.Reset(); +} + +void Value::HashValue(absl::HashState state) const { + state = absl::HashState::combine(std::move(state), type()); + switch (kind()) { + case Kind::kNullType: + absl::HashState::combine(std::move(state), 0); + return; + case Kind::kBool: + absl::HashState::combine(std::move(state), AsBool()); + return; + case Kind::kInt: + absl::HashState::combine(std::move(state), AsInt()); + return; + case Kind::kUint: + absl::HashState::combine(std::move(state), AsUint()); + return; + case Kind::kDouble: + absl::HashState::combine(std::move(state), AsDouble()); + return; + case Kind::kDuration: + absl::HashState::combine(std::move(state), AsDuration()); + return; + case Kind::kTimestamp: + absl::HashState::combine(std::move(state), AsTimestamp()); + return; + case Kind::kError: + StatusHashValue(std::move(state), AsError()); + return; + case Kind::kBytes: + content_.reffed_value()->HashValue(std::move(state)); + return; + default: + // TODO(issues/5): remove after implementing other kinds + std::abort(); + } +} + +bool Value::Equals(const Value& other) const { + // Comparing types is not enough as type may only compare the type name, + // which could be the same in separate environments but different kinds. So + // we also compare the kinds. + if (kind() != other.kind() || type() != other.type()) { + return false; + } + switch (kind()) { + case Kind::kNullType: + return true; + case Kind::kBool: + return AsBool() == other.AsBool(); + case Kind::kInt: + return AsInt() == other.AsInt(); + case Kind::kUint: + return AsUint() == other.AsUint(); + case Kind::kDouble: + return AsDouble() == other.AsDouble(); + case Kind::kDuration: + return AsDuration() == other.AsDuration(); + case Kind::kTimestamp: + return AsTimestamp() == other.AsTimestamp(); + case Kind::kError: + return AsError() == other.AsError(); + case Kind::kBytes: + return content_.reffed_value()->Equals(other); + default: + // TODO(issues/5): remove after implementing other kinds + std::abort(); + } +} + +void Value::Swap(Value& other) { + // TODO(issues/5): Optimize this after other values are implemented + Value tmp(std::move(other)); + other = std::move(*this); + *this = std::move(tmp); +} + +namespace { + +constexpr absl::string_view ExternalDataToStringView( + const base_internal::ExternalData& external_data) { + return absl::string_view(static_cast(external_data.data), + external_data.size); +} + +struct DebugStringVisitor final { + std::string operator()(const std::string& value) const { + return internal::FormatBytesLiteral(value); + } + + std::string operator()(const absl::Cord& value) const { + absl::string_view flat; + if (value.GetFlat(&flat)) { + return internal::FormatBytesLiteral(flat); + } + return internal::FormatBytesLiteral(value.ToString()); + } + + std::string operator()(const base_internal::ExternalData& value) const { + return internal::FormatBytesLiteral(ExternalDataToStringView(value)); + } +}; + +struct ToCordReleaser final { + void operator()() const { internal::Unref(refcnt); } + + const internal::ReferenceCounted* refcnt; +}; + +struct ToStringVisitor final { + std::string operator()(const std::string& value) const { return value; } + + std::string operator()(const absl::Cord& value) const { + return value.ToString(); + } + + std::string operator()(const base_internal::ExternalData& value) const { + return std::string(static_cast(value.data), value.size); + } +}; + +struct ToCordVisitor final { + const internal::ReferenceCounted* refcnt; + + absl::Cord operator()(const std::string& value) const { + internal::Ref(refcnt); + return absl::MakeCordFromExternal(value, ToCordReleaser{refcnt}); + } + + absl::Cord operator()(const absl::Cord& value) const { return value; } + + absl::Cord operator()(const base_internal::ExternalData& value) const { + internal::Ref(refcnt); + return absl::MakeCordFromExternal(ExternalDataToStringView(value), + ToCordReleaser{refcnt}); + } +}; + +struct SizeVisitor final { + size_t operator()(const std::string& value) const { return value.size(); } + + size_t operator()(const absl::Cord& value) const { return value.size(); } + + size_t operator()(const base_internal::ExternalData& value) const { + return value.size; + } +}; + +struct EmptyVisitor final { + bool operator()(const std::string& value) const { return value.empty(); } + + bool operator()(const absl::Cord& value) const { return value.empty(); } + + bool operator()(const base_internal::ExternalData& value) const { + return value.size == 0; + } +}; + +bool EqualsImpl(absl::string_view lhs, absl::string_view rhs) { + return lhs == rhs; +} + +bool EqualsImpl(absl::string_view lhs, const absl::Cord& rhs) { + return lhs == rhs; +} + +bool EqualsImpl(const absl::Cord& lhs, absl::string_view rhs) { + return lhs == rhs; +} + +bool EqualsImpl(const absl::Cord& lhs, const absl::Cord& rhs) { + return lhs == rhs; +} + +int CompareImpl(absl::string_view lhs, absl::string_view rhs) { + return lhs.compare(rhs); +} + +int CompareImpl(absl::string_view lhs, const absl::Cord& rhs) { + return -rhs.Compare(lhs); +} + +int CompareImpl(const absl::Cord& lhs, absl::string_view rhs) { + return lhs.Compare(rhs); +} + +int CompareImpl(const absl::Cord& lhs, const absl::Cord& rhs) { + return lhs.Compare(rhs); +} + +template +class EqualsVisitor final { + public: + explicit EqualsVisitor(const T& ref) : ref_(ref) {} + + bool operator()(const std::string& value) const { + return EqualsImpl(value, ref_); + } + + bool operator()(const absl::Cord& value) const { + return EqualsImpl(value, ref_); + } + + bool operator()(const base_internal::ExternalData& value) const { + return EqualsImpl(ExternalDataToStringView(value), ref_); + } + + private: + const T& ref_; +}; + +template <> +class EqualsVisitor final { + public: + explicit EqualsVisitor(const Bytes& ref) : ref_(ref) {} + + bool operator()(const std::string& value) const { return ref_.Equals(value); } + + bool operator()(const absl::Cord& value) const { return ref_.Equals(value); } + + bool operator()(const base_internal::ExternalData& value) const { + return ref_.Equals(ExternalDataToStringView(value)); + } + + private: + const Bytes& ref_; +}; + +template +class CompareVisitor final { + public: + explicit CompareVisitor(const T& ref) : ref_(ref) {} + + int operator()(const std::string& value) const { + return CompareImpl(value, ref_); + } + + int operator()(const absl::Cord& value) const { + return CompareImpl(value, ref_); + } + + int operator()(const base_internal::ExternalData& value) const { + return CompareImpl(ExternalDataToStringView(value), ref_); + } + + private: + const T& ref_; +}; + +template <> +class CompareVisitor final { + public: + explicit CompareVisitor(const Bytes& ref) : ref_(ref) {} + + int operator()(const std::string& value) const { return ref_.Compare(value); } + + int operator()(const absl::Cord& value) const { return ref_.Compare(value); } + + int operator()(absl::string_view value) const { return ref_.Compare(value); } + + int operator()(const base_internal::ExternalData& value) const { + return ref_.Compare(ExternalDataToStringView(value)); + } + + private: + const Bytes& ref_; +}; + +class HashValueVisitor final { + public: + explicit HashValueVisitor(absl::HashState state) : state_(std::move(state)) {} + + void operator()(const std::string& value) { + absl::HashState::combine(std::move(state_), value); + } + + void operator()(const absl::Cord& value) { + absl::HashState::combine(std::move(state_), value); + } + + void operator()(const base_internal::ExternalData& value) { + absl::HashState::combine(std::move(state_), + ExternalDataToStringView(value)); + } + + private: + absl::HashState state_; +}; + +} // namespace + +Value Bytes::Empty() { + Value::InitializeSingletons(); + return simple_values->empty_bytes; +} + +Value Bytes::New(std::string value) { + if (value.empty()) { + return Empty(); + } + return Value(Kind::kBytes, new Bytes(std::move(value))); +} + +Value Bytes::New(absl::Cord value) { + if (value.empty()) { + return Empty(); + } + return Value(Kind::kBytes, new Bytes(std::move(value))); +} + +Value Bytes::Concat(const Bytes& lhs, const Bytes& rhs) { + absl::Cord value; + value.Append(lhs.ToCord()); + value.Append(rhs.ToCord()); + return New(std::move(value)); +} + +size_t Bytes::size() const { return absl::visit(SizeVisitor{}, data_); } + +bool Bytes::empty() const { return absl::visit(EmptyVisitor{}, data_); } + +bool Bytes::Equals(absl::string_view bytes) const { + return absl::visit(EqualsVisitor(bytes), data_); +} + +bool Bytes::Equals(const absl::Cord& bytes) const { + return absl::visit(EqualsVisitor(bytes), data_); +} + +bool Bytes::Equals(const Bytes& bytes) const { + return absl::visit(EqualsVisitor(*this), bytes.data_); +} + +int Bytes::Compare(absl::string_view bytes) const { + return absl::visit(CompareVisitor(bytes), data_); +} + +int Bytes::Compare(const absl::Cord& bytes) const { + return absl::visit(CompareVisitor(bytes), data_); +} + +int Bytes::Compare(const Bytes& bytes) const { + return absl::visit(CompareVisitor(*this), bytes.data_); +} + +std::string Bytes::ToString() const { + return absl::visit(ToStringVisitor{}, data_); +} + +absl::Cord Bytes::ToCord() const { + return absl::visit(ToCordVisitor{this}, data_); +} + +std::string Bytes::DebugString() const { + return absl::visit(DebugStringVisitor{}, data_); +} + +bool Bytes::Equals(const Value& value) const { + ABSL_ASSERT(value.IsBytes()); + return absl::visit(EqualsVisitor(*this), value.AsBytes().data_); +} + +void Bytes::HashValue(absl::HashState state) const { + absl::visit(HashValueVisitor(std::move(state)), data_); +} + +} // namespace cel diff --git a/base/value.h b/base/value.h new file mode 100644 index 000000000..5b62ff940 --- /dev/null +++ b/base/value.h @@ -0,0 +1,380 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUE_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/casts.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/variant.h" +#include "base/internal/value.h" +#include "base/kind.h" +#include "base/type.h" +#include "internal/casts.h" + +namespace cel { + +// A representation of a CEL value that enables reflection and introspection of +// values. +// +// TODO(issues/5): document once derived implementations stabilize +class Value final { + public: + // Returns the null value. + ABSL_ATTRIBUTE_PURE_FUNCTION static Value Null() { return Value(); } + + // Constructs an error value. It is required that `status` is non-OK, + // otherwise behavior is undefined. + static Value Error(const absl::Status& status); + + // Returns a bool value. + static Value Bool(bool value) { return Value(value); } + + // Returns the false bool value. Equivalent to `Value::Bool(false)`. + ABSL_ATTRIBUTE_PURE_FUNCTION static Value False() { return Bool(false); } + + // Returns the true bool value. Equivalent to `Value::Bool(true)`. + ABSL_ATTRIBUTE_PURE_FUNCTION static Value True() { return Bool(true); } + + // Returns an int value. + static Value Int(int64_t value) { return Value(value); } + + // Returns a uint value. + static Value Uint(uint64_t value) { return Value(value); } + + // Returns a double value. + static Value Double(double value) { return Value(value); } + + // Returns a NaN double value. Equivalent to `Value::Double(NAN)`. + ABSL_ATTRIBUTE_PURE_FUNCTION static Value NaN() { + return Double(std::numeric_limits::quiet_NaN()); + } + + // Returns a positive infinity double value. Equivalent to + // `Value::Double(INFINITY)`. + ABSL_ATTRIBUTE_PURE_FUNCTION static Value PositiveInfinity() { + return Double(std::numeric_limits::infinity()); + } + + // Returns a negative infinity double value. Equivalent to + // `Value::Double(-INFINITY)`. + ABSL_ATTRIBUTE_PURE_FUNCTION static Value NegativeInfinity() { + return Double(-std::numeric_limits::infinity()); + } + + // Returns a duration value or a `absl::StatusCode::kInvalidArgument` error if + // the value is not in the valid range. + static absl::StatusOr Duration(absl::Duration value); + + // Returns the zero duration value. Equivalent to + // `Value::Duration(absl::ZeroDuration())`. + ABSL_ATTRIBUTE_PURE_FUNCTION static Value ZeroDuration() { + return Value(Kind::kDuration, 0, 0); + } + + // Returns a timestamp value or a `absl::StatusCode::kInvalidArgument` error + // if the value is not in the valid range. + static absl::StatusOr Timestamp(absl::Time value); + + // Returns the zero timestamp value. Equivalent to + // `Value::Timestamp(absl::UnixEpoch())`. + ABSL_ATTRIBUTE_PURE_FUNCTION static Value UnixEpoch() { + return Value(Kind::kTimestamp, 0, 0); + } + + // Equivalent to `Value::Null()`. + constexpr Value() = default; + + Value(const Value& other); + + Value(Value&& other); + + ~Value(); + + Value& operator=(const Value& other); + + Value& operator=(Value&& other); + + // Returns the type of the value. If you only need the kind, prefer `kind()`. + cel::Type type() const { + return metadata_.simple_tag() + ? cel::Type::Simple(metadata_.kind()) + : cel::Type(internal::Ref(metadata_.base_type())); + } + + // Returns the kind of the value. This is equivalent to `type().kind()` but + // faster in many scenarios. As such it should be preffered when only the kind + // is required. + Kind kind() const { return metadata_.kind(); } + + // True if this is the null value, false otherwise. + bool IsNull() const { return kind() == Kind::kNullType; } + + // True if this is an error value, false otherwise. + bool IsError() const { return kind() == Kind::kError; } + + // True if this is a bool value, false otherwise. + bool IsBool() const { return kind() == Kind::kBool; } + + // True if this is an int value, false otherwise. + bool IsInt() const { return kind() == Kind::kInt; } + + // True if this is a uint value, false otherwise. + bool IsUint() const { return kind() == Kind::kUint; } + + // True if this is a double value, false otherwise. + bool IsDouble() const { return kind() == Kind::kDouble; } + + // True if this is a duration value, false otherwise. + bool IsDuration() const { return kind() == Kind::kDuration; } + + // True if this is a timestamp value, false otherwise. + bool IsTimestamp() const { return kind() == Kind::kTimestamp; } + + // True if this is a bytes value, false otherwise. + bool IsBytes() const { return kind() == Kind::kBytes; } + + // Returns the C++ error value. Requires `kind() == Kind::kError` or behavior + // is undefined. + const absl::Status& AsError() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_ASSERT(IsError()); + return content_.error_value(); + } + + // Returns the C++ bool value. Requires `kind() == Kind::kBool` or behavior is + // undefined. + bool AsBool() const { + ABSL_ASSERT(IsBool()); + return content_.bool_value(); + } + + // Returns the C++ int value. Requires `kind() == Kind::kInt` or behavior is + // undefined. + int64_t AsInt() const { + ABSL_ASSERT(IsInt()); + return content_.int_value(); + } + + // Returns the C++ uint value. Requires `kind() == Kind::kUint` or behavior is + // undefined. + uint64_t AsUint() const { + ABSL_ASSERT(IsUint()); + return content_.uint_value(); + } + + // Returns the C++ double value. Requires `kind() == Kind::kDouble` or + // behavior is undefined. + double AsDouble() const { + ABSL_ASSERT(IsDouble()); + return content_.double_value(); + } + + // Returns the C++ duration value. Requires `kind() == Kind::kDuration` or + // behavior is undefined. + absl::Duration AsDuration() const { + ABSL_ASSERT(IsDuration()); + return absl::Seconds(content_.int_value()) + + absl::Nanoseconds( + absl::bit_cast(metadata_.extended_content())); + } + + // Returns the C++ timestamp value. Requires `kind() == Kind::kTimestamp` or + // behavior is undefined. + absl::Time AsTimestamp() const { + // Timestamp is stored as the duration since Unix Epoch. + ABSL_ASSERT(IsTimestamp()); + return absl::UnixEpoch() + absl::Seconds(content_.int_value()) + + absl::Nanoseconds( + absl::bit_cast(metadata_.extended_content())); + } + + std::string DebugString() const; + + const Bytes& AsBytes() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_ASSERT(IsBytes()); + return internal::down_cast(*content_.reffed_value()); + } + + template + friend H AbslHashValue(H state, const Value& value) { + value.HashValue(absl::HashState::Create(&state)); + return std::move(state); + } + + friend void swap(Value& lhs, Value& rhs) { lhs.Swap(rhs); } + + friend bool operator==(const Value& lhs, const Value& rhs) { + return lhs.Equals(rhs); + } + + friend bool operator!=(const Value& lhs, const Value& rhs) { + return !operator==(lhs, rhs); + } + + private: + friend class Bytes; + + using Metadata = base_internal::ValueMetadata; + using Content = base_internal::ValueContent; + + static void InitializeSingletons(); + + static void Destruct(Value* dest); + + constexpr explicit Value(bool value) + : metadata_(Kind::kBool), content_(value) {} + + constexpr explicit Value(int64_t value) + : metadata_(Kind::kInt), content_(value) {} + + constexpr explicit Value(uint64_t value) + : metadata_(Kind::kUint), content_(value) {} + + constexpr explicit Value(double value) + : metadata_(Kind::kDouble), content_(value) {} + + explicit Value(const absl::Status& status) + : metadata_(Kind::kError), content_(status) {} + + constexpr Value(Kind kind, base_internal::BaseValue* base_value) + : metadata_(kind), content_(base_value) {} + + constexpr Value(Kind kind, int64_t content, uint32_t extended_content) + : metadata_(kind, extended_content), content_(content) {} + + bool Equals(const Value& other) const; + + void HashValue(absl::HashState state) const; + + void Swap(Value& other); + + Metadata metadata_; + Content content_; +}; + +// A CEL bytes value specific interface that can be accessed via +// `cel::Value::AsBytes`. It acts as a facade over various native +// representations and provides efficient implementations of CEL builtin +// functions. +class Bytes final : public base_internal::BaseValue { + public: + // Returns a bytes value which has a size of 0 and is empty. + ABSL_ATTRIBUTE_PURE_FUNCTION static Value Empty(); + + // Returns a bytes value with `value` as its contents. + static Value New(std::string value); + + // Returns a bytes value with a copy of `value` as its contents. + static Value New(absl::string_view value) { + return New(std::string(value.data(), value.size())); + } + + // Returns a bytes value with a copy of `value` as its contents. + // + // This is needed for `Value::Bytes("foo")` to be an unambiguous function + // call. + static Value New(const char* value) { + ABSL_ASSERT(value != nullptr); + return New(absl::string_view(value)); + } + + // Returns a bytes value with `value` as its contents. + static Value New(absl::Cord value); + + // Returns a bytes value with `value` as its contents. Unlike `New()` this + // does not copy `value`, instead it expects the contents pointed to by + // `value` to live as long as the returned instance. `releaser` is used to + // notify the caller when the contents pointed to by `value` are no longer + // required. + template + static std::enable_if_t, Value> Wrap( + absl::string_view value, Releaser&& releaser); + + static Value Concat(const Bytes& lhs, const Bytes& rhs); + + size_t size() const; + + bool empty() const; + + bool Equals(absl::string_view bytes) const; + + bool Equals(const absl::Cord& bytes) const; + + bool Equals(const Bytes& bytes) const; + + int Compare(absl::string_view bytes) const; + + int Compare(const absl::Cord& bytes) const; + + int Compare(const Bytes& bytes) const; + + std::string ToString() const; + + absl::Cord ToCord() const; + + std::string DebugString() const override; + + protected: + bool Equals(const Value& value) const override; + + void HashValue(absl::HashState state) const override; + + private: + friend class Value; + + Bytes() : Bytes(std::string()) {} + + explicit Bytes(std::string value) + : base_internal::BaseValue(), + data_(absl::in_place_index<0>, std::move(value)) {} + + explicit Bytes(absl::Cord value) + : base_internal::BaseValue(), + data_(absl::in_place_index<1>, std::move(value)) {} + + explicit Bytes(base_internal::ExternalData value) + : base_internal::BaseValue(), + data_(absl::in_place_index<2>, std::move(value)) {} + + absl::variant data_; +}; + +template +std::enable_if_t, Value> Bytes::Wrap( + absl::string_view value, Releaser&& releaser) { + if (value.empty()) { + std::forward(releaser)(); + return Empty(); + } + return Value(Kind::kBytes, + new Bytes(base_internal::ExternalData( + value.data(), value.size(), + std::make_unique( + std::forward(releaser))))); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUE_H_ diff --git a/base/value_test.cc b/base/value_test.cc new file mode 100644 index 000000000..f9eae5723 --- /dev/null +++ b/base/value_test.cc @@ -0,0 +1,749 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/value.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/hash/hash_testing.h" +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "base/type.h" +#include "internal/strings.h" +#include "internal/testing.h" +#include "internal/time.h" + +namespace cel { +namespace { + +using cel::internal::StatusIs; + +template +constexpr void IS_INITIALIZED(T&) {} + +TEST(Value, TypeTraits) { + EXPECT_TRUE(std::is_default_constructible_v); + EXPECT_TRUE(std::is_copy_constructible_v); + EXPECT_TRUE(std::is_move_constructible_v); + EXPECT_TRUE(std::is_copy_assignable_v); + EXPECT_TRUE(std::is_move_assignable_v); + EXPECT_TRUE(std::is_swappable_v); +} + +TEST(Value, DefaultConstructor) { + Value value; + EXPECT_EQ(value, Value::Null()); +} + +struct ConstructionAssignmentTestCase final { + std::string name; + std::function default_value; +}; + +using ConstructionAssignmentTest = + testing::TestWithParam; + +TEST_P(ConstructionAssignmentTest, CopyConstructor) { + const auto& test_case = GetParam(); + Value from(test_case.default_value()); + Value to(from); + IS_INITIALIZED(to); + EXPECT_EQ(to, test_case.default_value()); +} + +TEST_P(ConstructionAssignmentTest, MoveConstructor) { + const auto& test_case = GetParam(); + Value from(test_case.default_value()); + Value to(std::move(from)); + IS_INITIALIZED(from); + EXPECT_EQ(from, Value::Null()); + EXPECT_EQ(to, test_case.default_value()); +} + +TEST_P(ConstructionAssignmentTest, CopyAssignment) { + const auto& test_case = GetParam(); + Value from(test_case.default_value()); + Value to; + to = from; + EXPECT_EQ(to, from); +} + +TEST_P(ConstructionAssignmentTest, MoveAssignment) { + const auto& test_case = GetParam(); + Value from(test_case.default_value()); + Value to; + to = std::move(from); + IS_INITIALIZED(from); + EXPECT_EQ(from, Value::Null()); + EXPECT_EQ(to, test_case.default_value()); +} + +INSTANTIATE_TEST_SUITE_P( + ConstructionAssignmentTest, ConstructionAssignmentTest, + testing::ValuesIn({ + {"Null", Value::Null}, + {"Bool", Value::False}, + {"Int", []() { return Value::Int(0); }}, + {"Uint", []() { return Value::Uint(0); }}, + {"Double", []() { return Value::Double(0.0); }}, + {"Duration", []() { return Value::ZeroDuration(); }}, + {"Timestamp", []() { return Value::UnixEpoch(); }}, + {"Error", []() { return Value::Error(absl::CancelledError()); }}, + {"Bytes", Bytes::Empty}, + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +TEST(Value, Swap) { + Value lhs = Value::Int(0); + Value rhs = Value::Uint(0); + std::swap(lhs, rhs); + EXPECT_EQ(lhs, Value::Uint(0)); + EXPECT_EQ(rhs, Value::Int(0)); +} + +TEST(Value, NaN) { EXPECT_TRUE(std::isnan(Value::NaN().AsDouble())); } + +TEST(Value, PositiveInfinity) { + EXPECT_TRUE(std::isinf(Value::PositiveInfinity().AsDouble())); + EXPECT_FALSE(std::signbit(Value::PositiveInfinity().AsDouble())); +} + +TEST(Value, NegativeInfinity) { + EXPECT_TRUE(std::isinf(Value::NegativeInfinity().AsDouble())); + EXPECT_TRUE(std::signbit(Value::NegativeInfinity().AsDouble())); +} + +TEST(Value, ZeroDuration) { + EXPECT_EQ(Value::ZeroDuration().AsDuration(), absl::ZeroDuration()); +} + +TEST(Value, UnixEpoch) { + EXPECT_EQ(Value::UnixEpoch().AsTimestamp(), absl::UnixEpoch()); +} + +TEST(Null, DebugString) { EXPECT_EQ(Value::Null().DebugString(), "null"); } + +TEST(Bool, DebugString) { + EXPECT_EQ(Value::False().DebugString(), "false"); + EXPECT_EQ(Value::True().DebugString(), "true"); +} + +TEST(Int, DebugString) { + EXPECT_EQ(Value::Int(-1).DebugString(), "-1"); + EXPECT_EQ(Value::Int(0).DebugString(), "0"); + EXPECT_EQ(Value::Int(1).DebugString(), "1"); +} + +TEST(Uint, DebugString) { + EXPECT_EQ(Value::Uint(0).DebugString(), "0u"); + EXPECT_EQ(Value::Uint(1).DebugString(), "1u"); +} + +TEST(Double, DebugString) { + EXPECT_EQ(Value::Double(-1.0).DebugString(), "-1.0"); + EXPECT_EQ(Value::Double(0.0).DebugString(), "0.0"); + EXPECT_EQ(Value::Double(1.0).DebugString(), "1.0"); + EXPECT_EQ(Value::Double(-1.1).DebugString(), "-1.1"); + EXPECT_EQ(Value::Double(0.1).DebugString(), "0.1"); + EXPECT_EQ(Value::Double(1.1).DebugString(), "1.1"); + + EXPECT_EQ(Value::NaN().DebugString(), "nan"); + EXPECT_EQ(Value::PositiveInfinity().DebugString(), "+infinity"); + EXPECT_EQ(Value::NegativeInfinity().DebugString(), "-infinity"); +} + +TEST(Duration, DebugString) { + EXPECT_EQ(Value::ZeroDuration().DebugString(), + internal::FormatDuration(absl::ZeroDuration()).value()); +} + +TEST(Timestamp, DebugString) { + EXPECT_EQ(Value::UnixEpoch().DebugString(), + internal::FormatTimestamp(absl::UnixEpoch()).value()); +} + +// The below tests could be made parameterized but doing so requires the +// extension for struct member initiation by name for it to be worth it. That +// feature is not available in C++17. + +TEST(Value, Error) { + Value error_value = Value::Error(absl::CancelledError()); + EXPECT_TRUE(error_value.IsError()); + EXPECT_EQ(error_value, error_value); + EXPECT_EQ(error_value, Value::Error(absl::CancelledError())); + EXPECT_EQ(error_value.AsError(), absl::CancelledError()); +} + +TEST(Value, Bool) { + Value false_value = Value::False(); + EXPECT_TRUE(false_value.IsBool()); + EXPECT_EQ(false_value, false_value); + EXPECT_EQ(false_value, Value::Bool(false)); + EXPECT_EQ(false_value.kind(), Kind::kBool); + EXPECT_EQ(false_value.type(), Type::Bool()); + EXPECT_FALSE(false_value.AsBool()); + + Value true_value = Value::True(); + EXPECT_TRUE(true_value.IsBool()); + EXPECT_EQ(true_value, true_value); + EXPECT_EQ(true_value, Value::Bool(true)); + EXPECT_EQ(true_value.kind(), Kind::kBool); + EXPECT_EQ(true_value.type(), Type::Bool()); + EXPECT_TRUE(true_value.AsBool()); + + EXPECT_NE(false_value, true_value); + EXPECT_NE(true_value, false_value); +} + +TEST(Value, Int) { + Value zero_value = Value::Int(0); + EXPECT_TRUE(zero_value.IsInt()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Value::Int(0)); + EXPECT_EQ(zero_value.kind(), Kind::kInt); + EXPECT_EQ(zero_value.type(), Type::Int()); + EXPECT_EQ(zero_value.AsInt(), 0); + + Value one_value = Value::Int(1); + EXPECT_TRUE(one_value.IsInt()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Value::Int(1)); + EXPECT_EQ(one_value.kind(), Kind::kInt); + EXPECT_EQ(one_value.type(), Type::Int()); + EXPECT_EQ(one_value.AsInt(), 1); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(Value, Uint) { + Value zero_value = Value::Uint(0); + EXPECT_TRUE(zero_value.IsUint()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Value::Uint(0)); + EXPECT_EQ(zero_value.kind(), Kind::kUint); + EXPECT_EQ(zero_value.type(), Type::Uint()); + EXPECT_EQ(zero_value.AsUint(), 0); + + Value one_value = Value::Uint(1); + EXPECT_TRUE(one_value.IsUint()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Value::Uint(1)); + EXPECT_EQ(one_value.kind(), Kind::kUint); + EXPECT_EQ(one_value.type(), Type::Uint()); + EXPECT_EQ(one_value.AsUint(), 1); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(Value, Double) { + Value zero_value = Value::Double(0.0); + EXPECT_TRUE(zero_value.IsDouble()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Value::Double(0.0)); + EXPECT_EQ(zero_value.kind(), Kind::kDouble); + EXPECT_EQ(zero_value.type(), Type::Double()); + EXPECT_EQ(zero_value.AsDouble(), 0.0); + + Value one_value = Value::Double(1.0); + EXPECT_TRUE(one_value.IsDouble()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Value::Double(1.0)); + EXPECT_EQ(one_value.kind(), Kind::kDouble); + EXPECT_EQ(one_value.type(), Type::Double()); + EXPECT_EQ(one_value.AsDouble(), 1.0); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(Value, Duration) { + Value zero_value = Value::ZeroDuration(); + EXPECT_TRUE(zero_value.IsDuration()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Value::ZeroDuration()); + EXPECT_EQ(zero_value.kind(), Kind::kDuration); + EXPECT_EQ(zero_value.type(), Type::Duration()); + EXPECT_EQ(zero_value.AsDuration(), absl::ZeroDuration()); + + ASSERT_OK_AND_ASSIGN(Value one_value, Value::Duration(absl::ZeroDuration() + + absl::Nanoseconds(1))); + EXPECT_TRUE(one_value.IsDuration()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value.kind(), Kind::kDuration); + EXPECT_EQ(one_value.type(), Type::Duration()); + EXPECT_EQ(one_value.AsDuration(), + absl::ZeroDuration() + absl::Nanoseconds(1)); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); + + EXPECT_THAT(Value::Duration(absl::InfiniteDuration()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(Value, Timestamp) { + Value zero_value = Value::UnixEpoch(); + EXPECT_TRUE(zero_value.IsTimestamp()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Value::UnixEpoch()); + EXPECT_EQ(zero_value.kind(), Kind::kTimestamp); + EXPECT_EQ(zero_value.type(), Type::Timestamp()); + EXPECT_EQ(zero_value.AsTimestamp(), absl::UnixEpoch()); + + ASSERT_OK_AND_ASSIGN(Value one_value, Value::Timestamp(absl::UnixEpoch() + + absl::Nanoseconds(1))); + EXPECT_TRUE(one_value.IsTimestamp()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value.kind(), Kind::kTimestamp); + EXPECT_EQ(one_value.type(), Type::Timestamp()); + EXPECT_EQ(one_value.AsTimestamp(), absl::UnixEpoch() + absl::Nanoseconds(1)); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); + + EXPECT_THAT(Value::Timestamp(absl::InfiniteFuture()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(Value, BytesFromString) { + Value zero_value = Bytes::New(std::string("0")); + EXPECT_TRUE(zero_value.IsBytes()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Bytes::New(std::string("0"))); + EXPECT_EQ(zero_value.kind(), Kind::kBytes); + EXPECT_EQ(zero_value.type(), Type::Bytes()); + EXPECT_EQ(zero_value.AsBytes().ToString(), "0"); + + Value one_value = Bytes::New(std::string("1")); + EXPECT_TRUE(one_value.IsBytes()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Bytes::New(std::string("1"))); + EXPECT_EQ(one_value.kind(), Kind::kBytes); + EXPECT_EQ(one_value.type(), Type::Bytes()); + EXPECT_EQ(one_value.AsBytes().ToString(), "1"); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(Value, BytesFromStringView) { + Value zero_value = Bytes::New(absl::string_view("0")); + EXPECT_TRUE(zero_value.IsBytes()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Bytes::New(absl::string_view("0"))); + EXPECT_EQ(zero_value.kind(), Kind::kBytes); + EXPECT_EQ(zero_value.type(), Type::Bytes()); + EXPECT_EQ(zero_value.AsBytes().ToString(), "0"); + + Value one_value = Bytes::New(absl::string_view("1")); + EXPECT_TRUE(one_value.IsBytes()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Bytes::New(absl::string_view("1"))); + EXPECT_EQ(one_value.kind(), Kind::kBytes); + EXPECT_EQ(one_value.type(), Type::Bytes()); + EXPECT_EQ(one_value.AsBytes().ToString(), "1"); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(Value, BytesFromCord) { + Value zero_value = Bytes::New(absl::Cord("0")); + EXPECT_TRUE(zero_value.IsBytes()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Bytes::New(absl::Cord("0"))); + EXPECT_EQ(zero_value.kind(), Kind::kBytes); + EXPECT_EQ(zero_value.type(), Type::Bytes()); + EXPECT_EQ(zero_value.AsBytes().ToCord(), "0"); + + Value one_value = Bytes::New(absl::Cord("1")); + EXPECT_TRUE(one_value.IsBytes()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Bytes::New(absl::Cord("1"))); + EXPECT_EQ(one_value.kind(), Kind::kBytes); + EXPECT_EQ(one_value.type(), Type::Bytes()); + EXPECT_EQ(one_value.AsBytes().ToCord(), "1"); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(Value, BytesFromLiteral) { + Value zero_value = Bytes::New("0"); + EXPECT_TRUE(zero_value.IsBytes()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Bytes::New("0")); + EXPECT_EQ(zero_value.kind(), Kind::kBytes); + EXPECT_EQ(zero_value.type(), Type::Bytes()); + EXPECT_EQ(zero_value.AsBytes().ToString(), "0"); + + Value one_value = Bytes::New("1"); + EXPECT_TRUE(one_value.IsBytes()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Bytes::New("1")); + EXPECT_EQ(one_value.kind(), Kind::kBytes); + EXPECT_EQ(one_value.type(), Type::Bytes()); + EXPECT_EQ(one_value.AsBytes().ToString(), "1"); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +Value MakeStringBytes(absl::string_view value) { return Bytes::New(value); } + +Value MakeCordBytes(absl::string_view value) { + return Bytes::New(absl::Cord(value)); +} + +Value MakeWrappedBytes(absl::string_view value) { + return Bytes::Wrap(value, []() {}); +} + +struct BytesConcatTestCase final { + std::string lhs; + std::string rhs; +}; + +using BytesConcatTest = testing::TestWithParam; + +TEST_P(BytesConcatTest, Concat) { + const BytesConcatTestCase& test_case = GetParam(); + EXPECT_TRUE(Bytes::Concat(MakeStringBytes(test_case.lhs).AsBytes(), + MakeStringBytes(test_case.rhs).AsBytes()) + .AsBytes() + .Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE(Bytes::Concat(MakeStringBytes(test_case.lhs).AsBytes(), + MakeCordBytes(test_case.rhs).AsBytes()) + .AsBytes() + .Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE(Bytes::Concat(MakeStringBytes(test_case.lhs).AsBytes(), + MakeWrappedBytes(test_case.rhs).AsBytes()) + .AsBytes() + .Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE(Bytes::Concat(MakeCordBytes(test_case.lhs).AsBytes(), + MakeStringBytes(test_case.rhs).AsBytes()) + .AsBytes() + .Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE(Bytes::Concat(MakeCordBytes(test_case.lhs).AsBytes(), + MakeWrappedBytes(test_case.rhs).AsBytes()) + .AsBytes() + .Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE(Bytes::Concat(MakeCordBytes(test_case.lhs).AsBytes(), + MakeCordBytes(test_case.rhs).AsBytes()) + .AsBytes() + .Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE(Bytes::Concat(MakeWrappedBytes(test_case.lhs).AsBytes(), + MakeStringBytes(test_case.rhs).AsBytes()) + .AsBytes() + .Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE(Bytes::Concat(MakeWrappedBytes(test_case.lhs).AsBytes(), + MakeCordBytes(test_case.rhs).AsBytes()) + .AsBytes() + .Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE(Bytes::Concat(MakeWrappedBytes(test_case.lhs).AsBytes(), + MakeWrappedBytes(test_case.rhs).AsBytes()) + .AsBytes() + .Equals(test_case.lhs + test_case.rhs)); +} + +INSTANTIATE_TEST_SUITE_P(BytesConcatTest, BytesConcatTest, + testing::ValuesIn({ + {"", ""}, + {"", std::string("\0", 1)}, + {std::string("\0", 1), ""}, + {std::string("\0", 1), std::string("\0", 1)}, + {"", "foo"}, + {"foo", ""}, + {"foo", "foo"}, + {"bar", "foo"}, + {"foo", "bar"}, + {"bar", "bar"}, + })); + +struct BytesSizeTestCase final { + std::string data; + size_t size; +}; + +using BytesSizeTest = testing::TestWithParam; + +TEST_P(BytesSizeTest, Size) { + const BytesSizeTestCase& test_case = GetParam(); + EXPECT_EQ(MakeStringBytes(test_case.data).AsBytes().size(), test_case.size); + EXPECT_EQ(MakeCordBytes(test_case.data).AsBytes().size(), test_case.size); + EXPECT_EQ(MakeWrappedBytes(test_case.data).AsBytes().size(), test_case.size); +} + +INSTANTIATE_TEST_SUITE_P(BytesSizeTest, BytesSizeTest, + testing::ValuesIn({ + {"", 0}, + {"1", 1}, + {"foo", 3}, + {"\xef\xbf\xbd", 3}, + })); + +struct BytesEmptyTestCase final { + std::string data; + bool empty; +}; + +using BytesEmptyTest = testing::TestWithParam; + +TEST_P(BytesEmptyTest, Empty) { + const BytesEmptyTestCase& test_case = GetParam(); + EXPECT_EQ(MakeStringBytes(test_case.data).AsBytes().empty(), test_case.empty); + EXPECT_EQ(MakeCordBytes(test_case.data).AsBytes().empty(), test_case.empty); + EXPECT_EQ(MakeWrappedBytes(test_case.data).AsBytes().empty(), + test_case.empty); +} + +INSTANTIATE_TEST_SUITE_P(BytesEmptyTest, BytesEmptyTest, + testing::ValuesIn({ + {"", true}, + {std::string("\0", 1), false}, + {"1", false}, + })); + +struct BytesEqualsTestCase final { + std::string lhs; + std::string rhs; + bool equals; +}; + +using BytesEqualsTest = testing::TestWithParam; + +TEST_P(BytesEqualsTest, Equals) { + const BytesEqualsTestCase& test_case = GetParam(); + EXPECT_EQ(MakeStringBytes(test_case.lhs) + .AsBytes() + .Equals(MakeStringBytes(test_case.rhs).AsBytes()), + test_case.equals); + EXPECT_EQ(MakeStringBytes(test_case.lhs) + .AsBytes() + .Equals(MakeCordBytes(test_case.rhs).AsBytes()), + test_case.equals); + EXPECT_EQ(MakeStringBytes(test_case.lhs) + .AsBytes() + .Equals(MakeWrappedBytes(test_case.rhs).AsBytes()), + test_case.equals); + EXPECT_EQ(MakeCordBytes(test_case.lhs) + .AsBytes() + .Equals(MakeStringBytes(test_case.rhs).AsBytes()), + test_case.equals); + EXPECT_EQ(MakeCordBytes(test_case.lhs) + .AsBytes() + .Equals(MakeWrappedBytes(test_case.rhs).AsBytes()), + test_case.equals); + EXPECT_EQ(MakeCordBytes(test_case.lhs) + .AsBytes() + .Equals(MakeCordBytes(test_case.rhs).AsBytes()), + test_case.equals); + EXPECT_EQ(MakeWrappedBytes(test_case.lhs) + .AsBytes() + .Equals(MakeStringBytes(test_case.rhs).AsBytes()), + test_case.equals); + EXPECT_EQ(MakeWrappedBytes(test_case.lhs) + .AsBytes() + .Equals(MakeCordBytes(test_case.rhs).AsBytes()), + test_case.equals); + EXPECT_EQ(MakeWrappedBytes(test_case.lhs) + .AsBytes() + .Equals(MakeWrappedBytes(test_case.rhs).AsBytes()), + test_case.equals); +} + +INSTANTIATE_TEST_SUITE_P(BytesEqualsTest, BytesEqualsTest, + testing::ValuesIn({ + {"", "", true}, + {"", std::string("\0", 1), false}, + {std::string("\0", 1), "", false}, + {std::string("\0", 1), std::string("\0", 1), true}, + {"", "foo", false}, + {"foo", "", false}, + {"foo", "foo", true}, + {"bar", "foo", false}, + {"foo", "bar", false}, + {"bar", "bar", true}, + })); + +struct BytesCompareTestCase final { + std::string lhs; + std::string rhs; + int compare; +}; + +using BytesCompareTest = testing::TestWithParam; + +int NormalizeCompareResult(int compare) { return std::clamp(compare, -1, 1); } + +TEST_P(BytesCompareTest, Equals) { + const BytesCompareTestCase& test_case = GetParam(); + EXPECT_EQ(NormalizeCompareResult( + MakeStringBytes(test_case.lhs) + .AsBytes() + .Compare(MakeStringBytes(test_case.rhs).AsBytes())), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeStringBytes(test_case.lhs) + .AsBytes() + .Compare(MakeCordBytes(test_case.rhs).AsBytes())), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeStringBytes(test_case.lhs) + .AsBytes() + .Compare(MakeWrappedBytes(test_case.rhs).AsBytes())), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeCordBytes(test_case.lhs) + .AsBytes() + .Compare(MakeStringBytes(test_case.rhs).AsBytes())), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeCordBytes(test_case.lhs) + .AsBytes() + .Compare(MakeWrappedBytes(test_case.rhs).AsBytes())), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeCordBytes(test_case.lhs) + .AsBytes() + .Compare(MakeCordBytes(test_case.rhs).AsBytes())), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeWrappedBytes(test_case.lhs) + .AsBytes() + .Compare(MakeStringBytes(test_case.rhs).AsBytes())), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeWrappedBytes(test_case.lhs) + .AsBytes() + .Compare(MakeCordBytes(test_case.rhs).AsBytes())), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeWrappedBytes(test_case.lhs) + .AsBytes() + .Compare(MakeWrappedBytes(test_case.rhs).AsBytes())), + test_case.compare); +} + +INSTANTIATE_TEST_SUITE_P(BytesCompareTest, BytesCompareTest, + testing::ValuesIn({ + {"", "", 0}, + {"", std::string("\0", 1), -1}, + {std::string("\0", 1), "", 1}, + {std::string("\0", 1), std::string("\0", 1), 0}, + {"", "foo", -1}, + {"foo", "", 1}, + {"foo", "foo", 0}, + {"bar", "foo", -1}, + {"foo", "bar", 1}, + {"bar", "bar", 0}, + })); + +struct BytesDebugStringTestCase final { + std::string data; +}; + +using BytesDebugStringTest = testing::TestWithParam; + +TEST_P(BytesDebugStringTest, ToCord) { + const BytesDebugStringTestCase& test_case = GetParam(); + EXPECT_EQ(MakeStringBytes(test_case.data).DebugString(), + internal::FormatBytesLiteral(test_case.data)); + EXPECT_EQ(MakeCordBytes(test_case.data).DebugString(), + internal::FormatBytesLiteral(test_case.data)); + EXPECT_EQ(MakeWrappedBytes(test_case.data).DebugString(), + internal::FormatBytesLiteral(test_case.data)); +} + +INSTANTIATE_TEST_SUITE_P(BytesDebugStringTest, BytesDebugStringTest, + testing::ValuesIn({ + {""}, + {"1"}, + {"foo"}, + {"\xef\xbf\xbd"}, + })); + +struct BytesToStringTestCase final { + std::string data; +}; + +using BytesToStringTest = testing::TestWithParam; + +TEST_P(BytesToStringTest, ToString) { + const BytesToStringTestCase& test_case = GetParam(); + EXPECT_EQ(MakeStringBytes(test_case.data).AsBytes().ToString(), + test_case.data); + EXPECT_EQ(MakeCordBytes(test_case.data).AsBytes().ToString(), test_case.data); + EXPECT_EQ(MakeWrappedBytes(test_case.data).AsBytes().ToString(), + test_case.data); +} + +INSTANTIATE_TEST_SUITE_P(BytesToStringTest, BytesToStringTest, + testing::ValuesIn({ + {""}, + {"1"}, + {"foo"}, + {"\xef\xbf\xbd"}, + })); + +struct BytesToCordTestCase final { + std::string data; +}; + +using BytesToCordTest = testing::TestWithParam; + +TEST_P(BytesToCordTest, ToCord) { + const BytesToCordTestCase& test_case = GetParam(); + EXPECT_EQ(MakeStringBytes(test_case.data).AsBytes().ToCord(), test_case.data); + EXPECT_EQ(MakeCordBytes(test_case.data).AsBytes().ToCord(), test_case.data); + EXPECT_EQ(MakeWrappedBytes(test_case.data).AsBytes().ToCord(), + test_case.data); +} + +INSTANTIATE_TEST_SUITE_P(BytesToCordTest, BytesToCordTest, + testing::ValuesIn({ + {""}, + {"1"}, + {"foo"}, + {"\xef\xbf\xbd"}, + })); + +TEST(Value, SupportsAbslHash) { + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ + Value::Null(), + Value::Error(absl::CancelledError()), + Value::Bool(false), + Value::Int(0), + Value::Uint(0), + Value::Double(0.0), + Value::ZeroDuration(), + Value::UnixEpoch(), + Bytes::Empty(), + Bytes::New("foo"), + Bytes::New(absl::Cord("bar")), + Bytes::Wrap("baz", []() {}), + })); +} + +} // namespace +} // namespace cel diff --git a/bazel/BUILD b/bazel/BUILD index f95444438..ffd0fb0cd 100644 --- a/bazel/BUILD +++ b/bazel/BUILD @@ -1,9 +1 @@ package(default_visibility = ["//visibility:public"]) - -load("@rules_java//java:defs.bzl", "java_binary") - -java_binary( - name = "antlr4_tool", - runtime_deps = ["@antlr4_jar//jar"], - main_class = "org.antlr.v4.Tool", -) diff --git a/bazel/antlr.bzl b/bazel/antlr.bzl index def928b39..ea5520582 100644 --- a/bazel/antlr.bzl +++ b/bazel/antlr.bzl @@ -16,18 +16,25 @@ Generate C++ parser and lexer from a grammar file. """ -def antlr_cc_library(name, src, package): +load("@rules_antlr//antlr:antlr4.bzl", "antlr") + +def antlr_cc_library(name, src, package = None, listener = False, visitor = True): """Creates a C++ lexer and parser from a source grammar. Args: name: Base name for the lexer and the parser rules. src: source ANTLR grammar file package: The namespace for the generated code + listener: generate ANTLR listener (default: False) + visitor: generate ANTLR visitor (default: True) """ generated = name + "_grammar" - antlr_library( + antlr( name = generated, - src = src, + srcs = [src], + language = "Cpp", + listener = listener, + visitor = visitor, package = package, ) native.cc_library( @@ -39,65 +46,3 @@ def antlr_cc_library(name, src, package): ], linkstatic = 1, ) - -def _antlr_library(ctx): - output = ctx.actions.declare_directory(ctx.attr.name) - - antlr_args = ctx.actions.args() - antlr_args.add("-Dlanguage=Cpp") - antlr_args.add("-no-listener") - antlr_args.add("-visitor") - antlr_args.add("-o", output.path) - antlr_args.add("-package", ctx.attr.package) - antlr_args.add(ctx.file.src) - - # Strip ".g4" extension. - basename = ctx.file.src.basename[:-3] - - suffixes = ["Lexer", "Parser", "BaseVisitor", "Visitor"] - - ctx.actions.run( - arguments = [antlr_args], - inputs = [ctx.file.src], - outputs = [output], - executable = ctx.executable._tool, - progress_message = "Processing ANTLR grammar", - ) - - files = [] - for suffix in suffixes: - header = ctx.actions.declare_file(basename + suffix + ".h") - source = ctx.actions.declare_file(basename + suffix + ".cpp") - generated = output.path + "/" + ctx.file.src.path[:-3] + suffix - - ctx.actions.run_shell( - mnemonic = "CopyHeader" + suffix, - inputs = [output], - outputs = [header], - command = 'cp "{generated}" "{out}"'.format(generated = generated + ".h", out = header.path), - ) - ctx.actions.run_shell( - mnemonic = "CopySource" + suffix, - inputs = [output], - outputs = [source], - command = 'cp "{generated}" "{out}"'.format(generated = generated + ".cpp", out = source.path), - ) - - files.append(header) - files.append(source) - - compilation_context = cc_common.create_compilation_context(headers = depset(files)) - return [DefaultInfo(files = depset(files)), CcInfo(compilation_context = compilation_context)] - -antlr_library = rule( - implementation = _antlr_library, - attrs = { - "src": attr.label(allow_single_file = [".g4"], mandatory = True), - "package": attr.string(), - "_tool": attr.label( - executable = True, - cfg = "host", - default = Label("//bazel:antlr4_tool"), - ), - }, -) diff --git a/bazel/deps.bzl b/bazel/deps.bzl index abe35fdfc..0edf314df 100644 --- a/bazel/deps.bzl +++ b/bazel/deps.bzl @@ -2,7 +2,7 @@ Main dependencies of cel-cpp. """ -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_jar") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") def base_deps(): """Base evaluator and test dependencies.""" @@ -69,9 +69,15 @@ def base_deps(): def parser_deps(): """ANTLR dependency for the parser.""" - # Apr 15, 2022 - ANTLR4_VERSION = "4.10.1" + http_archive( + name = "rules_antlr", + sha256 = "26e6a83c665cf6c1093b628b3a749071322f0f70305d12ede30909695ed85591", + strip_prefix = "rules_antlr-0.5.0", + urls = ["https://github.com/marcohu/rules_antlr/archive/0.5.0.tar.gz"], + ) + ANTLR4_RUNTIME_GIT_SHA = "70b2edcf98eb612a92d3dbaedb2ce0b69533b0cb" # Dec 7, 2021 + ANTLR4_RUNTIME_SHA = "fae73909f95e1320701e29ac03bab9233293fb5b90d3ce857279f1b46b614c83" http_archive( name = "antlr4_runtimes", build_file_content = """ @@ -83,14 +89,9 @@ cc_library( includes = ["runtime/Cpp/runtime/src"], ) """, - sha256 = "a320568b738e42735946bebc5d9d333170e14a251c5734e8b852ad1502efa8a2", - strip_prefix = "antlr4-" + ANTLR4_VERSION, - urls = ["https://github.com/antlr/antlr4/archive/v" + ANTLR4_VERSION + ".tar.gz"], - ) - http_jar( - name = "antlr4_jar", - urls = ["https://www.antlr.org/download/antlr-" + ANTLR4_VERSION + "-complete.jar"], - sha256 = "41949d41f20d31d5b8277187735dd755108df52b38db6c865108d3382040f918", + sha256 = ANTLR4_RUNTIME_SHA, + strip_prefix = "antlr4-" + ANTLR4_RUNTIME_GIT_SHA, + urls = ["https://github.com/antlr/antlr4/archive/" + ANTLR4_RUNTIME_GIT_SHA + ".tar.gz"], ) def flatbuffers_deps(): diff --git a/bazel/deps_extra.bzl b/bazel/deps_extra.bzl index 40a47f01b..76cb8c5d6 100644 --- a/bazel/deps_extra.bzl +++ b/bazel/deps_extra.bzl @@ -4,6 +4,7 @@ Transitive dependencies. load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") load("@com_google_googleapis//:repository_rules.bzl", "switched_rules_by_language") +load("@rules_antlr//antlr:repositories.bzl", "rules_antlr_dependencies") load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") @@ -49,4 +50,5 @@ def cel_cpp_deps_extra(): cc = True, go = True, # cel-spec requirement ) + rules_antlr_dependencies("4.8") cel_spec_deps_extra() diff --git a/conformance/BUILD b/conformance/BUILD index b620f2282..ab43d7b50 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -56,8 +56,8 @@ cc_binary( "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:conformance_service_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_googleapis//google/rpc:code_cc_proto", "@com_google_protobuf//:protobuf", diff --git a/conformance/server.cc b/conformance/server.cc index 68f77fda7..6a717d470 100644 --- a/conformance/server.cc +++ b/conformance/server.cc @@ -1,7 +1,7 @@ #include #include -#include "google/api/expr/v1alpha1/conformance_service.pb.h" +#include "google/api/expr/conformance/v1alpha1/conformance_service.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/eval.pb.h" @@ -44,8 +44,8 @@ class ConformanceServiceImpl { proto3_tests_(&google::api::expr::test::v1::proto3::TestAllTypes:: default_instance()) {} - void Parse(const v1alpha1::ParseRequest* request, - v1alpha1::ParseResponse* response) { + void Parse(const conformance::v1alpha1::ParseRequest* request, + conformance::v1alpha1::ParseResponse* response) { if (request->cel_source().empty()) { auto issue = response->add_issues(); issue->set_message("No source code"); @@ -64,15 +64,15 @@ class ConformanceServiceImpl { } } - void Check(const v1alpha1::CheckRequest* request, - v1alpha1::CheckResponse* response) { + void Check(const conformance::v1alpha1::CheckRequest* request, + conformance::v1alpha1::CheckResponse* response) { auto issue = response->add_issues(); issue->set_message("Check is not supported"); issue->set_code(google::rpc::Code::UNIMPLEMENTED); } - void Eval(const v1alpha1::EvalRequest* request, - v1alpha1::EvalResponse* response) { + void Eval(const conformance::v1alpha1::EvalRequest* request, + conformance::v1alpha1::EvalResponse* response) { const v1alpha1::Expr* expr = nullptr; if (request->has_parsed_expr()) { expr = &request->parsed_expr().expr(); @@ -190,8 +190,8 @@ int RunServer(bool optimize) { std::getline(std::cin, cmd); std::getline(std::cin, input); if (cmd == "parse") { - v1alpha1::ParseRequest request; - v1alpha1::ParseResponse response; + conformance::v1alpha1::ParseRequest request; + conformance::v1alpha1::ParseResponse response; if (!JsonStringToMessage(input, &request).ok()) { std::cerr << "Failed to parse JSON" << std::endl; } @@ -200,8 +200,8 @@ int RunServer(bool optimize) { std::cerr << "Failed to convert to JSON" << std::endl; } } else if (cmd == "eval") { - v1alpha1::EvalRequest request; - v1alpha1::EvalResponse response; + conformance::v1alpha1::EvalRequest request; + conformance::v1alpha1::EvalResponse response; if (!JsonStringToMessage(input, &request).ok()) { std::cerr << "Failed to parse JSON" << std::endl; } diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 21ba318bd..4d3e94853 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -45,6 +45,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -53,12 +54,16 @@ cc_test( srcs = [ "flat_expr_builder_test.cc", ], + data = [ + "//eval/testutil:simple_test_message_proto", + ], deps = [ ":flat_expr_builder", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", "//eval/public:cel_builtins", + "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_function_adapter", "//eval/public:cel_options", @@ -74,6 +79,7 @@ cc_test( "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 72a810025..69a494d80 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1056,10 +1056,11 @@ FlatExprBuilder::CreateExpressionImpl( std::unique_ptr expression_impl = absl::make_unique( - expr, std::move(execution_path), comprehension_max_iterations_, - std::move(iter_variable_names), enable_unknowns_, - enable_unknown_function_results_, enable_missing_attribute_errors_, - enable_null_coercion_, std::move(rewrite_buffer)); + expr, std::move(execution_path), descriptor_pool_, message_factory_, + comprehension_max_iterations_, std::move(iter_variable_names), + enable_unknowns_, enable_unknown_function_results_, + enable_missing_attribute_errors_, enable_null_coercion_, + std::move(rewrite_buffer)); if (warnings != nullptr) { *warnings = std::move(warnings_builder).warnings(); diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 6ad6e60b6..993672309 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -19,6 +19,7 @@ #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/descriptor.h" #include "absl/status/statusor.h" #include "eval/public/cel_expression.h" @@ -28,8 +29,12 @@ namespace google::api::expr::runtime { // Builds instances of CelExpressionFlatImpl. class FlatExprBuilder : public CelExpressionBuilder { public: - FlatExprBuilder() - : enable_unknowns_(false), + explicit FlatExprBuilder(const google::protobuf::DescriptorPool* descriptor_pool = + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory* message_factory = + google::protobuf::MessageFactory::generated_factory()) + : CelExpressionBuilder(descriptor_pool), + enable_unknowns_(false), enable_unknown_function_results_(false), enable_missing_attribute_errors_(false), shortcircuiting_(true), @@ -42,7 +47,9 @@ class FlatExprBuilder : public CelExpressionBuilder { enable_comprehension_list_append_(false), enable_comprehension_vulnerability_check_(false), enable_null_coercion_(true), - enable_wrapper_type_null_unboxing_(false) {} + enable_wrapper_type_null_unboxing_(false), + descriptor_pool_(descriptor_pool), + message_factory_(message_factory) {} // set_enable_unknowns controls support for unknowns in expressions created. void set_enable_unknowns(bool enabled) { enable_unknowns_ = enabled; } @@ -172,6 +179,9 @@ class FlatExprBuilder : public CelExpressionBuilder { bool enable_comprehension_vulnerability_check_; bool enable_null_coercion_; bool enable_wrapper_type_null_unboxing_; + + const google::protobuf::DescriptorPool* descriptor_pool_; + google::protobuf::MessageFactory* message_factory_; }; } // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index df0285d41..c0fcc0899 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -16,15 +16,22 @@ #include "eval/compiler/flat_expr_builder.h" +#include +#include #include #include #include #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/duration.pb.h" #include "google/protobuf/field_mask.pb.h" +#include "google/protobuf/descriptor.pb.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" #include "google/protobuf/text_format.h" #include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -32,6 +39,7 @@ #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_builtins.h" +#include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_options.h" @@ -61,6 +69,29 @@ using testing::HasSubstr; using cel::internal::IsOk; using cel::internal::StatusIs; +inline constexpr absl::string_view kSimpleTestMessageDescriptorSetFile = + "eval/testutil/" + "simple_test_message_proto-descriptor-set.proto.bin"; + +template +absl::Status ReadBinaryProtoFromDisk(absl::string_view file_name, + MessageClass& message) { + std::ifstream file; + file.open(file_name, std::fstream::in); + if (!file.is_open()) { + return absl::NotFoundError(absl::StrFormat("Failed to open file '%s': %s", + file_name, strerror(errno))); + } + + if (!message.ParseFromIstream(&file)) { + return absl::InvalidArgumentError( + absl::StrFormat("Failed to parse proto of type '%s' from file '%s'", + message.GetTypeName(), file_name)); + } + + return absl::OkStatus(); +} + class ConcatFunction : public CelFunction { public: explicit ConcatFunction() : CelFunction(CreateDescriptor()) {} @@ -1546,6 +1577,196 @@ TEST(FlatExprBuilderTest, NullUnboxingDisabled) { EXPECT_THAT(result, test::IsCelInt64(0)); } +TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { + ASSERT_OK_AND_ASSIGN( + ParsedExpr parsed_expr, + parser::Parse("google.api.expr.runtime.SimpleTestMessage{}")); + + // This time, the message is unknown. We only have the proto as data, we did + // not link the generated message, so it's not included in the generated pool. + FlatExprBuilder builder(google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory()); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInvalidArgument)); + + // Now we create a custom DescriptorPool to which we add SimpleTestMessage + google::protobuf::DescriptorPool desc_pool; + google::protobuf::FileDescriptorSet filedesc_set; + + ASSERT_OK(ReadBinaryProtoFromDisk(kSimpleTestMessageDescriptorSetFile, + filedesc_set)); + ASSERT_EQ(filedesc_set.file_size(), 1); + desc_pool.BuildFile(filedesc_set.file(0)); + + google::protobuf::DynamicMessageFactory message_factory(&desc_pool); + + // This time, the message is *known*. We are using a custom descriptor pool + // that has been primed with the relevant message. + FlatExprBuilder builder2(&desc_pool, &message_factory); + ASSERT_OK_AND_ASSIGN(auto expression, + builder2.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsMessage()); + EXPECT_EQ(result.MessageOrDie()->GetTypeName(), + "google.api.expr.runtime.SimpleTestMessage"); +} + +TEST(FlatExprBuilderTest, CustomDescriptorPoolForSelect) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("message.int64_value")); + + google::protobuf::DescriptorPool desc_pool; + google::protobuf::FileDescriptorSet filedesc_set; + + ASSERT_OK(ReadBinaryProtoFromDisk(kSimpleTestMessageDescriptorSetFile, + filedesc_set)); + ASSERT_EQ(filedesc_set.file_size(), 1); + desc_pool.BuildFile(filedesc_set.file(0)); + + google::protobuf::DynamicMessageFactory message_factory(&desc_pool); + + const google::protobuf::Descriptor* desc = desc_pool.FindMessageTypeByName( + "google.api.expr.runtime.SimpleTestMessage"); + const google::protobuf::Message* message_prototype = message_factory.GetPrototype(desc); + google::protobuf::Message* message = message_prototype->New(); + const google::protobuf::Reflection* refl = message->GetReflection(); + const google::protobuf::FieldDescriptor* field = desc->FindFieldByName("int64_value"); + refl->SetInt64(message, field, 123); + + // This time, the message is *known*. We are using a custom descriptor pool + // that has been primed with the relevant message. + FlatExprBuilder builder(&desc_pool, &message_factory); + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("message", + CelProtoWrapper::CreateMessage(message, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + EXPECT_THAT(result, test::IsCelInt64(123)); + + delete message; +} + +std::pair CreateTestMessage( + const google::protobuf::DescriptorPool& descriptor_pool, + google::protobuf::MessageFactory& message_factory, absl::string_view message_type) { + const google::protobuf::Descriptor* desc = + descriptor_pool.FindMessageTypeByName(message_type); + const google::protobuf::Message* message_prototype = message_factory.GetPrototype(desc); + google::protobuf::Message* message = message_prototype->New(); + const google::protobuf::Reflection* refl = message->GetReflection(); + return std::make_pair(message, refl); +} + +struct CustomDescriptorPoolTestParam final { + using SetterFunction = + std::function; + std::string message_type; + std::string field_name; + SetterFunction setter; + test::CelValueMatcher matcher; +}; + +class CustomDescriptorPoolTest + : public ::testing::TestWithParam {}; + +// This test in particular checks for conversion errors in cel_proto_wrapper.cc. +TEST_P(CustomDescriptorPoolTest, TestType) { + const CustomDescriptorPoolTestParam& p = GetParam(); + + google::protobuf::DescriptorPool descriptor_pool; + google::protobuf::Arena arena; + + // Setup descriptor pool and builder + ASSERT_OK(AddStandardMessageTypesToDescriptorPool(&descriptor_pool)); + google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("m")); + FlatExprBuilder builder(&descriptor_pool, &message_factory); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + + // Create test subject, invoke custom setter for message + auto [message, reflection] = + CreateTestMessage(descriptor_pool, message_factory, p.message_type); + const google::protobuf::FieldDescriptor* field = + message->GetDescriptor()->FindFieldByName(p.field_name); + + p.setter(message, reflection, field); + ASSERT_OK_AND_ASSIGN(std::unique_ptr expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + // Evaluate expression, verify expectation with custom matcher + Activation activation; + activation.InsertValue("m", CelProtoWrapper::CreateMessage(message, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + EXPECT_THAT(result, p.matcher); + + delete message; +} + +INSTANTIATE_TEST_SUITE_P( + ValueTypes, CustomDescriptorPoolTest, + ::testing::ValuesIn(std::vector{ + {"google.protobuf.Duration", "seconds", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetInt64(message, field, 10); + }, + test::IsCelDuration(absl::Seconds(10))}, + {"google.protobuf.DoubleValue", "value", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetDouble(message, field, 1.2); + }, + test::IsCelDouble(1.2)}, + {"google.protobuf.Int64Value", "value", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetInt64(message, field, -23); + }, + test::IsCelInt64(-23)}, + {"google.protobuf.UInt64Value", "value", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetUInt64(message, field, 42); + }, + test::IsCelUint64(42)}, + {"google.protobuf.BoolValue", "value", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetBool(message, field, true); + }, + test::IsCelBool(true)}, + {"google.protobuf.StringValue", "value", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetString(message, field, "foo"); + }, + test::IsCelString("foo")}, + {"google.protobuf.BytesValue", "value", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetString(message, field, "bar"); + }, + test::IsCelBytes("bar")}, + {"google.protobuf.Timestamp", "seconds", + [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, + const google::protobuf::FieldDescriptor* field) { + reflection->SetInt64(message, field, 20); + }, + test::IsCelTimestamp(absl::FromUnixSeconds(20))}})); + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/eval/BUILD b/eval/eval/BUILD index ec47b265f..45e55015d 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -322,6 +322,7 @@ cc_test( "//internal:status_macros", "//internal:testing", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -339,6 +340,7 @@ cc_test( "//internal:testing", "@com_google_absl//absl/status:statusor", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -378,6 +380,7 @@ cc_test( "//internal:status_macros", "//internal:testing", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -406,6 +409,7 @@ cc_test( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -423,6 +427,7 @@ cc_test( "//eval/public:unknown_set", "//internal:status_macros", "//internal:testing", + "@com_google_protobuf//:protobuf", ], ) @@ -468,6 +473,7 @@ cc_test( "//internal:testing", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", ], ) @@ -493,6 +499,7 @@ cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -620,6 +627,7 @@ cc_test( "//eval/public:unknown_set", "//internal:status_macros", "//internal:testing", + "@com_google_protobuf//:protobuf", ], ) @@ -654,5 +662,6 @@ cc_test( "//internal:testing", "@com_google_absl//absl/status:statusor", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/eval/comprehension_step_test.cc b/eval/eval/comprehension_step_test.cc index cea9fb0db..feb7312dc 100644 --- a/eval/eval/comprehension_step_test.cc +++ b/eval/eval/comprehension_step_test.cc @@ -7,6 +7,7 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/descriptor.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "eval/eval/evaluator_core.h" @@ -44,7 +45,8 @@ class ListKeysStepTest : public testing::Test { std::unique_ptr MakeExpression( ExecutionPath&& path, bool unknown_attributes = false) { return std::make_unique( - &dummy_expr_, std::move(path), 0, std::set(), + &dummy_expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, std::set(), unknown_attributes, unknown_attributes); } diff --git a/eval/eval/const_value_step_test.cc b/eval/eval/const_value_step_test.cc index 18598d0a1..5251ee185 100644 --- a/eval/eval/const_value_step_test.cc +++ b/eval/eval/const_value_step_test.cc @@ -3,6 +3,7 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/descriptor.h" #include "absl/status/statusor.h" #include "eval/eval/evaluator_core.h" #include "eval/public/activation.h" @@ -32,7 +33,9 @@ absl::StatusOr RunConstantExpression(const Expr* expr, google::api::expr::v1alpha1::Expr dummy_expr; - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}); + CelExpressionFlatImpl impl( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}); Activation activation; diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 9f7a2bf1c..2af1f9ce6 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -5,6 +5,7 @@ #include #include "google/protobuf/struct.pb.h" +#include "google/protobuf/descriptor.h" #include "absl/status/status.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" @@ -54,7 +55,9 @@ CelValue EvaluateAttributeHelper( std::move(CreateIdentStep(&key_expr->ident_expr(), 2).value())); path.push_back(std::move(CreateContainerAccessStep(call, 3).value())); - CelExpressionFlatImpl cel_expr(&expr, std::move(path), 0, {}, enable_unknown); + CelExpressionFlatImpl cel_expr( + &expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, enable_unknown); Activation activation; activation.InsertValue("container", container); diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index ba0e33880..8a80268f2 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -3,6 +3,7 @@ #include #include +#include "google/protobuf/descriptor.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "eval/eval/const_value_step.h" @@ -45,8 +46,9 @@ absl::StatusOr RunExpression(const std::vector& values, CreateCreateListStep(create_list, dummy_expr.id())); path.push_back(std::move(step)); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, - enable_unknowns); + CelExpressionFlatImpl cel_expr( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, enable_unknowns); Activation activation; return cel_expr.Evaluate(activation, arena); @@ -78,8 +80,9 @@ absl::StatusOr RunExpressionWithCelValues( CreateCreateListStep(create_list, dummy_expr.id())); path.push_back(std::move(step0)); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, - enable_unknowns); + CelExpressionFlatImpl cel_expr( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, enable_unknowns); return cel_expr.Evaluate(activation, arena); } @@ -100,7 +103,9 @@ TEST(CreateListStepTest, TestCreateListStackUnderflow) { CreateCreateListStep(create_list, dummy_expr.id())); path.push_back(std::move(step0)); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}); + CelExpressionFlatImpl cel_expr( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 464c5ce9d..4cbad64bc 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -76,7 +76,7 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, } const Message* prototype = - MessageFactory::generated_factory()->GetPrototype(descriptor_); + frame->message_factory()->GetPrototype(descriptor_); Message* msg = (prototype != nullptr) ? prototype->New(frame->arena()) : nullptr; diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index 8a435e621..80395e49a 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -4,6 +4,7 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/descriptor.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -69,8 +70,9 @@ absl::StatusOr RunExpression(absl::string_view field, path.push_back(std::move(step0)); path.push_back(std::move(step1)); - CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0, {}, - enable_unknowns); + CelExpressionFlatImpl cel_expr( + &expr1, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, enable_unknowns); Activation activation; activation.InsertValue("message", value); @@ -157,8 +159,9 @@ absl::StatusOr RunCreateMapExpression( CreateCreateStructStep(create_struct, expr1.id())); path.push_back(std::move(step1)); - CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0, {}, - enable_unknowns); + CelExpressionFlatImpl cel_expr( + &expr1, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, enable_unknowns); return cel_expr.Evaluate(activation, arena); } @@ -179,7 +182,9 @@ TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { CreateCreateStructStep(create_struct, desc, expr1.id())); path.push_back(std::move(step)); - CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0, {}, GetParam()); + CelExpressionFlatImpl cel_expr( + &expr1, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, GetParam()); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index 06c256f13..df64324e4 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -152,8 +152,9 @@ absl::StatusOr CelExpressionFlatImpl::Trace( ::cel::internal::down_cast(_state); state->Reset(); - ExecutionFrame frame(path_, activation, max_iterations_, state, - enable_unknowns_, enable_unknown_function_results_, + ExecutionFrame frame(path_, activation, descriptor_pool_, message_factory_, + max_iterations_, state, enable_unknowns_, + enable_unknown_function_results_, enable_missing_attribute_errors_, enable_null_coercion_); EvaluatorStack* stack = &frame.value_stack(); diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 947d97931..8c29574af 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -14,6 +14,7 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -110,13 +111,17 @@ class ExecutionFrame { // arena serves as allocation manager during the expression evaluation. ExecutionFrame(const ExecutionPath& flat, const BaseActivation& activation, - int max_iterations, CelExpressionFlatEvaluationState* state, - bool enable_unknowns, bool enable_unknown_function_results, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, int max_iterations, + CelExpressionFlatEvaluationState* state, bool enable_unknowns, + bool enable_unknown_function_results, bool enable_missing_attribute_errors, bool enable_null_coercion) : pc_(0UL), execution_path_(flat), activation_(activation), + descriptor_pool_(descriptor_pool), + message_factory_(message_factory), enable_unknowns_(enable_unknowns), enable_unknown_function_results_(enable_unknown_function_results), enable_missing_attribute_errors_(enable_missing_attribute_errors), @@ -156,6 +161,11 @@ class ExecutionFrame { bool enable_null_coercion() const { return enable_null_coercion_; } google::protobuf::Arena* arena() { return state_->arena(); } + const google::protobuf::DescriptorPool* descriptor_pool() const { + return descriptor_pool_; + } + google::protobuf::MessageFactory* message_factory() const { return message_factory_; } + const AttributeUtility& attribute_utility() const { return attribute_utility_; } @@ -215,6 +225,8 @@ class ExecutionFrame { size_t pc_; // pc_ - Program Counter. Current position on execution path. const ExecutionPath& execution_path_; const BaseActivation& activation_; + const google::protobuf::DescriptorPool* descriptor_pool_; + google::protobuf::MessageFactory* message_factory_; bool enable_unknowns_; bool enable_unknown_function_results_; bool enable_missing_attribute_errors_; @@ -235,7 +247,10 @@ class CelExpressionFlatImpl : public CelExpression { // iterations in the comprehension expressions (use 0 to disable the upper // bound). CelExpressionFlatImpl(ABSL_ATTRIBUTE_UNUSED const Expr* root_expr, - ExecutionPath path, int max_iterations, + ExecutionPath path, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, + int max_iterations, std::set iter_variable_names, bool enable_unknowns = false, bool enable_unknown_function_results = false, @@ -244,6 +259,8 @@ class CelExpressionFlatImpl : public CelExpression { std::unique_ptr rewritten_expr = nullptr) : rewritten_expr_(std::move(rewritten_expr)), path_(std::move(path)), + descriptor_pool_(descriptor_pool), + message_factory_(message_factory), max_iterations_(max_iterations), iter_variable_names_(std::move(iter_variable_names)), enable_unknowns_(enable_unknowns), @@ -282,6 +299,8 @@ class CelExpressionFlatImpl : public CelExpression { // Maintain lifecycle of a modified expression. std::unique_ptr rewritten_expr_; const ExecutionPath path_; + const google::protobuf::DescriptorPool* descriptor_pool_; + google::protobuf::MessageFactory* message_factory_; const int max_iterations_; const std::set iter_variable_names_; bool enable_unknowns_; diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index 59bc90a20..57112f69d 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -4,6 +4,7 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/descriptor.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/eval/attribute_trail.h" #include "eval/public/activation.h" @@ -66,7 +67,10 @@ TEST(EvaluatorCoreTest, ExecutionFrameNext) { Activation activation; CelExpressionFlatEvaluationState state(path.size(), {}, nullptr); - ExecutionFrame frame(path, activation, 0, &state, false, false, false, true); + ExecutionFrame frame(path, activation, + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, &state, + false, false, false, true); EXPECT_THAT(frame.Next(), Eq(path[0].get())); EXPECT_THAT(frame.Next(), Eq(path[1].get())); @@ -84,7 +88,10 @@ TEST(EvaluatorCoreTest, ExecutionFrameSetGetClearVar) { google::protobuf::Arena arena; ExecutionPath path; CelExpressionFlatEvaluationState state(path.size(), {test_iter_var}, nullptr); - ExecutionFrame frame(path, activation, 0, &state, false, false, false, true); + ExecutionFrame frame(path, activation, + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, &state, + false, false, false, true); CelValue original = CelValue::CreateInt64(test_value); Expr ident; @@ -149,7 +156,10 @@ TEST(EvaluatorCoreTest, SimpleEvaluatorTest) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, + {}); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index 89673b621..7513d5c0d 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -6,6 +6,7 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/descriptor.h" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "eval/eval/evaluator_core.h" @@ -225,8 +226,9 @@ class FunctionStepTest break; } return absl::make_unique( - &dummy_expr_, std::move(path), 0, std::set(), unknowns, - unknown_function_results); + &dummy_expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, std::set(), + unknowns, unknown_function_results); } private: @@ -478,9 +480,10 @@ class FunctionStepTestUnknowns unknown_functions = false; break; } - return absl::make_unique(&expr_, std::move(path), 0, - std::set(), - true, unknown_functions); + return absl::make_unique( + &expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, std::set(), + true, unknown_functions); } private: @@ -629,7 +632,9 @@ TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { Expr dummy_expr; - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + CelExpressionFlatImpl impl( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true); Activation activation; google::protobuf::Arena arena; @@ -678,7 +683,9 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { Expr dummy_expr; - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + CelExpressionFlatImpl impl( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true); Activation activation; google::protobuf::Arena arena; @@ -727,7 +734,9 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { Expr dummy_expr; - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + CelExpressionFlatImpl impl( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true); Activation activation; google::protobuf::Arena arena; @@ -771,7 +780,9 @@ TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { Expr dummy_expr; - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + CelExpressionFlatImpl impl( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true); Activation activation; google::protobuf::Arena arena; @@ -871,9 +882,10 @@ TEST_F(FunctionStepNullCoercionTest, EnabledSupportsMessageOverloads) { path.push_back(std::move(call_step)); - CelExpressionFlatImpl impl(&dummy_expr_, std::move(path), 0, {}, true, true, - true, - /*enable_null_coercion=*/true); + CelExpressionFlatImpl impl( + &dummy_expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true, true, + /*enable_null_coercion=*/true); ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation_, &arena_)); ASSERT_TRUE(value.IsString()); @@ -895,9 +907,10 @@ TEST_F(FunctionStepNullCoercionTest, EnabledPrefersNullOverloads) { path.push_back(std::move(call_step)); - CelExpressionFlatImpl impl(&dummy_expr_, std::move(path), 0, {}, true, true, - true, - /*enable_null_coercion=*/true); + CelExpressionFlatImpl impl( + &dummy_expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true, true, + /*enable_null_coercion=*/true); ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation_, &arena_)); ASSERT_TRUE(value.IsString()); @@ -918,9 +931,10 @@ TEST_F(FunctionStepNullCoercionTest, EnabledNullMessageDoesNotEscape) { path.push_back(std::move(call_step)); - CelExpressionFlatImpl impl(&dummy_expr_, std::move(path), 0, {}, true, true, - true, - /*enable_null_coercion=*/true); + CelExpressionFlatImpl impl( + &dummy_expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true, true, + /*enable_null_coercion=*/true); ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation_, &arena_)); ASSERT_TRUE(value.IsNull()); @@ -941,9 +955,10 @@ TEST_F(FunctionStepNullCoercionTest, Disabled) { path.push_back(std::move(call_step)); - CelExpressionFlatImpl impl(&dummy_expr_, std::move(path), 0, {}, true, true, - true, - /*enable_null_coercion=*/false); + CelExpressionFlatImpl impl( + &dummy_expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true, true, + /*enable_null_coercion=*/false); ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation_, &arena_)); ASSERT_TRUE(value.IsError()); diff --git a/eval/eval/ident_step_test.cc b/eval/eval/ident_step_test.cc index 79394dcb7..60680dbdc 100644 --- a/eval/eval/ident_step_test.cc +++ b/eval/eval/ident_step_test.cc @@ -4,6 +4,7 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/descriptor.h" #include "eval/eval/evaluator_core.h" #include "eval/public/activation.h" #include "internal/status_macros.h" @@ -31,7 +32,10 @@ TEST(IdentStepTest, TestIdentStep) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, + {}); Activation activation; Arena arena; @@ -59,7 +63,10 @@ TEST(IdentStepTest, TestIdentStepNameNotFound) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, + {}); Activation activation; Arena arena; @@ -84,7 +91,9 @@ TEST(IdentStepTest, DisableMissingAttributeErrorsOK) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, /*enable_unknowns=*/false); Activation activation; @@ -121,8 +130,11 @@ TEST(IdentStepTest, TestIdentStepMissingAttributeErrors) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, false, - false, /*enable_missing_attribute_errors=*/true); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, + false, false, + /*enable_missing_attribute_errors=*/true); Activation activation; Arena arena; @@ -160,7 +172,10 @@ TEST(IdentStepTest, TestIdentStepUnknownAttribute) { auto dummy_expr = absl::make_unique(); // Expression with unknowns enabled. - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, true); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, + true); Activation activation; Arena arena; diff --git a/eval/eval/logic_step_test.cc b/eval/eval/logic_step_test.cc index 4b09a347e..1300360ed 100644 --- a/eval/eval/logic_step_test.cc +++ b/eval/eval/logic_step_test.cc @@ -2,6 +2,7 @@ #include +#include "google/protobuf/descriptor.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" #include "eval/public/unknown_attribute_set.h" @@ -40,8 +41,10 @@ class LogicStepTest : public testing::TestWithParam { path.push_back(std::move(step)); auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, - enable_unknown); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, + {}, enable_unknown); Activation activation; activation.InsertValue("name0", arg0); diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index 68f69ed0a..8b3ec5452 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -5,6 +5,7 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/descriptor.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "eval/eval/ident_step.h" @@ -58,8 +59,10 @@ absl::StatusOr RunExpression(const CelValue target, path.push_back(std::move(step0)); path.push_back(std::move(step1)); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, - options.enable_unknowns); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, + {}, options.enable_unknowns); Activation activation; activation.InsertValue("target", target); @@ -204,7 +207,9 @@ TEST(SelectStepTest, MapPresenseIsErrorTest) { path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); - CelExpressionFlatImpl cel_expr(&select_expr, std::move(path), 0, {}, false); + CelExpressionFlatImpl cel_expr( + &select_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, false); Activation activation; activation.InsertValue("target", CelProtoWrapper::CreateMessage(&message, &arena)); @@ -508,8 +513,9 @@ TEST_P(SelectStepTest, CelErrorAsArgument) { google::protobuf::Arena arena; bool enable_unknowns = GetParam(); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, - enable_unknowns); + CelExpressionFlatImpl cel_expr( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, enable_unknowns); Activation activation; activation.InsertValue("message", CelValue::CreateError(&error)); @@ -542,8 +548,10 @@ TEST(SelectStepTest, DisableMissingAttributeOK) { path.push_back(std::move(step0)); path.push_back(std::move(step1)); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, - /*enable_unknowns=*/false); + CelExpressionFlatImpl cel_expr( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, + /*enable_unknowns=*/false); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena)); @@ -583,9 +591,10 @@ TEST(SelectStepTest, UnrecoverableUnknownValueProducesError) { path.push_back(std::move(step0)); path.push_back(std::move(step1)); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, false, - false, - /*enable_missing_attribute_errors=*/true); + CelExpressionFlatImpl cel_expr( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, false, false, + /*enable_missing_attribute_errors=*/true); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena)); @@ -631,7 +640,9 @@ TEST(SelectStepTest, UnknownPatternResolvesToUnknown) { path.push_back(*std::move(step0_status)); path.push_back(*std::move(step1_status)); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, true); + CelExpressionFlatImpl cel_expr( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, true); { std::vector unknown_patterns; diff --git a/eval/eval/shadowable_value_step_test.cc b/eval/eval/shadowable_value_step_test.cc index 08fa22a26..e4de0d03e 100644 --- a/eval/eval/shadowable_value_step_test.cc +++ b/eval/eval/shadowable_value_step_test.cc @@ -4,6 +4,7 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/descriptor.h" #include "absl/status/statusor.h" #include "eval/eval/evaluator_core.h" #include "eval/public/activation.h" @@ -28,7 +29,9 @@ absl::StatusOr RunShadowableExpression(const std::string& identifier, path.push_back(std::move(step)); google::api::expr::v1alpha1::Expr dummy_expr; - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}); + CelExpressionFlatImpl impl( + &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}); return impl.Evaluate(activation, arena); } diff --git a/eval/eval/ternary_step_test.cc b/eval/eval/ternary_step_test.cc index 621fb006f..10d57df61 100644 --- a/eval/eval/ternary_step_test.cc +++ b/eval/eval/ternary_step_test.cc @@ -3,6 +3,7 @@ #include #include +#include "google/protobuf/descriptor.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" #include "eval/public/unknown_attribute_set.h" @@ -53,8 +54,10 @@ class LogicStepTest : public testing::TestWithParam { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, - enable_unknown); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, + {}, enable_unknown); Activation activation; std::string value("test"); diff --git a/eval/public/BUILD b/eval/public/BUILD index bde897f3d..a498ee9c1 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -401,6 +401,20 @@ cc_library( ":cel_expression", ":cel_options", "//eval/compiler:flat_expr_builder", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "cel_expr_builder_factory_test", + srcs = ["cel_expr_builder_factory_test.cc"], + deps = [ + ":cel_expr_builder_factory", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_protobuf//:protobuf", ], ) @@ -800,7 +814,7 @@ cc_library( name = "set_util", srcs = ["set_util.cc"], hdrs = ["set_util.h"], - deps = ["//eval/public:cel_value"], + deps = [":cel_value"], ) cc_library( diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index b431ab63d..54d51fc5c 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -18,14 +18,150 @@ #include +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/descriptor.pb.h" +#include "google/protobuf/util/message_differencer.h" +#include "absl/status/status.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/public/cel_options.h" namespace google::api::expr::runtime { +namespace { +template +absl::Status ValidateStandardMessageType( + const google::protobuf::DescriptorPool* descriptor_pool) { + const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); + const google::protobuf::Descriptor* descriptor_from_pool = + descriptor_pool->FindMessageTypeByName(descriptor->full_name()); + if (descriptor_from_pool == nullptr) { + return absl::NotFoundError( + absl::StrFormat("Descriptor '%s' not found in descriptor pool", + descriptor->full_name())); + } + if (descriptor_from_pool == descriptor) { + return absl::OkStatus(); + } + google::protobuf::DescriptorProto descriptor_proto; + google::protobuf::DescriptorProto descriptor_from_pool_proto; + descriptor->CopyTo(&descriptor_proto); + descriptor_from_pool->CopyTo(&descriptor_from_pool_proto); + if (!google::protobuf::util::MessageDifferencer::Equals(descriptor_proto, + descriptor_from_pool_proto)) { + return absl::FailedPreconditionError(absl::StrFormat( + "The descriptor for '%s' in the descriptor pool differs from the " + "compiled-in generated version", + descriptor->full_name())); + } + return absl::OkStatus(); +} + +template +absl::Status AddOrValidateMessageType(google::protobuf::DescriptorPool* descriptor_pool) { + const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); + if (descriptor_pool->FindMessageTypeByName(descriptor->full_name()) != + nullptr) { + return ValidateStandardMessageType(descriptor_pool); + } + google::protobuf::FileDescriptorProto file_descriptor_proto; + descriptor->file()->CopyTo(&file_descriptor_proto); + if (descriptor_pool->BuildFile(file_descriptor_proto) == nullptr) { + return absl::InternalError( + absl::StrFormat("Failed to add descriptor '%s' to descriptor pool", + descriptor->full_name())); + } + return absl::OkStatus(); +} + +absl::Status ValidateStandardMessageTypes( + const google::protobuf::DescriptorPool* descriptor_pool) { + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType(descriptor_pool)); + return absl::OkStatus(); +} + +} // namespace + +absl::Status AddStandardMessageTypesToDescriptorPool( + google::protobuf::DescriptorPool* descriptor_pool) { + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + return absl::OkStatus(); +} + std::unique_ptr CreateCelExpressionBuilder( + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, const InterpreterOptions& options) { - auto builder = absl::make_unique(); + if (!ValidateStandardMessageTypes(descriptor_pool).ok()) { + return nullptr; + } + auto builder = + absl::make_unique(descriptor_pool, message_factory); builder->set_shortcircuiting(options.short_circuiting); builder->set_constant_folding(options.constant_folding, options.constant_arena); diff --git a/eval/public/cel_expr_builder_factory.h b/eval/public/cel_expr_builder_factory.h index f3f08d991..6063dacc2 100644 --- a/eval/public/cel_expr_builder_factory.h +++ b/eval/public/cel_expr_builder_factory.h @@ -1,6 +1,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPR_BUILDER_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPR_BUILDER_FACTORY_H_ +#include "google/protobuf/descriptor.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" @@ -11,8 +12,20 @@ namespace runtime { // Factory creates CelExpressionBuilder implementation for public use. std::unique_ptr CreateCelExpressionBuilder( + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, const InterpreterOptions& options = InterpreterOptions()); +inline std::unique_ptr CreateCelExpressionBuilder( + const InterpreterOptions& options = InterpreterOptions()) { + return CreateCelExpressionBuilder(google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), + options); +} + +absl::Status AddStandardMessageTypesToDescriptorPool( + google::protobuf::DescriptorPool* descriptor_pool); + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/cel_expr_builder_factory_test.cc b/eval/public/cel_expr_builder_factory_test.cc new file mode 100644 index 000000000..571fb6dc5 --- /dev/null +++ b/eval/public/cel_expr_builder_factory_test.cc @@ -0,0 +1,164 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "eval/public/cel_expr_builder_factory.h" + +#include + +#include "google/protobuf/any.pb.h" +#include "absl/container/flat_hash_map.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { + +namespace { + +using testing::HasSubstr; +using cel::internal::StatusIs; + +TEST(DescriptorPoolUtilsTest, PopulatesEmptyDescriptorPool) { + google::protobuf::DescriptorPool descriptor_pool; + + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Any"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.BoolValue"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.BytesValue"), + nullptr); + ASSERT_EQ( + descriptor_pool.FindMessageTypeByName("google.protobuf.DoubleValue"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Duration"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.FloatValue"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Int32Value"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Int64Value"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.ListValue"), + nullptr); + ASSERT_EQ( + descriptor_pool.FindMessageTypeByName("google.protobuf.StringValue"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Struct"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Timestamp"), + nullptr); + ASSERT_EQ( + descriptor_pool.FindMessageTypeByName("google.protobuf.UInt32Value"), + nullptr); + ASSERT_EQ( + descriptor_pool.FindMessageTypeByName("google.protobuf.UInt64Value"), + nullptr); + ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Value"), + nullptr); + + ASSERT_OK(AddStandardMessageTypesToDescriptorPool(&descriptor_pool)); + + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Any"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.BoolValue"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.BytesValue"), + nullptr); + EXPECT_NE( + descriptor_pool.FindMessageTypeByName("google.protobuf.DoubleValue"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Duration"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.FloatValue"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Int32Value"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Int64Value"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.ListValue"), + nullptr); + EXPECT_NE( + descriptor_pool.FindMessageTypeByName("google.protobuf.StringValue"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Struct"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Timestamp"), + nullptr); + EXPECT_NE( + descriptor_pool.FindMessageTypeByName("google.protobuf.UInt32Value"), + nullptr); + EXPECT_NE( + descriptor_pool.FindMessageTypeByName("google.protobuf.UInt64Value"), + nullptr); + EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Value"), + nullptr); +} + +TEST(DescriptorPoolUtilsTest, AcceptsPreAddedStandardTypes) { + google::protobuf::DescriptorPool descriptor_pool; + + for (auto proto_name : std::vector{ + "google.protobuf.Any", "google.protobuf.BoolValue", + "google.protobuf.BytesValue", "google.protobuf.DoubleValue", + "google.protobuf.Duration", "google.protobuf.FloatValue", + "google.protobuf.Int32Value", "google.protobuf.Int64Value", + "google.protobuf.ListValue", "google.protobuf.StringValue", + "google.protobuf.Struct", "google.protobuf.Timestamp", + "google.protobuf.UInt32Value", "google.protobuf.UInt64Value", + "google.protobuf.Value"}) { + const google::protobuf::Descriptor* descriptor = + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + proto_name); + ASSERT_NE(descriptor, nullptr); + google::protobuf::FileDescriptorProto file_descriptor_proto; + descriptor->file()->CopyTo(&file_descriptor_proto); + ASSERT_NE(descriptor_pool.BuildFile(file_descriptor_proto), nullptr); + } + + EXPECT_OK(AddStandardMessageTypesToDescriptorPool(&descriptor_pool)); +} + +TEST(DescriptorPoolUtilsTest, RejectsModifiedStandardType) { + google::protobuf::DescriptorPool descriptor_pool; + + const google::protobuf::Descriptor* descriptor = + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.protobuf.Duration"); + ASSERT_NE(descriptor, nullptr); + google::protobuf::FileDescriptorProto file_descriptor_proto; + descriptor->file()->CopyTo(&file_descriptor_proto); + // We emulate a modification by external code that replaced the nanos by a + // millis field. + google::protobuf::FieldDescriptorProto seconds_desc_proto; + google::protobuf::FieldDescriptorProto nanos_desc_proto; + descriptor->FindFieldByName("seconds")->CopyTo(&seconds_desc_proto); + descriptor->FindFieldByName("nanos")->CopyTo(&nanos_desc_proto); + nanos_desc_proto.set_name("millis"); + file_descriptor_proto.mutable_message_type(0)->clear_field(); + *file_descriptor_proto.mutable_message_type(0)->add_field() = + seconds_desc_proto; + *file_descriptor_proto.mutable_message_type(0)->add_field() = + nanos_desc_proto; + + descriptor_pool.BuildFile(file_descriptor_proto); + + EXPECT_THAT( + AddStandardMessageTypesToDescriptorPool(&descriptor_pool), + StatusIs(absl::StatusCode::kFailedPrecondition, HasSubstr("differs"))); +} + +} // namespace + +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index fc77425b2..04f9c98d7 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -80,6 +80,11 @@ class CelExpressionBuilder { type_registry_(absl::make_unique()), container_("") {} + explicit CelExpressionBuilder(const google::protobuf::DescriptorPool* descriptor_pool) + : func_registry_(absl::make_unique()), + type_registry_(absl::make_unique(descriptor_pool)), + container_("") {} + virtual ~CelExpressionBuilder() {} // Creates CelExpression object from AST tree. diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc index 85c3bb755..085c1daba 100644 --- a/eval/public/cel_type_registry.cc +++ b/eval/public/cel_type_registry.cc @@ -44,7 +44,14 @@ const absl::flat_hash_set GetCoreEnums( } // namespace CelTypeRegistry::CelTypeRegistry() - : types_(GetCoreTypes()), enums_(GetCoreEnums()) {} + : descriptor_pool_(google::protobuf::DescriptorPool::generated_pool()), + types_(GetCoreTypes()), + enums_(GetCoreEnums()) {} + +CelTypeRegistry::CelTypeRegistry(const google::protobuf::DescriptorPool* descriptor_pool) + : descriptor_pool_(descriptor_pool), + types_(GetCoreTypes()), + enums_(GetCoreEnums()) {} void CelTypeRegistry::Register(std::string fully_qualified_type_name) { // Registers the fully qualified type name as a CEL type. @@ -58,7 +65,7 @@ void CelTypeRegistry::Register(const google::protobuf::EnumDescriptor* enum_desc const google::protobuf::Descriptor* CelTypeRegistry::FindDescriptor( absl::string_view fully_qualified_type_name) const { // Public protobuf interface only accepts const string&. - return google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + return descriptor_pool_->FindMessageTypeByName( std::string(fully_qualified_type_name)); } diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index 17c1382fb..f20eab8d2 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -18,12 +18,14 @@ namespace google::api::expr::runtime { // within the standard CelExpressionBuilder. // // By default, all core CEL types and all linked protobuf message types are -// implicitly registered by way of the generated descriptor pool. In the future, -// such type registrations may be explicit to avoid accidentally exposing linked -// protobuf types to CEL which were intended to remain internal. +// implicitly registered by way of the generated descriptor pool. A descriptor +// pool can be given to avoid accidentally exposing linked protobuf types to CEL +// which were intended to remain internal or to operate on hermetic descriptor +// pools. class CelTypeRegistry { public: CelTypeRegistry(); + explicit CelTypeRegistry(const google::protobuf::DescriptorPool* descriptor_pool); ~CelTypeRegistry() {} @@ -57,6 +59,7 @@ class CelTypeRegistry { } private: + const google::protobuf::DescriptorPool* descriptor_pool_; // externally owned // pointer-stability is required for the strings in the types set, which is // why a node_hash_set is used instead of another container type. absl::node_hash_set types_; diff --git a/eval/public/containers/field_access.cc b/eval/public/containers/field_access.cc index d61d1292c..b7dcc7ead 100644 --- a/eval/public/containers/field_access.cc +++ b/eval/public/containers/field_access.cc @@ -572,13 +572,11 @@ class FieldSetter { break; } case FieldDescriptor::CPPTYPE_MESSAGE: { - const absl::string_view type_name = - field_desc_->message_type()->full_name(); // When the field is a message, it might be a well-known type with a // non-proto representation that requires special handling before it // can be set on the field. - auto wrapped_value = - CelProtoWrapper::MaybeWrapValue(type_name, value, arena_); + auto wrapped_value = CelProtoWrapper::MaybeWrapValue( + field_desc_->message_type(), value, arena_); return AssignMessage(wrapped_value.value_or(value)); } case FieldDescriptor::CPPTYPE_ENUM: { diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index 7bbafd004..12b24e9c6 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -23,6 +23,7 @@ #include #include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" @@ -48,6 +49,7 @@ using google::protobuf::Arena; using google::protobuf::Descriptor; using google::protobuf::DescriptorPool; using google::protobuf::Message; +using google::protobuf::MessageFactory; using google::api::expr::internal::EncodeTime; using google::protobuf::Any; @@ -209,7 +211,9 @@ CelValue ValueFromMessage(const Struct* struct_value, Arena* arena) { Arena::Create(arena, struct_value, arena)); } -CelValue ValueFromMessage(const Any* any_value, Arena* arena) { +CelValue ValueFromMessage(const Any* any_value, Arena* arena, + const DescriptorPool* descriptor_pool, + MessageFactory* message_factory) { auto type_url = any_value->type_url(); auto pos = type_url.find_last_of('/'); if (pos == absl::string_view::npos) { @@ -220,7 +224,7 @@ CelValue ValueFromMessage(const Any* any_value, Arena* arena) { std::string full_name = std::string(type_url.substr(pos + 1)); const Descriptor* nested_descriptor = - DescriptorPool::generated_pool()->FindMessageTypeByName(full_name); + descriptor_pool->FindMessageTypeByName(full_name); if (nested_descriptor == nullptr) { // Descriptor not found for the type @@ -228,9 +232,7 @@ CelValue ValueFromMessage(const Any* any_value, Arena* arena) { return CreateErrorValue(arena, "Descriptor not found"); } - const Message* prototype = - google::protobuf::MessageFactory::generated_factory()->GetPrototype( - nested_descriptor); + const Message* prototype = message_factory->GetPrototype(nested_descriptor); if (prototype == nullptr) { // Failed to obtain prototype for the descriptor // TODO(issues/25) What error code? @@ -247,6 +249,11 @@ CelValue ValueFromMessage(const Any* any_value, Arena* arena) { return CelProtoWrapper::CreateMessage(nested_message, arena); } +CelValue ValueFromMessage(const Any* any_value, Arena* arena) { + return ValueFromMessage(any_value, arena, DescriptorPool::generated_pool(), + MessageFactory::generated_factory()); +} + CelValue ValueFromMessage(const BoolValue* wrapper, Arena*) { return CelValue::CreateBool(wrapper->value()); } @@ -314,80 +321,77 @@ class ValueFromMessageFactory { Arena* arena) const = 0; }; -// This template class has a good performance, but performes downcast -// operations on google::protobuf::Message pointers. -template -class CastingValueFromMessageFactory : public ValueFromMessageFactory { - public: - const google::protobuf::Descriptor* GetDescriptor() const override { - return MessageType::descriptor(); - } - - absl::optional CreateValue(const google::protobuf::Message* msg, - Arena* arena) const override { - if (MessageType::descriptor() == msg->GetDescriptor()) { - const MessageType* message = - google::protobuf::DynamicCastToGenerated(msg); - if (message == nullptr) { - auto message_copy = Arena::CreateMessage(arena); - message_copy->CopyFrom(*msg); - message = message_copy; - } - return ValueFromMessage(message, arena); - } - return absl::nullopt; - } -}; - // Class makes CelValue from generic protobuf Message. // It holds a registry of CelValue factories for specific subtypes of Message. // If message does not match any of types stored in registry, generic // message-containing CelValue is created. class ValueFromMessageMaker { public: - explicit ValueFromMessageMaker() { - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - - Add(absl::make_unique>()); - - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - } - - absl::optional CreateValue(const google::protobuf::Message* value, - Arena* arena) const { - auto it = factories_.find(value->GetDescriptor()); - if (it == factories_.end()) { - // Not found for value->GetDescriptor()->name() - return absl::nullopt; + template + static absl::optional CreateWellknownTypeValue( + const google::protobuf::Message* msg, Arena* arena) { + const MessageType* message = + google::protobuf::DynamicCastToGenerated(msg); + if (message == nullptr) { + auto message_copy = Arena::CreateMessage(arena); + if (MessageType::descriptor() == msg->GetDescriptor()) { + message_copy->CopyFrom(*msg); + message = message_copy; + } else { + // message of well-known type but from a descriptor pool other than the + // generated one. + std::string serialized_msg; + if (msg->SerializeToString(&serialized_msg) && + message_copy->ParseFromString(serialized_msg)) { + message = message_copy; + } + } + } + return ValueFromMessage(message, arena); + } + + static absl::optional CreateValue(const google::protobuf::Message* message, + Arena* arena) { + switch (message->GetDescriptor()->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: + return CreateWellknownTypeValue(message, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: + return CreateWellknownTypeValue(message, arena); + // WELLKNOWNTYPE_FIELDMASK has no special CelValue type + default: + return absl::nullopt; } - return (it->second)->CreateValue(value, arena); } // Non-copyable, non-assignable ValueFromMessageMaker(const ValueFromMessageMaker&) = delete; ValueFromMessageMaker& operator=(const ValueFromMessageMaker&) = delete; - - private: - void Add(std::unique_ptr factory) { - const Descriptor* desc = factory->GetDescriptor(); - factories_.emplace(desc, std::move(factory)); - } - - absl::flat_hash_map> - factories_; }; absl::optional MessageFromValue(const CelValue& value, @@ -768,8 +772,8 @@ absl::optional MessageFromValue(const CelValue } } break; case CelValue::Type::kMessage: { - any->PackFrom(*(value.MessageOrDie())); - return any; + any->PackFrom(*(value.MessageOrDie())); + return any; } break; default: break; @@ -787,22 +791,27 @@ class MessageFromValueFactory { const CelValue& value, Arena* arena) const = 0; }; -// This template class has a good performance, but performes downcast -// operations on google::protobuf::Message pointers. -template -class CastingMessageFromValueFactory : public MessageFromValueFactory { +// MessageFromValueMaker makes a specific protobuf Message instance based on +// the desired protobuf type name and an input CelValue. +// +// It holds a registry of CelValue factories for specific subtypes of Message. +// If message does not match any of types stored in registry, an the factory +// returns an absent value. +class MessageFromValueMaker { public: - const google::protobuf::Descriptor* GetDescriptor() const override { - return MessageType::descriptor(); - } + // Non-copyable, non-assignable + MessageFromValueMaker(const MessageFromValueMaker&) = delete; + MessageFromValueMaker& operator=(const MessageFromValueMaker&) = delete; - absl::optional WrapMessage( - const CelValue& value, Arena* arena) const override { + template + static absl::optional WrapWellknownTypeMessage( + const CelValue& value, Arena* arena) { // If the value is a message type, see if it is already of the proper type // name, and return it directly. if (value.IsMessage()) { const auto* msg = value.MessageOrDie(); - if (MessageType::descriptor() == msg->GetDescriptor()) { + if (MessageType::descriptor()->well_known_type() == + msg->GetDescriptor()->well_known_type()) { return absl::nullopt; } } @@ -811,55 +820,46 @@ class CastingMessageFromValueFactory : public MessageFromValueFactory { auto* msg_buffer = Arena::CreateMessage(arena); return MessageFromValue(value, msg_buffer); } -}; -// MessageFromValueMaker makes a specific protobuf Message instance based on -// the desired protobuf type name and an input CelValue. -// -// It holds a registry of CelValue factories for specific subtypes of Message. -// If message does not match any of types stored in registry, an the factory -// returns an absent value. -class MessageFromValueMaker { - public: - explicit MessageFromValueMaker() { - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - Add(absl::make_unique>()); - } - // Non-copyable, non-assignable - MessageFromValueMaker(const MessageFromValueMaker&) = delete; - MessageFromValueMaker& operator=(const MessageFromValueMaker&) = delete; - - absl::optional MaybeWrapMessage( - absl::string_view type_name, const CelValue& value, Arena* arena) const { - auto it = factories_.find(type_name); - if (it == factories_.end()) { - // Descriptor not found for type name. - return absl::nullopt; + static absl::optional MaybeWrapMessage( + const google::protobuf::Descriptor* descriptor, const CelValue& value, + Arena* arena) { + switch (descriptor->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: + return WrapWellknownTypeMessage(value, arena); + // WELLKNOWNTYPE_FIELDMASK has no special CelValue type + default: + return absl::nullopt; } - return (it->second)->WrapMessage(value, arena); } - - private: - void Add(std::unique_ptr factory) { - const Descriptor* desc = factory->GetDescriptor(); - factories_.emplace(desc->full_name(), std::move(factory)); - } - - absl::flat_hash_map> - factories_; }; } // namespace @@ -869,23 +869,22 @@ class MessageFromValueMaker { // this method contains type checking and downcasts. CelValue CelProtoWrapper::CreateMessage(const google::protobuf::Message* value, Arena* arena) { - static const ValueFromMessageMaker* maker = new ValueFromMessageMaker(); - // Messages are Nullable types if (value == nullptr) { return CelValue::CreateNull(); } - auto special_value = maker->CreateValue(value, arena); + absl::optional special_value; + + special_value = ValueFromMessageMaker::CreateValue(value, arena); return special_value.has_value() ? special_value.value() : CelValue::CreateMessage(value); } absl::optional CelProtoWrapper::MaybeWrapValue( - absl::string_view type_name, const CelValue& value, Arena* arena) { - static const MessageFromValueMaker* maker = new MessageFromValueMaker(); - - auto msg = maker->MaybeWrapMessage(type_name, value, arena); + const google::protobuf::Descriptor* descriptor, const CelValue& value, Arena* arena) { + absl::optional msg = + MessageFromValueMaker::MaybeWrapMessage(descriptor, value, arena); if (!msg.has_value()) { return absl::nullopt; } diff --git a/eval/public/structs/cel_proto_wrapper.h b/eval/public/structs/cel_proto_wrapper.h index e979e8c7f..633be5f28 100644 --- a/eval/public/structs/cel_proto_wrapper.h +++ b/eval/public/structs/cel_proto_wrapper.h @@ -3,6 +3,7 @@ #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/descriptor.h" #include "eval/public/cel_value.h" #include "internal/proto_util.h" @@ -35,9 +36,9 @@ class CelProtoWrapper { // message to native CelValue representation during a protobuf field read. // Just as CreateMessage should only be used when reading protobuf values, // MaybeWrapValue should only be used when assigning protobuf fields. - static absl::optional MaybeWrapValue(absl::string_view type_name, - const CelValue& value, - google::protobuf::Arena* arena); + static absl::optional MaybeWrapValue( + const google::protobuf::Descriptor* descriptor, const CelValue& value, + google::protobuf::Arena* arena); }; } // namespace google::api::expr::runtime diff --git a/eval/public/structs/cel_proto_wrapper_test.cc b/eval/public/structs/cel_proto_wrapper_test.cc index ab427a7d4..296c32949 100644 --- a/eval/public/structs/cel_proto_wrapper_test.cc +++ b/eval/public/structs/cel_proto_wrapper_test.cc @@ -57,21 +57,21 @@ class CelProtoWrapperTest : public ::testing::Test { void ExpectWrappedMessage(const CelValue& value, const google::protobuf::Message& message) { // Test the input value wraps to the destination message type. - std::string type_name = message.GetTypeName(); - auto result = CelProtoWrapper::MaybeWrapValue(type_name, value, arena()); + auto result = CelProtoWrapper::MaybeWrapValue(message.GetDescriptor(), + value, arena()); EXPECT_TRUE(result.has_value()); EXPECT_TRUE((*result).IsMessage()); EXPECT_THAT((*result).MessageOrDie(), testutil::EqualsProto(message)); // Ensure that double wrapping results in the object being wrapped once. - auto identity = - CelProtoWrapper::MaybeWrapValue(type_name, *result, arena()); + auto identity = CelProtoWrapper::MaybeWrapValue(message.GetDescriptor(), + *result, arena()); EXPECT_FALSE(identity.has_value()); // Check to make sure that even dynamic messages can be used as input to // the wrapping call. result = CelProtoWrapper::MaybeWrapValue( - ReflectedCopy(message)->GetTypeName(), value, arena()); + ReflectedCopy(message)->GetDescriptor(), value, arena()); EXPECT_TRUE(result.has_value()); EXPECT_TRUE((*result).IsMessage()); EXPECT_THAT((*result).MessageOrDie(), testutil::EqualsProto(message)); @@ -79,8 +79,8 @@ class CelProtoWrapperTest : public ::testing::Test { void ExpectNotWrapped(const CelValue& value, const google::protobuf::Message& message) { // Test the input value does not wrap by asserting value == result. - auto result = - CelProtoWrapper::MaybeWrapValue(message.GetTypeName(), value, arena()); + auto result = CelProtoWrapper::MaybeWrapValue(message.GetDescriptor(), + value, arena()); EXPECT_FALSE(result.has_value()); } diff --git a/eval/testutil/BUILD b/eval/testutil/BUILD index 268e225b1..420f29f0c 100644 --- a/eval/testutil/BUILD +++ b/eval/testutil/BUILD @@ -21,3 +21,10 @@ cc_proto_library( name = "test_message_cc_proto", deps = [":test_message_protos"], ) + +proto_library( + name = "simple_test_message_proto", + srcs = [ + "simple_test_message.proto", + ], +) diff --git a/eval/testutil/simple_test_message.proto b/eval/testutil/simple_test_message.proto new file mode 100644 index 000000000..27a822fbb --- /dev/null +++ b/eval/testutil/simple_test_message.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package google.api.expr.runtime; + +// This has no dependencies on any other messages to keep the file descriptor +// set needed to parse this message simple. +message SimpleTestMessage { + int64 int64_value = 1; +} diff --git a/tools/BUILD b/tools/BUILD index 1daaf8756..1146add08 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -2,42 +2,6 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) -cc_library( - name = "cel_ast_renumber", - srcs = ["cel_ast_renumber.cc"], - hdrs = ["cel_ast_renumber.h"], - deps = [ - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - ], -) - -cc_library( - name = "reference_inliner", - srcs = [ - "reference_inliner.cc", - ], - hdrs = [ - "reference_inliner.h", - ], - deps = [ - ":cel_ast_renumber", - "//eval/public:ast_rewrite", - "//eval/public:ast_traverse", - "//eval/public:ast_visitor_base", - "//eval/public:source_position", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_googlesource_code_re2//:re2", - ], -) - cc_library( name = "flatbuffers_backed_impl", srcs = [ diff --git a/tools/cel_ast_renumber.cc b/tools/cel_ast_renumber.cc deleted file mode 100644 index 80aa51cb7..000000000 --- a/tools/cel_ast_renumber.cc +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tools/cel_ast_renumber.h" - -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "absl/container/flat_hash_map.h" - -namespace cel::ast { -namespace { - -using ::google::api::expr::v1alpha1::CheckedExpr; -using ::google::api::expr::v1alpha1::Expr; - -// Renumbers expression IDs in a CheckedExpr. -// Note: does not renumber within macro_calls values. -class Renumberer { - public: - explicit Renumberer(int64_t next_id) : next_id_(next_id) {} - - // Returns the next free expression ID after renumbering. - int64_t Renumber(CheckedExpr* cexpr) { - old_to_new_.clear(); - Visit(cexpr->mutable_expr()); - CheckedExpr c2; // scratch proto tables of the right type - - for (auto it = cexpr->type_map().begin(); it != cexpr->type_map().end(); - it++) { - (*c2.mutable_type_map())[old_to_new_[it->first]] = it->second; - } - std::swap(*cexpr->mutable_type_map(), *c2.mutable_type_map()); - c2.mutable_type_map()->clear(); - - for (auto it = cexpr->reference_map().begin(); - it != cexpr->reference_map().end(); it++) { - (*c2.mutable_reference_map())[old_to_new_[it->first]] = it->second; - } - std::swap(*cexpr->mutable_reference_map(), *c2.mutable_reference_map()); - c2.mutable_reference_map()->clear(); - - if (cexpr->has_source_info()) { - auto* source_info = cexpr->mutable_source_info(); - auto* s2 = c2.mutable_source_info(); - - for (auto it = source_info->positions().begin(); - it != source_info->positions().end(); it++) { - (*s2->mutable_positions())[old_to_new_[it->first]] = it->second; - } - std::swap(*source_info->mutable_positions(), *s2->mutable_positions()); - s2->mutable_positions()->clear(); - - for (auto it = source_info->macro_calls().begin(); - it != source_info->macro_calls().end(); it++) { - (*s2->mutable_macro_calls())[old_to_new_[it->first]] = it->second; - } - std::swap(*source_info->mutable_macro_calls(), - *s2->mutable_macro_calls()); - s2->mutable_macro_calls()->clear(); - } - - return next_id_; - } - - private: - // Insert mapping from old_id to the current next new_id. - // Return next new_id. - int64_t Renumber(int64_t old_id) { - int64_t new_id = next_id_; - ++next_id_; - old_to_new_[old_id] = new_id; - return new_id; - } - - // Renumber this Expr and all sub-exprs and map entries. - void Visit(Expr* e) { - if (!e) { - return; - } - switch (e->expr_kind_case()) { - case Expr::kSelectExpr: - Visit(e->mutable_select_expr()->mutable_operand()); - break; - case Expr::kCallExpr: { - auto call_expr = e->mutable_call_expr(); - if (call_expr->has_target()) { - Visit(call_expr->mutable_target()); - } - for (int i = 0; i < call_expr->args_size(); i++) { - Visit(call_expr->mutable_args(i)); - } - } break; - case Expr::kListExpr: { - auto list_expr = e->mutable_list_expr(); - for (int i = 0; i < list_expr->elements_size(); i++) { - Visit(list_expr->mutable_elements(i)); - } - } break; - case Expr::kStructExpr: { - auto struct_expr = e->mutable_struct_expr(); - for (int i = 0; i < struct_expr->entries_size(); i++) { - auto entry = struct_expr->mutable_entries(i); - if (entry->has_map_key()) { - Visit(entry->mutable_map_key()); - } - Visit(entry->mutable_value()); - entry->set_id(Renumber(entry->id())); - } - } break; - case Expr::kComprehensionExpr: { - auto comp_expr = e->mutable_comprehension_expr(); - Visit(comp_expr->mutable_iter_range()); - Visit(comp_expr->mutable_accu_init()); - Visit(comp_expr->mutable_loop_condition()); - Visit(comp_expr->mutable_loop_step()); - Visit(comp_expr->mutable_result()); - } break; - default: - // no other types have sub-expressions - break; - } - e->set_id(Renumber(e->id())); // do this last to mimic bottom-up build - } - - int64_t next_id_; // saved between Renumber() calls - absl::flat_hash_map - old_to_new_; // cleared between Renumber() calls -}; - -} // namespace - -// Renumbers expression IDs in a CheckedExpr in-place. -// This is intended to be used for injecting multiple sub-expressions into -// a merged expression. -// Note: does not renumber within macro_calls values. -// Returns the next free ID. -int64_t Renumber(int64_t starting_id, CheckedExpr* expr) { - return Renumberer(starting_id).Renumber(expr); -} - -} // namespace cel::ast diff --git a/tools/cel_ast_renumber.h b/tools/cel_ast_renumber.h deleted file mode 100644 index 5dad9d4b9..000000000 --- a/tools/cel_ast_renumber.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_TOOLS_CEL_AST_RENUMBER_H_ -#define THIRD_PARTY_CEL_CPP_TOOLS_CEL_AST_RENUMBER_H_ - -#include - -#include "google/api/expr/v1alpha1/checked.pb.h" - -namespace cel::ast { - -// Renumbers expression IDs in a CheckedExpr in-place. -// This is intended to be used for injecting multiple sub-expressions into -// a merged expression. -// TODO(issues/139): this does not renumber within macro_calls values. -// Returns the next free ID. -int64_t Renumber(int64_t starting_id, google::api::expr::v1alpha1::CheckedExpr* expr); - -} // namespace cel::ast - -#endif // THIRD_PARTY_CEL_CPP_TOOLS_CEL_AST_RENUMBER_H_ diff --git a/tools/reference_inliner.cc b/tools/reference_inliner.cc deleted file mode 100644 index 8fdacba2c..000000000 --- a/tools/reference_inliner.cc +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tools/reference_inliner.h" - -#include -#include -#include - -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "eval/public/ast_rewrite.h" -#include "eval/public/ast_traverse.h" -#include "eval/public/ast_visitor_base.h" -#include "eval/public/source_position.h" -#include "tools/cel_ast_renumber.h" -#include "re2/re2.h" -#include "re2/regexp.h" - -namespace cel::ast { -namespace { - -using ::google::api::expr::v1alpha1::CheckedExpr; -using ::google::api::expr::v1alpha1::Expr; -using ::google::api::expr::runtime::AstRewrite; -using ::google::api::expr::runtime::AstRewriterBase; -using ::google::api::expr::runtime::AstTraverse; -using ::google::api::expr::runtime::AstVisitorBase; -using ::google::api::expr::runtime::SourcePosition; - -// Filter for legal select paths. -static LazyRE2 kIdentRegex = { - R"(([_a-zA-Z][_a-zA-Z0-9]*)(\.[_a-zA-Z][_a-zA-Z0-9]*)*)"}; - -using IdentExpr = google::api::expr::v1alpha1::Expr::Ident; -using RewriteRuleMap = - absl::flat_hash_map; - -void MergeMetadata(const CheckedExpr& to_insert, CheckedExpr* base) { - base->mutable_reference_map()->insert(to_insert.reference_map().begin(), - to_insert.reference_map().end()); - base->mutable_type_map()->insert(to_insert.type_map().begin(), - to_insert.type_map().end()); - auto* source_info = base->mutable_source_info(); - source_info->mutable_positions()->insert( - to_insert.source_info().positions().begin(), - to_insert.source_info().positions().end()); - - source_info->mutable_macro_calls()->insert( - to_insert.source_info().macro_calls().begin(), - to_insert.source_info().macro_calls().end()); -} - -void PruneMetadata(const std::vector& ids, CheckedExpr* base) { - auto* source_info = base->mutable_source_info(); - for (int64_t i : ids) { - base->mutable_reference_map()->erase(i); - base->mutable_type_map()->erase(i); - source_info->mutable_positions()->erase(i); - source_info->mutable_macro_calls()->erase(i); - } -} - -class InlinerRewrite : public AstRewriterBase { - public: - InlinerRewrite(const RewriteRuleMap& rewrite_rules, CheckedExpr* base, - int64_t next_id) - : base_(base), rewrite_rules_(rewrite_rules), next_id_(next_id) {} - void PostVisitIdent(const IdentExpr* ident, const Expr* expr, - const SourcePosition* source_pos) override { - // e.g. `com.google.Identifier` would have a path of - // SelectExpr("Identifier"), SelectExpr("google"), IdentExpr("com") - std::vector qualifiers{ident->name()}; - for (int i = path_.size() - 2; i >= 0; i--) { - if (!path_[i]->has_select_expr() || path_[i]->select_expr().test_only()) { - break; - } - qualifiers.push_back(path_[i]->select_expr().field()); - } - - // Check longest possible match first then less specific qualifiers. - for (int path_len = qualifiers.size(); path_len >= 1; path_len--) { - int path_len_offset = qualifiers.size() - path_len; - std::string candidate = absl::StrJoin( - qualifiers.begin(), qualifiers.end() - path_len_offset, "."); - auto rule_it = rewrite_rules_.find(candidate); - if (rule_it != rewrite_rules_.end()) { - std::vector invalidated_ids; - invalidated_ids.reserve(path_len); - for (int offset = 0; offset < path_len; offset++) { - invalidated_ids.push_back(path_[path_.size() - (1 + offset)]->id()); - } - - // The target the root node of the reference subtree to get updated. - int64_t root_id = path_[path_.size() - path_len]->id(); - rewrite_positions_[root_id] = - Rewrite{std::move(invalidated_ids), rule_it->second}; - // Any other rewrites are redundant. - break; - } - } - } - - bool PostVisitRewrite(Expr* expr, const SourcePosition* source_pos) override { - auto it = rewrite_positions_.find(expr->id()); - if (it == rewrite_positions_.end()) { - return false; - } - const Rewrite& rewrite = (it->second); - CheckedExpr new_sub_expr = *rewrite.rewrite; - next_id_ = Renumber(next_id_, &new_sub_expr); - MergeMetadata(new_sub_expr, base_); - expr->Swap(new_sub_expr.mutable_expr()); - PruneMetadata(rewrite.invalidated_ids, base_); - return true; - } - - void TraversalStackUpdate(absl::Span path) override { - path_ = path; - } - - private: - struct Rewrite { - std::vector invalidated_ids; - const CheckedExpr* rewrite; - }; - absl::Span path_; - absl::flat_hash_map rewrite_positions_; - CheckedExpr* base_; - const RewriteRuleMap& rewrite_rules_; - int next_id_; -}; - -// Validate visitor is used to check that an AST is safe for the inlining -// utility -- hand-rolled ASTs may not have a legal numbering for the nodes in -// the tree and metadata maps (i.e. a unique id for each node). -// CheckedExprs generated from a type checker should always be safe. -class ValidateVisitor : public AstVisitorBase { - public: - ValidateVisitor() : max_id_(0), is_valid_(true) {} - void PostVisitExpr(const Expr* expr, const SourcePosition* pos) override { - auto [it, inserted] = visited_.insert(expr->id()); - if (!inserted) { - is_valid_ = false; - } - if (expr->id() > max_id_) { - max_id_ = expr->id(); - } - } - bool IdsValid() { return is_valid_; } - int64_t GetMaxId() { return max_id_; } - - private: - int64_t max_id_; - absl::flat_hash_set visited_; - bool is_valid_; -}; - -} // namespace - -absl::Status Inliner::SetRewriteRule(absl::string_view qualified_identifier, - const CheckedExpr& expr) { - if (!RE2::FullMatch(re2::StringPiece(qualified_identifier.data(), qualified_identifier.size()), *kIdentRegex)) { - return absl::InvalidArgumentError( - absl::StrCat("Unsupported identifier for CheckedExpr rewrite rule: ", - qualified_identifier)); - } - rewrites_.insert_or_assign(qualified_identifier, &expr); - return absl::OkStatus(); -} - -absl::StatusOr Inliner::Inline(const CheckedExpr& expr) const { - // Determine if the source expr has a legal numbering and pick out the next - // available id. - ValidateVisitor validator; - AstTraverse(&expr.expr(), &expr.source_info(), &validator); - if (!validator.IdsValid()) { - return absl::InvalidArgumentError("Invalid Expr IDs"); - } - CheckedExpr output = expr; - InlinerRewrite rewrite_visitor(rewrites_, &output, validator.GetMaxId() + 1); - AstRewrite(output.mutable_expr(), &output.source_info(), &rewrite_visitor); - return output; -} - -} // namespace cel::ast diff --git a/tools/reference_inliner.h b/tools/reference_inliner.h deleted file mode 100644 index 010f74d41..000000000 --- a/tools/reference_inliner.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_TOOLS_REFERENCE_INLINER_H_ -#define THIRD_PARTY_CEL_CPP_TOOLS_REFERENCE_INLINER_H_ - -#include - -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" - -namespace cel::ast { - -class Inliner { - public: - Inliner() {} - explicit Inliner(absl::flat_hash_map - rewrites) - : rewrites_(std::move(rewrites)) {} - - // Add a qualified ident to replace with a checked expression. - // The supplied CheckedExpr must outlive the Inliner. - // Replaces any existing rewrite rules for the given identifier -- the last - // call will always overwrite any prior calls for a given identifier. - absl::Status SetRewriteRule(absl::string_view qualified_identifier, - const google::api::expr::v1alpha1::CheckedExpr& expr); - - // Apply all of the rewrites to expr. - // Returns an error if expr is not valid (i.e. unsupported expr ids). - absl::StatusOr Inline( - const google::api::expr::v1alpha1::CheckedExpr& expr) const; - - private: - absl::flat_hash_map - rewrites_; -}; - -} // namespace cel::ast -#endif // THIRD_PARTY_CEL_CPP_TOOLS_REFERENCE_INLINER_H_ From be6479adc1d33a4af57700effc6f6eb37f5142f8 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 10 Feb 2022 20:28:30 +0000 Subject: [PATCH 002/111] Allow registration of non-strict functions Allow users to register non-strict functions, which are functions that could take `CelError` or `UnknownSet` as arguments. - Add a field `is_strict` to `CelFunctionDescriptor`. - Change `CelFunctionRegistry` so that it ensure if the function has any non-strict overload, it has only one overload. - Modify `IsNonStrict` in `function_step.cc` to check `is_strict` field. PiperOrigin-RevId: 427817399 --- eval/eval/BUILD | 1 + eval/eval/function_step.cc | 15 +-- eval/eval/function_step_test.cc | 66 +++++++++++- eval/public/BUILD | 1 + eval/public/cel_function.h | 12 ++- eval/public/cel_function_registry.cc | 31 ++++++ eval/public/cel_function_registry.h | 4 + eval/public/cel_function_registry_test.cc | 124 ++++++++++++++++++++++ 8 files changed, 241 insertions(+), 13 deletions(-) diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 45e55015d..6368811fa 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -403,6 +403,7 @@ cc_test( "//eval/public:cel_value", "//eval/public:unknown_function_result_set", "//eval/public/structs:cel_proto_wrapper", + "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//internal:status_macros", "//internal:testing", diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index 620129dd9..cf2322598 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -31,11 +31,14 @@ namespace google::api::expr::runtime { namespace { -// Non-strict functions are allowed to consume errors and UnknownSets. Currently -// only the special function "@not_strictly_false" is allowed to do this. -bool IsNonStrict(const std::string& name) { - return (name == builtin::kNotStrictlyFalse || - name == builtin::kNotStrictlyFalseDeprecated); +// Only non-strict functions are allowed to consume errors and unknown sets. +bool IsNonStrict(const CelFunction& function) { + const CelFunctionDescriptor& descriptor = function.descriptor(); + // Special case: built-in function "@not_strictly_false" is treated as + // non-strict. + return !descriptor.is_strict() || + descriptor.name() == builtin::kNotStrictlyFalse || + descriptor.name() == builtin::kNotStrictlyFalseDeprecated; } // Determine if the overload should be considered. Overloads that can consume @@ -47,7 +50,7 @@ bool ShouldAcceptOverload(const CelFunction* function, } for (size_t i = 0; i < arguments.size(); i++) { if (arguments[i].IsUnknownSet() || arguments[i].IsError()) { - return IsNonStrict(function->descriptor().name()); + return IsNonStrict(*function); } } return true; diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index 7513d5c0d..d64020434 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -19,6 +19,7 @@ #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/testing/matchers.h" #include "eval/public/unknown_function_result_set.h" #include "eval/testutil/test_message.pb.h" #include "internal/status_macros.h" @@ -123,11 +124,12 @@ class AddFunction : public CelFunction { class SinkFunction : public CelFunction { public: - explicit SinkFunction(CelValue::Type type) - : CelFunction(CreateDescriptor(type)) {} + explicit SinkFunction(CelValue::Type type, bool is_strict = true) + : CelFunction(CreateDescriptor(type, is_strict)) {} - static CelFunctionDescriptor CreateDescriptor(CelValue::Type type) { - return CelFunctionDescriptor{"Sink", false, {type}}; + static CelFunctionDescriptor CreateDescriptor(CelValue::Type type, + bool is_strict = true) { + return CelFunctionDescriptor{"Sink", false, {type}, is_strict}; } static Expr::Call MakeCall() { @@ -964,6 +966,60 @@ TEST_F(FunctionStepNullCoercionTest, Disabled) { ASSERT_TRUE(value.IsError()); } -} // namespace +TEST(FunctionStepStrictnessTest, + IfFunctionStrictAndGivenUnknownSkipsInvocation) { + UnknownSet unknown_set; + CelFunctionRegistry registry; + ASSERT_OK(registry.Register(absl::make_unique( + CelValue::CreateUnknownSet(&unknown_set), "ConstUnknown"))); + ASSERT_OK(registry.Register(std::make_unique( + CelValue::Type::kUnknownSet, /*is_strict=*/true))); + ExecutionPath path; + Expr::Call call0 = ConstFunction::MakeCall("ConstUnknown"); + Expr::Call call1 = SinkFunction::MakeCall(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr step0, + MakeTestFunctionStep(&call0, registry)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr step1, + MakeTestFunctionStep(&call1, registry)); + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + Expr placeholder_expr; + CelExpressionFlatImpl impl(&placeholder_expr, std::move(path), + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, + true, true); + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); + ASSERT_TRUE(value.IsUnknownSet()); +} +TEST(FunctionStepStrictnessTest, IfFunctionNonStrictAndGivenUnknownInvokesIt) { + UnknownSet unknown_set; + CelFunctionRegistry registry; + ASSERT_OK(registry.Register(absl::make_unique( + CelValue::CreateUnknownSet(&unknown_set), "ConstUnknown"))); + ASSERT_OK(registry.Register(std::make_unique( + CelValue::Type::kUnknownSet, /*is_strict=*/false))); + ExecutionPath path; + Expr::Call call0 = ConstFunction::MakeCall("ConstUnknown"); + Expr::Call call1 = SinkFunction::MakeCall(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr step0, + MakeTestFunctionStep(&call0, registry)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr step1, + MakeTestFunctionStep(&call1, registry)); + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + Expr placeholder_expr; + CelExpressionFlatImpl impl(&placeholder_expr, std::move(path), + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory(), 0, {}, + true, true); + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); + ASSERT_THAT(value, test::IsCelInt64(Eq(0))); +} + +} // namespace } // namespace google::api::expr::runtime diff --git a/eval/public/BUILD b/eval/public/BUILD index a498ee9c1..35283c8f6 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -589,6 +589,7 @@ cc_test( ":cel_function_registry", "//internal:status_macros", "//internal:testing", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) diff --git a/eval/public/cel_function.h b/eval/public/cel_function.h index 28250b561..d60a107e3 100644 --- a/eval/public/cel_function.h +++ b/eval/public/cel_function.h @@ -17,10 +17,12 @@ namespace google::api::expr::runtime { class CelFunctionDescriptor { public: CelFunctionDescriptor(absl::string_view name, bool receiver_style, - std::vector types) + std::vector types, + bool is_strict = true) : name_(name), receiver_style_(receiver_style), - types_(std::move(types)) {} + types_(std::move(types)), + is_strict_(is_strict) {} // Function name. const std::string& name() const { return name_; } @@ -31,6 +33,11 @@ class CelFunctionDescriptor { // The argmument types the function accepts. const std::vector& types() const { return types_; } + // if true (strict, default), error or unknown arguments are propagated + // instead of calling the function. if false (non-strict), the function may + // receive error or unknown values as arguments. + bool is_strict() const { return is_strict_; } + // Helper for matching a descriptor. This tests that the shape is the same -- // |other| accepts the same number and types of arguments and is the same call // style). @@ -44,6 +51,7 @@ class CelFunctionDescriptor { std::string name_; bool receiver_style_; std::vector types_; + bool is_strict_; }; // CelFunction is a handler that represents single diff --git a/eval/public/cel_function_registry.cc b/eval/public/cel_function_registry.cc index 6834d6e37..35735d86d 100644 --- a/eval/public/cel_function_registry.cc +++ b/eval/public/cel_function_registry.cc @@ -14,6 +14,10 @@ absl::Status CelFunctionRegistry::Register( absl::StatusCode::kAlreadyExists, "CelFunction with specified parameters already registered"); } + if (!ValidateNonStrictOverload(descriptor)) { + return absl::Status(absl::StatusCode::kAlreadyExists, + "Only one overload is allowed for non-strict function"); + } auto& overloads = functions_[descriptor.name()]; overloads.static_overloads.push_back(std::move(function)); @@ -28,6 +32,10 @@ absl::Status CelFunctionRegistry::RegisterLazyFunction( absl::StatusCode::kAlreadyExists, "CelFunction with specified parameters already registered"); } + if (!ValidateNonStrictOverload(descriptor)) { + return absl::Status(absl::StatusCode::kAlreadyExists, + "Only one overload is allowed for non-strict function"); + } auto& overloads = functions_[descriptor.name()]; LazyFunctionEntry entry = std::make_unique( descriptor, std::move(factory)); @@ -106,4 +114,27 @@ bool CelFunctionRegistry::DescriptorRegistered( .empty()); } +bool CelFunctionRegistry::ValidateNonStrictOverload( + const CelFunctionDescriptor& descriptor) const { + auto overloads = functions_.find(descriptor.name()); + if (overloads == functions_.end()) { + return true; + } + const RegistryEntry& entry = overloads->second; + if (!descriptor.is_strict()) { + // If the newly added overload is a non-strict function, we require that + // there are no other overloads, which is not possible here. + return false; + } + // If the newly added overload is a strict function, we need to make sure + // that no previous overloads are registered non-strict. If the list of + // overload is not empty, we only need to check the first overload. This is + // because if the first overload is strict, other overloads must also be + // strict by the rule. + return (entry.static_overloads.empty() || + entry.static_overloads[0]->descriptor().is_strict()) && + (entry.lazy_overloads.empty() || + entry.lazy_overloads[0]->first.is_strict()); +} + } // namespace google::api::expr::runtime diff --git a/eval/public/cel_function_registry.h b/eval/public/cel_function_registry.h index 79fbbb4d1..f4445609d 100644 --- a/eval/public/cel_function_registry.h +++ b/eval/public/cel_function_registry.h @@ -69,6 +69,10 @@ class CelFunctionRegistry { // Returns whether the descriptor is registered in either as a lazy funtion or // in the static functions. bool DescriptorRegistered(const CelFunctionDescriptor& descriptor) const; + // Returns true if after adding this function, the rule "a non-strict + // function should have only a single overload" will be preserved. + bool ValidateNonStrictOverload(const CelFunctionDescriptor& descriptor) const; + using StaticFunctionEntry = std::unique_ptr; using LazyFunctionEntry = std::unique_ptr< std::pair>>; diff --git a/eval/public/cel_function_registry_test.cc b/eval/public/cel_function_registry_test.cc index 66bd8218e..4f03c9983 100644 --- a/eval/public/cel_function_registry_test.cc +++ b/eval/public/cel_function_registry_test.cc @@ -2,6 +2,7 @@ #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "eval/public/activation.h" #include "eval/public/cel_function.h" @@ -14,8 +15,10 @@ namespace google::api::expr::runtime { namespace { using testing::Eq; +using testing::HasSubstr; using testing::Property; using testing::SizeIs; +using cel::internal::StatusIs; class NullLazyFunctionProvider : public virtual CelFunctionProvider { public: @@ -105,6 +108,127 @@ TEST(CelFunctionRegistryTest, DefaultLazyProvider) { Eq("LazyFunction")))); } +TEST(CelFunctionRegistryTest, CanRegisterNonStrictFunction) { + { + CelFunctionRegistry registry; + CelFunctionDescriptor descriptor("NonStrictFunction", + /*receiver_style=*/false, + {CelValue::Type::kAny}, + /*is_strict=*/false); + ASSERT_OK( + registry.Register(std::make_unique(descriptor))); + EXPECT_THAT(registry.FindOverloads("NonStrictFunction", false, + {CelValue::Type::kAny}), + SizeIs(1)); + } + { + CelFunctionRegistry registry; + CelFunctionDescriptor descriptor("NonStrictLazyFunction", + /*receiver_style=*/false, + {CelValue::Type::kAny}, + /*is_strict=*/false); + EXPECT_OK(registry.RegisterLazyFunction(descriptor)); + EXPECT_THAT(registry.FindLazyOverloads("NonStrictLazyFunction", false, + {CelValue::Type::kAny}), + SizeIs(1)); + } +} + +using NonStrictTestCase = std::tuple; +using NonStrictRegistrationFailTest = testing::TestWithParam; + +TEST_P(NonStrictRegistrationFailTest, + IfOtherOverloadExistsRegisteringNonStrictFails) { + bool existing_function_is_lazy, new_function_is_lazy; + std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); + CelFunctionRegistry registry; + CelFunctionDescriptor descriptor("OverloadedFunction", + /*receiver_style=*/false, + {CelValue::Type::kAny}, + /*is_strict=*/true); + if (existing_function_is_lazy) { + ASSERT_OK(registry.RegisterLazyFunction(descriptor)); + } else { + ASSERT_OK( + registry.Register(std::make_unique(descriptor))); + } + CelFunctionDescriptor new_descriptor( + "OverloadedFunction", + /*receiver_style=*/false, {CelValue::Type::kAny, CelValue::Type::kAny}, + /*is_strict=*/false); + absl::Status status; + if (new_function_is_lazy) { + status = registry.RegisterLazyFunction(new_descriptor); + } else { + status = + registry.Register(std::make_unique(new_descriptor)); + } + EXPECT_THAT(status, StatusIs(absl::StatusCode::kAlreadyExists, + HasSubstr("Only one overload"))); +} + +TEST_P(NonStrictRegistrationFailTest, + IfOtherNonStrictExistsRegisteringStrictFails) { + bool existing_function_is_lazy, new_function_is_lazy; + std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); + CelFunctionRegistry registry; + CelFunctionDescriptor descriptor("OverloadedFunction", + /*receiver_style=*/false, + {CelValue::Type::kAny}, + /*is_strict=*/false); + if (existing_function_is_lazy) { + ASSERT_OK(registry.RegisterLazyFunction(descriptor)); + } else { + ASSERT_OK( + registry.Register(std::make_unique(descriptor))); + } + CelFunctionDescriptor new_descriptor( + "OverloadedFunction", + /*receiver_style=*/false, {CelValue::Type::kAny, CelValue::Type::kAny}, + /*is_strict=*/true); + absl::Status status; + if (new_function_is_lazy) { + status = registry.RegisterLazyFunction(new_descriptor); + } else { + status = + registry.Register(std::make_unique(new_descriptor)); + } + EXPECT_THAT(status, StatusIs(absl::StatusCode::kAlreadyExists, + HasSubstr("Only one overload"))); +} + +TEST_P(NonStrictRegistrationFailTest, CanRegisterStrictFunctionsWithoutLimit) { + bool existing_function_is_lazy, new_function_is_lazy; + std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); + CelFunctionRegistry registry; + CelFunctionDescriptor descriptor("OverloadedFunction", + /*receiver_style=*/false, + {CelValue::Type::kAny}, + /*is_strict=*/true); + if (existing_function_is_lazy) { + ASSERT_OK(registry.RegisterLazyFunction(descriptor)); + } else { + ASSERT_OK( + registry.Register(std::make_unique(descriptor))); + } + CelFunctionDescriptor new_descriptor( + "OverloadedFunction", + /*receiver_style=*/false, {CelValue::Type::kAny, CelValue::Type::kAny}, + /*is_strict=*/true); + absl::Status status; + if (new_function_is_lazy) { + status = registry.RegisterLazyFunction(new_descriptor); + } else { + status = + registry.Register(std::make_unique(new_descriptor)); + } + EXPECT_OK(status); +} + +INSTANTIATE_TEST_SUITE_P(NonStrictRegistrationFailTest, + NonStrictRegistrationFailTest, + testing::Combine(testing::Bool(), testing::Bool())); + } // namespace } // namespace google::api::expr::runtime From 9207f71eeb08506e0009ca0ebe6de8780454c072 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 11 Feb 2022 18:57:32 +0000 Subject: [PATCH 003/111] Internal change PiperOrigin-RevId: 428042273 --- base/BUILD | 20 +++ base/internal/BUILD | 5 + base/internal/memory_manager.h | 80 ++++++++++++ base/memory_manager.cc | 62 +++++++++ base/memory_manager.h | 139 +++++++++++++++++++++ base/memory_manager_test.cc | 83 ++++++++++++ extensions/protobuf/BUILD | 41 ++++++ extensions/protobuf/memory_manager.cc | 42 +++++++ extensions/protobuf/memory_manager.h | 75 +++++++++++ extensions/protobuf/memory_manager_test.cc | 109 ++++++++++++++++ 10 files changed, 656 insertions(+) create mode 100644 base/internal/memory_manager.h create mode 100644 base/memory_manager.cc create mode 100644 base/memory_manager.h create mode 100644 base/memory_manager_test.cc create mode 100644 extensions/protobuf/BUILD create mode 100644 extensions/protobuf/memory_manager.cc create mode 100644 extensions/protobuf/memory_manager.h create mode 100644 extensions/protobuf/memory_manager_test.cc diff --git a/base/BUILD b/base/BUILD index b6f98e7fc..d9e7f28a3 100644 --- a/base/BUILD +++ b/base/BUILD @@ -37,6 +37,26 @@ cc_test( ], ) +cc_library( + name = "memory_manager", + srcs = ["memory_manager.cc"], + hdrs = ["memory_manager.h"], + deps = [ + "//base/internal:memory_manager", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_test( + name = "memory_manager_test", + srcs = ["memory_manager_test.cc"], + deps = [ + ":memory_manager", + "//base/internal:memory_manager", + "//internal:testing", + ], +) + cc_library( name = "operators", srcs = ["operators.cc"], diff --git a/base/internal/BUILD b/base/internal/BUILD index d4eeffe0d..ea842ae96 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -16,6 +16,11 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) +cc_library( + name = "memory_manager", + textual_hdrs = ["memory_manager.h"], +) + cc_library( name = "operators", hdrs = ["operators.h"], diff --git a/base/internal/memory_manager.h b/base/internal/memory_manager.h new file mode 100644 index 000000000..785cc8a72 --- /dev/null +++ b/base/internal/memory_manager.h @@ -0,0 +1,80 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_H_ + +#include + +namespace cel { + +class MemoryManager; + +namespace base_internal { + +template +class MemoryManagerDeleter; + +// True if the deleter is no-op, meaning the object was allocated in an arena +// and the arena will perform any deletion upon its own destruction. +template +bool IsEmptyDeleter(const MemoryManagerDeleter& deleter); + +template +class MemoryManagerDeleter final { + public: + constexpr MemoryManagerDeleter() noexcept = default; + + MemoryManagerDeleter(const MemoryManagerDeleter&) = delete; + + constexpr MemoryManagerDeleter(MemoryManagerDeleter&& other) noexcept + : MemoryManagerDeleter() { + std::swap(memory_manager_, other.memory_manager_); + std::swap(size_, other.size_); + std::swap(align_, other.align_); + } + + void operator()(T* pointer) const; + + private: + friend class cel::MemoryManager; + template + friend bool IsEmptyDeleter(const MemoryManagerDeleter& deleter); + + MemoryManagerDeleter(MemoryManager* memory_manager, size_t size, size_t align) + : memory_manager_(memory_manager), size_(size), align_(align) {} + + MemoryManager* memory_manager_ = nullptr; + size_t size_ = 0; + size_t align_ = 0; +}; + +template +bool IsEmptyDeleter(const MemoryManagerDeleter& deleter) { + return deleter.memory_manager_ == nullptr; +} + +template +class MemoryManagerDestructor final { + private: + friend class cel::MemoryManager; + + static void Destruct(void* pointer) { reinterpret_cast(pointer)->~T(); } +}; + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_H_ diff --git a/base/memory_manager.cc b/base/memory_manager.cc new file mode 100644 index 000000000..9e1805c38 --- /dev/null +++ b/base/memory_manager.cc @@ -0,0 +1,62 @@ +#include "base/memory_manager.h" + +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" + +namespace cel { + +namespace { + +class GlobalMemoryManager final : public MemoryManager { + private: + AllocationResult Allocate(size_t size, size_t align) override { + return {::operator new(size, static_cast(align)), true}; + } + + void Deallocate(void* pointer, size_t size, size_t align) override { + ::operator delete(pointer, size, static_cast(align)); + } +}; + +} // namespace + +MemoryManager* MemoryManager::Global() { + static MemoryManager* const instance = new GlobalMemoryManager(); + return instance; +} + +void MemoryManager::OwnDestructor(void* pointer, void (*destruct)(void*)) { + static_cast(pointer); + static_cast(destruct); + // OwnDestructor is only called for arena-based memory managers by `New`. If + // we got here, something is seriously wrong so crashing is okay. + std::abort(); +} + +void ArenaMemoryManager::Deallocate(void* pointer, size_t size, size_t align) { + static_cast(pointer); + static_cast(size); + static_cast(align); + // Most arena-based allocators will not deallocate individual allocations, so + // we default the implementation to std::abort(). + std::abort(); +} + +} // namespace cel diff --git a/base/memory_manager.h b/base/memory_manager.h new file mode 100644 index 000000000..22b36cc70 --- /dev/null +++ b/base/memory_manager.h @@ -0,0 +1,139 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_MEMORY_MANAGER_H_ +#define THIRD_PARTY_CEL_CPP_BASE_MEMORY_MANAGER_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "base/internal/memory_manager.h" + +namespace cel { + +// `ManagedMemory` is a smart pointer which ensures any applicable object +// destructors and deallocation are eventually performed upon its destruction. +// While `ManagedManager` is derived from `std::unique_ptr`, it does not make +// any guarantees that destructors and deallocation are run immediately upon its +// destruction, just that they will eventually be performed. +template +using ManagedMemory = + std::unique_ptr>; + +// `MemoryManager` is an abstraction over memory management that supports +// different allocation strategies. +class MemoryManager { + public: + ABSL_ATTRIBUTE_PURE_FUNCTION static MemoryManager* Global(); + + virtual ~MemoryManager() = default; + + // Allocates and constructs `T`. + // + // TODO(issues/5): mandate out of memory handling and return value? + template + ManagedMemory New(Args&&... args) ABSL_MUST_USE_RESULT { + auto [pointer, owned] = Allocate(sizeof(T), alignof(T)); + ::new (pointer) T(std::forward(args)...); + if (!owned) { + if constexpr (!std::is_trivially_destructible_v) { + OwnDestructor(pointer, + &base_internal::MemoryManagerDestructor::Destruct); + } + } + return ManagedMemory(reinterpret_cast(pointer), + base_internal::MemoryManagerDeleter( + owned ? this : nullptr, sizeof(T), alignof(T))); + } + + protected: + template + struct AllocationResult final { + Pointer pointer = nullptr; + // If true, the responsibility of deallocating and destructing `pointer` is + // passed to the caller of `Allocate`. + bool owned = false; + }; + + private: + template + friend class base_internal::MemoryManagerDeleter; + + // Delete a previous `New()` result when `AllocationResult::owned` is true. + template + void Delete(T* pointer, size_t size, size_t align) { + if (pointer != nullptr) { + if constexpr (!std::is_trivially_destructible_v) { + pointer->~T(); + } + Deallocate(pointer, size, align); + } + } + + // These are virtual private, ensuring only `MemoryManager` calls these. Which + // methods need to be implemented and which are called depends on whether the + // implementation is using arena memory management or not. + // + // If the implementation is using arenas then `Deallocate()` will never be + // called, `OwnDestructor` must be implemented, and `AllocationOnly` must + // return true. If the implementation is *not* using arenas then `Deallocate` + // must be implemented, `OwnDestructor` will never be called, and + // `AllocationOnly` will return false. + + // Allocates memory of at least size `size` in bytes that is at least as + // aligned as `align`. + virtual AllocationResult Allocate(size_t size, size_t align) = 0; + + // Deallocate the given pointer previously allocated via `Allocate`, assuming + // `AllocationResult::owned` was true. Calling this when + // `AllocationResult::owned` was false is undefined behavior. + virtual void Deallocate(void* pointer, size_t size, size_t align) = 0; + + // Registers a destructor to be run upon destruction of the memory management + // implementation. + // + // This method is only valid for arena memory managers. + virtual void OwnDestructor(void* pointer, void (*destruct)(void*)); +}; + +// Base class for all arena-based memory managers. +class ArenaMemoryManager : public MemoryManager { + private: + // Default implementation calls std::abort(). If you have a special case where + // you support deallocating individual allocations, override this. + void Deallocate(void* pointer, size_t size, size_t align) override; + + // OwnDestructor is typically required for arena-based memory managers. + void OwnDestructor(void* pointer, void (*destruct)(void*)) override = 0; +}; + +namespace base_internal { + +template +void MemoryManagerDeleter::operator()(T* pointer) const { + if (memory_manager_) { + memory_manager_->Delete(const_cast*>(pointer), size_, + align_); + } +} + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_MEMORY_MANAGER_H_ diff --git a/base/memory_manager_test.cc b/base/memory_manager_test.cc new file mode 100644 index 000000000..f9d8369a9 --- /dev/null +++ b/base/memory_manager_test.cc @@ -0,0 +1,83 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/memory_manager.h" + +#include + +#include "base/internal/memory_manager.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +struct TriviallyDestructible final {}; + +TEST(GlobalMemoryManager, TriviallyDestructible) { + EXPECT_TRUE(std::is_trivially_destructible_v); + auto managed = MemoryManager::Global()->New(); + EXPECT_FALSE(base_internal::IsEmptyDeleter(managed.get_deleter())); +} + +struct NotTriviallyDestuctible final { + ~NotTriviallyDestuctible() { Delete(); } + + MOCK_METHOD(void, Delete, (), ()); +}; + +TEST(GlobalMemoryManager, NotTriviallyDestuctible) { + EXPECT_FALSE(std::is_trivially_destructible_v); + auto managed = MemoryManager::Global()->New(); + EXPECT_FALSE(base_internal::IsEmptyDeleter(managed.get_deleter())); + EXPECT_CALL(*managed, Delete()); +} + +class BadMemoryManager final : public MemoryManager { + private: + AllocationResult Allocate(size_t size, size_t align) override { + // Return {..., false}, indicating that this was an arena allocation when it + // is not, causing OwnDestructor to be called and abort. + return {::operator new(size, static_cast(align)), false}; + } + + void Deallocate(void* pointer, size_t size, size_t align) override { + ::operator delete(pointer, size, static_cast(align)); + } +}; + +TEST(BadMemoryManager, OwnDestructorAborts) { + BadMemoryManager memory_manager; + EXPECT_EXIT(static_cast(memory_manager.New()), + testing::KilledBySignal(SIGABRT), ""); +} + +class BadArenaMemoryManager final : public ArenaMemoryManager { + private: + AllocationResult Allocate(size_t size, size_t align) override { + // Return {..., false}, indicating that this was an arena allocation when it + // is not, causing OwnDestructor to be called and abort. + return {::operator new(size, static_cast(align)), true}; + } + + void OwnDestructor(void* pointer, void (*destructor)(void*)) override {} +}; + +TEST(BadArenaMemoryManager, DeallocateAborts) { + BadArenaMemoryManager memory_manager; + EXPECT_EXIT(static_cast(memory_manager.New()), + testing::KilledBySignal(SIGABRT), ""); +} + +} // namespace +} // namespace cel diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD new file mode 100644 index 000000000..86588ba62 --- /dev/null +++ b/extensions/protobuf/BUILD @@ -0,0 +1,41 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # Under active development, not yet being released. + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +cc_library( + name = "memory_manager", + srcs = ["memory_manager.cc"], + hdrs = ["memory_manager.h"], + deps = [ + "//base:memory_manager", + "@com_google_absl//absl/base:core_headers", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "memory_manager_test", + srcs = ["memory_manager_test.cc"], + deps = [ + ":memory_manager", + "//internal:testing", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/extensions/protobuf/memory_manager.cc b/extensions/protobuf/memory_manager.cc new file mode 100644 index 000000000..485f5ada1 --- /dev/null +++ b/extensions/protobuf/memory_manager.cc @@ -0,0 +1,42 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/memory_manager.h" + +#include + +#include "absl/base/macros.h" + +namespace cel::extensions { + +MemoryManager::AllocationResult ProtoMemoryManager::Allocate( + size_t size, size_t align) { + if (arena_ != nullptr) { + return {arena_->AllocateAligned(size, align), false}; + } + return {::operator new(size, static_cast(align)), true}; +} + +void ProtoMemoryManager::Deallocate(void* pointer, size_t size, size_t align) { + // Only possible when `arena_` is nullptr. + ABSL_HARDENING_ASSERT(arena_ == nullptr); + ::operator delete(pointer, size, static_cast(align)); +} + +void ProtoMemoryManager::OwnDestructor(void* pointer, void (*destruct)(void*)) { + ABSL_HARDENING_ASSERT(arena_ != nullptr); + arena_->OwnCustomDestructor(pointer, destruct); +} + +} // namespace cel::extensions diff --git a/extensions/protobuf/memory_manager.h b/extensions/protobuf/memory_manager.h new file mode 100644 index 000000000..56d88aee6 --- /dev/null +++ b/extensions/protobuf/memory_manager.h @@ -0,0 +1,75 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_MEMORY_MANAGER_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_MEMORY_MANAGER_H_ + +#include + +#include "google/protobuf/arena.h" +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "base/memory_manager.h" + +namespace cel::extensions { + +// `ProtoMemoryManager` is an implementation of `ArenaMemoryManager` using +// `google::protobuf::Arena`. All allocations are valid so long as the underlying +// `google::protobuf::Arena` is still alive. +class ProtoMemoryManager final : public ArenaMemoryManager { + public: + // Passing a nullptr is highly discouraged, but supported for backwards + // compatibility. If `arena` is a nullptr, `ProtoMemoryManager` acts like + // `MemoryManager::Default()`. + explicit ProtoMemoryManager(google::protobuf::Arena* arena) : arena_(arena) {} + + ProtoMemoryManager(const ProtoMemoryManager&) = delete; + + ProtoMemoryManager(ProtoMemoryManager&&) = delete; + + ProtoMemoryManager& operator=(const ProtoMemoryManager&) = delete; + + ProtoMemoryManager& operator=(ProtoMemoryManager&&) = delete; + + google::protobuf::Arena* arena() const { return arena_; } + + private: + AllocationResult Allocate(size_t size, size_t align) override; + + void Deallocate(void* pointer, size_t size, size_t align) override; + + void OwnDestructor(void* pointer, void (*destruct)(void*)) override; + + google::protobuf::Arena* const arena_; +}; + +// Allocate and construct `T` using the `ProtoMemoryManager` provided as +// `memory_manager`. `memory_manager` must be `ProtoMemoryManager` or behavior +// is undefined. Unlike `MemoryManager::New`, this method supports arena-enabled +// messages. +template +ABSL_MUST_USE_RESULT T* NewInProtoArena(MemoryManager* memory_manager, + Args&&... args) { + ABSL_ASSERT(memory_manager != nullptr); +#if !defined(__GNUC__) || defined(__GXX_RTTI) + ABSL_ASSERT(dynamic_cast(memory_manager) != nullptr); +#endif + return google::protobuf::Arena::Create( + static_cast(memory_manager)->arena(), + std::forward(args)...); +} + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_MEMORY_MANAGER_H_ diff --git a/extensions/protobuf/memory_manager_test.cc b/extensions/protobuf/memory_manager_test.cc new file mode 100644 index 000000000..7d9170598 --- /dev/null +++ b/extensions/protobuf/memory_manager_test.cc @@ -0,0 +1,109 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/protobuf/memory_manager.h" + +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/arena.h" +#include "internal/testing.h" + +namespace cel::extensions { +namespace { + +struct NotArenaCompatible final { + ~NotArenaCompatible() { Delete(); } + + MOCK_METHOD(void, Delete, (), ()); +}; + +TEST(ProtoMemoryManager, ArenaConstructable) { + google::protobuf::Arena arena; + ProtoMemoryManager memory_manager(&arena); + EXPECT_TRUE( + google::protobuf::Arena::is_arena_constructable::value); + auto* object = NewInProtoArena(&memory_manager); + EXPECT_NE(object, nullptr); +} + +TEST(ProtoMemoryManager, NotArenaConstructable) { + google::protobuf::Arena arena; + ProtoMemoryManager memory_manager(&arena); + EXPECT_FALSE( + google::protobuf::Arena::is_arena_constructable::value); + auto* object = NewInProtoArena(&memory_manager); + EXPECT_NE(object, nullptr); + EXPECT_CALL(*object, Delete()); +} + +TEST(ProtoMemoryManagerNoArena, ArenaConstructable) { + ProtoMemoryManager memory_manager(nullptr); + EXPECT_TRUE( + google::protobuf::Arena::is_arena_constructable::value); + auto* object = NewInProtoArena(&memory_manager); + EXPECT_NE(object, nullptr); + delete object; +} + +TEST(ProtoMemoryManagerNoArena, NotArenaConstructable) { + ProtoMemoryManager memory_manager(nullptr); + EXPECT_FALSE( + google::protobuf::Arena::is_arena_constructable::value); + auto* object = NewInProtoArena(&memory_manager); + EXPECT_NE(object, nullptr); + EXPECT_CALL(*object, Delete()); + delete object; +} + +struct TriviallyDestructible final {}; + +struct NotTriviallyDestuctible final { + ~NotTriviallyDestuctible() { Delete(); } + + MOCK_METHOD(void, Delete, (), ()); +}; + +TEST(ProtoMemoryManager, TriviallyDestructible) { + google::protobuf::Arena arena; + ProtoMemoryManager memory_manager(&arena); + EXPECT_TRUE(std::is_trivially_destructible_v); + auto managed = memory_manager.New(); + EXPECT_TRUE(base_internal::IsEmptyDeleter(managed.get_deleter())); +} + +TEST(ProtoMemoryManager, NotTriviallyDestuctible) { + google::protobuf::Arena arena; + ProtoMemoryManager memory_manager(&arena); + EXPECT_FALSE(std::is_trivially_destructible_v); + auto managed = memory_manager.New(); + EXPECT_TRUE(base_internal::IsEmptyDeleter(managed.get_deleter())); + EXPECT_CALL(*managed, Delete()); +} + +TEST(ProtoMemoryManagerNoArena, TriviallyDestructible) { + ProtoMemoryManager memory_manager(nullptr); + EXPECT_TRUE(std::is_trivially_destructible_v); + auto managed = memory_manager.New(); + EXPECT_FALSE(base_internal::IsEmptyDeleter(managed.get_deleter())); +} + +TEST(ProtoMemoryManagerNoArena, NotTriviallyDestuctible) { + ProtoMemoryManager memory_manager(nullptr); + EXPECT_FALSE(std::is_trivially_destructible_v); + auto managed = memory_manager.New(); + EXPECT_FALSE(base_internal::IsEmptyDeleter(managed.get_deleter())); + EXPECT_CALL(*managed, Delete()); +} + +} // namespace +} // namespace cel::extensions From 589145f696458125a4ff5eb7afbb125dfdc2b2aa Mon Sep 17 00:00:00 2001 From: tswadell Date: Mon, 14 Feb 2022 16:42:07 +0000 Subject: [PATCH 004/111] Internal change. PiperOrigin-RevId: 428518876 --- eval/public/extension_func_test.cc | 4 ++-- eval/public/testing/matchers.cc | 17 ++++++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/eval/public/extension_func_test.cc b/eval/public/extension_func_test.cc index 7f3d05b05..0ac9c3f18 100644 --- a/eval/public/extension_func_test.cc +++ b/eval/public/extension_func_test.cc @@ -79,7 +79,7 @@ class ExtensionTest : public ::testing::Test { } // Helper method to test timestamp() function - void PerformTimestampConversion(Arena* arena, std::string ts_str, + void PerformTimestampConversion(Arena* arena, const std::string& ts_str, CelValue* result) { auto functions = registry_.FindOverloads("timestamp", false, {CelValue::Type::kString}); @@ -240,7 +240,7 @@ class ExtensionTest : public ::testing::Test { } // Helper method to test duration() function - void PerformDurationConversion(Arena* arena, std::string ts_str, + void PerformDurationConversion(Arena* arena, const std::string& ts_str, CelValue* result) { auto functions = registry_.FindOverloads("duration", false, {CelValue::Type::kString}); diff --git a/eval/public/testing/matchers.cc b/eval/public/testing/matchers.cc index 18eb8b480..a8333d210 100644 --- a/eval/public/testing/matchers.cc +++ b/eval/public/testing/matchers.cc @@ -1,5 +1,7 @@ #include "eval/public/testing/matchers.h" +#include + #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/strings/string_view.h" @@ -67,19 +69,19 @@ CelValueMatcher EqualsCelValue(const CelValue& v) { } CelValueMatcher IsCelBool(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelInt64(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelUint64(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelDouble(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelString(testing::Matcher m) { @@ -93,15 +95,16 @@ CelValueMatcher IsCelBytes(testing::Matcher m) { } CelValueMatcher IsCelMessage(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher( + new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelDuration(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelTimestamp(testing::Matcher m) { - return CelValueMatcher(new CelValueMatcherImpl(m)); + return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelError(testing::Matcher m) { From c5bbf197f416ba07f56ce875e442af523c2bda97 Mon Sep 17 00:00:00 2001 From: tswadell Date: Mon, 14 Feb 2022 17:31:30 +0000 Subject: [PATCH 005/111] Internal change. PiperOrigin-RevId: 428531745 --- eval/compiler/flat_expr_builder_test.cc | 2 -- eval/eval/create_struct_step_test.cc | 1 - eval/eval/ident_step_test.cc | 1 - eval/public/containers/field_backed_map_impl.cc | 1 - internal/strings_test.cc | 1 - parser/parser.cc | 1 - tools/flatbuffers_backed_impl_test.cc | 2 -- 7 files changed, 9 deletions(-) diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index c0fcc0899..5503b3001 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -63,10 +63,8 @@ using google::api::expr::v1alpha1::Expr; using google::api::expr::v1alpha1::ParsedExpr; using google::api::expr::v1alpha1::SourceInfo; -using google::protobuf::FieldMask; using testing::Eq; using testing::HasSubstr; -using cel::internal::IsOk; using cel::internal::StatusIs; inline constexpr absl::string_view kSimpleTestMessageDescriptorSetFile = diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index 80395e49a..c54b29db8 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -27,7 +27,6 @@ using ::google::protobuf::Arena; using ::google::protobuf::Message; using testing::Eq; -using testing::HasSubstr; using testing::IsNull; using testing::Not; using testing::Pointwise; diff --git a/eval/eval/ident_step_test.cc b/eval/eval/ident_step_test.cc index 60680dbdc..5bbd692ef 100644 --- a/eval/eval/ident_step_test.cc +++ b/eval/eval/ident_step_test.cc @@ -15,7 +15,6 @@ namespace google::api::expr::runtime { namespace { using ::google::api::expr::v1alpha1::Expr; -using ::google::protobuf::FieldMask; using testing::Eq; using google::protobuf::Arena; diff --git a/eval/public/containers/field_backed_map_impl.cc b/eval/public/containers/field_backed_map_impl.cc index aafa85db4..7f7460f99 100644 --- a/eval/public/containers/field_backed_map_impl.cc +++ b/eval/public/containers/field_backed_map_impl.cc @@ -44,7 +44,6 @@ namespace expr { namespace runtime { namespace { -using google::protobuf::Arena; using google::protobuf::Descriptor; using google::protobuf::FieldDescriptor; using google::protobuf::MapValueConstRef; diff --git a/internal/strings_test.cc b/internal/strings_test.cc index a550e30e9..abcac7e93 100644 --- a/internal/strings_test.cc +++ b/internal/strings_test.cc @@ -26,7 +26,6 @@ namespace cel::internal { namespace { -using cel::internal::IsOk; using cel::internal::StatusIs; constexpr char kUnicodeNotAllowedInBytes1[] = diff --git a/parser/parser.cc b/parser/parser.cc index 0fc1db41a..f810408cf 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -57,7 +57,6 @@ namespace { using ::antlr4::CharStream; using ::antlr4::CommonTokenStream; using ::antlr4::DefaultErrorStrategy; -using ::antlr4::IntStream; using ::antlr4::ParseCancellationException; using ::antlr4::Parser; using ::antlr4::ParserRuleContext; diff --git a/tools/flatbuffers_backed_impl_test.cc b/tools/flatbuffers_backed_impl_test.cc index 349dbea23..9f55f793a 100644 --- a/tools/flatbuffers_backed_impl_test.cc +++ b/tools/flatbuffers_backed_impl_test.cc @@ -14,8 +14,6 @@ namespace runtime { namespace { -using google::protobuf::Arena; - constexpr char kReflectionBufferPath[] = "tools/testdata/" "flatbuffers.bfbs"; From c0ad193b0d85aaa4dd278a156fcc7822e6e8260c Mon Sep 17 00:00:00 2001 From: jcking Date: Mon, 14 Feb 2022 18:21:36 +0000 Subject: [PATCH 006/111] Internal change PiperOrigin-RevId: 428545658 --- base/memory_manager.cc | 4 +++- base/memory_manager.h | 19 ++++++++++++------- extensions/protobuf/memory_manager.cc | 4 +++- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/base/memory_manager.cc b/base/memory_manager.cc index 9e1805c38..1daff0c08 100644 --- a/base/memory_manager.cc +++ b/base/memory_manager.cc @@ -27,7 +27,9 @@ namespace { class GlobalMemoryManager final : public MemoryManager { private: AllocationResult Allocate(size_t size, size_t align) override { - return {::operator new(size, static_cast(align)), true}; + return {::operator new(size, static_cast(align), + std::nothrow), + true}; } void Deallocate(void* pointer, size_t size, size_t align) override { diff --git a/base/memory_manager.h b/base/memory_manager.h index 22b36cc70..73cbb2763 100644 --- a/base/memory_manager.h +++ b/base/memory_manager.h @@ -22,6 +22,7 @@ #include "absl/base/attributes.h" #include "absl/base/macros.h" +#include "absl/base/optimization.h" #include "base/internal/memory_manager.h" namespace cel { @@ -43,22 +44,26 @@ class MemoryManager { virtual ~MemoryManager() = default; - // Allocates and constructs `T`. - // - // TODO(issues/5): mandate out of memory handling and return value? + // Allocates and constructs `T`. In the event of an allocation failure nullptr + // is returned. template ManagedMemory New(Args&&... args) ABSL_MUST_USE_RESULT { - auto [pointer, owned] = Allocate(sizeof(T), alignof(T)); + size_t size = sizeof(T); + size_t align = alignof(T); + auto [pointer, owned] = Allocate(size, align); + if (ABSL_PREDICT_FALSE(pointer == nullptr)) { + return ManagedMemory(); + } ::new (pointer) T(std::forward(args)...); - if (!owned) { - if constexpr (!std::is_trivially_destructible_v) { + if constexpr (!std::is_trivially_destructible_v) { + if (!owned) { OwnDestructor(pointer, &base_internal::MemoryManagerDestructor::Destruct); } } return ManagedMemory(reinterpret_cast(pointer), base_internal::MemoryManagerDeleter( - owned ? this : nullptr, sizeof(T), alignof(T))); + owned ? this : nullptr, size, align)); } protected: diff --git a/extensions/protobuf/memory_manager.cc b/extensions/protobuf/memory_manager.cc index 485f5ada1..9d88069c4 100644 --- a/extensions/protobuf/memory_manager.cc +++ b/extensions/protobuf/memory_manager.cc @@ -25,7 +25,9 @@ MemoryManager::AllocationResult ProtoMemoryManager::Allocate( if (arena_ != nullptr) { return {arena_->AllocateAligned(size, align), false}; } - return {::operator new(size, static_cast(align)), true}; + return { + ::operator new(size, static_cast(align), std::nothrow), + true}; } void ProtoMemoryManager::Deallocate(void* pointer, size_t size, size_t align) { From 7fef2ddc700633c34d7c0e92188412d548e0691c Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 15 Feb 2022 00:40:43 +0000 Subject: [PATCH 007/111] Add allocation focused benchmarks. PiperOrigin-RevId: 428639120 --- eval/tests/BUILD | 33 ++++ eval/tests/allocation_benchmark_test.cc | 220 ++++++++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 eval/tests/allocation_benchmark_test.cc diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 4146afdf6..d3908d152 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -42,6 +42,39 @@ cc_test( ], ) +cc_test( + name = "allocation_benchmark_test", + size = "small", + srcs = [ + "allocation_benchmark_test.cc", + ], + deps = [ + ":request_context_cc_proto", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/containers:container_backed_map_impl", + "//eval/public/structs:cel_proto_wrapper", + "//internal:benchmark", + "//internal:status_macros", + "//internal:testing", + "//parser", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "expression_builder_benchmark_test", size = "small", diff --git a/eval/tests/allocation_benchmark_test.cc b/eval/tests/allocation_benchmark_test.cc new file mode 100644 index 000000000..20bd0849a --- /dev/null +++ b/eval/tests/allocation_benchmark_test.cc @@ -0,0 +1,220 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/text_format.h" +#include "absl/base/attributes.h" +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/substitute.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/tests/request_context.pb.h" +#include "internal/benchmark.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "parser/parser.h" + +namespace google::api::expr::runtime { +namespace { + +using ::google::api::expr::parser::Parse; +using testing::HasSubstr; +using cel::internal::StatusIs; + +// Evaluates cel expression: +// '"1" + "1" + ...' +static void BM_StrCatLocalArena(benchmark::State& state) { + std::string expr("'1'"); + int len = state.range(0); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + for (int i = 0; i < len; i++) { + expr = absl::Substitute("($0 + $0)", expr); + } + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + CelValue::StringHolder holder; + ASSERT_TRUE(result.GetValue(&holder)); + ASSERT_EQ(holder.value().length(), 1 << len); + } +} +BENCHMARK(BM_StrCatLocalArena)->DenseRange(0, 8, 2); + +// Evaluates cel expression: +// '("1" + "1") + ...' +static void BM_StrCatSharedArena(benchmark::State& state) { + google::protobuf::Arena arena; + std::string expr("'1'"); + int len = state.range(0); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + for (int i = 0; i < len; i++) { + expr = absl::Substitute("($0 + $0)", expr); + } + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + CelValue::StringHolder holder; + ASSERT_TRUE(result.GetValue(&holder)); + ASSERT_EQ(holder.value().length(), 1 << len); + } +} + +// Expression grows exponentially. +BENCHMARK(BM_StrCatSharedArena)->DenseRange(0, 8, 2); + +// Series of simple expressions that are expected to require an allocation. +static void BM_AllocateString(benchmark::State& state) { + google::protobuf::Arena arena; + std::string expr("'1' + '1'"); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + CelValue::StringHolder holder; + ASSERT_TRUE(result.GetValue(&holder)); + ASSERT_EQ(holder.value(), "11"); + } +} +BENCHMARK(BM_AllocateString); + +static void BM_AllocateError(benchmark::State& state) { + google::protobuf::Arena arena; + std::string expr("1 / 0"); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + const CelError* value; + ASSERT_TRUE(result.GetValue(&value)); + ASSERT_THAT(*value, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("divide by zero"))); + } +} +BENCHMARK(BM_AllocateError); + +static void BM_AllocateMap(benchmark::State& state) { + google::protobuf::Arena arena; + std::string expr("{1: 2, 3: 4}"); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsMap()); + } +} + +BENCHMARK(BM_AllocateMap); + +static void BM_AllocateMessage(benchmark::State& state) { + google::protobuf::Arena arena; + std::string expr( + "google.api.expr.runtime.RequestContext{" + "ip: '192.168.0.1'," + "path: '/root'}"); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsMessage()); + } +} + +BENCHMARK(BM_AllocateMessage); + +static void BM_AllocateList(benchmark::State& state) { + google::protobuf::Arena arena; + std::string expr("[1, 2, 3, 4]"); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsList()); + } +} +BENCHMARK(BM_AllocateList); + +} // namespace +} // namespace google::api::expr::runtime From dd133a9cccbd029761e47171266f4fdf8c35c501 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 17 Feb 2022 19:54:20 +0000 Subject: [PATCH 008/111] Add overloads for cel error factories to use the cel::MemoryManager instead of directly using proto2::Arena api. PiperOrigin-RevId: 429366046 --- eval/compiler/flat_expr_builder_test.cc | 2 +- eval/public/BUILD | 6 ++ eval/public/cel_value.cc | 79 +++++++++++++++++++++---- eval/public/cel_value.h | 57 ++++++++++++++---- eval/public/cel_value_test.cc | 58 +++++++++++++++--- 5 files changed, 168 insertions(+), 34 deletions(-) diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 5503b3001..be684a7c9 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1071,7 +1071,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(result.ErrorOrDie()->message(), - Eq("No matching overloads found ")); + Eq("No matching overloads found : ")); } TEST(FlatExprBuilderTest, ComprehensionBudget) { diff --git a/eval/public/BUILD b/eval/public/BUILD index 35283c8f6..11bf8a0ea 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -38,6 +38,8 @@ cc_library( ], deps = [ ":cel_value_internal", + "//base:memory_manager", + "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:utf8", "@com_google_absl//absl/base:core_headers", @@ -459,6 +461,10 @@ cc_test( ":cel_value", ":unknown_attribute_set", ":unknown_set", + "//base:memory_manager", + "//eval/public/testing:matchers", + "//extensions/protobuf:memory_manager", + "//internal:status_macros", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/eval/public/cel_value.cc b/eval/public/cel_value.cc index 98de290df..603f1bd96 100644 --- a/eval/public/cel_value.cc +++ b/eval/public/cel_value.cc @@ -9,12 +9,15 @@ #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "base/memory_manager.h" +#include "extensions/protobuf/memory_manager.h" namespace google::api::expr::runtime { namespace { -using google::protobuf::Arena; +using ::cel::extensions::NewInProtoArena; +using ::google::protobuf::Arena; constexpr char kErrNoMatchingOverload[] = "No matching overloads found"; constexpr char kErrNoSuchField[] = "no_such_field"; @@ -232,21 +235,35 @@ const std::string CelValue::DebugString() const { Visit(DebugStringVisitor())); } +CelValue CreateErrorValue(cel::MemoryManager& manager, + absl::string_view message, + absl::StatusCode error_code) { + // TODO(issues/5): assume arena-style allocator while migrating to new + // value type. + CelError* error = NewInProtoArena(&manager, error_code, message); + return CelValue::CreateError(error); +} + CelValue CreateErrorValue(Arena* arena, absl::string_view message, - absl::StatusCode error_code, int) { + absl::StatusCode error_code) { CelError* error = Arena::Create(arena, error_code, message); return CelValue::CreateError(error); } -CelValue CreateNoMatchingOverloadError(google::protobuf::Arena* arena) { - return CreateErrorValue(arena, kErrNoMatchingOverload, - absl::StatusCode::kUnknown); +CelValue CreateNoMatchingOverloadError(cel::MemoryManager& manager, + absl::string_view fn) { + return CreateErrorValue( + manager, + absl::StrCat(kErrNoMatchingOverload, (!fn.empty()) ? " : " : "", fn), + absl::StatusCode::kUnknown); } CelValue CreateNoMatchingOverloadError(google::protobuf::Arena* arena, absl::string_view fn) { - return CreateErrorValue(arena, absl::StrCat(kErrNoMatchingOverload, " ", fn), - absl::StatusCode::kUnknown); + return CreateErrorValue( + arena, + absl::StrCat(kErrNoMatchingOverload, (!fn.empty()) ? " : " : "", fn), + absl::StatusCode::kUnknown); } bool CheckNoMatchingOverloadError(CelValue value) { @@ -256,12 +273,26 @@ bool CheckNoMatchingOverloadError(CelValue value) { kErrNoMatchingOverload); } +CelValue CreateNoSuchFieldError(cel::MemoryManager& manager, + absl::string_view field) { + return CreateErrorValue( + manager, + absl::StrCat(kErrNoSuchField, !field.empty() ? " : " : "", field), + absl::StatusCode::kNotFound); +} + CelValue CreateNoSuchFieldError(google::protobuf::Arena* arena, absl::string_view field) { return CreateErrorValue( arena, absl::StrCat(kErrNoSuchField, !field.empty() ? " : " : "", field), absl::StatusCode::kNotFound); } +CelValue CreateNoSuchKeyError(cel::MemoryManager& manager, + absl::string_view key) { + return CreateErrorValue(manager, absl::StrCat(kErrNoSuchKey, " : ", key), + absl::StatusCode::kNotFound); +} + CelValue CreateNoSuchKeyError(google::protobuf::Arena* arena, absl::string_view key) { return CreateErrorValue(arena, absl::StrCat(kErrNoSuchKey, " : ", key), absl::StatusCode::kNotFound); @@ -302,9 +333,21 @@ CelValue CreateMissingAttributeError(google::protobuf::Arena* arena, return CelValue::CreateError(error); } +CelValue CreateMissingAttributeError(cel::MemoryManager& manager, + absl::string_view missing_attribute_path) { + // TODO(issues/5): assume arena-style allocator while migrating + // to new value type. + CelError* error = NewInProtoArena( + &manager, absl::StatusCode::kInvalidArgument, + absl::StrCat(kErrMissingAttribute, missing_attribute_path)); + error->SetPayload(kPayloadUrlMissingAttributePath, + absl::Cord(missing_attribute_path)); + return CelValue::CreateError(error); +} + bool IsMissingAttributeError(const CelValue& value) { - if (!value.IsError()) return false; - const CelError* error = value.ErrorOrDie(); // Crash ok + const CelError* error; + if (!value.GetValue(&error)) return false; if (error && error->code() == absl::StatusCode::kInvalidArgument) { auto path = error->GetPayload(kPayloadUrlMissingAttributePath); return path.has_value(); @@ -312,6 +355,17 @@ bool IsMissingAttributeError(const CelValue& value) { return false; } +CelValue CreateUnknownFunctionResultError(cel::MemoryManager& manager, + absl::string_view help_message) { + // TODO(issues/5): Assume arena-style allocation until new value type is + // introduced + CelError* error = NewInProtoArena( + &manager, absl::StatusCode::kUnavailable, + absl::StrCat("Unknown function result: ", help_message)); + error->SetPayload(kPayloadUrlUnknownFunctionResult, absl::Cord("true")); + return CelValue::CreateError(error); +} + CelValue CreateUnknownFunctionResultError(google::protobuf::Arena* arena, absl::string_view help_message) { CelError* error = Arena::Create( @@ -322,10 +376,9 @@ CelValue CreateUnknownFunctionResultError(google::protobuf::Arena* arena, } bool IsUnknownFunctionResult(const CelValue& value) { - if (!value.IsError()) { - return false; - } - const CelError* error = value.ErrorOrDie(); + const CelError* error; + if (!value.GetValue(&error)) return false; + if (error == nullptr || error->code() != absl::StatusCode::kUnavailable) { return false; } diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index cce9bb233..7d09b89af 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -32,6 +32,7 @@ #include "absl/time/time.h" #include "absl/types/optional.h" #include "absl/types/variant.h" +#include "base/memory_manager.h" #include "eval/public/cel_value_internal.h" #include "internal/status_macros.h" #include "internal/utf8.h" @@ -529,12 +530,20 @@ class CelMap { // Utility method that generates CelValue containing CelError. // message an error message // error_code error code -// position location of the error source in CEL expression string the Expr was -// parsed from. -1, if the position can not be determined. +CelValue CreateErrorValue( + cel::MemoryManager& manager ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view message, + absl::StatusCode error_code = absl::StatusCode::kUnknown); CelValue CreateErrorValue( google::protobuf::Arena* arena, absl::string_view message, - absl::StatusCode error_code = absl::StatusCode::kUnknown, - int position = -1); + absl::StatusCode error_code = absl::StatusCode::kUnknown); + +// Utility method for generating a CelValue from an absl::Status. +inline CelValue CreateErrorValue(cel::MemoryManager& manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const absl::Status& status) { + return CreateErrorValue(manager, status.message(), status.code()); +} // Utility method for generating a CelValue from an absl::Status. inline CelValue CreateErrorValue(google::protobuf::Arena* arena, @@ -542,28 +551,39 @@ inline CelValue CreateErrorValue(google::protobuf::Arena* arena, return CreateErrorValue(arena, status.message(), status.code()); } -CelValue CreateNoMatchingOverloadError(google::protobuf::Arena* arena); +// Create an error for failed overload resolution, optionally including the name +// of the function. +CelValue CreateNoMatchingOverloadError(cel::MemoryManager& manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view fn = ""); +ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") CelValue CreateNoMatchingOverloadError(google::protobuf::Arena* arena, - absl::string_view fn); + absl::string_view fn = ""); bool CheckNoMatchingOverloadError(CelValue value); +CelValue CreateNoSuchFieldError(cel::MemoryManager& manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view field = ""); +ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") CelValue CreateNoSuchFieldError(google::protobuf::Arena* arena, absl::string_view field = ""); +CelValue CreateNoSuchKeyError(cel::MemoryManager& manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view key); +ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") CelValue CreateNoSuchKeyError(google::protobuf::Arena* arena, absl::string_view key); -bool CheckNoSuchKeyError(CelValue value); - -ABSL_DEPRECATED("This type of error is no longer used by the evaluator.") -CelValue CreateUnknownValueError(google::protobuf::Arena* arena, - absl::string_view unknown_path); -ABSL_DEPRECATED("This type of error is no longer used by the evaluator.") -bool IsUnknownValueError(const CelValue& value); +bool CheckNoSuchKeyError(CelValue value); // Returns an error indicating that evaluation has accessed an attribute whose // value is undefined. For example, this may represent a field in a proto // message bound to the activation whose value can't be determined by the // hosting application. +CelValue CreateMissingAttributeError(cel::MemoryManager& manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view missing_attribute_path); +ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") CelValue CreateMissingAttributeError(google::protobuf::Arena* arena, absl::string_view missing_attribute_path); @@ -572,6 +592,10 @@ bool IsMissingAttributeError(const CelValue& value); // Returns error indicating the result of the function is unknown. This is used // as a signal to create an unknown set if unknown function handling is opted // into. +CelValue CreateUnknownFunctionResultError(cel::MemoryManager& manager + ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::string_view help_message); +ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") CelValue CreateUnknownFunctionResultError(google::protobuf::Arena* arena, absl::string_view help_message); @@ -581,6 +605,13 @@ CelValue CreateUnknownFunctionResultError(google::protobuf::Arena* arena, // into. bool IsUnknownFunctionResult(const CelValue& value); +ABSL_DEPRECATED("This type of error is no longer used by the evaluator.") +CelValue CreateUnknownValueError(google::protobuf::Arena* arena, + absl::string_view unknown_path); + +ABSL_DEPRECATED("This type of error is no longer used by the evaluator.") +bool IsUnknownValueError(const CelValue& value); + } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_H_ diff --git a/eval/public/cel_value_test.cc b/eval/public/cel_value_test.cc index 232f0d44c..89955f40d 100644 --- a/eval/public/cel_value_test.cc +++ b/eval/public/cel_value_test.cc @@ -6,13 +6,18 @@ #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "base/memory_manager.h" +#include "eval/public/testing/matchers.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" #include "internal/testing.h" namespace google::api::expr::runtime { using testing::Eq; +using cel::internal::StatusIs; class DummyMap : public CelMap { public: @@ -272,11 +277,6 @@ TEST(CelValueTest, TestCelType) { CelValue value_unknown = CelValue::CreateUnknownSet(&unknown_set); EXPECT_THAT(value_unknown.type(), Eq(CelValue::Type::kUnknownSet)); EXPECT_TRUE(value_unknown.ObtainCelType().IsUnknownSet()); - - CelValue missing_attribute_error = - CreateMissingAttributeError(&arena, "destination.ip"); - EXPECT_TRUE(IsMissingAttributeError(missing_attribute_error)); - EXPECT_TRUE(missing_attribute_error.ObtainCelType().IsError()); } // This test verifies CelValue support of Unknown type. @@ -294,14 +294,58 @@ TEST(CelValueTest, TestUnknownSet) { EXPECT_THAT(CountTypeMatch(value), Eq(1)); } -TEST(CelValueTest, UnknownFunctionResultErrors) { - ::google::protobuf::Arena arena; +TEST(CelValueTest, SpecialErrorFactories) { + google::protobuf::Arena arena; + cel::extensions::ProtoMemoryManager manager(&arena); + + CelValue error = CreateNoSuchKeyError(manager, "key"); + EXPECT_THAT(error, test::IsCelError(StatusIs(absl::StatusCode::kNotFound))); + EXPECT_TRUE(CheckNoSuchKeyError(error)); + + error = CreateNoSuchFieldError(manager, "field"); + EXPECT_THAT(error, test::IsCelError(StatusIs(absl::StatusCode::kNotFound))); + + error = CreateNoMatchingOverloadError(manager, "function"); + EXPECT_THAT(error, test::IsCelError(StatusIs(absl::StatusCode::kUnknown))); + EXPECT_TRUE(CheckNoMatchingOverloadError(error)); +} + +TEST(CelValueTest, MissingAttributeErrorsDeprecated) { + google::protobuf::Arena arena; + + CelValue missing_attribute_error = + CreateMissingAttributeError(&arena, "destination.ip"); + EXPECT_TRUE(IsMissingAttributeError(missing_attribute_error)); + EXPECT_TRUE(missing_attribute_error.ObtainCelType().IsError()); +} + +TEST(CelValueTest, MissingAttributeErrors) { + google::protobuf::Arena arena; + cel::extensions::ProtoMemoryManager manager(&arena); + + CelValue missing_attribute_error = + CreateMissingAttributeError(manager, "destination.ip"); + EXPECT_TRUE(IsMissingAttributeError(missing_attribute_error)); + EXPECT_TRUE(missing_attribute_error.ObtainCelType().IsError()); +} + +TEST(CelValueTest, UnknownFunctionResultErrorsDeprecated) { + google::protobuf::Arena arena; CelValue value = CreateUnknownFunctionResultError(&arena, "message"); EXPECT_TRUE(value.IsError()); EXPECT_TRUE(IsUnknownFunctionResult(value)); } +TEST(CelValueTest, UnknownFunctionResultErrors) { + google::protobuf::Arena arena; + cel::extensions::ProtoMemoryManager manager(&arena); + + CelValue value = CreateUnknownFunctionResultError(manager, "message"); + EXPECT_TRUE(value.IsError()); + EXPECT_TRUE(IsUnknownFunctionResult(value)); +} + TEST(CelValueTest, DebugString) { EXPECT_EQ(CelValue::CreateNull().DebugString(), "null_type: null"); EXPECT_EQ(CelValue::CreateBool(true).DebugString(), "bool: 1"); From 720f5bd890f4d04c5a282a7589d7b727d421d969 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 17 Feb 2022 22:00:11 +0000 Subject: [PATCH 009/111] Internal change PiperOrigin-RevId: 429395768 --- base/memory_manager.cc | 4 ++-- base/memory_manager.h | 5 +++-- base/memory_manager_test.cc | 4 ++-- eval/public/cel_value.cc | 6 +++--- extensions/protobuf/memory_manager.h | 9 ++++----- extensions/protobuf/memory_manager_test.cc | 8 ++++---- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/base/memory_manager.cc b/base/memory_manager.cc index 1daff0c08..f10c8b406 100644 --- a/base/memory_manager.cc +++ b/base/memory_manager.cc @@ -39,9 +39,9 @@ class GlobalMemoryManager final : public MemoryManager { } // namespace -MemoryManager* MemoryManager::Global() { +MemoryManager& MemoryManager::Global() { static MemoryManager* const instance = new GlobalMemoryManager(); - return instance; + return *instance; } void MemoryManager::OwnDestructor(void* pointer, void (*destruct)(void*)) { diff --git a/base/memory_manager.h b/base/memory_manager.h index 73cbb2763..5903f23aa 100644 --- a/base/memory_manager.h +++ b/base/memory_manager.h @@ -40,14 +40,15 @@ using ManagedMemory = // different allocation strategies. class MemoryManager { public: - ABSL_ATTRIBUTE_PURE_FUNCTION static MemoryManager* Global(); + ABSL_ATTRIBUTE_PURE_FUNCTION static MemoryManager& Global(); virtual ~MemoryManager() = default; // Allocates and constructs `T`. In the event of an allocation failure nullptr // is returned. template - ManagedMemory New(Args&&... args) ABSL_MUST_USE_RESULT { + ManagedMemory New(Args&&... args) + ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { size_t size = sizeof(T); size_t align = alignof(T); auto [pointer, owned] = Allocate(size, align); diff --git a/base/memory_manager_test.cc b/base/memory_manager_test.cc index f9d8369a9..6f7a70f13 100644 --- a/base/memory_manager_test.cc +++ b/base/memory_manager_test.cc @@ -26,7 +26,7 @@ struct TriviallyDestructible final {}; TEST(GlobalMemoryManager, TriviallyDestructible) { EXPECT_TRUE(std::is_trivially_destructible_v); - auto managed = MemoryManager::Global()->New(); + auto managed = MemoryManager::Global().New(); EXPECT_FALSE(base_internal::IsEmptyDeleter(managed.get_deleter())); } @@ -38,7 +38,7 @@ struct NotTriviallyDestuctible final { TEST(GlobalMemoryManager, NotTriviallyDestuctible) { EXPECT_FALSE(std::is_trivially_destructible_v); - auto managed = MemoryManager::Global()->New(); + auto managed = MemoryManager::Global().New(); EXPECT_FALSE(base_internal::IsEmptyDeleter(managed.get_deleter())); EXPECT_CALL(*managed, Delete()); } diff --git a/eval/public/cel_value.cc b/eval/public/cel_value.cc index 603f1bd96..d84993e00 100644 --- a/eval/public/cel_value.cc +++ b/eval/public/cel_value.cc @@ -240,7 +240,7 @@ CelValue CreateErrorValue(cel::MemoryManager& manager, absl::StatusCode error_code) { // TODO(issues/5): assume arena-style allocator while migrating to new // value type. - CelError* error = NewInProtoArena(&manager, error_code, message); + CelError* error = NewInProtoArena(manager, error_code, message); return CelValue::CreateError(error); } @@ -338,7 +338,7 @@ CelValue CreateMissingAttributeError(cel::MemoryManager& manager, // TODO(issues/5): assume arena-style allocator while migrating // to new value type. CelError* error = NewInProtoArena( - &manager, absl::StatusCode::kInvalidArgument, + manager, absl::StatusCode::kInvalidArgument, absl::StrCat(kErrMissingAttribute, missing_attribute_path)); error->SetPayload(kPayloadUrlMissingAttributePath, absl::Cord(missing_attribute_path)); @@ -360,7 +360,7 @@ CelValue CreateUnknownFunctionResultError(cel::MemoryManager& manager, // TODO(issues/5): Assume arena-style allocation until new value type is // introduced CelError* error = NewInProtoArena( - &manager, absl::StatusCode::kUnavailable, + manager, absl::StatusCode::kUnavailable, absl::StrCat("Unknown function result: ", help_message)); error->SetPayload(kPayloadUrlUnknownFunctionResult, absl::Cord("true")); return CelValue::CreateError(error); diff --git a/extensions/protobuf/memory_manager.h b/extensions/protobuf/memory_manager.h index 56d88aee6..d13e94bd9 100644 --- a/extensions/protobuf/memory_manager.h +++ b/extensions/protobuf/memory_manager.h @@ -42,7 +42,7 @@ class ProtoMemoryManager final : public ArenaMemoryManager { ProtoMemoryManager& operator=(ProtoMemoryManager&&) = delete; - google::protobuf::Arena* arena() const { return arena_; } + constexpr google::protobuf::Arena* arena() const { return arena_; } private: AllocationResult Allocate(size_t size, size_t align) override; @@ -59,14 +59,13 @@ class ProtoMemoryManager final : public ArenaMemoryManager { // is undefined. Unlike `MemoryManager::New`, this method supports arena-enabled // messages. template -ABSL_MUST_USE_RESULT T* NewInProtoArena(MemoryManager* memory_manager, +ABSL_MUST_USE_RESULT T* NewInProtoArena(MemoryManager& memory_manager, Args&&... args) { - ABSL_ASSERT(memory_manager != nullptr); #if !defined(__GNUC__) || defined(__GXX_RTTI) - ABSL_ASSERT(dynamic_cast(memory_manager) != nullptr); + ABSL_ASSERT(dynamic_cast(&memory_manager) != nullptr); #endif return google::protobuf::Arena::Create( - static_cast(memory_manager)->arena(), + static_cast(memory_manager).arena(), std::forward(args)...); } diff --git a/extensions/protobuf/memory_manager_test.cc b/extensions/protobuf/memory_manager_test.cc index 7d9170598..0db014f2d 100644 --- a/extensions/protobuf/memory_manager_test.cc +++ b/extensions/protobuf/memory_manager_test.cc @@ -32,7 +32,7 @@ TEST(ProtoMemoryManager, ArenaConstructable) { ProtoMemoryManager memory_manager(&arena); EXPECT_TRUE( google::protobuf::Arena::is_arena_constructable::value); - auto* object = NewInProtoArena(&memory_manager); + auto* object = NewInProtoArena(memory_manager); EXPECT_NE(object, nullptr); } @@ -41,7 +41,7 @@ TEST(ProtoMemoryManager, NotArenaConstructable) { ProtoMemoryManager memory_manager(&arena); EXPECT_FALSE( google::protobuf::Arena::is_arena_constructable::value); - auto* object = NewInProtoArena(&memory_manager); + auto* object = NewInProtoArena(memory_manager); EXPECT_NE(object, nullptr); EXPECT_CALL(*object, Delete()); } @@ -50,7 +50,7 @@ TEST(ProtoMemoryManagerNoArena, ArenaConstructable) { ProtoMemoryManager memory_manager(nullptr); EXPECT_TRUE( google::protobuf::Arena::is_arena_constructable::value); - auto* object = NewInProtoArena(&memory_manager); + auto* object = NewInProtoArena(memory_manager); EXPECT_NE(object, nullptr); delete object; } @@ -59,7 +59,7 @@ TEST(ProtoMemoryManagerNoArena, NotArenaConstructable) { ProtoMemoryManager memory_manager(nullptr); EXPECT_FALSE( google::protobuf::Arena::is_arena_constructable::value); - auto* object = NewInProtoArena(&memory_manager); + auto* object = NewInProtoArena(memory_manager); EXPECT_NE(object, nullptr); EXPECT_CALL(*object, Delete()); delete object; From 3392c9688d922057202469e32831bfd95acb9b55 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 23 Feb 2022 17:35:53 +0000 Subject: [PATCH 010/111] Move internal usages of proto2::arena to cel::MemoryManager where possible. This helps move toward making proto2::arena an optional dependency and enabling alternative memory management strategies. PiperOrigin-RevId: 430473093 --- eval/eval/BUILD | 23 ++++++++- eval/eval/attribute_trail.cc | 15 ++++-- eval/eval/attribute_trail.h | 15 +++--- eval/eval/attribute_trail_test.cc | 14 ++++-- eval/eval/attribute_utility.cc | 7 +-- eval/eval/attribute_utility.h | 35 ++++++++++++-- eval/eval/attribute_utility_test.cc | 57 ++++++++++++++++------ eval/eval/comprehension_step.cc | 8 ++-- eval/eval/container_access_step.cc | 47 ++++++++++--------- eval/eval/create_list_step.cc | 12 +++-- eval/eval/create_struct_step.cc | 38 +++++++++------ eval/eval/evaluator_core.cc | 3 +- eval/eval/evaluator_core.h | 17 +++++-- eval/eval/evaluator_core_test.cc | 7 ++- eval/eval/evaluator_stack_test.cc | 5 +- eval/eval/function_step.cc | 35 +++++++------- eval/eval/ident_step.cc | 40 ++++++++++------ eval/eval/jump_step.cc | 4 +- eval/eval/logic_step.cc | 2 +- eval/eval/select_step.cc | 70 ++++++++++++++++------------ eval/eval/shadowable_value_step.cc | 10 +++- eval/eval/ternary_step.cc | 3 +- eval/public/BUILD | 1 + eval/public/activation_test.cc | 9 ++-- eval/public/cel_expression.h | 2 +- eval/public/unknown_attribute_set.h | 3 +- extensions/protobuf/BUILD | 3 ++ extensions/protobuf/memory_manager.h | 12 +++++ 28 files changed, 335 insertions(+), 162 deletions(-) diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 6368811fa..2456b7492 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -18,13 +18,16 @@ cc_library( ":attribute_trail", ":attribute_utility", ":evaluator_stack", + "//base:memory_manager", "//eval/public:base_activation", "//eval/public:cel_attribute", "//eval/public:cel_expression", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", + "//extensions/protobuf:memory_manager", "//internal:casts", "//internal:status_macros", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -57,6 +60,7 @@ cc_test( ], deps = [ ":evaluator_stack", + "//extensions/protobuf:memory_manager", "//internal:testing", ], ) @@ -98,13 +102,13 @@ cc_library( deps = [ ":evaluator_core", ":expression_step_base", + "//base:memory_manager", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_protobuf//:protobuf", ], ) @@ -121,6 +125,8 @@ cc_library( ":evaluator_core", ":expression_step_base", "//eval/public:unknown_attribute_set", + "//extensions/protobuf:memory_manager", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", @@ -148,6 +154,7 @@ cc_library( "//eval/public:unknown_attribute_set", "//eval/public:unknown_function_result_set", "//eval/public:unknown_set", + "//extensions/protobuf:memory_manager", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -173,6 +180,8 @@ cc_library( "//eval/public/containers:field_access", "//eval/public/containers:field_backed_list_impl", "//eval/public/containers:field_backed_map_impl", + "//extensions/protobuf:memory_manager", + "//internal:status_macros", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -214,6 +223,7 @@ cc_library( "//eval/public/containers:container_backed_map_impl", "//eval/public/containers:field_access", "//eval/public/structs:cel_proto_wrapper", + "//extensions/protobuf:memory_manager", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -319,6 +329,7 @@ cc_test( "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", "//eval/public:cel_value", + "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", @@ -535,6 +546,7 @@ cc_library( srcs = ["attribute_trail.cc"], hdrs = ["attribute_trail.h"], deps = [ + "//base:memory_manager", "//eval/public:cel_attribute", "//eval/public:cel_expression", "//eval/public:cel_value", @@ -556,6 +568,7 @@ cc_test( ":attribute_trail", "//eval/public:cel_attribute", "//eval/public:cel_value", + "//extensions/protobuf:memory_manager", "//internal:testing", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], @@ -567,9 +580,13 @@ cc_library( hdrs = ["attribute_utility.h"], deps = [ ":attribute_trail", + "//base:logging", + "//base:memory_manager", "//eval/public:cel_attribute", + "//eval/public:cel_function", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", + "//eval/public:unknown_function_result_set", "//eval/public:unknown_set", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", @@ -586,10 +603,12 @@ cc_test( ], deps = [ ":attribute_utility", + "//base:memory_manager", "//eval/public:cel_attribute", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", + "//extensions/protobuf:memory_manager", "//internal:testing", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], @@ -640,6 +659,8 @@ cc_library( ":evaluator_core", ":expression_step_base", "//eval/public:cel_value", + "//extensions/protobuf:memory_manager", + "//internal:status_macros", "@com_google_absl//absl/status:statusor", ], ) diff --git a/eval/eval/attribute_trail.cc b/eval/eval/attribute_trail.cc index 42ec8a5b3..7a604e37a 100644 --- a/eval/eval/attribute_trail.cc +++ b/eval/eval/attribute_trail.cc @@ -3,20 +3,29 @@ #include #include "absl/status/status.h" +#include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" namespace google::api::expr::runtime { +AttributeTrail::AttributeTrail(Expr root, cel::MemoryManager& manager) { + attribute_ = manager + .New(std::move(root), + std::vector()) + .release(); +} + // Creates AttributeTrail with attribute path incremented by "qualifier". AttributeTrail AttributeTrail::Step(CelAttributeQualifier qualifier, - google::protobuf::Arena* arena) const { + cel::MemoryManager& manager) const { // Cannot continue void trail if (empty()) return AttributeTrail(); std::vector qualifiers = attribute_->qualifier_path(); qualifiers.push_back(qualifier); - return AttributeTrail(google::protobuf::Arena::Create( - arena, attribute_->variable(), std::move(qualifiers))); + auto attribute = + manager.New(attribute_->variable(), std::move(qualifiers)); + return AttributeTrail(attribute.release()); } } // namespace google::api::expr::runtime diff --git a/eval/eval/attribute_trail.h b/eval/eval/attribute_trail.h index 96a75097a..38df44b2c 100644 --- a/eval/eval/attribute_trail.h +++ b/eval/eval/attribute_trail.h @@ -2,11 +2,13 @@ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_ATTRIBUTE_TRAIL_H_ #include +#include #include #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/arena.h" #include "absl/types/optional.h" +#include "base/memory_manager.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" @@ -26,26 +28,25 @@ namespace google::api::expr::runtime { class AttributeTrail { public: AttributeTrail() : attribute_(nullptr) {} - AttributeTrail(google::api::expr::v1alpha1::Expr root, google::protobuf::Arena* arena) - : AttributeTrail(google::protobuf::Arena::Create( - arena, std::move(root), std::vector())) {} + + AttributeTrail(Expr root, cel::MemoryManager& manager); // Creates AttributeTrail with attribute path incremented by "qualifier". AttributeTrail Step(CelAttributeQualifier qualifier, - google::protobuf::Arena* arena) const; + cel::MemoryManager& manager) const; // Creates AttributeTrail with attribute path incremented by "qualifier". AttributeTrail Step(const std::string* qualifier, - google::protobuf::Arena* arena) const { + cel::MemoryManager& manager) const { return Step( CelAttributeQualifier::Create(CelValue::CreateString(qualifier)), - arena); + manager); } // Returns CelAttribute that corresponds to content of AttributeTrail. const CelAttribute* attribute() const { return attribute_; } - bool empty() const { return !attribute_; } + bool empty() const { return attribute_ == nullptr; } private: explicit AttributeTrail(const CelAttribute* attribute) diff --git a/eval/eval/attribute_trail_test.cc b/eval/eval/attribute_trail_test.cc index 09d0e5508..adb982860 100644 --- a/eval/eval/attribute_trail_test.cc +++ b/eval/eval/attribute_trail_test.cc @@ -5,30 +5,36 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/testing.h" namespace google::api::expr::runtime { -using google::api::expr::v1alpha1::Expr; +using ::cel::extensions::ProtoMemoryManager; +using ::google::api::expr::v1alpha1::Expr; // Attribute Trail behavior TEST(AttributeTrailTest, AttributeTrailEmptyStep) { google::protobuf::Arena arena; + ProtoMemoryManager manager(&arena); + std::string step = "step"; CelValue step_value = CelValue::CreateString(&step); AttributeTrail trail; - ASSERT_TRUE(trail.Step(&step, &arena).empty()); + ASSERT_TRUE(trail.Step(&step, manager).empty()); ASSERT_TRUE( - trail.Step(CelAttributeQualifier::Create(step_value), &arena).empty()); + trail.Step(CelAttributeQualifier::Create(step_value), manager).empty()); } TEST(AttributeTrailTest, AttributeTrailStep) { google::protobuf::Arena arena; + ProtoMemoryManager manager(&arena); + std::string step = "step"; CelValue step_value = CelValue::CreateString(&step); Expr root; root.mutable_ident_expr()->set_name("ident"); - AttributeTrail trail = AttributeTrail(root, &arena).Step(&step, &arena); + AttributeTrail trail = AttributeTrail(root, manager).Step(&step, manager); ASSERT_TRUE(trail.attribute() != nullptr); ASSERT_EQ(*trail.attribute(), diff --git a/eval/eval/attribute_utility.cc b/eval/eval/attribute_utility.cc index 8cd1bf140..69e7813e0 100644 --- a/eval/eval/attribute_utility.cc +++ b/eval/eval/attribute_utility.cc @@ -59,7 +59,7 @@ const UnknownSet* AttributeUtility::MergeUnknowns( if (result == nullptr) { result = current_set; } else { - result = Arena::Create(arena_, *result, *current_set); + result = memory_manager_.New(*result, *current_set).release(); } } @@ -97,9 +97,10 @@ const UnknownSet* AttributeUtility::MergeUnknowns( if (!attr_set.attributes().empty()) { if (initial_set != nullptr) { initial_set = - Arena::Create(arena_, *initial_set, UnknownSet(attr_set)); + memory_manager_.New(*initial_set, UnknownSet(attr_set)) + .release(); } else { - initial_set = Arena::Create(arena_, attr_set); + initial_set = memory_manager_.New(attr_set).release(); } } return MergeUnknowns(args, initial_set); diff --git a/eval/eval/attribute_utility.h b/eval/eval/attribute_utility.h index 0b76387db..79f069215 100644 --- a/eval/eval/attribute_utility.h +++ b/eval/eval/attribute_utility.h @@ -3,13 +3,17 @@ #include +#include "base/logging.h" #include "google/protobuf/arena.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "base/memory_manager.h" #include "eval/eval/attribute_trail.h" #include "eval/public/cel_attribute.h" +#include "eval/public/cel_function.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_function_result_set.h" #include "eval/public/unknown_set.h" namespace google::api::expr::runtime { @@ -18,15 +22,21 @@ namespace google::api::expr::runtime { // helpers for merging unknown sets from arguments on the stack and for // identifying unknown/missing attributes based on the patterns for a given // Evaluation. +// Neither moveable nor copyable. class AttributeUtility { public: AttributeUtility( const std::vector* unknown_patterns, const std::vector* missing_attribute_patterns, - google::protobuf::Arena* arena) + cel::MemoryManager& manager) : unknown_patterns_(unknown_patterns), missing_attribute_patterns_(missing_attribute_patterns), - arena_(arena) {} + memory_manager_(manager) {} + + AttributeUtility(const AttributeUtility&) = delete; + AttributeUtility& operator=(const AttributeUtility&) = delete; + AttributeUtility(AttributeUtility&&) = delete; + AttributeUtility& operator=(AttributeUtility&&) = delete; // Checks whether particular corresponds to any patterns that define missing // attribute. @@ -59,10 +69,29 @@ class AttributeUtility { const UnknownSet* initial_set, bool use_partial) const; + // Create an initial UnknownSet from a single attribute. + const UnknownSet* CreateUnknownSet(const CelAttribute* attr) const { + return memory_manager_.New(UnknownAttributeSet({attr})) + .release(); + } + + // Create an initial UnknownSet from a single missing function call. + const UnknownSet* CreateUnknownSet(const CelFunctionDescriptor& fn_descriptor, + int64_t expr_id, + absl::Span args) const { + auto* fn = memory_manager_ + .New( + fn_descriptor, expr_id, + std::vector(args.begin(), args.end())) + .release(); + return memory_manager_.New(UnknownFunctionResultSet(fn)) + .release(); + } + private: const std::vector* unknown_patterns_; const std::vector* missing_attribute_patterns_; - google::protobuf::Arena* arena_; + cel::MemoryManager& memory_manager_; }; } // namespace google::api::expr::runtime diff --git a/eval/eval/attribute_utility_test.cc b/eval/eval/attribute_utility_test.cc index 4c70ebef1..fc80fd2ab 100644 --- a/eval/eval/attribute_utility_test.cc +++ b/eval/eval/attribute_utility_test.cc @@ -5,10 +5,12 @@ #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/testing.h" namespace google::api::expr::runtime { +using ::cel::extensions::ProtoMemoryManager; using ::google::api::expr::v1alpha1::Expr; using testing::Eq; using testing::NotNull; @@ -17,6 +19,7 @@ using testing::UnorderedPointwise; TEST(UnknownsUtilityTest, UnknownsUtilityCheckUnknowns) { google::protobuf::Arena arena; + ProtoMemoryManager manager(&arena); std::vector unknown_patterns = { CelAttributePattern("unknown0", {CelAttributeQualifierPattern::Create( CelValue::CreateInt64(1))}), @@ -29,7 +32,7 @@ TEST(UnknownsUtilityTest, UnknownsUtilityCheckUnknowns) { std::vector missing_attribute_patterns; AttributeUtility utility(&unknown_patterns, &missing_attribute_patterns, - &arena); + manager); // no match for void trail ASSERT_FALSE(utility.CheckForUnknown(AttributeTrail(), true)); ASSERT_FALSE(utility.CheckForUnknown(AttributeTrail(), false)); @@ -37,7 +40,7 @@ TEST(UnknownsUtilityTest, UnknownsUtilityCheckUnknowns) { google::api::expr::v1alpha1::Expr unknown_expr0; unknown_expr0.mutable_ident_expr()->set_name("unknown0"); - AttributeTrail unknown_trail0(unknown_expr0, &arena); + AttributeTrail unknown_trail0(unknown_expr0, manager); { ASSERT_FALSE(utility.CheckForUnknown(unknown_trail0, false)); } @@ -46,20 +49,21 @@ TEST(UnknownsUtilityTest, UnknownsUtilityCheckUnknowns) { { ASSERT_TRUE(utility.CheckForUnknown( unknown_trail0.Step( - CelAttributeQualifier::Create(CelValue::CreateInt64(1)), &arena), + CelAttributeQualifier::Create(CelValue::CreateInt64(1)), manager), false)); } { ASSERT_TRUE(utility.CheckForUnknown( unknown_trail0.Step( - CelAttributeQualifier::Create(CelValue::CreateInt64(1)), &arena), + CelAttributeQualifier::Create(CelValue::CreateInt64(1)), manager), true)); } } TEST(UnknownsUtilityTest, UnknownsUtilityMergeUnknownsFromValues) { google::protobuf::Arena arena; + ProtoMemoryManager manager(&arena); google::api::expr::v1alpha1::Expr unknown_expr0; unknown_expr0.mutable_ident_expr()->set_name("unknown0"); @@ -79,7 +83,7 @@ TEST(UnknownsUtilityTest, UnknownsUtilityMergeUnknownsFromValues) { CelAttribute attribute2(unknown_expr2, {}); AttributeUtility utility(&unknown_patterns, &missing_attribute_patterns, - &arena); + manager); UnknownSet unknown_set0(UnknownAttributeSet({&attribute0})); UnknownSet unknown_set1(UnknownAttributeSet({&attribute1})); @@ -107,6 +111,7 @@ TEST(UnknownsUtilityTest, UnknownsUtilityMergeUnknownsFromValues) { TEST(UnknownsUtilityTest, UnknownsUtilityCheckForUnknownsFromAttributes) { google::protobuf::Arena arena; + ProtoMemoryManager manager(&arena); std::vector unknown_patterns = { CelAttributePattern("unknown0", @@ -121,22 +126,22 @@ TEST(UnknownsUtilityTest, UnknownsUtilityCheckForUnknownsFromAttributes) { google::api::expr::v1alpha1::Expr unknown_expr1; unknown_expr1.mutable_ident_expr()->set_name("unknown1"); - AttributeTrail trail0(unknown_expr0, &arena); - AttributeTrail trail1(unknown_expr1, &arena); + AttributeTrail trail0(unknown_expr0, manager); + AttributeTrail trail1(unknown_expr1, manager); CelAttribute attribute1(unknown_expr1, {}); UnknownSet unknown_set1(UnknownAttributeSet({&attribute1})); AttributeUtility utility(&unknown_patterns, &missing_attribute_patterns, - &arena); + manager); UnknownSet unknown_attr_set(utility.CheckForUnknowns( { AttributeTrail(), // To make sure we handle empty trail gracefully. trail0.Step(CelAttributeQualifier::Create(CelValue::CreateInt64(1)), - &arena), + manager), trail0.Step(CelAttributeQualifier::Create(CelValue::CreateInt64(2)), - &arena), + manager), }, false)); @@ -147,6 +152,7 @@ TEST(UnknownsUtilityTest, UnknownsUtilityCheckForUnknownsFromAttributes) { TEST(UnknownsUtilityTest, UnknownsUtilityCheckForMissingAttributes) { google::protobuf::Arena arena; + ProtoMemoryManager manager(&arena); std::vector unknown_patterns; @@ -159,12 +165,12 @@ TEST(UnknownsUtilityTest, UnknownsUtilityCheckForMissingAttributes) { Expr* ident_expr = select_expr->mutable_operand(); ident_expr->mutable_ident_expr()->set_name("destination"); - AttributeTrail trail(*ident_expr, &arena); + AttributeTrail trail(*ident_expr, manager); trail = trail.Step( - CelAttributeQualifier::Create(CelValue::CreateStringView("ip")), &arena); + CelAttributeQualifier::Create(CelValue::CreateStringView("ip")), manager); AttributeUtility utility0(&unknown_patterns, &missing_attribute_patterns, - &arena); + manager); EXPECT_FALSE(utility0.CheckForMissingAttribute(trail)); missing_attribute_patterns.push_back(CelAttributePattern( @@ -172,8 +178,31 @@ TEST(UnknownsUtilityTest, UnknownsUtilityCheckForMissingAttributes) { CelValue::CreateStringView("ip"))})); AttributeUtility utility1(&unknown_patterns, &missing_attribute_patterns, - &arena); + manager); EXPECT_TRUE(utility1.CheckForMissingAttribute(trail)); } +TEST(AttributeUtilityTest, CreateUnknownSet) { + google::protobuf::Arena arena; + ProtoMemoryManager manager(&arena); + + Expr expr; + auto* select_expr = expr.mutable_select_expr(); + select_expr->set_field("ip"); + + Expr* ident_expr = select_expr->mutable_operand(); + ident_expr->mutable_ident_expr()->set_name("destination"); + + AttributeTrail trail(*ident_expr, manager); + trail = trail.Step( + CelAttributeQualifier::Create(CelValue::CreateStringView("ip")), manager); + + std::vector empty_patterns; + AttributeUtility utility(&empty_patterns, &empty_patterns, manager); + + const UnknownSet* set = utility.CreateUnknownSet(trail.attribute()); + EXPECT_EQ(*set->unknown_attributes().attributes().at(0)->AsString(), + "destination.ip"); +} + } // namespace google::api::expr::runtime diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index 88ab97f26..64b98f058 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -95,7 +95,7 @@ absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { return frame->JumpTo(error_jump_offset_); } frame->value_stack().Push( - CreateNoMatchingOverloadError(frame->arena(), "")); + CreateNoMatchingOverloadError(frame->memory_manager(), "")); return frame->JumpTo(error_jump_offset_); } const CelList* cel_list = iter_range.ListOrDie(); @@ -131,7 +131,7 @@ absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { frame->value_stack().Push(CelValue::CreateInt64(current_index)); auto iter_trail = iter_range_attr.Step( CelAttributeQualifier::Create(CelValue::CreateInt64(current_index)), - frame->arena()); + frame->memory_manager()); frame->value_stack().Push(current_value, iter_trail); CEL_RETURN_IF_ERROR(frame->SetIterVar(current_value, iter_trail)); return absl::OkStatus(); @@ -168,8 +168,8 @@ absl::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { if (loop_condition_value.IsError() || loop_condition_value.IsUnknownSet()) { frame->value_stack().Push(loop_condition_value); } else { - frame->value_stack().Push( - CreateNoMatchingOverloadError(frame->arena(), "")); + frame->value_stack().Push(CreateNoMatchingOverloadError( + frame->memory_manager(), "")); } // The error jump skips the ComprehensionFinish clean-up step, so we // need to update the iteration variable stack here. diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index ffd01d99b..51ba17ac8 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -2,10 +2,10 @@ #include -#include "google/protobuf/arena.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "base/memory_manager.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/public/cel_value.h" @@ -30,33 +30,33 @@ class ContainerAccessStep : public ExpressionStepBase { ValueAttributePair PerformLookup(ExecutionFrame* frame) const; CelValue LookupInMap(const CelMap* cel_map, const CelValue& key, - google::protobuf::Arena* arena) const; + cel::MemoryManager& manager) const; CelValue LookupInList(const CelList* cel_list, const CelValue& key, - google::protobuf::Arena* arena) const; + cel::MemoryManager& manager) const; }; -inline CelValue ContainerAccessStep::LookupInMap(const CelMap* cel_map, - const CelValue& key, - google::protobuf::Arena* arena) const { +inline CelValue ContainerAccessStep::LookupInMap( + const CelMap* cel_map, const CelValue& key, + cel::MemoryManager& manager) const { auto status = CelValue::CheckMapKeyType(key); if (!status.ok()) { - return CreateErrorValue(arena, status); + return CreateErrorValue(manager, status); } absl::optional maybe_value = (*cel_map)[key]; if (maybe_value.has_value()) { return maybe_value.value(); } - return CreateNoSuchKeyError(arena, "Key not found in map"); + return CreateNoSuchKeyError(manager, "Key not found in map"); } -inline CelValue ContainerAccessStep::LookupInList(const CelList* cel_list, - const CelValue& key, - google::protobuf::Arena* arena) const { +inline CelValue ContainerAccessStep::LookupInList( + const CelList* cel_list, const CelValue& key, + cel::MemoryManager& manager) const { switch (key.type()) { case CelValue::Type::kInt64: { int64_t idx = key.Int64OrDie(); if (idx < 0 || idx >= cel_list->size()) { - return CreateErrorValue(arena, + return CreateErrorValue(manager, absl::StrCat("Index error: index=", idx, " size=", cel_list->size())); } @@ -64,8 +64,8 @@ inline CelValue ContainerAccessStep::LookupInList(const CelList* cel_list, } default: { return CreateErrorValue( - arena, absl::StrCat("Index error: expected integer type, got ", - CelValue::TypeName(key.type()))); + manager, absl::StrCat("Index error: expected integer type, got ", + CelValue::TypeName(key.type()))); } } } @@ -92,12 +92,12 @@ ContainerAccessStep::ValueAttributePair ContainerAccessStep::PerformLookup( frame->value_stack().GetAttributeSpan(kNumContainerAccessArguments); auto container_trail = input_attrs[0]; trail = container_trail.Step(CelAttributeQualifier::Create(key), - frame->arena()); + frame->memory_manager()); if (frame->attribute_utility().CheckForUnknown(trail, /*use_partial=*/false)) { - auto unknown_set = google::protobuf::Arena::Create( - frame->arena(), UnknownAttributeSet({trail.attribute()})); + auto unknown_set = + frame->attribute_utility().CreateUnknownSet(trail.attribute()); return {CelValue::CreateUnknownSet(unknown_set), trail}; } @@ -113,17 +113,18 @@ ContainerAccessStep::ValueAttributePair ContainerAccessStep::PerformLookup( switch (container.type()) { case CelValue::Type::kMap: { const CelMap* cel_map = container.MapOrDie(); - return {LookupInMap(cel_map, key, frame->arena()), trail}; + return {LookupInMap(cel_map, key, frame->memory_manager()), trail}; } case CelValue::Type::kList: { const CelList* cel_list = container.ListOrDie(); - return {LookupInList(cel_list, key, frame->arena()), trail}; + return {LookupInList(cel_list, key, frame->memory_manager()), trail}; } default: { - auto error = CreateErrorValue( - frame->arena(), absl::InvalidArgumentError(absl::StrCat( - "Invalid container type: '", - CelValue::TypeName(container.type()), "'"))); + auto error = + CreateErrorValue(frame->memory_manager(), + absl::InvalidArgumentError(absl::StrCat( + "Invalid container type: '", + CelValue::TypeName(container.type()), "'"))); return {error, trail}; } } diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index 2567350c9..721743d12 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -65,11 +65,15 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { CelList* cel_list; if (immutable_) { - cel_list = google::protobuf::Arena::Create( - frame->arena(), std::vector(args.begin(), args.end())); + cel_list = frame->memory_manager() + .New( + std::vector(args.begin(), args.end())) + .release(); } else { - cel_list = google::protobuf::Arena::Create( - frame->arena(), std::vector(args.begin(), args.end())); + cel_list = frame->memory_manager() + .New( + std::vector(args.begin(), args.end())) + .release(); } result = CelValue::CreateList(cel_list); frame->value_stack().Pop(list_size_); diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 4cbad64bc..786e807ca 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -5,6 +5,7 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/arena.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -13,12 +14,14 @@ #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/containers/field_access.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { +using ::cel::extensions::ProtoMemoryManager; using ::google::protobuf::Descriptor; using ::google::protobuf::FieldDescriptor; using ::google::protobuf::Message; @@ -64,6 +67,10 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, absl::Span args = frame->value_stack().GetSpan(entries_size); + // This implementation requires arena-backed memory manager. + google::protobuf::Arena* arena = + ProtoMemoryManager::CastToProtoArena(frame->memory_manager()); + if (frame->enable_unknowns()) { auto unknown_set = frame->attribute_utility().MergeUnknowns( args, frame->value_stack().GetAttributeSpan(entries_size), @@ -78,12 +85,11 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, const Message* prototype = frame->message_factory()->GetPrototype(descriptor_); - Message* msg = - (prototype != nullptr) ? prototype->New(frame->arena()) : nullptr; + Message* msg = (prototype != nullptr) ? prototype->New(arena) : nullptr; if (msg == nullptr) { *result = CreateErrorValue( - frame->arena(), + frame->memory_manager(), absl::Substitute("Failed to create message $0", descriptor_->name())); return absl::OkStatus(); } @@ -149,13 +155,13 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, } Message* entry_msg = msg->GetReflection()->AddMessage(msg, entry.field); - status = SetValueToSingleField(key, key_field_descriptor, entry_msg, - frame->arena()); + status = + SetValueToSingleField(key, key_field_descriptor, entry_msg, arena); if (!status.ok()) { break; } status = SetValueToSingleField(value.value(), value_field_descriptor, - entry_msg, frame->arena()); + entry_msg, arena); if (!status.ok()) { break; } @@ -165,7 +171,7 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, const CelList* cel_list; if (!arg.GetValue(&cel_list) || cel_list == nullptr) { *result = CreateErrorValue( - frame->arena(), + frame->memory_manager(), absl::Substitute( "Failed to create message $0: value $1 is not CelList", descriptor_->name(), entry.field->name())); @@ -173,24 +179,24 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, } for (int i = 0; i < cel_list->size(); i++) { - status = AddValueToRepeatedField((*cel_list)[i], entry.field, msg, - frame->arena()); + status = + AddValueToRepeatedField((*cel_list)[i], entry.field, msg, arena); if (!status.ok()) break; } } else { - status = SetValueToSingleField(arg, entry.field, msg, frame->arena()); + status = SetValueToSingleField(arg, entry.field, msg, arena); } if (!status.ok()) { *result = CreateErrorValue( - frame->arena(), + frame->memory_manager(), absl::Substitute("Failed to create message $0: reason $1", descriptor_->name(), status.ToString())); return absl::OkStatus(); } } - *result = CelProtoWrapper::CreateMessage(msg, frame->arena()); + *result = CelProtoWrapper::CreateMessage(msg, arena); return absl::OkStatus(); } @@ -237,7 +243,7 @@ absl::Status CreateStructStepForMap::DoEvaluate(ExecutionFrame* frame, CreateContainerBackedMap(absl::Span>( map_entries.data(), map_entries.size())); if (!cel_map.ok()) { - *result = CreateErrorValue(frame->arena(), cel_map.status()); + *result = CreateErrorValue(frame->memory_manager(), cel_map.status()); return absl::OkStatus(); } @@ -245,7 +251,11 @@ absl::Status CreateStructStepForMap::DoEvaluate(ExecutionFrame* frame, *result = CelValue::CreateMap(cel_map_ptr.get()); // Pass object ownership to Arena. - frame->arena()->Own(cel_map_ptr.release()); + // TODO(issues/5): Update CEL map implementation to tolerate generic + // allocation api. + google::protobuf::Arena* arena = + ProtoMemoryManager::CastToProtoArena(frame->memory_manager()); + arena->Own(cel_map_ptr.release()); return absl::OkStatus(); } diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index df64324e4..92b13aca3 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -8,6 +8,7 @@ #include "absl/types/optional.h" #include "eval/eval/attribute_trail.h" #include "eval/public/cel_value.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/casts.h" #include "internal/status_macros.h" @@ -27,7 +28,7 @@ CelExpressionFlatEvaluationState::CelExpressionFlatEvaluationState( google::protobuf::Arena* arena) : value_stack_(value_stack_size), iter_variable_names_(iter_variable_names), - arena_(arena) {} + memory_manager_(arena) {} void CelExpressionFlatEvaluationState::Reset() { iter_stack_.clear(); diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 8c29574af..a59e87a75 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -15,12 +15,14 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" +#include "absl/base/attributes.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "base/memory_manager.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/attribute_utility.h" #include "eval/eval/evaluator_stack.h" @@ -29,6 +31,7 @@ #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" +#include "extensions/protobuf/memory_manager.h" namespace google::api::expr::runtime { @@ -93,13 +96,18 @@ class CelExpressionFlatEvaluationState : public CelEvaluationState { std::set& iter_variable_names() { return iter_variable_names_; } - google::protobuf::Arena* arena() { return arena_; } + google::protobuf::Arena* arena() { return memory_manager_.arena(); } + + cel::MemoryManager& memory_manager() { return memory_manager_; } private: EvaluatorStack value_stack_; std::set iter_variable_names_; std::vector iter_stack_; - google::protobuf::Arena* arena_; + // TODO(issues/5): State owns a ProtoMemoryManager to adapt from the client + // provided arena. In the future, clients will have to maintain the particular + // manager they want to use for evaluation. + cel::extensions::ProtoMemoryManager memory_manager_; }; // ExecutionFrame provides context for expression evaluation. @@ -128,7 +136,7 @@ class ExecutionFrame { enable_null_coercion_(enable_null_coercion), attribute_utility_(&activation.unknown_attribute_patterns(), &activation.missing_attribute_patterns(), - state->arena()), + state->memory_manager()), max_iterations_(max_iterations), iterations_(0), state_(state) {} @@ -160,7 +168,8 @@ class ExecutionFrame { bool enable_null_coercion() const { return enable_null_coercion_; } - google::protobuf::Arena* arena() { return state_->arena(); } + cel::MemoryManager& memory_manager() { return state_->memory_manager(); } + const google::protobuf::DescriptorPool* descriptor_pool() const { return descriptor_pool_; } diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index 57112f69d..58946d38a 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -11,11 +11,13 @@ #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" namespace google::api::expr::runtime { +using ::cel::extensions::ProtoMemoryManager; using ::google::api::expr::v1alpha1::Expr; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using testing::_; @@ -86,6 +88,7 @@ TEST(EvaluatorCoreTest, ExecutionFrameSetGetClearVar) { Activation activation; google::protobuf::Arena arena; + ProtoMemoryManager manager(&arena); ExecutionPath path; CelExpressionFlatEvaluationState state(path.size(), {test_iter_var}, nullptr); ExecutionFrame frame(path, activation, @@ -98,9 +101,9 @@ TEST(EvaluatorCoreTest, ExecutionFrameSetGetClearVar) { ident.mutable_ident_expr()->set_name("var"); AttributeTrail original_trail = - AttributeTrail(ident, &arena) + AttributeTrail(ident, manager) .Step(CelAttributeQualifier::Create(CelValue::CreateInt64(1)), - &arena); + manager); CelValue result; const AttributeTrail* trail; diff --git a/eval/eval/evaluator_stack_test.cc b/eval/eval/evaluator_stack_test.cc index b78d41606..98620041b 100644 --- a/eval/eval/evaluator_stack_test.cc +++ b/eval/eval/evaluator_stack_test.cc @@ -1,23 +1,26 @@ #include "eval/eval/evaluator_stack.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { +using ::cel::extensions::ProtoMemoryManager; using testing::NotNull; // Test Value Stack Push/Pop operation TEST(EvaluatorStackTest, StackPushPop) { google::protobuf::Arena arena; + ProtoMemoryManager manager(&arena); google::api::expr::v1alpha1::Expr expr; expr.mutable_ident_expr()->set_name("name"); CelAttribute attribute(expr, {}); EvaluatorStack stack(10); stack.Push(CelValue::CreateInt64(1)); stack.Push(CelValue::CreateInt64(2), AttributeTrail()); - stack.Push(CelValue::CreateInt64(3), AttributeTrail(expr, &arena)); + stack.Push(CelValue::CreateInt64(3), AttributeTrail(expr, manager)); ASSERT_EQ(stack.Peek().Int64OrDie(), 3); ASSERT_THAT(stack.PeekAttribute().attribute(), NotNull()); diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index cf2322598..c305559c7 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -26,11 +26,15 @@ #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_function_result_set.h" #include "eval/public/unknown_set.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { +using cel::extensions::ProtoMemoryManager; + // Only non-strict functions are allowed to consume errors and unknown sets. bool IsNonStrict(const CelFunction& function) { const CelFunctionDescriptor& descriptor = function.descriptor(); @@ -70,8 +74,9 @@ std::vector CheckForPartialUnknowns( auto attr_set = frame->attribute_utility().CheckForUnknowns( attrs.subspan(i, 1), /*use_partial=*/true); if (!attr_set.attributes().empty()) { - auto unknown_set = google::protobuf::Arena::Create(frame->arena(), - std::move(attr_set)); + auto unknown_set = frame->memory_manager() + .New(std::move(attr_set)) + .release(); result.push_back(CelValue::CreateUnknownSet(unknown_set)); } else { result.push_back(args.at(i)); @@ -126,27 +131,19 @@ absl::Status AbstractFunctionStep::DoEvaluate(ExecutionFrame* frame, } // Derived class resolves to a single function overload or none. - auto status = ResolveFunction(input_args, frame); - if (!status.ok()) { - return status.status(); - } - const CelFunction* matched_function = status.value(); + CEL_ASSIGN_OR_RETURN(const CelFunction* matched_function, + ResolveFunction(input_args, frame)); // Overload found and is allowed to consume the arguments. if (ShouldAcceptOverload(matched_function, input_args)) { - absl::Status status = - matched_function->Evaluate(input_args, result, frame->arena()); - if (!status.ok()) { - return status; - } + google::protobuf::Arena* arena = + ProtoMemoryManager::CastToProtoArena(frame->memory_manager()); + CEL_RETURN_IF_ERROR(matched_function->Evaluate(input_args, result, arena)); + if (frame->enable_unknown_function_results() && IsUnknownFunctionResult(*result)) { - const auto* function_result = - google::protobuf::Arena::Create( - frame->arena(), matched_function->descriptor(), id(), - std::vector(input_args.begin(), input_args.end())); - const auto* unknown_set = google::protobuf::Arena::Create( - frame->arena(), UnknownFunctionResultSet(function_result)); + auto unknown_set = frame->attribute_utility().CreateUnknownSet( + matched_function->descriptor(), id(), input_args); *result = CelValue::CreateUnknownSet(unknown_set); } } else { @@ -173,7 +170,7 @@ absl::Status AbstractFunctionStep::DoEvaluate(ExecutionFrame* frame, } // If no errors or unknowns in input args, create new CelError. - *result = CreateNoMatchingOverloadError(frame->arena()); + *result = CreateNoMatchingOverloadError(frame->memory_manager()); } return absl::OkStatus(); diff --git a/eval/eval/ident_step.cc b/eval/eval/ident_step.cc index 99c5c3491..d3fd44b68 100644 --- a/eval/eval/ident_step.cc +++ b/eval/eval/ident_step.cc @@ -4,16 +4,21 @@ #include #include "google/protobuf/arena.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/public/unknown_attribute_set.h" +#include "extensions/protobuf/memory_manager.h" namespace google::api::expr::runtime { namespace { + +using ::cel::extensions::ProtoMemoryManager; + class IdentStep : public ExpressionStepBase { public: IdentStep(absl::string_view name, int64_t expr_id) @@ -22,45 +27,50 @@ class IdentStep : public ExpressionStepBase { absl::Status Evaluate(ExecutionFrame* frame) const override; private: - void DoEvaluate(ExecutionFrame* frame, CelValue* result, - AttributeTrail* trail) const; + absl::Status DoEvaluate(ExecutionFrame* frame, CelValue* result, + AttributeTrail* trail) const; std::string name_; }; -void IdentStep::DoEvaluate(ExecutionFrame* frame, CelValue* result, - AttributeTrail* trail) const { +absl::Status IdentStep::DoEvaluate(ExecutionFrame* frame, CelValue* result, + AttributeTrail* trail) const { // Special case - iterator looked up in if (frame->GetIterVar(name_, result)) { const AttributeTrail* iter_trail; if (frame->GetIterAttr(name_, &iter_trail)) { *trail = *iter_trail; } - return; + return absl::OkStatus(); } - auto value = frame->activation().FindValue(name_, frame->arena()); + // TODO(issues/5): Update ValueProducer to support generic memory manager + // API. + google::protobuf::Arena* arena = + ProtoMemoryManager::CastToProtoArena(frame->memory_manager()); + + auto value = frame->activation().FindValue(name_, arena); // Populate trails if either MissingAttributeError or UnknownPattern // is enabled. if (frame->enable_missing_attribute_errors() || frame->enable_unknowns()) { google::api::expr::v1alpha1::Expr expr; expr.mutable_ident_expr()->set_name(name_); - *trail = AttributeTrail(std::move(expr), frame->arena()); + *trail = AttributeTrail(std::move(expr), frame->memory_manager()); } if (frame->enable_missing_attribute_errors() && !name_.empty() && frame->attribute_utility().CheckForMissingAttribute(*trail)) { - *result = CreateMissingAttributeError(frame->arena(), name_); - return; + *result = CreateMissingAttributeError(frame->memory_manager(), name_); + return absl::OkStatus(); } if (frame->enable_unknowns()) { if (frame->attribute_utility().CheckForUnknown(*trail, false)) { - auto unknown_set = google::protobuf::Arena::Create( - frame->arena(), UnknownAttributeSet({trail->attribute()})); + auto unknown_set = + frame->attribute_utility().CreateUnknownSet(trail->attribute()); *result = CelValue::CreateUnknownSet(unknown_set); - return; + return absl::OkStatus(); } } @@ -68,16 +78,18 @@ void IdentStep::DoEvaluate(ExecutionFrame* frame, CelValue* result, *result = value.value(); } else { *result = CreateErrorValue( - frame->arena(), + frame->memory_manager(), absl::StrCat("No value with name \"", name_, "\" found in Activation")); } + + return absl::OkStatus(); } absl::Status IdentStep::Evaluate(ExecutionFrame* frame) const { CelValue result; AttributeTrail trail; - DoEvaluate(frame, &result, &trail); + CEL_RETURN_IF_ERROR(DoEvaluate(frame, &result, &trail)); frame->value_stack().Push(result, trail); diff --git a/eval/eval/jump_step.cc b/eval/eval/jump_step.cc index e99469d47..f59762390 100644 --- a/eval/eval/jump_step.cc +++ b/eval/eval/jump_step.cc @@ -82,8 +82,8 @@ class BoolCheckJumpStep : public JumpStepBase { } if (!value.IsBool()) { - CelValue error_value = - CreateNoMatchingOverloadError(frame->arena(), ""); + CelValue error_value = CreateNoMatchingOverloadError( + frame->memory_manager(), ""); frame->value_stack().PopAndPush(error_value); return Jump(frame); } diff --git a/eval/eval/logic_step.cc b/eval/eval/logic_step.cc index 7be833874..1bcd9fcab 100644 --- a/eval/eval/logic_step.cc +++ b/eval/eval/logic_step.cc @@ -77,7 +77,7 @@ class LogicalOpStep : public ExpressionStepBase { // Fallback. *result = CreateNoMatchingOverloadError( - frame->arena(), + frame->memory_manager(), (op_type_ == OpType::OR) ? builtin::kOr : builtin::kAnd); return absl::OkStatus(); } diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index c200cea33..e8e7c7cb9 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -14,11 +14,14 @@ #include "eval/public/containers/field_access.h" #include "eval/public/containers/field_backed_list_impl.h" #include "eval/public/containers/field_backed_map_impl.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { +using ::cel::extensions::ProtoMemoryManager; using ::google::protobuf::Descriptor; using ::google::protobuf::FieldDescriptor; using ::google::protobuf::Reflection; @@ -52,7 +55,7 @@ class SelectStep : public ExpressionStepBase { private: absl::Status CreateValueFromField(const google::protobuf::Message& msg, - google::protobuf::Arena* arena, + cel::MemoryManager& manager, CelValue* result) const; std::string field_; @@ -62,16 +65,18 @@ class SelectStep : public ExpressionStepBase { }; absl::Status SelectStep::CreateValueFromField(const google::protobuf::Message& msg, - google::protobuf::Arena* arena, + cel::MemoryManager& manager, CelValue* result) const { const Descriptor* desc = msg.GetDescriptor(); const FieldDescriptor* field_desc = desc->FindFieldByName(field_); if (field_desc == nullptr) { - *result = CreateNoSuchFieldError(arena, field_); + *result = CreateNoSuchFieldError(manager, field_); return absl::OkStatus(); } + google::protobuf::Arena* arena = ProtoMemoryManager::CastToProtoArena(manager); + if (field_desc->is_map()) { CelMap* map = google::protobuf::Arena::Create(arena, &msg, field_desc, arena); @@ -89,42 +94,41 @@ absl::Status SelectStep::CreateValueFromField(const google::protobuf::Message& m result); } -absl::optional CheckForMarkedAttributes(const ExecutionFrame& frame, - const AttributeTrail& trail, - google::protobuf::Arena* arena) { - if (frame.enable_unknowns() && - frame.attribute_utility().CheckForUnknown(trail, - /*use_partial=*/false)) { - auto unknown_set = google::protobuf::Arena::Create( - arena, UnknownAttributeSet({trail.attribute()})); - return CelValue::CreateUnknownSet(unknown_set); +absl::optional CheckForMarkedAttributes(const AttributeTrail& trail, + ExecutionFrame* frame) { + if (frame->enable_unknowns() && + frame->attribute_utility().CheckForUnknown(trail, + /*use_partial=*/false)) { + auto unknown_set = frame->memory_manager().New( + UnknownAttributeSet({trail.attribute()})); + return CelValue::CreateUnknownSet(unknown_set.release()); } - if (frame.enable_missing_attribute_errors() && - frame.attribute_utility().CheckForMissingAttribute(trail)) { + if (frame->enable_missing_attribute_errors() && + frame->attribute_utility().CheckForMissingAttribute(trail)) { auto attribute_string = trail.attribute()->AsString(); if (attribute_string.ok()) { - return CreateMissingAttributeError(arena, *attribute_string); + return CreateMissingAttributeError(frame->memory_manager(), + *attribute_string); } // Invariant broken (an invalid CEL Attribute shouldn't match anything). // Log and return a CelError. GOOGLE_LOG(ERROR) << "Invalid attribute pattern matched select path: " - << attribute_string.status().ToString(); - return CelValue::CreateError( - google::protobuf::Arena::Create(arena, attribute_string.status())); + << attribute_string.status(); + return CreateErrorValue(frame->memory_manager(), attribute_string.status()); } return absl::nullopt; } CelValue TestOnlySelect(const google::protobuf::Message& msg, const std::string& field, - google::protobuf::Arena* arena) { + cel::MemoryManager& manager) { const Reflection* reflection = msg.GetReflection(); const Descriptor* desc = msg.GetDescriptor(); const FieldDescriptor* field_desc = desc->FindFieldByName(field); if (field_desc == nullptr) { - return CreateNoSuchFieldError(arena, field); + return CreateNoSuchFieldError(manager, field); } if (field_desc->is_map()) { @@ -147,12 +151,12 @@ CelValue TestOnlySelect(const google::protobuf::Message& msg, const std::string& } CelValue TestOnlySelect(const CelMap& map, const std::string& field_name, - google::protobuf::Arena* arena) { + cel::MemoryManager& manager) { // Field presence only supports string keys containing valid identifier // characters. auto presence = map.Has(CelValue::CreateStringView(field_name)); if (!presence.ok()) { - return CreateErrorValue(arena, presence.status()); + return CreateErrorValue(manager, presence.status()); } return CelValue::CreateBool(*presence); @@ -177,11 +181,12 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { // Handle unknown resolution. if (frame->enable_unknowns() || frame->enable_missing_attribute_errors()) { - result_trail = trail.Step(&field_, frame->arena()); + result_trail = trail.Step(&field_, frame->memory_manager()); } if (arg.IsNull()) { - CelValue error_value = CreateErrorValue(frame->arena(), "Message is NULL"); + CelValue error_value = + CreateErrorValue(frame->memory_manager(), "Message is NULL"); frame->value_stack().PopAndPush(error_value, result_trail); return absl::OkStatus(); } @@ -191,7 +196,7 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { } absl::optional marked_attribute_check = - CheckForMarkedAttributes(*frame, result_trail, frame->arena()); + CheckForMarkedAttributes(result_trail, frame); if (marked_attribute_check.has_value()) { frame->value_stack().PopAndPush(marked_attribute_check.value(), result_trail); @@ -203,7 +208,8 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { case CelValue::Type::kMap: { if (arg.MapOrDie() == nullptr) { frame->value_stack().PopAndPush( - CreateErrorValue(frame->arena(), "Map is NULL"), result_trail); + CreateErrorValue(frame->memory_manager(), "Map is NULL"), + result_trail); return absl::OkStatus(); } break; @@ -211,7 +217,8 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { case CelValue::Type::kMessage: { if (arg.MessageOrDie() == nullptr) { frame->value_stack().PopAndPush( - CreateErrorValue(frame->arena(), "Message is NULL"), result_trail); + CreateErrorValue(frame->memory_manager(), "Message is NULL"), + result_trail); return absl::OkStatus(); } break; @@ -225,11 +232,11 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { if (test_field_presence_) { if (arg.IsMap()) { frame->value_stack().PopAndPush( - TestOnlySelect(*arg.MapOrDie(), field_, frame->arena())); + TestOnlySelect(*arg.MapOrDie(), field_, frame->memory_manager())); return absl::OkStatus(); } else if (arg.IsMessage()) { frame->value_stack().PopAndPush( - TestOnlySelect(*arg.MessageOrDie(), field_, frame->arena())); + TestOnlySelect(*arg.MessageOrDie(), field_, frame->memory_manager())); return absl::OkStatus(); } } @@ -241,7 +248,8 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { // not null. const google::protobuf::Message* msg = arg.MessageOrDie(); - CEL_RETURN_IF_ERROR(CreateValueFromField(*msg, frame->arena(), &result)); + CEL_RETURN_IF_ERROR( + CreateValueFromField(*msg, frame->memory_manager(), &result)); frame->value_stack().PopAndPush(result, result_trail); return absl::OkStatus(); @@ -257,7 +265,7 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { if (lookup_result.has_value()) { result = *lookup_result; } else { - result = CreateNoSuchKeyError(frame->arena(), field_); + result = CreateNoSuchKeyError(frame->memory_manager(), field_); } frame->value_stack().PopAndPush(result, result_trail); return absl::OkStatus(); diff --git a/eval/eval/shadowable_value_step.cc b/eval/eval/shadowable_value_step.cc index 887f48e16..322278ec8 100644 --- a/eval/eval/shadowable_value_step.cc +++ b/eval/eval/shadowable_value_step.cc @@ -7,11 +7,15 @@ #include "absl/status/statusor.h" #include "eval/eval/expression_step_base.h" #include "eval/public/cel_value.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { +using ::cel::extensions::ProtoMemoryManager; + class ShadowableValueStep : public ExpressionStepBase { public: ShadowableValueStep(const std::string& identifier, const CelValue& value, @@ -26,7 +30,11 @@ class ShadowableValueStep : public ExpressionStepBase { }; absl::Status ShadowableValueStep::Evaluate(ExecutionFrame* frame) const { - auto var = frame->activation().FindValue(identifier_, frame->arena()); + // TODO(issues/5): update ValueProducer to support generic MemoryManager + // API. + google::protobuf::Arena* arena = + ProtoMemoryManager::CastToProtoArena(frame->memory_manager()); + auto var = frame->activation().FindValue(identifier_, arena); frame->value_stack().Push(var.value_or(value_)); return absl::OkStatus(); } diff --git a/eval/eval/ternary_step.cc b/eval/eval/ternary_step.cc index 97d6f2607..2393b9470 100644 --- a/eval/eval/ternary_step.cc +++ b/eval/eval/ternary_step.cc @@ -51,7 +51,8 @@ absl::Status TernaryStep::Evaluate(ExecutionFrame* frame) const { CelValue result; if (!condition.IsBool()) { - result = CreateNoMatchingOverloadError(frame->arena(), builtin::kTernary); + result = CreateNoMatchingOverloadError(frame->memory_manager(), + builtin::kTernary); } else if (condition.BoolOrDie()) { result = args.at(1); } else { diff --git a/eval/public/BUILD b/eval/public/BUILD index 11bf8a0ea..23788808d 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -501,6 +501,7 @@ cc_test( ":cel_function", "//eval/eval:attribute_trail", "//eval/eval:ident_step", + "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", "//parser", diff --git a/eval/public/activation_test.cc b/eval/public/activation_test.cc index 06b32ee4f..e225ea05a 100644 --- a/eval/public/activation_test.cc +++ b/eval/public/activation_test.cc @@ -7,6 +7,7 @@ #include "eval/eval/ident_step.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" @@ -18,7 +19,8 @@ namespace runtime { namespace { -using google::api::expr::v1alpha1::Expr; +using ::cel::extensions::ProtoMemoryManager; +using ::google::api::expr::v1alpha1::Expr; using ::google::protobuf::Arena; using testing::ElementsAre; using testing::Eq; @@ -204,6 +206,7 @@ TEST(ActivationTest, CheckValueProducerClear) { TEST(ActivationTest, ErrorPathTest) { Activation activation; Arena arena; + ProtoMemoryManager manager(&arena); Expr expr; auto* select_expr = expr.mutable_select_expr(); @@ -216,9 +219,9 @@ TEST(ActivationTest, ErrorPathTest) { "destination", {CelAttributeQualifierPattern::Create(CelValue::CreateStringView("ip"))}); - AttributeTrail trail(*ident_expr, &arena); + AttributeTrail trail(*ident_expr, manager); trail = trail.Step( - CelAttributeQualifier::Create(CelValue::CreateStringView("ip")), &arena); + CelAttributeQualifier::Create(CelValue::CreateStringView("ip")), manager); ASSERT_EQ(destination_ip_pattern.IsMatch(*trail.attribute()), CelAttributePattern::MatchType::FULL); diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 04f9c98d7..5dc894a9f 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -27,7 +27,7 @@ namespace google::api::expr::runtime { using CelEvaluationListener = std::function; -// An opaque state used for evaluation of a cell expression. +// An opaque state used for evaluation of a CEL expression. class CelEvaluationState { public: virtual ~CelEvaluationState() = default; diff --git a/eval/public/unknown_attribute_set.h b/eval/public/unknown_attribute_set.h index b3abdeeb2..a661de69f 100644 --- a/eval/public/unknown_attribute_set.h +++ b/eval/public/unknown_attribute_set.h @@ -19,7 +19,8 @@ class UnknownAttributeSet { UnknownAttributeSet& operator=(const UnknownAttributeSet& other) = default; UnknownAttributeSet() {} - UnknownAttributeSet(const std::vector& attributes) { + explicit UnknownAttributeSet( + const std::vector& attributes) { attributes_.reserve(attributes.size()); for (const auto& attr : attributes) { Add(attr); diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 86588ba62..404594065 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -25,7 +25,10 @@ cc_library( hdrs = ["memory_manager.h"], deps = [ "//base:memory_manager", + "//internal:casts", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) diff --git a/extensions/protobuf/memory_manager.h b/extensions/protobuf/memory_manager.h index d13e94bd9..f6d77c0bc 100644 --- a/extensions/protobuf/memory_manager.h +++ b/extensions/protobuf/memory_manager.h @@ -20,7 +20,10 @@ #include "google/protobuf/arena.h" #include "absl/base/attributes.h" #include "absl/base/macros.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "base/memory_manager.h" +#include "internal/casts.h" namespace cel::extensions { @@ -44,6 +47,15 @@ class ProtoMemoryManager final : public ArenaMemoryManager { constexpr google::protobuf::Arena* arena() const { return arena_; } + // Expose the underlying google::protobuf::Arena on a generic MemoryManager. This may + // only be called on an instance that is guaranteed to be a + // ProtoMemoryManager. + // + // Note: underlying arena may be null. + static google::protobuf::Arena* CastToProtoArena(MemoryManager& manager) { + return internal::down_cast(manager).arena(); + } + private: AllocationResult Allocate(size_t size, size_t align) override; From 12f334f67aa6e22437c7b181644c78b0016a2e22 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 2 Mar 2022 04:47:19 +0000 Subject: [PATCH 011/111] Expose CelMapBuilder class so it works better with memory manager abstraction. Update create map step to use builder allocated by the memory manager. PiperOrigin-RevId: 431842547 --- eval/eval/BUILD | 1 + eval/eval/create_struct_step.cc | 28 ++--- eval/public/containers/BUILD | 1 + .../containers/container_backed_map_impl.cc | 111 +++++------------- .../containers/container_backed_map_impl.h | 61 ++++++++-- .../container_backed_map_impl_test.cc | 70 +++++++++-- 6 files changed, 157 insertions(+), 115 deletions(-) diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 2456b7492..b918ed5de 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -220,6 +220,7 @@ cc_library( deps = [ ":evaluator_core", ":expression_step_base", + "//eval/public:cel_value", "//eval/public/containers:container_backed_map_impl", "//eval/public/containers:field_access", "//eval/public/structs:cel_proto_wrapper", diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 786e807ca..5ce180885 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -11,6 +11,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" #include "eval/eval/expression_step_base.h" +#include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/containers/field_access.h" #include "eval/public/structs/cel_proto_wrapper.h" @@ -230,32 +231,21 @@ absl::Status CreateStructStepForMap::DoEvaluate(ExecutionFrame* frame, } std::vector> map_entries; - map_entries.reserve(entry_count_); + auto map_builder = frame->memory_manager().New(); + for (size_t i = 0; i < entry_count_; i += 1) { int map_key_index = 2 * i; int map_value_index = map_key_index + 1; const CelValue& map_key = args[map_key_index]; CEL_RETURN_IF_ERROR(CelValue::CheckMapKeyType(map_key)); - map_entries.push_back({map_key, args[map_value_index]}); - } - - auto cel_map = - CreateContainerBackedMap(absl::Span>( - map_entries.data(), map_entries.size())); - if (!cel_map.ok()) { - *result = CreateErrorValue(frame->memory_manager(), cel_map.status()); - return absl::OkStatus(); + auto key_status = map_builder->Add(map_key, args[map_value_index]); + if (!key_status.ok()) { + *result = CreateErrorValue(frame->memory_manager(), key_status); + return absl::OkStatus(); + } } - auto cel_map_ptr = *std::move(cel_map); - *result = CelValue::CreateMap(cel_map_ptr.get()); - - // Pass object ownership to Arena. - // TODO(issues/5): Update CEL map implementation to tolerate generic - // allocation api. - google::protobuf::Arena* arena = - ProtoMemoryManager::CastToProtoArena(frame->memory_manager()); - arena->Own(cel_map_ptr.release()); + *result = CelValue::CreateMap(map_builder.release()); return absl::OkStatus(); } diff --git a/eval/public/containers/BUILD b/eval/public/containers/BUILD index 8c3dfd6ea..2d78c8681 100644 --- a/eval/public/containers/BUILD +++ b/eval/public/containers/BUILD @@ -113,6 +113,7 @@ cc_test( ":container_backed_map_impl", "//eval/public:cel_value", "//internal:testing", + "@com_google_absl//absl/status", ], ) diff --git a/eval/public/containers/container_backed_map_impl.cc b/eval/public/containers/container_backed_map_impl.cc index 37754ec8e..2bd3ea968 100644 --- a/eval/public/containers/container_backed_map_impl.cc +++ b/eval/public/containers/container_backed_map_impl.cc @@ -1,5 +1,7 @@ #include "eval/public/containers/container_backed_map_impl.h" +#include + #include "absl/container/node_hash_map.h" #include "absl/hash/hash.h" #include "absl/status/status.h" @@ -79,96 +81,47 @@ class CelValueEq { const CelValue& other_; }; -// CelValue hasher functor. -class Hasher { - public: - size_t operator()(const CelValue& key) const { - return key.template Visit(HasherOp()); - } -}; - -// CelValue equality functor. -class Equal { - public: - // - bool operator()(const CelValue& key1, const CelValue& key2) const { - if (key1.type() != key2.type()) { - return false; - } - return key1.template Visit(CelValueEq(key2)); - } -}; +} // namespace -// CelMap implementation that uses STL map container as backing storage. -// KeyType is the type of key values stored in CelValue, InnerKeyType is the -// type of key in STL map. -class ContainerBackedMapImpl : public CelMap { - public: - static absl::StatusOr> Create( - absl::Span> key_values) { - auto cel_map = absl::WrapUnique(new ContainerBackedMapImpl()); - auto status = cel_map->AddItems(key_values); - if (!status.ok()) { - return status; - } - return cel_map; +// Map element access operator. +absl::optional CelMapBuilder::operator[](CelValue cel_key) const { + auto item = values_map_.find(cel_key); + if (item == values_map_.end()) { + return absl::nullopt; } + return item->second; +} - // Map size. - int size() const override { return values_map_.size(); } - - // Map element access operator. - absl::optional operator[](CelValue cel_key) const override { - auto item = values_map_.find(cel_key); - if (item == values_map_.end()) { - return absl::nullopt; - } - return item->second; - } +absl::Status CelMapBuilder::Add(CelValue key, CelValue value) { + auto [unused, inserted] = values_map_.emplace(key, value); - absl::StatusOr Has(const CelValue& cel_key) const override { - return values_map_.contains(cel_key); + if (!inserted) { + return absl::InvalidArgumentError("duplicate map keys"); } + key_list_.Add(key); + return absl::OkStatus(); +} - const CelList* ListKeys() const override { return &key_list_; } - - private: - class KeyList : public CelList { - public: - int size() const override { return keys_.size(); } - - CelValue operator[](int index) const override { return keys_[index]; } - - void Add(const CelValue& key) { keys_.push_back(key); } - - private: - std::vector keys_; - }; - - ContainerBackedMapImpl() = default; - - absl::Status AddItems(absl::Span> key_values) { - for (const auto& item : key_values) { - auto result = values_map_.emplace(item.first, item.second); +// CelValue hasher functor. +size_t CelMapBuilder::Hasher::operator()(const CelValue& key) const { + return key.template Visit(HasherOp()); +} - // Failed to insert pair into map - addition failed. - if (!result.second) { - return absl::InvalidArgumentError("duplicate map keys"); - } - key_list_.Add(item.first); - } - return absl::OkStatus(); +bool CelMapBuilder::Equal::operator()(const CelValue& key1, + const CelValue& key2) const { + if (key1.type() != key2.type()) { + return false; } - - absl::node_hash_map values_map_; - KeyList key_list_; -}; - -} // namespace + return key1.template Visit(CelValueEq(key2)); +} absl::StatusOr> CreateContainerBackedMap( absl::Span> key_values) { - return ContainerBackedMapImpl::Create(key_values); + auto map = std::make_unique(); + for (const auto& key_value : key_values) { + CEL_RETURN_IF_ERROR(map->Add(key_value.first, key_value.second)); + } + return map; } } // namespace runtime diff --git a/eval/public/containers/container_backed_map_impl.h b/eval/public/containers/container_backed_map_impl.h index 8865352e0..ea1976715 100644 --- a/eval/public/containers/container_backed_map_impl.h +++ b/eval/public/containers/container_backed_map_impl.h @@ -4,22 +4,65 @@ #include #include +#include "absl/container/node_hash_map.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "eval/public/cel_value.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { -// Template factory method creating container-backed CelMap. +// CelMap implementation that uses STL map container as backing storage. +// KeyType is the type of key values stored in CelValue. +// After building, upcast to CelMap to prevent further additions. +class CelMapBuilder : public CelMap { + public: + CelMapBuilder() {} + + // Try to insert a key value pair into the map. Returns a status if key + // already exists. + absl::Status Add(CelValue key, CelValue value); + + int size() const override { return values_map_.size(); } + + absl::optional operator[](CelValue cel_key) const override; + + absl::StatusOr Has(const CelValue& cel_key) const override { + return values_map_.contains(cel_key); + } + + const CelList* ListKeys() const override { return &key_list_; } + + private: + // Custom CelList implementation for maintaining key list. + class KeyList : public CelList { + public: + KeyList() {} + + int size() const override { return keys_.size(); } + + CelValue operator[](int index) const override { return keys_[index]; } + + void Add(const CelValue& key) { keys_.push_back(key); } + + private: + std::vector keys_; + }; + + struct Hasher { + size_t operator()(const CelValue& key) const; + }; + struct Equal { + bool operator()(const CelValue& key1, const CelValue& key2) const; + }; + + absl::node_hash_map values_map_; + KeyList key_list_; +}; + +// Factory method creating container-backed CelMap. absl::StatusOr> CreateContainerBackedMap( absl::Span> key_values); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_CONTAINER_BACKED_MAP_IMPL_H_ diff --git a/eval/public/containers/container_backed_map_impl_test.cc b/eval/public/containers/container_backed_map_impl_test.cc index 971e804f5..ff4ac43ac 100644 --- a/eval/public/containers/container_backed_map_impl_test.cc +++ b/eval/public/containers/container_backed_map_impl_test.cc @@ -4,19 +4,18 @@ #include #include +#include "absl/status/status.h" #include "eval/public/cel_value.h" #include "internal/testing.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { using testing::Eq; using testing::IsNull; using testing::Not; +using cel::internal::StatusIs; TEST(ContainerBackedMapImplTest, TestMapInt64) { std::vector> args = { @@ -125,9 +124,64 @@ TEST(ContainerBackedMapImplTest, TestMapString) { ASSERT_FALSE(lookup3); } +TEST(CelMapBuilder, TestMapString) { + const std::string kKey1 = "1"; + const std::string kKey2 = "2"; + const std::string kKey3 = "3"; + + std::vector> args = { + {CelValue::CreateString(&kKey1), CelValue::CreateInt64(2)}, + {CelValue::CreateString(&kKey2), CelValue::CreateInt64(3)}}; + CelMapBuilder builder; + ASSERT_OK( + builder.Add(CelValue::CreateString(&kKey1), CelValue::CreateInt64(2))); + ASSERT_OK( + builder.Add(CelValue::CreateString(&kKey2), CelValue::CreateInt64(3))); + + CelMap* cel_map = &builder; + + ASSERT_THAT(cel_map, Not(IsNull())); + + EXPECT_THAT(cel_map->size(), Eq(2)); + + // Test lookup with key == 1 ( should succeed ) + auto lookup1 = (*cel_map)[CelValue::CreateString(&kKey1)]; + + ASSERT_TRUE(lookup1); + + CelValue cel_value = lookup1.value(); + + ASSERT_TRUE(cel_value.IsInt64()); + EXPECT_THAT(cel_value.Int64OrDie(), 2); + + // Test lookup with different type ( should fail ) + auto lookup2 = (*cel_map)[CelValue::CreateInt64(1)]; + + ASSERT_FALSE(lookup2); + + // Test lookup with key3 ( should fail ) + auto lookup3 = (*cel_map)[CelValue::CreateString(&kKey3)]; + + ASSERT_FALSE(lookup3); +} + +TEST(CelMapBuilder, RepeatKeysFail) { + const std::string kKey1 = "1"; + const std::string kKey2 = "2"; + + std::vector> args = { + {CelValue::CreateString(&kKey1), CelValue::CreateInt64(2)}, + {CelValue::CreateString(&kKey2), CelValue::CreateInt64(3)}}; + CelMapBuilder builder; + ASSERT_OK( + builder.Add(CelValue::CreateString(&kKey1), CelValue::CreateInt64(2))); + ASSERT_OK( + builder.Add(CelValue::CreateString(&kKey2), CelValue::CreateInt64(3))); + EXPECT_THAT( + builder.Add(CelValue::CreateString(&kKey2), CelValue::CreateInt64(3)), + StatusIs(absl::StatusCode::kInvalidArgument, "duplicate map keys")); +} + } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime From 9c3d2cf868d75c41dd2f758c9f077df96122b543 Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 2 Mar 2022 19:42:03 +0000 Subject: [PATCH 012/111] Internal change PiperOrigin-RevId: 431990497 --- base/BUILD | 4 +- base/internal/BUILD | 4 +- base/internal/memory_manager.h | 45 +--- base/memory_manager.cc | 211 +++++++++++++++++- base/memory_manager.h | 242 +++++++++++++++++---- base/memory_manager_test.cc | 39 ---- extensions/protobuf/memory_manager.cc | 22 +- extensions/protobuf/memory_manager.h | 10 +- extensions/protobuf/memory_manager_test.cc | 4 - internal/BUILD | 5 + internal/no_destructor.h | 92 ++++++++ 11 files changed, 531 insertions(+), 147 deletions(-) create mode 100644 internal/no_destructor.h diff --git a/base/BUILD b/base/BUILD index d9e7f28a3..01b9b3055 100644 --- a/base/BUILD +++ b/base/BUILD @@ -43,7 +43,10 @@ cc_library( hdrs = ["memory_manager.h"], deps = [ "//base/internal:memory_manager", + "//internal:no_destructor", + "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/numeric:bits", ], ) @@ -52,7 +55,6 @@ cc_test( srcs = ["memory_manager_test.cc"], deps = [ ":memory_manager", - "//base/internal:memory_manager", "//internal:testing", ], ) diff --git a/base/internal/BUILD b/base/internal/BUILD index ea842ae96..d32a4ef19 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -18,7 +18,9 @@ licenses(["notice"]) cc_library( name = "memory_manager", - textual_hdrs = ["memory_manager.h"], + textual_hdrs = [ + "memory_manager.h", + ], ) cc_library( diff --git a/base/internal/memory_manager.h b/base/internal/memory_manager.h index 785cc8a72..eec6c6dc3 100644 --- a/base/internal/memory_manager.h +++ b/base/internal/memory_manager.h @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +// IWYU pragma: private + #ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_H_ #define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_H_ +#include #include namespace cel { @@ -23,48 +26,6 @@ class MemoryManager; namespace base_internal { -template -class MemoryManagerDeleter; - -// True if the deleter is no-op, meaning the object was allocated in an arena -// and the arena will perform any deletion upon its own destruction. -template -bool IsEmptyDeleter(const MemoryManagerDeleter& deleter); - -template -class MemoryManagerDeleter final { - public: - constexpr MemoryManagerDeleter() noexcept = default; - - MemoryManagerDeleter(const MemoryManagerDeleter&) = delete; - - constexpr MemoryManagerDeleter(MemoryManagerDeleter&& other) noexcept - : MemoryManagerDeleter() { - std::swap(memory_manager_, other.memory_manager_); - std::swap(size_, other.size_); - std::swap(align_, other.align_); - } - - void operator()(T* pointer) const; - - private: - friend class cel::MemoryManager; - template - friend bool IsEmptyDeleter(const MemoryManagerDeleter& deleter); - - MemoryManagerDeleter(MemoryManager* memory_manager, size_t size, size_t align) - : memory_manager_(memory_manager), size_(size), align_(align) {} - - MemoryManager* memory_manager_ = nullptr; - size_t size_ = 0; - size_t align_ = 0; -}; - -template -bool IsEmptyDeleter(const MemoryManagerDeleter& deleter) { - return deleter.memory_manager_ == nullptr; -} - template class MemoryManagerDestructor final { private: diff --git a/base/memory_manager.cc b/base/memory_manager.cc index f10c8b406..56d9f670f 100644 --- a/base/memory_manager.cc +++ b/base/memory_manager.cc @@ -14,36 +14,237 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include #include +#include #include +#include #include "absl/base/attributes.h" +#include "absl/base/config.h" #include "absl/base/macros.h" +#include "absl/numeric/bits.h" +#include "internal/no_destructor.h" namespace cel { namespace { class GlobalMemoryManager final : public MemoryManager { + public: + GlobalMemoryManager() : MemoryManager() {} + private: AllocationResult Allocate(size_t size, size_t align) override { - return {::operator new(size, static_cast(align), - std::nothrow), - true}; + void* pointer; + if (ABSL_PREDICT_TRUE(align <= alignof(std::max_align_t))) { + pointer = ::operator new(size, std::nothrow); + } else { + pointer = ::operator new(size, static_cast(align), + std::nothrow); + } + return {pointer}; } void Deallocate(void* pointer, size_t size, size_t align) override { - ::operator delete(pointer, size, static_cast(align)); + if (ABSL_PREDICT_TRUE(align <= alignof(std::max_align_t))) { + ::operator delete(pointer, size); + } else { + ::operator delete(pointer, size, static_cast(align)); + } } }; +struct ControlBlock final { + constexpr explicit ControlBlock(MemoryManager* memory_manager) + : refs(1), memory_manager(memory_manager) {} + + ControlBlock(const ControlBlock&) = delete; + ControlBlock(ControlBlock&&) = delete; + ControlBlock& operator=(const ControlBlock&) = delete; + ControlBlock& operator=(ControlBlock&&) = delete; + + mutable std::atomic refs; + MemoryManager* memory_manager; + + void Ref() const { + const auto cnt = refs.fetch_add(1, std::memory_order_relaxed); + ABSL_ASSERT(cnt >= 1); + } + + bool Unref() const { + const auto cnt = refs.fetch_sub(1, std::memory_order_acq_rel); + ABSL_ASSERT(cnt >= 1); + return cnt == 1; + } +}; + +size_t AlignUp(size_t size, size_t align) { + ABSL_ASSERT(size != 0); + ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. +#if ABSL_HAVE_BUILTIN(__builtin_align_up) + return __builtin_align_up(size, align); +#else + return (size + align - size_t{1}) & ~(align - size_t{1}); +#endif +} + +inline constexpr size_t kControlBlockSize = sizeof(ControlBlock); +inline constexpr size_t kControlBlockAlign = alignof(ControlBlock); + +// When not using arena-based allocation, MemoryManager needs to embed a pointer +// to itself in the allocation block so the same memory manager can be used to +// deallocate. When the alignment requested is less than or equal to that of the +// native pointer alignment it is embedded at the beginning of the allocated +// block, otherwise its at the end. +// +// For allocations requiring alignment greater than alignof(ControlBlock) we +// cannot place the control block in front as it would change the alignment of +// T, resulting in undefined behavior. For allocations requiring less alignment +// than alignof(ControlBlock), we should not place the control back in back as +// it would waste memory due to having to pad the allocation to ensure +// ControlBlock itself is aligned. +enum class Placement { + kBefore = 0, + kAfter, +}; + +constexpr Placement GetPlacement(size_t align) { + return ABSL_PREDICT_TRUE(align <= kControlBlockAlign) ? Placement::kBefore + : Placement::kAfter; +} + +void* AdjustAfterAllocation(MemoryManager* memory_manager, void* pointer, + size_t size, size_t align) { + switch (GetPlacement(align)) { + case Placement::kBefore: + // Store the pointer to the memory manager at the beginning of the + // allocated block and adjust the pointer to immediately after it. + ::new (pointer) ControlBlock(memory_manager); + pointer = static_cast(static_cast(pointer) + + kControlBlockSize); + break; + case Placement::kAfter: + // Store the pointer to the memory manager at the end of the allocated + // block. Don't need to adjust the pointer. + ::new (static_cast(static_cast(pointer) + size - + kControlBlockSize)) + ControlBlock(memory_manager); + break; + } + return pointer; +} + +void* AdjustForDeallocation(void* pointer, size_t align) { + switch (GetPlacement(align)) { + case Placement::kBefore: + // We need to back up kPointerSize as that is actually the original + // allocated address returned from `Allocate`. + pointer = static_cast(static_cast(pointer) - + kControlBlockSize); + break; + case Placement::kAfter: + // No need to do anything. + break; + } + return pointer; +} + +ControlBlock* GetControlBlock(const void* pointer, size_t size, size_t align) { + ControlBlock* control_block; + switch (GetPlacement(align)) { + case Placement::kBefore: + // Embedded reference count block is located just before `pointer`. + control_block = reinterpret_cast( + static_cast(const_cast(pointer)) - + kControlBlockSize); + break; + case Placement::kAfter: + // Embedded reference count block is located at `pointer + size - + // kControlBlockSize`. + control_block = reinterpret_cast( + static_cast(const_cast(pointer)) + size - + kControlBlockSize); + break; + } + return control_block; +} + +size_t AdjustAllocationSize(size_t size, size_t align) { + if (GetPlacement(align) == Placement::kAfter) { + size = AlignUp(size, kControlBlockAlign); + } + return size + kControlBlockSize; +} + } // namespace MemoryManager& MemoryManager::Global() { - static MemoryManager* const instance = new GlobalMemoryManager(); + static internal::NoDestructor instance; return *instance; } +void* MemoryManager::AllocateInternal(size_t& size, size_t& align) { + ABSL_ASSERT(size != 0); + ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. + size_t adjusted_size = size; + if (!allocation_only_) { + adjusted_size = AdjustAllocationSize(adjusted_size, align); + } + auto [pointer] = Allocate(adjusted_size, align); + if (ABSL_PREDICT_TRUE(pointer != nullptr) && !allocation_only_) { + pointer = AdjustAfterAllocation(this, pointer, adjusted_size, align); + } else { + // 0 is not a valid result of sizeof. So we use that to signal to the + // deleter that it should not perform a deletion and that the memory manager + // will. + size = align = 0; + } + return pointer; +} + +void MemoryManager::DeallocateInternal(void* pointer, size_t size, + size_t align) { + ABSL_ASSERT(pointer != nullptr); + ABSL_ASSERT(size != 0); + ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. + // `size` is the unadjusted size, the original sizeof(T) used during + // allocation. We need to adjust it to match the allocation size. + size = AdjustAllocationSize(size, align); + ControlBlock* control_block = GetControlBlock(pointer, size, align); + MemoryManager* memory_manager = control_block->memory_manager; + if constexpr (!std::is_trivially_destructible_v) { + control_block->~ControlBlock(); + } + pointer = AdjustForDeallocation(pointer, align); + memory_manager->Deallocate(pointer, size, align); +} + +void MemoryManager::Ref(const void* pointer, size_t size, size_t align) { + if (pointer != nullptr && size != 0) { + ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. + // `size` is the unadjusted size, the original sizeof(T) used during + // allocation. We need to adjust it to match the allocation size. + size = AdjustAllocationSize(size, align); + GetControlBlock(pointer, size, align)->Ref(); + } +} + +bool MemoryManager::UnrefInternal(const void* pointer, size_t size, + size_t align) { + bool cleanup = false; + if (pointer != nullptr && size != 0) { + ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. + // `size` is the unadjusted size, the original sizeof(T) used during + // allocation. We need to adjust it to match the allocation size. + size = AdjustAllocationSize(size, align); + cleanup = GetControlBlock(pointer, size, align)->Unref(); + } + return cleanup; +} + void MemoryManager::OwnDestructor(void* pointer, void (*destruct)(void*)) { static_cast(pointer); static_cast(destruct); diff --git a/base/memory_manager.h b/base/memory_manager.h index 5903f23aa..a02b318c6 100644 --- a/base/memory_manager.h +++ b/base/memory_manager.h @@ -23,18 +23,127 @@ #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/base/optimization.h" -#include "base/internal/memory_manager.h" +#include "base/internal/memory_manager.h" // IWYU pragma: export namespace cel { +class MemoryManager; +class ArenaMemoryManager; + // `ManagedMemory` is a smart pointer which ensures any applicable object -// destructors and deallocation are eventually performed upon its destruction. -// While `ManagedManager` is derived from `std::unique_ptr`, it does not make -// any guarantees that destructors and deallocation are run immediately upon its -// destruction, just that they will eventually be performed. +// destructors and deallocation are eventually performed. Copying does not +// actually copy the underlying T, instead a pointer is copied and optionally +// reference counted. Moving does not actually move the underlying T, instead a +// pointer is moved. +// +// TODO(issues/5): consider feature parity with std::unique_ptr template -using ManagedMemory = - std::unique_ptr>; +class ManagedMemory final { + public: + ManagedMemory() = default; + + ManagedMemory(const ManagedMemory& other) + : ptr_(other.ptr_), size_(other.size_), align_(other.align_) { + Ref(); + } + + ManagedMemory(ManagedMemory&& other) + : ptr_(other.ptr_), size_(other.size_), align_(other.align_) { + other.ptr_ = nullptr; + other.size_ = other.align_ = 0; + } + + ~ManagedMemory() { Unref(); } + + ManagedMemory& operator=(const ManagedMemory& other) { + if (ABSL_PREDICT_TRUE(this != std::addressof(other))) { + other.Ref(); + Unref(); + ptr_ = other.ptr_; + size_ = other.size_; + align_ = other.align_; + } + return *this; + } + + ManagedMemory& operator=(ManagedMemory&& other) { + if (ABSL_PREDICT_TRUE(this != std::addressof(other))) { + reset(); + swap(other); + } + return *this; + } + + T* release() { + ABSL_ASSERT(size_ == 0); + T* ptr = ptr_; + ptr_ = nullptr; + size_ = align_ = 0; + return ptr; + } + + void reset() { + Unref(); + ptr_ = nullptr; + size_ = align_ = 0; + } + + void swap(ManagedMemory& other) { + std::swap(ptr_, other.ptr_); + std::swap(size_, other.size_); + std::swap(align_, other.align_); + } + + constexpr T& get() ABSL_ATTRIBUTE_LIFETIME_BOUND { return *ptr_; } + + constexpr const T& get() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return *ptr_; } + + constexpr T& operator*() ABSL_ATTRIBUTE_LIFETIME_BOUND { return get(); } + + constexpr const T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return get(); + } + + constexpr T* operator->() { return ptr_; } + + constexpr const T* operator->() const { return ptr_; } + + constexpr explicit operator bool() const { return ptr_ != nullptr; } + + private: + friend class MemoryManager; + + constexpr ManagedMemory(T* ptr, size_t size, size_t align) + : ptr_(ptr), size_(size), align_(align) {} + + void Ref() const; + + void Unref() const; + + T* ptr_ = nullptr; + size_t size_ = 0; + size_t align_ = 0; +}; + +template +bool operator==(const ManagedMemory& lhs, std::nullptr_t) { + return lhs.get() == nullptr; +} + +template +bool operator==(std::nullptr_t, const ManagedMemory& rhs) { + return rhs.get() == nullptr; +} + +template +bool operator!=(const ManagedMemory& lhs, std::nullptr_t) { + return !operator==(lhs, nullptr); +} + +template +bool operator!=(std::nullptr_t, const ManagedMemory& rhs) { + return !operator==(nullptr, rhs); +} // `MemoryManager` is an abstraction over memory management that supports // different allocation strategies. @@ -47,47 +156,78 @@ class MemoryManager { // Allocates and constructs `T`. In the event of an allocation failure nullptr // is returned. template - ManagedMemory New(Args&&... args) - ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { + std::enable_if_t, ManagedMemory> New( + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { size_t size = sizeof(T); size_t align = alignof(T); - auto [pointer, owned] = Allocate(size, align); - if (ABSL_PREDICT_FALSE(pointer == nullptr)) { - return ManagedMemory(); - } - ::new (pointer) T(std::forward(args)...); - if constexpr (!std::is_trivially_destructible_v) { - if (!owned) { - OwnDestructor(pointer, - &base_internal::MemoryManagerDestructor::Destruct); + void* pointer = AllocateInternal(size, align); + if (ABSL_PREDICT_TRUE(pointer != nullptr)) { + ::new (pointer) T(std::forward(args)...); + if constexpr (!std::is_trivially_destructible_v) { + if (allocation_only_) { + OwnDestructor(pointer, + &base_internal::MemoryManagerDestructor::Destruct); + } } } - return ManagedMemory(reinterpret_cast(pointer), - base_internal::MemoryManagerDeleter( - owned ? this : nullptr, size, align)); + return ManagedMemory(reinterpret_cast(pointer), size, align); } protected: + MemoryManager() : MemoryManager(false) {} + template struct AllocationResult final { Pointer pointer = nullptr; - // If true, the responsibility of deallocating and destructing `pointer` is - // passed to the caller of `Allocate`. - bool owned = false; }; private: template - friend class base_internal::MemoryManagerDeleter; + friend class ManagedMemory; + friend class ArenaMemoryManager; + + // Only for use by ArenaMemoryManager. + explicit MemoryManager(bool allocation_only) + : allocation_only_(allocation_only) {} + + void* AllocateInternal(size_t& size, size_t& align); + + static void DeallocateInternal(void* pointer, size_t size, size_t align); - // Delete a previous `New()` result when `AllocationResult::owned` is true. + // Potentially increment the reference count in the control block for the + // previously allocated memory from `New()`. This is intended to be called + // from `ManagedMemory`. + // + // If size is 0, then the allocation was arena-based. + static void Ref(const void* pointer, size_t size, size_t align); + + // Potentially decrement the reference count in the control block for the + // previously allocated memory from `New()`. Returns true if `Delete()` should + // be called. + // + // If size is 0, then the allocation was arena-based and this call is a noop. + static bool UnrefInternal(const void* pointer, size_t size, size_t align); + + // Delete a previous `New()` result when `allocation_only_` is false. template - void Delete(T* pointer, size_t size, size_t align) { - if (pointer != nullptr) { - if constexpr (!std::is_trivially_destructible_v) { - pointer->~T(); - } - Deallocate(pointer, size, align); + static void Delete(T* pointer, size_t size, size_t align) { + if constexpr (!std::is_trivially_destructible_v) { + pointer->~T(); + } + DeallocateInternal( + static_cast(const_cast*>(pointer)), size, + align); + } + + // Potentially decrement the reference count in the control block and + // deallocate the memory for the previously allocated memory from `New()`. + // This is intended to be called from `ManagedMemory`. + // + // If size is 0, then the allocation was arena-based and this call is a noop. + template + static void Unref(T* pointer, size_t size, size_t align) { + if (UnrefInternal(pointer, size, align)) { + Delete(pointer, size, align); } } @@ -115,11 +255,37 @@ class MemoryManager { // // This method is only valid for arena memory managers. virtual void OwnDestructor(void* pointer, void (*destruct)(void*)); + + const bool allocation_only_; }; +template +void ManagedMemory::Ref() const { + MemoryManager::Ref(ptr_, size_, align_); +} + +template +void ManagedMemory::Unref() const { + MemoryManager::Unref(ptr_, size_, align_); +} + +namespace extensions { +class ProtoMemoryManager; +} + // Base class for all arena-based memory managers. class ArenaMemoryManager : public MemoryManager { + protected: + ArenaMemoryManager() : ArenaMemoryManager(true) {} + private: + friend class extensions::ProtoMemoryManager; + + // Private so that only ProtoMemoryManager can use it for legacy reasons. All + // other derivations of ArenaMemoryManager should be allocation-only. + explicit ArenaMemoryManager(bool allocation_only) + : MemoryManager(allocation_only) {} + // Default implementation calls std::abort(). If you have a special case where // you support deallocating individual allocations, override this. void Deallocate(void* pointer, size_t size, size_t align) override; @@ -128,18 +294,6 @@ class ArenaMemoryManager : public MemoryManager { void OwnDestructor(void* pointer, void (*destruct)(void*)) override = 0; }; -namespace base_internal { - -template -void MemoryManagerDeleter::operator()(T* pointer) const { - if (memory_manager_) { - memory_manager_->Delete(const_cast*>(pointer), size_, - align_); - } -} - -} // namespace base_internal - } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_MEMORY_MANAGER_H_ diff --git a/base/memory_manager_test.cc b/base/memory_manager_test.cc index 6f7a70f13..dc4b4f7df 100644 --- a/base/memory_manager_test.cc +++ b/base/memory_manager_test.cc @@ -16,7 +16,6 @@ #include -#include "base/internal/memory_manager.h" #include "internal/testing.h" namespace cel { @@ -27,7 +26,6 @@ struct TriviallyDestructible final {}; TEST(GlobalMemoryManager, TriviallyDestructible) { EXPECT_TRUE(std::is_trivially_destructible_v); auto managed = MemoryManager::Global().New(); - EXPECT_FALSE(base_internal::IsEmptyDeleter(managed.get_deleter())); } struct NotTriviallyDestuctible final { @@ -39,45 +37,8 @@ struct NotTriviallyDestuctible final { TEST(GlobalMemoryManager, NotTriviallyDestuctible) { EXPECT_FALSE(std::is_trivially_destructible_v); auto managed = MemoryManager::Global().New(); - EXPECT_FALSE(base_internal::IsEmptyDeleter(managed.get_deleter())); EXPECT_CALL(*managed, Delete()); } -class BadMemoryManager final : public MemoryManager { - private: - AllocationResult Allocate(size_t size, size_t align) override { - // Return {..., false}, indicating that this was an arena allocation when it - // is not, causing OwnDestructor to be called and abort. - return {::operator new(size, static_cast(align)), false}; - } - - void Deallocate(void* pointer, size_t size, size_t align) override { - ::operator delete(pointer, size, static_cast(align)); - } -}; - -TEST(BadMemoryManager, OwnDestructorAborts) { - BadMemoryManager memory_manager; - EXPECT_EXIT(static_cast(memory_manager.New()), - testing::KilledBySignal(SIGABRT), ""); -} - -class BadArenaMemoryManager final : public ArenaMemoryManager { - private: - AllocationResult Allocate(size_t size, size_t align) override { - // Return {..., false}, indicating that this was an arena allocation when it - // is not, causing OwnDestructor to be called and abort. - return {::operator new(size, static_cast(align)), true}; - } - - void OwnDestructor(void* pointer, void (*destructor)(void*)) override {} -}; - -TEST(BadArenaMemoryManager, DeallocateAborts) { - BadArenaMemoryManager memory_manager; - EXPECT_EXIT(static_cast(memory_manager.New()), - testing::KilledBySignal(SIGABRT), ""); -} - } // namespace } // namespace cel diff --git a/extensions/protobuf/memory_manager.cc b/extensions/protobuf/memory_manager.cc index 9d88069c4..7e8d92eb8 100644 --- a/extensions/protobuf/memory_manager.cc +++ b/extensions/protobuf/memory_manager.cc @@ -14,26 +14,38 @@ #include "extensions/protobuf/memory_manager.h" +#include #include #include "absl/base/macros.h" +#include "absl/base/optimization.h" namespace cel::extensions { MemoryManager::AllocationResult ProtoMemoryManager::Allocate( size_t size, size_t align) { + void* pointer; if (arena_ != nullptr) { - return {arena_->AllocateAligned(size, align), false}; + pointer = arena_->AllocateAligned(size, align); + } else { + if (ABSL_PREDICT_TRUE(align <= alignof(std::max_align_t))) { + pointer = ::operator new(size, std::nothrow); + } else { + pointer = ::operator new(size, static_cast(align), + std::nothrow); + } } - return { - ::operator new(size, static_cast(align), std::nothrow), - true}; + return {pointer}; } void ProtoMemoryManager::Deallocate(void* pointer, size_t size, size_t align) { // Only possible when `arena_` is nullptr. ABSL_HARDENING_ASSERT(arena_ == nullptr); - ::operator delete(pointer, size, static_cast(align)); + if (ABSL_PREDICT_TRUE(align <= alignof(std::max_align_t))) { + ::operator delete(pointer, size); + } else { + ::operator delete(pointer, size, static_cast(align)); + } } void ProtoMemoryManager::OwnDestructor(void* pointer, void (*destruct)(void*)) { diff --git a/extensions/protobuf/memory_manager.h b/extensions/protobuf/memory_manager.h index f6d77c0bc..4d515140c 100644 --- a/extensions/protobuf/memory_manager.h +++ b/extensions/protobuf/memory_manager.h @@ -34,8 +34,9 @@ class ProtoMemoryManager final : public ArenaMemoryManager { public: // Passing a nullptr is highly discouraged, but supported for backwards // compatibility. If `arena` is a nullptr, `ProtoMemoryManager` acts like - // `MemoryManager::Default()`. - explicit ProtoMemoryManager(google::protobuf::Arena* arena) : arena_(arena) {} + // `MemoryManager::Default()` and then must outlive all allocations. + explicit ProtoMemoryManager(google::protobuf::Arena* arena) + : ArenaMemoryManager(arena != nullptr), arena_(arena) {} ProtoMemoryManager(const ProtoMemoryManager&) = delete; @@ -73,11 +74,8 @@ class ProtoMemoryManager final : public ArenaMemoryManager { template ABSL_MUST_USE_RESULT T* NewInProtoArena(MemoryManager& memory_manager, Args&&... args) { -#if !defined(__GNUC__) || defined(__GXX_RTTI) - ABSL_ASSERT(dynamic_cast(&memory_manager) != nullptr); -#endif return google::protobuf::Arena::Create( - static_cast(memory_manager).arena(), + ProtoMemoryManager::CastToProtoArena(memory_manager), std::forward(args)...); } diff --git a/extensions/protobuf/memory_manager_test.cc b/extensions/protobuf/memory_manager_test.cc index 0db014f2d..1290d8b7b 100644 --- a/extensions/protobuf/memory_manager_test.cc +++ b/extensions/protobuf/memory_manager_test.cc @@ -78,7 +78,6 @@ TEST(ProtoMemoryManager, TriviallyDestructible) { ProtoMemoryManager memory_manager(&arena); EXPECT_TRUE(std::is_trivially_destructible_v); auto managed = memory_manager.New(); - EXPECT_TRUE(base_internal::IsEmptyDeleter(managed.get_deleter())); } TEST(ProtoMemoryManager, NotTriviallyDestuctible) { @@ -86,7 +85,6 @@ TEST(ProtoMemoryManager, NotTriviallyDestuctible) { ProtoMemoryManager memory_manager(&arena); EXPECT_FALSE(std::is_trivially_destructible_v); auto managed = memory_manager.New(); - EXPECT_TRUE(base_internal::IsEmptyDeleter(managed.get_deleter())); EXPECT_CALL(*managed, Delete()); } @@ -94,14 +92,12 @@ TEST(ProtoMemoryManagerNoArena, TriviallyDestructible) { ProtoMemoryManager memory_manager(nullptr); EXPECT_TRUE(std::is_trivially_destructible_v); auto managed = memory_manager.New(); - EXPECT_FALSE(base_internal::IsEmptyDeleter(managed.get_deleter())); } TEST(ProtoMemoryManagerNoArena, NotTriviallyDestuctible) { ProtoMemoryManager memory_manager(nullptr); EXPECT_FALSE(std::is_trivially_destructible_v); auto managed = memory_manager.New(); - EXPECT_FALSE(base_internal::IsEmptyDeleter(managed.get_deleter())); EXPECT_CALL(*managed, Delete()); } diff --git a/internal/BUILD b/internal/BUILD index 33e8d2460..9a0c1dfd5 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -126,6 +126,11 @@ cc_test( ], ) +cc_library( + name = "no_destructor", + hdrs = ["no_destructor.h"], +) + cc_library( name = "proto_util", srcs = ["proto_util.cc"], diff --git a/internal/no_destructor.h b/internal/no_destructor.h new file mode 100644 index 000000000..7e8c44c24 --- /dev/null +++ b/internal/no_destructor.h @@ -0,0 +1,92 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_NO_DESTRUCTOR_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_NO_DESTRUCTOR_H_ + +#include +#include +#include +#include + +namespace cel::internal { + +// `NoDestructor` is primarily useful in optimizing the pattern of safe +// on-demand construction of an object with a non-trivial destructor in static +// storage without ever having the destructor called. By using `NoDestructor` +// there is no need to involve a heap allocation. +template +class NoDestructor final { + public: + template + explicit constexpr NoDestructor(Args&&... args) + : impl_(std::in_place, std::forward(args)...) {} + + NoDestructor(const NoDestructor&) = delete; + NoDestructor(NoDestructor&&) = delete; + NoDestructor& operator=(const NoDestructor&) = delete; + NoDestructor& operator=(NoDestructor&&) = delete; + + T& get() { return impl_.get(); } + + const T& get() const { return impl_.get(); } + + T& operator*() { return get(); } + + const T& operator*() const { return get(); } + + T* operator->() { return std::addressof(get()); } + + const T* operator->() const { return std::addressof(get()); } + + private: + class TrivialImpl final { + public: + template + explicit constexpr TrivialImpl(std::in_place_t, Args&&... args) + : value_(std::forward(args)...) {} + + T& get() { return value_; } + + const T& get() const { return value_; } + + private: + T value_; + }; + + class PlacementImpl final { + public: + template + explicit PlacementImpl(std::in_place_t, Args&&... args) { + ::new (static_cast(&value_)) T(std::forward(args)...); + } + + T& get() { return *std::launder(reinterpret_cast(&value_)); } + + const T& get() const { + return *std::launder(reinterpret_cast(&value_)); + } + + private: + alignas(T) uint8_t value_[sizeof(T)]; + }; + + std::conditional_t, TrivialImpl, + PlacementImpl> + impl_; +}; + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_NO_DESTRUCTOR_H_ From 585c96f41018126a6145d14b9536d418bbfdc12a Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 2 Mar 2022 21:45:50 +0000 Subject: [PATCH 013/111] Internal change PiperOrigin-RevId: 432018873 --- base/memory_manager.h | 30 ++++++++++++++---------------- base/memory_manager_test.cc | 9 +++++++++ 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/base/memory_manager.h b/base/memory_manager.h index a02b318c6..893785d44 100644 --- a/base/memory_manager.h +++ b/base/memory_manager.h @@ -94,19 +94,17 @@ class ManagedMemory final { std::swap(align_, other.align_); } - constexpr T& get() ABSL_ATTRIBUTE_LIFETIME_BOUND { return *ptr_; } - - constexpr const T& get() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return *ptr_; } - - constexpr T& operator*() ABSL_ATTRIBUTE_LIFETIME_BOUND { return get(); } - - constexpr const T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - return get(); + constexpr T& get() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_ASSERT(static_cast(*this)); + return *ptr_; } - constexpr T* operator->() { return ptr_; } + constexpr T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return get(); } - constexpr const T* operator->() const { return ptr_; } + constexpr T* operator->() const { + ABSL_ASSERT(static_cast(*this)); + return ptr_; + } constexpr explicit operator bool() const { return ptr_ != nullptr; } @@ -126,22 +124,22 @@ class ManagedMemory final { }; template -bool operator==(const ManagedMemory& lhs, std::nullptr_t) { - return lhs.get() == nullptr; +constexpr bool operator==(const ManagedMemory& lhs, std::nullptr_t) { + return !static_cast(lhs); } template -bool operator==(std::nullptr_t, const ManagedMemory& rhs) { - return rhs.get() == nullptr; +constexpr bool operator==(std::nullptr_t, const ManagedMemory& rhs) { + return !static_cast(rhs); } template -bool operator!=(const ManagedMemory& lhs, std::nullptr_t) { +constexpr bool operator!=(const ManagedMemory& lhs, std::nullptr_t) { return !operator==(lhs, nullptr); } template -bool operator!=(std::nullptr_t, const ManagedMemory& rhs) { +constexpr bool operator!=(std::nullptr_t, const ManagedMemory& rhs) { return !operator==(nullptr, rhs); } diff --git a/base/memory_manager_test.cc b/base/memory_manager_test.cc index dc4b4f7df..854c5c49b 100644 --- a/base/memory_manager_test.cc +++ b/base/memory_manager_test.cc @@ -26,6 +26,8 @@ struct TriviallyDestructible final {}; TEST(GlobalMemoryManager, TriviallyDestructible) { EXPECT_TRUE(std::is_trivially_destructible_v); auto managed = MemoryManager::Global().New(); + EXPECT_NE(managed, nullptr); + EXPECT_NE(nullptr, managed); } struct NotTriviallyDestuctible final { @@ -37,8 +39,15 @@ struct NotTriviallyDestuctible final { TEST(GlobalMemoryManager, NotTriviallyDestuctible) { EXPECT_FALSE(std::is_trivially_destructible_v); auto managed = MemoryManager::Global().New(); + EXPECT_NE(managed, nullptr); + EXPECT_NE(nullptr, managed); EXPECT_CALL(*managed, Delete()); } +TEST(ManagedMemory, Null) { + EXPECT_EQ(ManagedMemory(), nullptr); + EXPECT_EQ(nullptr, ManagedMemory()); +} + } // namespace } // namespace cel From f851cc2fc444ff91dc035d0219093b33f1297818 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 4 Mar 2022 17:46:37 +0000 Subject: [PATCH 014/111] Internal change PiperOrigin-RevId: 432468397 --- base/internal/BUILD | 3 +- base/internal/memory_manager.post.h | 45 +++++++++++++++++++ ...{memory_manager.h => memory_manager.pre.h} | 20 +++++++-- base/memory_manager.h | 35 +++++++++------ 4 files changed, 86 insertions(+), 17 deletions(-) create mode 100644 base/internal/memory_manager.post.h rename base/internal/{memory_manager.h => memory_manager.pre.h} (63%) diff --git a/base/internal/BUILD b/base/internal/BUILD index d32a4ef19..ac6f4237c 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -19,7 +19,8 @@ licenses(["notice"]) cc_library( name = "memory_manager", textual_hdrs = [ - "memory_manager.h", + "memory_manager.pre.h", + "memory_manager.post.h", ], ) diff --git a/base/internal/memory_manager.post.h b/base/internal/memory_manager.post.h new file mode 100644 index 000000000..dde3e425a --- /dev/null +++ b/base/internal/memory_manager.post.h @@ -0,0 +1,45 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_POST_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_POST_H_ + +namespace cel::base_internal { + +template +constexpr size_t GetManagedMemorySize(const ManagedMemory& managed_memory) { + return managed_memory.size_; +} + +template +constexpr size_t GetManagedMemoryAlignment( + const ManagedMemory& managed_memory) { + return managed_memory.align_; +} + +template +constexpr T* ManagedMemoryRelease(ManagedMemory& managed_memory) { + // Like ManagedMemory::release except there is no assert. For use during + // handle creation. + T* ptr = managed_memory.ptr_; + managed_memory.ptr_ = nullptr; + managed_memory.size_ = managed_memory.align_ = 0; + return ptr; +} + +} // namespace cel::base_internal + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_POST_H_ diff --git a/base/internal/memory_manager.h b/base/internal/memory_manager.pre.h similarity index 63% rename from base/internal/memory_manager.h rename to base/internal/memory_manager.pre.h index eec6c6dc3..aeda27995 100644 --- a/base/internal/memory_manager.h +++ b/base/internal/memory_manager.pre.h @@ -14,18 +14,32 @@ // IWYU pragma: private -#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_H_ -#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_H_ +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_PRE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_PRE_H_ #include #include namespace cel { +template +class ManagedMemory; class MemoryManager; namespace base_internal { +class Resource; + +template +constexpr size_t GetManagedMemorySize(const ManagedMemory& managed_memory); + +template +constexpr size_t GetManagedMemoryAlignment( + const ManagedMemory& managed_memory); + +template +constexpr T* ManagedMemoryRelease(ManagedMemory& managed_memory); + template class MemoryManagerDestructor final { private: @@ -38,4 +52,4 @@ class MemoryManagerDestructor final { } // namespace cel -#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_H_ +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_PRE_H_ diff --git a/base/memory_manager.h b/base/memory_manager.h index 893785d44..85f796ef4 100644 --- a/base/memory_manager.h +++ b/base/memory_manager.h @@ -23,7 +23,7 @@ #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/base/optimization.h" -#include "base/internal/memory_manager.h" // IWYU pragma: export +#include "base/internal/memory_manager.pre.h" // IWYU pragma: export namespace cel { @@ -76,10 +76,7 @@ class ManagedMemory final { T* release() { ABSL_ASSERT(size_ == 0); - T* ptr = ptr_; - ptr_ = nullptr; - size_ = align_ = 0; - return ptr; + return base_internal::ManagedMemoryRelease(*this); } void reset() { @@ -94,22 +91,31 @@ class ManagedMemory final { std::swap(align_, other.align_); } - constexpr T& get() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + constexpr T* get() const { return ptr_; } + + constexpr T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_ASSERT(static_cast(*this)); - return *ptr_; + return *get(); } - constexpr T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return get(); } - constexpr T* operator->() const { ABSL_ASSERT(static_cast(*this)); - return ptr_; + return get(); } - constexpr explicit operator bool() const { return ptr_ != nullptr; } + constexpr explicit operator bool() const { return get() != nullptr; } private: friend class MemoryManager; + template + friend constexpr size_t base_internal::GetManagedMemorySize( + const ManagedMemory& managed_memory); + template + friend constexpr size_t base_internal::GetManagedMemoryAlignment( + const ManagedMemory& managed_memory); + template + friend constexpr F* base_internal::ManagedMemoryRelease( + ManagedMemory& managed_memory); constexpr ManagedMemory(T* ptr, size_t size, size_t align) : ptr_(ptr), size_(size), align_(align) {} @@ -154,8 +160,8 @@ class MemoryManager { // Allocates and constructs `T`. In the event of an allocation failure nullptr // is returned. template - std::enable_if_t, ManagedMemory> New( - Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { + ManagedMemory New(Args&&... args) + ABSL_ATTRIBUTE_LIFETIME_BOUND ABSL_MUST_USE_RESULT { size_t size = sizeof(T); size_t align = alignof(T); void* pointer = AllocateInternal(size, align); @@ -183,6 +189,7 @@ class MemoryManager { template friend class ManagedMemory; friend class ArenaMemoryManager; + friend class base_internal::Resource; // Only for use by ArenaMemoryManager. explicit MemoryManager(bool allocation_only) @@ -294,4 +301,6 @@ class ArenaMemoryManager : public MemoryManager { } // namespace cel +#include "base/internal/memory_manager.post.h" // IWYU pragma: export + #endif // THIRD_PARTY_CEL_CPP_BASE_MEMORY_MANAGER_H_ From 760b424bb34b316802e4ec8e7521fdd843109d82 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 4 Mar 2022 23:01:11 +0000 Subject: [PATCH 015/111] Add cel_number.h with utilities for cross numeric comparisons. Use it for implementations of comparison operations. PiperOrigin-RevId: 432540100 --- eval/public/BUILD | 21 ++ eval/public/cel_number.cc | 30 +++ eval/public/cel_number.h | 243 ++++++++++++++++++++ eval/public/cel_number_test.cc | 144 ++++++++++++ eval/public/comparison_functions.cc | 333 ++++------------------------ 5 files changed, 481 insertions(+), 290 deletions(-) create mode 100644 eval/public/cel_number.cc create mode 100644 eval/public/cel_number.h create mode 100644 eval/public/cel_number_test.cc diff --git a/eval/public/BUILD b/eval/public/BUILD index 23788808d..3f73536ce 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -239,6 +239,7 @@ cc_library( ":cel_builtins", ":cel_function_adapter", ":cel_function_registry", + ":cel_number", ":cel_options", ":cel_value", "//eval/eval:mutable_list_impl", @@ -881,3 +882,23 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "cel_number", + srcs = ["cel_number.cc"], + hdrs = ["cel_number.h"], + deps = [ + ":cel_value", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "cel_number_test", + srcs = ["cel_number_test.cc"], + deps = [ + ":cel_number", + "//internal:testing", + "@com_google_absl//absl/types:optional", + ], +) diff --git a/eval/public/cel_number.cc b/eval/public/cel_number.cc new file mode 100644 index 000000000..8527ba9e7 --- /dev/null +++ b/eval/public/cel_number.cc @@ -0,0 +1,30 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/cel_number.h" + +#include "eval/public/cel_value.h" + +namespace google::api::expr::runtime { +absl::optional GetNumberFromCelValue(const CelValue& value) { + if (int64_t val; value.GetValue(&val)) { + return CelNumber(val); + } else if (uint64_t val; value.GetValue(&val)) { + return CelNumber(val); + } else if (double val; value.GetValue(&val)) { + return CelNumber(val); + } + return absl::nullopt; +} +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_number.h b/eval/public/cel_number.h new file mode 100644 index 000000000..e4a6a91d4 --- /dev/null +++ b/eval/public/cel_number.h @@ -0,0 +1,243 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_NUMERIC_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_NUMERIC_H_ + +#include +#include + +#include "absl/types/variant.h" +#include "eval/public/cel_value.h" + +namespace google::api::expr::runtime { + +constexpr int64_t kInt64Max = std::numeric_limits::max(); +constexpr int64_t kInt64Min = std::numeric_limits::lowest(); +constexpr uint64_t kUint64Max = std::numeric_limits::max(); +constexpr uint64_t kUintToIntMax = static_cast(kInt64Max); +constexpr double kDoubleToIntMax = static_cast(kInt64Max); +constexpr double kDoubleToIntMin = static_cast(kInt64Min); +constexpr double kDoubleToUintMax = static_cast(kUint64Max); + +namespace internal { + +using NumberVariant = absl::variant; + +enum class ComparisonResult { + kLesser, + kEqual, + kGreater, + // Special case for nan. + kNanInequal +}; + +// Return the inverse relation (i.e. Invert(cmp(b, a)) is the same as cmp(a, b). +constexpr ComparisonResult Invert(ComparisonResult result) { + switch (result) { + case ComparisonResult::kLesser: + return ComparisonResult::kGreater; + case ComparisonResult::kGreater: + return ComparisonResult::kLesser; + case ComparisonResult::kEqual: + return ComparisonResult::kEqual; + case ComparisonResult::kNanInequal: + return ComparisonResult::kNanInequal; + } +} + +template +struct ConversionVisitor { + template + constexpr OutType operator()(InType v) { + return static_cast(v); + } +}; + +template +constexpr ComparisonResult Compare(T a, T b) { + return (a > b) ? ComparisonResult::kGreater + : (a == b) ? ComparisonResult::kEqual + : ComparisonResult::kLesser; +} + +constexpr ComparisonResult DoubleCompare(double a, double b) { + // constexpr friendly isnan check. + if (!(a == a) || !(b == b)) { + return ComparisonResult::kNanInequal; + } + return Compare(a, b); +} + +// Implement generic numeric comparison against double value. +struct DoubleCompareVisitor { + constexpr explicit DoubleCompareVisitor(double v) : v(v) {} + + constexpr ComparisonResult operator()(double other) const { + return DoubleCompare(v, other); + } + + constexpr ComparisonResult operator()(uint64_t other) const { + if (v > kDoubleToUintMax) { + return ComparisonResult::kGreater; + } else if (v < 0) { + return ComparisonResult::kLesser; + } else { + return DoubleCompare(v, static_cast(other)); + } + } + + constexpr ComparisonResult operator()(int64_t other) const { + if (v > kDoubleToIntMax) { + return ComparisonResult::kGreater; + } else if (v < kDoubleToIntMin) { + return ComparisonResult::kLesser; + } else { + return DoubleCompare(v, static_cast(other)); + } + } + double v; +}; + +// Implement generic numeric comparison against uint value. +// Delegates to double comparison if either variable is double. +struct UintCompareVisitor { + constexpr explicit UintCompareVisitor(uint64_t v) : v(v) {} + + constexpr ComparisonResult operator()(double other) const { + return Invert(DoubleCompareVisitor(other)(v)); + } + + constexpr ComparisonResult operator()(uint64_t other) const { + return Compare(v, other); + } + + constexpr ComparisonResult operator()(int64_t other) const { + if (v > kUintToIntMax || other < 0) { + return ComparisonResult::kGreater; + } else { + return Compare(v, static_cast(other)); + } + } + uint64_t v; +}; + +// Implement generic numeric comparison against int value. +// Delegates to uint / double if either value is uint / double. +struct IntCompareVisitor { + constexpr explicit IntCompareVisitor(int64_t v) : v(v) {} + + constexpr ComparisonResult operator()(double other) { + return Invert(DoubleCompareVisitor(other)(v)); + } + + constexpr ComparisonResult operator()(uint64_t other) { + return Invert(UintCompareVisitor(other)(v)); + } + + constexpr ComparisonResult operator()(int64_t other) { + return Compare(v, other); + } + int64_t v; +}; + +struct CompareVisitor { + explicit constexpr CompareVisitor(NumberVariant rhs) : rhs(rhs) {} + + constexpr ComparisonResult operator()(double v) { + return absl::visit(DoubleCompareVisitor(v), rhs); + } + + constexpr ComparisonResult operator()(uint64_t v) { + return absl::visit(UintCompareVisitor(v), rhs); + } + + constexpr ComparisonResult operator()(int64_t v) { + return absl::visit(IntCompareVisitor(v), rhs); + } + NumberVariant rhs; +}; + +} // namespace internal + +// Utility class for CEL number operations. +// +// In CEL expressions, comparisons between differnet numeric types are treated +// as all happening on the same continuous number line. This generally means +// that integers and doubles in convertible range are compared after converting +// to doubles (tolerating some loss of precision). +// +// This extends to key lookups -- {1: 'abc'}[1.0f] is expected to work since +// 1.0 == 1 in CEL. +class CelNumber { + public: + // Factories to resolove ambiguous overload resolutions. + // int literals can't be resolved against the constructor overloads. + static constexpr CelNumber FromInt64(int64_t value) { + return CelNumber(value); + } + static constexpr CelNumber FromUint64(uint64_t value) { + return CelNumber(value); + } + static constexpr CelNumber FromDouble(double value) { + return CelNumber(value); + } + + constexpr explicit CelNumber(double double_value) : value_(double_value) {} + constexpr explicit CelNumber(int64_t int_value) : value_(int_value) {} + constexpr explicit CelNumber(uint64_t uint_value) : value_(uint_value) {} + + constexpr bool operator<(CelNumber other) const { + return Compare(other) == internal::ComparisonResult::kLesser; + } + + constexpr bool operator<=(CelNumber other) const { + internal::ComparisonResult cmp = Compare(other); + return cmp != internal::ComparisonResult::kGreater && + cmp != internal::ComparisonResult::kNanInequal; + } + + constexpr bool operator>(CelNumber other) const { + return Compare(other) == internal::ComparisonResult::kGreater; + } + + constexpr bool operator>=(CelNumber other) const { + internal::ComparisonResult cmp = Compare(other); + return cmp != internal::ComparisonResult::kLesser && + cmp != internal::ComparisonResult::kNanInequal; + } + + constexpr bool operator==(CelNumber other) const { + return Compare(other) == internal::ComparisonResult::kEqual; + } + + constexpr bool operator!=(CelNumber other) const { + return Compare(other) != internal::ComparisonResult::kEqual; + } + + private: + internal::NumberVariant value_; + + constexpr internal::ComparisonResult Compare(CelNumber other) const { + return absl::visit(internal::CompareVisitor(other.value_), value_); + } +}; + +// Return a CelNumber if the value holds a numeric type, otherwise return +// nullopt. +absl::optional GetNumberFromCelValue(const CelValue& value); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_NUMERIC_H_ diff --git a/eval/public/cel_number_test.cc b/eval/public/cel_number_test.cc new file mode 100644 index 000000000..77b8f44da --- /dev/null +++ b/eval/public/cel_number_test.cc @@ -0,0 +1,144 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/cel_number.h" + +#include + +#include "absl/types/optional.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +using testing::Optional; + +constexpr double kNan = std::numeric_limits::quiet_NaN(); +constexpr double kInfinity = std::numeric_limits::infinity(); + +static_assert(CelNumber(1.0f) == CelNumber::FromInt64(1), "double == int"); +static_assert(CelNumber(1.0f) == CelNumber::FromUint64(1), "double == uint"); +static_assert(CelNumber(1.0f) == CelNumber(1.0f), "double == double"); +static_assert(CelNumber::FromInt64(1) == CelNumber::FromInt64(1), "int == int"); +static_assert(CelNumber::FromInt64(1) == CelNumber::FromUint64(1), + "int == uint"); +static_assert(CelNumber::FromInt64(1) == CelNumber(1.0f), "int == double"); +static_assert(CelNumber::FromUint64(1) == CelNumber::FromInt64(1), + "uint == int"); +static_assert(CelNumber::FromUint64(1) == CelNumber::FromUint64(1), + "uint == uint"); +static_assert(CelNumber::FromUint64(1) == CelNumber(1.0f), "uint == double"); + +static_assert(CelNumber(1.0f) >= CelNumber::FromInt64(1), "double >= int"); +static_assert(CelNumber(1.0f) >= CelNumber::FromUint64(1), "double >= uint"); +static_assert(CelNumber(1.0f) >= CelNumber(1.0f), "double >= double"); +static_assert(CelNumber::FromInt64(1) >= CelNumber::FromInt64(1), "int >= int"); +static_assert(CelNumber::FromInt64(1) >= CelNumber::FromUint64(1), + "int >= uint"); +static_assert(CelNumber::FromInt64(1) >= CelNumber(1.0f), "int >= double"); +static_assert(CelNumber::FromUint64(1) >= CelNumber::FromInt64(1), + "uint >= int"); +static_assert(CelNumber::FromUint64(1) >= CelNumber::FromUint64(1), + "uint >= uint"); +static_assert(CelNumber::FromUint64(1) >= CelNumber(1.0f), "uint >= double"); + +static_assert(CelNumber(1.0f) <= CelNumber::FromInt64(1), "double <= int"); +static_assert(CelNumber(1.0f) <= CelNumber::FromUint64(1), "double <= uint"); +static_assert(CelNumber(1.0f) <= CelNumber(1.0f), "double <= double"); +static_assert(CelNumber::FromInt64(1) <= CelNumber::FromInt64(1), "int <= int"); +static_assert(CelNumber::FromInt64(1) <= CelNumber::FromUint64(1), + "int <= uint"); +static_assert(CelNumber::FromInt64(1) <= CelNumber(1.0f), "int <= double"); +static_assert(CelNumber::FromUint64(1) <= CelNumber::FromInt64(1), + "uint <= int"); +static_assert(CelNumber::FromUint64(1) <= CelNumber::FromUint64(1), + "uint <= uint"); +static_assert(CelNumber::FromUint64(1) <= CelNumber(1.0f), "uint <= double"); + +static_assert(CelNumber(1.5f) > CelNumber::FromInt64(1), "double > int"); +static_assert(CelNumber(1.5f) > CelNumber::FromUint64(1), "double > uint"); +static_assert(CelNumber(1.5f) > CelNumber(1.0f), "double > double"); +static_assert(CelNumber::FromInt64(2) > CelNumber::FromInt64(1), "int > int"); +static_assert(CelNumber::FromInt64(2) > CelNumber::FromUint64(1), "int > uint"); +static_assert(CelNumber::FromInt64(2) > CelNumber(1.5f), "int > double"); +static_assert(CelNumber::FromUint64(2) > CelNumber::FromInt64(1), "uint > int"); +static_assert(CelNumber::FromUint64(2) > CelNumber::FromUint64(1), + "uint > uint"); +static_assert(CelNumber::FromUint64(2) > CelNumber(1.5f), "uint > double"); + +static_assert(CelNumber(1.0f) < CelNumber::FromInt64(2), "double < int"); +static_assert(CelNumber(1.0f) < CelNumber::FromUint64(2), "double < uint"); +static_assert(CelNumber(1.0f) < CelNumber(1.1f), "double < double"); +static_assert(CelNumber::FromInt64(1) < CelNumber::FromInt64(2), "int < int"); +static_assert(CelNumber::FromInt64(1) < CelNumber::FromUint64(2), "int < uint"); +static_assert(CelNumber::FromInt64(1) < CelNumber(1.5f), "int < double"); +static_assert(CelNumber::FromUint64(1) < CelNumber::FromInt64(2), "uint < int"); +static_assert(CelNumber::FromUint64(1) < CelNumber::FromUint64(2), + "uint < uint"); +static_assert(CelNumber::FromUint64(1) < CelNumber(1.5f), "uint < double"); + +static_assert(CelNumber(kNan) != CelNumber(kNan), "nan != nan"); +static_assert(!(CelNumber(kNan) == CelNumber(kNan)), "nan == nan"); +static_assert(!(CelNumber(kNan) > CelNumber(kNan)), "nan > nan"); +static_assert(!(CelNumber(kNan) < CelNumber(kNan)), "nan < nan"); +static_assert(!(CelNumber(kNan) >= CelNumber(kNan)), "nan >= nan"); +static_assert(!(CelNumber(kNan) <= CelNumber(kNan)), "nan <= nan"); + +static_assert(CelNumber(kNan) != CelNumber::FromInt64(1), "nan != int"); +static_assert(!(CelNumber(kNan) == CelNumber::FromInt64(1)), "nan == int"); +static_assert(!(CelNumber(kNan) > CelNumber::FromInt64(1)), "nan > int"); +static_assert(!(CelNumber(kNan) < CelNumber::FromInt64(1)), "nan < int"); +static_assert(!(CelNumber(kNan) >= CelNumber::FromInt64(1)), "nan >= int"); +static_assert(!(CelNumber(kNan) <= CelNumber::FromInt64(1)), "nan <= int"); + +static_assert(!(CelNumber(kInfinity) != CelNumber(kInfinity)), "inf != inf"); +static_assert(CelNumber(kInfinity) == CelNumber(kInfinity), "inf == inf"); +static_assert(!(CelNumber(kInfinity) > CelNumber(kInfinity)), "inf > inf"); +static_assert(!(CelNumber(kInfinity) < CelNumber(kInfinity)), "inf < inf"); +static_assert(CelNumber(kInfinity) >= CelNumber(kInfinity), "inf >= inf"); +static_assert(CelNumber(kInfinity) <= CelNumber(kInfinity), "inf <= inf"); + +static_assert(CelNumber(kInfinity) != CelNumber::FromInt64(1), "inf != int"); +static_assert(!(CelNumber(kInfinity) == CelNumber::FromInt64(1)), "inf == int"); +static_assert(CelNumber(kInfinity) > CelNumber::FromInt64(1), "inf > int"); +static_assert(!(CelNumber(kInfinity) < CelNumber::FromInt64(1)), "inf < int"); +static_assert(CelNumber(kInfinity) >= CelNumber::FromInt64(1), "inf >= int"); +static_assert(!(CelNumber(kInfinity) <= CelNumber::FromInt64(1)), "inf <= int"); + +TEST(CelNumber, Basic) { + EXPECT_GT(CelNumber(1.1), CelNumber::FromInt64(1)); + EXPECT_LT(CelNumber::FromUint64(1), CelNumber(1.1)); + EXPECT_EQ(CelNumber(1.1), CelNumber(1.1)); + + EXPECT_EQ(CelNumber::FromUint64(1), CelNumber::FromUint64(1)); + EXPECT_EQ(CelNumber::FromInt64(1), CelNumber::FromUint64(1)); + EXPECT_GT(CelNumber::FromUint64(1), CelNumber::FromInt64(-1)); + + EXPECT_EQ(CelNumber::FromInt64(-1), CelNumber::FromInt64(-1)); +} + +TEST(CelNumber, GetNumberFromCelValue) { + EXPECT_THAT(GetNumberFromCelValue(CelValue::CreateDouble(1.1)), + Optional(CelNumber::FromDouble(1.1))); + EXPECT_THAT(GetNumberFromCelValue(CelValue::CreateInt64(1)), + Optional(CelNumber::FromDouble(1.0))); + EXPECT_THAT(GetNumberFromCelValue(CelValue::CreateUint64(1)), + Optional(CelNumber::FromDouble(1.0))); + + EXPECT_EQ(GetNumberFromCelValue(CelValue::CreateDuration(absl::Seconds(1))), + absl::nullopt); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index 1f1d900b3..2c03b01bc 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -35,6 +35,7 @@ #include "eval/public/cel_builtins.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_number.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" @@ -53,14 +54,6 @@ namespace { using ::google::protobuf::Arena; using ::google::protobuf::util::MessageDifferencer; -constexpr int64_t kInt64Max = std::numeric_limits::max(); -constexpr int64_t kInt64Min = std::numeric_limits::lowest(); -constexpr uint64_t kUint64Max = std::numeric_limits::max(); -constexpr uint64_t kUintToIntMax = static_cast(kInt64Max); -constexpr double kDoubleToIntMax = static_cast(kInt64Max); -constexpr double kDoubleToIntMin = static_cast(kInt64Min); -constexpr double kDoubleToUintMax = static_cast(kUint64Max); - // Forward declaration of the functors for generic equality operator. // Equal only defined for same-typed values. struct HomogenousEqualProvider { @@ -165,147 +158,24 @@ bool GreaterThanOrEqual(Arena*, absl::Time t1, absl::Time t2) { return absl::operator>=(t1, t2); } -inline int32_t CompareDouble(double d1, double d2) { - double cmp = d1 - d2; - return cmp < 0 ? -1 : cmp > 0 ? 1 : 0; -} - -int32_t CompareDoubleInt(double d, int64_t i) { - if (d < kDoubleToIntMin) { - return -1; - } - if (d > kDoubleToIntMax) { - return 1; - } - return CompareDouble(d, static_cast(i)); -} - -inline int32_t CompareIntDouble(int64_t i, double d) { - return -CompareDoubleInt(d, i); -} - -int32_t CompareDoubleUint(double d, uint64_t u) { - if (d < 0.0) { - return -1; - } - if (d > kDoubleToUintMax) { - return 1; - } - return CompareDouble(d, static_cast(u)); -} - -inline int32_t CompareUintDouble(uint64_t u, double d) { - return -CompareDoubleUint(d, u); -} - -int32_t CompareIntUint(int64_t i, uint64_t u) { - if (i < 0 || u > kUintToIntMax) { - return -1; - } - // Note, the type conversion cannot overflow as the overflow condition is - // checked earlier as part of the special case comparison. - int64_t cmp = i - static_cast(u); - return cmp < 0 ? -1 : cmp > 0 ? 1 : 0; -} - -inline int32_t CompareUintInt(uint64_t u, int64_t i) { - return -CompareIntUint(i, u); -} - -bool LessThanDoubleInt(Arena*, double d, int64_t i) { - return CompareDoubleInt(d, i) == -1; -} - -bool LessThanIntDouble(Arena*, int64_t i, double d) { - return CompareIntDouble(i, d) == -1; -} - -bool LessThanDoubleUint(Arena*, double d, uint64_t u) { - return CompareDoubleInt(d, u) == -1; -} - -bool LessThanUintDouble(Arena*, uint64_t u, double d) { - return CompareIntDouble(u, d) == -1; -} - -bool LessThanIntUint(Arena*, int64_t i, uint64_t u) { - return CompareIntUint(i, u) == -1; -} - -bool LessThanUintInt(Arena*, uint64_t u, int64_t i) { - return CompareUintInt(u, i) == -1; -} - -bool LessThanOrEqualDoubleInt(Arena*, double d, int64_t i) { - return CompareDoubleInt(d, i) <= 0; -} - -bool LessThanOrEqualIntDouble(Arena*, int64_t i, double d) { - return CompareIntDouble(i, d) <= 0; -} - -bool LessThanOrEqualDoubleUint(Arena*, double d, uint64_t u) { - return CompareDoubleInt(d, u) <= 0; -} - -bool LessThanOrEqualUintDouble(Arena*, uint64_t u, double d) { - return CompareIntDouble(u, d) <= 0; -} - -bool LessThanOrEqualIntUint(Arena*, int64_t i, uint64_t u) { - return CompareIntUint(i, u) <= 0; -} - -bool LessThanOrEqualUintInt(Arena*, uint64_t u, int64_t i) { - return CompareUintInt(u, i) <= 0; -} - -bool GreaterThanDoubleInt(Arena*, double d, int64_t i) { - return CompareDoubleInt(d, i) == 1; -} - -bool GreaterThanIntDouble(Arena*, int64_t i, double d) { - return CompareIntDouble(i, d) == 1; -} - -bool GreaterThanDoubleUint(Arena*, double d, uint64_t u) { - return CompareDoubleInt(d, u) == 1; -} - -bool GreaterThanUintDouble(Arena*, uint64_t u, double d) { - return CompareIntDouble(u, d) == 1; -} - -bool GreaterThanIntUint(Arena*, int64_t i, uint64_t u) { - return CompareIntUint(i, u) == 1; -} - -bool GreaterThanUintInt(Arena*, uint64_t u, int64_t i) { - return CompareUintInt(u, i) == 1; -} - -bool GreaterThanOrEqualDoubleInt(Arena*, double d, int64_t i) { - return CompareDoubleInt(d, i) >= 0; -} - -bool GreaterThanOrEqualIntDouble(Arena*, int64_t i, double d) { - return CompareIntDouble(i, d) >= 0; -} - -bool GreaterThanOrEqualDoubleUint(Arena*, double d, uint64_t u) { - return CompareDoubleInt(d, u) >= 0; +template +bool CrossNumericLessThan(Arena* arena, T t, U u) { + return CelNumber(t) < CelNumber(u); } -bool GreaterThanOrEqualUintDouble(Arena*, uint64_t u, double d) { - return CompareIntDouble(u, d) >= 0; +template +bool CrossNumericGreaterThan(Arena* arena, T t, U u) { + return CelNumber(t) > CelNumber(u); } -bool GreaterThanOrEqualIntUint(Arena*, int64_t i, uint64_t u) { - return CompareIntUint(i, u) >= 0; +template +bool CrossNumericLessOrEqualTo(Arena* arena, T t, U u) { + return CelNumber(t) <= CelNumber(u); } -bool GreaterThanOrEqualUintInt(Arena*, uint64_t u, int64_t i) { - return CompareUintInt(u, i) >= 0; +template +bool CrossNumericGreaterOrEqualTo(Arena* arena, T t, U u) { + return CelNumber(t) >= CelNumber(u); } bool MessageNullEqual(Arena* arena, const google::protobuf::Message* t1, @@ -603,6 +473,23 @@ CelValue GeneralizedInequal(Arena* arena, CelValue t1, CelValue t2) { return CreateNoMatchingOverloadError(arena, builtin::kInequal); } +template +absl::Status RegisterCrossNumericComparisons(CelFunctionRegistry* registry) { + CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( + builtin::kLess, /*receiver_style=*/false, &CrossNumericLessThan, + registry))); + CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( + builtin::kGreater, /*receiver_style=*/false, + &CrossNumericGreaterThan, registry))); + CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( + builtin::kGreaterOrEqual, /*receiver_style=*/false, + &CrossNumericGreaterOrEqualTo, registry))); + CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( + builtin::kLessOrEqual, /*receiver_style=*/false, + &CrossNumericLessOrEqualTo, registry))); + return absl::OkStatus(); +} + absl::Status RegisterHeterogeneousComparisonFunctions( CelFunctionRegistry* registry) { CEL_RETURN_IF_ERROR( @@ -614,109 +501,20 @@ absl::Status RegisterHeterogeneousComparisonFunctions( builtin::kInequal, /*receiver_style=*/false, &GeneralizedInequal, registry))); - // Cross-type numeric less than operator - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kLess, /*receiver_style=*/false, &LessThanDoubleInt, - registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kLess, /*receiver_style=*/false, &LessThanDoubleUint, - registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kLess, /*receiver_style=*/false, &LessThanIntUint, - registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kLess, /*receiver_style=*/false, &LessThanIntDouble, - registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kLess, /*receiver_style=*/false, &LessThanUintDouble, - registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kLess, /*receiver_style=*/false, &LessThanUintInt, - registry))); - - // Cross-type numeric less than or equal operator - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kLessOrEqual, /*receiver_style=*/false, - &LessThanOrEqualDoubleInt, registry))); CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kLessOrEqual, /*receiver_style=*/false, - &LessThanOrEqualDoubleUint, registry))); + (RegisterCrossNumericComparisons(registry))); CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kLessOrEqual, /*receiver_style=*/false, - &LessThanOrEqualIntUint, registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kLessOrEqual, /*receiver_style=*/false, - &LessThanOrEqualIntDouble, registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kLessOrEqual, /*receiver_style=*/false, - &LessThanOrEqualUintDouble, registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kLessOrEqual, /*receiver_style=*/false, - &LessThanOrEqualUintInt, registry))); + (RegisterCrossNumericComparisons(registry))); - // Cross-type numeric greater than operator - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kGreater, /*receiver_style=*/false, &GreaterThanDoubleInt, - registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kGreater, /*receiver_style=*/false, &GreaterThanDoubleUint, - registry))); CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kGreater, /*receiver_style=*/false, &GreaterThanIntUint, - registry))); + (RegisterCrossNumericComparisons(registry))); CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kGreater, /*receiver_style=*/false, &GreaterThanIntDouble, - registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kGreater, /*receiver_style=*/false, &GreaterThanUintDouble, - registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kGreater, /*receiver_style=*/false, &GreaterThanUintInt, - registry))); + (RegisterCrossNumericComparisons(registry))); - // Cross-type numeric greater than or equal operator - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kGreaterOrEqual, /*receiver_style=*/false, - &GreaterThanOrEqualDoubleInt, registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kGreaterOrEqual, /*receiver_style=*/false, - &GreaterThanOrEqualDoubleUint, registry))); CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kGreaterOrEqual, /*receiver_style=*/false, - &GreaterThanOrEqualIntUint, registry))); + (RegisterCrossNumericComparisons(registry))); CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kGreaterOrEqual, /*receiver_style=*/false, - &GreaterThanOrEqualIntDouble, registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kGreaterOrEqual, /*receiver_style=*/false, - &GreaterThanOrEqualUintDouble, registry))); - CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - builtin::kGreaterOrEqual, /*receiver_style=*/false, - &GreaterThanOrEqualUintInt, registry))); + (RegisterCrossNumericComparisons(registry))); CEL_RETURN_IF_ERROR(RegisterOrderingFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterOrderingFunctionsForType(registry)); @@ -762,58 +560,13 @@ absl::optional CelValueEqualImpl(const CelValue& v1, const CelValue& v2) { v2.type() == CelValue::Type::kNullType) { return false; } - switch (v1.type()) { - case CelValue::Type::kDouble: { - double d; - v1.GetValue(&d); - if (std::isnan(d)) { - return false; - } - switch (v2.type()) { - case CelValue::Type::kInt64: - return CompareDoubleInt(d, v2.Int64OrDie()) == 0; - case CelValue::Type::kUint64: - return CompareDoubleUint(d, v2.Uint64OrDie()) == 0; - default: - return absl::nullopt; - } - } - case CelValue::Type::kInt64: - int64_t i; - v1.GetValue(&i); - switch (v2.type()) { - case CelValue::Type::kDouble: { - double d; - v2.GetValue(&d); - if (std::isnan(d)) { - return false; - } - return CompareIntDouble(i, d) == 0; - } - case CelValue::Type::kUint64: - return CompareIntUint(i, v2.Uint64OrDie()) == 0; - default: - return absl::nullopt; - } - case CelValue::Type::kUint64: - uint64_t u; - v1.GetValue(&u); - switch (v2.type()) { - case CelValue::Type::kDouble: { - double d; - v2.GetValue(&d); - if (std::isnan(d)) { - return false; - } - return CompareUintDouble(u, d) == 0; - } - case CelValue::Type::kInt64: - return CompareUintInt(u, v2.Int64OrDie()) == 0; - default: - return absl::nullopt; - } - default: - return absl::nullopt; + absl::optional lhs = GetNumberFromCelValue(v1); + absl::optional rhs = GetNumberFromCelValue(v2); + + if (rhs.has_value() && lhs.has_value()) { + return *lhs == *rhs; + } else { + return absl::nullopt; } } From 84b2760645eef525c021c70960c28b39c7756d69 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 8 Mar 2022 17:10:46 +0000 Subject: [PATCH 016/111] Internal change PiperOrigin-RevId: 433227593 --- base/BUILD | 10 + base/handle.h | 496 ++++++++++++++++++++++++++++++++++++ base/internal/BUILD | 14 + base/internal/handle.post.h | 142 +++++++++++ base/internal/handle.pre.h | 178 +++++++++++++ 5 files changed, 840 insertions(+) create mode 100644 base/handle.h create mode 100644 base/internal/handle.post.h create mode 100644 base/internal/handle.pre.h diff --git a/base/BUILD b/base/BUILD index 01b9b3055..f6ee44c30 100644 --- a/base/BUILD +++ b/base/BUILD @@ -19,6 +19,16 @@ package( licenses(["notice"]) +cc_library( + name = "handle", + hdrs = ["handle.h"], + deps = [ + "//base/internal:handle", + "//internal:casts", + "@com_google_absl//absl/base:core_headers", + ], +) + cc_library( name = "kind", srcs = ["kind.cc"], diff --git a/base/handle.h b/base/handle.h new file mode 100644 index 000000000..f82b797ca --- /dev/null +++ b/base/handle.h @@ -0,0 +1,496 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_HANDLE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_HANDLE_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "base/internal/handle.pre.h" // IWYU pragma: export +#include "internal/casts.h" + +namespace cel { + +template +class Transient; + +template +class Persistent; + +// `Transient` is a handle that is intended to be short lived and may not +// actually own the referenced `T`. It is only valid as long as the handle it +// was created from or the native C++ value it is wrapping is valid. If you need +// to store a handle such that it can escape the current scope use `Persistent`. +template +class Transient final : private base_internal::HandlePolicy { + private: + using Traits = base_internal::TransientHandleTraits>; + using Handle = typename Traits::handle_type; + + public: + // Default constructs the handle, setting it to an empty state. It is + // undefined behavior to call any functions that attempt to dereference or + // access `T` when in an empty state. + Transient() = default; + + Transient(const Transient&) = default; + + template >> + Transient(const Transient& handle) : impl_(handle.impl_) {} // NOLINT + + Transient(Transient&&) = default; + + // Allow implicit conversion from Persistent to Transient, but not the other + // way around. This is analogous to implicit conversion from std::string to + // std::string_view. + Transient(const Persistent& handle); // NOLINT + + // Allow implicit conversion from Persistent to Transient, but not the other + // way around. This is analygous to implicit conversion from std::string to + // std::string_view. + template >> + Transient(const Persistent& handle); // NOLINT + + Transient& operator=(const Transient&) = default; + + template + std::enable_if_t, Transient&> // NOLINT + operator=(const Transient& handle) { + impl_ = handle.impl_; + return *this; + } + + Transient& operator=(Transient&&) = default; + + Transient& operator=(const Persistent& handle); + + // Same as the constructor above, but for the assign operator. + template + std::enable_if_t, Transient&> // NOLINT + operator=(const Persistent& handle); + + // Reinterpret the handle of type `T` as type `F`. `T` must be derived from + // `F`, `F` must be derived from `T`, or `F` must be the same as `T`. + // + // Persistent handle; + // handle.As()->SubMethod(); + template + std::enable_if_t< + std::disjunction_v, std::is_base_of, + std::is_same>, + Transient&> + As() ABSL_MUST_USE_RESULT { + static_assert(std::is_same_v::Handle>, + "Transient and Transient must have the same " + "implementation type"); + static_assert( + (std::is_const_v == std::is_const_v || std::is_const_v), + "Constness cannot be removed, only added using As()"); + ABSL_ASSERT(this->template Is>()); + // Persistent and Persistent have the same underlying layout + // representation, as ensured via the first static_assert, and they have + // compatible types such that F is the base of T or T is the base of F, as + // ensured via SFINAE on the return value and the second static_assert. Thus + // we can saftley reinterpret_cast. + return *reinterpret_cast*>(this); + } + + // Reinterpret the handle of type `T` as type `F`. `T` must be derived from + // `F`, `F` must be derived from `T`, or `F` must be the same as `T`. + // + // Persistent handle; + // handle.As()->SubMethod(); + template + std::enable_if_t< + std::disjunction_v, std::is_base_of, + std::is_same>, + const Transient&> + As() const ABSL_MUST_USE_RESULT { + static_assert(std::is_same_v::Handle>, + "Transient and Transient must have the same " + "implementation type"); + static_assert( + (std::is_const_v == std::is_const_v || std::is_const_v), + "Constness cannot be removed, only added using As()"); + ABSL_ASSERT(this->template Is>()); + // Persistent and Persistent have the same underlying layout + // representation, as ensured via the first static_assert, and they have + // compatible types such that F is the base of T or T is the base of F, as + // ensured via SFINAE on the return value and the second static_assert. Thus + // we can saftley reinterpret_cast. + return *reinterpret_cast*>(this); + } + + // Is checks wether `T` is an instance of `F`. + template + bool Is() const { + return impl_.template Is(); + } + + T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_ASSERT(static_cast(*this)); + return internal::down_cast(*impl_); + } + + T* operator->() const { + ABSL_ASSERT(static_cast(*this)); + return internal::down_cast(impl_.operator->()); + } + + // Tests whether the handle is not empty, returning false if it is empty. + explicit operator bool() const { return static_cast(impl_); } + + friend void swap(Transient& lhs, Transient& rhs) { + std::swap(lhs.impl_, rhs.impl_); + } + + friend bool operator==(const Transient& lhs, const Transient& rhs) { + return lhs.impl_ == rhs.impl_; + } + + template + friend H AbslHashValue(H state, const Transient& handle) { + return H::combine(std::move(state), handle.impl_); + } + + private: + template + friend class Transient; + template + friend class Persistent; + template + friend struct base_internal::HandleFactory; + template + friend bool base_internal::IsManagedHandle(const Transient& handle); + template + friend bool base_internal::IsUnmanagedHandle(const Transient& handle); + template + friend bool base_internal::IsInlinedHandle(const Transient& handle); + + template + explicit Transient(base_internal::HandleInPlace, Args&&... args) + : impl_(std::forward(args)...) {} + + Handle impl_; +}; + +template +std::enable_if_t, bool> operator==( + const Transient& lhs, const Transient& rhs) { + return lhs == rhs.template As(); +} + +template +std::enable_if_t, bool> operator==( + const Transient& lhs, const Transient& rhs) { + return rhs == lhs.template As(); +} + +template +bool operator!=(const Transient& lhs, const Transient& rhs) { + return !operator==(lhs, rhs); +} + +template +std::enable_if_t, bool> operator!=( + const Transient& lhs, const Transient& rhs) { + return !operator==(lhs, rhs); +} + +template +std::enable_if_t, bool> operator!=( + const Transient& lhs, const Transient& rhs) { + return !operator==(lhs, rhs); +} + +// `Persistent` is a handle that is intended to be long lived and shares +// ownership of the referenced `T`. It is valid so long as +// there are 1 or more `Persistent` handles pointing to `T` and the +// `AllocationManager` that constructed it is alive. +template +class Persistent final : private base_internal::HandlePolicy { + private: + using Traits = base_internal::PersistentHandleTraits>; + using Handle = typename Traits::handle_type; + + public: + // Default constructs the handle, setting it to an empty state. It is + // undefined behavior to call any functions that attempt to dereference or + // access `T` when in an empty state. + Persistent() = default; + + Persistent(const Persistent&) = default; + + template >> + Persistent(const Persistent& handle) : impl_(handle.impl_) {} // NOLINT + + Persistent(Persistent&&) = default; + + template >> + Persistent(Persistent&& handle) // NOLINT + : impl_(std::move(handle.impl_)) {} + + // Allow Transient handles to be assigned to Persistent handles. This is + // similar to std::string_view being assignable to std::string. + explicit Persistent(Transient handle) : impl_(handle.impl_) {} + + // Allow Transient handles to be assigned to Persistent handles. This is + // similar to std::string_view being assignable to std::string. + template >> + explicit Persistent(Transient handle) : impl_(handle.impl_) {} + + Persistent& operator=(const Persistent&) = default; + + // Allow Transient handles to be assigned to Persistent handles. This is + // similar to std::string_view being assignable to std::string. + Persistent& operator=(Transient handle) { + impl_ = handle.impl_; + return *this; + } + + // Allow Transient handles to be assigned to Persistent handles. This is + // similar to std::string_view being assignable to std::string. + template + std::enable_if_t, Persistent&> // NOLINT + operator=(Transient handle) { + impl_ = handle.impl_; + return *this; + } + + Persistent& operator=(Persistent&&) = default; + + template + std::enable_if_t, Persistent&> // NOLINT + operator=(const Persistent& handle) { + impl_ = handle.impl_; + return *this; + } + + template + std::enable_if_t, Persistent&> // NOLINT + operator=(Persistent&& handle) { + impl_ = std::move(handle.impl_); + return *this; + } + + // Reinterpret the handle of type `T` as type `F`. `T` must be derived from + // `F`, `F` must be derived from `T`, or `F` must be the same as `T`. + // + // Persistent handle; + // handle.As()->SubMethod(); + template + std::enable_if_t< + std::disjunction_v, std::is_base_of, + std::is_same>, + Persistent&> + As() ABSL_MUST_USE_RESULT { + static_assert(std::is_same_v::Handle>, + "Persistent and Persistent must have the same " + "implementation type"); + static_assert( + (std::is_const_v == std::is_const_v || std::is_const_v), + "Constness cannot be removed, only added using As()"); + ABSL_ASSERT(this->template Is()); + // Persistent and Persistent have the same underlying layout + // representation, as ensured via the first static_assert, and they have + // compatible types such that F is the base of T or T is the base of F, as + // ensured via SFINAE on the return value and the second static_assert. Thus + // we can saftley reinterpret_cast. + return *reinterpret_cast*>(this); + } + + // Reinterpret the handle of type `T` as type `F`. `T` must be derived from + // `F`, `F` must be derived from `T`, or `F` must be the same as `T`. + // + // Persistent handle; + // handle.As()->SubMethod(); + template + std::enable_if_t< + std::disjunction_v, std::is_base_of, + std::is_same>, + const Persistent&> + As() const ABSL_MUST_USE_RESULT { + static_assert(std::is_same_v::Handle>, + "Persistent and Persistent must have the same " + "implementation type"); + static_assert( + (std::is_const_v == std::is_const_v || std::is_const_v), + "Constness cannot be removed, only added using As()"); + ABSL_ASSERT(this->template Is>()); + // Persistent and Persistent have the same underlying layout + // representation, as ensured via the first static_assert, and they have + // compatible types such that F is the base of T or T is the base of F, as + // ensured via SFINAE on the return value and the second static_assert. Thus + // we can saftley reinterpret_cast. + return *reinterpret_cast*>(this); + } + + // Is checks wether `T` is an instance of `F`. + template + bool Is() const { + return impl_.template Is(); + } + + T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_ASSERT(static_cast(*this)); + return internal::down_cast(*impl_); + } + + T* operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_ASSERT(static_cast(*this)); + return internal::down_cast(impl_.operator->()); + } + + // Tests whether the handle is not empty, returning false if it is empty. + explicit operator bool() const { return static_cast(impl_); } + + friend void swap(Persistent& lhs, Persistent& rhs) { + std::swap(lhs.impl_, rhs.impl_); + } + + friend bool operator==(const Persistent& lhs, const Persistent& rhs) { + return lhs.impl_ == rhs.impl_; + } + + friend bool operator==(const Transient& lhs, const Persistent& rhs) { + return lhs.impl_ == rhs.impl_; + } + + friend bool operator==(const Persistent& lhs, const Transient& rhs) { + return lhs.impl_ == rhs.impl_; + } + + template + friend H AbslHashValue(H state, const Persistent& handle) { + return H::combine(std::move(state), handle.impl_); + } + + private: + template + friend class Transient; + template + friend class Persistent; + template + friend struct base_internal::HandleFactory; + template + friend bool base_internal::IsManagedHandle(const Persistent& handle); + template + friend bool base_internal::IsUnmanagedHandle(const Persistent& handle); + template + friend bool base_internal::IsInlinedHandle(const Persistent& handle); + + template + explicit Persistent(base_internal::HandleInPlace, Args&&... args) + : impl_(std::forward(args)...) {} + + Handle impl_; +}; + +template +std::enable_if_t, bool> operator==( + const Persistent& lhs, const Persistent& rhs) { + return lhs == rhs.template As(); +} + +template +std::enable_if_t, bool> operator==( + const Persistent& lhs, const Persistent& rhs) { + return rhs == lhs.template As(); +} + +template +std::enable_if_t, bool> operator==( + const Transient& lhs, const Persistent& rhs) { + return lhs == rhs.template As(); +} + +template +std::enable_if_t, bool> operator==( + const Persistent& lhs, const Transient& rhs) { + return rhs == lhs.template As(); +} + +template +bool operator!=(const Persistent& lhs, const Persistent& rhs) { + return !operator==(lhs, rhs); +} + +template +bool operator!=(const Transient& lhs, const Persistent& rhs) { + return !operator==(lhs, rhs); +} + +template +bool operator!=(const Persistent& lhs, const Transient& rhs) { + return !operator==(lhs, rhs); +} + +template +std::enable_if_t, bool> operator!=( + const Persistent& lhs, const Persistent& rhs) { + return !operator==(lhs, rhs); +} + +template +std::enable_if_t, bool> operator!=( + const Persistent& lhs, const Persistent& rhs) { + return !operator==(lhs, rhs); +} + +template +std::enable_if_t, bool> operator!=( + const Transient& lhs, const Persistent& rhs) { + return !operator==(lhs, rhs); +} + +template +std::enable_if_t, bool> operator!=( + const Persistent& lhs, const Transient& rhs) { + return !operator==(lhs, rhs); +} + +template +Transient::Transient(const Persistent& handle) : impl_(handle.impl_) {} + +template +template +Transient::Transient(const Persistent& handle) : impl_(handle.impl_) {} + +template +Transient& Transient::operator=(const Persistent& handle) { + impl_ = handle.impl_; + return *this; +} + +template // NOLINT +template +std::enable_if_t, Transient&> +Transient::operator=(const Persistent& handle) { + impl_ = handle.impl_; + return *this; +} + +} // namespace cel + +#include "base/internal/handle.post.h" // IWYU pragma: export + +#endif // THIRD_PARTY_CEL_CPP_BASE_HANDLE_H_ diff --git a/base/internal/BUILD b/base/internal/BUILD index ac6f4237c..73bb7cb60 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -16,6 +16,20 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) +# These headers should only ever be used by ../handle.h. They are here to avoid putting +# large amounts of implementation details in public headers. +cc_library( + name = "handle", + textual_hdrs = [ + "handle.pre.h", + "handle.post.h", + ], + deps = [ + "//base:memory_manager", + "@com_google_absl//absl/base:core_headers", + ], +) + cc_library( name = "memory_manager", textual_hdrs = [ diff --git a/base/internal/handle.post.h b/base/internal/handle.post.h new file mode 100644 index 000000000..be57aed18 --- /dev/null +++ b/base/internal/handle.post.h @@ -0,0 +1,142 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "base/handle.h" + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_POST_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_POST_H_ + +#include +#include + +#include "absl/base/optimization.h" +#include "base/memory_manager.h" + +namespace cel::base_internal { + +template +struct HandleFactory { + template + static Transient MakeInlined(Args&&... args) { + static_assert(std::is_base_of_v, "F is not derived from T"); + return Transient(kHandleInPlace, kInlinedResource, + std::forward(args)...); + } + + template + static Transient MakeUnmanaged(F& from) { + static_assert(std::is_base_of_v, "F is not derived from T"); + return Transient(kHandleInPlace, kUnmanagedResource, from); + } +}; + +template +struct HandleFactory { + // Constructs a persistent handle whose underlying object is stored in the + // handle itself. + template + static std::enable_if_t, Persistent> + Make(Args&&... args) { + static_assert(std::is_base_of_v, + "T is not derived from Resource"); + static_assert(std::is_base_of_v, "F is not derived from T"); + return Persistent(kHandleInPlace, kInlinedResource, + std::forward(args)...); + } + + // Constructs a persistent handle whose underlying object is heap allocated + // and potentially reference counted, depending on the memory manager + // implementation. + template + static std::enable_if_t>, + Persistent> + Make(MemoryManager& memory_manager, Args&&... args) { + static_assert(std::is_base_of_v, + "T is not derived from Resource"); + static_assert(std::is_base_of_v, "F is not derived from T"); +#if defined(__cpp_lib_is_pointer_interconvertible) && \ + __cpp_lib_is_pointer_interconvertible >= 201907L + // Only available in C++20. + static_assert(std::is_pointer_interconvertible_base_of_v, + "F must be pointer interconvertible to Resource"); +#endif + auto managed_memory = memory_manager.New(std::forward(args)...); + if (ABSL_PREDICT_FALSE(managed_memory == nullptr)) { + return Persistent(); + } + bool unmanaged = GetManagedMemorySize(managed_memory) == 0; +#ifndef NDEBUG + if (!unmanaged) { + // Ensure there is no funny business going on by asserting that the size + // and alignment are the same as F. + ABSL_ASSERT(GetManagedMemorySize(managed_memory) == sizeof(F)); + ABSL_ASSERT(GetManagedMemoryAlignment(managed_memory) == alignof(F)); + // Ensure that the implementation F has correctly overriden + // SizeAndAlignment(). + auto [size, align] = static_cast(managed_memory.get()) + ->SizeAndAlignment(); + ABSL_ASSERT(size == sizeof(F)); + ABSL_ASSERT(align == alignof(F)); + // Ensures that casting F to the most base class does not require + // thunking, which occurs when using multiple inheritance. If thunking is + // used our usage of memory manager will break. If you think you need + // thunking, please consult the CEL team. + ABSL_ASSERT(static_cast( + static_cast(managed_memory.get())) == + static_cast(managed_memory.get())); + } +#endif + // Convert ManagedMemory to Persistent, transferring reference + // counting responsibility to it when applicable. `unmanaged` is true when + // no reference counting is required. + return unmanaged ? Persistent(kHandleInPlace, kUnmanagedResource, + *ManagedMemoryRelease(managed_memory)) + : Persistent(kHandleInPlace, kManagedResource, + *ManagedMemoryRelease(managed_memory)); + } +}; + +template +bool IsManagedHandle(const Transient& handle) { + return handle.impl_.IsManaged(); +} + +template +bool IsUnmanagedHandle(const Transient& handle) { + return handle.impl_.IsUnmanaged(); +} + +template +bool IsInlinedHandle(const Transient& handle) { + return handle.impl_.IsInlined(); +} + +template +bool IsManagedHandle(const Persistent& handle) { + return handle.impl_.IsManaged(); +} + +template +bool IsUnmanagedHandle(const Persistent& handle) { + return handle.impl_.IsUnmanaged(); +} + +template +bool IsInlinedHandle(const Persistent& handle) { + return handle.impl_.IsInlined(); +} + +} // namespace cel::base_internal + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_POST_H_ diff --git a/base/internal/handle.pre.h b/base/internal/handle.pre.h new file mode 100644 index 000000000..423142b58 --- /dev/null +++ b/base/internal/handle.pre.h @@ -0,0 +1,178 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "base/handle.h" + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_PRE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_PRE_H_ + +#include +#include +#include + +#include "base/memory_manager.h" + +namespace cel { + +class Type; +class Value; + +template +class Transient; +template +class Persistent; + +namespace base_internal { + +class TypeHandleBase; +class ValueHandleBase; + +// Enumeration of different types of handles. +enum class HandleType { + kTransient = 0, + kPersistent, +}; + +template +struct HandleTraits; + +// Convenient aliases. +template +using TransientHandleTraits = HandleTraits; +template +using PersistentHandleTraits = HandleTraits; + +template +struct HandleFactory; + +// Convenient aliases. +template +using TransientHandleFactory = HandleFactory; +template +using PersistentHandleFactory = HandleFactory; + +struct HandleInPlace { + explicit HandleInPlace() = default; +}; + +// Disambiguation tag used to select the appropriate constructor on Persistent +// and Transient. Think std::in_place. +inline constexpr HandleInPlace kHandleInPlace{}; + +// Virtual base class for all classes that can be managed by handles. +class Resource { + public: + virtual ~Resource() = default; + + Resource& operator=(const Resource&) = delete; + Resource& operator=(Resource&&) = delete; + + private: + friend class cel::Type; + friend class cel::Value; + friend class TypeHandleBase; + friend class ValueHandleBase; + template + friend struct HandleFactory; + + Resource() = default; + Resource(const Resource&) = default; + Resource(Resource&&) = default; + + // For non-inlined resources that are reference counted, this is the result of + // `sizeof` and `alignof` for the most derived class. + virtual std::pair SizeAndAlignment() const = 0; + + // Called by TypeHandleBase, ValueHandleBase, Type, and Value for reference + // counting. + void Ref() const { + auto [size, align] = SizeAndAlignment(); + MemoryManager::Ref(this, size, align); + } + + // Called by TypeHandleBase, ValueHandleBase, Type, and Value for reference + // counting. + void Unref() const { + auto [size, align] = SizeAndAlignment(); + MemoryManager::Unref(this, size, align); + } +}; + +// Non-virtual base class for all classes that can be stored inline in handles. +// This is primarily used with SFINAE. +class ResourceInlined {}; + +template +struct InlinedResource { + explicit InlinedResource() = default; +}; + +// Disambiguation tag used to select the appropriate constructor in the handle +// implementation. Think std::in_place. +template +inline constexpr InlinedResource kInlinedResource{}; + +template +struct ManagedResource { + explicit ManagedResource() = default; +}; + +// Disambiguation tag used to select the appropriate constructor in the handle +// implementation. Think std::in_place. +template +inline constexpr ManagedResource kManagedResource{}; + +template +struct UnmanagedResource { + explicit UnmanagedResource() = default; +}; + +// Disambiguation tag used to select the appropriate constructor in the handle +// implementation. Think std::in_place. +template +inline constexpr UnmanagedResource kUnmanagedResource{}; + +// Non-virtual base class enforces type requirements via static_asserts for +// types used with handles. +template +struct HandlePolicy { + static_assert(!std::is_reference_v, "Handles do not support references"); + static_assert(!std::is_pointer_v, "Handles do not support pointers"); + static_assert(std::is_class_v, "Handles only support classes"); + static_assert(!std::is_volatile_v, "Handles do not support volatile"); + static_assert((std::is_base_of_v> && + !std::is_same_v> && + !std::is_same_v>), + "Handles do not support this type"); +}; + +template +bool IsManagedHandle(const Transient& handle); +template +bool IsUnmanagedHandle(const Transient& handle); +template +bool IsInlinedHandle(const Transient& handle); + +template +bool IsManagedHandle(const Persistent& handle); +template +bool IsUnmanagedHandle(const Persistent& handle); +template +bool IsInlinedHandle(const Persistent& handle); + +} // namespace base_internal + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_PRE_H_ From 61593803b00f61f7d85843d39aa9b8ff6df5bd8c Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 8 Mar 2022 22:17:26 +0000 Subject: [PATCH 017/111] Update cel_number.h with utilities for cross numeric conversions for key lookups. PiperOrigin-RevId: 433308944 --- eval/public/cel_number.h | 63 ++++++++++++++++++++++++++++++++++ eval/public/cel_number_test.cc | 37 ++++++++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/eval/public/cel_number.h b/eval/public/cel_number.h index e4a6a91d4..54c76a057 100644 --- a/eval/public/cel_number.h +++ b/eval/public/cel_number.h @@ -17,6 +17,7 @@ #include #include +#include #include "absl/types/variant.h" #include "eval/public/cel_value.h" @@ -31,6 +32,19 @@ constexpr double kDoubleToIntMax = static_cast(kInt64Max); constexpr double kDoubleToIntMin = static_cast(kInt64Min); constexpr double kDoubleToUintMax = static_cast(kUint64Max); +// The highest integer values that are round-trippable after rounding and +// casting to double. +template +constexpr int RoundingError() { + return 1 << (std::numeric_limits::digits - + std::numeric_limits::digits - 1); +} + +constexpr double kMaxDoubleRepresentableAsInt = + static_cast(kInt64Max - RoundingError()); +constexpr double kMaxDoubleRepresentableAsUint = + static_cast(kUint64Max - RoundingError()); + namespace internal { using NumberVariant = absl::variant; @@ -169,6 +183,26 @@ struct CompareVisitor { NumberVariant rhs; }; +struct LosslessConvertibleToIntVisitor { + constexpr bool operator()(double value) const { + return value >= kDoubleToIntMin && value <= kMaxDoubleRepresentableAsInt && + value == static_cast(static_cast(value)); + } + constexpr bool operator()(uint64_t value) const { + return value <= kUintToIntMax; + } + constexpr bool operator()(int64_t value) const { return true; } +}; + +struct LosslessConvertibleToUintVisitor { + constexpr bool operator()(double value) const { + return value >= 0 && value <= kMaxDoubleRepresentableAsUint && + value == static_cast(static_cast(value)); + } + constexpr bool operator()(uint64_t value) const { return true; } + constexpr bool operator()(int64_t value) const { return value >= 0; } +}; + } // namespace internal // Utility class for CEL number operations. @@ -198,6 +232,35 @@ class CelNumber { constexpr explicit CelNumber(int64_t int_value) : value_(int_value) {} constexpr explicit CelNumber(uint64_t uint_value) : value_(uint_value) {} + // Return a double representation of the value. + constexpr double AsDouble() const { + return absl::visit(internal::ConversionVisitor(), value_); + } + + // Return signed int64_t representation for the value. + // Caller must guarantee the underlying value is representatble as an + // int. + constexpr int64_t AsInt() const { + return absl::visit(internal::ConversionVisitor(), value_); + } + + // Return unsigned int64_t representation for the value. + // Caller must guarantee the underlying value is representable as an + // uint. + constexpr uint64_t AsUint() const { + return absl::visit(internal::ConversionVisitor(), value_); + } + + // For key lookups, check if the conversion to signed int is lossless. + constexpr bool LosslessConvertibleToInt() const { + return absl::visit(internal::LosslessConvertibleToIntVisitor(), value_); + } + + // For key lookups, check if the conversion to unsigned int is lossless. + constexpr bool LosslessConvertibleToUint() const { + return absl::visit(internal::LosslessConvertibleToUintVisitor(), value_); + } + constexpr bool operator<(CelNumber other) const { return Compare(other) == internal::ComparisonResult::kLesser; } diff --git a/eval/public/cel_number_test.cc b/eval/public/cel_number_test.cc index 77b8f44da..9a9855216 100644 --- a/eval/public/cel_number_test.cc +++ b/eval/public/cel_number_test.cc @@ -14,6 +14,7 @@ #include "eval/public/cel_number.h" +#include #include #include "absl/types/optional.h" @@ -140,5 +141,41 @@ TEST(CelNumber, GetNumberFromCelValue) { absl::nullopt); } +TEST(CelNumber, Conversions) { + EXPECT_TRUE(CelNumber::FromDouble(1.0).LosslessConvertibleToInt()); + EXPECT_TRUE(CelNumber::FromDouble(1.0).LosslessConvertibleToUint()); + EXPECT_FALSE(CelNumber::FromDouble(1.1).LosslessConvertibleToInt()); + EXPECT_FALSE(CelNumber::FromDouble(1.1).LosslessConvertibleToUint()); + EXPECT_TRUE(CelNumber::FromDouble(-1.0).LosslessConvertibleToInt()); + EXPECT_FALSE(CelNumber::FromDouble(-1.0).LosslessConvertibleToUint()); + EXPECT_TRUE( + CelNumber::FromDouble(kDoubleToIntMin).LosslessConvertibleToInt()); + + // Need to add/substract a large number since double resolution is low at this + // range. + static_assert(CelNumber::FromDouble(kMaxDoubleRepresentableAsUint + + RoundingError()) != + CelNumber::FromDouble(kMaxDoubleRepresentableAsUint)); + EXPECT_FALSE(CelNumber::FromDouble(kMaxDoubleRepresentableAsUint + + RoundingError()) + .LosslessConvertibleToUint()); + static_assert(CelNumber::FromDouble(kMaxDoubleRepresentableAsInt + + RoundingError()) != + CelNumber::FromDouble(kMaxDoubleRepresentableAsInt)); + EXPECT_FALSE(CelNumber::FromDouble(kMaxDoubleRepresentableAsInt + + RoundingError()) + .LosslessConvertibleToInt()); + static_assert(CelNumber::FromDouble(kDoubleToIntMin - + (2 * RoundingError() + 1)) != + CelNumber::FromDouble(kDoubleToIntMin)); + EXPECT_FALSE( + CelNumber::FromDouble(kDoubleToIntMin - 1025).LosslessConvertibleToInt()); + + EXPECT_EQ(CelNumber::FromInt64(1).AsUint(), 1u); + EXPECT_EQ(CelNumber::FromUint64(1).AsInt(), 1); + EXPECT_EQ(CelNumber::FromDouble(1.0).AsUint(), 1); + EXPECT_EQ(CelNumber::FromDouble(1.0).AsInt(), 1); +} + } // namespace } // namespace google::api::expr::runtime From 143ce2209fff1dc6bc2a5e303ae10560df1b6936 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 9 Mar 2022 01:13:27 +0000 Subject: [PATCH 018/111] Add support for cross numeric lookups in CEL C++ evaluator. PiperOrigin-RevId: 433349120 --- eval/compiler/flat_expr_builder.cc | 2 +- eval/compiler/flat_expr_builder.h | 9 + eval/compiler/flat_expr_builder_test.cc | 41 +++- eval/eval/BUILD | 8 + eval/eval/container_access_step.cc | 90 +++++--- eval/eval/container_access_step_test.cc | 270 +++++++++++++++++++++++- eval/eval/evaluator_core.cc | 3 +- eval/eval/evaluator_core.h | 15 +- eval/eval/evaluator_core_test.cc | 12 +- eval/public/cel_expr_builder_factory.cc | 2 + 10 files changed, 415 insertions(+), 37 deletions(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 69a494d80..a2a50e1f1 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1060,7 +1060,7 @@ FlatExprBuilder::CreateExpressionImpl( comprehension_max_iterations_, std::move(iter_variable_names), enable_unknowns_, enable_unknown_function_results_, enable_missing_attribute_errors_, enable_null_coercion_, - std::move(rewrite_buffer)); + enable_heterogeneous_equality_, std::move(rewrite_buffer)); if (warnings != nullptr) { *warnings = std::move(warnings_builder).warnings(); diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 993672309..9094c0c98 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -48,6 +48,7 @@ class FlatExprBuilder : public CelExpressionBuilder { enable_comprehension_vulnerability_check_(false), enable_null_coercion_(true), enable_wrapper_type_null_unboxing_(false), + enable_heterogeneous_equality_(false), descriptor_pool_(descriptor_pool), message_factory_(message_factory) {} @@ -141,6 +142,13 @@ class FlatExprBuilder : public CelExpressionBuilder { enable_wrapper_type_null_unboxing_ = enabled; } + // If enable_heterogeneous_equality is enabled, the evaluator will use + // hetergeneous equality semantics. This includes the == operator and numeric + // index lookups in containers. + void set_enable_heterogeneous_equality(bool enabled) { + enable_heterogeneous_equality_ = enabled; + } + absl::StatusOr> CreateExpression( const google::api::expr::v1alpha1::Expr* expr, const google::api::expr::v1alpha1::SourceInfo* source_info) const override; @@ -179,6 +187,7 @@ class FlatExprBuilder : public CelExpressionBuilder { bool enable_comprehension_vulnerability_check_; bool enable_null_coercion_; bool enable_wrapper_type_null_unboxing_; + bool enable_heterogeneous_equality_; const google::protobuf::DescriptorPool* descriptor_pool_; google::protobuf::MessageFactory* message_factory_; diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index be684a7c9..148ce8e71 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -58,11 +58,6 @@ namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::CheckedExpr; -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::ParsedExpr; -using google::api::expr::v1alpha1::SourceInfo; - using testing::Eq; using testing::HasSubstr; using cel::internal::StatusIs; @@ -1575,6 +1570,42 @@ TEST(FlatExprBuilderTest, NullUnboxingDisabled) { EXPECT_THAT(result, test::IsCelInt64(0)); } +TEST(FlatExprBuilderTest, HeterogeneousEqualityEnabled) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("{1: 2, 2u: 3}[1.0]")); + FlatExprBuilder builder; + builder.set_enable_heterogeneous_equality(true); + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + EXPECT_THAT(result, test::IsCelInt64(2)); +} + +TEST(FlatExprBuilderTest, HeterogeneousEqualityDisabled) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + parser::Parse("{1: 2, 2u: 3}[1.0]")); + FlatExprBuilder builder; + builder.set_enable_heterogeneous_equality(false); + ASSERT_OK_AND_ASSIGN(auto expression, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, + expression->Evaluate(activation, &arena)); + + EXPECT_THAT(result, + test::IsCelError(StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid map key type")))); +} + TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { ASSERT_OK_AND_ASSIGN( ParsedExpr parsed_expr, diff --git a/eval/eval/BUILD b/eval/eval/BUILD index b918ed5de..22a67c7fa 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -103,6 +103,7 @@ cc_library( ":evaluator_core", ":expression_step_base", "//base:memory_manager", + "//eval/public:cel_number", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "@com_google_absl//absl/status", @@ -366,15 +367,22 @@ cc_test( ":container_access_step", ":ident_step", "//eval/public:activation", + "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", "//eval/public:cel_builtins", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", + "//eval/public/testing:matchers", "//internal:status_macros", "//internal:testing", + "//parser", "@com_google_absl//absl/status", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index 51ba17ac8..cc0bdcb66 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -8,6 +8,7 @@ #include "base/memory_manager.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" +#include "eval/public/cel_number.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" @@ -30,44 +31,83 @@ class ContainerAccessStep : public ExpressionStepBase { ValueAttributePair PerformLookup(ExecutionFrame* frame) const; CelValue LookupInMap(const CelMap* cel_map, const CelValue& key, - cel::MemoryManager& manager) const; + ExecutionFrame* frame) const; CelValue LookupInList(const CelList* cel_list, const CelValue& key, - cel::MemoryManager& manager) const; + ExecutionFrame* frame) const; }; -inline CelValue ContainerAccessStep::LookupInMap( - const CelMap* cel_map, const CelValue& key, - cel::MemoryManager& manager) const { - auto status = CelValue::CheckMapKeyType(key); +inline CelValue ContainerAccessStep::LookupInMap(const CelMap* cel_map, + const CelValue& key, + ExecutionFrame* frame) const { + if (frame->enable_heterogeneous_numeric_lookups()) { + // Double isn't a supported key type but may be convertible to an integer. + absl::optional number = GetNumberFromCelValue(key); + if (number.has_value()) { + // consider uint as uint first then try coercion. + if (key.IsUint64()) { + absl::optional maybe_value = (*cel_map)[key]; + if (maybe_value.has_value()) { + return *maybe_value; + } + } + if (number->LosslessConvertibleToInt()) { + absl::optional maybe_value = + (*cel_map)[CelValue::CreateInt64(number->AsInt())]; + if (maybe_value.has_value()) { + return *maybe_value; + } + } + if (number->LosslessConvertibleToUint()) { + absl::optional maybe_value = + (*cel_map)[CelValue::CreateUint64(number->AsUint())]; + if (maybe_value.has_value()) { + return *maybe_value; + } + } + return CreateNoSuchKeyError(frame->memory_manager(), + "Key not found in map"); + } + } + + absl::Status status = CelValue::CheckMapKeyType(key); if (!status.ok()) { - return CreateErrorValue(manager, status); + return CreateErrorValue(frame->memory_manager(), status); } absl::optional maybe_value = (*cel_map)[key]; if (maybe_value.has_value()) { return maybe_value.value(); } - return CreateNoSuchKeyError(manager, "Key not found in map"); + + return CreateNoSuchKeyError(frame->memory_manager(), "Key not found in map"); } -inline CelValue ContainerAccessStep::LookupInList( - const CelList* cel_list, const CelValue& key, - cel::MemoryManager& manager) const { - switch (key.type()) { - case CelValue::Type::kInt64: { - int64_t idx = key.Int64OrDie(); - if (idx < 0 || idx >= cel_list->size()) { - return CreateErrorValue(manager, - absl::StrCat("Index error: index=", idx, - " size=", cel_list->size())); - } - return (*cel_list)[idx]; +inline CelValue ContainerAccessStep::LookupInList(const CelList* cel_list, + const CelValue& key, + ExecutionFrame* frame) const { + absl::optional maybe_idx; + if (frame->enable_heterogeneous_numeric_lookups()) { + auto number = GetNumberFromCelValue(key); + if (number.has_value() && number->LosslessConvertibleToInt()) { + maybe_idx = number->AsInt(); } - default: { + } else if (int64_t held_int; key.GetValue(&held_int)) { + maybe_idx = held_int; + } + + if (maybe_idx.has_value()) { + int64_t idx = *maybe_idx; + if (idx < 0 || idx >= cel_list->size()) { return CreateErrorValue( - manager, absl::StrCat("Index error: expected integer type, got ", - CelValue::TypeName(key.type()))); + frame->memory_manager(), + absl::StrCat("Index error: index=", idx, " size=", cel_list->size())); } + return (*cel_list)[idx]; } + + return CreateErrorValue( + frame->memory_manager(), + absl::StrCat("Index error: expected integer type, got ", + CelValue::TypeName(key.type()))); } ContainerAccessStep::ValueAttributePair ContainerAccessStep::PerformLookup( @@ -113,11 +153,11 @@ ContainerAccessStep::ValueAttributePair ContainerAccessStep::PerformLookup( switch (container.type()) { case CelValue::Type::kMap: { const CelMap* cel_map = container.MapOrDie(); - return {LookupInMap(cel_map, key, frame->memory_manager()), trail}; + return {LookupInMap(cel_map, key, frame), trail}; } case CelValue::Type::kList: { const CelList* cel_list = container.ListOrDie(); - return {LookupInList(cel_list, key, frame->memory_manager()), trail}; + return {LookupInList(cel_list, key, frame), trail}; } default: { auto error = diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 2af1f9ce6..5a8c9f2e5 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -1,22 +1,31 @@ #include "eval/eval/container_access_step.h" +#include #include #include #include +#include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/struct.pb.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "absl/status/status.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_builtins.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/testing/matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "parser/parser.h" namespace google::api::expr::runtime { @@ -25,6 +34,7 @@ namespace { using ::google::api::expr::v1alpha1::Expr; using ::google::api::expr::v1alpha1::SourceInfo; using ::google::protobuf::Struct; +using testing::_; using testing::HasSubstr; using cel::internal::StatusIs; @@ -317,7 +327,265 @@ TEST_F(ContainerAccessStepTest, TestInvalidContainerType) { INSTANTIATE_TEST_SUITE_P(CombinedContainerTest, ContainerAccessStepUniformityTest, - testing::Combine(testing::Bool(), testing::Bool())); + testing::Combine(/*receiver_style*/ testing::Bool(), + /*unknown_enabled*/ testing::Bool())); + +class ContainerAccessHeterogeneousLookupsTest : public testing::Test { + public: + ContainerAccessHeterogeneousLookupsTest() { + options_.enable_heterogeneous_equality = true; + builder_ = CreateCelExpressionBuilder(options_); + } + + protected: + InterpreterOptions options_; + std::unique_ptr builder_; + google::protobuf::Arena arena_; + Activation activation_; +}; + +TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleMapKeyInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1.0]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelInt64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleMapKeyNotAnInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1.1]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleMapKeyUint) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u}[1.0]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelUint64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleListIndex) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][1.0]")); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelInt64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleListIndexNotAnInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][1.1]")); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +// treat uint as uint before trying coercion to signed int. +TEST_F(ContainerAccessHeterogeneousLookupsTest, UintKeyAsUint) { + // TODO(issues/5): Map creation should error here instead of permitting + // mixed key types with equivalent values. + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u, 1: 2}[1u]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelUint64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, UintKeyAsInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1u]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelInt64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, IntKeyAsUint) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u}[1]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelUint64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, UintListIndex) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][2u]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelInt64(3)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsTest, StringKeyUnaffected) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2, '1': 3}['1']")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelInt64(3)); +} + +class ContainerAccessHeterogeneousLookupsDisabledTest : public testing::Test { + public: + ContainerAccessHeterogeneousLookupsDisabledTest() { + builder_ = CreateCelExpressionBuilder(options_); + } + + protected: + InterpreterOptions options_; + std::unique_ptr builder_; + google::protobuf::Arena arena_; + Activation activation_; +}; + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleMapKeyInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1.0]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleMapKeyNotAnInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1.1]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleMapKeyUint) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u}[1.0]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleListIndex) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][1.0]")); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, + DoubleListIndexNotAnInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][1.1]")); + + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, UintKeyAsUint) { + // TODO(issues/5): Map creation should error here instead of permitting + // mixed key types with equivalent values. + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u, 1: 2}[1u]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelUint64(2)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, UintKeyAsInt) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1u]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, IntKeyAsUint) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u}[1]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, UintListIndex) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][2u]")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelError(_)); +} + +TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, StringKeyUnaffected) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2, '1': 3}['1']")); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( + &expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation_, &arena_)); + + EXPECT_THAT(result, test::IsCelInt64(3)); +} } // namespace diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index 92b13aca3..febbad54d 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -156,7 +156,8 @@ absl::StatusOr CelExpressionFlatImpl::Trace( ExecutionFrame frame(path_, activation, descriptor_pool_, message_factory_, max_iterations_, state, enable_unknowns_, enable_unknown_function_results_, - enable_missing_attribute_errors_, enable_null_coercion_); + enable_missing_attribute_errors_, enable_null_coercion_, + enable_heterogeneous_equality_); EvaluatorStack* stack = &frame.value_stack(); size_t initial_stack_size = stack->size(); diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index a59e87a75..7f3308c6f 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -124,7 +124,8 @@ class ExecutionFrame { CelExpressionFlatEvaluationState* state, bool enable_unknowns, bool enable_unknown_function_results, bool enable_missing_attribute_errors, - bool enable_null_coercion) + bool enable_null_coercion, + bool enable_heterogeneous_numeric_lookups) : pc_(0UL), execution_path_(flat), activation_(activation), @@ -134,6 +135,8 @@ class ExecutionFrame { enable_unknown_function_results_(enable_unknown_function_results), enable_missing_attribute_errors_(enable_missing_attribute_errors), enable_null_coercion_(enable_null_coercion), + enable_heterogeneous_numeric_lookups_( + enable_heterogeneous_numeric_lookups), attribute_utility_(&activation.unknown_attribute_patterns(), &activation.missing_attribute_patterns(), state->memory_manager()), @@ -168,6 +171,10 @@ class ExecutionFrame { bool enable_null_coercion() const { return enable_null_coercion_; } + bool enable_heterogeneous_numeric_lookups() const { + return enable_heterogeneous_numeric_lookups_; + } + cel::MemoryManager& memory_manager() { return state_->memory_manager(); } const google::protobuf::DescriptorPool* descriptor_pool() const { @@ -240,6 +247,7 @@ class ExecutionFrame { bool enable_unknown_function_results_; bool enable_missing_attribute_errors_; bool enable_null_coercion_; + bool enable_heterogeneous_numeric_lookups_; AttributeUtility attribute_utility_; const int max_iterations_; int iterations_; @@ -265,6 +273,7 @@ class CelExpressionFlatImpl : public CelExpression { bool enable_unknown_function_results = false, bool enable_missing_attribute_errors = false, bool enable_null_coercion = true, + bool enable_heterogeneous_equality = false, std::unique_ptr rewritten_expr = nullptr) : rewritten_expr_(std::move(rewritten_expr)), path_(std::move(path)), @@ -275,7 +284,8 @@ class CelExpressionFlatImpl : public CelExpression { enable_unknowns_(enable_unknowns), enable_unknown_function_results_(enable_unknown_function_results), enable_missing_attribute_errors_(enable_missing_attribute_errors), - enable_null_coercion_(enable_null_coercion) {} + enable_null_coercion_(enable_null_coercion), + enable_heterogeneous_equality_(enable_heterogeneous_equality) {} // Move-only CelExpressionFlatImpl(const CelExpressionFlatImpl&) = delete; @@ -316,6 +326,7 @@ class CelExpressionFlatImpl : public CelExpression { bool enable_unknown_function_results_; bool enable_missing_attribute_errors_; bool enable_null_coercion_; + bool enable_heterogeneous_equality_; }; } // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index 58946d38a..61728e73a 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -72,7 +72,11 @@ TEST(EvaluatorCoreTest, ExecutionFrameNext) { ExecutionFrame frame(path, activation, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), 0, &state, - false, false, false, true); + /*enable_unknowns=*/false, + /*enable_unknown_funcion_results=*/false, + /*enable_missing_attribute_errors=*/false, + /*enable_null_coercion=*/true, + /*enable_heterogeneous_numeric_lookups=*/true); EXPECT_THAT(frame.Next(), Eq(path[0].get())); EXPECT_THAT(frame.Next(), Eq(path[1].get())); @@ -94,7 +98,11 @@ TEST(EvaluatorCoreTest, ExecutionFrameSetGetClearVar) { ExecutionFrame frame(path, activation, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), 0, &state, - false, false, false, true); + /*enable_unknowns=*/false, + /*enable_unknown_funcion_results=*/false, + /*enable_missing_attribute_errors=*/false, + /*enable_null_coercion=*/true, + /*enable_heterogeneous_numeric_lookups=*/true); CelValue original = CelValue::CreateInt64(test_value); Expr ident; diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 54d51fc5c..b349a37d2 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -178,6 +178,8 @@ std::unique_ptr CreateCelExpressionBuilder( builder->set_enable_null_coercion(options.enable_null_to_message_coercion); builder->set_enable_wrapper_type_null_unboxing( options.enable_empty_wrapper_null_unboxing); + builder->set_enable_heterogeneous_equality( + options.enable_heterogeneous_equality); switch (options.unknown_processing) { case UnknownProcessingOptions::kAttributeAndFunction: From 97cec8817b2b634c597c9937b27817a9c4cb9e98 Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 9 Mar 2022 03:25:19 +0000 Subject: [PATCH 019/111] Internal change PiperOrigin-RevId: 433368987 --- base/BUILD | 58 ++--- base/internal/BUILD | 25 +-- base/internal/type.h | 72 ------ base/internal/type.post.h | 249 +++++++++++++++++++++ base/internal/type.pre.h | 44 ++++ base/type.cc | 282 ++++++----------------- base/type.h | 417 +++++++++++++++++++++++++++------- base/type_factory.cc | 76 +++++++ base/type_factory.h | 83 +++++++ base/type_test.cc | 457 +++++++++++++++++++++----------------- 10 files changed, 1131 insertions(+), 632 deletions(-) delete mode 100644 base/internal/type.h create mode 100644 base/internal/type.post.h create mode 100644 base/internal/type.pre.h create mode 100644 base/type_factory.cc create mode 100644 base/type_factory.h diff --git a/base/BUILD b/base/BUILD index f6ee44c30..ccb2be2fc 100644 --- a/base/BUILD +++ b/base/BUILD @@ -97,13 +97,20 @@ cc_test( cc_library( name = "type", - srcs = ["type.cc"], - hdrs = ["type.h"], + srcs = [ + "type.cc", + "type_factory.cc", + ], + hdrs = [ + "type.h", + "type_factory.h", + ], deps = [ + ":handle", ":kind", + ":memory_manager", "//base/internal:type", - "//internal:reference_counted", - "@com_google_absl//absl/base", + "//internal:no_destructor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", @@ -115,49 +122,10 @@ cc_test( name = "type_test", srcs = ["type_test.cc"], deps = [ + ":handle", + ":memory_manager", ":type", "//internal:testing", "@com_google_absl//absl/hash:hash_testing", ], ) - -cc_library( - name = "value", - srcs = ["value.cc"], - hdrs = ["value.h"], - deps = [ - ":kind", - ":type", - "//base/internal:value", - "//internal:casts", - "//internal:reference_counted", - "//internal:status_macros", - "//internal:strings", - "//internal:time", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:variant", - ], -) - -cc_test( - name = "value_test", - srcs = ["value_test.cc"], - deps = [ - ":type", - ":value", - "//internal:strings", - "//internal:testing", - "//internal:time", - "@com_google_absl//absl/hash:hash_testing", - "@com_google_absl//absl/status", - "@com_google_absl//absl/time", - ], -) diff --git a/base/internal/BUILD b/base/internal/BUILD index 73bb7cb60..3715ed189 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -46,31 +46,18 @@ cc_library( ], ) +# These headers should only ever be used by ../type.h. They are here to avoid putting +# large amounts of implementation details in public headers. cc_library( name = "type", - hdrs = ["type.h"], - deps = [ - "//base:kind", - "//internal:reference_counted", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", + textual_hdrs = [ + "type.pre.h", + "type.post.h", ], -) - -cc_library( - name = "value", - hdrs = ["value.h"], deps = [ - ":type", - "//base:kind", - "//internal:casts", - "//internal:reference_counted", - "@com_google_absl//absl/base:config", + "//base:handle", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/hash", "@com_google_absl//absl/numeric:bits", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", ], ) diff --git a/base/internal/type.h b/base/internal/type.h deleted file mode 100644 index 3b2220c42..000000000 --- a/base/internal/type.h +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_H_ -#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_H_ - -#include "absl/hash/hash.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "base/kind.h" -#include "internal/reference_counted.h" - -namespace cel { - -class Type; - -namespace base_internal { - -class SimpleType; - -class BaseType : public cel::internal::ReferenceCounted { - public: - // Returns the type kind. - virtual Kind kind() const = 0; - - // Returns the type name, i.e. map or google.protobuf.Any. - virtual absl::string_view name() const = 0; - - // Returns the type parameters of the type, i.e. key and value of map type. - virtual absl::Span parameters() const = 0; - - protected: - // Overriden by subclasses to implement more strictly equality testing. By - // default `cel::Type` ensures `kind()` and `name()` are equal, this behavior - // cannot be overriden. It is completely valid and acceptable to simply return - // `true`. - // - // This method should only ever be called by cel::Type. - virtual bool Equals(const cel::Type& value) const = 0; - - // Overriden by subclasses to implement better hashing. By default `cel::Type` - // hashes `kind()` and `name()`, this behavior cannot be overriden. It is - // completely valid and acceptable to simply do nothing. - // - // This method should only ever be called by cel::Type. - virtual void HashValue(absl::HashState state) const = 0; - - private: - friend class cel::Type; - friend class SimpleType; - - // The default constructor is private so that only sanctioned classes can - // extend it. Users should extend those classes instead of this one. - constexpr BaseType() = default; -}; - -} // namespace base_internal - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_H_ diff --git a/base/internal/type.post.h b/base/internal/type.post.h new file mode 100644 index 000000000..956dd69a9 --- /dev/null +++ b/base/internal/type.post.h @@ -0,0 +1,249 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "base/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_POST_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_POST_H_ + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/hash/hash.h" +#include "absl/numeric/bits.h" +#include "base/handle.h" + +namespace cel { + +namespace base_internal { + +// Base implementation of persistent and transient handles. This contains +// implementation details shared among both, but is never used directly. The +// derived classes are responsible for defining appropriate constructors and +// assignments. +class TypeHandleBase { + public: + constexpr TypeHandleBase() = default; + + // Called by `Transient` and `Persistent` to implement the same operator. They + // will handle enforcing const correctness. + Type& operator*() const { return get(); } + + // Called by `Transient` and `Persistent` to implement the same operator. They + // will handle enforcing const correctness. + Type* operator->() const { return std::addressof(get()); } + + // Called by internal accessors `base_internal::IsXHandle`. + constexpr bool IsManaged() const { + return (rep_ & kTypeHandleUnmanaged) == 0; + } + + // Called by internal accessors `base_internal::IsXHandle`. + constexpr bool IsUnmanaged() const { + return (rep_ & kTypeHandleUnmanaged) != 0; + } + + // Called by internal accessors `base_internal::IsXHandle`. + constexpr bool IsInlined() const { return false; } + + // Called by `Transient` and `Persistent` to implement the same function. + template + bool Is() const { + return static_cast(*this) && T::Is(static_cast(**this)); + } + + // Called by `Transient` and `Persistent` to implement the same operator. + explicit operator bool() const { return (rep_ & kTypeHandleMask) != 0; } + + // Called by `Transient` and `Persistent` to implement the same operator. + friend bool operator==(const TypeHandleBase& lhs, const TypeHandleBase& rhs) { + const Type& lhs_type = ABSL_PREDICT_TRUE(static_cast(lhs)) + ? lhs.get() + : static_cast(NullType::Get()); + const Type& rhs_type = ABSL_PREDICT_TRUE(static_cast(rhs)) + ? rhs.get() + : static_cast(NullType::Get()); + return lhs_type.Equals(rhs_type); + } + + // Called by `Transient` and `Persistent` to implement std::swap. + friend void swap(TypeHandleBase& lhs, TypeHandleBase& rhs) { + std::swap(lhs.rep_, rhs.rep_); + } + + template + friend H AbslHashValue(H state, const TypeHandleBase& handle) { + if (ABSL_PREDICT_TRUE(static_cast(handle))) { + handle.get().HashValue(absl::HashState::Create(&state)); + } else { + NullType::Get().HashValue(absl::HashState::Create(&state)); + } + return state; + } + + private: + template + friend class TypeHandle; + + void Unref() const { + if ((rep_ & kTypeHandleUnmanaged) == 0) { + get().Unref(); + } + } + + uintptr_t Ref() const { + if ((rep_ & kTypeHandleUnmanaged) == 0) { + get().Ref(); + } + return rep_; + } + + Type& get() const { return *reinterpret_cast(rep_ & kTypeHandleMask); } + + // There are no inlined types, so we represent everything as a pointer and use + // tagging to differentiate between reference counted and arena-allocated. + uintptr_t rep_ = kTypeHandleUnmanaged; +}; + +// All methods are called by `Transient`. +template <> +class TypeHandle final : public TypeHandleBase { + public: + constexpr TypeHandle() = default; + + constexpr TypeHandle(const TransientTypeHandle& other) = default; + + constexpr TypeHandle(TransientTypeHandle&& other) = default; + + template + TypeHandle(UnmanagedResource, F& from) { + uintptr_t rep = reinterpret_cast( + static_cast(static_cast(std::addressof(from)))); + ABSL_ASSERT(absl::countr_zero(rep) >= + 2); // Verify the lower 2 bits are available. + rep_ = rep | kTypeHandleUnmanaged; + } + + explicit TypeHandle(const PersistentTypeHandle& other); + + TypeHandle& operator=(const TransientTypeHandle& other) = default; + + TypeHandle& operator=(TransientTypeHandle&& other) = default; + + TypeHandle& operator=(const PersistentTypeHandle& other); +}; + +// All methods are called by `Persistent`. +template <> +class TypeHandle final : public TypeHandleBase { + public: + constexpr TypeHandle() = default; + + TypeHandle(const PersistentTypeHandle& other) { rep_ = other.Ref(); } + + TypeHandle(PersistentTypeHandle&& other) { + rep_ = other.rep_; + other.rep_ = kTypeHandleUnmanaged; + } + + explicit TypeHandle(const TransientTypeHandle& other) { rep_ = other.Ref(); } + + ~TypeHandle() { Unref(); } + + TypeHandle& operator=(const PersistentTypeHandle& other) { + Unref(); + rep_ = other.Ref(); + return *this; + } + + TypeHandle& operator=(PersistentTypeHandle&& other) { + Unref(); + rep_ = other.rep_; + other.rep_ = kTypeHandleUnmanaged; + return *this; + } + + TypeHandle& operator=(const TransientTypeHandle& other) { + Unref(); + rep_ = other.Ref(); + return *this; + } +}; + +inline TypeHandle::TypeHandle( + const PersistentTypeHandle& other) { + rep_ = other.rep_; +} + +inline TypeHandle& TypeHandle< + HandleType::kTransient>::operator=(const PersistentTypeHandle& other) { + rep_ = other.rep_; + return *this; +} + +// Specialization for Type providing the implementation to `Transient`. +template <> +struct HandleTraits { + using handle_type = TypeHandle; +}; + +// Partial specialization for `Transient` for all classes derived from Type. +template +struct HandleTraits< + HandleType::kTransient, T, + std::enable_if_t<(std::is_base_of_v && !std::is_same_v)>> + final : public HandleTraits {}; + +// Specialization for Type providing the implementation to `Persistent`. +template <> +struct HandleTraits { + using handle_type = TypeHandle; +}; + +// Partial specialization for `Persistent` for all classes derived from Type. +template +struct HandleTraits< + HandleType::kPersistent, T, + std::enable_if_t<(std::is_base_of_v && !std::is_same_v)>> + final : public HandleTraits {}; + +} // namespace base_internal + +#define CEL_INTERNAL_TYPE_DECL(name) \ + extern template class Transient; \ + extern template class Transient; \ + extern template class Persistent; \ + extern template class Persistent +CEL_INTERNAL_TYPE_DECL(Type); +CEL_INTERNAL_TYPE_DECL(NullType); +CEL_INTERNAL_TYPE_DECL(ErrorType); +CEL_INTERNAL_TYPE_DECL(DynType); +CEL_INTERNAL_TYPE_DECL(AnyType); +CEL_INTERNAL_TYPE_DECL(BoolType); +CEL_INTERNAL_TYPE_DECL(IntType); +CEL_INTERNAL_TYPE_DECL(UintType); +CEL_INTERNAL_TYPE_DECL(DoubleType); +CEL_INTERNAL_TYPE_DECL(BytesType); +CEL_INTERNAL_TYPE_DECL(StringType); +CEL_INTERNAL_TYPE_DECL(DurationType); +CEL_INTERNAL_TYPE_DECL(TimestampType); +#undef CEL_INTERNAL_TYPE_DECL + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_POST_H_ diff --git a/base/internal/type.pre.h b/base/internal/type.pre.h new file mode 100644 index 000000000..d6bf8ae0b --- /dev/null +++ b/base/internal/type.pre.h @@ -0,0 +1,44 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "base/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_PRE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_PRE_H_ + +#include + +#include "base/handle.h" + +namespace cel::base_internal { + +class TypeHandleBase; +template +class TypeHandle; + +// Convenient aliases. +using TransientTypeHandle = TypeHandle; +using PersistentTypeHandle = TypeHandle; + +// As all objects should be aligned to at least 4 bytes, we can use the lower +// two bits for our own purposes. +inline constexpr uintptr_t kTypeHandleUnmanaged = 1 << 0; +inline constexpr uintptr_t kTypeHandleReserved = 1 << 1; +inline constexpr uintptr_t kTypeHandleBits = + kTypeHandleUnmanaged | kTypeHandleReserved; +inline constexpr uintptr_t kTypeHandleMask = ~kTypeHandleBits; + +} // namespace cel::base_internal + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_PRE_H_ diff --git a/base/type.cc b/base/type.cc index c1e0851c1..e6294ae6b 100644 --- a/base/type.cc +++ b/base/type.cc @@ -14,247 +14,107 @@ #include "base/type.h" -#include #include -#include "absl/base/attributes.h" -#include "absl/base/call_once.h" -#include "absl/base/macros.h" -#include "absl/base/optimization.h" -#include "base/internal/type.h" -#include "internal/reference_counted.h" +#include "absl/types/span.h" +#include "base/handle.h" +#include "internal/no_destructor.h" namespace cel { -namespace base_internal { +#define CEL_INTERNAL_TYPE_IMPL(name) \ + template class Transient; \ + template class Transient; \ + template class Persistent; \ + template class Persistent +CEL_INTERNAL_TYPE_IMPL(Type); +CEL_INTERNAL_TYPE_IMPL(NullType); +CEL_INTERNAL_TYPE_IMPL(ErrorType); +CEL_INTERNAL_TYPE_IMPL(DynType); +CEL_INTERNAL_TYPE_IMPL(AnyType); +CEL_INTERNAL_TYPE_IMPL(BoolType); +CEL_INTERNAL_TYPE_IMPL(IntType); +CEL_INTERNAL_TYPE_IMPL(UintType); +CEL_INTERNAL_TYPE_IMPL(DoubleType); +CEL_INTERNAL_TYPE_IMPL(BytesType); +CEL_INTERNAL_TYPE_IMPL(StringType); +CEL_INTERNAL_TYPE_IMPL(DurationType); +CEL_INTERNAL_TYPE_IMPL(TimestampType); +#undef CEL_INTERNAL_TYPE_IMPL + +absl::Span> Type::parameters() const { return {}; } + +std::pair Type::SizeAndAlignment() const { + // Currently no implementation of Type is reference counted. However once we + // introduce Struct it likely will be. Using 0 here will trigger runtime + // asserts in case of undefined behavior. Struct should force this to be pure. + return std::pair(0, 0); +} + +bool Type::Equals(const Type& other) const { return kind() == other.kind(); } -// Implementation of BaseType for simple types. See SimpleTypes below for the -// types being implemented. -class SimpleType final : public BaseType { - public: - constexpr SimpleType(Kind kind, absl::string_view name) - : BaseType(), name_(name), kind_(kind) {} - - ~SimpleType() override { - // Simple types should live for the lifetime of the process, so destructing - // them is definetly a bug. - std::abort(); - } - - Kind kind() const override { return kind_; } - - absl::string_view name() const override { return name_; } - - absl::Span parameters() const override { return {}; } - - protected: - void HashValue(absl::HashState state) const override { - // cel::Type already adds both kind and name to the hash state, nothing else - // for us to do. - static_cast(state); - } - - bool Equals(const cel::Type& other) const override { - // cel::Type already checks that the kind and name are equivalent, so at - // this point the types are the same. - static_cast(other); - return true; - } - - private: - const absl::string_view name_; - const Kind kind_; -}; - -} // namespace base_internal - -namespace { - -struct SimpleTypes final { - constexpr SimpleTypes() = default; - - SimpleTypes(const SimpleTypes&) = delete; - - SimpleTypes(SimpleTypes&&) = delete; - - ~SimpleTypes() = default; - - SimpleTypes& operator=(const SimpleTypes&) = delete; - - SimpleTypes& operator=(SimpleTypes&&) = delete; - - Type error_type; - Type null_type; - Type dyn_type; - Type any_type; - Type bool_type; - Type int_type; - Type uint_type; - Type double_type; - Type string_type; - Type bytes_type; - Type duration_type; - Type timestamp_type; -}; - -ABSL_CONST_INIT absl::once_flag simple_types_once; -ABSL_CONST_INIT SimpleTypes* simple_types = nullptr; - -} // namespace - -void Type::Initialize() { - absl::call_once(simple_types_once, []() { - ABSL_ASSERT(simple_types == nullptr); - simple_types = new SimpleTypes(); - simple_types->error_type = - Type(new base_internal::SimpleType(Kind::kError, "*error*")); - simple_types->dyn_type = - Type(new base_internal::SimpleType(Kind::kDyn, "dyn")); - simple_types->any_type = - Type(new base_internal::SimpleType(Kind::kAny, "google.protobuf.Any")); - simple_types->bool_type = - Type(new base_internal::SimpleType(Kind::kBool, "bool")); - simple_types->int_type = - Type(new base_internal::SimpleType(Kind::kInt, "int")); - simple_types->uint_type = - Type(new base_internal::SimpleType(Kind::kUint, "uint")); - simple_types->double_type = - Type(new base_internal::SimpleType(Kind::kDouble, "double")); - simple_types->string_type = - Type(new base_internal::SimpleType(Kind::kString, "string")); - simple_types->bytes_type = - Type(new base_internal::SimpleType(Kind::kBytes, "bytes")); - simple_types->duration_type = Type(new base_internal::SimpleType( - Kind::kDuration, "google.protobuf.Duration")); - simple_types->timestamp_type = Type(new base_internal::SimpleType( - Kind::kTimestamp, "google.protobuf.Timestamp")); - }); -} - -const Type& Type::Simple(Kind kind) { - switch (kind) { - case Kind::kNullType: - return Null(); - case Kind::kError: - return Error(); - case Kind::kBool: - return Bool(); - case Kind::kInt: - return Int(); - case Kind::kUint: - return Uint(); - case Kind::kDouble: - return Double(); - case Kind::kDuration: - return Duration(); - case Kind::kTimestamp: - return Timestamp(); - case Kind::kString: - return String(); - case Kind::kBytes: - return Bytes(); - default: - // We can only get here via memory corruption in cel::Value via - // cel::base_internal::ValueMetadata, as the the kinds with simple tags - // are all covered here. - std::abort(); - } -} - -const Type& Type::Null() { - Initialize(); - return simple_types->null_type; -} - -const Type& Type::Error() { - Initialize(); - return simple_types->error_type; -} - -const Type& Type::Dyn() { - Initialize(); - return simple_types->dyn_type; -} - -const Type& Type::Any() { - Initialize(); - return simple_types->any_type; +void Type::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), kind(), name()); } -const Type& Type::Bool() { - Initialize(); - return simple_types->bool_type; +const NullType& NullType::Get() { + static const internal::NoDestructor instance; + return *instance; } -const Type& Type::Int() { - Initialize(); - return simple_types->int_type; +const ErrorType& ErrorType::Get() { + static const internal::NoDestructor instance; + return *instance; } -const Type& Type::Uint() { - Initialize(); - return simple_types->uint_type; +const DynType& DynType::Get() { + static const internal::NoDestructor instance; + return *instance; } -const Type& Type::Double() { - Initialize(); - return simple_types->double_type; +const AnyType& AnyType::Get() { + static const internal::NoDestructor instance; + return *instance; } -const Type& Type::String() { - Initialize(); - return simple_types->string_type; +const BoolType& BoolType::Get() { + static const internal::NoDestructor instance; + return *instance; } -const Type& Type::Bytes() { - Initialize(); - return simple_types->bytes_type; +const IntType& IntType::Get() { + static const internal::NoDestructor instance; + return *instance; } -const Type& Type::Duration() { - Initialize(); - return simple_types->duration_type; +const UintType& UintType::Get() { + static const internal::NoDestructor instance; + return *instance; } -const Type& Type::Timestamp() { - Initialize(); - return simple_types->timestamp_type; +const DoubleType& DoubleType::Get() { + static const internal::NoDestructor instance; + return *instance; } -Type::Type(const Type& other) : impl_(other.impl_) { internal::Ref(impl_); } - -Type::Type(Type&& other) : impl_(other.impl_) { other.impl_ = nullptr; } - -Type& Type::operator=(const Type& other) { - if (ABSL_PREDICT_TRUE(this != &other)) { - internal::Ref(other.impl_); - internal::Unref(impl_); - impl_ = other.impl_; - } - return *this; +const StringType& StringType::Get() { + static const internal::NoDestructor instance; + return *instance; } -Type& Type::operator=(Type&& other) { - if (ABSL_PREDICT_TRUE(this != &other)) { - internal::Unref(impl_); - impl_ = other.impl_; - other.impl_ = nullptr; - } - return *this; +const BytesType& BytesType::Get() { + static const internal::NoDestructor instance; + return *instance; } -bool Type::Equals(const Type& other) const { - return impl_ == other.impl_ || - (kind() == other.kind() && name() == other.name() && - // It should not be possible to reach here if impl_ is nullptr. - impl_->Equals(other)); +const DurationType& DurationType::Get() { + static const internal::NoDestructor instance; + return *instance; } -void Type::HashValue(absl::HashState state) const { - state = absl::HashState::combine(std::move(state), kind(), name()); - if (impl_) { - impl_->HashValue(std::move(state)); - } +const TimestampType& TimestampType::Get() { + static const internal::NoDestructor instance; + return *instance; } } // namespace cel diff --git a/base/type.h b/base/type.h index 84d201536..03433af9a 100644 --- a/base/type.h +++ b/base/type.h @@ -22,138 +22,399 @@ #include "absl/hash/hash.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "base/internal/type.h" +#include "base/handle.h" +#include "base/internal/type.pre.h" // IWYU pragma: export #include "base/kind.h" -#include "internal/reference_counted.h" +#include "base/memory_manager.h" namespace cel { -class Value; +class Type; +class NullType; +class ErrorType; +class DynType; +class AnyType; +class BoolType; +class IntType; +class UintType; +class DoubleType; +class StringType; +class BytesType; +class DurationType; +class TimestampType; +class TypeFactory; + +class NullValue; +class ErrorValue; +class BoolValue; +class IntValue; +class UintValue; +class DoubleValue; +class BytesValue; +class DurationValue; +class TimestampValue; +class ValueFactory; + +namespace internal { +template +class NoDestructor; +} // A representation of a CEL type that enables reflection, for static analysis, // and introspection, for program construction, of types. -class Type final { +class Type : public base_internal::Resource { public: - // Returns the null type. - ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Null(); + // Returns the type kind. + virtual Kind kind() const = 0; - // Returns the error type. - ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Error(); + // Returns the type name, i.e. "list". + virtual absl::string_view name() const = 0; - // Returns the dynamic type. - ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Dyn(); + // Returns the type parameters of the type, i.e. key and value type of map. + virtual absl::Span> parameters() const; - // Returns the any type. - ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Any(); + private: + friend class NullType; + friend class ErrorType; + friend class DynType; + friend class AnyType; + friend class BoolType; + friend class IntType; + friend class UintType; + friend class DoubleType; + friend class StringType; + friend class BytesType; + friend class DurationType; + friend class TimestampType; + friend class base_internal::TypeHandleBase; + + Type() = default; + Type(const Type&) = default; + Type(Type&&) = default; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return true; } + + // For non-inlined types that are reference counted, this is the result of + // `sizeof` and `alignof` for the most derived class. + std::pair SizeAndAlignment() const override; + + using base_internal::Resource::Ref; + using base_internal::Resource::Unref; + + // Called by base_internal::TypeHandleBase. + virtual bool Equals(const Type& other) const; + + // Called by base_internal::TypeHandleBase. + virtual void HashValue(absl::HashState state) const; +}; - // Returns the bool type. - ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Bool(); +class NullType final : public Type { + public: + Kind kind() const override { return Kind::kNullType; } - // Returns the int type. - ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Int(); + absl::string_view name() const override { return "null_type"; } - // Returns the uint type. - ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Uint(); + private: + friend class NullValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; - // Returns the double type. - ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Double(); + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kNullType; } - // Returns the string type. - ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& String(); + ABSL_ATTRIBUTE_PURE_FUNCTION static const NullType& Get(); - // Returns the bytes type. - ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Bytes(); + NullType() = default; - // Returns the duration type. - ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Duration(); + NullType(const NullType&) = delete; + NullType(NullType&&) = delete; +}; - // Returns the timestamp type. - ABSL_ATTRIBUTE_PURE_FUNCTION static const Type& Timestamp(); +class ErrorType final : public Type { + public: + Kind kind() const override { return Kind::kError; } - // Equivalent to `Type::Null()`. - constexpr Type() : Type(nullptr) {} + absl::string_view name() const override { return "*error*"; } - Type(const Type& other); + private: + friend class ErrorValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; - Type(Type&& other); + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kError; } - ~Type() { internal::Unref(impl_); } + ABSL_ATTRIBUTE_PURE_FUNCTION static const ErrorType& Get(); - Type& operator=(const Type& other); + ErrorType() = default; - Type& operator=(Type&& other); + ErrorType(const ErrorType&) = delete; + ErrorType(ErrorType&&) = delete; +}; - // Returns the type kind. - Kind kind() const { return impl_ ? impl_->kind() : Kind::kNullType; } +class DynType final : public Type { + public: + Kind kind() const override { return Kind::kDyn; } - // Returns the type name, i.e. "list". - absl::string_view name() const { return impl_ ? impl_->name() : "null_type"; } + absl::string_view name() const override { return "dyn"; } - // Returns the type parameters of the type, i.e. key and value type of map. - absl::Span parameters() const { - return impl_ ? impl_->parameters() : absl::Span(); - } + private: + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; - bool IsNull() const { return kind() == Kind::kNullType; } + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kDyn; } - bool IsError() const { return kind() == Kind::kError; } + ABSL_ATTRIBUTE_PURE_FUNCTION static const DynType& Get(); - bool IsDyn() const { return kind() == Kind::kDyn; } + DynType() = default; - bool IsAny() const { return kind() == Kind::kAny; } + DynType(const DynType&) = delete; + DynType(DynType&&) = delete; +}; - bool IsBool() const { return kind() == Kind::kBool; } +class AnyType final : public Type { + public: + Kind kind() const override { return Kind::kAny; } - bool IsInt() const { return kind() == Kind::kInt; } + absl::string_view name() const override { return "google.protobuf.Any"; } - bool IsUint() const { return kind() == Kind::kUint; } + private: + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; - bool IsDouble() const { return kind() == Kind::kDouble; } + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kAny; } - bool IsString() const { return kind() == Kind::kString; } + ABSL_ATTRIBUTE_PURE_FUNCTION static const AnyType& Get(); - bool IsBytes() const { return kind() == Kind::kBytes; } + AnyType() = default; - bool IsDuration() const { return kind() == Kind::kDuration; } + AnyType(const AnyType&) = delete; + AnyType(AnyType&&) = delete; +}; - bool IsTimestamp() const { return kind() == Kind::kTimestamp; } +class BoolType final : public Type { + public: + Kind kind() const override { return Kind::kBool; } - template - friend H AbslHashValue(H state, const Type& type) { - type.HashValue(absl::HashState::Create(&state)); - return std::move(state); - } + absl::string_view name() const override { return "bool"; } - friend void swap(Type& lhs, Type& rhs) { - const base_internal::BaseType* impl = lhs.impl_; - lhs.impl_ = rhs.impl_; - rhs.impl_ = impl; - } + private: + friend class BoolValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; - friend bool operator==(const Type& lhs, const Type& rhs) { - return lhs.Equals(rhs); - } + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kBool; } - friend bool operator!=(const Type& lhs, const Type& rhs) { - return !operator==(lhs, rhs); - } + ABSL_ATTRIBUTE_PURE_FUNCTION static const BoolType& Get(); + + BoolType() = default; + + BoolType(const BoolType&) = delete; + BoolType(BoolType&&) = delete; +}; + +class IntType final : public Type { + public: + Kind kind() const override { return Kind::kInt; } + + absl::string_view name() const override { return "int"; } private: - friend class Value; + friend class IntValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; - static void Initialize(); + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kInt; } - static const Type& Simple(Kind kind); + ABSL_ATTRIBUTE_PURE_FUNCTION static const IntType& Get(); - constexpr explicit Type(const base_internal::BaseType* impl) : impl_(impl) {} + IntType() = default; - bool Equals(const Type& other) const; + IntType(const IntType&) = delete; + IntType(IntType&&) = delete; +}; + +class UintType final : public Type { + public: + Kind kind() const override { return Kind::kUint; } - void HashValue(absl::HashState state) const; + absl::string_view name() const override { return "uint"; } - const base_internal::BaseType* impl_; + private: + friend class UintValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kUint; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const UintType& Get(); + + UintType() = default; + + UintType(const UintType&) = delete; + UintType(UintType&&) = delete; +}; + +class DoubleType final : public Type { + public: + Kind kind() const override { return Kind::kDouble; } + + absl::string_view name() const override { return "double"; } + + private: + friend class DoubleValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kDouble; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const DoubleType& Get(); + + DoubleType() = default; + + DoubleType(const DoubleType&) = delete; + DoubleType(DoubleType&&) = delete; +}; + +class StringType final : public Type { + public: + Kind kind() const override { return Kind::kString; } + + absl::string_view name() const override { return "string"; } + + private: + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kString; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const StringType& Get(); + + StringType() = default; + + StringType(const StringType&) = delete; + StringType(StringType&&) = delete; +}; + +class BytesType final : public Type { + public: + Kind kind() const override { return Kind::kBytes; } + + absl::string_view name() const override { return "bytes"; } + + private: + friend class BytesValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kBytes; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const BytesType& Get(); + + BytesType() = default; + + BytesType(const BytesType&) = delete; + BytesType(BytesType&&) = delete; +}; + +class DurationType final : public Type { + public: + Kind kind() const override { return Kind::kDuration; } + + absl::string_view name() const override { return "google.protobuf.Duration"; } + + private: + friend class DurationValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kDuration; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const DurationType& Get(); + + DurationType() = default; + + DurationType(const DurationType&) = delete; + DurationType(DurationType&&) = delete; +}; + +class TimestampType final : public Type { + public: + Kind kind() const override { return Kind::kTimestamp; } + + absl::string_view name() const override { + return "google.protobuf.Timestamp"; + } + + private: + friend class TimestampValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kTimestamp; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const TimestampType& Get(); + + TimestampType() = default; + + TimestampType(const TimestampType&) = delete; + TimestampType(TimestampType&&) = delete; }; } // namespace cel +// type.pre.h forward declares types so they can be friended above. The types +// themselves need to be defined after everything else as they need to access or +// derive from the above types. We do this in type.post.h to avoid mudying this +// header and making it difficult to read. +#include "base/internal/type.post.h" // IWYU pragma: export + #endif // THIRD_PARTY_CEL_CPP_BASE_TYPE_H_ diff --git a/base/type_factory.cc b/base/type_factory.cc new file mode 100644 index 000000000..4f3509fb2 --- /dev/null +++ b/base/type_factory.cc @@ -0,0 +1,76 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/type_factory.h" + +#include "base/handle.h" +#include "base/type.h" + +namespace cel { + +namespace { + +using base_internal::TransientHandleFactory; + +} // namespace + +Persistent TypeFactory::GetNullType() { + return WrapSingletonType(); +} + +Persistent TypeFactory::GetErrorType() { + return WrapSingletonType(); +} + +Persistent TypeFactory::GetDynType() { + return WrapSingletonType(); +} + +Persistent TypeFactory::GetAnyType() { + return WrapSingletonType(); +} + +Persistent TypeFactory::GetBoolType() { + return WrapSingletonType(); +} + +Persistent TypeFactory::GetIntType() { + return WrapSingletonType(); +} + +Persistent TypeFactory::GetUintType() { + return WrapSingletonType(); +} + +Persistent TypeFactory::GetDoubleType() { + return WrapSingletonType(); +} + +Persistent TypeFactory::GetStringType() { + return WrapSingletonType(); +} + +Persistent TypeFactory::GetBytesType() { + return WrapSingletonType(); +} + +Persistent TypeFactory::GetDurationType() { + return WrapSingletonType(); +} + +Persistent TypeFactory::GetTimestampType() { + return WrapSingletonType(); +} + +} // namespace cel diff --git a/base/type_factory.h b/base/type_factory.h new file mode 100644 index 000000000..4e74a3654 --- /dev/null +++ b/base/type_factory.h @@ -0,0 +1,83 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPE_FACTORY_H_ + +#include "absl/base/attributes.h" +#include "base/handle.h" +#include "base/memory_manager.h" +#include "base/type.h" + +namespace cel { + +// TypeFactory provides member functions to get and create type implementations +// of builtin types. +class TypeFactory { + public: + virtual ~TypeFactory() = default; + + Persistent GetNullType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent GetErrorType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent GetDynType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent GetAnyType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent GetBoolType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent GetIntType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent GetUintType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent GetDoubleType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent GetStringType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent GetBytesType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent GetDurationType() + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent GetTimestampType() + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + protected: + // Prevent direct intantiation until more pure virtual methods are added. + explicit TypeFactory(MemoryManager& memory_manager) + : memory_manager_(memory_manager) {} + + // Ignore unused for now, as it will be used in the future. + ABSL_ATTRIBUTE_UNUSED MemoryManager& memory_manager() const { + return memory_manager_; + } + + private: + template + static Persistent WrapSingletonType() { + // This is not normal, but we treat the underlying object as having been + // arena allocated. The only way to do this is through + // TransientHandleFactory. + return Persistent( + base_internal::TransientHandleFactory::template MakeUnmanaged< + const T>(T::Get())); + } + + MemoryManager& memory_manager_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPE_FACTORY_H_ diff --git a/base/type_test.cc b/base/type_test.cc index d8df5dae0..d6e2045fa 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -18,6 +18,9 @@ #include #include "absl/hash/hash_testing.h" +#include "base/handle.h" +#include "base/memory_manager.h" +#include "base/type_factory.h" #include "internal/testing.h" namespace cel { @@ -25,57 +28,83 @@ namespace { using testing::SizeIs; +class TestTypeFactory final : public TypeFactory { + public: + TestTypeFactory() : TypeFactory(MemoryManager::Global()) {} +}; + template constexpr void IS_INITIALIZED(T&) {} -TEST(Type, TypeTraits) { - EXPECT_TRUE(std::is_default_constructible_v); - EXPECT_TRUE(std::is_copy_constructible_v); - EXPECT_TRUE(std::is_move_constructible_v); - EXPECT_TRUE(std::is_copy_assignable_v); - EXPECT_TRUE(std::is_move_assignable_v); - EXPECT_TRUE(std::is_swappable_v); +TEST(Type, TransientHandleTypeTraits) { + EXPECT_TRUE(std::is_default_constructible_v>); + EXPECT_TRUE(std::is_copy_constructible_v>); + EXPECT_TRUE(std::is_move_constructible_v>); + EXPECT_TRUE(std::is_copy_assignable_v>); + EXPECT_TRUE(std::is_move_assignable_v>); + EXPECT_TRUE(std::is_swappable_v>); + EXPECT_TRUE(std::is_default_constructible_v>); + EXPECT_TRUE(std::is_copy_constructible_v>); + EXPECT_TRUE(std::is_move_constructible_v>); + EXPECT_TRUE(std::is_copy_assignable_v>); + EXPECT_TRUE(std::is_move_assignable_v>); + EXPECT_TRUE(std::is_swappable_v>); } -TEST(Type, DefaultConstructor) { - Type type; - EXPECT_EQ(type, Type::Null()); +TEST(Type, PersistentHandleTypeTraits) { + EXPECT_TRUE(std::is_default_constructible_v>); + EXPECT_TRUE(std::is_copy_constructible_v>); + EXPECT_TRUE(std::is_move_constructible_v>); + EXPECT_TRUE(std::is_copy_assignable_v>); + EXPECT_TRUE(std::is_move_assignable_v>); + EXPECT_TRUE(std::is_swappable_v>); + EXPECT_TRUE(std::is_default_constructible_v>); + EXPECT_TRUE(std::is_copy_constructible_v>); + EXPECT_TRUE(std::is_move_constructible_v>); + EXPECT_TRUE(std::is_copy_assignable_v>); + EXPECT_TRUE(std::is_move_assignable_v>); + EXPECT_TRUE(std::is_swappable_v>); } TEST(Type, CopyConstructor) { - Type type(Type::Int()); - EXPECT_EQ(type, Type::Int()); + TestTypeFactory type_factory; + Transient type(type_factory.GetIntType()); + EXPECT_EQ(type, type_factory.GetIntType()); } TEST(Type, MoveConstructor) { - Type from(Type::Int()); - Type to(std::move(from)); + TestTypeFactory type_factory; + Transient from(type_factory.GetIntType()); + Transient to(std::move(from)); IS_INITIALIZED(from); - EXPECT_EQ(from, Type::Null()); - EXPECT_EQ(to, Type::Int()); + EXPECT_EQ(from, type_factory.GetIntType()); + EXPECT_EQ(to, type_factory.GetIntType()); } TEST(Type, CopyAssignment) { - Type type; - type = Type::Int(); - EXPECT_EQ(type, Type::Int()); + TestTypeFactory type_factory; + Transient type(type_factory.GetNullType()); + type = type_factory.GetIntType(); + EXPECT_EQ(type, type_factory.GetIntType()); } TEST(Type, MoveAssignment) { - Type from(Type::Int()); - Type to; + TestTypeFactory type_factory; + Transient from(type_factory.GetIntType()); + Transient to(type_factory.GetNullType()); to = std::move(from); IS_INITIALIZED(from); - EXPECT_EQ(from, Type::Null()); - EXPECT_EQ(to, Type::Int()); + EXPECT_EQ(from, type_factory.GetIntType()); + EXPECT_EQ(to, type_factory.GetIntType()); } TEST(Type, Swap) { - Type lhs = Type::Int(); - Type rhs = Type::Uint(); + TestTypeFactory type_factory; + Transient lhs = type_factory.GetIntType(); + Transient rhs = type_factory.GetUintType(); std::swap(lhs, rhs); - EXPECT_EQ(lhs, Type::Uint()); - EXPECT_EQ(rhs, Type::Int()); + EXPECT_EQ(lhs, type_factory.GetUintType()); + EXPECT_EQ(rhs, type_factory.GetIntType()); } // The below tests could be made parameterized but doing so requires the @@ -83,223 +112,237 @@ TEST(Type, Swap) { // feature is not available in C++17. TEST(Type, Null) { - EXPECT_EQ(Type::Null().kind(), Kind::kNullType); - EXPECT_EQ(Type::Null().name(), "null_type"); - EXPECT_THAT(Type::Null().parameters(), SizeIs(0)); - EXPECT_TRUE(Type::Null().IsNull()); - EXPECT_FALSE(Type::Null().IsDyn()); - EXPECT_FALSE(Type::Null().IsAny()); - EXPECT_FALSE(Type::Null().IsBool()); - EXPECT_FALSE(Type::Null().IsInt()); - EXPECT_FALSE(Type::Null().IsUint()); - EXPECT_FALSE(Type::Null().IsDouble()); - EXPECT_FALSE(Type::Null().IsString()); - EXPECT_FALSE(Type::Null().IsBytes()); - EXPECT_FALSE(Type::Null().IsDuration()); - EXPECT_FALSE(Type::Null().IsTimestamp()); + TestTypeFactory type_factory; + EXPECT_EQ(type_factory.GetNullType()->kind(), Kind::kNullType); + EXPECT_EQ(type_factory.GetNullType()->name(), "null_type"); + EXPECT_THAT(type_factory.GetNullType()->parameters(), SizeIs(0)); + EXPECT_TRUE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); } TEST(Type, Error) { - EXPECT_EQ(Type::Error().kind(), Kind::kError); - EXPECT_EQ(Type::Error().name(), "*error*"); - EXPECT_THAT(Type::Error().parameters(), SizeIs(0)); - EXPECT_FALSE(Type::Error().IsNull()); - EXPECT_FALSE(Type::Error().IsDyn()); - EXPECT_FALSE(Type::Error().IsAny()); - EXPECT_FALSE(Type::Error().IsBool()); - EXPECT_FALSE(Type::Error().IsInt()); - EXPECT_FALSE(Type::Error().IsUint()); - EXPECT_FALSE(Type::Error().IsDouble()); - EXPECT_FALSE(Type::Error().IsString()); - EXPECT_FALSE(Type::Error().IsBytes()); - EXPECT_FALSE(Type::Error().IsDuration()); - EXPECT_FALSE(Type::Error().IsTimestamp()); + TestTypeFactory type_factory; + EXPECT_EQ(type_factory.GetErrorType()->kind(), Kind::kError); + EXPECT_EQ(type_factory.GetErrorType()->name(), "*error*"); + EXPECT_THAT(type_factory.GetErrorType()->parameters(), SizeIs(0)); + EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); } TEST(Type, Dyn) { - EXPECT_EQ(Type::Dyn().kind(), Kind::kDyn); - EXPECT_EQ(Type::Dyn().name(), "dyn"); - EXPECT_THAT(Type::Dyn().parameters(), SizeIs(0)); - EXPECT_FALSE(Type::Dyn().IsNull()); - EXPECT_TRUE(Type::Dyn().IsDyn()); - EXPECT_FALSE(Type::Dyn().IsAny()); - EXPECT_FALSE(Type::Dyn().IsBool()); - EXPECT_FALSE(Type::Dyn().IsInt()); - EXPECT_FALSE(Type::Dyn().IsUint()); - EXPECT_FALSE(Type::Dyn().IsDouble()); - EXPECT_FALSE(Type::Dyn().IsString()); - EXPECT_FALSE(Type::Dyn().IsBytes()); - EXPECT_FALSE(Type::Dyn().IsDuration()); - EXPECT_FALSE(Type::Dyn().IsTimestamp()); + TestTypeFactory type_factory; + EXPECT_EQ(type_factory.GetDynType()->kind(), Kind::kDyn); + EXPECT_EQ(type_factory.GetDynType()->name(), "dyn"); + EXPECT_THAT(type_factory.GetDynType()->parameters(), SizeIs(0)); + EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_TRUE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); } TEST(Type, Any) { - EXPECT_EQ(Type::Any().kind(), Kind::kAny); - EXPECT_EQ(Type::Any().name(), "google.protobuf.Any"); - EXPECT_THAT(Type::Any().parameters(), SizeIs(0)); - EXPECT_FALSE(Type::Any().IsNull()); - EXPECT_FALSE(Type::Any().IsDyn()); - EXPECT_TRUE(Type::Any().IsAny()); - EXPECT_FALSE(Type::Any().IsBool()); - EXPECT_FALSE(Type::Any().IsInt()); - EXPECT_FALSE(Type::Any().IsUint()); - EXPECT_FALSE(Type::Any().IsDouble()); - EXPECT_FALSE(Type::Any().IsString()); - EXPECT_FALSE(Type::Any().IsBytes()); - EXPECT_FALSE(Type::Any().IsDuration()); - EXPECT_FALSE(Type::Any().IsTimestamp()); + TestTypeFactory type_factory; + EXPECT_EQ(type_factory.GetAnyType()->kind(), Kind::kAny); + EXPECT_EQ(type_factory.GetAnyType()->name(), "google.protobuf.Any"); + EXPECT_THAT(type_factory.GetAnyType()->parameters(), SizeIs(0)); + EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_TRUE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); } TEST(Type, Bool) { - EXPECT_EQ(Type::Bool().kind(), Kind::kBool); - EXPECT_EQ(Type::Bool().name(), "bool"); - EXPECT_THAT(Type::Bool().parameters(), SizeIs(0)); - EXPECT_FALSE(Type::Bool().IsNull()); - EXPECT_FALSE(Type::Bool().IsDyn()); - EXPECT_FALSE(Type::Bool().IsAny()); - EXPECT_TRUE(Type::Bool().IsBool()); - EXPECT_FALSE(Type::Bool().IsInt()); - EXPECT_FALSE(Type::Bool().IsUint()); - EXPECT_FALSE(Type::Bool().IsDouble()); - EXPECT_FALSE(Type::Bool().IsString()); - EXPECT_FALSE(Type::Bool().IsBytes()); - EXPECT_FALSE(Type::Bool().IsDuration()); - EXPECT_FALSE(Type::Bool().IsTimestamp()); + TestTypeFactory type_factory; + EXPECT_EQ(type_factory.GetBoolType()->kind(), Kind::kBool); + EXPECT_EQ(type_factory.GetBoolType()->name(), "bool"); + EXPECT_THAT(type_factory.GetBoolType()->parameters(), SizeIs(0)); + EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_TRUE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); } TEST(Type, Int) { - EXPECT_EQ(Type::Int().kind(), Kind::kInt); - EXPECT_EQ(Type::Int().name(), "int"); - EXPECT_THAT(Type::Int().parameters(), SizeIs(0)); - EXPECT_FALSE(Type::Int().IsNull()); - EXPECT_FALSE(Type::Int().IsDyn()); - EXPECT_FALSE(Type::Int().IsAny()); - EXPECT_FALSE(Type::Int().IsBool()); - EXPECT_TRUE(Type::Int().IsInt()); - EXPECT_FALSE(Type::Int().IsUint()); - EXPECT_FALSE(Type::Int().IsDouble()); - EXPECT_FALSE(Type::Int().IsString()); - EXPECT_FALSE(Type::Int().IsBytes()); - EXPECT_FALSE(Type::Int().IsDuration()); - EXPECT_FALSE(Type::Int().IsTimestamp()); + TestTypeFactory type_factory; + EXPECT_EQ(type_factory.GetIntType()->kind(), Kind::kInt); + EXPECT_EQ(type_factory.GetIntType()->name(), "int"); + EXPECT_THAT(type_factory.GetIntType()->parameters(), SizeIs(0)); + EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_TRUE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); } TEST(Type, Uint) { - EXPECT_EQ(Type::Uint().kind(), Kind::kUint); - EXPECT_EQ(Type::Uint().name(), "uint"); - EXPECT_THAT(Type::Uint().parameters(), SizeIs(0)); - EXPECT_FALSE(Type::Uint().IsNull()); - EXPECT_FALSE(Type::Uint().IsDyn()); - EXPECT_FALSE(Type::Uint().IsAny()); - EXPECT_FALSE(Type::Uint().IsBool()); - EXPECT_FALSE(Type::Uint().IsInt()); - EXPECT_TRUE(Type::Uint().IsUint()); - EXPECT_FALSE(Type::Uint().IsDouble()); - EXPECT_FALSE(Type::Uint().IsString()); - EXPECT_FALSE(Type::Uint().IsBytes()); - EXPECT_FALSE(Type::Uint().IsDuration()); - EXPECT_FALSE(Type::Uint().IsTimestamp()); + TestTypeFactory type_factory; + EXPECT_EQ(type_factory.GetUintType()->kind(), Kind::kUint); + EXPECT_EQ(type_factory.GetUintType()->name(), "uint"); + EXPECT_THAT(type_factory.GetUintType()->parameters(), SizeIs(0)); + EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_TRUE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); } TEST(Type, Double) { - EXPECT_EQ(Type::Double().kind(), Kind::kDouble); - EXPECT_EQ(Type::Double().name(), "double"); - EXPECT_THAT(Type::Double().parameters(), SizeIs(0)); - EXPECT_FALSE(Type::Double().IsNull()); - EXPECT_FALSE(Type::Double().IsDyn()); - EXPECT_FALSE(Type::Double().IsAny()); - EXPECT_FALSE(Type::Double().IsBool()); - EXPECT_FALSE(Type::Double().IsInt()); - EXPECT_FALSE(Type::Double().IsUint()); - EXPECT_TRUE(Type::Double().IsDouble()); - EXPECT_FALSE(Type::Double().IsString()); - EXPECT_FALSE(Type::Double().IsBytes()); - EXPECT_FALSE(Type::Double().IsDuration()); - EXPECT_FALSE(Type::Double().IsTimestamp()); + TestTypeFactory type_factory; + EXPECT_EQ(type_factory.GetDoubleType()->kind(), Kind::kDouble); + EXPECT_EQ(type_factory.GetDoubleType()->name(), "double"); + EXPECT_THAT(type_factory.GetDoubleType()->parameters(), SizeIs(0)); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_TRUE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); } TEST(Type, String) { - EXPECT_EQ(Type::String().kind(), Kind::kString); - EXPECT_EQ(Type::String().name(), "string"); - EXPECT_THAT(Type::String().parameters(), SizeIs(0)); - EXPECT_FALSE(Type::String().IsNull()); - EXPECT_FALSE(Type::String().IsDyn()); - EXPECT_FALSE(Type::String().IsAny()); - EXPECT_FALSE(Type::String().IsBool()); - EXPECT_FALSE(Type::String().IsInt()); - EXPECT_FALSE(Type::String().IsUint()); - EXPECT_FALSE(Type::String().IsDouble()); - EXPECT_TRUE(Type::String().IsString()); - EXPECT_FALSE(Type::String().IsBytes()); - EXPECT_FALSE(Type::String().IsDuration()); - EXPECT_FALSE(Type::String().IsTimestamp()); + TestTypeFactory type_factory; + EXPECT_EQ(type_factory.GetStringType()->kind(), Kind::kString); + EXPECT_EQ(type_factory.GetStringType()->name(), "string"); + EXPECT_THAT(type_factory.GetStringType()->parameters(), SizeIs(0)); + EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_TRUE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); } TEST(Type, Bytes) { - EXPECT_EQ(Type::Bytes().kind(), Kind::kBytes); - EXPECT_EQ(Type::Bytes().name(), "bytes"); - EXPECT_THAT(Type::Bytes().parameters(), SizeIs(0)); - EXPECT_FALSE(Type::Bytes().IsNull()); - EXPECT_FALSE(Type::Bytes().IsDyn()); - EXPECT_FALSE(Type::Bytes().IsAny()); - EXPECT_FALSE(Type::Bytes().IsBool()); - EXPECT_FALSE(Type::Bytes().IsInt()); - EXPECT_FALSE(Type::Bytes().IsUint()); - EXPECT_FALSE(Type::Bytes().IsDouble()); - EXPECT_FALSE(Type::Bytes().IsString()); - EXPECT_TRUE(Type::Bytes().IsBytes()); - EXPECT_FALSE(Type::Bytes().IsDuration()); - EXPECT_FALSE(Type::Bytes().IsTimestamp()); + TestTypeFactory type_factory; + EXPECT_EQ(type_factory.GetBytesType()->kind(), Kind::kBytes); + EXPECT_EQ(type_factory.GetBytesType()->name(), "bytes"); + EXPECT_THAT(type_factory.GetBytesType()->parameters(), SizeIs(0)); + EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_TRUE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); } TEST(Type, Duration) { - EXPECT_EQ(Type::Duration().kind(), Kind::kDuration); - EXPECT_EQ(Type::Duration().name(), "google.protobuf.Duration"); - EXPECT_THAT(Type::Duration().parameters(), SizeIs(0)); - EXPECT_FALSE(Type::Duration().IsNull()); - EXPECT_FALSE(Type::Duration().IsDyn()); - EXPECT_FALSE(Type::Duration().IsAny()); - EXPECT_FALSE(Type::Duration().IsBool()); - EXPECT_FALSE(Type::Duration().IsInt()); - EXPECT_FALSE(Type::Duration().IsUint()); - EXPECT_FALSE(Type::Duration().IsDouble()); - EXPECT_FALSE(Type::Duration().IsString()); - EXPECT_FALSE(Type::Duration().IsBytes()); - EXPECT_TRUE(Type::Duration().IsDuration()); - EXPECT_FALSE(Type::Duration().IsTimestamp()); + TestTypeFactory type_factory; + EXPECT_EQ(type_factory.GetDurationType()->kind(), Kind::kDuration); + EXPECT_EQ(type_factory.GetDurationType()->name(), "google.protobuf.Duration"); + EXPECT_THAT(type_factory.GetDurationType()->parameters(), SizeIs(0)); + EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_TRUE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); } TEST(Type, Timestamp) { - EXPECT_EQ(Type::Timestamp().kind(), Kind::kTimestamp); - EXPECT_EQ(Type::Timestamp().name(), "google.protobuf.Timestamp"); - EXPECT_THAT(Type::Timestamp().parameters(), SizeIs(0)); - EXPECT_FALSE(Type::Timestamp().IsNull()); - EXPECT_FALSE(Type::Timestamp().IsDyn()); - EXPECT_FALSE(Type::Timestamp().IsAny()); - EXPECT_FALSE(Type::Timestamp().IsBool()); - EXPECT_FALSE(Type::Timestamp().IsInt()); - EXPECT_FALSE(Type::Timestamp().IsUint()); - EXPECT_FALSE(Type::Timestamp().IsDouble()); - EXPECT_FALSE(Type::Timestamp().IsString()); - EXPECT_FALSE(Type::Timestamp().IsBytes()); - EXPECT_FALSE(Type::Timestamp().IsDuration()); - EXPECT_TRUE(Type::Timestamp().IsTimestamp()); + TestTypeFactory type_factory; + EXPECT_EQ(type_factory.GetTimestampType()->kind(), Kind::kTimestamp); + EXPECT_EQ(type_factory.GetTimestampType()->name(), + "google.protobuf.Timestamp"); + EXPECT_THAT(type_factory.GetTimestampType()->parameters(), SizeIs(0)); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_TRUE(type_factory.GetTimestampType().Is()); } TEST(Type, SupportsAbslHash) { + TestTypeFactory type_factory; EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ - Type::Error(), - Type::Null(), - Type::Dyn(), - Type::Any(), - Type::Bool(), - Type::Int(), - Type::Uint(), - Type::Double(), - Type::String(), - Type::Bytes(), - Type::Duration(), - Type::Timestamp(), + Persistent(type_factory.GetNullType()), + Persistent(type_factory.GetErrorType()), + Persistent(type_factory.GetDynType()), + Persistent(type_factory.GetAnyType()), + Persistent(type_factory.GetBoolType()), + Persistent(type_factory.GetIntType()), + Persistent(type_factory.GetUintType()), + Persistent(type_factory.GetDoubleType()), + Persistent(type_factory.GetStringType()), + Persistent(type_factory.GetBytesType()), + Persistent(type_factory.GetDurationType()), + Persistent(type_factory.GetTimestampType()), })); } From d053fa2c608538bb847e04ac07247b3797262a6b Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 9 Mar 2022 16:35:47 +0000 Subject: [PATCH 020/111] Internal change PiperOrigin-RevId: 433487364 --- base/internal/memory_manager.post.h | 2 +- base/internal/memory_manager.pre.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/base/internal/memory_manager.post.h b/base/internal/memory_manager.post.h index dde3e425a..11da71b3e 100644 --- a/base/internal/memory_manager.post.h +++ b/base/internal/memory_manager.post.h @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// IWYU pragma: private +// IWYU pragma: private, include "base/memory_manager.h" #ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_POST_H_ #define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_POST_H_ diff --git a/base/internal/memory_manager.pre.h b/base/internal/memory_manager.pre.h index aeda27995..66507a0e1 100644 --- a/base/internal/memory_manager.pre.h +++ b/base/internal/memory_manager.pre.h @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// IWYU pragma: private +// IWYU pragma: private, include "base/memory_manager.h" #ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_PRE_H_ #define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_PRE_H_ From 470e7ba20a0374b58bef450426319b08be620bce Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 9 Mar 2022 16:45:01 +0000 Subject: [PATCH 021/111] Update heterogeneous equality behavior to return false for mixed types. PiperOrigin-RevId: 433489389 --- conformance/BUILD | 3 +- eval/compiler/constant_folding.cc | 5 +- eval/public/comparison_functions.cc | 19 +++++--- eval/public/comparison_functions_test.cc | 60 +++++++++++++++--------- 4 files changed, 57 insertions(+), 30 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index ab43d7b50..97c603d03 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -103,8 +103,9 @@ cc_binary( # Future features for CEL 1.0 # TODO(google/cel-spec/issues/225): These are supported comparisons with heterogeneous equality enabled. - "--skip_test=comparisons/eq_literal/eq_list_elem_mixed_types_error,eq_mixed_types_error", + "--skip_test=comparisons/eq_literal/eq_list_elem_mixed_types_error,eq_mixed_types_error,eq_map_value_mixed_types_error", "--skip_test=comparisons/ne_literal/ne_mixed_types_error", + "--skip_test=macros/exists/list_elem_type_exhaustive,map_key_type_exhaustive", # TODO(issues/119): Strong typing support for enums, specified but not implemented. "--skip_test=enums/strong_proto2", "--skip_test=enums/strong_proto3", diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index 40ef0996d..115467346 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -104,7 +104,8 @@ class ConstantFoldingTransform { matched_function = overload; } } - if (matched_function == nullptr) { + if (matched_function == nullptr || + matched_function->descriptor().is_strict()) { // propagate argument errors up the expression for (const CelValue& arg : arg_values) { if (arg.IsError()) { @@ -112,6 +113,8 @@ class ConstantFoldingTransform { return true; } } + } + if (matched_function == nullptr) { makeConstant( CreateNoMatchingOverloadError(arena_, call_expr->function()), out); diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index 2c03b01bc..59ad41da2 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -276,8 +276,8 @@ absl::optional Inequal(const CelMap* t1, const CelMap* t2) { } bool MessageEqual(const google::protobuf::Message& m1, const google::protobuf::Message& m2) { - // Equality behavior is undefined if input messages have different - // descriptors. + // Equality behavior is undefined for message differencer if input messages + // have different descriptors. For CEL just return false. if (m1.GetDescriptor() != m2.GetDescriptor()) { return false; } @@ -460,6 +460,8 @@ CelValue GeneralizedEqual(Arena* arena, CelValue t1, CelValue t2) { if (result.has_value()) { return CelValue::CreateBool(*result); } + // Note: With full heterogeneous equality enabled, this only happens for + // containers containing special value types (errors, unknowns). return CreateNoMatchingOverloadError(arena, builtin::kEqual); } @@ -556,18 +558,21 @@ absl::optional CelValueEqualImpl(const CelValue& v1, const CelValue& v2) { return HomogenousCelValueEqual(v1, v2); } - if (v1.type() == CelValue::Type::kNullType || - v2.type() == CelValue::Type::kNullType) { - return false; - } absl::optional lhs = GetNumberFromCelValue(v1); absl::optional rhs = GetNumberFromCelValue(v2); if (rhs.has_value() && lhs.has_value()) { return *lhs == *rhs; - } else { + } + + // TODO(issues/5): It's currently possible for the interpreter to create a + // map containing an Error. Return no matching overload to propagate an error + // instead of a false result. + if (v1.IsError() || v1.IsUnknownSet() || v2.IsError() || v2.IsUnknownSet()) { return absl::nullopt; } + + return false; } absl::Status RegisterComparisonFunctions(CelFunctionRegistry* registry, diff --git a/eval/public/comparison_functions_test.cc b/eval/public/comparison_functions_test.cc index a11b4153f..b8723d949 100644 --- a/eval/public/comparison_functions_test.cc +++ b/eval/public/comparison_functions_test.cc @@ -59,6 +59,7 @@ namespace google::api::expr::runtime { namespace { using google::api::expr::v1alpha1::ParsedExpr; +using testing::_; using testing::Combine; using testing::HasSubstr; using testing::Optional; @@ -77,7 +78,7 @@ MATCHER_P2(DefinesHomogenousOverload, name, argument_type, } struct ComparisonTestCase { - enum class ErrorKind { kMissingOverload }; + enum class ErrorKind { kMissingOverload, kMissingIdentifier }; absl::string_view expr; absl::variant result; CelValue lhs = CelValue::CreateNull(); @@ -205,12 +206,11 @@ TEST_P(CelValueEqualImplTypesTest, Basic) { } else { EXPECT_THAT(result, Optional(false)); } - } else if (lhs().type() == rhs().type()) { - EXPECT_THAT(result, Optional(should_be_equal())); - } else if (IsNumeric(lhs().type()) && IsNumeric(rhs().type())) { + } else if (lhs().type() == rhs().type() || + (IsNumeric(lhs().type()) && IsNumeric(rhs().type()))) { EXPECT_THAT(result, Optional(should_be_equal())); } else { - EXPECT_EQ(result, absl::nullopt); + EXPECT_THAT(result, Optional(false)); } } @@ -302,13 +302,13 @@ TEST(CelValueEqualImplTest, LossyNumericEquality) { EXPECT_TRUE(*result); } -TEST(CelValueEqualImplTest, ListMixedTypesEqualityNotDefined) { +TEST(CelValueEqualImplTest, ListMixedTypesInequal) { ContainerBackedListImpl lhs({CelValue::CreateInt64(1)}); ContainerBackedListImpl rhs({CelValue::CreateStringView("abc")}); - EXPECT_EQ( + EXPECT_THAT( CelValueEqualImpl(CelValue::CreateList(&lhs), CelValue::CreateList(&rhs)), - absl::nullopt); + Optional(false)); } TEST(CelValueEqualImplTest, NestedList) { @@ -322,7 +322,7 @@ TEST(CelValueEqualImplTest, NestedList) { Optional(false)); } -TEST(CelValueEqualImplTest, MapMixedValueTypesEqualityNotDefined) { +TEST(CelValueEqualImplTest, MapMixedValueTypesInequal) { std::vector> lhs_data{ {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; std::vector> rhs_data{ @@ -333,9 +333,9 @@ TEST(CelValueEqualImplTest, MapMixedValueTypesEqualityNotDefined) { ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, CreateContainerBackedMap(absl::MakeSpan(rhs_data))); - EXPECT_EQ(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), - CelValue::CreateMap(rhs.get())), - absl::nullopt); + EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), + CelValue::CreateMap(rhs.get())), + Optional(false)); } TEST(CelValueEqualImplTest, MapMixedKeyTypesInequal) { @@ -604,9 +604,21 @@ TEST_P(ComparisonFunctionTest, SmokeTest) { if (absl::holds_alternative(test_case.result)) { EXPECT_THAT(result, test::IsCelBool(absl::get(test_case.result))); } else { - EXPECT_THAT(result, - test::IsCelError(StatusIs(absl::StatusCode::kUnknown, - HasSubstr("No matching overloads")))); + switch (std::get(test_case.result)) { + case ComparisonTestCase::ErrorKind::kMissingOverload: + EXPECT_THAT(result, test::IsCelError( + StatusIs(absl::StatusCode::kUnknown, + HasSubstr("No matching overloads")))); + break; + case ComparisonTestCase::ErrorKind::kMissingIdentifier: + EXPECT_THAT(result, test::IsCelError( + StatusIs(absl::StatusCode::kUnknown, + HasSubstr("found in Activation")))); + break; + default: + EXPECT_THAT(result, test::IsCelError(_)); + break; + } } } @@ -769,9 +781,12 @@ INSTANTIATE_TEST_SUITE_P( {"lhs == rhs", true, CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), CelValue::CreateTimestamp(absl::FromUnixSeconds(20))}, - // Maps may have errors as values. These don't propagate from - // deep comparisons at the moment, they just return no - // overload. + // This should fail before getting to the equal operator. + {"no_such_identifier == 1", + ComparisonTestCase::ErrorKind::kMissingIdentifier}, + // TODO(issues/5): The C++ evaluator allows creating maps + // with error values. Propagate an error instead of a false + // result. {"{1: no_such_identifier} == {1: 1}", ComparisonTestCase::ErrorKind::kMissingOverload}}), // heterogeneous equality enabled @@ -794,9 +809,12 @@ INSTANTIATE_TEST_SUITE_P( {"lhs != rhs", true, CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), CelValue::CreateTimestamp(absl::FromUnixSeconds(30))}, - // Maps may have errors as values. These don't propagate from - // deep comparisons at the moment, they just return no - // overload. + // This should fail before getting to the equal operator. + {"no_such_identifier != 1", + ComparisonTestCase::ErrorKind::kMissingIdentifier}, + // TODO(issues/5): The C++ evaluator allows creating maps + // with error values. Propagate an error instead of a false + // result. {"{1: no_such_identifier} != {1: 1}", ComparisonTestCase::ErrorKind::kMissingOverload}}), // heterogeneous equality enabled From 202148f1b0f220ef5f7fd744df827454253855f3 Mon Sep 17 00:00:00 2001 From: timdn Date: Thu, 10 Mar 2022 16:19:36 +0000 Subject: [PATCH 022/111] Add utility function to get file descriptor set of standard messages AddStandardMessageTypesToDescriptorPool() cannot be called on descriptor pools backed by a database. Yet users may need to add the relevant types. The new function GetStandardMessageTypesFileDescriptorSet() gives a file descriptor set of those types so that users can add those by themselves. PiperOrigin-RevId: 433757758 --- eval/compiler/BUILD | 1 + eval/compiler/flat_expr_builder_test.cc | 3 +- eval/public/BUILD | 13 +- eval/public/cel_expr_builder_factory.cc | 137 ++---------------- eval/public/cel_expr_builder_factory.h | 3 - eval/public/structs/BUILD | 25 ++++ .../cel_proto_descriptor_pool_builder.cc | 126 ++++++++++++++++ .../cel_proto_descriptor_pool_builder.h | 40 +++++ ...cel_proto_descriptor_pool_builder_test.cc} | 23 ++- internal/BUILD | 1 + internal/proto_util.cc | 42 ++++++ internal/proto_util.h | 33 +++++ 12 files changed, 300 insertions(+), 147 deletions(-) create mode 100644 eval/public/structs/cel_proto_descriptor_pool_builder.cc create mode 100644 eval/public/structs/cel_proto_descriptor_pool_builder.h rename eval/public/{cel_expr_builder_factory_test.cc => structs/cel_proto_descriptor_pool_builder_test.cc} (86%) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 4d3e94853..f0ab9b06e 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -71,6 +71,7 @@ cc_test( "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", "//eval/public/containers:container_backed_map_impl", + "//eval/public/structs:cel_proto_descriptor_pool_builder", "//eval/public/structs:cel_proto_wrapper", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 148ce8e71..a8077839a 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -45,6 +45,7 @@ #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/structs/cel_proto_descriptor_pool_builder.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/testing/matchers.h" #include "eval/public/unknown_attribute_set.h" @@ -1717,7 +1718,7 @@ TEST_P(CustomDescriptorPoolTest, TestType) { google::protobuf::Arena arena; // Setup descriptor pool and builder - ASSERT_OK(AddStandardMessageTypesToDescriptorPool(&descriptor_pool)); + ASSERT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("m")); FlatExprBuilder builder(&descriptor_pool, &message_factory); diff --git a/eval/public/BUILD b/eval/public/BUILD index 3f73536ce..3c0a0fce5 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -404,23 +404,12 @@ cc_library( ":cel_expression", ":cel_options", "//eval/compiler:flat_expr_builder", + "//internal:proto_util", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) -cc_test( - name = "cel_expr_builder_factory_test", - srcs = ["cel_expr_builder_factory_test.cc"], - deps = [ - ":cel_expr_builder_factory", - "//eval/testutil:test_message_cc_proto", - "//internal:testing", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_protobuf//:protobuf", - ], -) - cc_library( name = "value_export_util", srcs = [ diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index b349a37d2..c78e846c5 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -16,148 +16,31 @@ #include "eval/public/cel_expr_builder_factory.h" +#include #include -#include "google/protobuf/any.pb.h" -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/struct.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "google/protobuf/wrappers.pb.h" -#include "google/protobuf/descriptor.pb.h" -#include "google/protobuf/util/message_differencer.h" #include "absl/status/status.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/public/cel_options.h" +#include "internal/proto_util.h" namespace google::api::expr::runtime { namespace { -template -absl::Status ValidateStandardMessageType( - const google::protobuf::DescriptorPool* descriptor_pool) { - const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); - const google::protobuf::Descriptor* descriptor_from_pool = - descriptor_pool->FindMessageTypeByName(descriptor->full_name()); - if (descriptor_from_pool == nullptr) { - return absl::NotFoundError( - absl::StrFormat("Descriptor '%s' not found in descriptor pool", - descriptor->full_name())); - } - if (descriptor_from_pool == descriptor) { - return absl::OkStatus(); - } - google::protobuf::DescriptorProto descriptor_proto; - google::protobuf::DescriptorProto descriptor_from_pool_proto; - descriptor->CopyTo(&descriptor_proto); - descriptor_from_pool->CopyTo(&descriptor_from_pool_proto); - if (!google::protobuf::util::MessageDifferencer::Equals(descriptor_proto, - descriptor_from_pool_proto)) { - return absl::FailedPreconditionError(absl::StrFormat( - "The descriptor for '%s' in the descriptor pool differs from the " - "compiled-in generated version", - descriptor->full_name())); - } - return absl::OkStatus(); -} - -template -absl::Status AddOrValidateMessageType(google::protobuf::DescriptorPool* descriptor_pool) { - const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); - if (descriptor_pool->FindMessageTypeByName(descriptor->full_name()) != - nullptr) { - return ValidateStandardMessageType(descriptor_pool); - } - google::protobuf::FileDescriptorProto file_descriptor_proto; - descriptor->file()->CopyTo(&file_descriptor_proto); - if (descriptor_pool->BuildFile(file_descriptor_proto) == nullptr) { - return absl::InternalError( - absl::StrFormat("Failed to add descriptor '%s' to descriptor pool", - descriptor->full_name())); - } - return absl::OkStatus(); -} - -absl::Status ValidateStandardMessageTypes( - const google::protobuf::DescriptorPool* descriptor_pool) { - CEL_RETURN_IF_ERROR( - ValidateStandardMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR(ValidateStandardMessageType( - descriptor_pool)); - CEL_RETURN_IF_ERROR(ValidateStandardMessageType( - descriptor_pool)); - CEL_RETURN_IF_ERROR( - ValidateStandardMessageType( - descriptor_pool)); - CEL_RETURN_IF_ERROR( - ValidateStandardMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR(ValidateStandardMessageType( - descriptor_pool)); - CEL_RETURN_IF_ERROR(ValidateStandardMessageType( - descriptor_pool)); - CEL_RETURN_IF_ERROR(ValidateStandardMessageType( - descriptor_pool)); - CEL_RETURN_IF_ERROR(ValidateStandardMessageType( - descriptor_pool)); - CEL_RETURN_IF_ERROR( - ValidateStandardMessageType( - descriptor_pool)); - CEL_RETURN_IF_ERROR( - ValidateStandardMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR(ValidateStandardMessageType( - descriptor_pool)); - CEL_RETURN_IF_ERROR( - ValidateStandardMessageType( - descriptor_pool)); - CEL_RETURN_IF_ERROR( - ValidateStandardMessageType( - descriptor_pool)); - CEL_RETURN_IF_ERROR( - ValidateStandardMessageType(descriptor_pool)); - return absl::OkStatus(); -} - +using ::google::api::expr::internal::ValidateStandardMessageTypes; } // namespace -absl::Status AddStandardMessageTypesToDescriptorPool( - google::protobuf::DescriptorPool* descriptor_pool) { - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - CEL_RETURN_IF_ERROR( - AddOrValidateMessageType(descriptor_pool)); - return absl::OkStatus(); -} - std::unique_ptr CreateCelExpressionBuilder( const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::MessageFactory* message_factory, const InterpreterOptions& options) { - if (!ValidateStandardMessageTypes(descriptor_pool).ok()) { + if (descriptor_pool == nullptr) { + GOOGLE_LOG(ERROR) << "Cannot pass nullptr as descriptor pool to " + "CreateCelExpressionBuilder"; + return nullptr; + } + if (auto s = ValidateStandardMessageTypes(*descriptor_pool); !s.ok()) { + GOOGLE_LOG(WARNING) << "Failed to validate standard message types: " << s; return nullptr; } auto builder = diff --git a/eval/public/cel_expr_builder_factory.h b/eval/public/cel_expr_builder_factory.h index 6063dacc2..7321e29a2 100644 --- a/eval/public/cel_expr_builder_factory.h +++ b/eval/public/cel_expr_builder_factory.h @@ -23,9 +23,6 @@ inline std::unique_ptr CreateCelExpressionBuilder( options); } -absl::Status AddStandardMessageTypesToDescriptorPool( - google::protobuf::DescriptorPool* descriptor_pool); - } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 651a92b0c..5ed70a3a0 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -38,6 +38,31 @@ cc_library( ], ) +cc_library( + name = "cel_proto_descriptor_pool_builder", + srcs = ["cel_proto_descriptor_pool_builder.cc"], + hdrs = ["cel_proto_descriptor_pool_builder.h"], + deps = [ + "//internal:proto_util", + "//internal:status_macros", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "cel_proto_descriptor_pool_builder_test", + srcs = ["cel_proto_descriptor_pool_builder_test.cc"], + deps = [ + ":cel_proto_descriptor_pool_builder", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "cel_proto_wrapper_test", size = "small", diff --git a/eval/public/structs/cel_proto_descriptor_pool_builder.cc b/eval/public/structs/cel_proto_descriptor_pool_builder.cc new file mode 100644 index 000000000..abf35181b --- /dev/null +++ b/eval/public/structs/cel_proto_descriptor_pool_builder.cc @@ -0,0 +1,126 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "eval/public/structs/cel_proto_descriptor_pool_builder.h" + +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/container/flat_hash_map.h" +#include "internal/proto_util.h" +#include "internal/status_macros.h" + +namespace google::api::expr::runtime { +namespace { +template +absl::Status AddOrValidateMessageType(google::protobuf::DescriptorPool& descriptor_pool) { + const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); + if (descriptor_pool.FindMessageTypeByName(descriptor->full_name()) != + nullptr) { + return internal::ValidateStandardMessageType(descriptor_pool); + } + google::protobuf::FileDescriptorProto file_descriptor_proto; + descriptor->file()->CopyTo(&file_descriptor_proto); + if (descriptor_pool.BuildFile(file_descriptor_proto) == nullptr) { + return absl::InternalError( + absl::StrFormat("Failed to add descriptor '%s' to descriptor pool", + descriptor->full_name())); + } + return absl::OkStatus(); +} + +template +void AddStandardMessageTypeToMap( + absl::flat_hash_map& fdmap) { + const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); + + if (fdmap.contains(descriptor->file()->name())) return; + + descriptor->file()->CopyTo(&fdmap[descriptor->file()->name()]); +} + +} // namespace + +absl::Status AddStandardMessageTypesToDescriptorPool( + google::protobuf::DescriptorPool& descriptor_pool) { + // The types below do not depend on each other, hence we can add them in any + // order. Should that change with new messages add them in the proper order, + // i.e., dependencies first. + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR( + AddOrValidateMessageType(descriptor_pool)); + return absl::OkStatus(); +} + +google::protobuf::FileDescriptorSet GetStandardMessageTypesFileDescriptorSet() { + // The types below do not depend on each other, hence we can add them to + // an unordered map. Should that change with new messages being added here + // adapt this to a sorted data structure and add in the proper order. + absl::flat_hash_map files; + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + AddStandardMessageTypeToMap(files); + google::protobuf::FileDescriptorSet fdset; + for (const auto& [name, fdproto] : files) { + *fdset.add_file() = fdproto; + } + return fdset; +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/cel_proto_descriptor_pool_builder.h b/eval/public/structs/cel_proto_descriptor_pool_builder.h new file mode 100644 index 000000000..d6007c76b --- /dev/null +++ b/eval/public/structs/cel_proto_descriptor_pool_builder.h @@ -0,0 +1,40 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_DESCRIPTOR_POOL_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_DESCRIPTOR_POOL_BUILDER_H_ + +#include "google/protobuf/descriptor.pb.h" +#include "google/protobuf/descriptor.h" +#include "absl/status/status.h" + +namespace google::api::expr::runtime { + +// Add standard message types required by CEL to given descriptor pool. +// This includes standard wrappers, timestamp, duration, any, etc. +// This does not work for descriptor pools that have a fallback database. +// Use GetStandardMessageTypesFileDescriptorSet() below instead to populate. +absl::Status AddStandardMessageTypesToDescriptorPool( + google::protobuf::DescriptorPool& descriptor_pool); + +// Get the standard message types required by CEL. +// This includes standard wrappers, timestamp, duration, any, etc. These can be +// used to, e.g., add them to a DescriptorDatabase backing a DescriptorPool. +google::protobuf::FileDescriptorSet GetStandardMessageTypesFileDescriptorSet(); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_DESCRIPTOR_POOL_BUILDER_H_ diff --git a/eval/public/cel_expr_builder_factory_test.cc b/eval/public/structs/cel_proto_descriptor_pool_builder_test.cc similarity index 86% rename from eval/public/cel_expr_builder_factory_test.cc rename to eval/public/structs/cel_proto_descriptor_pool_builder_test.cc index 571fb6dc5..3682d1ba3 100644 --- a/eval/public/cel_expr_builder_factory_test.cc +++ b/eval/public/structs/cel_proto_descriptor_pool_builder_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/structs/cel_proto_descriptor_pool_builder.h" #include @@ -28,6 +28,7 @@ namespace google::api::expr::runtime { namespace { using testing::HasSubstr; +using testing::UnorderedElementsAre; using cel::internal::StatusIs; TEST(DescriptorPoolUtilsTest, PopulatesEmptyDescriptorPool) { @@ -68,7 +69,7 @@ TEST(DescriptorPoolUtilsTest, PopulatesEmptyDescriptorPool) { ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Value"), nullptr); - ASSERT_OK(AddStandardMessageTypesToDescriptorPool(&descriptor_pool)); + ASSERT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Any"), nullptr); @@ -127,7 +128,7 @@ TEST(DescriptorPoolUtilsTest, AcceptsPreAddedStandardTypes) { ASSERT_NE(descriptor_pool.BuildFile(file_descriptor_proto), nullptr); } - EXPECT_OK(AddStandardMessageTypesToDescriptorPool(&descriptor_pool)); + EXPECT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); } TEST(DescriptorPoolUtilsTest, RejectsModifiedStandardType) { @@ -155,10 +156,24 @@ TEST(DescriptorPoolUtilsTest, RejectsModifiedStandardType) { descriptor_pool.BuildFile(file_descriptor_proto); EXPECT_THAT( - AddStandardMessageTypesToDescriptorPool(&descriptor_pool), + AddStandardMessageTypesToDescriptorPool(descriptor_pool), StatusIs(absl::StatusCode::kFailedPrecondition, HasSubstr("differs"))); } +TEST(DescriptorPoolUtilsTest, GetStandardMessageTypesFileDescriptorSet) { + google::protobuf::FileDescriptorSet fdset = GetStandardMessageTypesFileDescriptorSet(); + std::vector file_names; + for (int i = 0; i < fdset.file_size(); ++i) { + file_names.push_back(fdset.file(i).name()); + } + EXPECT_THAT(file_names, + UnorderedElementsAre("google/protobuf/any.proto", + "google/protobuf/struct.proto", + "google/protobuf/wrappers.proto", + "google/protobuf/timestamp.proto", + "google/protobuf/duration.proto")); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/internal/BUILD b/internal/BUILD index 9a0c1dfd5..cda5eba3b 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -142,6 +142,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf", ], diff --git a/internal/proto_util.cc b/internal/proto_util.cc index 305a6cf3d..7bc7d049f 100644 --- a/internal/proto_util.cc +++ b/internal/proto_util.cc @@ -16,8 +16,11 @@ #include +#include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" #include "google/protobuf/util/time_util.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -101,6 +104,45 @@ absl::StatusOr EncodeTimeToString(absl::Time time) { return google::protobuf::util::TimeUtil::ToString(t); } +absl::Status ValidateStandardMessageTypes( + const google::protobuf::DescriptorPool& descriptor_pool) { + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType(descriptor_pool)); + CEL_RETURN_IF_ERROR(ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType( + descriptor_pool)); + CEL_RETURN_IF_ERROR( + ValidateStandardMessageType(descriptor_pool)); + return absl::OkStatus(); +} + } // namespace internal } // namespace expr } // namespace api diff --git a/internal/proto_util.h b/internal/proto_util.h index 1549aba31..f82b00172 100644 --- a/internal/proto_util.h +++ b/internal/proto_util.h @@ -17,10 +17,12 @@ #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/descriptor.pb.h" #include "google/protobuf/util/message_differencer.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "absl/time/time.h" namespace google { @@ -54,6 +56,37 @@ absl::Duration DecodeDuration(const google::protobuf::Duration& proto); /** Helper function to decode a time from a google::protobuf::Timestamp. */ absl::Time DecodeTime(const google::protobuf::Timestamp& proto); +template +absl::Status ValidateStandardMessageType( + const google::protobuf::DescriptorPool& descriptor_pool) { + const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); + const google::protobuf::Descriptor* descriptor_from_pool = + descriptor_pool.FindMessageTypeByName(descriptor->full_name()); + if (descriptor_from_pool == nullptr) { + return absl::NotFoundError( + absl::StrFormat("Descriptor '%s' not found in descriptor pool", + descriptor->full_name())); + } + if (descriptor_from_pool == descriptor) { + return absl::OkStatus(); + } + google::protobuf::DescriptorProto descriptor_proto; + google::protobuf::DescriptorProto descriptor_from_pool_proto; + descriptor->CopyTo(&descriptor_proto); + descriptor_from_pool->CopyTo(&descriptor_from_pool_proto); + if (!google::protobuf::util::MessageDifferencer::Equals(descriptor_proto, + descriptor_from_pool_proto)) { + return absl::FailedPreconditionError(absl::StrFormat( + "The descriptor for '%s' in the descriptor pool differs from the " + "compiled-in generated version", + descriptor->full_name())); + } + return absl::OkStatus(); +} + +absl::Status ValidateStandardMessageTypes( + const google::protobuf::DescriptorPool& descriptor_pool); + } // namespace internal } // namespace expr } // namespace api From afc86aa708ea6cc7944f5fbb5a335a0beb2762b3 Mon Sep 17 00:00:00 2001 From: timdn Date: Thu, 10 Mar 2022 16:20:20 +0000 Subject: [PATCH 023/111] Ignore json_name on DescriptorProto comparison for pools The json_name field of the FieldDescriptorProto may be set differently (or not at all) depending on the compiler used. We saw differences in just this field leading to the differencer detecting a difference, but the messages would still be compatible. Hence ignore this field. PiperOrigin-RevId: 433757941 --- internal/BUILD | 11 ++++ internal/proto_util.h | 13 ++++- internal/proto_util_test.cc | 113 ++++++++++++++++++++++++++++++++++++ 3 files changed, 135 insertions(+), 2 deletions(-) create mode 100644 internal/proto_util_test.cc diff --git a/internal/BUILD b/internal/BUILD index cda5eba3b..f92794e89 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -148,6 +148,17 @@ cc_library( ], ) +cc_test( + name = "proto_util_test", + srcs = ["proto_util_test.cc"], + deps = [ + ":proto_util", + ":testing", + "//eval/public/structs:cel_proto_descriptor_pool_builder", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "reference_counted", hdrs = ["reference_counted.h"], diff --git a/internal/proto_util.h b/internal/proto_util.h index f82b00172..386d1309a 100644 --- a/internal/proto_util.h +++ b/internal/proto_util.h @@ -74,8 +74,17 @@ absl::Status ValidateStandardMessageType( google::protobuf::DescriptorProto descriptor_from_pool_proto; descriptor->CopyTo(&descriptor_proto); descriptor_from_pool->CopyTo(&descriptor_from_pool_proto); - if (!google::protobuf::util::MessageDifferencer::Equals(descriptor_proto, - descriptor_from_pool_proto)) { + + google::protobuf::util::MessageDifferencer descriptor_differencer; + // The json_name is a compiler detail and does not change the message content. + // It can differ, e.g., between C++ and Go compilers. Hence ignore. + const google::protobuf::FieldDescriptor* json_name_field_desc = + google::protobuf::FieldDescriptorProto::descriptor()->FindFieldByName("json_name"); + if (json_name_field_desc != nullptr) { + descriptor_differencer.IgnoreField(json_name_field_desc); + } + if (!descriptor_differencer.Compare(descriptor_proto, + descriptor_from_pool_proto)) { return absl::FailedPreconditionError(absl::StrFormat( "The descriptor for '%s' in the descriptor pool differs from the " "compiled-in generated version", diff --git a/internal/proto_util_test.cc b/internal/proto_util_test.cc new file mode 100644 index 000000000..df913b48a --- /dev/null +++ b/internal/proto_util_test.cc @@ -0,0 +1,113 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/proto_util.h" + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/descriptor.pb.h" +#include "google/protobuf/descriptor.h" +#include "eval/public/structs/cel_proto_descriptor_pool_builder.h" +#include "internal/testing.h" + +namespace cel::internal { +namespace { + +using google::api::expr::internal::ValidateStandardMessageType; +using google::api::expr::internal::ValidateStandardMessageTypes; +using google::api::expr::runtime::AddStandardMessageTypesToDescriptorPool; +using google::api::expr::runtime::GetStandardMessageTypesFileDescriptorSet; + +using testing::HasSubstr; +using cel::internal::StatusIs; + +TEST(ProtoUtil, ValidateStandardMessageTypesOk) { + google::protobuf::DescriptorPool descriptor_pool; + ASSERT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); + EXPECT_OK(ValidateStandardMessageTypes(descriptor_pool)); +} + +TEST(ProtoUtil, ValidateStandardMessageTypesRejectsMissing) { + google::protobuf::DescriptorPool descriptor_pool; + EXPECT_THAT(ValidateStandardMessageTypes(descriptor_pool), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("not found in descriptor pool"))); +} + +TEST(ProtoUtil, ValidateStandardMessageTypesRejectsIncompatible) { + google::protobuf::DescriptorPool descriptor_pool; + google::protobuf::FileDescriptorSet standard_fds = + GetStandardMessageTypesFileDescriptorSet(); + + const google::protobuf::Descriptor* descriptor = + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.protobuf.Duration"); + ASSERT_NE(descriptor, nullptr); + google::protobuf::FileDescriptorProto file_descriptor_proto; + descriptor->file()->CopyTo(&file_descriptor_proto); + // We emulate a modification by external code that replaced the nanos by a + // millis field. + google::protobuf::FieldDescriptorProto seconds_desc_proto; + google::protobuf::FieldDescriptorProto nanos_desc_proto; + descriptor->FindFieldByName("seconds")->CopyTo(&seconds_desc_proto); + descriptor->FindFieldByName("nanos")->CopyTo(&nanos_desc_proto); + nanos_desc_proto.set_name("millis"); + file_descriptor_proto.mutable_message_type(0)->clear_field(); + *file_descriptor_proto.mutable_message_type(0)->add_field() = + seconds_desc_proto; + *file_descriptor_proto.mutable_message_type(0)->add_field() = + nanos_desc_proto; + + descriptor_pool.BuildFile(file_descriptor_proto); + + EXPECT_THAT( + ValidateStandardMessageType(descriptor_pool), + StatusIs(absl::StatusCode::kFailedPrecondition, HasSubstr("differs"))); +} + +TEST(ProtoUtil, ValidateStandardMessageTypesIgnoredJsonName) { + google::protobuf::DescriptorPool descriptor_pool; + google::protobuf::FileDescriptorSet standard_fds = + GetStandardMessageTypesFileDescriptorSet(); + bool modified = false; + // This nested loops are used to find the field descriptor proto to modify the + // json_name field of. + for (int i = 0; i < standard_fds.file_size(); ++i) { + if (standard_fds.file(i).name() == "google/protobuf/duration.proto") { + google::protobuf::FileDescriptorProto* fdp = standard_fds.mutable_file(i); + for (int j = 0; j < fdp->message_type_size(); ++j) { + if (fdp->message_type(j).name() == "Duration") { + google::protobuf::DescriptorProto* dp = fdp->mutable_message_type(j); + for (int k = 0; k < dp->field_size(); ++k) { + if (dp->field(k).name() == "seconds") { + // we need to set this to something we are reasonable sure of that + // it won't be set for real to make sure it is ignored + dp->mutable_field(k)->set_json_name("FOOBAR"); + modified = true; + } + } + } + } + } + } + ASSERT_TRUE(modified); + + for (int i = 0; i < standard_fds.file_size(); ++i) { + descriptor_pool.BuildFile(standard_fds.file(i)); + } + + EXPECT_OK(ValidateStandardMessageTypes(descriptor_pool)); +} + +} // namespace +} // namespace cel::internal From 9e21e139b71a0759ed943d6e18321a6634f46bf3 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 10 Mar 2022 23:22:51 +0000 Subject: [PATCH 024/111] Internal change PiperOrigin-RevId: 433865216 --- base/BUILD | 54 ++ base/internal/BUILD | 11 + base/internal/type.post.h | 8 +- base/internal/value.h | 512 ------------------- base/internal/value.post.h | 554 ++++++++++++++++++++ base/internal/value.pre.h | 158 ++++++ base/value.cc | 950 ++++++++++++++++------------------- base/value.h | 613 +++++++++++++--------- base/value_factory.cc | 117 +++++ base/value_factory.h | 114 +++++ base/value_factory_test.cc | 38 ++ base/value_test.cc | 889 +++++++++++++++++++------------- internal/BUILD | 8 - internal/reference_counted.h | 99 ---- 14 files changed, 2377 insertions(+), 1748 deletions(-) delete mode 100644 base/internal/value.h create mode 100644 base/internal/value.post.h create mode 100644 base/internal/value.pre.h create mode 100644 base/value_factory.cc create mode 100644 base/value_factory.h create mode 100644 base/value_factory_test.cc delete mode 100644 internal/reference_counted.h diff --git a/base/BUILD b/base/BUILD index ccb2be2fc..516ec2f00 100644 --- a/base/BUILD +++ b/base/BUILD @@ -129,3 +129,57 @@ cc_test( "@com_google_absl//absl/hash:hash_testing", ], ) + +cc_library( + name = "value", + srcs = [ + "value.cc", + "value_factory.cc", + ], + hdrs = [ + "value.h", + "value_factory.h", + ], + deps = [ + ":handle", + ":kind", + ":memory_manager", + ":type", + "//base/internal:value", + "//internal:casts", + "//internal:no_destructor", + "//internal:status_macros", + "//internal:strings", + "//internal:time", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "value_test", + srcs = [ + "value_factory_test.cc", + "value_test.cc", + ], + deps = [ + ":memory_manager", + ":type", + ":value", + "//internal:strings", + "//internal:testing", + "//internal:time", + "@com_google_absl//absl/hash:hash_testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + ], +) diff --git a/base/internal/BUILD b/base/internal/BUILD index 3715ed189..ce4b046d7 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -54,10 +54,21 @@ cc_library( "type.pre.h", "type.post.h", ], +) + +cc_library( + name = "value", + textual_hdrs = [ + "value.pre.h", + "value.post.h", + ], deps = [ "//base:handle", + "//internal:casts", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/hash", "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", ], ) diff --git a/base/internal/type.post.h b/base/internal/type.post.h index 956dd69a9..102c8dee2 100644 --- a/base/internal/type.post.h +++ b/base/internal/type.post.h @@ -32,10 +32,10 @@ namespace cel { namespace base_internal { -// Base implementation of persistent and transient handles. This contains -// implementation details shared among both, but is never used directly. The -// derived classes are responsible for defining appropriate constructors and -// assignments. +// Base implementation of persistent and transient handles for types. This +// contains implementation details shared among both, but is never used +// directly. The derived classes are responsible for defining appropriate +// constructors and assignments. class TypeHandleBase { public: constexpr TypeHandleBase() = default; diff --git a/base/internal/value.h b/base/internal/value.h deleted file mode 100644 index 81fdc0b03..000000000 --- a/base/internal/value.h +++ /dev/null @@ -1,512 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_H_ -#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/config.h" -#include "absl/base/macros.h" -#include "absl/base/optimization.h" -#include "absl/hash/hash.h" -#include "absl/numeric/bits.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "base/internal/type.h" -#include "base/kind.h" -#include "internal/casts.h" -#include "internal/reference_counted.h" - -namespace cel { - -class Value; -class Bytes; - -namespace base_internal { - -// Abstract base class that all non-simple values are derived from. Users will -// not inherit from this directly but rather indirectly through exposed classes -// like cel::Struct. -class BaseValue : public cel::internal::ReferenceCounted { - public: - // Returns a human readable representation of this value. The representation - // is not guaranteed to be consistent across versions and should only be used - // for debugging purposes. - virtual std::string DebugString() const = 0; - - protected: - virtual bool Equals(const cel::Value& value) const = 0; - - virtual void HashValue(absl::HashState state) const = 0; - - private: - friend class cel::Value; - friend class cel::Bytes; - - BaseValue() = default; -}; - -// Type erased state capable of holding a pointer to remote storage or storing -// objects less than two pointers in size inline. -union ExternalDataReleaserState final { - void* remote; - alignas(alignof(std::max_align_t)) char local[sizeof(void*) * 2]; -}; - -// Function which deletes the object referenced by ExternalDataReleaserState. -using ExternalDataReleaserDeleter = void(ExternalDataReleaserState* state); - -template -void LocalExternalDataReleaserDeleter(ExternalDataReleaserState* state) { - reinterpret_cast(&state->local)->~Releaser(); -} - -template -void RemoteExternalDataReleaserDeleter(ExternalDataReleaserState* state) { - ::delete reinterpret_cast(state->remote); -} - -// Function which invokes the object referenced by ExternalDataReleaserState. -using ExternalDataReleaseInvoker = - void(ExternalDataReleaserState* state) noexcept; - -template -void LocalExternalDataReleaserInvoker( - ExternalDataReleaserState* state) noexcept { - (*reinterpret_cast(&state->local))(); -} - -template -void RemoteExternalDataReleaserInvoker( - ExternalDataReleaserState* state) noexcept { - (*reinterpret_cast(&state->remote))(); -} - -struct ExternalDataReleaser final { - ExternalDataReleaser() = delete; - - template - explicit ExternalDataReleaser(Releaser&& releaser) { - using DecayedReleaser = std::decay_t; - if constexpr (sizeof(DecayedReleaser) <= sizeof(void*) * 2 && - alignof(DecayedReleaser) <= alignof(std::max_align_t)) { - // Object meets size and alignment constraints, will be stored - // inline in ExternalDataReleaserState.local. - ::new (static_cast(&state.local)) - DecayedReleaser(std::forward(releaser)); - invoker = LocalExternalDataReleaserInvoker; - if constexpr (std::is_trivially_destructible_v) { - // Object is trivially destructable, no need to call destructor at all. - deleter = nullptr; - } else { - deleter = LocalExternalDataReleaserDeleter; - } - } else { - // Object does not meet size and alignment constraints, allocate on the - // heap and store pointer in ExternalDataReleaserState::remote. inline in - // ExternalDataReleaserState::local. - state.remote = ::new DecayedReleaser(std::forward(releaser)); - invoker = RemoteExternalDataReleaserInvoker; - deleter = RemoteExternalDataReleaserDeleter; - } - } - - ExternalDataReleaser(const ExternalDataReleaser&) = delete; - - ExternalDataReleaser(ExternalDataReleaser&&) = delete; - - ~ExternalDataReleaser() { - (*invoker)(&state); - if (deleter != nullptr) { - (*deleter)(&state); - } - } - - ExternalDataReleaser& operator=(const ExternalDataReleaser&) = delete; - - ExternalDataReleaser& operator=(ExternalDataReleaser&&) = delete; - - ExternalDataReleaserState state; - ExternalDataReleaserDeleter* deleter; - ExternalDataReleaseInvoker* invoker; -}; - -// Utility class encompassing a contiguous array of data which a function that -// must be called when the data is no longer needed. -struct ExternalData final { - ExternalData() = delete; - - ExternalData(const void* data, size_t size, - std::unique_ptr releaser) - : data(data), size(size), releaser(std::move(releaser)) {} - - ExternalData(const ExternalData&) = delete; - - ExternalData(ExternalData&&) noexcept = default; - - ExternalData& operator=(const ExternalData&) = delete; - - ExternalData& operator=(ExternalData&&) noexcept = default; - - const void* data; - size_t size; - std::unique_ptr releaser; -}; - -// Currently absl::Status has a size that is less than or equal to 8, however -// this could change at any time. Thus we delegate the lifetime management to -// BaseInlinedStatus which is always less than or equal to 8 bytes. -template -class BaseInlinedStatus; - -// Specialization for when the size of absl::Status is less than or equal to 8 -// bytes. -template <> -class BaseInlinedStatus final { - public: - BaseInlinedStatus() = default; - - BaseInlinedStatus(const BaseInlinedStatus&) = default; - - BaseInlinedStatus(BaseInlinedStatus&&) = default; - - explicit BaseInlinedStatus(const absl::Status& status) : status_(status) {} - - BaseInlinedStatus& operator=(const BaseInlinedStatus&) = default; - - BaseInlinedStatus& operator=(BaseInlinedStatus&&) = default; - - BaseInlinedStatus& operator=(const absl::Status& status) { - status_ = status; - return *this; - } - - const absl::Status& status() const { return status_; } - - private: - absl::Status status_; -}; - -// Specialization for when the size of absl::Status is greater than 8 bytes. As -// mentioned above, this template is never used today. It could in the future if -// the size of `absl::Status` ever changes. Without this specialization, our -// static asserts below would break and so would compiling CEL. -template <> -class BaseInlinedStatus final { - public: - BaseInlinedStatus() = default; - - BaseInlinedStatus(const BaseInlinedStatus&) = default; - - BaseInlinedStatus(BaseInlinedStatus&&) = default; - - explicit BaseInlinedStatus(const absl::Status& status) - : status_(std::make_shared(status)) {} - - BaseInlinedStatus& operator=(const BaseInlinedStatus&) = default; - - BaseInlinedStatus& operator=(BaseInlinedStatus&&) = default; - - BaseInlinedStatus& operator=(const absl::Status& status) { - if (status_) { - *status_ = status; - } else { - status_ = std::make_shared(status); - } - return *this; - } - - const absl::Status& status() const { - static const absl::Status* ok_status = new absl::Status(); - return status_ ? *status_ : *ok_status; - } - - private: - std::shared_ptr status_; -}; - -using InlinedStatus = BaseInlinedStatus<(sizeof(absl::Status) <= 8)>; - -// ValueMetadata is a specialized tagged union capable of storing either a -// pointer to a BaseType or a Kind. Only simple kinds are stored directly. -// Simple kinds can be converted into cel::Type using cel::Type::Simple. -// ValueMetadata is primarily used to interpret the contents of ValueContent. -// -// We assume that all pointers returned by `malloc()` are at minimum aligned to -// 4 bytes. In practice this assumption is pretty safe and all known -// implementations exhibit this behavior. -// -// The tagged union byte layout depends on the 0 bit. -// -// Bit 0 unset: -// -// -------------------------------- -// | 63 ... 2 | 1 | 0 | -// -------------------------------- -// | pointer | reserved | reffed | -// -------------------------------- -// -// Bit 0 set: -// -// --------------------------------------------------------------- -// | 63 ... 32 | 31 ... 16 | 15 ... 8 | 7 ... 1 | 0 | -// --------------------------------------------------------------- -// | extended_content | reserved | kind | reserved | simple | -// --------------------------------------------------------------- -// -// Q: Why not use absl::variant/std::variant? -// A: In theory, we could. However it would be repetative and inefficient. -// variant has a size equal to the largest of its memory types plus an -// additional field keeping track of the type that is active. For our purposes, -// the field that is active is kept track of by ValueMetadata and the storage in -// ValueContent. We know what is stored in ValueContent by the kind/type in -// ValueMetadata. Since we need to keep the type bundled with the Value, using -// variant would introduce two sources of truth for what is stored in -// ValueContent. If we chose the naive implementation, which would be to use -// Type instead of ValueMetadata and variant instead of ValueContent, each time -// we copy Value we would be guaranteed to incur a reference count causing a -// cache miss. This approach avoids that reference count for simple types. -// Additionally the size of Value would now be roughly 8 + 16 on 64-bit -// platforms. -// -// As with ValueContent, this class is only meant to be used by cel::Value. -class ValueMetadata final { - public: - constexpr ValueMetadata() : raw_(MakeDefault()) {} - - constexpr explicit ValueMetadata(Kind kind) : ValueMetadata(kind, 0) {} - - constexpr ValueMetadata(Kind kind, uint32_t extended_content) - : raw_(MakeSimple(kind, extended_content)) {} - - explicit ValueMetadata(const BaseType* base_type) - : ptr_(reinterpret_cast(base_type)) { - // Assert that the lower 2 bits are 0, a.k.a. at minimum 4 byte aligned. - ABSL_ASSERT(absl::countr_zero(reinterpret_cast(base_type)) >= 2); - } - - ValueMetadata(const ValueMetadata&) = delete; - - ValueMetadata(ValueMetadata&&) = delete; - - ValueMetadata& operator=(const ValueMetadata&) = delete; - - ValueMetadata& operator=(ValueMetadata&&) = delete; - - constexpr bool simple_tag() const { - return (lower_ & kSimpleTag) == kSimpleTag; - } - - constexpr uint32_t extended_content() const { - ABSL_ASSERT(simple_tag()); - return higher_; - } - - const BaseType* base_type() const { - ABSL_ASSERT(!simple_tag()); - return reinterpret_cast(ptr_ & kPtrMask); - } - - Kind kind() const { - return simple_tag() ? static_cast(lower_ >> 8) : base_type()->kind(); - } - - void Reset() { - if (!simple_tag()) { - internal::Unref(base_type()); - } - raw_ = MakeDefault(); - } - - void CopyFrom(const ValueMetadata& other) { - if (ABSL_PREDICT_TRUE(this != std::addressof(other))) { - if (!other.simple_tag()) { - internal::Ref(other.base_type()); - } - if (!simple_tag()) { - internal::Unref(base_type()); - } - raw_ = other.raw_; - } - } - - void MoveFrom(ValueMetadata&& other) { - if (ABSL_PREDICT_TRUE(this != std::addressof(other))) { - if (!simple_tag()) { - internal::Unref(base_type()); - } - raw_ = other.raw_; - other.raw_ = MakeDefault(); - } - } - - private: - static constexpr uint64_t MakeSimple(Kind kind, uint32_t extended_content) { - return static_cast(kSimpleTag | - (static_cast(kind) << 8)) | - (static_cast(extended_content) << 32); - } - - static constexpr uint64_t MakeDefault() { - return MakeSimple(Kind::kNullType, 0); - } - - static constexpr uint32_t kNoTag = 0; - static constexpr uint32_t kSimpleTag = - 1 << 0; // Indicates the kind is simple and there is no BaseType* held. - static constexpr uint32_t kReservedTag = 1 << 1; - static constexpr uintptr_t kPtrMask = - ~static_cast(kSimpleTag | kReservedTag); - - union { - uint64_t raw_; - -#if defined(ABSL_IS_LITTLE_ENDIAN) - struct { - uint32_t lower_; - uint32_t higher_; - }; -#elif defined(ABSL_IS_BIG_ENDIAN) - struct { - uint32_t higher_; - uint32_t lower_; - }; -#else -#error "Platform is neither big endian nor little endian" -#endif - - uintptr_t ptr_; - }; -}; - -static_assert(sizeof(ValueMetadata) == 8, - "Expected sizeof(ValueMetadata) to be 8"); - -// ValueContent is an untagged union whose contents are determined by the -// accompanying ValueMetadata. -// -// As with ValueMetadata, this class is only meant to be used by cel::Value. -class ValueContent final { - public: - constexpr ValueContent() : raw_(0) {} - - constexpr explicit ValueContent(bool value) : bool_value_(value) {} - - constexpr explicit ValueContent(int64_t value) : int_value_(value) {} - - constexpr explicit ValueContent(uint64_t value) : uint_value_(value) {} - - constexpr explicit ValueContent(double value) : double_value_(value) {} - - explicit ValueContent(const absl::Status& status) { - construct_error_value(status); - } - - constexpr explicit ValueContent(BaseValue* base_value) - : base_value_(base_value) {} - - ValueContent(const ValueContent&) = delete; - - ValueContent(ValueContent&&) = delete; - - ~ValueContent() {} - - ValueContent& operator=(const ValueContent&) = delete; - - ValueContent& operator=(ValueContent&&) = delete; - - constexpr bool bool_value() const { return bool_value_; } - - constexpr int64_t int_value() const { return int_value_; } - - constexpr uint64_t uint_value() const { return uint_value_; } - - constexpr double double_value() const { return double_value_; } - - constexpr void construct_trivial_value(uint64_t value) { raw_ = value; } - - constexpr void destruct_trivial_value() { raw_ = 0; } - - constexpr uint64_t trivial_value() const { return raw_; } - - // Updates this to hold `value`, incrementing the reference count. This is - // used during copies. - void construct_reffed_value(BaseValue* value) { - base_value_ = cel::internal::Ref(value); - } - - // Updates this to hold `value` without incrementing the reference count. This - // is used during moves. - void adopt_reffed_value(BaseValue* value) { base_value_ = value; } - - // Decrement the reference count of the currently held reffed value and clear - // this. - void destruct_reffed_value() { - cel::internal::Unref(base_value_); - base_value_ = nullptr; - } - - // Return the currently held reffed value and reset this, without decrementing - // the reference count. This is used during moves. - BaseValue* release_reffed_value() { - BaseValue* reffed_value = base_value_; - base_value_ = nullptr; - return reffed_value; - } - - constexpr BaseValue* reffed_value() const { return base_value_; } - - void construct_error_value(const absl::Status& status) { - ::new (static_cast(std::addressof(error_value_))) - InlinedStatus(status); - } - - void assign_error_value(const absl::Status& status) { error_value_ = status; } - - void destruct_error_value() { - std::addressof(error_value_)->~InlinedStatus(); - } - - constexpr const absl::Status& error_value() const { - return error_value_.status(); - } - - private: - union { - uint64_t raw_; - - bool bool_value_; - int64_t int_value_; - uint64_t uint_value_; - double double_value_; - InlinedStatus error_value_; - BaseValue* base_value_; - }; -}; - -static_assert(sizeof(ValueContent) == 8, - "Expected sizeof(ValueContent) to be 8"); - -} // namespace base_internal - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_H_ diff --git a/base/internal/value.post.h b/base/internal/value.post.h new file mode 100644 index 000000000..75ab7f7f1 --- /dev/null +++ b/base/internal/value.post.h @@ -0,0 +1,554 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "base/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_POST_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_POST_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/hash/hash.h" +#include "absl/numeric/bits.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "base/handle.h" +#include "internal/casts.h" + +namespace cel { + +namespace base_internal { + +// Implementation of BytesValue that is stored inlined within a handle. Since +// absl::Cord is reference counted itself, this is more efficient than storing +// this on the heap. +class InlinedCordBytesValue final : public BytesValue, public ResourceInlined { + private: + template + friend class ValueHandle; + + explicit InlinedCordBytesValue(absl::Cord value) : value_(std::move(value)) {} + + InlinedCordBytesValue() = delete; + + InlinedCordBytesValue(const InlinedCordBytesValue&) = default; + InlinedCordBytesValue(InlinedCordBytesValue&&) = default; + + // See comments for respective member functions on `ByteValue` and `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + absl::Cord value_; +}; + +// Implementation of BytesValue that is stored inlined within a handle. This +// class is inheritently unsafe and care should be taken when using it. +// Typically this should only be used for empty strings or data that is static +// and lives for the duration of a program. +class InlinedStringViewBytesValue final : public BytesValue, + public ResourceInlined { + private: + template + friend class ValueHandle; + + explicit InlinedStringViewBytesValue(absl::string_view value) + : value_(value) {} + + InlinedStringViewBytesValue() = delete; + + InlinedStringViewBytesValue(const InlinedStringViewBytesValue&) = default; + InlinedStringViewBytesValue(InlinedStringViewBytesValue&&) = default; + + // See comments for respective member functions on `ByteValue` and `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + absl::string_view value_; +}; + +// Implementation of BytesValue that uses std::string and is allocated on the +// heap, potentially reference counted. +class StringBytesValue final : public BytesValue { + private: + friend class cel::MemoryManager; + + explicit StringBytesValue(std::string value) : value_(std::move(value)) {} + + StringBytesValue() = delete; + StringBytesValue(const StringBytesValue&) = delete; + StringBytesValue(StringBytesValue&&) = delete; + + // See comments for respective member functions on `ByteValue` and `Value`. + std::pair SizeAndAlignment() const override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + std::string value_; +}; + +// Implementation of BytesValue that wraps a contiguous array of bytes and calls +// the releaser when it is no longer needed. It is stored on the heap and +// potentially reference counted. +class ExternalDataBytesValue final : public BytesValue { + private: + friend class cel::MemoryManager; + + explicit ExternalDataBytesValue(ExternalData value) + : value_(std::move(value)) {} + + ExternalDataBytesValue() = delete; + ExternalDataBytesValue(const ExternalDataBytesValue&) = delete; + ExternalDataBytesValue(ExternalDataBytesValue&&) = delete; + + // See comments for respective member functions on `ByteValue` and `Value`. + std::pair SizeAndAlignment() const override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + ExternalData value_; +}; + +// Class used to assert the object memory layout for vptr at compile time, +// otherwise it is unused. +struct ABSL_ATTRIBUTE_UNUSED CheckVptrOffsetBase { + virtual ~CheckVptrOffsetBase() = default; + + virtual void Member() const {} +}; + +// Class used to assert the object memory layout for vptr at compile time, +// otherwise it is unused. +struct ABSL_ATTRIBUTE_UNUSED CheckVptrOffset final + : public CheckVptrOffsetBase { + uintptr_t member; +}; + +// Ensure the hidden vptr is stored at the beginning of the object. See +// ValueHandleData for more information. +static_assert(offsetof(CheckVptrOffset, member) == sizeof(void*), + "CEL C++ requires a compiler that stores the vptr as a hidden " + "member at the beginning of the object. If this static_assert " + "fails, please reach out to the CEL team."); + +// Union of all known inlinable values. +union ValueHandleData final { + // As asserted above, we rely on the fact that the compiler stores the vptr as + // a hidden member at the beginning of the object. We then re-use the first 2 + // bits to differentiate between an inlined value (both 0), a heap allocated + // reference counted value, or a arena allocated value. + void* vptr; + std::aligned_union_t + padding; +}; + +// Base implementation of persistent and transient handles for values. This +// contains implementation details shared among both, but is never used +// directly. The derived classes are responsible for defining appropriate +// constructors and assignments. +class ValueHandleBase { + public: + ValueHandleBase() { Reset(); } + + // Used by derived classes to bypass default construction to perform their own + // construction. + explicit ValueHandleBase(HandleInPlace) {} + + // Called by `Transient` and `Persistent` to implement the same operator. They + // will handle enforcing const correctness. + Value& operator*() const { return get(); } + + // Called by `Transient` and `Persistent` to implement the same operator. They + // will handle enforcing const correctness. + Value* operator->() const { return std::addressof(get()); } + + // Called by internal accessors `base_internal::IsXHandle`. + constexpr bool IsManaged() const { + return (vptr() & kValueHandleManaged) != 0; + } + + // Called by internal accessors `base_internal::IsXHandle`. + constexpr bool IsUnmanaged() const { + return (vptr() & kValueHandleUnmanaged) != 0; + } + + // Called by internal accessors `base_internal::IsXHandle`. + constexpr bool IsInlined() const { return (vptr() & kValueHandleBits) == 0; } + + // Called by `Transient` and `Persistent` to implement the same function. + template + bool Is() const { + // Tests that this is not an empty handle and then dereferences the handle + // calling the RTTI-like implementation T::Is which takes `const Value&`. + return static_cast(*this) && T::Is(static_cast(**this)); + } + + // Called by `Transient` and `Persistent` to implement the same operator. + explicit operator bool() const { return (vptr() & kValueHandleMask) != 0; } + + // Called by `Transient` and `Persistent` to implement the same operator. + friend bool operator==(const ValueHandleBase& lhs, + const ValueHandleBase& rhs) { + const Value& lhs_value = ABSL_PREDICT_TRUE(static_cast(lhs)) + ? lhs.get() + : static_cast(NullValue::Get()); + const Value& rhs_value = ABSL_PREDICT_TRUE(static_cast(rhs)) + ? rhs.get() + : static_cast(NullValue::Get()); + return lhs_value.Equals(rhs_value); + } + + // Called by `Transient` and `Persistent` to implement std::swap. + friend void swap(ValueHandleBase& lhs, ValueHandleBase& rhs) { + if (lhs.empty_or_not_inlined() && rhs.empty_or_not_inlined()) { + // Both `lhs` and `rhs` are simple pointers. Just swap them. + std::swap(lhs.data_.vptr, rhs.data_.vptr); + return; + } + ValueHandleBase tmp; + Move(lhs, tmp); + Move(rhs, lhs); + Move(tmp, rhs); + } + + template + friend H AbslHashValue(H state, const ValueHandleBase& handle) { + if (ABSL_PREDICT_TRUE(static_cast(handle))) { + handle.get().HashValue(absl::HashState::Create(&state)); + } else { + NullValue::Get().HashValue(absl::HashState::Create(&state)); + } + return state; + } + + private: + template + friend class ValueHandle; + + // Resets the state to the same as the default constructor. Does not perform + // any destruction of existing content. + void Reset() { data_.vptr = reinterpret_cast(kValueHandleUnmanaged); } + + void Unref() const { + ABSL_ASSERT(reffed()); + reinterpret_cast(vptr() & kValueHandleMask)->Unref(); + } + + void Ref() const { + ABSL_ASSERT(reffed()); + reinterpret_cast(vptr() & kValueHandleMask)->Ref(); + } + + Value& get() const { + return *(inlined() + ? reinterpret_cast(const_cast(&data_.vptr)) + : reinterpret_cast(vptr() & kValueHandleMask)); + } + + bool empty() const { return !static_cast(*this); } + + // Does the stored data represent an inlined value? + bool inlined() const { return (vptr() & kValueHandleBits) == 0; } + + // Does the stored data represent a non-null inlined value? + bool not_empty_and_inlined() const { + return (vptr() & kValueHandleBits) == 0 && (vptr() & kValueHandleMask) != 0; + } + + // Does the stored data represent null, heap allocated reference counted, or + // arena allocated value? + bool empty_or_not_inlined() const { + return (vptr() & kValueHandleBits) != 0 || (vptr() & kValueHandleMask) == 0; + } + + // Does the stored data required reference counting? + bool reffed() const { return (vptr() & kValueHandleManaged) != 0; } + + uintptr_t vptr() const { return reinterpret_cast(data_.vptr); } + + static void Copy(const ValueHandleBase& from, ValueHandleBase& to) { + if (from.empty_or_not_inlined()) { + // `from` is a simple pointer, just copy it. + to.data_.vptr = from.data_.vptr; + } else { + from.get().CopyTo(*reinterpret_cast(&to.data_.vptr)); + } + } + + static void Move(ValueHandleBase& from, ValueHandleBase& to) { + if (from.empty_or_not_inlined()) { + // `from` is a simple pointer, just swap it. + std::swap(from.data_.vptr, to.data_.vptr); + } else { + from.get().MoveTo(*reinterpret_cast(&to.data_.vptr)); + DestructInlined(from); + } + } + + static void DestructInlined(ValueHandleBase& handle) { + ABSL_ASSERT(!handle.empty_or_not_inlined()); + handle.get().~Value(); + handle.Reset(); + } + + ValueHandleData data_; +}; + +// All methods are called by `Transient`. Unlike `Persistent`, reference +// counting is not performed as `Transient` is a non-owning handle. +template <> +class ValueHandle final : public ValueHandleBase { + private: + using Base = ValueHandleBase; + + public: + ValueHandle() = default; + + template + explicit ValueHandle(InlinedResource, Args&&... args) + : ValueHandleBase(kHandleInPlace) { + static_assert(sizeof(T) <= sizeof(data_.padding), + "T cannot be inlined in Handle"); + static_assert(alignof(T) <= alignof(data_.padding), + "T cannot be inlined in Handle"); + // Same as std::construct_at from C++20. + ::new (const_cast(static_cast(&data_.padding))) + T(std::forward(args)...); + ABSL_ASSERT(absl::countr_zero(vptr()) >= + 2); // Verify the lower 2 bits are available. + } + + template + ValueHandle(UnmanagedResource, F& from) : ValueHandleBase(kHandleInPlace) { + uintptr_t vptr = reinterpret_cast( + static_cast(static_cast(std::addressof(from)))); + ABSL_ASSERT(absl::countr_zero(vptr) >= + 2); // Verify the lower 2 bits are available. + data_.vptr = reinterpret_cast(vptr | kValueHandleUnmanaged); + } + + ValueHandle(const TransientValueHandle& other) : ValueHandle() { + Base::Copy(other, *this); + } + + ValueHandle(TransientValueHandle&& other) : ValueHandle() { + Base::Move(other, *this); + } + + explicit ValueHandle(const PersistentValueHandle& other); + + ~ValueHandle() { + if (not_empty_and_inlined()) { + DestructInlined(*this); + } + } + + ValueHandle& operator=(const TransientValueHandle& other) { + if (not_empty_and_inlined()) { + DestructInlined(*this); + } + Base::Copy(other, *this); + return *this; + } + + ValueHandle& operator=(TransientValueHandle&& other) { + if (not_empty_and_inlined()) { + DestructInlined(*this); + } + Base::Move(other, *this); + return *this; + } + + ValueHandle& operator=(const PersistentValueHandle& other); +}; + +// All methods are called by `Persistent`. +template <> +class ValueHandle final : public ValueHandleBase { + private: + using Base = ValueHandleBase; + + public: + ValueHandle() = default; + + template + explicit ValueHandle(InlinedResource, Args&&... args) + : ValueHandleBase(kHandleInPlace) { + static_assert(sizeof(T) <= sizeof(data_.padding), + "T cannot be inlined in Handle"); + static_assert(alignof(T) <= alignof(data_.padding), + "T cannot be inlined in Handle"); + ::new (const_cast(static_cast(&data_.padding))) + T(std::forward(args)...); + ABSL_ASSERT(absl::countr_zero(vptr()) >= + 2); // Verify the lower 2 bits are available. + } + + template + ValueHandle(UnmanagedResource, F& from) : ValueHandleBase(kHandleInPlace) { + uintptr_t vptr = reinterpret_cast( + static_cast(static_cast(std::addressof(from)))); + ABSL_ASSERT(absl::countr_zero(vptr) >= + 2); // Verify the lower 2 bits are available. + data_.vptr = reinterpret_cast(vptr | kValueHandleUnmanaged); + } + + template + ValueHandle(ManagedResource, F& from) : ValueHandleBase(kHandleInPlace) { + uintptr_t vptr = reinterpret_cast( + static_cast(static_cast(std::addressof(from)))); + ABSL_ASSERT(absl::countr_zero(vptr) >= + 2); // Verify the lower 2 bits are available. + data_.vptr = reinterpret_cast(vptr | kValueHandleManaged); + } + + ValueHandle(const PersistentValueHandle& other) : ValueHandle() { + Base::Copy(other, *this); + if (reffed()) { + Ref(); + } + } + + ValueHandle(PersistentValueHandle&& other) : ValueHandle() { + Base::Move(other, *this); + } + + explicit ValueHandle(const TransientValueHandle& other) { + Base::Copy(other, *this); + if (reffed()) { + Ref(); + } + } + + ~ValueHandle() { + if (not_empty_and_inlined()) { + DestructInlined(*this); + } else if (reffed()) { + Unref(); + } + } + + ValueHandle& operator=(const PersistentValueHandle& other) { + if (not_empty_and_inlined()) { + DestructInlined(*this); + } else if (reffed()) { + Unref(); + } + Base::Copy(other, *this); + if (reffed()) { + Ref(); + } + return *this; + } + + ValueHandle& operator=(PersistentValueHandle&& other) { + if (not_empty_and_inlined()) { + DestructInlined(*this); + } else if (reffed()) { + Unref(); + } + Base::Move(other, *this); + return *this; + } + + ValueHandle& operator=(const TransientValueHandle& other) { + if (not_empty_and_inlined()) { + DestructInlined(*this); + } else if (reffed()) { + Unref(); + } + Base::Copy(other, *this); + if (reffed()) { + Ref(); + } + return *this; + } +}; + +inline ValueHandle::ValueHandle( + const PersistentValueHandle& other) + : ValueHandle() { + Base::Copy(other, *this); +} + +inline ValueHandle& ValueHandle< + HandleType::kTransient>::operator=(const PersistentValueHandle& other) { + if (not_empty_and_inlined()) { + DestructInlined(*this); + } + Base::Copy(other, *this); + return *this; +} + +// Specialization for Value providing the implementation to `Transient`. +template <> +struct HandleTraits { + using handle_type = ValueHandle; +}; + +// Partial specialization for `Transient` for all classes derived from Value. +template +struct HandleTraits && + !std::is_same_v)>> + final : public HandleTraits {}; + +// Specialization for Value providing the implementation to `Persistent`. +template <> +struct HandleTraits { + using handle_type = ValueHandle; +}; + +// Partial specialization for `Persistent` for all classes derived from Value. +template +struct HandleTraits && + !std::is_same_v)>> + final : public HandleTraits {}; + +} // namespace base_internal + +#define CEL_INTERNAL_VALUE_DECL(name) \ + extern template class Transient; \ + extern template class Transient; \ + extern template class Persistent; \ + extern template class Persistent +CEL_INTERNAL_VALUE_DECL(Value); +CEL_INTERNAL_VALUE_DECL(NullValue); +CEL_INTERNAL_VALUE_DECL(ErrorValue); +CEL_INTERNAL_VALUE_DECL(BoolValue); +CEL_INTERNAL_VALUE_DECL(IntValue); +CEL_INTERNAL_VALUE_DECL(UintValue); +CEL_INTERNAL_VALUE_DECL(DoubleValue); +CEL_INTERNAL_VALUE_DECL(BytesValue); +CEL_INTERNAL_VALUE_DECL(DurationValue); +CEL_INTERNAL_VALUE_DECL(TimestampValue); +#undef CEL_INTERNAL_VALUE_DECL + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_POST_H_ diff --git a/base/internal/value.pre.h b/base/internal/value.pre.h new file mode 100644 index 000000000..19ac1bca3 --- /dev/null +++ b/base/internal/value.pre.h @@ -0,0 +1,158 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "base/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_PRE_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_PRE_H_ + +#include +#include +#include + +#include "base/handle.h" + +namespace cel::base_internal { + +class ValueHandleBase; +template +class ValueHandle; + +// Convenient aliases. +using TransientValueHandle = ValueHandle; +using PersistentValueHandle = ValueHandle; + +// As all objects should be aligned to at least 4 bytes, we can use the lower +// two bits for our own purposes. +inline constexpr uintptr_t kValueHandleManaged = 1 << 0; +inline constexpr uintptr_t kValueHandleUnmanaged = 1 << 1; +inline constexpr uintptr_t kValueHandleBits = + kValueHandleManaged | kValueHandleUnmanaged; +inline constexpr uintptr_t kValueHandleMask = ~kValueHandleBits; + +class InlinedCordBytesValue; +class InlinedStringViewBytesValue; +class StringBytesValue; +class ExternalDataBytesValue; + +// Type erased state capable of holding a pointer to remote storage or storing +// objects less than two pointers in size inline. +union ExternalDataReleaserState final { + void* remote; + alignas(alignof(std::max_align_t)) char local[sizeof(void*) * 2]; +}; + +// Function which deletes the object referenced by ExternalDataReleaserState. +using ExternalDataReleaserDeleter = void(ExternalDataReleaserState* state); + +template +void LocalExternalDataReleaserDeleter(ExternalDataReleaserState* state) { + reinterpret_cast(&state->local)->~Releaser(); +} + +template +void RemoteExternalDataReleaserDeleter(ExternalDataReleaserState* state) { + ::delete reinterpret_cast(state->remote); +} + +// Function which invokes the object referenced by ExternalDataReleaserState. +using ExternalDataReleaseInvoker = + void(ExternalDataReleaserState* state) noexcept; + +template +void LocalExternalDataReleaserInvoker( + ExternalDataReleaserState* state) noexcept { + (*reinterpret_cast(&state->local))(); +} + +template +void RemoteExternalDataReleaserInvoker( + ExternalDataReleaserState* state) noexcept { + (*reinterpret_cast(&state->remote))(); +} + +struct ExternalDataReleaser final { + ExternalDataReleaser() = delete; + + template + explicit ExternalDataReleaser(Releaser&& releaser) { + using DecayedReleaser = std::decay_t; + if constexpr (sizeof(DecayedReleaser) <= sizeof(void*) * 2 && + alignof(DecayedReleaser) <= alignof(std::max_align_t)) { + // Object meets size and alignment constraints, will be stored + // inline in ExternalDataReleaserState.local. + ::new (static_cast(&state.local)) + DecayedReleaser(std::forward(releaser)); + invoker = LocalExternalDataReleaserInvoker; + if constexpr (std::is_trivially_destructible_v) { + // Object is trivially destructable, no need to call destructor at all. + deleter = nullptr; + } else { + deleter = LocalExternalDataReleaserDeleter; + } + } else { + // Object does not meet size and alignment constraints, allocate on the + // heap and store pointer in ExternalDataReleaserState::remote. inline in + // ExternalDataReleaserState::local. + state.remote = ::new DecayedReleaser(std::forward(releaser)); + invoker = RemoteExternalDataReleaserInvoker; + deleter = RemoteExternalDataReleaserDeleter; + } + } + + ExternalDataReleaser(const ExternalDataReleaser&) = delete; + + ExternalDataReleaser(ExternalDataReleaser&&) = delete; + + ~ExternalDataReleaser() { + (*invoker)(&state); + if (deleter != nullptr) { + (*deleter)(&state); + } + } + + ExternalDataReleaser& operator=(const ExternalDataReleaser&) = delete; + + ExternalDataReleaser& operator=(ExternalDataReleaser&&) = delete; + + ExternalDataReleaserState state; + ExternalDataReleaserDeleter* deleter; + ExternalDataReleaseInvoker* invoker; +}; + +// Utility class encompassing a contiguous array of data which a function that +// must be called when the data is no longer needed. +struct ExternalData final { + ExternalData() = delete; + + ExternalData(const void* data, size_t size, + std::unique_ptr releaser) + : data(data), size(size), releaser(std::move(releaser)) {} + + ExternalData(const ExternalData&) = delete; + + ExternalData(ExternalData&&) noexcept = default; + + ExternalData& operator=(const ExternalData&) = delete; + + ExternalData& operator=(ExternalData&&) noexcept = default; + + const void* data; + size_t size; + std::unique_ptr releaser; +}; + +} // namespace cel::base_internal + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_PRE_H_ diff --git a/base/value.cc b/base/value.cc index e28e20c81..396f87271 100644 --- a/base/value.cc +++ b/base/value.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -30,15 +31,92 @@ #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "base/internal/value.h" -#include "internal/reference_counted.h" +#include "absl/time/time.h" +#include "base/value_factory.h" +#include "internal/casts.h" +#include "internal/no_destructor.h" #include "internal/status_macros.h" #include "internal/strings.h" #include "internal/time.h" namespace cel { +#define CEL_INTERNAL_VALUE_IMPL(name) \ + template class Transient; \ + template class Transient; \ + template class Persistent; \ + template class Persistent +CEL_INTERNAL_VALUE_IMPL(Value); +CEL_INTERNAL_VALUE_IMPL(NullValue); +CEL_INTERNAL_VALUE_IMPL(ErrorValue); +CEL_INTERNAL_VALUE_IMPL(BoolValue); +CEL_INTERNAL_VALUE_IMPL(IntValue); +CEL_INTERNAL_VALUE_IMPL(UintValue); +CEL_INTERNAL_VALUE_IMPL(DoubleValue); +CEL_INTERNAL_VALUE_IMPL(BytesValue); +CEL_INTERNAL_VALUE_IMPL(DurationValue); +CEL_INTERNAL_VALUE_IMPL(TimestampValue); +#undef CEL_INTERNAL_VALUE_IMPL + +namespace { + +using base_internal::TransientHandleFactory; + +// Both are equivalent to std::construct_at implementation from C++20. +#define CEL_COPY_TO_IMPL(type, src, dest) \ + ::new (const_cast( \ + static_cast(std::addressof(dest)))) type(src) +#define CEL_MOVE_TO_IMPL(type, src, dest) \ + ::new (const_cast(static_cast( \ + std::addressof(dest)))) type(std::move(src)) + +} // namespace + +std::pair Value::SizeAndAlignment() const { + // Currently most implementations of Value are not reference counted, so those + // that are override this and those that do not inherit this. Using 0 here + // will trigger runtime asserts in case of undefined behavior. + return std::pair(0, 0); +} + +void Value::CopyTo(Value& address) const {} + +void Value::MoveTo(Value& address) {} + +Persistent NullValue::Get(ValueFactory& value_factory) { + return value_factory.GetNullValue(); +} + +Transient NullValue::type() const { + return TransientHandleFactory::MakeUnmanaged( + NullType::Get()); +} + +std::string NullValue::DebugString() const { return "null"; } + +const NullValue& NullValue::Get() { + static const internal::NoDestructor instance; + return *instance; +} + +void NullValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(NullValue, *this, address); +} + +void NullValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(NullValue, *this, address); +} + +bool NullValue::Equals(const Value& other) const { + return kind() == other.kind(); +} + +void NullValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), 0); +} + namespace { struct StatusPayload final { @@ -72,441 +150,250 @@ void StatusHashValue(absl::HashState state, const absl::Status& status) { } } -// SimpleValues holds common values that are frequently needed and should not be -// constructed everytime they are required, usually because they would require a -// heap allocation. An example of this is an empty byte string. -struct SimpleValues final { - public: - SimpleValues() = default; +} // namespace - SimpleValues(const SimpleValues&) = delete; +Transient ErrorValue::type() const { + return TransientHandleFactory::MakeUnmanaged( + ErrorType::Get()); +} - SimpleValues(SimpleValues&&) = delete; +std::string ErrorValue::DebugString() const { return value().ToString(); } - SimpleValues& operator=(const SimpleValues&) = delete; +void ErrorValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(ErrorValue, *this, address); +} - SimpleValues& operator=(SimpleValues&&) = delete; +void ErrorValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(ErrorValue, *this, address); +} - Value empty_bytes; -}; +bool ErrorValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void ErrorValue::HashValue(absl::HashState state) const { + StatusHashValue(absl::HashState::combine(std::move(state), type()), value()); +} -ABSL_CONST_INIT absl::once_flag simple_values_once; -ABSL_CONST_INIT SimpleValues* simple_values = nullptr; +Persistent BoolValue::False(ValueFactory& value_factory) { + return value_factory.CreateBoolValue(false); +} -} // namespace +Persistent BoolValue::True(ValueFactory& value_factory) { + return value_factory.CreateBoolValue(true); +} -Value Value::Error(const absl::Status& status) { - ABSL_ASSERT(!status.ok()); - if (ABSL_PREDICT_FALSE(status.ok())) { - return Value(absl::UnknownError( - "If you are seeing this message the caller attempted to construct an " - "error value from a successful status. Refusing to fail " - "successfully.")); - } - return Value(status); -} - -absl::StatusOr Value::Duration(absl::Duration value) { - CEL_RETURN_IF_ERROR(internal::ValidateDuration(value)); - int64_t seconds = absl::IDivDuration(value, absl::Seconds(1), &value); - int64_t nanoseconds = absl::IDivDuration(value, absl::Nanoseconds(1), &value); - return Value(Kind::kDuration, seconds, - absl::bit_cast(static_cast(nanoseconds))); -} - -absl::StatusOr Value::Timestamp(absl::Time value) { - CEL_RETURN_IF_ERROR(internal::ValidateTimestamp(value)); - absl::Duration duration = value - absl::UnixEpoch(); - int64_t seconds = absl::IDivDuration(duration, absl::Seconds(1), &duration); - int64_t nanoseconds = - absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); - return Value(Kind::kTimestamp, seconds, - absl::bit_cast(static_cast(nanoseconds))); -} - -Value::Value(const Value& other) { - // metadata_ is currently equal to the simple null type. - // content_ is zero initialized. - switch (other.kind()) { - case Kind::kNullType: - // `this` is already the null value, do nothing. - return; - case Kind::kBool: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kInt: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kUint: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kDouble: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kDuration: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kTimestamp: - // `other` is a simple value and simple type. We only need to trivially - // copy metadata_ and content_. - metadata_.CopyFrom(other.metadata_); - content_.construct_trivial_value(other.content_.trivial_value()); - return; - case Kind::kError: - // `other` is an error value and a simple type. We need to trivially copy - // metadata_ and copy construct the error value to content_. - metadata_.CopyFrom(other.metadata_); - content_.construct_error_value(other.content_.error_value()); - return; - case Kind::kBytes: - // `other` is a reffed value and a simple type. We need to trivially copy - // metadata_ and copy construct the reffed value to content_. - metadata_.CopyFrom(other.metadata_); - content_.construct_reffed_value(other.content_.reffed_value()); - return; - default: - // TODO(issues/5): remove after implementing other kinds - std::abort(); - } +Transient BoolValue::type() const { + return TransientHandleFactory::MakeUnmanaged( + BoolType::Get()); } -Value::Value(Value&& other) { - // metadata_ is currently equal to the simple null type. - // content_ is currently zero initialized. - switch (other.kind()) { - case Kind::kNullType: - // `this` and `other` are already the null value, do nothing. - return; - case Kind::kBool: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kInt: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kUint: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kDouble: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kDuration: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kTimestamp: - // `other` is a simple value and simple type. Trivially copy and then - // clear metadata_ and content_, making `other` equivalent to `Value()` or - // `Value::Null()`. - metadata_.MoveFrom(std::move(other.metadata_)); - content_.construct_trivial_value(other.content_.trivial_value()); - other.content_.destruct_trivial_value(); - break; - case Kind::kError: - // `other` is an error value and simple type. Trivially copy and then - // clear metadata_ and copy construct and then clear content_, making - // `other` equivalent to `Value()` or `Value::Null()`. - metadata_.MoveFrom(std::move(other.metadata_)); - content_.construct_error_value(other.content_.error_value()); - other.content_.destruct_error_value(); - break; - case Kind::kBytes: - // `other` is a reffed value and simple type. Trivially copy and then - // clear metadata_ and trivially move content_, making - // `other` equivalent to `Value()` or `Value::Null()`. - metadata_.MoveFrom(std::move(other.metadata_)); - content_.adopt_reffed_value(other.content_.release_reffed_value()); - break; - default: - // TODO(issues/5): remove after implementing other kinds - std::abort(); - } +std::string BoolValue::DebugString() const { + return value() ? "true" : "false"; +} + +void BoolValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(BoolValue, *this, address); +} + +void BoolValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(BoolValue, *this, address); } -Value::~Value() { Destruct(this); } - -Value& Value::operator=(const Value& other) { - if (ABSL_PREDICT_TRUE(this != std::addressof(other))) { - switch (other.kind()) { - case Kind::kNullType: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kBool: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kInt: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kUint: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kDouble: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kDuration: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kTimestamp: - // `this` could be a simple value, an error value, or a reffed value. - // First we destruct resetting `this` to `Value()`. Then we perform the - // equivalent work of the copy constructor. - Destruct(this); - metadata_.CopyFrom(other.metadata_); - content_.construct_trivial_value(other.content_.trivial_value()); - break; - case Kind::kError: - if (kind() == Kind::kError) { - // `this` and `other` are error values. Perform a copy assignment - // which is faster than destructing and copy constructing. - content_.assign_error_value(other.content_.error_value()); - } else { - // `this` could be a simple value or a reffed value. First we destruct - // resetting `this` to `Value()`. Then we perform the equivalent work - // of the copy constructor. - Destruct(this); - content_.construct_error_value(other.content_.error_value()); - } - // Always copy metadata, for forward compatibility in case other bits - // are added. - metadata_.CopyFrom(other.metadata_); - break; - case Kind::kBytes: { - // `this` could be a simple value, an error value, or a reffed value. - // First we destruct resetting `this` to `Value()`. Then we perform the - // equivalent work of the copy constructor. - base_internal::BaseValue* reffed_value = - internal::Ref(other.content_.reffed_value()); - Destruct(this); - metadata_.CopyFrom(other.metadata_); - // Adopt is typically used for moves, but in this case we already - // increment the reference count, so it is equivalent to a move. - content_.adopt_reffed_value(reffed_value); - } break; - default: - // TODO(issues/5): remove after implementing other kinds - std::abort(); +bool BoolValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void BoolValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); +} + +Transient IntValue::type() const { + return TransientHandleFactory::MakeUnmanaged( + IntType::Get()); +} + +std::string IntValue::DebugString() const { return absl::StrCat(value()); } + +void IntValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(IntValue, *this, address); +} + +void IntValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(IntValue, *this, address); +} + +bool IntValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void IntValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); +} + +Transient UintValue::type() const { + return TransientHandleFactory::MakeUnmanaged( + UintType::Get()); +} + +std::string UintValue::DebugString() const { + return absl::StrCat(value(), "u"); +} + +void UintValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(UintValue, *this, address); +} + +void UintValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(UintValue, *this, address); +} + +bool UintValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void UintValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); +} + +Persistent DoubleValue::NaN(ValueFactory& value_factory) { + return value_factory.CreateDoubleValue( + std::numeric_limits::quiet_NaN()); +} + +Persistent DoubleValue::PositiveInfinity( + ValueFactory& value_factory) { + return value_factory.CreateDoubleValue( + std::numeric_limits::infinity()); +} + +Persistent DoubleValue::NegativeInfinity( + ValueFactory& value_factory) { + return value_factory.CreateDoubleValue( + -std::numeric_limits::infinity()); +} + +Transient DoubleValue::type() const { + return TransientHandleFactory::MakeUnmanaged( + DoubleType::Get()); +} + +std::string DoubleValue::DebugString() const { + if (std::isfinite(value())) { + if (std::floor(value()) != value()) { + // The double is not representable as a whole number, so use + // absl::StrCat which will add decimal places. + return absl::StrCat(value()); } - } - return *this; -} - -Value& Value::operator=(Value&& other) { - if (ABSL_PREDICT_TRUE(this != std::addressof(other))) { - switch (other.kind()) { - case Kind::kNullType: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kBool: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kInt: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kUint: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kDouble: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kDuration: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kTimestamp: - // `this` could be a simple value, an error value, or a reffed value. - // First we destruct resetting `this` to `Value()`. Then we perform the - // equivalent work of the move constructor. - Destruct(this); - metadata_.MoveFrom(std::move(other.metadata_)); - content_.construct_trivial_value(other.content_.trivial_value()); - other.content_.destruct_trivial_value(); - break; - case Kind::kError: - if (kind() == Kind::kError) { - // `this` and `other` are error values. Perform a copy assignment - // which is faster than destructing and copy constructing. `other` - // will be reset below. - content_.assign_error_value(other.content_.error_value()); - } else { - // `this` could be a simple value or a reffed value. First we destruct - // resetting `this` to `Value()`. Then we perform the equivalent work - // of the copy constructor. - Destruct(this); - content_.construct_error_value(other.content_.error_value()); - } - // Always copy metadata, for forward compatibility in case other bits - // are added. - metadata_.CopyFrom(other.metadata_); - // Reset `other` to `Value()`. - Destruct(std::addressof(other)); - break; - case Kind::kBytes: - // `this` could be a simple value, an error value, or a reffed value. - // First we destruct resetting `this` to `Value()`. Then we perform the - // equivalent work of the move constructor. - Destruct(this); - metadata_.MoveFrom(std::move(other.metadata_)); - content_.adopt_reffed_value(other.content_.release_reffed_value()); - break; - default: - // TODO(issues/5): remove after implementing other kinds - std::abort(); + // absl::StrCat historically would represent 0.0 as 0, and we want the + // decimal places so ZetaSQL correctly assumes the type as double + // instead of int64_t. + std::string stringified = absl::StrCat(value()); + if (!absl::StrContains(stringified, '.')) { + absl::StrAppend(&stringified, ".0"); + } else { + // absl::StrCat has a decimal now? Use it directly. } + return stringified; } - return *this; -} - -std::string Value::DebugString() const { - switch (kind()) { - case Kind::kNullType: - return "null"; - case Kind::kBool: - return AsBool() ? "true" : "false"; - case Kind::kInt: - return absl::StrCat(AsInt()); - case Kind::kUint: - return absl::StrCat(AsUint(), "u"); - case Kind::kDouble: { - if (std::isfinite(AsDouble())) { - if (static_cast(static_cast(AsDouble())) != - AsDouble()) { - // The double is not representable as a whole number, so use - // absl::StrCat which will add decimal places. - return absl::StrCat(AsDouble()); - } - // absl::StrCat historically would represent 0.0 as 0, and we want the - // decimal places so ZetaSQL correctly assumes the type as double - // instead of int64_t. - std::string stringified = absl::StrCat(AsDouble()); - if (!absl::StrContains(stringified, '.')) { - absl::StrAppend(&stringified, ".0"); - } else { - // absl::StrCat has a decimal now? Use it directly. - } - return stringified; - } - if (std::isnan(AsDouble())) { - return "nan"; - } - if (std::signbit(AsDouble())) { - return "-infinity"; - } - return "+infinity"; - } - case Kind::kDuration: - return internal::FormatDuration(AsDuration()).value(); - case Kind::kTimestamp: - return internal::FormatTimestamp(AsTimestamp()).value(); - case Kind::kError: - return AsError().ToString(); - case Kind::kBytes: - return content_.reffed_value()->DebugString(); - default: - // TODO(issues/5): remove after implementing other kinds - std::abort(); + if (std::isnan(value())) { + return "nan"; + } + if (std::signbit(value())) { + return "-infinity"; } + return "+infinity"; } -void Value::InitializeSingletons() { - absl::call_once(simple_values_once, []() { - ABSL_ASSERT(simple_values == nullptr); - simple_values = new SimpleValues(); - simple_values->empty_bytes = Value(Kind::kBytes, new cel::Bytes()); - }); +void DoubleValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(DoubleValue, *this, address); } -void Value::Destruct(Value* dest) { - // Perform any deallocations or destructions necessary and reset the state - // of `dest` to `Value()` making it the null value. - switch (dest->kind()) { - case Kind::kNullType: - return; - case Kind::kBool: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kInt: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kUint: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kDouble: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kDuration: - ABSL_FALLTHROUGH_INTENDED; - case Kind::kTimestamp: - dest->content_.destruct_trivial_value(); - break; - case Kind::kError: - dest->content_.destruct_error_value(); - break; - case Kind::kBytes: - dest->content_.destruct_reffed_value(); - break; - default: - // TODO(issues/5): remove after implementing other kinds - std::abort(); - } - dest->metadata_.Reset(); -} - -void Value::HashValue(absl::HashState state) const { - state = absl::HashState::combine(std::move(state), type()); - switch (kind()) { - case Kind::kNullType: - absl::HashState::combine(std::move(state), 0); - return; - case Kind::kBool: - absl::HashState::combine(std::move(state), AsBool()); - return; - case Kind::kInt: - absl::HashState::combine(std::move(state), AsInt()); - return; - case Kind::kUint: - absl::HashState::combine(std::move(state), AsUint()); - return; - case Kind::kDouble: - absl::HashState::combine(std::move(state), AsDouble()); - return; - case Kind::kDuration: - absl::HashState::combine(std::move(state), AsDuration()); - return; - case Kind::kTimestamp: - absl::HashState::combine(std::move(state), AsTimestamp()); - return; - case Kind::kError: - StatusHashValue(std::move(state), AsError()); - return; - case Kind::kBytes: - content_.reffed_value()->HashValue(std::move(state)); - return; - default: - // TODO(issues/5): remove after implementing other kinds - std::abort(); - } +void DoubleValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(DoubleValue, *this, address); } -bool Value::Equals(const Value& other) const { - // Comparing types is not enough as type may only compare the type name, - // which could be the same in separate environments but different kinds. So - // we also compare the kinds. - if (kind() != other.kind() || type() != other.type()) { - return false; - } - switch (kind()) { - case Kind::kNullType: - return true; - case Kind::kBool: - return AsBool() == other.AsBool(); - case Kind::kInt: - return AsInt() == other.AsInt(); - case Kind::kUint: - return AsUint() == other.AsUint(); - case Kind::kDouble: - return AsDouble() == other.AsDouble(); - case Kind::kDuration: - return AsDuration() == other.AsDuration(); - case Kind::kTimestamp: - return AsTimestamp() == other.AsTimestamp(); - case Kind::kError: - return AsError() == other.AsError(); - case Kind::kBytes: - return content_.reffed_value()->Equals(other); - default: - // TODO(issues/5): remove after implementing other kinds - std::abort(); - } +bool DoubleValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); } -void Value::Swap(Value& other) { - // TODO(issues/5): Optimize this after other values are implemented - Value tmp(std::move(other)); - other = std::move(*this); - *this = std::move(tmp); +void DoubleValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); } -namespace { +Persistent DurationValue::Zero( + ValueFactory& value_factory) { + // Should never fail, tests assert this. + return value_factory.CreateDurationValue(absl::ZeroDuration()).value(); +} + +Transient DurationValue::type() const { + return TransientHandleFactory::MakeUnmanaged( + DurationType::Get()); +} + +std::string DurationValue::DebugString() const { + return internal::FormatDuration(value()).value(); +} + +void DurationValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(DurationValue, *this, address); +} + +void DurationValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(DurationValue, *this, address); +} + +bool DurationValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void DurationValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); +} + +Persistent TimestampValue::UnixEpoch( + ValueFactory& value_factory) { + // Should never fail, tests assert this. + return value_factory.CreateTimestampValue(absl::UnixEpoch()).value(); +} + +Transient TimestampValue::type() const { + return TransientHandleFactory::MakeUnmanaged( + TimestampType::Get()); +} + +std::string TimestampValue::DebugString() const { + return internal::FormatTimestamp(value()).value(); +} + +void TimestampValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(TimestampValue, *this, address); +} + +void TimestampValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(TimestampValue, *this, address); +} + +bool TimestampValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} -constexpr absl::string_view ExternalDataToStringView( - const base_internal::ExternalData& external_data) { - return absl::string_view(static_cast(external_data.data), - external_data.size); +void TimestampValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); } +namespace { + struct DebugStringVisitor final { - std::string operator()(const std::string& value) const { + std::string operator()(absl::string_view value) const { return internal::FormatBytesLiteral(value); } @@ -515,67 +402,30 @@ struct DebugStringVisitor final { if (value.GetFlat(&flat)) { return internal::FormatBytesLiteral(flat); } - return internal::FormatBytesLiteral(value.ToString()); - } - - std::string operator()(const base_internal::ExternalData& value) const { - return internal::FormatBytesLiteral(ExternalDataToStringView(value)); + return internal::FormatBytesLiteral(static_cast(value)); } }; -struct ToCordReleaser final { - void operator()() const { internal::Unref(refcnt); } - - const internal::ReferenceCounted* refcnt; -}; - struct ToStringVisitor final { - std::string operator()(const std::string& value) const { return value; } - - std::string operator()(const absl::Cord& value) const { - return value.ToString(); - } - - std::string operator()(const base_internal::ExternalData& value) const { - return std::string(static_cast(value.data), value.size); + std::string operator()(absl::string_view value) const { + return std::string(value); } -}; - -struct ToCordVisitor final { - const internal::ReferenceCounted* refcnt; - absl::Cord operator()(const std::string& value) const { - internal::Ref(refcnt); - return absl::MakeCordFromExternal(value, ToCordReleaser{refcnt}); - } - - absl::Cord operator()(const absl::Cord& value) const { return value; } - - absl::Cord operator()(const base_internal::ExternalData& value) const { - internal::Ref(refcnt); - return absl::MakeCordFromExternal(ExternalDataToStringView(value), - ToCordReleaser{refcnt}); + std::string operator()(const absl::Cord& value) const { + return static_cast(value); } }; struct SizeVisitor final { - size_t operator()(const std::string& value) const { return value.size(); } + size_t operator()(absl::string_view value) const { return value.size(); } size_t operator()(const absl::Cord& value) const { return value.size(); } - - size_t operator()(const base_internal::ExternalData& value) const { - return value.size; - } }; struct EmptyVisitor final { - bool operator()(const std::string& value) const { return value.empty(); } + bool operator()(absl::string_view value) const { return value.empty(); } bool operator()(const absl::Cord& value) const { return value.empty(); } - - bool operator()(const base_internal::ExternalData& value) const { - return value.size == 0; - } }; bool EqualsImpl(absl::string_view lhs, absl::string_view rhs) { @@ -615,7 +465,7 @@ class EqualsVisitor final { public: explicit EqualsVisitor(const T& ref) : ref_(ref) {} - bool operator()(const std::string& value) const { + bool operator()(absl::string_view value) const { return EqualsImpl(value, ref_); } @@ -623,29 +473,21 @@ class EqualsVisitor final { return EqualsImpl(value, ref_); } - bool operator()(const base_internal::ExternalData& value) const { - return EqualsImpl(ExternalDataToStringView(value), ref_); - } - private: const T& ref_; }; template <> -class EqualsVisitor final { +class EqualsVisitor final { public: - explicit EqualsVisitor(const Bytes& ref) : ref_(ref) {} + explicit EqualsVisitor(const BytesValue& ref) : ref_(ref) {} - bool operator()(const std::string& value) const { return ref_.Equals(value); } + bool operator()(absl::string_view value) const { return ref_.Equals(value); } bool operator()(const absl::Cord& value) const { return ref_.Equals(value); } - bool operator()(const base_internal::ExternalData& value) const { - return ref_.Equals(ExternalDataToStringView(value)); - } - private: - const Bytes& ref_; + const BytesValue& ref_; }; template @@ -653,7 +495,7 @@ class CompareVisitor final { public: explicit CompareVisitor(const T& ref) : ref_(ref) {} - int operator()(const std::string& value) const { + int operator()(absl::string_view value) const { return CompareImpl(value, ref_); } @@ -661,38 +503,28 @@ class CompareVisitor final { return CompareImpl(value, ref_); } - int operator()(const base_internal::ExternalData& value) const { - return CompareImpl(ExternalDataToStringView(value), ref_); - } - private: const T& ref_; }; template <> -class CompareVisitor final { +class CompareVisitor final { public: - explicit CompareVisitor(const Bytes& ref) : ref_(ref) {} - - int operator()(const std::string& value) const { return ref_.Compare(value); } + explicit CompareVisitor(const BytesValue& ref) : ref_(ref) {} int operator()(const absl::Cord& value) const { return ref_.Compare(value); } int operator()(absl::string_view value) const { return ref_.Compare(value); } - int operator()(const base_internal::ExternalData& value) const { - return ref_.Compare(ExternalDataToStringView(value)); - } - private: - const Bytes& ref_; + const BytesValue& ref_; }; class HashValueVisitor final { public: explicit HashValueVisitor(absl::HashState state) : state_(std::move(state)) {} - void operator()(const std::string& value) { + void operator()(absl::string_view value) { absl::HashState::combine(std::move(state_), value); } @@ -700,90 +532,154 @@ class HashValueVisitor final { absl::HashState::combine(std::move(state_), value); } - void operator()(const base_internal::ExternalData& value) { - absl::HashState::combine(std::move(state_), - ExternalDataToStringView(value)); - } - private: absl::HashState state_; }; } // namespace -Value Bytes::Empty() { - Value::InitializeSingletons(); - return simple_values->empty_bytes; +Persistent BytesValue::Empty(ValueFactory& value_factory) { + return value_factory.GetBytesValue(); } -Value Bytes::New(std::string value) { - if (value.empty()) { - return Empty(); - } - return Value(Kind::kBytes, new Bytes(std::move(value))); +absl::StatusOr> BytesValue::Concat( + ValueFactory& value_factory, const Transient& lhs, + const Transient& rhs) { + absl::Cord cord = lhs->ToCord(base_internal::IsManagedHandle(lhs)); + cord.Append(rhs->ToCord(base_internal::IsManagedHandle(rhs))); + return value_factory.CreateBytesValue(std::move(cord)); } -Value Bytes::New(absl::Cord value) { - if (value.empty()) { - return Empty(); - } - return Value(Kind::kBytes, new Bytes(std::move(value))); +Transient BytesValue::type() const { + return TransientHandleFactory::MakeUnmanaged( + BytesType::Get()); } -Value Bytes::Concat(const Bytes& lhs, const Bytes& rhs) { - absl::Cord value; - value.Append(lhs.ToCord()); - value.Append(rhs.ToCord()); - return New(std::move(value)); +size_t BytesValue::size() const { return absl::visit(SizeVisitor{}, rep()); } + +bool BytesValue::empty() const { return absl::visit(EmptyVisitor{}, rep()); } + +bool BytesValue::Equals(absl::string_view bytes) const { + return absl::visit(EqualsVisitor(bytes), rep()); +} + +bool BytesValue::Equals(const absl::Cord& bytes) const { + return absl::visit(EqualsVisitor(bytes), rep()); +} + +bool BytesValue::Equals(const Transient& bytes) const { + return absl::visit(EqualsVisitor(*this), bytes->rep()); } -size_t Bytes::size() const { return absl::visit(SizeVisitor{}, data_); } +int BytesValue::Compare(absl::string_view bytes) const { + return absl::visit(CompareVisitor(bytes), rep()); +} + +int BytesValue::Compare(const absl::Cord& bytes) const { + return absl::visit(CompareVisitor(bytes), rep()); +} + +int BytesValue::Compare(const Transient& bytes) const { + return absl::visit(CompareVisitor(*this), bytes->rep()); +} + +std::string BytesValue::ToString() const { + return absl::visit(ToStringVisitor{}, rep()); +} + +std::string BytesValue::DebugString() const { + return absl::visit(DebugStringVisitor{}, rep()); +} + +bool BytesValue::Equals(const Value& other) const { + return kind() == other.kind() && + absl::visit(EqualsVisitor(*this), + internal::down_cast(other).rep()); +} -bool Bytes::empty() const { return absl::visit(EmptyVisitor{}, data_); } +void BytesValue::HashValue(absl::HashState state) const { + absl::visit( + HashValueVisitor(absl::HashState::combine(std::move(state), type())), + rep()); +} + +namespace base_internal { + +absl::Cord InlinedCordBytesValue::ToCord(bool reference_counted) const { + static_cast(reference_counted); + return value_; +} -bool Bytes::Equals(absl::string_view bytes) const { - return absl::visit(EqualsVisitor(bytes), data_); +void InlinedCordBytesValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(InlinedCordBytesValue, *this, address); } -bool Bytes::Equals(const absl::Cord& bytes) const { - return absl::visit(EqualsVisitor(bytes), data_); +void InlinedCordBytesValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(InlinedCordBytesValue, *this, address); } -bool Bytes::Equals(const Bytes& bytes) const { - return absl::visit(EqualsVisitor(*this), bytes.data_); +typename InlinedCordBytesValue::Rep InlinedCordBytesValue::rep() const { + return Rep(absl::in_place_type>, + std::cref(value_)); } -int Bytes::Compare(absl::string_view bytes) const { - return absl::visit(CompareVisitor(bytes), data_); +absl::Cord InlinedStringViewBytesValue::ToCord(bool reference_counted) const { + static_cast(reference_counted); + return absl::Cord(value_); } -int Bytes::Compare(const absl::Cord& bytes) const { - return absl::visit(CompareVisitor(bytes), data_); +void InlinedStringViewBytesValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(InlinedStringViewBytesValue, *this, address); } -int Bytes::Compare(const Bytes& bytes) const { - return absl::visit(CompareVisitor(*this), bytes.data_); +void InlinedStringViewBytesValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(InlinedStringViewBytesValue, *this, address); } -std::string Bytes::ToString() const { - return absl::visit(ToStringVisitor{}, data_); +typename InlinedStringViewBytesValue::Rep InlinedStringViewBytesValue::rep() + const { + return Rep(absl::in_place_type, value_); } -absl::Cord Bytes::ToCord() const { - return absl::visit(ToCordVisitor{this}, data_); +std::pair StringBytesValue::SizeAndAlignment() const { + return std::make_pair(sizeof(StringBytesValue), alignof(StringBytesValue)); } -std::string Bytes::DebugString() const { - return absl::visit(DebugStringVisitor{}, data_); +absl::Cord StringBytesValue::ToCord(bool reference_counted) const { + if (reference_counted) { + Ref(); + return absl::MakeCordFromExternal(absl::string_view(value_), + [this]() { Unref(); }); + } + return absl::Cord(value_); } -bool Bytes::Equals(const Value& value) const { - ABSL_ASSERT(value.IsBytes()); - return absl::visit(EqualsVisitor(*this), value.AsBytes().data_); +typename StringBytesValue::Rep StringBytesValue::rep() const { + return Rep(absl::in_place_type, absl::string_view(value_)); } -void Bytes::HashValue(absl::HashState state) const { - absl::visit(HashValueVisitor(std::move(state)), data_); +std::pair ExternalDataBytesValue::SizeAndAlignment() const { + return std::make_pair(sizeof(ExternalDataBytesValue), + alignof(ExternalDataBytesValue)); } +absl::Cord ExternalDataBytesValue::ToCord(bool reference_counted) const { + if (reference_counted) { + Ref(); + return absl::MakeCordFromExternal( + absl::string_view(static_cast(value_.data), value_.size), + [this]() { Unref(); }); + } + return absl::Cord( + absl::string_view(static_cast(value_.data), value_.size)); +} + +typename ExternalDataBytesValue::Rep ExternalDataBytesValue::rep() const { + return Rep( + absl::in_place_type, + absl::string_view(static_cast(value_.data), value_.size)); +} + +} // namespace base_internal + } // namespace cel diff --git a/base/value.h b/base/value.h index 5b62ff940..c123753bf 100644 --- a/base/value.h +++ b/base/value.h @@ -16,365 +16,484 @@ #define THIRD_PARTY_CEL_CPP_BASE_VALUE_H_ #include -#include -#include +#include #include #include #include "absl/base/attributes.h" -#include "absl/base/casts.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/variant.h" -#include "base/internal/value.h" +#include "base/handle.h" +#include "base/internal/value.pre.h" // IWYU pragma: export #include "base/kind.h" +#include "base/memory_manager.h" #include "base/type.h" -#include "internal/casts.h" namespace cel { +class Value; +class NullValue; +class ErrorValue; +class BoolValue; +class IntValue; +class UintValue; +class DoubleValue; +class BytesValue; +class DurationValue; +class TimestampValue; +class ValueFactory; + +namespace internal { +template +class NoDestructor; +} + // A representation of a CEL value that enables reflection and introspection of // values. -// -// TODO(issues/5): document once derived implementations stabilize -class Value final { +class Value : public base_internal::Resource { public: - // Returns the null value. - ABSL_ATTRIBUTE_PURE_FUNCTION static Value Null() { return Value(); } + // Returns the type of the value. If you only need the kind, prefer `kind()`. + virtual Transient type() const = 0; - // Constructs an error value. It is required that `status` is non-OK, - // otherwise behavior is undefined. - static Value Error(const absl::Status& status); + // Returns the kind of the value. This is equivalent to `type().kind()` but + // faster in many scenarios. As such it should be preffered when only the kind + // is required. + virtual Kind kind() const { return type()->kind(); } - // Returns a bool value. - static Value Bool(bool value) { return Value(value); } + virtual std::string DebugString() const = 0; - // Returns the false bool value. Equivalent to `Value::Bool(false)`. - ABSL_ATTRIBUTE_PURE_FUNCTION static Value False() { return Bool(false); } + private: + friend class NullValue; + friend class ErrorValue; + friend class BoolValue; + friend class IntValue; + friend class UintValue; + friend class DoubleValue; + friend class BytesValue; + friend class DurationValue; + friend class TimestampValue; + friend class base_internal::ValueHandleBase; + friend class base_internal::StringBytesValue; + friend class base_internal::ExternalDataBytesValue; + + Value() = default; + Value(const Value&) = default; + Value(Value&&) = default; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return true; } + + // For non-inlined values that are reference counted, this is the result of + // `sizeof` and `alignof` for the most derived class. + std::pair SizeAndAlignment() const override; + + // Expose to some value implementations using friendship. + using base_internal::Resource::Ref; + using base_internal::Resource::Unref; + + // Called by base_internal::ValueHandleBase for inlined values. + virtual void CopyTo(Value& address) const; + + // Called by base_internal::ValueHandleBase for inlined values. + virtual void MoveTo(Value& address); + + // Called by base_internal::ValueHandleBase. + virtual bool Equals(const Value& other) const = 0; + + // Called by base_internal::ValueHandleBase. + virtual void HashValue(absl::HashState state) const = 0; +}; - // Returns the true bool value. Equivalent to `Value::Bool(true)`. - ABSL_ATTRIBUTE_PURE_FUNCTION static Value True() { return Bool(true); } +class NullValue final : public Value, public base_internal::ResourceInlined { + public: + static Persistent Get(ValueFactory& value_factory); - // Returns an int value. - static Value Int(int64_t value) { return Value(value); } + Transient type() const override; - // Returns a uint value. - static Value Uint(uint64_t value) { return Value(value); } + Kind kind() const override { return Kind::kNullType; } - // Returns a double value. - static Value Double(double value) { return Value(value); } + std::string DebugString() const override; - // Returns a NaN double value. Equivalent to `Value::Double(NAN)`. - ABSL_ATTRIBUTE_PURE_FUNCTION static Value NaN() { - return Double(std::numeric_limits::quiet_NaN()); - } + private: + friend class ValueFactory; + template + friend class internal::NoDestructor; + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kNullType; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const NullValue& Get(); + + NullValue() = default; + NullValue(const NullValue&) = default; + NullValue(NullValue&&) = default; + + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; +}; - // Returns a positive infinity double value. Equivalent to - // `Value::Double(INFINITY)`. - ABSL_ATTRIBUTE_PURE_FUNCTION static Value PositiveInfinity() { - return Double(std::numeric_limits::infinity()); - } +class ErrorValue final : public Value, public base_internal::ResourceInlined { + public: + Transient type() const override; - // Returns a negative infinity double value. Equivalent to - // `Value::Double(-INFINITY)`. - ABSL_ATTRIBUTE_PURE_FUNCTION static Value NegativeInfinity() { - return Double(-std::numeric_limits::infinity()); - } + Kind kind() const override { return Kind::kError; } - // Returns a duration value or a `absl::StatusCode::kInvalidArgument` error if - // the value is not in the valid range. - static absl::StatusOr Duration(absl::Duration value); + std::string DebugString() const override; - // Returns the zero duration value. Equivalent to - // `Value::Duration(absl::ZeroDuration())`. - ABSL_ATTRIBUTE_PURE_FUNCTION static Value ZeroDuration() { - return Value(Kind::kDuration, 0, 0); - } + const absl::Status& value() const { return value_; } - // Returns a timestamp value or a `absl::StatusCode::kInvalidArgument` error - // if the value is not in the valid range. - static absl::StatusOr Timestamp(absl::Time value); + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; - // Returns the zero timestamp value. Equivalent to - // `Value::Timestamp(absl::UnixEpoch())`. - ABSL_ATTRIBUTE_PURE_FUNCTION static Value UnixEpoch() { - return Value(Kind::kTimestamp, 0, 0); - } + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kError; } - // Equivalent to `Value::Null()`. - constexpr Value() = default; + // Called by `base_internal::ValueHandle` to construct value inline. + explicit ErrorValue(absl::Status value) : value_(std::move(value)) {} - Value(const Value& other); + ErrorValue() = delete; - Value(Value&& other); + ErrorValue(const ErrorValue&) = default; + ErrorValue(ErrorValue&&) = default; - ~Value(); + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; - Value& operator=(const Value& other); + absl::Status value_; +}; - Value& operator=(Value&& other); +class BoolValue final : public Value, public base_internal::ResourceInlined { + public: + static Persistent False(ValueFactory& value_factory); - // Returns the type of the value. If you only need the kind, prefer `kind()`. - cel::Type type() const { - return metadata_.simple_tag() - ? cel::Type::Simple(metadata_.kind()) - : cel::Type(internal::Ref(metadata_.base_type())); - } + static Persistent True(ValueFactory& value_factory); - // Returns the kind of the value. This is equivalent to `type().kind()` but - // faster in many scenarios. As such it should be preffered when only the kind - // is required. - Kind kind() const { return metadata_.kind(); } + Transient type() const override; - // True if this is the null value, false otherwise. - bool IsNull() const { return kind() == Kind::kNullType; } + Kind kind() const override { return Kind::kBool; } - // True if this is an error value, false otherwise. - bool IsError() const { return kind() == Kind::kError; } + std::string DebugString() const override; - // True if this is a bool value, false otherwise. - bool IsBool() const { return kind() == Kind::kBool; } + constexpr bool value() const { return value_; } - // True if this is an int value, false otherwise. - bool IsInt() const { return kind() == Kind::kInt; } + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; - // True if this is a uint value, false otherwise. - bool IsUint() const { return kind() == Kind::kUint; } + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kBool; } - // True if this is a double value, false otherwise. - bool IsDouble() const { return kind() == Kind::kDouble; } + // Called by `base_internal::ValueHandle` to construct value inline. + explicit BoolValue(bool value) : value_(value) {} - // True if this is a duration value, false otherwise. - bool IsDuration() const { return kind() == Kind::kDuration; } + BoolValue() = delete; - // True if this is a timestamp value, false otherwise. - bool IsTimestamp() const { return kind() == Kind::kTimestamp; } + BoolValue(const BoolValue&) = default; + BoolValue(BoolValue&&) = default; - // True if this is a bytes value, false otherwise. - bool IsBytes() const { return kind() == Kind::kBytes; } + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; - // Returns the C++ error value. Requires `kind() == Kind::kError` or behavior - // is undefined. - const absl::Status& AsError() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - ABSL_ASSERT(IsError()); - return content_.error_value(); - } + bool value_; +}; - // Returns the C++ bool value. Requires `kind() == Kind::kBool` or behavior is - // undefined. - bool AsBool() const { - ABSL_ASSERT(IsBool()); - return content_.bool_value(); - } +class IntValue final : public Value, public base_internal::ResourceInlined { + public: + Transient type() const override; - // Returns the C++ int value. Requires `kind() == Kind::kInt` or behavior is - // undefined. - int64_t AsInt() const { - ABSL_ASSERT(IsInt()); - return content_.int_value(); - } + Kind kind() const override { return Kind::kInt; } - // Returns the C++ uint value. Requires `kind() == Kind::kUint` or behavior is - // undefined. - uint64_t AsUint() const { - ABSL_ASSERT(IsUint()); - return content_.uint_value(); - } + std::string DebugString() const override; - // Returns the C++ double value. Requires `kind() == Kind::kDouble` or - // behavior is undefined. - double AsDouble() const { - ABSL_ASSERT(IsDouble()); - return content_.double_value(); - } + constexpr int64_t value() const { return value_; } - // Returns the C++ duration value. Requires `kind() == Kind::kDuration` or - // behavior is undefined. - absl::Duration AsDuration() const { - ABSL_ASSERT(IsDuration()); - return absl::Seconds(content_.int_value()) + - absl::Nanoseconds( - absl::bit_cast(metadata_.extended_content())); - } + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; - // Returns the C++ timestamp value. Requires `kind() == Kind::kTimestamp` or - // behavior is undefined. - absl::Time AsTimestamp() const { - // Timestamp is stored as the duration since Unix Epoch. - ABSL_ASSERT(IsTimestamp()); - return absl::UnixEpoch() + absl::Seconds(content_.int_value()) + - absl::Nanoseconds( - absl::bit_cast(metadata_.extended_content())); - } + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kInt; } - std::string DebugString() const; + // Called by `base_internal::ValueHandle` to construct value inline. + explicit IntValue(int64_t value) : value_(value) {} - const Bytes& AsBytes() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - ABSL_ASSERT(IsBytes()); - return internal::down_cast(*content_.reffed_value()); - } + IntValue() = delete; - template - friend H AbslHashValue(H state, const Value& value) { - value.HashValue(absl::HashState::Create(&state)); - return std::move(state); - } + IntValue(const IntValue&) = default; + IntValue(IntValue&&) = default; - friend void swap(Value& lhs, Value& rhs) { lhs.Swap(rhs); } + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; - friend bool operator==(const Value& lhs, const Value& rhs) { - return lhs.Equals(rhs); - } + int64_t value_; +}; - friend bool operator!=(const Value& lhs, const Value& rhs) { - return !operator==(lhs, rhs); - } +class UintValue final : public Value, public base_internal::ResourceInlined { + public: + Transient type() const override; + + Kind kind() const override { return Kind::kUint; } + + std::string DebugString() const override; + + constexpr uint64_t value() const { return value_; } private: - friend class Bytes; + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; - using Metadata = base_internal::ValueMetadata; - using Content = base_internal::ValueContent; + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kUint; } - static void InitializeSingletons(); + // Called by `base_internal::ValueHandle` to construct value inline. + explicit UintValue(uint64_t value) : value_(value) {} - static void Destruct(Value* dest); + UintValue() = delete; - constexpr explicit Value(bool value) - : metadata_(Kind::kBool), content_(value) {} + UintValue(const UintValue&) = default; + UintValue(UintValue&&) = default; - constexpr explicit Value(int64_t value) - : metadata_(Kind::kInt), content_(value) {} + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; - constexpr explicit Value(uint64_t value) - : metadata_(Kind::kUint), content_(value) {} + uint64_t value_; +}; - constexpr explicit Value(double value) - : metadata_(Kind::kDouble), content_(value) {} +class DoubleValue final : public Value, public base_internal::ResourceInlined { + public: + static Persistent NaN(ValueFactory& value_factory); - explicit Value(const absl::Status& status) - : metadata_(Kind::kError), content_(status) {} + static Persistent PositiveInfinity( + ValueFactory& value_factory); - constexpr Value(Kind kind, base_internal::BaseValue* base_value) - : metadata_(kind), content_(base_value) {} + static Persistent NegativeInfinity( + ValueFactory& value_factory); - constexpr Value(Kind kind, int64_t content, uint32_t extended_content) - : metadata_(kind, extended_content), content_(content) {} + Transient type() const override; - bool Equals(const Value& other) const; + Kind kind() const override { return Kind::kDouble; } - void HashValue(absl::HashState state) const; + std::string DebugString() const override; - void Swap(Value& other); + constexpr double value() const { return value_; } - Metadata metadata_; - Content content_; -}; + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; -// A CEL bytes value specific interface that can be accessed via -// `cel::Value::AsBytes`. It acts as a facade over various native -// representations and provides efficient implementations of CEL builtin -// functions. -class Bytes final : public base_internal::BaseValue { - public: - // Returns a bytes value which has a size of 0 and is empty. - ABSL_ATTRIBUTE_PURE_FUNCTION static Value Empty(); + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kDouble; } - // Returns a bytes value with `value` as its contents. - static Value New(std::string value); + // Called by `base_internal::ValueHandle` to construct value inline. + explicit DoubleValue(double value) : value_(value) {} - // Returns a bytes value with a copy of `value` as its contents. - static Value New(absl::string_view value) { - return New(std::string(value.data(), value.size())); - } + DoubleValue() = delete; - // Returns a bytes value with a copy of `value` as its contents. - // - // This is needed for `Value::Bytes("foo")` to be an unambiguous function - // call. - static Value New(const char* value) { - ABSL_ASSERT(value != nullptr); - return New(absl::string_view(value)); - } + DoubleValue(const DoubleValue&) = default; + DoubleValue(DoubleValue&&) = default; - // Returns a bytes value with `value` as its contents. - static Value New(absl::Cord value); + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; - // Returns a bytes value with `value` as its contents. Unlike `New()` this - // does not copy `value`, instead it expects the contents pointed to by - // `value` to live as long as the returned instance. `releaser` is used to - // notify the caller when the contents pointed to by `value` are no longer - // required. - template - static std::enable_if_t, Value> Wrap( - absl::string_view value, Releaser&& releaser); + double value_; +}; - static Value Concat(const Bytes& lhs, const Bytes& rhs); +class BytesValue : public Value { + protected: + using Rep = absl::variant>; + + public: + static Persistent Empty(ValueFactory& value_factory); + + // Concat concatenates the contents of two ByteValue, returning a new + // ByteValue. The resulting ByteValue is not tied to the lifetime of either of + // the input ByteValue. + static absl::StatusOr> Concat( + ValueFactory& value_factory, const Transient& lhs, + const Transient& rhs); + + Transient type() const final; + + Kind kind() const final { return Kind::kBytes; } + + std::string DebugString() const final; size_t size() const; bool empty() const; bool Equals(absl::string_view bytes) const; - bool Equals(const absl::Cord& bytes) const; - - bool Equals(const Bytes& bytes) const; + bool Equals(const Transient& bytes) const; int Compare(absl::string_view bytes) const; - int Compare(const absl::Cord& bytes) const; - - int Compare(const Bytes& bytes) const; + int Compare(const Transient& bytes) const; std::string ToString() const; - absl::Cord ToCord() const; + absl::Cord ToCord() const { + // Without the handle we cannot know if this is reference counted. + return ToCord(/*reference_counted=*/false); + } - std::string DebugString() const override; + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + friend class base_internal::InlinedCordBytesValue; + friend class base_internal::InlinedStringViewBytesValue; + friend class base_internal::StringBytesValue; + friend class base_internal::ExternalDataBytesValue; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kBytes; } + + BytesValue() = default; + BytesValue(const BytesValue&) = default; + BytesValue(BytesValue&&) = default; + + // Get the contents of this BytesValue as absl::Cord. When reference_counted + // is true, the implementation can potentially return an absl::Cord that wraps + // the contents instead of copying. + virtual absl::Cord ToCord(bool reference_counted) const = 0; + + // Get the contents of this BytesValue as either absl::string_view or const + // absl::Cord&. + virtual Rep rep() const = 0; + + // See comments for respective member functions on `Value`. + bool Equals(const Value& other) const final; + void HashValue(absl::HashState state) const final; +}; - protected: - bool Equals(const Value& value) const override; +class DurationValue final : public Value, + public base_internal::ResourceInlined { + public: + static Persistent Zero(ValueFactory& value_factory); - void HashValue(absl::HashState state) const override; + Transient type() const override; + + Kind kind() const override { return Kind::kDuration; } + + std::string DebugString() const override; + + constexpr absl::Duration value() const { return value_; } private: - friend class Value; + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; - Bytes() : Bytes(std::string()) {} + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kDuration; } - explicit Bytes(std::string value) - : base_internal::BaseValue(), - data_(absl::in_place_index<0>, std::move(value)) {} + // Called by `base_internal::ValueHandle` to construct value inline. + explicit DurationValue(absl::Duration value) : value_(value) {} - explicit Bytes(absl::Cord value) - : base_internal::BaseValue(), - data_(absl::in_place_index<1>, std::move(value)) {} + DurationValue() = delete; - explicit Bytes(base_internal::ExternalData value) - : base_internal::BaseValue(), - data_(absl::in_place_index<2>, std::move(value)) {} + DurationValue(const DurationValue&) = default; + DurationValue(DurationValue&&) = default; - absl::variant data_; + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; + + absl::Duration value_; }; -template -std::enable_if_t, Value> Bytes::Wrap( - absl::string_view value, Releaser&& releaser) { - if (value.empty()) { - std::forward(releaser)(); - return Empty(); +class TimestampValue final : public Value, + public base_internal::ResourceInlined { + public: + static Persistent UnixEpoch( + ValueFactory& value_factory); + + Transient type() const override; + + Kind kind() const override { return Kind::kTimestamp; } + + std::string DebugString() const override; + + constexpr absl::Time value() const { return value_; } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { + return value.kind() == Kind::kTimestamp; } - return Value(Kind::kBytes, - new Bytes(base_internal::ExternalData( - value.data(), value.size(), - std::make_unique( - std::forward(releaser))))); -} + + // Called by `base_internal::ValueHandle` to construct value inline. + explicit TimestampValue(absl::Time value) : value_(value) {} + + TimestampValue() = delete; + + TimestampValue(const TimestampValue&) = default; + TimestampValue(TimestampValue&&) = default; + + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; + + absl::Time value_; +}; } // namespace cel +// value.pre.h forward declares types so they can be friended above. The types +// themselves need to be defined after everything else as they need to access or +// derive from the above types. We do this in value.post.h to avoid mudying this +// header and making it difficult to read. +#include "base/internal/value.post.h" // IWYU pragma: export + #endif // THIRD_PARTY_CEL_CPP_BASE_VALUE_H_ diff --git a/base/value_factory.cc b/base/value_factory.cc new file mode 100644 index 000000000..59fcac40c --- /dev/null +++ b/base/value_factory.cc @@ -0,0 +1,117 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/value_factory.h" + +#include +#include + +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "base/handle.h" +#include "base/value.h" +#include "internal/status_macros.h" +#include "internal/time.h" + +namespace cel { + +namespace { + +using base_internal::ExternalDataBytesValue; +using base_internal::InlinedCordBytesValue; +using base_internal::InlinedStringViewBytesValue; +using base_internal::PersistentHandleFactory; +using base_internal::StringBytesValue; +using base_internal::TransientHandleFactory; + +} // namespace + +Persistent ValueFactory::GetNullValue() { + return Persistent( + TransientHandleFactory::MakeUnmanaged( + NullValue::Get())); +} + +Persistent ValueFactory::CreateErrorValue( + absl::Status status) { + if (ABSL_PREDICT_FALSE(status.ok())) { + status = absl::UnknownError( + "If you are seeing this message the caller attempted to construct an " + "error value from a successful status. Refusing to fail successfully."); + } + return PersistentHandleFactory::Make( + std::move(status)); +} + +Persistent ValueFactory::CreateBoolValue(bool value) { + return PersistentHandleFactory::Make(value); +} + +Persistent ValueFactory::CreateIntValue(int64_t value) { + return PersistentHandleFactory::Make(value); +} + +Persistent ValueFactory::CreateUintValue(uint64_t value) { + return PersistentHandleFactory::Make(value); +} + +Persistent ValueFactory::CreateDoubleValue(double value) { + return PersistentHandleFactory::Make(value); +} + +absl::StatusOr> ValueFactory::CreateBytesValue( + std::string value) { + if (value.empty()) { + return GetEmptyBytesValue(); + } + return PersistentHandleFactory::Make( + memory_manager(), std::move(value)); +} + +absl::StatusOr> ValueFactory::CreateBytesValue( + absl::Cord value) { + if (value.empty()) { + return GetEmptyBytesValue(); + } + return PersistentHandleFactory::Make( + std::move(value)); +} + +absl::StatusOr> +ValueFactory::CreateDurationValue(absl::Duration value) { + CEL_RETURN_IF_ERROR(internal::ValidateDuration(value)); + return PersistentHandleFactory::Make( + value); +} + +absl::StatusOr> +ValueFactory::CreateTimestampValue(absl::Time value) { + CEL_RETURN_IF_ERROR(internal::ValidateTimestamp(value)); + return PersistentHandleFactory::Make( + value); +} + +Persistent ValueFactory::GetEmptyBytesValue() { + return PersistentHandleFactory::Make< + InlinedStringViewBytesValue>(absl::string_view()); +} + +absl::StatusOr> ValueFactory::CreateBytesValue( + base_internal::ExternalData value) { + return PersistentHandleFactory::Make< + ExternalDataBytesValue>(memory_manager(), std::move(value)); +} + +} // namespace cel diff --git a/base/value_factory.h b/base/value_factory.h new file mode 100644 index 000000000..02b0d32ca --- /dev/null +++ b/base/value_factory.h @@ -0,0 +1,114 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_VALUE_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_BASE_VALUE_FACTORY_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "base/handle.h" +#include "base/memory_manager.h" +#include "base/value.h" + +namespace cel { + +class ValueFactory { + public: + virtual ~ValueFactory() = default; + + Persistent GetNullValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent CreateErrorValue(absl::Status status) + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent CreateBoolValue(bool value) + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent CreateIntValue(int64_t value) + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent CreateUintValue(uint64_t value) + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent CreateDoubleValue(double value) + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + Persistent GetBytesValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetEmptyBytesValue(); + } + + absl::StatusOr> CreateBytesValue( + const char* value) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return CreateBytesValue(absl::string_view(value)); + } + + absl::StatusOr> CreateBytesValue( + absl::string_view value) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return CreateBytesValue(std::string(value)); + } + + absl::StatusOr> CreateBytesValue( + std::string value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::StatusOr> CreateBytesValue( + absl::Cord value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + + template + absl::StatusOr> CreateBytesValue( + absl::string_view value, + Releaser&& releaser) ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (value.empty()) { + std::forward(releaser)(); + return GetEmptyBytesValue(); + } + return CreateBytesValue(base_internal::ExternalData( + static_cast(value.data()), value.size(), + std::make_unique( + std::forward(releaser)))); + } + + absl::StatusOr> CreateDurationValue( + absl::Duration value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::StatusOr> CreateTimestampValue( + absl::Time value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + + protected: + // Prevent direct intantiation until more pure virtual methods are added. + explicit ValueFactory(MemoryManager& memory_manager) + : memory_manager_(memory_manager) {} + + MemoryManager& memory_manager() const { return memory_manager_; } + + private: + Persistent GetEmptyBytesValue() + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::StatusOr> CreateBytesValue( + base_internal::ExternalData value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + + MemoryManager& memory_manager_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_VALUE_FACTORY_H_ diff --git a/base/value_factory_test.cc b/base/value_factory_test.cc new file mode 100644 index 000000000..346df9583 --- /dev/null +++ b/base/value_factory_test.cc @@ -0,0 +1,38 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/value_factory.h" + +#include "absl/status/status.h" +#include "base/memory_manager.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using cel::internal::StatusIs; + +class TestValueFactory final : public ValueFactory { + public: + TestValueFactory() : ValueFactory(MemoryManager::Global()) {} +}; + +TEST(ValueFactory, CreateErrorValueReplacesOk) { + TestValueFactory value_factory; + EXPECT_THAT(value_factory.CreateErrorValue(absl::OkStatus())->value(), + StatusIs(absl::StatusCode::kUnknown)); +} + +} // namespace +} // namespace cel diff --git a/base/value_test.cc b/base/value_test.cc index f9eae5723..8f7e87775 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -17,14 +17,19 @@ #include #include #include +#include #include #include #include #include "absl/hash/hash_testing.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/time/time.h" +#include "base/memory_manager.h" #include "base/type.h" +#include "base/type_factory.h" +#include "base/value_factory.h" #include "internal/strings.h" #include "internal/testing.h" #include "internal/time.h" @@ -34,26 +39,69 @@ namespace { using cel::internal::StatusIs; +template +Persistent Must(absl::StatusOr> status_or_handle) { + return std::move(status_or_handle).value(); +} + +class TestTypeFactory final : public TypeFactory { + public: + TestTypeFactory() : TypeFactory(MemoryManager::Global()) {} +}; + +class TestValueFactory final : public ValueFactory { + public: + TestValueFactory() : ValueFactory(MemoryManager::Global()) {} +}; + template constexpr void IS_INITIALIZED(T&) {} -TEST(Value, TypeTraits) { - EXPECT_TRUE(std::is_default_constructible_v); - EXPECT_TRUE(std::is_copy_constructible_v); - EXPECT_TRUE(std::is_move_constructible_v); - EXPECT_TRUE(std::is_copy_assignable_v); - EXPECT_TRUE(std::is_move_assignable_v); - EXPECT_TRUE(std::is_swappable_v); +TEST(Value, HandleSize) { + // Advisory test to ensure we attempt to keep the size of Value handles under + // 32 bytes. As of the time of writing they are 24 bytes. + EXPECT_LE(sizeof(base_internal::ValueHandleData), 32); +} + +TEST(Value, TransientHandleTypeTraits) { + EXPECT_TRUE(std::is_default_constructible_v>); + EXPECT_TRUE(std::is_copy_constructible_v>); + EXPECT_TRUE(std::is_move_constructible_v>); + EXPECT_TRUE(std::is_copy_assignable_v>); + EXPECT_TRUE(std::is_move_assignable_v>); + EXPECT_TRUE(std::is_swappable_v>); + EXPECT_TRUE(std::is_default_constructible_v>); + EXPECT_TRUE(std::is_copy_constructible_v>); + EXPECT_TRUE(std::is_move_constructible_v>); + EXPECT_TRUE(std::is_copy_assignable_v>); + EXPECT_TRUE(std::is_move_assignable_v>); + EXPECT_TRUE(std::is_swappable_v>); +} + +TEST(Value, PersistentHandleTypeTraits) { + EXPECT_TRUE(std::is_default_constructible_v>); + EXPECT_TRUE(std::is_copy_constructible_v>); + EXPECT_TRUE(std::is_move_constructible_v>); + EXPECT_TRUE(std::is_copy_assignable_v>); + EXPECT_TRUE(std::is_move_assignable_v>); + EXPECT_TRUE(std::is_swappable_v>); + EXPECT_TRUE(std::is_default_constructible_v>); + EXPECT_TRUE(std::is_copy_constructible_v>); + EXPECT_TRUE(std::is_move_constructible_v>); + EXPECT_TRUE(std::is_copy_assignable_v>); + EXPECT_TRUE(std::is_move_assignable_v>); + EXPECT_TRUE(std::is_swappable_v>); } TEST(Value, DefaultConstructor) { - Value value; - EXPECT_EQ(value, Value::Null()); + TestValueFactory value_factory; + Transient value; + EXPECT_EQ(value, value_factory.GetNullValue()); } struct ConstructionAssignmentTestCase final { std::string name; - std::function default_value; + std::function(ValueFactory&)> default_value; }; using ConstructionAssignmentTest = @@ -61,122 +109,171 @@ using ConstructionAssignmentTest = TEST_P(ConstructionAssignmentTest, CopyConstructor) { const auto& test_case = GetParam(); - Value from(test_case.default_value()); - Value to(from); + TestValueFactory value_factory; + Persistent from(test_case.default_value(value_factory)); + Persistent to(from); IS_INITIALIZED(to); - EXPECT_EQ(to, test_case.default_value()); + EXPECT_EQ(to, test_case.default_value(value_factory)); } TEST_P(ConstructionAssignmentTest, MoveConstructor) { const auto& test_case = GetParam(); - Value from(test_case.default_value()); - Value to(std::move(from)); + TestValueFactory value_factory; + Persistent from(test_case.default_value(value_factory)); + Persistent to(std::move(from)); IS_INITIALIZED(from); - EXPECT_EQ(from, Value::Null()); - EXPECT_EQ(to, test_case.default_value()); + EXPECT_EQ(from, value_factory.GetNullValue()); + EXPECT_EQ(to, test_case.default_value(value_factory)); } TEST_P(ConstructionAssignmentTest, CopyAssignment) { const auto& test_case = GetParam(); - Value from(test_case.default_value()); - Value to; + TestValueFactory value_factory; + Persistent from(test_case.default_value(value_factory)); + Persistent to; to = from; EXPECT_EQ(to, from); } TEST_P(ConstructionAssignmentTest, MoveAssignment) { const auto& test_case = GetParam(); - Value from(test_case.default_value()); - Value to; + TestValueFactory value_factory; + Persistent from(test_case.default_value(value_factory)); + Persistent to; to = std::move(from); IS_INITIALIZED(from); - EXPECT_EQ(from, Value::Null()); - EXPECT_EQ(to, test_case.default_value()); + EXPECT_EQ(from, value_factory.GetNullValue()); + EXPECT_EQ(to, test_case.default_value(value_factory)); } INSTANTIATE_TEST_SUITE_P( ConstructionAssignmentTest, ConstructionAssignmentTest, testing::ValuesIn({ - {"Null", Value::Null}, - {"Bool", Value::False}, - {"Int", []() { return Value::Int(0); }}, - {"Uint", []() { return Value::Uint(0); }}, - {"Double", []() { return Value::Double(0.0); }}, - {"Duration", []() { return Value::ZeroDuration(); }}, - {"Timestamp", []() { return Value::UnixEpoch(); }}, - {"Error", []() { return Value::Error(absl::CancelledError()); }}, - {"Bytes", Bytes::Empty}, + {"Null", + [](ValueFactory& value_factory) -> Persistent { + return value_factory.GetNullValue(); + }}, + {"Bool", + [](ValueFactory& value_factory) -> Persistent { + return value_factory.CreateBoolValue(false); + }}, + {"Int", + [](ValueFactory& value_factory) -> Persistent { + return value_factory.CreateIntValue(0); + }}, + {"Uint", + [](ValueFactory& value_factory) -> Persistent { + return value_factory.CreateUintValue(0); + }}, + {"Double", + [](ValueFactory& value_factory) -> Persistent { + return value_factory.CreateDoubleValue(0.0); + }}, + {"Duration", + [](ValueFactory& value_factory) -> Persistent { + return Must(value_factory.CreateDurationValue(absl::ZeroDuration())); + }}, + {"Timestamp", + [](ValueFactory& value_factory) -> Persistent { + return Must(value_factory.CreateTimestampValue(absl::UnixEpoch())); + }}, + {"Error", + [](ValueFactory& value_factory) -> Persistent { + return value_factory.CreateErrorValue(absl::CancelledError()); + }}, + {"Bytes", + [](ValueFactory& value_factory) -> Persistent { + return Must(value_factory.CreateBytesValue(0)); + }}, }), [](const testing::TestParamInfo& info) { return info.param.name; }); TEST(Value, Swap) { - Value lhs = Value::Int(0); - Value rhs = Value::Uint(0); + TestValueFactory value_factory; + Persistent lhs = value_factory.CreateIntValue(0); + Persistent rhs = value_factory.CreateUintValue(0); std::swap(lhs, rhs); - EXPECT_EQ(lhs, Value::Uint(0)); - EXPECT_EQ(rhs, Value::Int(0)); -} - -TEST(Value, NaN) { EXPECT_TRUE(std::isnan(Value::NaN().AsDouble())); } - -TEST(Value, PositiveInfinity) { - EXPECT_TRUE(std::isinf(Value::PositiveInfinity().AsDouble())); - EXPECT_FALSE(std::signbit(Value::PositiveInfinity().AsDouble())); -} - -TEST(Value, NegativeInfinity) { - EXPECT_TRUE(std::isinf(Value::NegativeInfinity().AsDouble())); - EXPECT_TRUE(std::signbit(Value::NegativeInfinity().AsDouble())); -} - -TEST(Value, ZeroDuration) { - EXPECT_EQ(Value::ZeroDuration().AsDuration(), absl::ZeroDuration()); -} - -TEST(Value, UnixEpoch) { - EXPECT_EQ(Value::UnixEpoch().AsTimestamp(), absl::UnixEpoch()); -} - -TEST(Null, DebugString) { EXPECT_EQ(Value::Null().DebugString(), "null"); } - -TEST(Bool, DebugString) { - EXPECT_EQ(Value::False().DebugString(), "false"); - EXPECT_EQ(Value::True().DebugString(), "true"); -} - -TEST(Int, DebugString) { - EXPECT_EQ(Value::Int(-1).DebugString(), "-1"); - EXPECT_EQ(Value::Int(0).DebugString(), "0"); - EXPECT_EQ(Value::Int(1).DebugString(), "1"); -} - -TEST(Uint, DebugString) { - EXPECT_EQ(Value::Uint(0).DebugString(), "0u"); - EXPECT_EQ(Value::Uint(1).DebugString(), "1u"); -} - -TEST(Double, DebugString) { - EXPECT_EQ(Value::Double(-1.0).DebugString(), "-1.0"); - EXPECT_EQ(Value::Double(0.0).DebugString(), "0.0"); - EXPECT_EQ(Value::Double(1.0).DebugString(), "1.0"); - EXPECT_EQ(Value::Double(-1.1).DebugString(), "-1.1"); - EXPECT_EQ(Value::Double(0.1).DebugString(), "0.1"); - EXPECT_EQ(Value::Double(1.1).DebugString(), "1.1"); - - EXPECT_EQ(Value::NaN().DebugString(), "nan"); - EXPECT_EQ(Value::PositiveInfinity().DebugString(), "+infinity"); - EXPECT_EQ(Value::NegativeInfinity().DebugString(), "-infinity"); -} - -TEST(Duration, DebugString) { - EXPECT_EQ(Value::ZeroDuration().DebugString(), + EXPECT_EQ(lhs, value_factory.CreateUintValue(0)); + EXPECT_EQ(rhs, value_factory.CreateIntValue(0)); +} + +TEST(NullValue, DebugString) { + TestValueFactory value_factory; + EXPECT_EQ(value_factory.GetNullValue()->DebugString(), "null"); +} + +TEST(BoolValue, DebugString) { + TestValueFactory value_factory; + EXPECT_EQ(value_factory.CreateBoolValue(false)->DebugString(), "false"); + EXPECT_EQ(value_factory.CreateBoolValue(true)->DebugString(), "true"); +} + +TEST(IntValue, DebugString) { + TestValueFactory value_factory; + EXPECT_EQ(value_factory.CreateIntValue(-1)->DebugString(), "-1"); + EXPECT_EQ(value_factory.CreateIntValue(0)->DebugString(), "0"); + EXPECT_EQ(value_factory.CreateIntValue(1)->DebugString(), "1"); + EXPECT_EQ(value_factory.CreateIntValue(std::numeric_limits::min()) + ->DebugString(), + "-9223372036854775808"); + EXPECT_EQ(value_factory.CreateIntValue(std::numeric_limits::max()) + ->DebugString(), + "9223372036854775807"); +} + +TEST(UintValue, DebugString) { + TestValueFactory value_factory; + EXPECT_EQ(value_factory.CreateUintValue(0)->DebugString(), "0u"); + EXPECT_EQ(value_factory.CreateUintValue(1)->DebugString(), "1u"); + EXPECT_EQ(value_factory.CreateUintValue(std::numeric_limits::max()) + ->DebugString(), + "18446744073709551615u"); +} + +TEST(DoubleValue, DebugString) { + TestValueFactory value_factory; + EXPECT_EQ(value_factory.CreateDoubleValue(-1.0)->DebugString(), "-1.0"); + EXPECT_EQ(value_factory.CreateDoubleValue(0.0)->DebugString(), "0.0"); + EXPECT_EQ(value_factory.CreateDoubleValue(1.0)->DebugString(), "1.0"); + EXPECT_EQ(value_factory.CreateDoubleValue(-1.1)->DebugString(), "-1.1"); + EXPECT_EQ(value_factory.CreateDoubleValue(0.1)->DebugString(), "0.1"); + EXPECT_EQ(value_factory.CreateDoubleValue(1.1)->DebugString(), "1.1"); + EXPECT_EQ(value_factory.CreateDoubleValue(-9007199254740991.0)->DebugString(), + "-9.0072e+15"); + EXPECT_EQ(value_factory.CreateDoubleValue(9007199254740991.0)->DebugString(), + "9.0072e+15"); + EXPECT_EQ(value_factory.CreateDoubleValue(-9007199254740991.1)->DebugString(), + "-9.0072e+15"); + EXPECT_EQ(value_factory.CreateDoubleValue(9007199254740991.1)->DebugString(), + "9.0072e+15"); + EXPECT_EQ(value_factory.CreateDoubleValue(9007199254740991.1)->DebugString(), + "9.0072e+15"); + + EXPECT_EQ( + value_factory.CreateDoubleValue(std::numeric_limits::quiet_NaN()) + ->DebugString(), + "nan"); + EXPECT_EQ( + value_factory.CreateDoubleValue(std::numeric_limits::infinity()) + ->DebugString(), + "+infinity"); + EXPECT_EQ( + value_factory.CreateDoubleValue(-std::numeric_limits::infinity()) + ->DebugString(), + "-infinity"); +} + +TEST(DurationValue, DebugString) { + TestValueFactory value_factory; + EXPECT_EQ(DurationValue::Zero(value_factory)->DebugString(), internal::FormatDuration(absl::ZeroDuration()).value()); } -TEST(Timestamp, DebugString) { - EXPECT_EQ(Value::UnixEpoch().DebugString(), +TEST(TimestampValue, DebugString) { + TestValueFactory value_factory; + EXPECT_EQ(TimestampValue::UnixEpoch(value_factory)->DebugString(), internal::FormatTimestamp(absl::UnixEpoch()).value()); } @@ -185,238 +282,317 @@ TEST(Timestamp, DebugString) { // feature is not available in C++17. TEST(Value, Error) { - Value error_value = Value::Error(absl::CancelledError()); - EXPECT_TRUE(error_value.IsError()); + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto error_value = value_factory.CreateErrorValue(absl::CancelledError()); + EXPECT_TRUE(error_value.Is()); + EXPECT_FALSE(error_value.Is()); EXPECT_EQ(error_value, error_value); - EXPECT_EQ(error_value, Value::Error(absl::CancelledError())); - EXPECT_EQ(error_value.AsError(), absl::CancelledError()); + EXPECT_EQ(error_value, + value_factory.CreateErrorValue(absl::CancelledError())); + EXPECT_EQ(error_value->value(), absl::CancelledError()); } TEST(Value, Bool) { - Value false_value = Value::False(); - EXPECT_TRUE(false_value.IsBool()); + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto false_value = BoolValue::False(value_factory); + EXPECT_TRUE(false_value.Is()); + EXPECT_FALSE(false_value.Is()); EXPECT_EQ(false_value, false_value); - EXPECT_EQ(false_value, Value::Bool(false)); - EXPECT_EQ(false_value.kind(), Kind::kBool); - EXPECT_EQ(false_value.type(), Type::Bool()); - EXPECT_FALSE(false_value.AsBool()); - - Value true_value = Value::True(); - EXPECT_TRUE(true_value.IsBool()); + EXPECT_EQ(false_value, value_factory.CreateBoolValue(false)); + EXPECT_EQ(false_value->kind(), Kind::kBool); + EXPECT_EQ(false_value->type(), type_factory.GetBoolType()); + EXPECT_FALSE(false_value->value()); + + auto true_value = BoolValue::True(value_factory); + EXPECT_TRUE(true_value.Is()); + EXPECT_FALSE(true_value.Is()); EXPECT_EQ(true_value, true_value); - EXPECT_EQ(true_value, Value::Bool(true)); - EXPECT_EQ(true_value.kind(), Kind::kBool); - EXPECT_EQ(true_value.type(), Type::Bool()); - EXPECT_TRUE(true_value.AsBool()); + EXPECT_EQ(true_value, value_factory.CreateBoolValue(true)); + EXPECT_EQ(true_value->kind(), Kind::kBool); + EXPECT_EQ(true_value->type(), type_factory.GetBoolType()); + EXPECT_TRUE(true_value->value()); EXPECT_NE(false_value, true_value); EXPECT_NE(true_value, false_value); } TEST(Value, Int) { - Value zero_value = Value::Int(0); - EXPECT_TRUE(zero_value.IsInt()); + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = value_factory.CreateIntValue(0); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); EXPECT_EQ(zero_value, zero_value); - EXPECT_EQ(zero_value, Value::Int(0)); - EXPECT_EQ(zero_value.kind(), Kind::kInt); - EXPECT_EQ(zero_value.type(), Type::Int()); - EXPECT_EQ(zero_value.AsInt(), 0); - - Value one_value = Value::Int(1); - EXPECT_TRUE(one_value.IsInt()); + EXPECT_EQ(zero_value, value_factory.CreateIntValue(0)); + EXPECT_EQ(zero_value->kind(), Kind::kInt); + EXPECT_EQ(zero_value->type(), type_factory.GetIntType()); + EXPECT_EQ(zero_value->value(), 0); + + auto one_value = value_factory.CreateIntValue(1); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value, Value::Int(1)); - EXPECT_EQ(one_value.kind(), Kind::kInt); - EXPECT_EQ(one_value.type(), Type::Int()); - EXPECT_EQ(one_value.AsInt(), 1); + EXPECT_EQ(one_value, value_factory.CreateIntValue(1)); + EXPECT_EQ(one_value->kind(), Kind::kInt); + EXPECT_EQ(one_value->type(), type_factory.GetIntType()); + EXPECT_EQ(one_value->value(), 1); EXPECT_NE(zero_value, one_value); EXPECT_NE(one_value, zero_value); } TEST(Value, Uint) { - Value zero_value = Value::Uint(0); - EXPECT_TRUE(zero_value.IsUint()); + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = value_factory.CreateUintValue(0); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); EXPECT_EQ(zero_value, zero_value); - EXPECT_EQ(zero_value, Value::Uint(0)); - EXPECT_EQ(zero_value.kind(), Kind::kUint); - EXPECT_EQ(zero_value.type(), Type::Uint()); - EXPECT_EQ(zero_value.AsUint(), 0); - - Value one_value = Value::Uint(1); - EXPECT_TRUE(one_value.IsUint()); + EXPECT_EQ(zero_value, value_factory.CreateUintValue(0)); + EXPECT_EQ(zero_value->kind(), Kind::kUint); + EXPECT_EQ(zero_value->type(), type_factory.GetUintType()); + EXPECT_EQ(zero_value->value(), 0); + + auto one_value = value_factory.CreateUintValue(1); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value, Value::Uint(1)); - EXPECT_EQ(one_value.kind(), Kind::kUint); - EXPECT_EQ(one_value.type(), Type::Uint()); - EXPECT_EQ(one_value.AsUint(), 1); + EXPECT_EQ(one_value, value_factory.CreateUintValue(1)); + EXPECT_EQ(one_value->kind(), Kind::kUint); + EXPECT_EQ(one_value->type(), type_factory.GetUintType()); + EXPECT_EQ(one_value->value(), 1); EXPECT_NE(zero_value, one_value); EXPECT_NE(one_value, zero_value); } TEST(Value, Double) { - Value zero_value = Value::Double(0.0); - EXPECT_TRUE(zero_value.IsDouble()); + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = value_factory.CreateDoubleValue(0.0); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); EXPECT_EQ(zero_value, zero_value); - EXPECT_EQ(zero_value, Value::Double(0.0)); - EXPECT_EQ(zero_value.kind(), Kind::kDouble); - EXPECT_EQ(zero_value.type(), Type::Double()); - EXPECT_EQ(zero_value.AsDouble(), 0.0); - - Value one_value = Value::Double(1.0); - EXPECT_TRUE(one_value.IsDouble()); + EXPECT_EQ(zero_value, value_factory.CreateDoubleValue(0.0)); + EXPECT_EQ(zero_value->kind(), Kind::kDouble); + EXPECT_EQ(zero_value->type(), type_factory.GetDoubleType()); + EXPECT_EQ(zero_value->value(), 0.0); + + auto one_value = value_factory.CreateDoubleValue(1.0); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value, Value::Double(1.0)); - EXPECT_EQ(one_value.kind(), Kind::kDouble); - EXPECT_EQ(one_value.type(), Type::Double()); - EXPECT_EQ(one_value.AsDouble(), 1.0); + EXPECT_EQ(one_value, value_factory.CreateDoubleValue(1.0)); + EXPECT_EQ(one_value->kind(), Kind::kDouble); + EXPECT_EQ(one_value->type(), type_factory.GetDoubleType()); + EXPECT_EQ(one_value->value(), 1.0); EXPECT_NE(zero_value, one_value); EXPECT_NE(one_value, zero_value); } TEST(Value, Duration) { - Value zero_value = Value::ZeroDuration(); - EXPECT_TRUE(zero_value.IsDuration()); + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = + Must(value_factory.CreateDurationValue(absl::ZeroDuration())); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); EXPECT_EQ(zero_value, zero_value); - EXPECT_EQ(zero_value, Value::ZeroDuration()); - EXPECT_EQ(zero_value.kind(), Kind::kDuration); - EXPECT_EQ(zero_value.type(), Type::Duration()); - EXPECT_EQ(zero_value.AsDuration(), absl::ZeroDuration()); - - ASSERT_OK_AND_ASSIGN(Value one_value, Value::Duration(absl::ZeroDuration() + - absl::Nanoseconds(1))); - EXPECT_TRUE(one_value.IsDuration()); + EXPECT_EQ(zero_value, + Must(value_factory.CreateDurationValue(absl::ZeroDuration()))); + EXPECT_EQ(zero_value->kind(), Kind::kDuration); + EXPECT_EQ(zero_value->type(), type_factory.GetDurationType()); + EXPECT_EQ(zero_value->value(), absl::ZeroDuration()); + + auto one_value = Must(value_factory.CreateDurationValue( + absl::ZeroDuration() + absl::Nanoseconds(1))); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value.kind(), Kind::kDuration); - EXPECT_EQ(one_value.type(), Type::Duration()); - EXPECT_EQ(one_value.AsDuration(), - absl::ZeroDuration() + absl::Nanoseconds(1)); + EXPECT_EQ(one_value->kind(), Kind::kDuration); + EXPECT_EQ(one_value->type(), type_factory.GetDurationType()); + EXPECT_EQ(one_value->value(), absl::ZeroDuration() + absl::Nanoseconds(1)); EXPECT_NE(zero_value, one_value); EXPECT_NE(one_value, zero_value); - EXPECT_THAT(Value::Duration(absl::InfiniteDuration()), + EXPECT_THAT(value_factory.CreateDurationValue(absl::InfiniteDuration()), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(Value, Timestamp) { - Value zero_value = Value::UnixEpoch(); - EXPECT_TRUE(zero_value.IsTimestamp()); + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = Must(value_factory.CreateTimestampValue(absl::UnixEpoch())); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); EXPECT_EQ(zero_value, zero_value); - EXPECT_EQ(zero_value, Value::UnixEpoch()); - EXPECT_EQ(zero_value.kind(), Kind::kTimestamp); - EXPECT_EQ(zero_value.type(), Type::Timestamp()); - EXPECT_EQ(zero_value.AsTimestamp(), absl::UnixEpoch()); - - ASSERT_OK_AND_ASSIGN(Value one_value, Value::Timestamp(absl::UnixEpoch() + - absl::Nanoseconds(1))); - EXPECT_TRUE(one_value.IsTimestamp()); + EXPECT_EQ(zero_value, + Must(value_factory.CreateTimestampValue(absl::UnixEpoch()))); + EXPECT_EQ(zero_value->kind(), Kind::kTimestamp); + EXPECT_EQ(zero_value->type(), type_factory.GetTimestampType()); + EXPECT_EQ(zero_value->value(), absl::UnixEpoch()); + + auto one_value = Must(value_factory.CreateTimestampValue( + absl::UnixEpoch() + absl::Nanoseconds(1))); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value.kind(), Kind::kTimestamp); - EXPECT_EQ(one_value.type(), Type::Timestamp()); - EXPECT_EQ(one_value.AsTimestamp(), absl::UnixEpoch() + absl::Nanoseconds(1)); + EXPECT_EQ(one_value->kind(), Kind::kTimestamp); + EXPECT_EQ(one_value->type(), type_factory.GetTimestampType()); + EXPECT_EQ(one_value->value(), absl::UnixEpoch() + absl::Nanoseconds(1)); EXPECT_NE(zero_value, one_value); EXPECT_NE(one_value, zero_value); - EXPECT_THAT(Value::Timestamp(absl::InfiniteFuture()), + EXPECT_THAT(value_factory.CreateTimestampValue(absl::InfiniteFuture()), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(Value, BytesFromString) { - Value zero_value = Bytes::New(std::string("0")); - EXPECT_TRUE(zero_value.IsBytes()); + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = Must(value_factory.CreateBytesValue(std::string("0"))); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); EXPECT_EQ(zero_value, zero_value); - EXPECT_EQ(zero_value, Bytes::New(std::string("0"))); - EXPECT_EQ(zero_value.kind(), Kind::kBytes); - EXPECT_EQ(zero_value.type(), Type::Bytes()); - EXPECT_EQ(zero_value.AsBytes().ToString(), "0"); - - Value one_value = Bytes::New(std::string("1")); - EXPECT_TRUE(one_value.IsBytes()); + EXPECT_EQ(zero_value, Must(value_factory.CreateBytesValue(std::string("0")))); + EXPECT_EQ(zero_value->kind(), Kind::kBytes); + EXPECT_EQ(zero_value->type(), type_factory.GetBytesType()); + EXPECT_EQ(zero_value->ToString(), "0"); + + auto one_value = Must(value_factory.CreateBytesValue(std::string("1"))); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value, Bytes::New(std::string("1"))); - EXPECT_EQ(one_value.kind(), Kind::kBytes); - EXPECT_EQ(one_value.type(), Type::Bytes()); - EXPECT_EQ(one_value.AsBytes().ToString(), "1"); + EXPECT_EQ(one_value, Must(value_factory.CreateBytesValue(std::string("1")))); + EXPECT_EQ(one_value->kind(), Kind::kBytes); + EXPECT_EQ(one_value->type(), type_factory.GetBytesType()); + EXPECT_EQ(one_value->ToString(), "1"); EXPECT_NE(zero_value, one_value); EXPECT_NE(one_value, zero_value); } TEST(Value, BytesFromStringView) { - Value zero_value = Bytes::New(absl::string_view("0")); - EXPECT_TRUE(zero_value.IsBytes()); + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = + Must(value_factory.CreateBytesValue(absl::string_view("0"))); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); EXPECT_EQ(zero_value, zero_value); - EXPECT_EQ(zero_value, Bytes::New(absl::string_view("0"))); - EXPECT_EQ(zero_value.kind(), Kind::kBytes); - EXPECT_EQ(zero_value.type(), Type::Bytes()); - EXPECT_EQ(zero_value.AsBytes().ToString(), "0"); - - Value one_value = Bytes::New(absl::string_view("1")); - EXPECT_TRUE(one_value.IsBytes()); + EXPECT_EQ(zero_value, + Must(value_factory.CreateBytesValue(absl::string_view("0")))); + EXPECT_EQ(zero_value->kind(), Kind::kBytes); + EXPECT_EQ(zero_value->type(), type_factory.GetBytesType()); + EXPECT_EQ(zero_value->ToString(), "0"); + + auto one_value = Must(value_factory.CreateBytesValue(absl::string_view("1"))); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value, Bytes::New(absl::string_view("1"))); - EXPECT_EQ(one_value.kind(), Kind::kBytes); - EXPECT_EQ(one_value.type(), Type::Bytes()); - EXPECT_EQ(one_value.AsBytes().ToString(), "1"); + EXPECT_EQ(one_value, + Must(value_factory.CreateBytesValue(absl::string_view("1")))); + EXPECT_EQ(one_value->kind(), Kind::kBytes); + EXPECT_EQ(one_value->type(), type_factory.GetBytesType()); + EXPECT_EQ(one_value->ToString(), "1"); EXPECT_NE(zero_value, one_value); EXPECT_NE(one_value, zero_value); } TEST(Value, BytesFromCord) { - Value zero_value = Bytes::New(absl::Cord("0")); - EXPECT_TRUE(zero_value.IsBytes()); + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = Must(value_factory.CreateBytesValue(absl::Cord("0"))); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); EXPECT_EQ(zero_value, zero_value); - EXPECT_EQ(zero_value, Bytes::New(absl::Cord("0"))); - EXPECT_EQ(zero_value.kind(), Kind::kBytes); - EXPECT_EQ(zero_value.type(), Type::Bytes()); - EXPECT_EQ(zero_value.AsBytes().ToCord(), "0"); - - Value one_value = Bytes::New(absl::Cord("1")); - EXPECT_TRUE(one_value.IsBytes()); + EXPECT_EQ(zero_value, Must(value_factory.CreateBytesValue(absl::Cord("0")))); + EXPECT_EQ(zero_value->kind(), Kind::kBytes); + EXPECT_EQ(zero_value->type(), type_factory.GetBytesType()); + EXPECT_EQ(zero_value->ToCord(), "0"); + + auto one_value = Must(value_factory.CreateBytesValue(absl::Cord("1"))); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value, Bytes::New(absl::Cord("1"))); - EXPECT_EQ(one_value.kind(), Kind::kBytes); - EXPECT_EQ(one_value.type(), Type::Bytes()); - EXPECT_EQ(one_value.AsBytes().ToCord(), "1"); + EXPECT_EQ(one_value, Must(value_factory.CreateBytesValue(absl::Cord("1")))); + EXPECT_EQ(one_value->kind(), Kind::kBytes); + EXPECT_EQ(one_value->type(), type_factory.GetBytesType()); + EXPECT_EQ(one_value->ToCord(), "1"); EXPECT_NE(zero_value, one_value); EXPECT_NE(one_value, zero_value); } TEST(Value, BytesFromLiteral) { - Value zero_value = Bytes::New("0"); - EXPECT_TRUE(zero_value.IsBytes()); + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = Must(value_factory.CreateBytesValue("0")); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); EXPECT_EQ(zero_value, zero_value); - EXPECT_EQ(zero_value, Bytes::New("0")); - EXPECT_EQ(zero_value.kind(), Kind::kBytes); - EXPECT_EQ(zero_value.type(), Type::Bytes()); - EXPECT_EQ(zero_value.AsBytes().ToString(), "0"); + EXPECT_EQ(zero_value, Must(value_factory.CreateBytesValue("0"))); + EXPECT_EQ(zero_value->kind(), Kind::kBytes); + EXPECT_EQ(zero_value->type(), type_factory.GetBytesType()); + EXPECT_EQ(zero_value->ToString(), "0"); + + auto one_value = Must(value_factory.CreateBytesValue("1")); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Must(value_factory.CreateBytesValue("1"))); + EXPECT_EQ(one_value->kind(), Kind::kBytes); + EXPECT_EQ(one_value->type(), type_factory.GetBytesType()); + EXPECT_EQ(one_value->ToString(), "1"); - Value one_value = Bytes::New("1"); - EXPECT_TRUE(one_value.IsBytes()); + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(Value, BytesFromExternal) { + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = Must(value_factory.CreateBytesValue("0", []() {})); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Must(value_factory.CreateBytesValue("0", []() {}))); + EXPECT_EQ(zero_value->kind(), Kind::kBytes); + EXPECT_EQ(zero_value->type(), type_factory.GetBytesType()); + EXPECT_EQ(zero_value->ToString(), "0"); + + auto one_value = Must(value_factory.CreateBytesValue("1", []() {})); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); EXPECT_EQ(one_value, one_value); - EXPECT_EQ(one_value, Bytes::New("1")); - EXPECT_EQ(one_value.kind(), Kind::kBytes); - EXPECT_EQ(one_value.type(), Type::Bytes()); - EXPECT_EQ(one_value.AsBytes().ToString(), "1"); + EXPECT_EQ(one_value, Must(value_factory.CreateBytesValue("1", []() {}))); + EXPECT_EQ(one_value->kind(), Kind::kBytes); + EXPECT_EQ(one_value->type(), type_factory.GetBytesType()); + EXPECT_EQ(one_value->ToString(), "1"); EXPECT_NE(zero_value, one_value); EXPECT_NE(one_value, zero_value); } -Value MakeStringBytes(absl::string_view value) { return Bytes::New(value); } +Persistent MakeStringBytes(ValueFactory& value_factory, + absl::string_view value) { + return Must(value_factory.CreateBytesValue(value)); +} -Value MakeCordBytes(absl::string_view value) { - return Bytes::New(absl::Cord(value)); +Persistent MakeCordBytes(ValueFactory& value_factory, + absl::string_view value) { + return Must(value_factory.CreateBytesValue(absl::Cord(value))); } -Value MakeWrappedBytes(absl::string_view value) { - return Bytes::Wrap(value, []() {}); +Persistent MakeExternalBytes(ValueFactory& value_factory, + absl::string_view value) { + return Must(value_factory.CreateBytesValue(value, []() {})); } struct BytesConcatTestCase final { @@ -428,42 +604,52 @@ using BytesConcatTest = testing::TestWithParam; TEST_P(BytesConcatTest, Concat) { const BytesConcatTestCase& test_case = GetParam(); - EXPECT_TRUE(Bytes::Concat(MakeStringBytes(test_case.lhs).AsBytes(), - MakeStringBytes(test_case.rhs).AsBytes()) - .AsBytes() - .Equals(test_case.lhs + test_case.rhs)); - EXPECT_TRUE(Bytes::Concat(MakeStringBytes(test_case.lhs).AsBytes(), - MakeCordBytes(test_case.rhs).AsBytes()) - .AsBytes() - .Equals(test_case.lhs + test_case.rhs)); - EXPECT_TRUE(Bytes::Concat(MakeStringBytes(test_case.lhs).AsBytes(), - MakeWrappedBytes(test_case.rhs).AsBytes()) - .AsBytes() - .Equals(test_case.lhs + test_case.rhs)); - EXPECT_TRUE(Bytes::Concat(MakeCordBytes(test_case.lhs).AsBytes(), - MakeStringBytes(test_case.rhs).AsBytes()) - .AsBytes() - .Equals(test_case.lhs + test_case.rhs)); - EXPECT_TRUE(Bytes::Concat(MakeCordBytes(test_case.lhs).AsBytes(), - MakeWrappedBytes(test_case.rhs).AsBytes()) - .AsBytes() - .Equals(test_case.lhs + test_case.rhs)); - EXPECT_TRUE(Bytes::Concat(MakeCordBytes(test_case.lhs).AsBytes(), - MakeCordBytes(test_case.rhs).AsBytes()) - .AsBytes() - .Equals(test_case.lhs + test_case.rhs)); - EXPECT_TRUE(Bytes::Concat(MakeWrappedBytes(test_case.lhs).AsBytes(), - MakeStringBytes(test_case.rhs).AsBytes()) - .AsBytes() - .Equals(test_case.lhs + test_case.rhs)); - EXPECT_TRUE(Bytes::Concat(MakeWrappedBytes(test_case.lhs).AsBytes(), - MakeCordBytes(test_case.rhs).AsBytes()) - .AsBytes() - .Equals(test_case.lhs + test_case.rhs)); - EXPECT_TRUE(Bytes::Concat(MakeWrappedBytes(test_case.lhs).AsBytes(), - MakeWrappedBytes(test_case.rhs).AsBytes()) - .AsBytes() - .Equals(test_case.lhs + test_case.rhs)); + TestValueFactory value_factory; + EXPECT_TRUE( + Must(BytesValue::Concat(value_factory, + MakeStringBytes(value_factory, test_case.lhs), + MakeStringBytes(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(BytesValue::Concat(value_factory, + MakeStringBytes(value_factory, test_case.lhs), + MakeCordBytes(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(BytesValue::Concat(value_factory, + MakeStringBytes(value_factory, test_case.lhs), + MakeExternalBytes(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(BytesValue::Concat(value_factory, + MakeCordBytes(value_factory, test_case.lhs), + MakeStringBytes(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(BytesValue::Concat(value_factory, + MakeCordBytes(value_factory, test_case.lhs), + MakeCordBytes(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(BytesValue::Concat(value_factory, + MakeCordBytes(value_factory, test_case.lhs), + MakeExternalBytes(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(BytesValue::Concat(value_factory, + MakeExternalBytes(value_factory, test_case.lhs), + MakeStringBytes(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(BytesValue::Concat(value_factory, + MakeExternalBytes(value_factory, test_case.lhs), + MakeCordBytes(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(BytesValue::Concat(value_factory, + MakeExternalBytes(value_factory, test_case.lhs), + MakeExternalBytes(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); } INSTANTIATE_TEST_SUITE_P(BytesConcatTest, BytesConcatTest, @@ -489,9 +675,13 @@ using BytesSizeTest = testing::TestWithParam; TEST_P(BytesSizeTest, Size) { const BytesSizeTestCase& test_case = GetParam(); - EXPECT_EQ(MakeStringBytes(test_case.data).AsBytes().size(), test_case.size); - EXPECT_EQ(MakeCordBytes(test_case.data).AsBytes().size(), test_case.size); - EXPECT_EQ(MakeWrappedBytes(test_case.data).AsBytes().size(), test_case.size); + TestValueFactory value_factory; + EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->size(), + test_case.size); + EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->size(), + test_case.size); + EXPECT_EQ(MakeExternalBytes(value_factory, test_case.data)->size(), + test_case.size); } INSTANTIATE_TEST_SUITE_P(BytesSizeTest, BytesSizeTest, @@ -511,9 +701,12 @@ using BytesEmptyTest = testing::TestWithParam; TEST_P(BytesEmptyTest, Empty) { const BytesEmptyTestCase& test_case = GetParam(); - EXPECT_EQ(MakeStringBytes(test_case.data).AsBytes().empty(), test_case.empty); - EXPECT_EQ(MakeCordBytes(test_case.data).AsBytes().empty(), test_case.empty); - EXPECT_EQ(MakeWrappedBytes(test_case.data).AsBytes().empty(), + TestValueFactory value_factory; + EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->empty(), + test_case.empty); + EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->empty(), + test_case.empty); + EXPECT_EQ(MakeExternalBytes(value_factory, test_case.data)->empty(), test_case.empty); } @@ -534,41 +727,33 @@ using BytesEqualsTest = testing::TestWithParam; TEST_P(BytesEqualsTest, Equals) { const BytesEqualsTestCase& test_case = GetParam(); - EXPECT_EQ(MakeStringBytes(test_case.lhs) - .AsBytes() - .Equals(MakeStringBytes(test_case.rhs).AsBytes()), + TestValueFactory value_factory; + EXPECT_EQ(MakeStringBytes(value_factory, test_case.lhs) + ->Equals(MakeStringBytes(value_factory, test_case.rhs)), test_case.equals); - EXPECT_EQ(MakeStringBytes(test_case.lhs) - .AsBytes() - .Equals(MakeCordBytes(test_case.rhs).AsBytes()), + EXPECT_EQ(MakeStringBytes(value_factory, test_case.lhs) + ->Equals(MakeCordBytes(value_factory, test_case.rhs)), test_case.equals); - EXPECT_EQ(MakeStringBytes(test_case.lhs) - .AsBytes() - .Equals(MakeWrappedBytes(test_case.rhs).AsBytes()), + EXPECT_EQ(MakeStringBytes(value_factory, test_case.lhs) + ->Equals(MakeExternalBytes(value_factory, test_case.rhs)), test_case.equals); - EXPECT_EQ(MakeCordBytes(test_case.lhs) - .AsBytes() - .Equals(MakeStringBytes(test_case.rhs).AsBytes()), + EXPECT_EQ(MakeCordBytes(value_factory, test_case.lhs) + ->Equals(MakeStringBytes(value_factory, test_case.rhs)), test_case.equals); - EXPECT_EQ(MakeCordBytes(test_case.lhs) - .AsBytes() - .Equals(MakeWrappedBytes(test_case.rhs).AsBytes()), + EXPECT_EQ(MakeCordBytes(value_factory, test_case.lhs) + ->Equals(MakeCordBytes(value_factory, test_case.rhs)), test_case.equals); - EXPECT_EQ(MakeCordBytes(test_case.lhs) - .AsBytes() - .Equals(MakeCordBytes(test_case.rhs).AsBytes()), + EXPECT_EQ(MakeCordBytes(value_factory, test_case.lhs) + ->Equals(MakeExternalBytes(value_factory, test_case.rhs)), test_case.equals); - EXPECT_EQ(MakeWrappedBytes(test_case.lhs) - .AsBytes() - .Equals(MakeStringBytes(test_case.rhs).AsBytes()), + EXPECT_EQ(MakeExternalBytes(value_factory, test_case.lhs) + ->Equals(MakeStringBytes(value_factory, test_case.rhs)), test_case.equals); - EXPECT_EQ(MakeWrappedBytes(test_case.lhs) - .AsBytes() - .Equals(MakeCordBytes(test_case.rhs).AsBytes()), + EXPECT_EQ(MakeExternalBytes(value_factory, test_case.lhs) + ->Equals(MakeCordBytes(value_factory, test_case.rhs)), test_case.equals); - EXPECT_EQ(MakeWrappedBytes(test_case.lhs) - .AsBytes() - .Equals(MakeWrappedBytes(test_case.rhs).AsBytes()), + EXPECT_EQ(MakeExternalBytes(value_factory, test_case.lhs) + ->Equals(MakeExternalBytes(value_factory, test_case.rhs)), test_case.equals); } @@ -598,50 +783,42 @@ int NormalizeCompareResult(int compare) { return std::clamp(compare, -1, 1); } TEST_P(BytesCompareTest, Equals) { const BytesCompareTestCase& test_case = GetParam(); + TestValueFactory value_factory; EXPECT_EQ(NormalizeCompareResult( - MakeStringBytes(test_case.lhs) - .AsBytes() - .Compare(MakeStringBytes(test_case.rhs).AsBytes())), + MakeStringBytes(value_factory, test_case.lhs) + ->Compare(MakeStringBytes(value_factory, test_case.rhs))), test_case.compare); EXPECT_EQ(NormalizeCompareResult( - MakeStringBytes(test_case.lhs) - .AsBytes() - .Compare(MakeCordBytes(test_case.rhs).AsBytes())), + MakeStringBytes(value_factory, test_case.lhs) + ->Compare(MakeCordBytes(value_factory, test_case.rhs))), test_case.compare); EXPECT_EQ(NormalizeCompareResult( - MakeStringBytes(test_case.lhs) - .AsBytes() - .Compare(MakeWrappedBytes(test_case.rhs).AsBytes())), + MakeStringBytes(value_factory, test_case.lhs) + ->Compare(MakeExternalBytes(value_factory, test_case.rhs))), test_case.compare); EXPECT_EQ(NormalizeCompareResult( - MakeCordBytes(test_case.lhs) - .AsBytes() - .Compare(MakeStringBytes(test_case.rhs).AsBytes())), + MakeCordBytes(value_factory, test_case.lhs) + ->Compare(MakeStringBytes(value_factory, test_case.rhs))), test_case.compare); EXPECT_EQ(NormalizeCompareResult( - MakeCordBytes(test_case.lhs) - .AsBytes() - .Compare(MakeWrappedBytes(test_case.rhs).AsBytes())), + MakeCordBytes(value_factory, test_case.lhs) + ->Compare(MakeCordBytes(value_factory, test_case.rhs))), test_case.compare); EXPECT_EQ(NormalizeCompareResult( - MakeCordBytes(test_case.lhs) - .AsBytes() - .Compare(MakeCordBytes(test_case.rhs).AsBytes())), + MakeCordBytes(value_factory, test_case.lhs) + ->Compare(MakeExternalBytes(value_factory, test_case.rhs))), test_case.compare); EXPECT_EQ(NormalizeCompareResult( - MakeWrappedBytes(test_case.lhs) - .AsBytes() - .Compare(MakeStringBytes(test_case.rhs).AsBytes())), + MakeExternalBytes(value_factory, test_case.lhs) + ->Compare(MakeStringBytes(value_factory, test_case.rhs))), test_case.compare); EXPECT_EQ(NormalizeCompareResult( - MakeWrappedBytes(test_case.lhs) - .AsBytes() - .Compare(MakeCordBytes(test_case.rhs).AsBytes())), + MakeExternalBytes(value_factory, test_case.lhs) + ->Compare(MakeCordBytes(value_factory, test_case.rhs))), test_case.compare); EXPECT_EQ(NormalizeCompareResult( - MakeWrappedBytes(test_case.lhs) - .AsBytes() - .Compare(MakeWrappedBytes(test_case.rhs).AsBytes())), + MakeExternalBytes(value_factory, test_case.lhs) + ->Compare(MakeExternalBytes(value_factory, test_case.rhs))), test_case.compare); } @@ -667,11 +844,12 @@ using BytesDebugStringTest = testing::TestWithParam; TEST_P(BytesDebugStringTest, ToCord) { const BytesDebugStringTestCase& test_case = GetParam(); - EXPECT_EQ(MakeStringBytes(test_case.data).DebugString(), + TestValueFactory value_factory; + EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->DebugString(), internal::FormatBytesLiteral(test_case.data)); - EXPECT_EQ(MakeCordBytes(test_case.data).DebugString(), + EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->DebugString(), internal::FormatBytesLiteral(test_case.data)); - EXPECT_EQ(MakeWrappedBytes(test_case.data).DebugString(), + EXPECT_EQ(MakeExternalBytes(value_factory, test_case.data)->DebugString(), internal::FormatBytesLiteral(test_case.data)); } @@ -691,10 +869,12 @@ using BytesToStringTest = testing::TestWithParam; TEST_P(BytesToStringTest, ToString) { const BytesToStringTestCase& test_case = GetParam(); - EXPECT_EQ(MakeStringBytes(test_case.data).AsBytes().ToString(), + TestValueFactory value_factory; + EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->ToString(), + test_case.data); + EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->ToString(), test_case.data); - EXPECT_EQ(MakeCordBytes(test_case.data).AsBytes().ToString(), test_case.data); - EXPECT_EQ(MakeWrappedBytes(test_case.data).AsBytes().ToString(), + EXPECT_EQ(MakeExternalBytes(value_factory, test_case.data)->ToString(), test_case.data); } @@ -714,9 +894,12 @@ using BytesToCordTest = testing::TestWithParam; TEST_P(BytesToCordTest, ToCord) { const BytesToCordTestCase& test_case = GetParam(); - EXPECT_EQ(MakeStringBytes(test_case.data).AsBytes().ToCord(), test_case.data); - EXPECT_EQ(MakeCordBytes(test_case.data).AsBytes().ToCord(), test_case.data); - EXPECT_EQ(MakeWrappedBytes(test_case.data).AsBytes().ToCord(), + TestValueFactory value_factory; + EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->ToCord(), + test_case.data); + EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->ToCord(), + test_case.data); + EXPECT_EQ(MakeExternalBytes(value_factory, test_case.data)->ToCord(), test_case.data); } @@ -729,19 +912,23 @@ INSTANTIATE_TEST_SUITE_P(BytesToCordTest, BytesToCordTest, })); TEST(Value, SupportsAbslHash) { + TestValueFactory value_factory; EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ - Value::Null(), - Value::Error(absl::CancelledError()), - Value::Bool(false), - Value::Int(0), - Value::Uint(0), - Value::Double(0.0), - Value::ZeroDuration(), - Value::UnixEpoch(), - Bytes::Empty(), - Bytes::New("foo"), - Bytes::New(absl::Cord("bar")), - Bytes::Wrap("baz", []() {}), + Persistent(value_factory.GetNullValue()), + Persistent( + value_factory.CreateErrorValue(absl::CancelledError())), + Persistent(value_factory.CreateBoolValue(false)), + Persistent(value_factory.CreateIntValue(0)), + Persistent(value_factory.CreateUintValue(0)), + Persistent(value_factory.CreateDoubleValue(0.0)), + Persistent( + Must(value_factory.CreateDurationValue(absl::ZeroDuration()))), + Persistent( + Must(value_factory.CreateTimestampValue(absl::UnixEpoch()))), + Persistent(value_factory.GetBytesValue()), + Persistent(Must(value_factory.CreateBytesValue("foo"))), + Persistent( + Must(value_factory.CreateBytesValue(absl::Cord("bar")))), })); } diff --git a/internal/BUILD b/internal/BUILD index f92794e89..3b8f43163 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -159,14 +159,6 @@ cc_test( ], ) -cc_library( - name = "reference_counted", - hdrs = ["reference_counted.h"], - deps = [ - "@com_google_absl//absl/base:core_headers", - ], -) - cc_library( name = "testing", testonly = True, diff --git a/internal/reference_counted.h b/internal/reference_counted.h deleted file mode 100644 index 87dcac1ba..000000000 --- a/internal/reference_counted.h +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_REFERENCE_COUNTED_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_REFERENCE_COUNTED_H_ - -#include -#include -#include - -#include "absl/base/macros.h" - -namespace cel::internal { - -class ReferenceCounted; - -void Ref(const ReferenceCounted& refcnt); -void Unref(const ReferenceCounted& refcnt); - -// To make life easier, we return the passed pointer so it can be used inline in -// places like constructors. To ensure this is only be used as intended, we use -// SFINAE. -template -std::enable_if_t, T*> Ref(T* refcnt); - -void Unref(const ReferenceCounted* refcnt); - -class ReferenceCounted { - public: - ReferenceCounted(const ReferenceCounted&) = delete; - - ReferenceCounted(ReferenceCounted&&) = delete; - - virtual ~ReferenceCounted() = default; - - ReferenceCounted& operator=(const ReferenceCounted&) = delete; - - ReferenceCounted& operator=(ReferenceCounted&&) = delete; - - protected: - constexpr ReferenceCounted() : refs_(1) {} - - private: - friend void Ref(const ReferenceCounted& refcnt); - friend void Unref(const ReferenceCounted& refcnt); - template - friend std::enable_if_t, T*> Ref( - T* refcnt); - friend void Unref(const ReferenceCounted* refcnt); - - void Ref() const { - const auto refs = refs_.fetch_add(1, std::memory_order_relaxed); - ABSL_ASSERT(refs >= 1); - } - - void Unref() const { - const auto refs = refs_.fetch_sub(1, std::memory_order_acq_rel); - ABSL_ASSERT(refs >= 1); - if (refs == 1) { - delete this; - } - } - - mutable std::atomic refs_; // NOLINT -}; - -inline void Ref(const ReferenceCounted& refcnt) { refcnt.Ref(); } - -inline void Unref(const ReferenceCounted& refcnt) { refcnt.Unref(); } - -template -inline std::enable_if_t, T*> Ref( - T* refcnt) { - if (refcnt != nullptr) { - (Ref)(*refcnt); - } - return refcnt; -} - -inline void Unref(const ReferenceCounted* refcnt) { - if (refcnt != nullptr) { - (Unref)(*refcnt); - } -} - -} // namespace cel::internal - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_REFERENCE_COUNTED_H_ From 36f0a84122d0a02653eaa29fad700622192d0ded Mon Sep 17 00:00:00 2001 From: jcking Date: Mon, 14 Mar 2022 19:32:51 +0000 Subject: [PATCH 025/111] Internal change PiperOrigin-RevId: 434537767 --- base/BUILD | 1 + base/internal/value.post.h | 114 ++++++++- base/internal/value.pre.h | 4 + base/type.h | 2 + base/value.cc | 222 +++++++++++++++++- base/value.h | 82 +++++++ base/value_factory.cc | 54 +++++ base/value_factory.h | 46 ++++ base/value_factory_test.cc | 8 + base/value_test.cc | 465 +++++++++++++++++++++++++++++++++++++ 10 files changed, 991 insertions(+), 7 deletions(-) diff --git a/base/BUILD b/base/BUILD index 516ec2f00..1df68443d 100644 --- a/base/BUILD +++ b/base/BUILD @@ -151,6 +151,7 @@ cc_library( "//internal:status_macros", "//internal:strings", "//internal:time", + "//internal:utf8", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", diff --git a/base/internal/value.post.h b/base/internal/value.post.h index 75ab7f7f1..d7a3fe752 100644 --- a/base/internal/value.post.h +++ b/base/internal/value.post.h @@ -130,8 +130,114 @@ class ExternalDataBytesValue final : public BytesValue { ExternalData value_; }; -// Class used to assert the object memory layout for vptr at compile time, -// otherwise it is unused. +// Implementation of StringValue that is stored inlined within a handle. Since +// absl::Cord is reference counted itself, this is more efficient then storing +// this on the heap. +class InlinedCordStringValue final : public StringValue, + public ResourceInlined { + private: + template + friend class ValueHandle; + + explicit InlinedCordStringValue(absl::Cord value) + : InlinedCordStringValue(0, std::move(value)) {} + + InlinedCordStringValue(size_t size, absl::Cord value) + : StringValue(size), value_(std::move(value)) {} + + InlinedCordStringValue() = delete; + + InlinedCordStringValue(const InlinedCordStringValue&) = default; + InlinedCordStringValue(InlinedCordStringValue&&) = default; + + // See comments for respective member functions on `StringValue` and `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + absl::Cord value_; +}; + +// Implementation of StringValue that is stored inlined within a handle. This +// class is inheritently unsafe and care should be taken when using it. +// Typically this should only be used for empty strings or data that is static +// and lives for the duration of a program. +class InlinedStringViewStringValue final : public StringValue, + public ResourceInlined { + private: + template + friend class ValueHandle; + + explicit InlinedStringViewStringValue(absl::string_view value) + : InlinedStringViewStringValue(0, value) {} + + InlinedStringViewStringValue(size_t size, absl::string_view value) + : StringValue(size), value_(value) {} + + InlinedStringViewStringValue() = delete; + + InlinedStringViewStringValue(const InlinedStringViewStringValue&) = default; + InlinedStringViewStringValue(InlinedStringViewStringValue&&) = default; + + // See comments for respective member functions on `StringValue` and `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + absl::string_view value_; +}; + +// Implementation of StringValue that uses std::string and is allocated on the +// heap, potentially reference counted. +class StringStringValue final : public StringValue { + private: + friend class cel::MemoryManager; + + explicit StringStringValue(std::string value) + : StringStringValue(0, std::move(value)) {} + + StringStringValue(size_t size, std::string value) + : StringValue(size), value_(std::move(value)) {} + + StringStringValue() = delete; + StringStringValue(const StringStringValue&) = delete; + StringStringValue(StringStringValue&&) = delete; + + // See comments for respective member functions on `StringValue` and `Value`. + std::pair SizeAndAlignment() const override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + std::string value_; +}; + +// Implementation of StringValue that wraps a contiguous array of bytes and +// calls the releaser when it is no longer needed. It is stored on the heap and +// potentially reference counted. +class ExternalDataStringValue final : public StringValue { + private: + friend class cel::MemoryManager; + + explicit ExternalDataStringValue(ExternalData value) + : ExternalDataStringValue(0, std::move(value)) {} + + ExternalDataStringValue(size_t size, ExternalData value) + : StringValue(size), value_(std::move(value)) {} + + ExternalDataStringValue() = delete; + ExternalDataStringValue(const ExternalDataStringValue&) = delete; + ExternalDataStringValue(ExternalDataStringValue&&) = delete; + + // See comments for respective member functions on `StringValue` and `Value`. + std::pair SizeAndAlignment() const override; + absl::Cord ToCord(bool reference_counted) const override; + Rep rep() const override; + + ExternalData value_; +}; + struct ABSL_ATTRIBUTE_UNUSED CheckVptrOffsetBase { virtual ~CheckVptrOffsetBase() = default; @@ -161,7 +267,8 @@ union ValueHandleData final { void* vptr; std::aligned_union_t padding; }; @@ -545,6 +652,7 @@ CEL_INTERNAL_VALUE_DECL(IntValue); CEL_INTERNAL_VALUE_DECL(UintValue); CEL_INTERNAL_VALUE_DECL(DoubleValue); CEL_INTERNAL_VALUE_DECL(BytesValue); +CEL_INTERNAL_VALUE_DECL(StringValue); CEL_INTERNAL_VALUE_DECL(DurationValue); CEL_INTERNAL_VALUE_DECL(TimestampValue); #undef CEL_INTERNAL_VALUE_DECL diff --git a/base/internal/value.pre.h b/base/internal/value.pre.h index 19ac1bca3..837e2f9d5 100644 --- a/base/internal/value.pre.h +++ b/base/internal/value.pre.h @@ -45,6 +45,10 @@ class InlinedCordBytesValue; class InlinedStringViewBytesValue; class StringBytesValue; class ExternalDataBytesValue; +class InlinedCordStringValue; +class InlinedStringViewStringValue; +class StringStringValue; +class ExternalDataStringValue; // Type erased state capable of holding a pointer to remote storage or storing // objects less than two pointers in size inline. diff --git a/base/type.h b/base/type.h index 03433af9a..87093a183 100644 --- a/base/type.h +++ b/base/type.h @@ -51,6 +51,7 @@ class IntValue; class UintValue; class DoubleValue; class BytesValue; +class StringValue; class DurationValue; class TimestampValue; class ValueFactory; @@ -315,6 +316,7 @@ class StringType final : public Type { absl::string_view name() const override { return "string"; } private: + friend class StringValue; friend class TypeFactory; template friend class internal::NoDestructor; diff --git a/base/value.cc b/base/value.cc index 396f87271..9d6ecc948 100644 --- a/base/value.cc +++ b/base/value.cc @@ -15,6 +15,7 @@ #include "base/value.h" #include +#include #include #include #include @@ -40,6 +41,7 @@ #include "internal/status_macros.h" #include "internal/strings.h" #include "internal/time.h" +#include "internal/utf8.h" namespace cel { @@ -56,6 +58,7 @@ CEL_INTERNAL_VALUE_IMPL(IntValue); CEL_INTERNAL_VALUE_IMPL(UintValue); CEL_INTERNAL_VALUE_IMPL(DoubleValue); CEL_INTERNAL_VALUE_IMPL(BytesValue); +CEL_INTERNAL_VALUE_IMPL(StringValue); CEL_INTERNAL_VALUE_IMPL(DurationValue); CEL_INTERNAL_VALUE_IMPL(TimestampValue); #undef CEL_INTERNAL_VALUE_IMPL @@ -392,7 +395,7 @@ void TimestampValue::HashValue(absl::HashState state) const { namespace { -struct DebugStringVisitor final { +struct BytesValueDebugStringVisitor final { std::string operator()(absl::string_view value) const { return internal::FormatBytesLiteral(value); } @@ -406,6 +409,20 @@ struct DebugStringVisitor final { } }; +struct StringValueDebugStringVisitor final { + std::string operator()(absl::string_view value) const { + return internal::FormatStringLiteral(value); + } + + std::string operator()(const absl::Cord& value) const { + absl::string_view flat; + if (value.GetFlat(&flat)) { + return internal::FormatStringLiteral(flat); + } + return internal::FormatStringLiteral(static_cast(value)); + } +}; + struct ToStringVisitor final { std::string operator()(absl::string_view value) const { return std::string(value); @@ -416,12 +433,22 @@ struct ToStringVisitor final { } }; -struct SizeVisitor final { +struct BytesValueSizeVisitor final { size_t operator()(absl::string_view value) const { return value.size(); } size_t operator()(const absl::Cord& value) const { return value.size(); } }; +struct StringValueSizeVisitor final { + size_t operator()(absl::string_view value) const { + return internal::Utf8CodePointCount(value); + } + + size_t operator()(const absl::Cord& value) const { + return internal::Utf8CodePointCount(value); + } +}; + struct EmptyVisitor final { bool operator()(absl::string_view value) const { return value.empty(); } @@ -490,6 +517,19 @@ class EqualsVisitor final { const BytesValue& ref_; }; +template <> +class EqualsVisitor final { + public: + explicit EqualsVisitor(const StringValue& ref) : ref_(ref) {} + + bool operator()(absl::string_view value) const { return ref_.Equals(value); } + + bool operator()(const absl::Cord& value) const { return ref_.Equals(value); } + + private: + const StringValue& ref_; +}; + template class CompareVisitor final { public: @@ -520,6 +560,19 @@ class CompareVisitor final { const BytesValue& ref_; }; +template <> +class CompareVisitor final { + public: + explicit CompareVisitor(const StringValue& ref) : ref_(ref) {} + + int operator()(const absl::Cord& value) const { return ref_.Compare(value); } + + int operator()(absl::string_view value) const { return ref_.Compare(value); } + + private: + const StringValue& ref_; +}; + class HashValueVisitor final { public: explicit HashValueVisitor(absl::HashState state) : state_(std::move(state)) {} @@ -555,7 +608,9 @@ Transient BytesValue::type() const { BytesType::Get()); } -size_t BytesValue::size() const { return absl::visit(SizeVisitor{}, rep()); } +size_t BytesValue::size() const { + return absl::visit(BytesValueSizeVisitor{}, rep()); +} bool BytesValue::empty() const { return absl::visit(EmptyVisitor{}, rep()); } @@ -588,7 +643,7 @@ std::string BytesValue::ToString() const { } std::string BytesValue::DebugString() const { - return absl::visit(DebugStringVisitor{}, rep()); + return absl::visit(BytesValueDebugStringVisitor{}, rep()); } bool BytesValue::Equals(const Value& other) const { @@ -603,6 +658,90 @@ void BytesValue::HashValue(absl::HashState state) const { rep()); } +Persistent StringValue::Empty(ValueFactory& value_factory) { + return value_factory.GetStringValue(); +} + +absl::StatusOr> StringValue::Concat( + ValueFactory& value_factory, const Transient& lhs, + const Transient& rhs) { + absl::Cord cord = lhs->ToCord(base_internal::IsManagedHandle(lhs)); + cord.Append(rhs->ToCord(base_internal::IsManagedHandle(rhs))); + size_t size = 0; + size_t lhs_size = lhs->size_.load(std::memory_order_relaxed); + if (lhs_size != 0 && !lhs->empty()) { + size_t rhs_size = rhs->size_.load(std::memory_order_relaxed); + if (rhs_size != 0 && !rhs->empty()) { + size = lhs_size + rhs_size; + } + } + return value_factory.CreateStringValue(std::move(cord), size); +} + +Transient StringValue::type() const { + return TransientHandleFactory::MakeUnmanaged( + StringType::Get()); +} + +size_t StringValue::size() const { + // We lazily calculate the code point count in some circumstances. If the code + // point count is 0 and the underlying rep is not empty we need to actually + // calculate the size. It is okay if this is done by multiple threads + // simultaneously, it is a benign race. + size_t size = size_.load(std::memory_order_relaxed); + if (size == 0 && !empty()) { + size = absl::visit(StringValueSizeVisitor{}, rep()); + size_.store(size, std::memory_order_relaxed); + } + return size; +} + +bool StringValue::empty() const { return absl::visit(EmptyVisitor{}, rep()); } + +bool StringValue::Equals(absl::string_view string) const { + return absl::visit(EqualsVisitor(string), rep()); +} + +bool StringValue::Equals(const absl::Cord& string) const { + return absl::visit(EqualsVisitor(string), rep()); +} + +bool StringValue::Equals(const Transient& string) const { + return absl::visit(EqualsVisitor(*this), string->rep()); +} + +int StringValue::Compare(absl::string_view string) const { + return absl::visit(CompareVisitor(string), rep()); +} + +int StringValue::Compare(const absl::Cord& string) const { + return absl::visit(CompareVisitor(string), rep()); +} + +int StringValue::Compare(const Transient& string) const { + return absl::visit(CompareVisitor(*this), string->rep()); +} + +std::string StringValue::ToString() const { + return absl::visit(ToStringVisitor{}, rep()); +} + +std::string StringValue::DebugString() const { + return absl::visit(StringValueDebugStringVisitor{}, rep()); +} + +bool StringValue::Equals(const Value& other) const { + return kind() == other.kind() && + absl::visit(EqualsVisitor(*this), + internal::down_cast(other).rep()); +} + +void StringValue::HashValue(absl::HashState state) const { + absl::visit( + HashValueVisitor(absl::HashState::combine(std::move(state), type())), + rep()); +} + namespace base_internal { absl::Cord InlinedCordBytesValue::ToCord(bool reference_counted) const { @@ -680,6 +819,81 @@ typename ExternalDataBytesValue::Rep ExternalDataBytesValue::rep() const { absl::string_view(static_cast(value_.data), value_.size)); } +absl::Cord InlinedCordStringValue::ToCord(bool reference_counted) const { + static_cast(reference_counted); + return value_; +} + +void InlinedCordStringValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(InlinedCordStringValue, *this, address); +} + +void InlinedCordStringValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(InlinedCordStringValue, *this, address); +} + +typename InlinedCordStringValue::Rep InlinedCordStringValue::rep() const { + return Rep(absl::in_place_type>, + std::cref(value_)); +} + +absl::Cord InlinedStringViewStringValue::ToCord(bool reference_counted) const { + static_cast(reference_counted); + return absl::Cord(value_); +} + +void InlinedStringViewStringValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(InlinedStringViewStringValue, *this, address); +} + +void InlinedStringViewStringValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(InlinedStringViewStringValue, *this, address); +} + +typename InlinedStringViewStringValue::Rep InlinedStringViewStringValue::rep() + const { + return Rep(absl::in_place_type, value_); +} + +std::pair StringStringValue::SizeAndAlignment() const { + return std::make_pair(sizeof(StringStringValue), alignof(StringStringValue)); +} + +absl::Cord StringStringValue::ToCord(bool reference_counted) const { + if (reference_counted) { + Ref(); + return absl::MakeCordFromExternal(absl::string_view(value_), + [this]() { Unref(); }); + } + return absl::Cord(value_); +} + +typename StringStringValue::Rep StringStringValue::rep() const { + return Rep(absl::in_place_type, absl::string_view(value_)); +} + +std::pair ExternalDataStringValue::SizeAndAlignment() const { + return std::make_pair(sizeof(ExternalDataStringValue), + alignof(ExternalDataStringValue)); +} + +absl::Cord ExternalDataStringValue::ToCord(bool reference_counted) const { + if (reference_counted) { + Ref(); + return absl::MakeCordFromExternal( + absl::string_view(static_cast(value_.data), value_.size), + [this]() { Unref(); }); + } + return absl::Cord( + absl::string_view(static_cast(value_.data), value_.size)); +} + +typename ExternalDataStringValue::Rep ExternalDataStringValue::rep() const { + return Rep( + absl::in_place_type, + absl::string_view(static_cast(value_.data), value_.size)); +} + } // namespace base_internal } // namespace cel diff --git a/base/value.h b/base/value.h index c123753bf..483eaa21a 100644 --- a/base/value.h +++ b/base/value.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_VALUE_H_ #define THIRD_PARTY_CEL_CPP_BASE_VALUE_H_ +#include #include #include #include @@ -43,6 +44,7 @@ class IntValue; class UintValue; class DoubleValue; class BytesValue; +class StringValue; class DurationValue; class TimestampValue; class ValueFactory; @@ -74,11 +76,14 @@ class Value : public base_internal::Resource { friend class UintValue; friend class DoubleValue; friend class BytesValue; + friend class StringValue; friend class DurationValue; friend class TimestampValue; friend class base_internal::ValueHandleBase; friend class base_internal::StringBytesValue; friend class base_internal::ExternalDataBytesValue; + friend class base_internal::StringStringValue; + friend class base_internal::ExternalDataStringValue; Value() = default; Value(const Value&) = default; @@ -407,6 +412,83 @@ class BytesValue : public Value { void HashValue(absl::HashState state) const final; }; +class StringValue : public Value { + protected: + using Rep = absl::variant>; + + public: + static Persistent Empty(ValueFactory& value_factory); + + static absl::StatusOr> Concat( + ValueFactory& value_factory, const Transient& lhs, + const Transient& rhs); + + Transient type() const final; + + Kind kind() const final { return Kind::kString; } + + std::string DebugString() const final; + + size_t size() const; + + bool empty() const; + + bool Equals(absl::string_view string) const; + bool Equals(const absl::Cord& string) const; + bool Equals(const Transient& string) const; + + int Compare(absl::string_view string) const; + int Compare(const absl::Cord& string) const; + int Compare(const Transient& string) const; + + std::string ToString() const; + + absl::Cord ToCord() const { + // Without the handle we cannot know if this is reference counted. + return ToCord(/*reference_counted=*/false); + } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + friend class base_internal::InlinedCordStringValue; + friend class base_internal::InlinedStringViewStringValue; + friend class base_internal::StringStringValue; + friend class base_internal::ExternalDataStringValue; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kString; } + + explicit StringValue(size_t size) : size_(size) {} + + StringValue() = default; + + StringValue(const StringValue& other) + : StringValue(other.size_.load(std::memory_order_relaxed)) {} + + StringValue(StringValue&& other) + : StringValue(other.size_.exchange(0, std::memory_order_relaxed)) {} + + // Get the contents of this BytesValue as absl::Cord. When reference_counted + // is true, the implementation can potentially return an absl::Cord that wraps + // the contents instead of copying. + virtual absl::Cord ToCord(bool reference_counted) const = 0; + + // Get the contents of this StringValue as either absl::string_view or const + // absl::Cord&. + virtual Rep rep() const = 0; + + // See comments for respective member functions on `Value`. + bool Equals(const Value& other) const final; + void HashValue(absl::HashState state) const final; + + // Lazily cached code point count. + mutable std::atomic size_ = 0; +}; + class DurationValue final : public Value, public base_internal::ResourceInlined { public: diff --git a/base/value_factory.cc b/base/value_factory.cc index 59fcac40c..d6831f9eb 100644 --- a/base/value_factory.cc +++ b/base/value_factory.cc @@ -24,16 +24,21 @@ #include "base/value.h" #include "internal/status_macros.h" #include "internal/time.h" +#include "internal/utf8.h" namespace cel { namespace { using base_internal::ExternalDataBytesValue; +using base_internal::ExternalDataStringValue; using base_internal::InlinedCordBytesValue; +using base_internal::InlinedCordStringValue; using base_internal::InlinedStringViewBytesValue; +using base_internal::InlinedStringViewStringValue; using base_internal::PersistentHandleFactory; using base_internal::StringBytesValue; +using base_internal::StringStringValue; using base_internal::TransientHandleFactory; } // namespace @@ -89,6 +94,35 @@ absl::StatusOr> ValueFactory::CreateBytesValue( std::move(value)); } +absl::StatusOr> ValueFactory::CreateStringValue( + std::string value) { + // Avoid persisting empty strings which may have underlying storage after + // mutating. + if (value.empty()) { + return GetEmptyStringValue(); + } + auto [count, ok] = internal::Utf8Validate(value); + if (ABSL_PREDICT_FALSE(!ok)) { + return absl::InvalidArgumentError( + "Illegal byte sequence in UTF-8 encoded string"); + } + return PersistentHandleFactory::Make( + memory_manager(), count, std::move(value)); +} + +absl::StatusOr> ValueFactory::CreateStringValue( + absl::Cord value) { + if (value.empty()) { + return GetEmptyStringValue(); + } + auto [count, ok] = internal::Utf8Validate(value); + if (ABSL_PREDICT_FALSE(!ok)) { + return absl::InvalidArgumentError( + "Illegal byte sequence in UTF-8 encoded string"); + } + return CreateStringValue(std::move(value), count); +} + absl::StatusOr> ValueFactory::CreateDurationValue(absl::Duration value) { CEL_RETURN_IF_ERROR(internal::ValidateDuration(value)); @@ -114,4 +148,24 @@ absl::StatusOr> ValueFactory::CreateBytesValue( ExternalDataBytesValue>(memory_manager(), std::move(value)); } +Persistent ValueFactory::GetEmptyStringValue() { + return PersistentHandleFactory::Make< + InlinedStringViewStringValue>(absl::string_view()); +} + +absl::StatusOr> ValueFactory::CreateStringValue( + absl::Cord value, size_t size) { + if (value.empty()) { + return GetEmptyStringValue(); + } + return PersistentHandleFactory::Make< + InlinedCordStringValue>(size, std::move(value)); +} + +absl::StatusOr> ValueFactory::CreateStringValue( + base_internal::ExternalData value) { + return PersistentHandleFactory::Make< + ExternalDataStringValue>(memory_manager(), std::move(value)); +} + } // namespace cel diff --git a/base/value_factory.h b/base/value_factory.h index 02b0d32ca..ab9ce7559 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -18,6 +18,7 @@ #include #include #include +#include #include "absl/base/attributes.h" #include "absl/status/status.h" @@ -86,6 +87,40 @@ class ValueFactory { std::forward(releaser)))); } + Persistent GetStringValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return GetEmptyStringValue(); + } + + absl::StatusOr> CreateStringValue( + const char* value) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return CreateStringValue(absl::string_view(value)); + } + + absl::StatusOr> CreateStringValue( + absl::string_view value) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return CreateStringValue(std::string(value)); + } + + absl::StatusOr> CreateStringValue( + std::string value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::StatusOr> CreateStringValue( + absl::Cord value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + + template + absl::StatusOr> CreateStringValue( + absl::string_view value, + Releaser&& releaser) ABSL_ATTRIBUTE_LIFETIME_BOUND { + if (value.empty()) { + std::forward(releaser)(); + return GetEmptyStringValue(); + } + return CreateStringValue(base_internal::ExternalData( + static_cast(value.data()), value.size(), + std::make_unique( + std::forward(releaser)))); + } + absl::StatusOr> CreateDurationValue( absl::Duration value) ABSL_ATTRIBUTE_LIFETIME_BOUND; @@ -100,12 +135,23 @@ class ValueFactory { MemoryManager& memory_manager() const { return memory_manager_; } private: + friend class StringValue; + Persistent GetEmptyBytesValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::StatusOr> CreateBytesValue( base_internal::ExternalData value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + Persistent GetEmptyStringValue() + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::StatusOr> CreateStringValue( + absl::Cord value, size_t size) ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::StatusOr> CreateStringValue( + base_internal::ExternalData value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + MemoryManager& memory_manager_; }; diff --git a/base/value_factory_test.cc b/base/value_factory_test.cc index 346df9583..d873bbd50 100644 --- a/base/value_factory_test.cc +++ b/base/value_factory_test.cc @@ -34,5 +34,13 @@ TEST(ValueFactory, CreateErrorValueReplacesOk) { StatusIs(absl::StatusCode::kUnknown)); } +TEST(ValueFactory, CreateStringValueIllegalByteSequence) { + TestValueFactory value_factory; + EXPECT_THAT(value_factory.CreateStringValue("\xff"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(value_factory.CreateStringValue(absl::Cord("\xff")), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + } // namespace } // namespace cel diff --git a/base/value_test.cc b/base/value_test.cc index 8f7e87775..92914fa8c 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -580,6 +580,136 @@ TEST(Value, BytesFromExternal) { EXPECT_NE(one_value, zero_value); } +TEST(Value, StringFromString) { + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = Must(value_factory.CreateStringValue(std::string("0"))); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, + Must(value_factory.CreateStringValue(std::string("0")))); + EXPECT_EQ(zero_value->kind(), Kind::kString); + EXPECT_EQ(zero_value->type(), type_factory.GetStringType()); + EXPECT_EQ(zero_value->ToString(), "0"); + + auto one_value = Must(value_factory.CreateStringValue(std::string("1"))); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Must(value_factory.CreateStringValue(std::string("1")))); + EXPECT_EQ(one_value->kind(), Kind::kString); + EXPECT_EQ(one_value->type(), type_factory.GetStringType()); + EXPECT_EQ(one_value->ToString(), "1"); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(Value, StringFromStringView) { + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = + Must(value_factory.CreateStringValue(absl::string_view("0"))); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, + Must(value_factory.CreateStringValue(absl::string_view("0")))); + EXPECT_EQ(zero_value->kind(), Kind::kString); + EXPECT_EQ(zero_value->type(), type_factory.GetStringType()); + EXPECT_EQ(zero_value->ToString(), "0"); + + auto one_value = + Must(value_factory.CreateStringValue(absl::string_view("1"))); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, + Must(value_factory.CreateStringValue(absl::string_view("1")))); + EXPECT_EQ(one_value->kind(), Kind::kString); + EXPECT_EQ(one_value->type(), type_factory.GetStringType()); + EXPECT_EQ(one_value->ToString(), "1"); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(Value, StringFromCord) { + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = Must(value_factory.CreateStringValue(absl::Cord("0"))); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Must(value_factory.CreateStringValue(absl::Cord("0")))); + EXPECT_EQ(zero_value->kind(), Kind::kString); + EXPECT_EQ(zero_value->type(), type_factory.GetStringType()); + EXPECT_EQ(zero_value->ToCord(), "0"); + + auto one_value = Must(value_factory.CreateStringValue(absl::Cord("1"))); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Must(value_factory.CreateStringValue(absl::Cord("1")))); + EXPECT_EQ(one_value->kind(), Kind::kString); + EXPECT_EQ(one_value->type(), type_factory.GetStringType()); + EXPECT_EQ(one_value->ToCord(), "1"); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(Value, StringFromLiteral) { + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = Must(value_factory.CreateStringValue("0")); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Must(value_factory.CreateStringValue("0"))); + EXPECT_EQ(zero_value->kind(), Kind::kString); + EXPECT_EQ(zero_value->type(), type_factory.GetStringType()); + EXPECT_EQ(zero_value->ToString(), "0"); + + auto one_value = Must(value_factory.CreateStringValue("1")); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Must(value_factory.CreateStringValue("1"))); + EXPECT_EQ(one_value->kind(), Kind::kString); + EXPECT_EQ(one_value->type(), type_factory.GetStringType()); + EXPECT_EQ(one_value->ToString(), "1"); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(Value, StringFromExternal) { + TestValueFactory value_factory; + TestTypeFactory type_factory; + auto zero_value = Must(value_factory.CreateStringValue("0", []() {})); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Must(value_factory.CreateStringValue("0", []() {}))); + EXPECT_EQ(zero_value->kind(), Kind::kString); + EXPECT_EQ(zero_value->type(), type_factory.GetStringType()); + EXPECT_EQ(zero_value->ToString(), "0"); + + auto one_value = Must(value_factory.CreateStringValue("1", []() {})); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Must(value_factory.CreateStringValue("1", []() {}))); + EXPECT_EQ(one_value->kind(), Kind::kString); + EXPECT_EQ(one_value->type(), type_factory.GetStringType()); + EXPECT_EQ(one_value->ToString(), "1"); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + Persistent MakeStringBytes(ValueFactory& value_factory, absl::string_view value) { return Must(value_factory.CreateBytesValue(value)); @@ -911,6 +1041,337 @@ INSTANTIATE_TEST_SUITE_P(BytesToCordTest, BytesToCordTest, {"\xef\xbf\xbd"}, })); +Persistent MakeStringString(ValueFactory& value_factory, + absl::string_view value) { + return Must(value_factory.CreateStringValue(value)); +} + +Persistent MakeCordString(ValueFactory& value_factory, + absl::string_view value) { + return Must(value_factory.CreateStringValue(absl::Cord(value))); +} + +Persistent MakeExternalString(ValueFactory& value_factory, + absl::string_view value) { + return Must(value_factory.CreateStringValue(value, []() {})); +} + +struct StringConcatTestCase final { + std::string lhs; + std::string rhs; +}; + +using StringConcatTest = testing::TestWithParam; + +TEST_P(StringConcatTest, Concat) { + const StringConcatTestCase& test_case = GetParam(); + TestValueFactory value_factory; + EXPECT_TRUE( + Must(StringValue::Concat(value_factory, + MakeStringString(value_factory, test_case.lhs), + MakeStringString(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(StringValue::Concat(value_factory, + MakeStringString(value_factory, test_case.lhs), + MakeCordString(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(StringValue::Concat( + value_factory, MakeStringString(value_factory, test_case.lhs), + MakeExternalString(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(StringValue::Concat(value_factory, + MakeCordString(value_factory, test_case.lhs), + MakeStringString(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(StringValue::Concat(value_factory, + MakeCordString(value_factory, test_case.lhs), + MakeCordString(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(StringValue::Concat( + value_factory, MakeCordString(value_factory, test_case.lhs), + MakeExternalString(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(StringValue::Concat(value_factory, + MakeExternalString(value_factory, test_case.lhs), + MakeStringString(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(StringValue::Concat(value_factory, + MakeExternalString(value_factory, test_case.lhs), + MakeCordString(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); + EXPECT_TRUE( + Must(StringValue::Concat( + value_factory, MakeExternalString(value_factory, test_case.lhs), + MakeExternalString(value_factory, test_case.rhs))) + ->Equals(test_case.lhs + test_case.rhs)); +} + +INSTANTIATE_TEST_SUITE_P(StringConcatTest, StringConcatTest, + testing::ValuesIn({ + {"", ""}, + {"", std::string("\0", 1)}, + {std::string("\0", 1), ""}, + {std::string("\0", 1), std::string("\0", 1)}, + {"", "foo"}, + {"foo", ""}, + {"foo", "foo"}, + {"bar", "foo"}, + {"foo", "bar"}, + {"bar", "bar"}, + })); + +struct StringSizeTestCase final { + std::string data; + size_t size; +}; + +using StringSizeTest = testing::TestWithParam; + +TEST_P(StringSizeTest, Size) { + const StringSizeTestCase& test_case = GetParam(); + TestValueFactory value_factory; + EXPECT_EQ(MakeStringString(value_factory, test_case.data)->size(), + test_case.size); + EXPECT_EQ(MakeCordString(value_factory, test_case.data)->size(), + test_case.size); + EXPECT_EQ(MakeExternalString(value_factory, test_case.data)->size(), + test_case.size); +} + +INSTANTIATE_TEST_SUITE_P(StringSizeTest, StringSizeTest, + testing::ValuesIn({ + {"", 0}, + {"1", 1}, + {"foo", 3}, + {"\xef\xbf\xbd", 1}, + })); + +struct StringEmptyTestCase final { + std::string data; + bool empty; +}; + +using StringEmptyTest = testing::TestWithParam; + +TEST_P(StringEmptyTest, Empty) { + const StringEmptyTestCase& test_case = GetParam(); + TestValueFactory value_factory; + EXPECT_EQ(MakeStringString(value_factory, test_case.data)->empty(), + test_case.empty); + EXPECT_EQ(MakeCordString(value_factory, test_case.data)->empty(), + test_case.empty); + EXPECT_EQ(MakeExternalString(value_factory, test_case.data)->empty(), + test_case.empty); +} + +INSTANTIATE_TEST_SUITE_P(StringEmptyTest, StringEmptyTest, + testing::ValuesIn({ + {"", true}, + {std::string("\0", 1), false}, + {"1", false}, + })); + +struct StringEqualsTestCase final { + std::string lhs; + std::string rhs; + bool equals; +}; + +using StringEqualsTest = testing::TestWithParam; + +TEST_P(StringEqualsTest, Equals) { + const StringEqualsTestCase& test_case = GetParam(); + TestValueFactory value_factory; + EXPECT_EQ(MakeStringString(value_factory, test_case.lhs) + ->Equals(MakeStringString(value_factory, test_case.rhs)), + test_case.equals); + EXPECT_EQ(MakeStringString(value_factory, test_case.lhs) + ->Equals(MakeCordString(value_factory, test_case.rhs)), + test_case.equals); + EXPECT_EQ(MakeStringString(value_factory, test_case.lhs) + ->Equals(MakeExternalString(value_factory, test_case.rhs)), + test_case.equals); + EXPECT_EQ(MakeCordString(value_factory, test_case.lhs) + ->Equals(MakeStringString(value_factory, test_case.rhs)), + test_case.equals); + EXPECT_EQ(MakeCordString(value_factory, test_case.lhs) + ->Equals(MakeCordString(value_factory, test_case.rhs)), + test_case.equals); + EXPECT_EQ(MakeCordString(value_factory, test_case.lhs) + ->Equals(MakeExternalString(value_factory, test_case.rhs)), + test_case.equals); + EXPECT_EQ(MakeExternalString(value_factory, test_case.lhs) + ->Equals(MakeStringString(value_factory, test_case.rhs)), + test_case.equals); + EXPECT_EQ(MakeExternalString(value_factory, test_case.lhs) + ->Equals(MakeCordString(value_factory, test_case.rhs)), + test_case.equals); + EXPECT_EQ(MakeExternalString(value_factory, test_case.lhs) + ->Equals(MakeExternalString(value_factory, test_case.rhs)), + test_case.equals); +} + +INSTANTIATE_TEST_SUITE_P(StringEqualsTest, StringEqualsTest, + testing::ValuesIn({ + {"", "", true}, + {"", std::string("\0", 1), false}, + {std::string("\0", 1), "", false}, + {std::string("\0", 1), std::string("\0", 1), true}, + {"", "foo", false}, + {"foo", "", false}, + {"foo", "foo", true}, + {"bar", "foo", false}, + {"foo", "bar", false}, + {"bar", "bar", true}, + })); + +struct StringCompareTestCase final { + std::string lhs; + std::string rhs; + int compare; +}; + +using StringCompareTest = testing::TestWithParam; + +TEST_P(StringCompareTest, Equals) { + const StringCompareTestCase& test_case = GetParam(); + TestValueFactory value_factory; + EXPECT_EQ(NormalizeCompareResult( + MakeStringString(value_factory, test_case.lhs) + ->Compare(MakeStringString(value_factory, test_case.rhs))), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeStringString(value_factory, test_case.lhs) + ->Compare(MakeCordString(value_factory, test_case.rhs))), + test_case.compare); + EXPECT_EQ( + NormalizeCompareResult( + MakeStringString(value_factory, test_case.lhs) + ->Compare(MakeExternalString(value_factory, test_case.rhs))), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeCordString(value_factory, test_case.lhs) + ->Compare(MakeStringString(value_factory, test_case.rhs))), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeCordString(value_factory, test_case.lhs) + ->Compare(MakeCordString(value_factory, test_case.rhs))), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult(MakeCordString(value_factory, test_case.lhs) + ->Compare(MakeExternalString( + value_factory, test_case.rhs))), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeExternalString(value_factory, test_case.lhs) + ->Compare(MakeStringString(value_factory, test_case.rhs))), + test_case.compare); + EXPECT_EQ(NormalizeCompareResult( + MakeExternalString(value_factory, test_case.lhs) + ->Compare(MakeCordString(value_factory, test_case.rhs))), + test_case.compare); + EXPECT_EQ( + NormalizeCompareResult( + MakeExternalString(value_factory, test_case.lhs) + ->Compare(MakeExternalString(value_factory, test_case.rhs))), + test_case.compare); +} + +INSTANTIATE_TEST_SUITE_P(StringCompareTest, StringCompareTest, + testing::ValuesIn({ + {"", "", 0}, + {"", std::string("\0", 1), -1}, + {std::string("\0", 1), "", 1}, + {std::string("\0", 1), std::string("\0", 1), 0}, + {"", "foo", -1}, + {"foo", "", 1}, + {"foo", "foo", 0}, + {"bar", "foo", -1}, + {"foo", "bar", 1}, + {"bar", "bar", 0}, + })); + +struct StringDebugStringTestCase final { + std::string data; +}; + +using StringDebugStringTest = testing::TestWithParam; + +TEST_P(StringDebugStringTest, ToCord) { + const StringDebugStringTestCase& test_case = GetParam(); + TestValueFactory value_factory; + EXPECT_EQ(MakeStringString(value_factory, test_case.data)->DebugString(), + internal::FormatStringLiteral(test_case.data)); + EXPECT_EQ(MakeCordString(value_factory, test_case.data)->DebugString(), + internal::FormatStringLiteral(test_case.data)); + EXPECT_EQ(MakeExternalString(value_factory, test_case.data)->DebugString(), + internal::FormatStringLiteral(test_case.data)); +} + +INSTANTIATE_TEST_SUITE_P(StringDebugStringTest, StringDebugStringTest, + testing::ValuesIn({ + {""}, + {"1"}, + {"foo"}, + {"\xef\xbf\xbd"}, + })); + +struct StringToStringTestCase final { + std::string data; +}; + +using StringToStringTest = testing::TestWithParam; + +TEST_P(StringToStringTest, ToString) { + const StringToStringTestCase& test_case = GetParam(); + TestValueFactory value_factory; + EXPECT_EQ(MakeStringString(value_factory, test_case.data)->ToString(), + test_case.data); + EXPECT_EQ(MakeCordString(value_factory, test_case.data)->ToString(), + test_case.data); + EXPECT_EQ(MakeExternalString(value_factory, test_case.data)->ToString(), + test_case.data); +} + +INSTANTIATE_TEST_SUITE_P(StringToStringTest, StringToStringTest, + testing::ValuesIn({ + {""}, + {"1"}, + {"foo"}, + {"\xef\xbf\xbd"}, + })); + +struct StringToCordTestCase final { + std::string data; +}; + +using StringToCordTest = testing::TestWithParam; + +TEST_P(StringToCordTest, ToCord) { + const StringToCordTestCase& test_case = GetParam(); + TestValueFactory value_factory; + EXPECT_EQ(MakeStringString(value_factory, test_case.data)->ToCord(), + test_case.data); + EXPECT_EQ(MakeCordString(value_factory, test_case.data)->ToCord(), + test_case.data); + EXPECT_EQ(MakeExternalString(value_factory, test_case.data)->ToCord(), + test_case.data); +} + +INSTANTIATE_TEST_SUITE_P(StringToCordTest, StringToCordTest, + testing::ValuesIn({ + {""}, + {"1"}, + {"foo"}, + {"\xef\xbf\xbd"}, + })); + TEST(Value, SupportsAbslHash) { TestValueFactory value_factory; EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ @@ -929,6 +1390,10 @@ TEST(Value, SupportsAbslHash) { Persistent(Must(value_factory.CreateBytesValue("foo"))), Persistent( Must(value_factory.CreateBytesValue(absl::Cord("bar")))), + Persistent(value_factory.GetStringValue()), + Persistent(Must(value_factory.CreateStringValue("foo"))), + Persistent( + Must(value_factory.CreateStringValue(absl::Cord("bar")))), })); } From 0398fa834236070b434ec8b03150d99ea5fb1a81 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 16 Mar 2022 19:37:38 +0000 Subject: [PATCH 026/111] Introduce type provider API and LegacyTypeAdapter with similar functionality to cel::Types. This introduces a transitional API to refactor the evaluator to be structurally compatible with new Type APIs without forcing a swap to the new type system. PiperOrigin-RevId: 435127247 --- base/BUILD | 9 ++ base/type_provider.h | 50 ++++++++ eval/public/structs/BUILD | 23 ++++ eval/public/structs/legacy_type_adapter.h | 112 ++++++++++++++++++ .../structs/legacy_type_adapter_test.cc | 56 +++++++++ 5 files changed, 250 insertions(+) create mode 100644 base/type_provider.h create mode 100644 eval/public/structs/legacy_type_adapter.h create mode 100644 eval/public/structs/legacy_type_adapter_test.cc diff --git a/base/BUILD b/base/BUILD index 1df68443d..6f1ee5ce2 100644 --- a/base/BUILD +++ b/base/BUILD @@ -130,6 +130,15 @@ cc_test( ], ) +cc_library( + name = "type_provider", + hdrs = ["type_provider.h"], + deps = [ + "//eval/public/structs:legacy_type_adapter", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "value", srcs = [ diff --git a/base/type_provider.h b/base/type_provider.h new file mode 100644 index 000000000..f4359d9eb --- /dev/null +++ b/base/type_provider.h @@ -0,0 +1,50 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ + +#include "absl/strings/string_view.h" +#include "eval/public/structs/legacy_type_adapter.h" + +namespace cel { + +// Interface for a TypeProvider, allowing host applications to inject +// functionality for operating on custom types in the CEL interpreter. +// +// Type providers are registered with a TypeRegistry. When resolving a type, +// the registry will check if it is a well known type, then check against each +// of the registered providers. If the type can't be resolved, the operation +// will result in an error. +// +// Note: This API is not finalized. Consult the CEL team before introducing new +// implementations. +class TypeProvider { + public: + virtual ~TypeProvider() = default; + + // Return LegacyTypeAdapter for the fully qualified type name if available. + // + // nullopt values are interpreted as not present. + // + // Returned non-null pointers from the adapter implemententation must remain + // valid as long as the type provider. + // TODO(issues/5): add alternative for new type system. + virtual absl::optional + ProvideLegacyType(absl::string_view name) const = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 5ed70a3a0..0ccfb40b6 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -85,3 +85,26 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "legacy_type_adapter", + hdrs = ["legacy_type_adapter.h"], + deps = [ + "//base:memory_manager", + "//eval/public:cel_value", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "legacy_type_adapter_test", + srcs = ["legacy_type_adapter_test.cc"], + deps = [ + ":legacy_type_adapter", + "//eval/public:cel_value", + "//extensions/protobuf:memory_manager", + "//internal:status_macros", + "//internal:testing", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h new file mode 100644 index 000000000..fbefe1c35 --- /dev/null +++ b/eval/public/structs/legacy_type_adapter.h @@ -0,0 +1,112 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ + +#include "absl/status/status.h" +#include "base/memory_manager.h" +#include "eval/public/cel_value.h" + +namespace google::api::expr::runtime { + +// Type information about a legacy Struct type. +// Provides methods to the interpreter for interacting with a custom type. +// +// This provides Apis for emulating the behavior of new types working on +// existing cel values. +// +// MutationApis provide equivalent behavior to a cel::Type and cel::ValueFactory +// (resolved from a type name). +// +// AccessApis provide equivalent behavior to cel::StructValue accessors (virtual +// dispatch to a concrete implementation for accessing underlying values). +// +// This class is a simple wrapper around (nullable) pointers to the interface +// implementations. The underlying pointers are expected to be valid as long as +// the type provider that returned this object. +class LegacyTypeAdapter { + public: + // Interface for mutation apis. + // Note: in the new type system, the provider represents this by returning + // a cel::Type and cel::ValueFactory for the type. + class MutationApis { + public: + virtual ~MutationApis() = default; + + // Return whether the type defines the given field. + // TODO(issues/5): This is only used to eagerly fail during the planning + // phase. Check if it's safe to remove this behavior and fail at runtime. + virtual bool DefinesField(absl::string_view field_name) const = 0; + + // Create a new empty instance of the type. + // May return a status if the type is not possible to create. + virtual absl::StatusOr NewInstance( + cel::MemoryManager& memory_manager) const = 0; + + // Normalize special types to a native CEL value after building. + // The default implementation is a no-op. + // The interpreter guarantees that instance is uniquely owned by the + // interpreter, and can be safely mutated. + virtual absl::Status AdaptFromWellKnownType( + cel::MemoryManager& memory_manager, CelValue& instance) const { + return absl::OkStatus(); + } + + // Set field on instance to value. + // The interpreter guarantees that instance is uniquely owned by the + // interpreter, and can be safely mutated. + virtual absl::Status SetField(absl::string_view field_name, + const CelValue& value, + cel::MemoryManager& memory_manager, + CelValue& instance) const = 0; + }; + + // Interface for access apis. + // Note: in new type system this is integrated into the StructValue (via + // dynamic dispatch to concerete implementations). + class AccessApis { + public: + virtual ~AccessApis() = default; + + // Return whether an instance of the type has field set to a non-default + // value. + virtual absl::StatusOr HasField(absl::string_view field_name, + const CelValue& value) const = 0; + + // Access field on instance. + virtual absl::StatusOr GetField( + absl::string_view field_name, const CelValue& instance, + cel::MemoryManager& memory_manager) const = 0; + }; + + LegacyTypeAdapter(const AccessApis* access, const MutationApis* mutation) + : access_apis_(access), mutation_apis_(mutation) {} + + // Apis for access for the represented type. + // If null, access is not supported (this is an opaque type). + const AccessApis* access_apis() { return access_apis_; } + + // Apis for mutation for the represented type. + // If null, mutation is not supported (this type cannot be created). + const MutationApis* mutation_apis() { return mutation_apis_; } + + private: + const AccessApis* access_apis_; + const MutationApis* mutation_apis_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ diff --git a/eval/public/structs/legacy_type_adapter_test.cc b/eval/public/structs/legacy_type_adapter_test.cc new file mode 100644 index 000000000..ce93f9f71 --- /dev/null +++ b/eval/public/structs/legacy_type_adapter_test.cc @@ -0,0 +1,56 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/legacy_type_adapter.h" + +#include "google/protobuf/arena.h" +#include "eval/public/cel_value.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +class TestMutationApiImpl : public LegacyTypeAdapter::MutationApis { + public: + TestMutationApiImpl() {} + bool DefinesField(absl::string_view field_name) const override { + return false; + } + + absl::StatusOr NewInstance( + cel::MemoryManager& memory_manager) const override { + return absl::UnimplementedError("Not implemented"); + } + + absl::Status SetField(absl::string_view field_name, const CelValue& value, + cel::MemoryManager& memory_manager, + CelValue& instance) const override { + return absl::UnimplementedError("Not implemented"); + } +}; + +TEST(LegacyTypeAdapterMutationApis, DefaultNoopAdapt) { + CelValue v; + google::protobuf::Arena arena; + cel::extensions::ProtoMemoryManager manager(&arena); + + TestMutationApiImpl impl; + + EXPECT_OK(impl.AdaptFromWellKnownType(manager, v)); +} + +} // namespace +} // namespace google::api::expr::runtime From a3df55997c15ee382dcc2ef053c6f97e505da103 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 17 Mar 2022 19:35:34 +0000 Subject: [PATCH 027/111] Internal change PiperOrigin-RevId: 435418602 --- base/BUILD | 10 +---- base/type_provider.h | 15 +------- eval/public/structs/BUILD | 10 +++++ eval/public/structs/legacy_type_provider.h | 43 ++++++++++++++++++++++ 4 files changed, 56 insertions(+), 22 deletions(-) create mode 100644 eval/public/structs/legacy_type_provider.h diff --git a/base/BUILD b/base/BUILD index 6f1ee5ce2..de01486b1 100644 --- a/base/BUILD +++ b/base/BUILD @@ -104,6 +104,7 @@ cc_library( hdrs = [ "type.h", "type_factory.h", + "type_provider.h", ], deps = [ ":handle", @@ -130,15 +131,6 @@ cc_test( ], ) -cc_library( - name = "type_provider", - hdrs = ["type_provider.h"], - deps = [ - "//eval/public/structs:legacy_type_adapter", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "value", srcs = [ diff --git a/base/type_provider.h b/base/type_provider.h index f4359d9eb..6bbc5ab7d 100644 --- a/base/type_provider.h +++ b/base/type_provider.h @@ -15,11 +15,10 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ -#include "absl/strings/string_view.h" -#include "eval/public/structs/legacy_type_adapter.h" - namespace cel { +class TypeFactory; + // Interface for a TypeProvider, allowing host applications to inject // functionality for operating on custom types in the CEL interpreter. // @@ -33,16 +32,6 @@ namespace cel { class TypeProvider { public: virtual ~TypeProvider() = default; - - // Return LegacyTypeAdapter for the fully qualified type name if available. - // - // nullopt values are interpreted as not present. - // - // Returned non-null pointers from the adapter implemententation must remain - // valid as long as the type provider. - // TODO(issues/5): add alternative for new type system. - virtual absl::optional - ProvideLegacyType(absl::string_view name) const = 0; }; } // namespace cel diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 0ccfb40b6..43bc0423e 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -86,6 +86,16 @@ cc_test( ], ) +cc_library( + name = "legacy_type_provider", + hdrs = ["legacy_type_provider.h"], + deps = [ + ":legacy_type_adapter", + "//base:type", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "legacy_type_adapter", hdrs = ["legacy_type_adapter.h"], diff --git a/eval/public/structs/legacy_type_provider.h b/eval/public/structs/legacy_type_provider.h new file mode 100644 index 000000000..72ac86eaa --- /dev/null +++ b/eval/public/structs/legacy_type_provider.h @@ -0,0 +1,43 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TYPE_PROVIDER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TYPE_PROVIDER_H_ + +#include "absl/types/optional.h" +#include "base/type_provider.h" +#include "eval/public/structs/legacy_type_adapter.h" + +namespace google::api::expr::runtime { + +// An internal extension of cel::TypeProvider that also deals with legacy types. +// +// Note: This API is not finalized. Consult the CEL team before introducing new +// implementations. +class LegacyTypeProvider : public cel::TypeProvider { + public: + // Return LegacyTypeAdapter for the fully qualified type name if available. + // + // nullopt values are interpreted as not present. + // + // Returned non-null pointers from the adapter implemententation must remain + // valid as long as the type provider. + // TODO(issues/5): add alternative for new type system. + virtual absl::optional ProvideLegacyType( + absl::string_view name) const = 0; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TYPE_PROVIDER_H_ From f4dd008f7f1bfbb72687c940ddd3b7fdcf7542ee Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 18 Mar 2022 05:28:09 +0000 Subject: [PATCH 028/111] Add TypeProvider implementation for creation APIs for protocol buffer messages based on the configured descriptor pool. PiperOrigin-RevId: 435544477 --- eval/public/cel_value.h | 2 + eval/public/structs/BUILD | 65 ++++ .../structs/proto_message_type_adapter.cc | 165 +++++++++ .../structs/proto_message_type_adapter.h | 68 ++++ .../proto_message_type_adapter_test.cc | 314 ++++++++++++++++++ .../protobuf_descriptor_type_provider.cc | 55 +++ .../protobuf_descriptor_type_provider.h | 59 ++++ .../protobuf_descriptor_type_provider_test.cc | 82 +++++ 8 files changed, 810 insertions(+) create mode 100644 eval/public/structs/proto_message_type_adapter.cc create mode 100644 eval/public/structs/proto_message_type_adapter.h create mode 100644 eval/public/structs/proto_message_type_adapter_test.cc create mode 100644 eval/public/structs/protobuf_descriptor_type_provider.cc create mode 100644 eval/public/structs/protobuf_descriptor_type_provider.h create mode 100644 eval/public/structs/protobuf_descriptor_type_provider_test.cc diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 7d09b89af..5a6442bb6 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -45,6 +45,7 @@ using CelError = absl::Status; class CelList; class CelMap; class UnknownSet; +class LegacyTypeAdapter; class CelValue { public: @@ -452,6 +453,7 @@ class CelValue { } friend class CelProtoWrapper; + friend class ProtoMessageTypeAdapter; friend class EvaluatorStack; }; diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 43bc0423e..180829047 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -118,3 +118,68 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "proto_message_type_adapter", + srcs = ["proto_message_type_adapter.cc"], + hdrs = ["proto_message_type_adapter.h"], + deps = [ + ":cel_proto_wrapper", + ":legacy_type_adapter", + "//base:memory_manager", + "//eval/public:cel_value", + "//eval/public/containers:field_access", + "//extensions/protobuf:memory_manager", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_message_type_adapter_test", + srcs = ["proto_message_type_adapter_test.cc"], + deps = [ + ":cel_proto_wrapper", + ":proto_message_type_adapter", + "//eval/public:cel_value", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/containers:container_backed_map_impl", + "//eval/public/testing:matchers", + "//eval/testutil:test_message_cc_proto", + "//extensions/protobuf:memory_manager", + "//internal:status_macros", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "protobuf_descriptor_type_provider", + srcs = ["protobuf_descriptor_type_provider.cc"], + hdrs = ["protobuf_descriptor_type_provider.h"], + deps = [ + ":proto_message_type_adapter", + "//eval/public:cel_value", + "//eval/public/structs:legacy_type_provider", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "protobuf_descriptor_type_provider_test", + srcs = ["protobuf_descriptor_type_provider_test.cc"], + deps = [ + ":protobuf_descriptor_type_provider", + "//eval/public:cel_value", + "//eval/public/testing:matchers", + "//extensions/protobuf:memory_manager", + "//internal:status_macros", + "//internal:testing", + ], +) diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc new file mode 100644 index 000000000..abefd239f --- /dev/null +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -0,0 +1,165 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/proto_message_type_adapter.h" + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/field_access.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" + +namespace google::api::expr::runtime { +using ::google::protobuf::Message; + +absl::Status ProtoMessageTypeAdapter::ValidateSetFieldOp( + bool assertion, absl::string_view field, absl::string_view detail) const { + if (!assertion) { + return absl::InvalidArgumentError( + absl::Substitute("SetField failed on message $0, field '$1': $2", + descriptor_->full_name(), field, detail)); + } + return absl::OkStatus(); +} + +absl::StatusOr ProtoMessageTypeAdapter::NewInstance( + cel::MemoryManager& memory_manager) const { + // This implementation requires arena-backed memory manager. + google::protobuf::Arena* arena = + cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager); + const Message* prototype = message_factory_->GetPrototype(descriptor_); + + Message* msg = (prototype != nullptr) ? prototype->New(arena) : nullptr; + + if (msg == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to create message ", descriptor_->name())); + } + return CelValue::CreateMessage(msg); +} + +bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { + return descriptor_->FindFieldByName(field_name.data()) != nullptr; +} + +absl::StatusOr ProtoMessageTypeAdapter::HasField( + absl::string_view field_name, const CelValue& value) const { + return absl::UnimplementedError("Not yet implemented."); +} + +absl::StatusOr ProtoMessageTypeAdapter::GetField( + absl::string_view field_name, const CelValue& instance, + cel::MemoryManager& memory_manager) const { + return absl::UnimplementedError("Not yet implemented."); +} + +absl::Status ProtoMessageTypeAdapter::SetField( + absl::string_view field_name, const CelValue& value, + cel::MemoryManager& memory_manager, CelValue& instance) const { + // Assume proto arena implementation if this provider is used. + google::protobuf::Arena* arena = + cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager); + const google::protobuf::Message* message = nullptr; + if (!instance.GetValue(&message) || message == nullptr) { + return absl::InternalError("SetField called on non-message type."); + } + + // Interpreter guarantees this is the top-level instance. + google::protobuf::Message* mutable_message = const_cast(message); + + const google::protobuf::FieldDescriptor* field_descriptor = + descriptor_->FindFieldByName(field_name.data()); + CEL_RETURN_IF_ERROR( + ValidateSetFieldOp(field_descriptor != nullptr, field_name, "not found")); + + if (field_descriptor->is_map()) { + constexpr int kKeyField = 1; + constexpr int kValueField = 2; + + const CelMap* cel_map; + CEL_RETURN_IF_ERROR(ValidateSetFieldOp( + value.GetValue(&cel_map) && cel_map != nullptr, + field_name, "value is not CelMap")); + + auto entry_descriptor = field_descriptor->message_type(); + + CEL_RETURN_IF_ERROR( + ValidateSetFieldOp(entry_descriptor != nullptr, field_name, + "failed to find map entry descriptor")); + auto key_field_descriptor = entry_descriptor->FindFieldByNumber(kKeyField); + auto value_field_descriptor = + entry_descriptor->FindFieldByNumber(kValueField); + + CEL_RETURN_IF_ERROR( + ValidateSetFieldOp(key_field_descriptor != nullptr, field_name, + "failed to find key field descriptor")); + + CEL_RETURN_IF_ERROR( + ValidateSetFieldOp(value_field_descriptor != nullptr, field_name, + "failed to find value field descriptor")); + + const CelList* key_list = cel_map->ListKeys(); + for (int i = 0; i < key_list->size(); i++) { + CelValue key = (*key_list)[i]; + + auto value = (*cel_map)[key]; + CEL_RETURN_IF_ERROR(ValidateSetFieldOp(value.has_value(), field_name, + "error serializing CelMap")); + Message* entry_msg = mutable_message->GetReflection()->AddMessage( + mutable_message, field_descriptor); + CEL_RETURN_IF_ERROR( + SetValueToSingleField(key, key_field_descriptor, entry_msg, arena)); + CEL_RETURN_IF_ERROR(SetValueToSingleField( + value.value(), value_field_descriptor, entry_msg, arena)); + } + + } else if (field_descriptor->is_repeated()) { + const CelList* cel_list; + CEL_RETURN_IF_ERROR(ValidateSetFieldOp( + value.GetValue(&cel_list) && cel_list != nullptr, + field_name, "expected CelList value")); + + for (int i = 0; i < cel_list->size(); i++) { + CEL_RETURN_IF_ERROR(AddValueToRepeatedField( + (*cel_list)[i], field_descriptor, mutable_message, arena)); + } + } else { + CEL_RETURN_IF_ERROR( + SetValueToSingleField(value, field_descriptor, mutable_message, arena)); + } + return absl::OkStatus(); +} + +absl::Status ProtoMessageTypeAdapter::AdaptFromWellKnownType( + cel::MemoryManager& memory_manager, CelValue& instance) const { + // Assume proto arena implementation if this provider is used. + google::protobuf::Arena* arena = + cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager); + const google::protobuf::Message* message; + if (!instance.GetValue(&message) || message == nullptr) { + return absl::InternalError( + "Adapt from well-known type failed: not a message"); + } + + instance = CelProtoWrapper::CreateMessage(message, arena); + return absl::OkStatus(); +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/proto_message_type_adapter.h b/eval/public/structs/proto_message_type_adapter.h new file mode 100644 index 000000000..5d75927a6 --- /dev/null +++ b/eval/public/structs/proto_message_type_adapter.h @@ -0,0 +1,68 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "base/memory_manager.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/legacy_type_adapter.h" + +namespace google::api::expr::runtime { + +class ProtoMessageTypeAdapter : public LegacyTypeAdapter::AccessApis, + public LegacyTypeAdapter::MutationApis { + public: + ProtoMessageTypeAdapter(const google::protobuf::Descriptor* descriptor, + google::protobuf::MessageFactory* message_factory) + : message_factory_(message_factory), descriptor_(descriptor) {} + + ~ProtoMessageTypeAdapter() override = default; + + absl::StatusOr NewInstance( + cel::MemoryManager& memory_manager) const override; + + bool DefinesField(absl::string_view field_name) const override; + + absl::Status SetField(absl::string_view field_name, const CelValue& value, + + cel::MemoryManager& memory_manager, + CelValue& instance) const override; + + absl::Status AdaptFromWellKnownType(cel::MemoryManager& memory_manager, + CelValue& instance) const override; + + absl::StatusOr GetField( + absl::string_view field_name, const CelValue& instance, + cel::MemoryManager& memory_manager) const override; + + absl::StatusOr HasField(absl::string_view field_name, + const CelValue& value) const override; + + private: + // Helper for standardizing error messages for SetField operation. + absl::Status ValidateSetFieldOp(bool assertion, absl::string_view field, + absl::string_view detail) const; + + google::protobuf::MessageFactory* message_factory_; + const google::protobuf::Descriptor* descriptor_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc new file mode 100644 index 000000000..40acbacb0 --- /dev/null +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -0,0 +1,314 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/proto_message_type_adapter.h" + +#include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/descriptor.pb.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "absl/status/status.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/testing/matchers.h" +#include "eval/testutil/test_message.pb.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +using testing::EqualsProto; +using testing::HasSubstr; +using cel::internal::StatusIs; + +TEST(ProtoMessageTypeAdapter, HasFieldNotYetImplemented) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + + TestMessage example; + example.set_int64_value(10); + + CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + + EXPECT_THAT(adapter.HasField("value", value), + StatusIs(absl::StatusCode::kUnimplemented)); +} + +TEST(ProtoMessageTypeAdapter, GetFieldNotYetImplemented) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + cel::extensions::ProtoMemoryManager manager(&arena); + + TestMessage example; + example.set_int64_value(10); + + CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + + EXPECT_THAT(adapter.GetField("int64_value", value, manager), + StatusIs(absl::StatusCode::kUnimplemented)); +} + +TEST(ProtoMessageTypeAdapter, NewInstance) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + cel::extensions::ProtoMemoryManager manager(&arena); + + ASSERT_OK_AND_ASSIGN(CelValue result, adapter.NewInstance(manager)); + const google::protobuf::Message* message; + ASSERT_TRUE(result.GetValue(&message)); + EXPECT_THAT(message, EqualsProto(TestMessage::default_instance())); +} + +TEST(ProtoMessageTypeAdapter, NewInstanceUnsupportedDescriptor) { + google::protobuf::Arena arena; + + google::protobuf::DescriptorPool pool; + google::protobuf::FileDescriptorProto faked_file; + faked_file.set_name("faked.proto"); + faked_file.set_syntax("proto3"); + faked_file.set_package("google.api.expr.runtime"); + auto msg_descriptor = faked_file.add_message_type(); + msg_descriptor->set_name("FakeMessage"); + pool.BuildFile(faked_file); + + ProtoMessageTypeAdapter adapter( + pool.FindMessageTypeByName("google.api.expr.runtime.FakeMessage"), + google::protobuf::MessageFactory::generated_factory()); + cel::extensions::ProtoMemoryManager manager(&arena); + + // Message factory doesn't know how to create our custom message, even though + // we provided a descriptor for it. + EXPECT_THAT( + adapter.NewInstance(manager), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("FakeMessage"))); +} + +TEST(ProtoMessageTypeAdapter, DefinesField) { + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + + EXPECT_TRUE(adapter.DefinesField("int64_value")); + EXPECT_FALSE(adapter.DefinesField("not_a_field")); +} + +TEST(ProtoMessageTypeAdapter, SetFieldSingular) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + cel::extensions::ProtoMemoryManager manager(&arena); + + ASSERT_OK_AND_ASSIGN(CelValue value, adapter.NewInstance(manager)); + + ASSERT_OK(adapter.SetField("int64_value", CelValue::CreateInt64(10), manager, + value)); + + const google::protobuf::Message* message; + ASSERT_TRUE(value.GetValue(&message)); + EXPECT_THAT(message, EqualsProto("int64_value: 10")); + + ASSERT_THAT(adapter.SetField("not_a_field", CelValue::CreateInt64(10), + manager, value), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field 'not_a_field': not found"))); +} + +TEST(ProtoMessageTypeAdapter, SetFieldMap) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + cel::extensions::ProtoMemoryManager manager(&arena); + + CelMapBuilder builder; + ASSERT_OK(builder.Add(CelValue::CreateInt64(1), CelValue::CreateInt64(2))); + ASSERT_OK(builder.Add(CelValue::CreateInt64(2), CelValue::CreateInt64(4))); + + CelValue value_to_set = CelValue::CreateMap(&builder); + + ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); + + ASSERT_OK( + adapter.SetField("int64_int32_map", value_to_set, manager, instance)); + + const google::protobuf::Message* message; + ASSERT_TRUE(instance.GetValue(&message)); + EXPECT_THAT(message, EqualsProto(R"pb( + int64_int32_map { key: 1 value: 2 } + int64_int32_map { key: 2 value: 4 } + )pb")); +} + +TEST(ProtoMessageTypeAdapter, SetFieldRepeated) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + cel::extensions::ProtoMemoryManager manager(&arena); + + ContainerBackedListImpl list( + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + CelValue value_to_set = CelValue::CreateList(&list); + ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); + + ASSERT_OK(adapter.SetField("int64_list", value_to_set, manager, instance)); + + const google::protobuf::Message* message; + ASSERT_TRUE(instance.GetValue(&message)); + EXPECT_THAT(message, EqualsProto(R"pb( + int64_list: 1 int64_list: 2 + )pb")); +} + +TEST(ProtoMessageTypeAdapter, SetFieldNotAField) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + cel::extensions::ProtoMemoryManager manager(&arena); + + ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); + + ASSERT_THAT(adapter.SetField("not_a_field", CelValue::CreateInt64(10), + manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field 'not_a_field': not found"))); +} + +TEST(ProtoMesssageTypeAdapter, SetFieldWrongType) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + cel::extensions::ProtoMemoryManager manager(&arena); + + ContainerBackedListImpl list( + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + CelValue list_value = CelValue::CreateList(&list); + + CelMapBuilder builder; + ASSERT_OK(builder.Add(CelValue::CreateInt64(1), CelValue::CreateInt64(2))); + ASSERT_OK(builder.Add(CelValue::CreateInt64(2), CelValue::CreateInt64(4))); + + CelValue map_value = CelValue::CreateMap(&builder); + + CelValue int_value = CelValue::CreateInt64(42); + + ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); + + EXPECT_THAT(adapter.SetField("int64_value", map_value, manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(adapter.SetField("int64_value", list_value, manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument)); + + EXPECT_THAT( + adapter.SetField("int64_int32_map", list_value, manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(adapter.SetField("int64_int32_map", int_value, manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument)); + + EXPECT_THAT(adapter.SetField("int64_list", int_value, manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(adapter.SetField("int64_list", map_value, manager, instance), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ProtoMesssageTypeAdapter, SetFieldNotAMessage) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + cel::extensions::ProtoMemoryManager manager(&arena); + + CelValue int_value = CelValue::CreateInt64(42); + CelValue instance = CelValue::CreateNull(); + + EXPECT_THAT(adapter.SetField("int64_value", int_value, manager, instance), + StatusIs(absl::StatusCode::kInternal)); +} + +TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownType) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.protobuf.Int64Value"), + google::protobuf::MessageFactory::generated_factory()); + cel::extensions::ProtoMemoryManager manager(&arena); + + ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); + ASSERT_OK( + adapter.SetField("value", CelValue::CreateInt64(42), manager, instance)); + + ASSERT_OK(adapter.AdaptFromWellKnownType(manager, instance)); + + EXPECT_THAT(instance, test::IsCelInt64(42)); +} + +TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeUnspecial) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + cel::extensions::ProtoMemoryManager manager(&arena); + + ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); + ASSERT_OK(adapter.SetField("int64_value", CelValue::CreateInt64(42), manager, + instance)); + + ASSERT_OK(adapter.AdaptFromWellKnownType(manager, instance)); + + // TestMessage should not be converted to a CEL primitive type. + EXPECT_THAT(instance, test::IsCelMessage(EqualsProto("int64_value: 42"))); +} + +TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeNotAMessageError) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + cel::extensions::ProtoMemoryManager manager(&arena); + + CelValue instance = CelValue::CreateNull(); + + // Interpreter guaranteed to call this with a message type, otherwise, + // something has broken. + EXPECT_THAT(adapter.AdaptFromWellKnownType(manager, instance), + StatusIs(absl::StatusCode::kInternal)); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/protobuf_descriptor_type_provider.cc b/eval/public/structs/protobuf_descriptor_type_provider.cc new file mode 100644 index 000000000..8c96c6b38 --- /dev/null +++ b/eval/public/structs/protobuf_descriptor_type_provider.cc @@ -0,0 +1,55 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/protobuf_descriptor_type_provider.h" + +#include +#include + +#include "google/protobuf/descriptor.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/proto_message_type_adapter.h" + +namespace google::api::expr::runtime { + +absl::optional ProtobufDescriptorProvider::ProvideLegacyType( + absl::string_view name) const { + const ProtoMessageTypeAdapter* result = nullptr; + auto it = type_cache_.find(name); + if (it != type_cache_.end()) { + result = it->second.get(); + } else { + auto type_provider = GetType(name); + result = type_provider.get(); + type_cache_[name] = std::move(type_provider); + } + if (result == nullptr) { + return absl::nullopt; + } + // ProtoMessageTypeAdapter provides apis for both access and mutation. + return LegacyTypeAdapter(result, result); +} + +std::unique_ptr ProtobufDescriptorProvider::GetType( + absl::string_view name) const { + const google::protobuf::Descriptor* descriptor = + descriptor_pool_->FindMessageTypeByName(name.data()); + if (descriptor == nullptr) { + return nullptr; + } + + return std::make_unique(descriptor, + message_factory_); +} +} // namespace google::api::expr::runtime diff --git a/eval/public/structs/protobuf_descriptor_type_provider.h b/eval/public/structs/protobuf_descriptor_type_provider.h new file mode 100644 index 000000000..c5091ff2d --- /dev/null +++ b/eval/public/structs/protobuf_descriptor_type_provider.h @@ -0,0 +1,59 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_DESCRIPTOR_TYPE_PROVIDER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_DESCRIPTOR_TYPE_PROVIDER_H_ + +#include +#include +#include + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/legacy_type_provider.h" +#include "eval/public/structs/proto_message_type_adapter.h" + +namespace google::api::expr::runtime { + +// Implementation of a type provider that generates types from protocol buffer +// descriptors. +class ProtobufDescriptorProvider : public LegacyTypeProvider { + public: + ProtobufDescriptorProvider(const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* factory) + : descriptor_pool_(pool), message_factory_(factory) {} + + absl::optional ProvideLegacyType( + absl::string_view name) const override; + + private: + // Run a lookup if the type adapter hasn't already been built. + // returns nullptr if not found. + std::unique_ptr GetType( + absl::string_view name) const; + + const google::protobuf::DescriptorPool* descriptor_pool_; + google::protobuf::MessageFactory* message_factory_; + mutable absl::flat_hash_map> + type_cache_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_DESCRIPTOR_TYPE_PROVIDER_H_ diff --git a/eval/public/structs/protobuf_descriptor_type_provider_test.cc b/eval/public/structs/protobuf_descriptor_type_provider_test.cc new file mode 100644 index 000000000..4443bb59a --- /dev/null +++ b/eval/public/structs/protobuf_descriptor_type_provider_test.cc @@ -0,0 +1,82 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/protobuf_descriptor_type_provider.h" + +#include "eval/public/cel_value.h" +#include "eval/public/testing/matchers.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +TEST(ProtobufDescriptorProvider, Basic) { + ProtobufDescriptorProvider provider( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory()); + google::protobuf::Arena arena; + cel::extensions::ProtoMemoryManager manager(&arena); + auto type_adapter = provider.ProvideLegacyType("google.protobuf.Int64Value"); + + ASSERT_TRUE(type_adapter.has_value()); + ASSERT_TRUE(type_adapter->mutation_apis() != nullptr); + + ASSERT_TRUE(type_adapter->mutation_apis()->DefinesField("value")); + ASSERT_OK_AND_ASSIGN(CelValue value, + type_adapter->mutation_apis()->NewInstance(manager)); + + ASSERT_TRUE(value.IsMessage()); + ASSERT_OK(type_adapter->mutation_apis()->SetField( + "value", CelValue::CreateInt64(10), manager, value)); + + ASSERT_OK( + type_adapter->mutation_apis()->AdaptFromWellKnownType(manager, value)); + + EXPECT_THAT(value, test::IsCelInt64(10)); +} + +// This is an implementation detail, but testing for coverage. +TEST(ProtobufDescriptorProvider, MemoizesAdapters) { + ProtobufDescriptorProvider provider( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory()); + google::protobuf::Arena arena; + cel::extensions::ProtoMemoryManager manager(&arena); + auto type_adapter = provider.ProvideLegacyType("google.protobuf.Int64Value"); + + ASSERT_TRUE(type_adapter.has_value()); + ASSERT_TRUE(type_adapter->mutation_apis() != nullptr); + + auto type_adapter2 = provider.ProvideLegacyType("google.protobuf.Int64Value"); + ASSERT_TRUE(type_adapter2.has_value()); + + EXPECT_EQ(type_adapter->mutation_apis(), type_adapter2->mutation_apis()); + EXPECT_EQ(type_adapter->access_apis(), type_adapter2->access_apis()); +} + +TEST(ProtobufDescriptorProvider, NotFound) { + ProtobufDescriptorProvider provider( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory()); + google::protobuf::Arena arena; + cel::extensions::ProtoMemoryManager manager(&arena); + auto type_adapter = provider.ProvideLegacyType("UnknownType"); + + ASSERT_FALSE(type_adapter.has_value()); +} + +} // namespace +} // namespace google::api::expr::runtime From 96fb8dc22adbce4d5eee53301a98971ea20d484e Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 18 Mar 2022 15:41:11 +0000 Subject: [PATCH 029/111] Internal change PiperOrigin-RevId: 435647665 --- base/handle.h | 14 +++ base/internal/handle.post.h | 18 ++++ base/internal/handle.pre.h | 13 +++ base/internal/memory_manager.post.h | 5 + base/internal/memory_manager.pre.h | 2 + base/memory_manager.cc | 5 + base/memory_manager.h | 5 + base/type_factory.h | 14 +-- base/type_test.cc | 41 ++++----- base/value.cc | 30 +++++- base/value_factory.h | 15 +-- base/value_factory_test.cc | 9 +- base/value_test.cc | 138 +++++++++++++--------------- 13 files changed, 189 insertions(+), 120 deletions(-) diff --git a/base/handle.h b/base/handle.h index f82b797ca..0b096684d 100644 --- a/base/handle.h +++ b/base/handle.h @@ -25,6 +25,8 @@ namespace cel { +class MemoryManager; + template class Transient; @@ -182,6 +184,12 @@ class Transient final : private base_internal::HandlePolicy { friend bool base_internal::IsUnmanagedHandle(const Transient& handle); template friend bool base_internal::IsInlinedHandle(const Transient& handle); + template + friend MemoryManager& base_internal::GetMemoryManager( + const Transient& handle); + template + friend MemoryManager& base_internal::GetMemoryManager( + const Persistent& handle); template explicit Transient(base_internal::HandleInPlace, Args&&... args) @@ -397,6 +405,12 @@ class Persistent final : private base_internal::HandlePolicy { friend bool base_internal::IsUnmanagedHandle(const Persistent& handle); template friend bool base_internal::IsInlinedHandle(const Persistent& handle); + template + friend MemoryManager& base_internal::GetMemoryManager( + const Transient& handle); + template + friend MemoryManager& base_internal::GetMemoryManager( + const Persistent& handle); template explicit Persistent(base_internal::HandleInPlace, Args&&... args) diff --git a/base/internal/handle.post.h b/base/internal/handle.post.h index be57aed18..5fbfc8199 100644 --- a/base/internal/handle.post.h +++ b/base/internal/handle.post.h @@ -137,6 +137,24 @@ bool IsInlinedHandle(const Persistent& handle) { return handle.impl_.IsInlined(); } +template +MemoryManager& GetMemoryManager(const Transient& handle) { + ABSL_ASSERT(IsManagedHandle(handle)); + auto [size, align] = + static_cast(handle.operator->())->SizeAndAlignment(); + return GetMemoryManager(static_cast(handle.operator->()), size, + align); +} + +template +MemoryManager& GetMemoryManager(const Persistent& handle) { + ABSL_ASSERT(IsManagedHandle(handle)); + auto [size, align] = + static_cast(handle.operator->())->SizeAndAlignment(); + return GetMemoryManager(static_cast(handle.operator->()), size, + align); +} + } // namespace cel::base_internal #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_HANDLE_POST_H_ diff --git a/base/internal/handle.pre.h b/base/internal/handle.pre.h index 423142b58..867b8f59a 100644 --- a/base/internal/handle.pre.h +++ b/base/internal/handle.pre.h @@ -33,6 +33,8 @@ class Transient; template class Persistent; +class MemoryManager; + namespace base_internal { class TypeHandleBase; @@ -70,6 +72,13 @@ struct HandleInPlace { // and Transient. Think std::in_place. inline constexpr HandleInPlace kHandleInPlace{}; +// If IsManagedHandle returns true, get a reference to the memory manager that +// is managing it. +template +MemoryManager& GetMemoryManager(const Transient& handle); +template +MemoryManager& GetMemoryManager(const Persistent& handle); + // Virtual base class for all classes that can be managed by handles. class Resource { public: @@ -85,6 +94,10 @@ class Resource { friend class ValueHandleBase; template friend struct HandleFactory; + template + friend MemoryManager& GetMemoryManager(const Transient& handle); + template + friend MemoryManager& GetMemoryManager(const Persistent& handle); Resource() = default; Resource(const Resource&) = default; diff --git a/base/internal/memory_manager.post.h b/base/internal/memory_manager.post.h index 11da71b3e..3dec55578 100644 --- a/base/internal/memory_manager.post.h +++ b/base/internal/memory_manager.post.h @@ -40,6 +40,11 @@ constexpr T* ManagedMemoryRelease(ManagedMemory& managed_memory) { return ptr; } +inline MemoryManager& GetMemoryManager(const void* pointer, size_t size, + size_t align) { + return MemoryManager::Get(pointer, size, align); +} + } // namespace cel::base_internal #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_POST_H_ diff --git a/base/internal/memory_manager.pre.h b/base/internal/memory_manager.pre.h index 66507a0e1..28ac19541 100644 --- a/base/internal/memory_manager.pre.h +++ b/base/internal/memory_manager.pre.h @@ -40,6 +40,8 @@ constexpr size_t GetManagedMemoryAlignment( template constexpr T* ManagedMemoryRelease(ManagedMemory& managed_memory); +MemoryManager& GetMemoryManager(const void* pointer, size_t size, size_t align); + template class MemoryManagerDestructor final { private: diff --git a/base/memory_manager.cc b/base/memory_manager.cc index 56d9f670f..db2484646 100644 --- a/base/memory_manager.cc +++ b/base/memory_manager.cc @@ -222,6 +222,11 @@ void MemoryManager::DeallocateInternal(void* pointer, size_t size, memory_manager->Deallocate(pointer, size, align); } +MemoryManager& MemoryManager::Get(const void* pointer, size_t size, + size_t align) { + return *GetControlBlock(pointer, size, align)->memory_manager; +} + void MemoryManager::Ref(const void* pointer, size_t size, size_t align) { if (pointer != nullptr && size != 0) { ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. diff --git a/base/memory_manager.h b/base/memory_manager.h index 85f796ef4..d53cbf074 100644 --- a/base/memory_manager.h +++ b/base/memory_manager.h @@ -190,11 +190,16 @@ class MemoryManager { friend class ManagedMemory; friend class ArenaMemoryManager; friend class base_internal::Resource; + friend MemoryManager& base_internal::GetMemoryManager(const void* pointer, + size_t size, + size_t align); // Only for use by ArenaMemoryManager. explicit MemoryManager(bool allocation_only) : allocation_only_(allocation_only) {} + static MemoryManager& Get(const void* pointer, size_t size, size_t align); + void* AllocateInternal(size_t& size, size_t& align); static void DeallocateInternal(void* pointer, size_t size, size_t align); diff --git a/base/type_factory.h b/base/type_factory.h index 4e74a3654..304d50d87 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -24,9 +24,15 @@ namespace cel { // TypeFactory provides member functions to get and create type implementations // of builtin types. -class TypeFactory { +class TypeFactory final { public: - virtual ~TypeFactory() = default; + explicit TypeFactory( + MemoryManager& memory_manager ABSL_ATTRIBUTE_LIFETIME_BOUND) + : memory_manager_(memory_manager) {} + + TypeFactory(const TypeFactory&) = delete; + + TypeFactory& operator=(const TypeFactory&) = delete; Persistent GetNullType() ABSL_ATTRIBUTE_LIFETIME_BOUND; @@ -55,10 +61,6 @@ class TypeFactory { ABSL_ATTRIBUTE_LIFETIME_BOUND; protected: - // Prevent direct intantiation until more pure virtual methods are added. - explicit TypeFactory(MemoryManager& memory_manager) - : memory_manager_(memory_manager) {} - // Ignore unused for now, as it will be used in the future. ABSL_ATTRIBUTE_UNUSED MemoryManager& memory_manager() const { return memory_manager_; diff --git a/base/type_test.cc b/base/type_test.cc index d6e2045fa..c98d5c0c5 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -28,11 +28,6 @@ namespace { using testing::SizeIs; -class TestTypeFactory final : public TypeFactory { - public: - TestTypeFactory() : TypeFactory(MemoryManager::Global()) {} -}; - template constexpr void IS_INITIALIZED(T&) {} @@ -67,13 +62,13 @@ TEST(Type, PersistentHandleTypeTraits) { } TEST(Type, CopyConstructor) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); Transient type(type_factory.GetIntType()); EXPECT_EQ(type, type_factory.GetIntType()); } TEST(Type, MoveConstructor) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); Transient from(type_factory.GetIntType()); Transient to(std::move(from)); IS_INITIALIZED(from); @@ -82,14 +77,14 @@ TEST(Type, MoveConstructor) { } TEST(Type, CopyAssignment) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); Transient type(type_factory.GetNullType()); type = type_factory.GetIntType(); EXPECT_EQ(type, type_factory.GetIntType()); } TEST(Type, MoveAssignment) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); Transient from(type_factory.GetIntType()); Transient to(type_factory.GetNullType()); to = std::move(from); @@ -99,7 +94,7 @@ TEST(Type, MoveAssignment) { } TEST(Type, Swap) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); Transient lhs = type_factory.GetIntType(); Transient rhs = type_factory.GetUintType(); std::swap(lhs, rhs); @@ -112,7 +107,7 @@ TEST(Type, Swap) { // feature is not available in C++17. TEST(Type, Null) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetNullType()->kind(), Kind::kNullType); EXPECT_EQ(type_factory.GetNullType()->name(), "null_type"); EXPECT_THAT(type_factory.GetNullType()->parameters(), SizeIs(0)); @@ -130,7 +125,7 @@ TEST(Type, Null) { } TEST(Type, Error) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetErrorType()->kind(), Kind::kError); EXPECT_EQ(type_factory.GetErrorType()->name(), "*error*"); EXPECT_THAT(type_factory.GetErrorType()->parameters(), SizeIs(0)); @@ -148,7 +143,7 @@ TEST(Type, Error) { } TEST(Type, Dyn) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetDynType()->kind(), Kind::kDyn); EXPECT_EQ(type_factory.GetDynType()->name(), "dyn"); EXPECT_THAT(type_factory.GetDynType()->parameters(), SizeIs(0)); @@ -166,7 +161,7 @@ TEST(Type, Dyn) { } TEST(Type, Any) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetAnyType()->kind(), Kind::kAny); EXPECT_EQ(type_factory.GetAnyType()->name(), "google.protobuf.Any"); EXPECT_THAT(type_factory.GetAnyType()->parameters(), SizeIs(0)); @@ -184,7 +179,7 @@ TEST(Type, Any) { } TEST(Type, Bool) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetBoolType()->kind(), Kind::kBool); EXPECT_EQ(type_factory.GetBoolType()->name(), "bool"); EXPECT_THAT(type_factory.GetBoolType()->parameters(), SizeIs(0)); @@ -202,7 +197,7 @@ TEST(Type, Bool) { } TEST(Type, Int) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetIntType()->kind(), Kind::kInt); EXPECT_EQ(type_factory.GetIntType()->name(), "int"); EXPECT_THAT(type_factory.GetIntType()->parameters(), SizeIs(0)); @@ -220,7 +215,7 @@ TEST(Type, Int) { } TEST(Type, Uint) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetUintType()->kind(), Kind::kUint); EXPECT_EQ(type_factory.GetUintType()->name(), "uint"); EXPECT_THAT(type_factory.GetUintType()->parameters(), SizeIs(0)); @@ -238,7 +233,7 @@ TEST(Type, Uint) { } TEST(Type, Double) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetDoubleType()->kind(), Kind::kDouble); EXPECT_EQ(type_factory.GetDoubleType()->name(), "double"); EXPECT_THAT(type_factory.GetDoubleType()->parameters(), SizeIs(0)); @@ -256,7 +251,7 @@ TEST(Type, Double) { } TEST(Type, String) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetStringType()->kind(), Kind::kString); EXPECT_EQ(type_factory.GetStringType()->name(), "string"); EXPECT_THAT(type_factory.GetStringType()->parameters(), SizeIs(0)); @@ -274,7 +269,7 @@ TEST(Type, String) { } TEST(Type, Bytes) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetBytesType()->kind(), Kind::kBytes); EXPECT_EQ(type_factory.GetBytesType()->name(), "bytes"); EXPECT_THAT(type_factory.GetBytesType()->parameters(), SizeIs(0)); @@ -292,7 +287,7 @@ TEST(Type, Bytes) { } TEST(Type, Duration) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetDurationType()->kind(), Kind::kDuration); EXPECT_EQ(type_factory.GetDurationType()->name(), "google.protobuf.Duration"); EXPECT_THAT(type_factory.GetDurationType()->parameters(), SizeIs(0)); @@ -310,7 +305,7 @@ TEST(Type, Duration) { } TEST(Type, Timestamp) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetTimestampType()->kind(), Kind::kTimestamp); EXPECT_EQ(type_factory.GetTimestampType()->name(), "google.protobuf.Timestamp"); @@ -329,7 +324,7 @@ TEST(Type, Timestamp) { } TEST(Type, SupportsAbslHash) { - TestTypeFactory type_factory; + TypeFactory type_factory(MemoryManager::Global()); EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ Persistent(type_factory.GetNullType()), Persistent(type_factory.GetErrorType()), diff --git a/base/value.cc b/base/value.cc index 9d6ecc948..b7c9d7908 100644 --- a/base/value.cc +++ b/base/value.cc @@ -589,6 +589,14 @@ class HashValueVisitor final { absl::HashState state_; }; +template +bool CanPerformZeroCopy(MemoryManager& memory_manager, + const Transient& handle) { + return base_internal::IsManagedHandle(handle) && + std::addressof(memory_manager) == + std::addressof(base_internal::GetMemoryManager(handle)); +} + } // namespace Persistent BytesValue::Empty(ValueFactory& value_factory) { @@ -598,8 +606,15 @@ Persistent BytesValue::Empty(ValueFactory& value_factory) { absl::StatusOr> BytesValue::Concat( ValueFactory& value_factory, const Transient& lhs, const Transient& rhs) { - absl::Cord cord = lhs->ToCord(base_internal::IsManagedHandle(lhs)); - cord.Append(rhs->ToCord(base_internal::IsManagedHandle(rhs))); + absl::Cord cord; + // We can only use the potential zero-copy path if the memory managers are + // the same. Otherwise we need to escape the original memory manager scope. + cord.Append( + lhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), lhs))); + // We can only use the potential zero-copy path if the memory managers are + // the same. Otherwise we need to escape the original memory manager scope. + cord.Append( + rhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), rhs))); return value_factory.CreateBytesValue(std::move(cord)); } @@ -665,8 +680,15 @@ Persistent StringValue::Empty(ValueFactory& value_factory) { absl::StatusOr> StringValue::Concat( ValueFactory& value_factory, const Transient& lhs, const Transient& rhs) { - absl::Cord cord = lhs->ToCord(base_internal::IsManagedHandle(lhs)); - cord.Append(rhs->ToCord(base_internal::IsManagedHandle(rhs))); + absl::Cord cord; + // We can only use the potential zero-copy path if the memory managers are + // the same. Otherwise we need to escape the original memory manager scope. + cord.Append( + lhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), lhs))); + // We can only use the potential zero-copy path if the memory managers are + // the same. Otherwise we need to escape the original memory manager scope. + cord.Append( + rhs->ToCord(CanPerformZeroCopy(value_factory.memory_manager(), rhs))); size_t size = 0; size_t lhs_size = lhs->size_.load(std::memory_order_relaxed); if (lhs_size != 0 && !lhs->empty()) { diff --git a/base/value_factory.h b/base/value_factory.h index ab9ce7559..6522cd347 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -32,9 +32,15 @@ namespace cel { -class ValueFactory { +class ValueFactory final { public: - virtual ~ValueFactory() = default; + explicit ValueFactory( + MemoryManager& memory_manager ABSL_ATTRIBUTE_LIFETIME_BOUND) + : memory_manager_(memory_manager) {} + + ValueFactory(const ValueFactory&) = delete; + + ValueFactory& operator=(const ValueFactory&) = delete; Persistent GetNullValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; @@ -128,13 +134,10 @@ class ValueFactory { absl::Time value) ABSL_ATTRIBUTE_LIFETIME_BOUND; protected: - // Prevent direct intantiation until more pure virtual methods are added. - explicit ValueFactory(MemoryManager& memory_manager) - : memory_manager_(memory_manager) {} - MemoryManager& memory_manager() const { return memory_manager_; } private: + friend class BytesValue; friend class StringValue; Persistent GetEmptyBytesValue() diff --git a/base/value_factory_test.cc b/base/value_factory_test.cc index d873bbd50..171f0f360 100644 --- a/base/value_factory_test.cc +++ b/base/value_factory_test.cc @@ -23,19 +23,14 @@ namespace { using cel::internal::StatusIs; -class TestValueFactory final : public ValueFactory { - public: - TestValueFactory() : ValueFactory(MemoryManager::Global()) {} -}; - TEST(ValueFactory, CreateErrorValueReplacesOk) { - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_THAT(value_factory.CreateErrorValue(absl::OkStatus())->value(), StatusIs(absl::StatusCode::kUnknown)); } TEST(ValueFactory, CreateStringValueIllegalByteSequence) { - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_THAT(value_factory.CreateStringValue("\xff"), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(value_factory.CreateStringValue(absl::Cord("\xff")), diff --git a/base/value_test.cc b/base/value_test.cc index 92914fa8c..da3b305cf 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -44,16 +44,6 @@ Persistent Must(absl::StatusOr> status_or_handle) { return std::move(status_or_handle).value(); } -class TestTypeFactory final : public TypeFactory { - public: - TestTypeFactory() : TypeFactory(MemoryManager::Global()) {} -}; - -class TestValueFactory final : public ValueFactory { - public: - TestValueFactory() : ValueFactory(MemoryManager::Global()) {} -}; - template constexpr void IS_INITIALIZED(T&) {} @@ -94,7 +84,7 @@ TEST(Value, PersistentHandleTypeTraits) { } TEST(Value, DefaultConstructor) { - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); Transient value; EXPECT_EQ(value, value_factory.GetNullValue()); } @@ -109,7 +99,7 @@ using ConstructionAssignmentTest = TEST_P(ConstructionAssignmentTest, CopyConstructor) { const auto& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); Persistent from(test_case.default_value(value_factory)); Persistent to(from); IS_INITIALIZED(to); @@ -118,7 +108,7 @@ TEST_P(ConstructionAssignmentTest, CopyConstructor) { TEST_P(ConstructionAssignmentTest, MoveConstructor) { const auto& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); Persistent from(test_case.default_value(value_factory)); Persistent to(std::move(from)); IS_INITIALIZED(from); @@ -128,7 +118,7 @@ TEST_P(ConstructionAssignmentTest, MoveConstructor) { TEST_P(ConstructionAssignmentTest, CopyAssignment) { const auto& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); Persistent from(test_case.default_value(value_factory)); Persistent to; to = from; @@ -137,7 +127,7 @@ TEST_P(ConstructionAssignmentTest, CopyAssignment) { TEST_P(ConstructionAssignmentTest, MoveAssignment) { const auto& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); Persistent from(test_case.default_value(value_factory)); Persistent to; to = std::move(from); @@ -191,7 +181,7 @@ INSTANTIATE_TEST_SUITE_P( }); TEST(Value, Swap) { - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); Persistent lhs = value_factory.CreateIntValue(0); Persistent rhs = value_factory.CreateUintValue(0); std::swap(lhs, rhs); @@ -200,18 +190,18 @@ TEST(Value, Swap) { } TEST(NullValue, DebugString) { - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(value_factory.GetNullValue()->DebugString(), "null"); } TEST(BoolValue, DebugString) { - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(value_factory.CreateBoolValue(false)->DebugString(), "false"); EXPECT_EQ(value_factory.CreateBoolValue(true)->DebugString(), "true"); } TEST(IntValue, DebugString) { - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(value_factory.CreateIntValue(-1)->DebugString(), "-1"); EXPECT_EQ(value_factory.CreateIntValue(0)->DebugString(), "0"); EXPECT_EQ(value_factory.CreateIntValue(1)->DebugString(), "1"); @@ -224,7 +214,7 @@ TEST(IntValue, DebugString) { } TEST(UintValue, DebugString) { - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(value_factory.CreateUintValue(0)->DebugString(), "0u"); EXPECT_EQ(value_factory.CreateUintValue(1)->DebugString(), "1u"); EXPECT_EQ(value_factory.CreateUintValue(std::numeric_limits::max()) @@ -233,7 +223,7 @@ TEST(UintValue, DebugString) { } TEST(DoubleValue, DebugString) { - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(value_factory.CreateDoubleValue(-1.0)->DebugString(), "-1.0"); EXPECT_EQ(value_factory.CreateDoubleValue(0.0)->DebugString(), "0.0"); EXPECT_EQ(value_factory.CreateDoubleValue(1.0)->DebugString(), "1.0"); @@ -266,13 +256,13 @@ TEST(DoubleValue, DebugString) { } TEST(DurationValue, DebugString) { - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(DurationValue::Zero(value_factory)->DebugString(), internal::FormatDuration(absl::ZeroDuration()).value()); } TEST(TimestampValue, DebugString) { - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(TimestampValue::UnixEpoch(value_factory)->DebugString(), internal::FormatTimestamp(absl::UnixEpoch()).value()); } @@ -282,8 +272,8 @@ TEST(TimestampValue, DebugString) { // feature is not available in C++17. TEST(Value, Error) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto error_value = value_factory.CreateErrorValue(absl::CancelledError()); EXPECT_TRUE(error_value.Is()); EXPECT_FALSE(error_value.Is()); @@ -294,8 +284,8 @@ TEST(Value, Error) { } TEST(Value, Bool) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto false_value = BoolValue::False(value_factory); EXPECT_TRUE(false_value.Is()); EXPECT_FALSE(false_value.Is()); @@ -319,8 +309,8 @@ TEST(Value, Bool) { } TEST(Value, Int) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = value_factory.CreateIntValue(0); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -344,8 +334,8 @@ TEST(Value, Int) { } TEST(Value, Uint) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = value_factory.CreateUintValue(0); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -369,8 +359,8 @@ TEST(Value, Uint) { } TEST(Value, Double) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = value_factory.CreateDoubleValue(0.0); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -394,8 +384,8 @@ TEST(Value, Double) { } TEST(Value, Duration) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = Must(value_factory.CreateDurationValue(absl::ZeroDuration())); EXPECT_TRUE(zero_value.Is()); @@ -424,8 +414,8 @@ TEST(Value, Duration) { } TEST(Value, Timestamp) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = Must(value_factory.CreateTimestampValue(absl::UnixEpoch())); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -453,8 +443,8 @@ TEST(Value, Timestamp) { } TEST(Value, BytesFromString) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = Must(value_factory.CreateBytesValue(std::string("0"))); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -478,8 +468,8 @@ TEST(Value, BytesFromString) { } TEST(Value, BytesFromStringView) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = Must(value_factory.CreateBytesValue(absl::string_view("0"))); EXPECT_TRUE(zero_value.Is()); @@ -506,8 +496,8 @@ TEST(Value, BytesFromStringView) { } TEST(Value, BytesFromCord) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = Must(value_factory.CreateBytesValue(absl::Cord("0"))); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -531,8 +521,8 @@ TEST(Value, BytesFromCord) { } TEST(Value, BytesFromLiteral) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = Must(value_factory.CreateBytesValue("0")); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -556,8 +546,8 @@ TEST(Value, BytesFromLiteral) { } TEST(Value, BytesFromExternal) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = Must(value_factory.CreateBytesValue("0", []() {})); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -581,8 +571,8 @@ TEST(Value, BytesFromExternal) { } TEST(Value, StringFromString) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = Must(value_factory.CreateStringValue(std::string("0"))); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -607,8 +597,8 @@ TEST(Value, StringFromString) { } TEST(Value, StringFromStringView) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = Must(value_factory.CreateStringValue(absl::string_view("0"))); EXPECT_TRUE(zero_value.Is()); @@ -636,8 +626,8 @@ TEST(Value, StringFromStringView) { } TEST(Value, StringFromCord) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = Must(value_factory.CreateStringValue(absl::Cord("0"))); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -661,8 +651,8 @@ TEST(Value, StringFromCord) { } TEST(Value, StringFromLiteral) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = Must(value_factory.CreateStringValue("0")); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -686,8 +676,8 @@ TEST(Value, StringFromLiteral) { } TEST(Value, StringFromExternal) { - TestValueFactory value_factory; - TestTypeFactory type_factory; + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); auto zero_value = Must(value_factory.CreateStringValue("0", []() {})); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -734,7 +724,7 @@ using BytesConcatTest = testing::TestWithParam; TEST_P(BytesConcatTest, Concat) { const BytesConcatTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_TRUE( Must(BytesValue::Concat(value_factory, MakeStringBytes(value_factory, test_case.lhs), @@ -805,7 +795,7 @@ using BytesSizeTest = testing::TestWithParam; TEST_P(BytesSizeTest, Size) { const BytesSizeTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->size(), test_case.size); EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->size(), @@ -831,7 +821,7 @@ using BytesEmptyTest = testing::TestWithParam; TEST_P(BytesEmptyTest, Empty) { const BytesEmptyTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->empty(), test_case.empty); EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->empty(), @@ -857,7 +847,7 @@ using BytesEqualsTest = testing::TestWithParam; TEST_P(BytesEqualsTest, Equals) { const BytesEqualsTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(MakeStringBytes(value_factory, test_case.lhs) ->Equals(MakeStringBytes(value_factory, test_case.rhs)), test_case.equals); @@ -913,7 +903,7 @@ int NormalizeCompareResult(int compare) { return std::clamp(compare, -1, 1); } TEST_P(BytesCompareTest, Equals) { const BytesCompareTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(NormalizeCompareResult( MakeStringBytes(value_factory, test_case.lhs) ->Compare(MakeStringBytes(value_factory, test_case.rhs))), @@ -974,7 +964,7 @@ using BytesDebugStringTest = testing::TestWithParam; TEST_P(BytesDebugStringTest, ToCord) { const BytesDebugStringTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->DebugString(), internal::FormatBytesLiteral(test_case.data)); EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->DebugString(), @@ -999,7 +989,7 @@ using BytesToStringTest = testing::TestWithParam; TEST_P(BytesToStringTest, ToString) { const BytesToStringTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->ToString(), test_case.data); EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->ToString(), @@ -1024,7 +1014,7 @@ using BytesToCordTest = testing::TestWithParam; TEST_P(BytesToCordTest, ToCord) { const BytesToCordTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->ToCord(), test_case.data); EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->ToCord(), @@ -1065,7 +1055,7 @@ using StringConcatTest = testing::TestWithParam; TEST_P(StringConcatTest, Concat) { const StringConcatTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_TRUE( Must(StringValue::Concat(value_factory, MakeStringString(value_factory, test_case.lhs), @@ -1136,7 +1126,7 @@ using StringSizeTest = testing::TestWithParam; TEST_P(StringSizeTest, Size) { const StringSizeTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(MakeStringString(value_factory, test_case.data)->size(), test_case.size); EXPECT_EQ(MakeCordString(value_factory, test_case.data)->size(), @@ -1162,7 +1152,7 @@ using StringEmptyTest = testing::TestWithParam; TEST_P(StringEmptyTest, Empty) { const StringEmptyTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(MakeStringString(value_factory, test_case.data)->empty(), test_case.empty); EXPECT_EQ(MakeCordString(value_factory, test_case.data)->empty(), @@ -1188,7 +1178,7 @@ using StringEqualsTest = testing::TestWithParam; TEST_P(StringEqualsTest, Equals) { const StringEqualsTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(MakeStringString(value_factory, test_case.lhs) ->Equals(MakeStringString(value_factory, test_case.rhs)), test_case.equals); @@ -1242,7 +1232,7 @@ using StringCompareTest = testing::TestWithParam; TEST_P(StringCompareTest, Equals) { const StringCompareTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(NormalizeCompareResult( MakeStringString(value_factory, test_case.lhs) ->Compare(MakeStringString(value_factory, test_case.rhs))), @@ -1305,7 +1295,7 @@ using StringDebugStringTest = testing::TestWithParam; TEST_P(StringDebugStringTest, ToCord) { const StringDebugStringTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(MakeStringString(value_factory, test_case.data)->DebugString(), internal::FormatStringLiteral(test_case.data)); EXPECT_EQ(MakeCordString(value_factory, test_case.data)->DebugString(), @@ -1330,7 +1320,7 @@ using StringToStringTest = testing::TestWithParam; TEST_P(StringToStringTest, ToString) { const StringToStringTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(MakeStringString(value_factory, test_case.data)->ToString(), test_case.data); EXPECT_EQ(MakeCordString(value_factory, test_case.data)->ToString(), @@ -1355,7 +1345,7 @@ using StringToCordTest = testing::TestWithParam; TEST_P(StringToCordTest, ToCord) { const StringToCordTestCase& test_case = GetParam(); - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_EQ(MakeStringString(value_factory, test_case.data)->ToCord(), test_case.data); EXPECT_EQ(MakeCordString(value_factory, test_case.data)->ToCord(), @@ -1373,7 +1363,7 @@ INSTANTIATE_TEST_SUITE_P(StringToCordTest, StringToCordTest, })); TEST(Value, SupportsAbslHash) { - TestValueFactory value_factory; + ValueFactory value_factory(MemoryManager::Global()); EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ Persistent(value_factory.GetNullValue()), Persistent( From b1947263c570253f1bc47c4c08efcb3bf4922d66 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 18 Mar 2022 17:25:16 +0000 Subject: [PATCH 030/111] internal change PiperOrigin-RevId: 435672768 --- eval/tests/BUILD | 1 + eval/tests/allocation_benchmark_test.cc | 41 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/eval/tests/BUILD b/eval/tests/BUILD index d3908d152..4c80f6b19 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -71,6 +71,7 @@ cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/tests/allocation_benchmark_test.cc b/eval/tests/allocation_benchmark_test.cc index 20bd0849a..26bd41100 100644 --- a/eval/tests/allocation_benchmark_test.cc +++ b/eval/tests/allocation_benchmark_test.cc @@ -15,6 +15,7 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/rpc/context/attribute_context.pb.h" #include "google/protobuf/text_format.h" #include "absl/base/attributes.h" #include "absl/container/btree_map.h" @@ -196,6 +197,46 @@ static void BM_AllocateMessage(benchmark::State& state) { BENCHMARK(BM_AllocateMessage); +static void BM_AllocateLargeMessage(benchmark::State& state) { + // Make sure attribute context is loaded in the generated descriptor pool. + rpc::context::AttributeContext context; + static_cast(context); + + google::protobuf::Arena arena; + std::string expr(R"( + google.rpc.context.AttributeContext{ + source: google.rpc.context.AttributeContext.Peer{ + ip: '192.168.0.1', + port: 1025, + labels: {"abc": "123", "def": "456"} + }, + request: google.rpc.context.AttributeContext.Request{ + method: 'GET', + path: 'root', + host: 'www.example.com' + }, + resource: google.rpc.context.AttributeContext.Resource{ + labels: {"abc": "123", "def": "456"}, + } + })"); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + for (auto _ : state) { + Activation activation; + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsMessage()); + } +} + +BENCHMARK(BM_AllocateLargeMessage); + static void BM_AllocateList(benchmark::State& state) { google::protobuf::Arena arena; std::string expr("[1, 2, 3, 4]"); From ba76d687bebaab8cad21a5d010e56a2bd9af9f22 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 18 Mar 2022 17:27:17 +0000 Subject: [PATCH 031/111] Internal change PiperOrigin-RevId: 435673258 --- base/BUILD | 4 ++++ base/type_factory.h | 12 ++++++++++-- base/type_manager.h | 32 ++++++++++++++++++++++++++++++++ base/type_provider.h | 15 +++++++++++++++ base/type_registry.h | 27 +++++++++++++++++++++++++++ base/value_factory.h | 4 +++- 6 files changed, 91 insertions(+), 3 deletions(-) create mode 100644 base/type_manager.h create mode 100644 base/type_registry.h diff --git a/base/BUILD b/base/BUILD index de01486b1..be8548188 100644 --- a/base/BUILD +++ b/base/BUILD @@ -104,7 +104,9 @@ cc_library( hdrs = [ "type.h", "type_factory.h", + "type_manager.h", "type_provider.h", + "type_registry.h", ], deps = [ ":handle", @@ -114,6 +116,8 @@ cc_library( "//internal:no_destructor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], diff --git a/base/type_factory.h b/base/type_factory.h index 304d50d87..5f578a9e1 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -24,14 +24,18 @@ namespace cel { // TypeFactory provides member functions to get and create type implementations // of builtin types. -class TypeFactory final { +// +// While TypeFactory is not final and has a virtual destructor, inheriting it is +// forbidden outside of the CEL codebase. +class TypeFactory { public: explicit TypeFactory( MemoryManager& memory_manager ABSL_ATTRIBUTE_LIFETIME_BOUND) : memory_manager_(memory_manager) {} - TypeFactory(const TypeFactory&) = delete; + virtual ~TypeFactory() = default; + TypeFactory(const TypeFactory&) = delete; TypeFactory& operator=(const TypeFactory&) = delete; Persistent GetNullType() ABSL_ATTRIBUTE_LIFETIME_BOUND; @@ -60,6 +64,10 @@ class TypeFactory final { Persistent GetTimestampType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + // TODO(issues/5): Add CreateStructType(Args...) + // and CreateEnumType(Args...) which returns + // Persistent + protected: // Ignore unused for now, as it will be used in the future. ABSL_ATTRIBUTE_UNUSED MemoryManager& memory_manager() const { diff --git a/base/type_manager.h b/base/type_manager.h new file mode 100644 index 000000000..e18f30f27 --- /dev/null +++ b/base/type_manager.h @@ -0,0 +1,32 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_MANAGER_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPE_MANAGER_H_ + +#include "base/type_factory.h" +#include "base/type_registry.h" + +namespace cel { + +// TypeManager is a union of the TypeFactory and TypeRegistry, allowing for both +// the instantiation of type implementations, loading of type implementations, +// and registering type implementations. +// +// TODO(issues/5): more comments after solidifying role +class TypeManager : public TypeFactory, public TypeRegistry {}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPE_MANAGER_H_ diff --git a/base/type_provider.h b/base/type_provider.h index 6bbc5ab7d..8a481801c 100644 --- a/base/type_provider.h +++ b/base/type_provider.h @@ -15,6 +15,12 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "base/handle.h" +#include "base/type.h" + namespace cel { class TypeFactory; @@ -32,6 +38,15 @@ class TypeFactory; class TypeProvider { public: virtual ~TypeProvider() = default; + + // Return a persistent handle to a Type for the fully qualified type name, if + // available. + // + // An empty handle is returned if the provider cannot find the requested type. + virtual absl::StatusOr> ProvideType( + TypeFactory& type_factory, absl::string_view name) const { + return absl::UnimplementedError("ProvideType is not yet implemented"); + } }; } // namespace cel diff --git a/base/type_registry.h b/base/type_registry.h new file mode 100644 index 000000000..3f5e21333 --- /dev/null +++ b/base/type_registry.h @@ -0,0 +1,27 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_BASE_TYPE_REGISTRY_H_ + +#include "base/type_provider.h" + +namespace cel { + +// TODO(issues/5): define interface and consolidate with CelTypeRegistry +class TypeRegistry : public TypeProvider {}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_BASE_TYPE_REGISTRY_H_ diff --git a/base/value_factory.h b/base/value_factory.h index 6522cd347..24b9e6172 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -39,7 +39,6 @@ class ValueFactory final { : memory_manager_(memory_manager) {} ValueFactory(const ValueFactory&) = delete; - ValueFactory& operator=(const ValueFactory&) = delete; Persistent GetNullValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; @@ -133,6 +132,9 @@ class ValueFactory final { absl::StatusOr> CreateTimestampValue( absl::Time value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + // TODO(issues/5): Add CreateStructType(Args...) and + // CreateEnumType(Args...) which returns Persistent + protected: MemoryManager& memory_manager() const { return memory_manager_; } From 7e19b96082d8f814bcb4f376ccbf4c8dd3a048d2 Mon Sep 17 00:00:00 2001 From: tswadell Date: Mon, 21 Mar 2022 17:23:47 +0000 Subject: [PATCH 032/111] Edge case fixes for heterogeneous equality PiperOrigin-RevId: 436242601 --- eval/public/BUILD | 2 + eval/public/builtin_func_registrar.cc | 71 +++++++++++++++++++----- eval/public/builtin_func_test.cc | 24 +++++--- eval/public/comparison_functions.cc | 17 +++++- eval/public/comparison_functions_test.cc | 16 ++++++ 5 files changed, 107 insertions(+), 23 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index 3c0a0fce5..448c8a220 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -208,6 +208,7 @@ cc_library( ":cel_function", ":cel_function_adapter", ":cel_function_registry", + ":cel_number", ":cel_options", ":cel_value", ":comparison_functions", @@ -222,6 +223,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", "@com_googlesource_code_re2//:re2", ], diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 52390d148..e867c7608 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -29,10 +29,12 @@ #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "absl/types/optional.h" #include "eval/eval/mutable_list_impl.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_number.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/comparison_functions.h" @@ -555,34 +557,70 @@ absl::Status RegisterSetMembershipFunctions(CelFunctionRegistry* registry, auto boolKeyInSet = [](Arena* arena, bool key, const CelMap* cel_map) -> CelValue { const auto& result = cel_map->Has(CelValue::CreateBool(key)); - if (!result.ok()) { - return CreateErrorValue(arena, result.status()); + if (result.ok()) { + return CelValue::CreateBool(*result); } - return CelValue::CreateBool(*result); + return CelValue::CreateBool(false); + }; + auto doubleKeyInSet = [](Arena* arena, double key, + const CelMap* cel_map) -> CelValue { + absl::optional number = + GetNumberFromCelValue(CelValue::CreateDouble(key)); + if (number->LosslessConvertibleToInt()) { + const auto& result = cel_map->Has(CelValue::CreateInt64(number->AsInt())); + if (result.ok() && *result) { + return CelValue::CreateBool(*result); + } + } + if (number->LosslessConvertibleToUint()) { + const auto& result = + cel_map->Has(CelValue::CreateUint64(number->AsUint())); + if (result.ok() && *result) { + return CelValue::CreateBool(*result); + } + } + return CelValue::CreateBool(false); }; auto intKeyInSet = [](Arena* arena, int64_t key, const CelMap* cel_map) -> CelValue { - const auto& result = cel_map->Has(CelValue::CreateInt64(key)); - if (!result.ok()) { - return CreateErrorValue(arena, result.status()); + CelValue int_key = CelValue::CreateInt64(key); + const auto& result = cel_map->Has(int_key); + if (result.ok() && *result) { + return CelValue::CreateBool(*result); } - return CelValue::CreateBool(*result); + absl::optional number = GetNumberFromCelValue(int_key); + if (number->LosslessConvertibleToUint()) { + const auto& result = + cel_map->Has(CelValue::CreateUint64(number->AsUint())); + if (result.ok() && *result) { + return CelValue::CreateBool(*result); + } + } + return CelValue::CreateBool(false); }; auto stringKeyInSet = [](Arena* arena, CelValue::StringHolder key, const CelMap* cel_map) -> CelValue { const auto& result = cel_map->Has(CelValue::CreateString(key)); - if (!result.ok()) { - return CreateErrorValue(arena, result.status()); + if (result.ok()) { + return CelValue::CreateBool(*result); } - return CelValue::CreateBool(*result); + return CelValue::CreateBool(false); }; auto uintKeyInSet = [](Arena* arena, uint64_t key, const CelMap* cel_map) -> CelValue { - const auto& result = cel_map->Has(CelValue::CreateUint64(key)); - if (!result.ok()) { - return CreateErrorValue(arena, result.status()); + CelValue uint_key = CelValue::CreateUint64(key); + const auto& result = cel_map->Has(uint_key); + if (result.ok() && *result) { + return CelValue::CreateBool(*result); + } + absl::optional number = GetNumberFromCelValue(uint_key); + if (number->LosslessConvertibleToInt()) { + const auto& result = cel_map->Has(CelValue::CreateInt64(number->AsInt())); + if (result.ok() && *result) { + return CelValue::CreateBool(*result); + } } - return CelValue::CreateBool(*result); + return CelValue::CreateBool(false); }; for (auto op : in_operators) { @@ -597,6 +635,11 @@ absl::Status RegisterSetMembershipFunctions(CelFunctionRegistry* registry, op, false, boolKeyInSet, registry); if (!status.ok()) return status; + status = + FunctionAdapter::CreateAndRegister( + op, false, doubleKeyInSet, registry); + if (!status.ok()) return status; + status = FunctionAdapter::CreateAndRegister( op, false, intKeyInSet, registry); diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index cba38ceea..6bb9165d3 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -388,7 +388,7 @@ class BuiltinsTest : public ::testing::Test { ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) - << " for " << CelValue::TypeName(value.type()); + << " for " << value.DebugString(); } void TestInDeprecatedMap(const CelMap* cel_map, const CelValue& value, @@ -1579,11 +1579,8 @@ TEST_F(BuiltinsTest, TestMapInError) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kIn, {}, {key, CelValue::CreateMap(&cel_map)}, &result_value)); - - EXPECT_TRUE(result_value.IsError()); - EXPECT_EQ(result_value.ErrorOrDie()->message(), "bad key type"); - EXPECT_EQ(result_value.ErrorOrDie()->code(), - absl::StatusCode::kInvalidArgument); + EXPECT_TRUE(result_value.IsBool()); + EXPECT_FALSE(result_value.BoolOrDie()); } } @@ -1608,7 +1605,14 @@ TEST_F(BuiltinsTest, TestInt64MapIn) { FakeInt64Map cel_map(data); TestInMap(&cel_map, CelValue::CreateInt64(-4), true); TestInMap(&cel_map, CelValue::CreateInt64(4), false); - TestInMap(&cel_map, CelValue::CreateUint64(3), false); + TestInMap(&cel_map, CelValue::CreateUint64(3), true); + TestInMap(&cel_map, CelValue::CreateUint64(4), false); + TestInMap(&cel_map, CelValue::CreateDouble(NAN), false); + TestInMap(&cel_map, CelValue::CreateDouble(-4.0), true); + TestInMap(&cel_map, CelValue::CreateDouble(-4.1), false); + TestInMap(&cel_map, + CelValue::CreateDouble(std::numeric_limits::max()), + false); } TEST_F(BuiltinsTest, TestUint64MapIn) { @@ -1620,7 +1624,11 @@ TEST_F(BuiltinsTest, TestUint64MapIn) { FakeUint64Map cel_map(data); TestInMap(&cel_map, CelValue::CreateUint64(4), true); TestInMap(&cel_map, CelValue::CreateUint64(44), false); - TestInMap(&cel_map, CelValue::CreateInt64(4), false); + TestInMap(&cel_map, CelValue::CreateDouble(4.0), true); + TestInMap(&cel_map, CelValue::CreateDouble(-4.0), false); + TestInMap(&cel_map, CelValue::CreateDouble(7.0), false); + TestInMap(&cel_map, CelValue::CreateInt64(4), true); + TestInMap(&cel_map, CelValue::CreateInt64(-1), false); } TEST_F(BuiltinsTest, TestStringMapIn) { diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index 59ad41da2..5b6d7a0d8 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -245,6 +245,22 @@ absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { CelValue key = (*keys)[i]; CelValue v1 = (*t1)[key].value(); absl::optional v2 = (*t2)[key]; + if (!v2.has_value()) { + auto number = GetNumberFromCelValue(key); + if (!number.has_value()) { + return false; + } + if (key.IsUint64()) { + v2 = (*t2)[key]; + } + if (!v2.has_value() && number->LosslessConvertibleToInt()) { + v2 = (*t2)[CelValue::CreateInt64(number->AsInt())]; + } + if (!key.IsUint64() && !v2.has_value() && + number->LosslessConvertibleToUint()) { + v2 = (*t2)[CelValue::CreateUint64(number->AsUint())]; + } + } if (!v2.has_value()) { return false; } @@ -254,7 +270,6 @@ absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { return eq; } } - return true; } diff --git a/eval/public/comparison_functions_test.cc b/eval/public/comparison_functions_test.cc index b8723d949..c37d73a10 100644 --- a/eval/public/comparison_functions_test.cc +++ b/eval/public/comparison_functions_test.cc @@ -338,6 +338,22 @@ TEST(CelValueEqualImplTest, MapMixedValueTypesInequal) { Optional(false)); } +TEST(CelValueEqualImplTest, MapMixedKeyTypesEqual) { + std::vector> lhs_data{ + {CelValue::CreateUint64(1), CelValue::CreateStringView("abc")}}; + std::vector> rhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, + CreateContainerBackedMap(absl::MakeSpan(lhs_data))); + ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, + CreateContainerBackedMap(absl::MakeSpan(rhs_data))); + + EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), + CelValue::CreateMap(rhs.get())), + Optional(true)); +} + TEST(CelValueEqualImplTest, MapMixedKeyTypesInequal) { std::vector> lhs_data{ {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; From 63f9d597f1755ce658f3a48fdc563c33cd061962 Mon Sep 17 00:00:00 2001 From: tswadell Date: Mon, 21 Mar 2022 18:15:01 +0000 Subject: [PATCH 033/111] Edge case fixes for heterogeneous equality PiperOrigin-RevId: 436256886 --- eval/public/BUILD | 2 - eval/public/builtin_func_registrar.cc | 71 +++++------------------- eval/public/builtin_func_test.cc | 24 +++----- eval/public/comparison_functions.cc | 17 +----- eval/public/comparison_functions_test.cc | 16 ------ 5 files changed, 23 insertions(+), 107 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index 448c8a220..3c0a0fce5 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -208,7 +208,6 @@ cc_library( ":cel_function", ":cel_function_adapter", ":cel_function_registry", - ":cel_number", ":cel_options", ":cel_value", ":comparison_functions", @@ -223,7 +222,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", "@com_googlesource_code_re2//:re2", ], diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index e867c7608..52390d148 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -29,12 +29,10 @@ #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" -#include "absl/types/optional.h" #include "eval/eval/mutable_list_impl.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" -#include "eval/public/cel_number.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/comparison_functions.h" @@ -557,70 +555,34 @@ absl::Status RegisterSetMembershipFunctions(CelFunctionRegistry* registry, auto boolKeyInSet = [](Arena* arena, bool key, const CelMap* cel_map) -> CelValue { const auto& result = cel_map->Has(CelValue::CreateBool(key)); - if (result.ok()) { - return CelValue::CreateBool(*result); + if (!result.ok()) { + return CreateErrorValue(arena, result.status()); } - return CelValue::CreateBool(false); - }; - auto doubleKeyInSet = [](Arena* arena, double key, - const CelMap* cel_map) -> CelValue { - absl::optional number = - GetNumberFromCelValue(CelValue::CreateDouble(key)); - if (number->LosslessConvertibleToInt()) { - const auto& result = cel_map->Has(CelValue::CreateInt64(number->AsInt())); - if (result.ok() && *result) { - return CelValue::CreateBool(*result); - } - } - if (number->LosslessConvertibleToUint()) { - const auto& result = - cel_map->Has(CelValue::CreateUint64(number->AsUint())); - if (result.ok() && *result) { - return CelValue::CreateBool(*result); - } - } - return CelValue::CreateBool(false); + return CelValue::CreateBool(*result); }; auto intKeyInSet = [](Arena* arena, int64_t key, const CelMap* cel_map) -> CelValue { - CelValue int_key = CelValue::CreateInt64(key); - const auto& result = cel_map->Has(int_key); - if (result.ok() && *result) { - return CelValue::CreateBool(*result); + const auto& result = cel_map->Has(CelValue::CreateInt64(key)); + if (!result.ok()) { + return CreateErrorValue(arena, result.status()); } - absl::optional number = GetNumberFromCelValue(int_key); - if (number->LosslessConvertibleToUint()) { - const auto& result = - cel_map->Has(CelValue::CreateUint64(number->AsUint())); - if (result.ok() && *result) { - return CelValue::CreateBool(*result); - } - } - return CelValue::CreateBool(false); + return CelValue::CreateBool(*result); }; auto stringKeyInSet = [](Arena* arena, CelValue::StringHolder key, const CelMap* cel_map) -> CelValue { const auto& result = cel_map->Has(CelValue::CreateString(key)); - if (result.ok()) { - return CelValue::CreateBool(*result); + if (!result.ok()) { + return CreateErrorValue(arena, result.status()); } - return CelValue::CreateBool(false); + return CelValue::CreateBool(*result); }; auto uintKeyInSet = [](Arena* arena, uint64_t key, const CelMap* cel_map) -> CelValue { - CelValue uint_key = CelValue::CreateUint64(key); - const auto& result = cel_map->Has(uint_key); - if (result.ok() && *result) { - return CelValue::CreateBool(*result); - } - absl::optional number = GetNumberFromCelValue(uint_key); - if (number->LosslessConvertibleToInt()) { - const auto& result = cel_map->Has(CelValue::CreateInt64(number->AsInt())); - if (result.ok() && *result) { - return CelValue::CreateBool(*result); - } + const auto& result = cel_map->Has(CelValue::CreateUint64(key)); + if (!result.ok()) { + return CreateErrorValue(arena, result.status()); } - return CelValue::CreateBool(false); + return CelValue::CreateBool(*result); }; for (auto op : in_operators) { @@ -635,11 +597,6 @@ absl::Status RegisterSetMembershipFunctions(CelFunctionRegistry* registry, op, false, boolKeyInSet, registry); if (!status.ok()) return status; - status = - FunctionAdapter::CreateAndRegister( - op, false, doubleKeyInSet, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( op, false, intKeyInSet, registry); diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 6bb9165d3..cba38ceea 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -388,7 +388,7 @@ class BuiltinsTest : public ::testing::Test { ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) - << " for " << value.DebugString(); + << " for " << CelValue::TypeName(value.type()); } void TestInDeprecatedMap(const CelMap* cel_map, const CelValue& value, @@ -1579,8 +1579,11 @@ TEST_F(BuiltinsTest, TestMapInError) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kIn, {}, {key, CelValue::CreateMap(&cel_map)}, &result_value)); - EXPECT_TRUE(result_value.IsBool()); - EXPECT_FALSE(result_value.BoolOrDie()); + + EXPECT_TRUE(result_value.IsError()); + EXPECT_EQ(result_value.ErrorOrDie()->message(), "bad key type"); + EXPECT_EQ(result_value.ErrorOrDie()->code(), + absl::StatusCode::kInvalidArgument); } } @@ -1605,14 +1608,7 @@ TEST_F(BuiltinsTest, TestInt64MapIn) { FakeInt64Map cel_map(data); TestInMap(&cel_map, CelValue::CreateInt64(-4), true); TestInMap(&cel_map, CelValue::CreateInt64(4), false); - TestInMap(&cel_map, CelValue::CreateUint64(3), true); - TestInMap(&cel_map, CelValue::CreateUint64(4), false); - TestInMap(&cel_map, CelValue::CreateDouble(NAN), false); - TestInMap(&cel_map, CelValue::CreateDouble(-4.0), true); - TestInMap(&cel_map, CelValue::CreateDouble(-4.1), false); - TestInMap(&cel_map, - CelValue::CreateDouble(std::numeric_limits::max()), - false); + TestInMap(&cel_map, CelValue::CreateUint64(3), false); } TEST_F(BuiltinsTest, TestUint64MapIn) { @@ -1624,11 +1620,7 @@ TEST_F(BuiltinsTest, TestUint64MapIn) { FakeUint64Map cel_map(data); TestInMap(&cel_map, CelValue::CreateUint64(4), true); TestInMap(&cel_map, CelValue::CreateUint64(44), false); - TestInMap(&cel_map, CelValue::CreateDouble(4.0), true); - TestInMap(&cel_map, CelValue::CreateDouble(-4.0), false); - TestInMap(&cel_map, CelValue::CreateDouble(7.0), false); - TestInMap(&cel_map, CelValue::CreateInt64(4), true); - TestInMap(&cel_map, CelValue::CreateInt64(-1), false); + TestInMap(&cel_map, CelValue::CreateInt64(4), false); } TEST_F(BuiltinsTest, TestStringMapIn) { diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index 5b6d7a0d8..59ad41da2 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -245,22 +245,6 @@ absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { CelValue key = (*keys)[i]; CelValue v1 = (*t1)[key].value(); absl::optional v2 = (*t2)[key]; - if (!v2.has_value()) { - auto number = GetNumberFromCelValue(key); - if (!number.has_value()) { - return false; - } - if (key.IsUint64()) { - v2 = (*t2)[key]; - } - if (!v2.has_value() && number->LosslessConvertibleToInt()) { - v2 = (*t2)[CelValue::CreateInt64(number->AsInt())]; - } - if (!key.IsUint64() && !v2.has_value() && - number->LosslessConvertibleToUint()) { - v2 = (*t2)[CelValue::CreateUint64(number->AsUint())]; - } - } if (!v2.has_value()) { return false; } @@ -270,6 +254,7 @@ absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { return eq; } } + return true; } diff --git a/eval/public/comparison_functions_test.cc b/eval/public/comparison_functions_test.cc index c37d73a10..b8723d949 100644 --- a/eval/public/comparison_functions_test.cc +++ b/eval/public/comparison_functions_test.cc @@ -338,22 +338,6 @@ TEST(CelValueEqualImplTest, MapMixedValueTypesInequal) { Optional(false)); } -TEST(CelValueEqualImplTest, MapMixedKeyTypesEqual) { - std::vector> lhs_data{ - {CelValue::CreateUint64(1), CelValue::CreateStringView("abc")}}; - std::vector> rhs_data{ - {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; - - ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, - CreateContainerBackedMap(absl::MakeSpan(lhs_data))); - ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, - CreateContainerBackedMap(absl::MakeSpan(rhs_data))); - - EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), - CelValue::CreateMap(rhs.get())), - Optional(true)); -} - TEST(CelValueEqualImplTest, MapMixedKeyTypesInequal) { std::vector> lhs_data{ {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; From 26e4e57385a4a25ee5c8dcbefc1e91df60280c4d Mon Sep 17 00:00:00 2001 From: tswadell Date: Mon, 21 Mar 2022 21:15:33 +0000 Subject: [PATCH 034/111] Edge case fixes for heterogeneous equality PiperOrigin-RevId: 436302402 --- eval/public/BUILD | 2 + eval/public/builtin_func_registrar.cc | 100 +++++++++++++++++++---- eval/public/builtin_func_test.cc | 37 ++++++++- eval/public/comparison_functions.cc | 21 +++++ eval/public/comparison_functions_test.cc | 16 ++++ 5 files changed, 156 insertions(+), 20 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index 3c0a0fce5..448c8a220 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -208,6 +208,7 @@ cc_library( ":cel_function", ":cel_function_adapter", ":cel_function_registry", + ":cel_number", ":cel_options", ":cel_value", ":comparison_functions", @@ -222,6 +223,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", "@com_googlesource_code_re2//:re2", ], diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 52390d148..b57782fcd 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -29,10 +29,12 @@ #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "absl/types/optional.h" #include "eval/eval/mutable_list_impl.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_number.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/comparison_functions.h" @@ -552,39 +554,98 @@ absl::Status RegisterSetMembershipFunctions(CelFunctionRegistry* registry, } } - auto boolKeyInSet = [](Arena* arena, bool key, - const CelMap* cel_map) -> CelValue { + auto boolKeyInSet = [options](Arena* arena, bool key, + const CelMap* cel_map) -> CelValue { const auto& result = cel_map->Has(CelValue::CreateBool(key)); - if (!result.ok()) { - return CreateErrorValue(arena, result.status()); + if (result.ok()) { + return CelValue::CreateBool(*result); } - return CelValue::CreateBool(*result); + if (options.enable_heterogeneous_equality) { + return CelValue::CreateBool(false); + } + return CreateErrorValue(arena, result.status()); }; - auto intKeyInSet = [](Arena* arena, int64_t key, - const CelMap* cel_map) -> CelValue { - const auto& result = cel_map->Has(CelValue::CreateInt64(key)); + + auto intKeyInSet = [options](Arena* arena, int64_t key, + const CelMap* cel_map) -> CelValue { + CelValue int_key = CelValue::CreateInt64(key); + const auto& result = cel_map->Has(int_key); + if (options.enable_heterogeneous_equality) { + if (result.ok() && *result) { + return CelValue::CreateBool(*result); + } + absl::optional number = GetNumberFromCelValue(int_key); + if (number->LosslessConvertibleToUint()) { + const auto& result = + cel_map->Has(CelValue::CreateUint64(number->AsUint())); + if (result.ok() && *result) { + return CelValue::CreateBool(*result); + } + } + return CelValue::CreateBool(false); + } if (!result.ok()) { return CreateErrorValue(arena, result.status()); } return CelValue::CreateBool(*result); }; - auto stringKeyInSet = [](Arena* arena, CelValue::StringHolder key, - const CelMap* cel_map) -> CelValue { + + auto stringKeyInSet = [options](Arena* arena, CelValue::StringHolder key, + const CelMap* cel_map) -> CelValue { const auto& result = cel_map->Has(CelValue::CreateString(key)); - if (!result.ok()) { - return CreateErrorValue(arena, result.status()); + if (result.ok()) { + return CelValue::CreateBool(*result); } - return CelValue::CreateBool(*result); + if (options.enable_heterogeneous_equality) { + return CelValue::CreateBool(false); + } + return CreateErrorValue(arena, result.status()); }; - auto uintKeyInSet = [](Arena* arena, uint64_t key, - const CelMap* cel_map) -> CelValue { - const auto& result = cel_map->Has(CelValue::CreateUint64(key)); + + auto uintKeyInSet = [options](Arena* arena, uint64_t key, + const CelMap* cel_map) -> CelValue { + CelValue uint_key = CelValue::CreateUint64(key); + const auto& result = cel_map->Has(uint_key); + if (options.enable_heterogeneous_equality) { + if (result.ok() && *result) { + return CelValue::CreateBool(*result); + } + absl::optional number = GetNumberFromCelValue(uint_key); + if (number->LosslessConvertibleToInt()) { + const auto& result = + cel_map->Has(CelValue::CreateInt64(number->AsInt())); + if (result.ok() && *result) { + return CelValue::CreateBool(*result); + } + } + return CelValue::CreateBool(false); + } if (!result.ok()) { return CreateErrorValue(arena, result.status()); } return CelValue::CreateBool(*result); }; + auto doubleKeyInSet = [](Arena* arena, double key, + const CelMap* cel_map) -> CelValue { + absl::optional number = + GetNumberFromCelValue(CelValue::CreateDouble(key)); + if (number->LosslessConvertibleToInt()) { + const auto& result = cel_map->Has(CelValue::CreateInt64(number->AsInt())); + if (result.ok() && *result) { + return CelValue::CreateBool(*result); + } + } + if (number->LosslessConvertibleToUint()) { + const auto& result = + cel_map->Has(CelValue::CreateUint64(number->AsUint())); + if (result.ok() && *result) { + return CelValue::CreateBool(*result); + } + } + return CelValue::CreateBool(false); + }; + for (auto op : in_operators) { auto status = FunctionAdapter::CreateAndRegister( op, false, uintKeyInSet, registry); if (!status.ok()) return status; + + if (options.enable_heterogeneous_equality) { + status = + FunctionAdapter::CreateAndRegister( + op, false, doubleKeyInSet, registry); + if (!status.ok()) return status; + } } return absl::OkStatus(); } diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index cba38ceea..e38a49a0c 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -384,11 +384,11 @@ class BuiltinsTest : public ::testing::Test { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kIn, {}, {value, CelValue::CreateMap(cel_map)}, - &result_value)); + &result_value, options_)); ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) - << " for " << CelValue::TypeName(value.type()); + << " for " << value.DebugString(); } void TestInDeprecatedMap(const CelMap* cel_map, const CelValue& value, @@ -396,7 +396,7 @@ class BuiltinsTest : public ::testing::Test { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInDeprecated, {}, {value, CelValue::CreateMap(cel_map)}, - &result_value)); + &result_value, options_)); ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) @@ -408,7 +408,7 @@ class BuiltinsTest : public ::testing::Test { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInFunction, {}, {value, CelValue::CreateMap(cel_map)}, - &result_value)); + &result_value, options_)); ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) @@ -1575,6 +1575,17 @@ TEST_F(BuiltinsTest, TestMapInError) { CelValue::CreateStringView("hello"), CelValue::CreateUint64(2), }; + + options_.enable_heterogeneous_equality = true; + for (auto key : kValues) { + CelValue result_value; + ASSERT_NO_FATAL_FAILURE(PerformRun( + builtin::kIn, {}, {key, CelValue::CreateMap(&cel_map)}, &result_value)); + EXPECT_TRUE(result_value.IsBool()); + EXPECT_FALSE(result_value.BoolOrDie()); + } + + options_.enable_heterogeneous_equality = false; for (auto key : kValues) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( @@ -1609,6 +1620,17 @@ TEST_F(BuiltinsTest, TestInt64MapIn) { TestInMap(&cel_map, CelValue::CreateInt64(-4), true); TestInMap(&cel_map, CelValue::CreateInt64(4), false); TestInMap(&cel_map, CelValue::CreateUint64(3), false); + TestInMap(&cel_map, CelValue::CreateUint64(4), false); + + options_.enable_heterogeneous_equality = true; + TestInMap(&cel_map, CelValue::CreateUint64(3), true); + TestInMap(&cel_map, CelValue::CreateUint64(4), false); + TestInMap(&cel_map, CelValue::CreateDouble(NAN), false); + TestInMap(&cel_map, CelValue::CreateDouble(-4.0), true); + TestInMap(&cel_map, CelValue::CreateDouble(-4.1), false); + TestInMap(&cel_map, + CelValue::CreateDouble(std::numeric_limits::max()), + false); } TEST_F(BuiltinsTest, TestUint64MapIn) { @@ -1621,6 +1643,13 @@ TEST_F(BuiltinsTest, TestUint64MapIn) { TestInMap(&cel_map, CelValue::CreateUint64(4), true); TestInMap(&cel_map, CelValue::CreateUint64(44), false); TestInMap(&cel_map, CelValue::CreateInt64(4), false); + + options_.enable_heterogeneous_equality = true; + TestInMap(&cel_map, CelValue::CreateInt64(-1), false); + TestInMap(&cel_map, CelValue::CreateInt64(4), true); + TestInMap(&cel_map, CelValue::CreateDouble(4.0), true); + TestInMap(&cel_map, CelValue::CreateDouble(-4.0), false); + TestInMap(&cel_map, CelValue::CreateDouble(7.0), false); } TEST_F(BuiltinsTest, TestStringMapIn) { diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index 59ad41da2..cc33df500 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -245,6 +245,27 @@ absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { CelValue key = (*keys)[i]; CelValue v1 = (*t1)[key].value(); absl::optional v2 = (*t2)[key]; + if (!v2.has_value()) { + auto number = GetNumberFromCelValue(key); + if (!number.has_value()) { + return false; + } + if (!key.IsInt64() && number->LosslessConvertibleToInt()) { + CelValue int_key = CelValue::CreateInt64(number->AsInt()); + absl::optional eq = EqualsProvider()(key, int_key); + if (eq.has_value() && *eq) { + v2 = (*t2)[int_key]; + } + } + if (!key.IsUint64() && !v2.has_value() && + number->LosslessConvertibleToUint()) { + CelValue uint_key = CelValue::CreateUint64(number->AsUint()); + absl::optional eq = EqualsProvider()(key, uint_key); + if (eq.has_value() && *eq) { + v2 = (*t2)[uint_key]; + } + } + } if (!v2.has_value()) { return false; } diff --git a/eval/public/comparison_functions_test.cc b/eval/public/comparison_functions_test.cc index b8723d949..c37d73a10 100644 --- a/eval/public/comparison_functions_test.cc +++ b/eval/public/comparison_functions_test.cc @@ -338,6 +338,22 @@ TEST(CelValueEqualImplTest, MapMixedValueTypesInequal) { Optional(false)); } +TEST(CelValueEqualImplTest, MapMixedKeyTypesEqual) { + std::vector> lhs_data{ + {CelValue::CreateUint64(1), CelValue::CreateStringView("abc")}}; + std::vector> rhs_data{ + {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, + CreateContainerBackedMap(absl::MakeSpan(lhs_data))); + ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, + CreateContainerBackedMap(absl::MakeSpan(rhs_data))); + + EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), + CelValue::CreateMap(rhs.get())), + Optional(true)); +} + TEST(CelValueEqualImplTest, MapMixedKeyTypesInequal) { std::vector> lhs_data{ {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; From 3dd2e367feb461fcfa84b92e22fbce7638c9cd0c Mon Sep 17 00:00:00 2001 From: tswadell Date: Tue, 22 Mar 2022 05:04:04 +0000 Subject: [PATCH 035/111] Internal sync PiperOrigin-RevId: 436384650 --- conformance/BUILD | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index 97c603d03..d5748fbce 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -84,28 +84,16 @@ cc_binary( # TODO(issues/112): Unbound functions result in empty eval response. "--skip_test=basic/functions/unbound", "--skip_test=basic/functions/unbound_is_runtime_error", - # TODO(issues/113): Aggregate values must logically AND element equality results. - "--skip_test=comparisons/eq_literal/not_eq_list_false_vs_types", - "--skip_test=comparisons/eq_literal/not_eq_map_false_vs_types", - # TODO(issues/114): Ensure the 'in' operator is a logical OR of element equality results. - "--skip_test=comparisons/in_list_literal/elem_in_mixed_type_list_error", - "--skip_test=comparisons/in_map_literal/key_in_mixed_key_type_map_error", + # TODO(issues/116): Debug why dynamic/list/var fails to JSON parse correctly. + "--skip_test=dynamic/list/var", # TODO(issues/97): Parse-only qualified variable lookup "x.y" wtih binding "x.y" or "y" within container "x" fails "--skip_test=fields/qualified_identifier_resolution/qualified_ident,map_field_select,ident_with_longest_prefix_check,qualified_identifier_resolution_unchecked", "--skip_test=namespace/qualified/self_eval_qualified_lookup", "--skip_test=namespace/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", - # TODO(issues/116): Debug why dynamic/list/var fails to JSON parse correctly. - "--skip_test=dynamic/list/var", # TODO(issues/117): Integer overflow on enum assignments should error. "--skip_test=enums/legacy_proto2/select_big,select_neg", - # TODO(issues/127): Ensure overflow occurs on conversions of double values which might not work properly on all platforms. - "--skip_test=conversions/int/double_int_min_range", # Future features for CEL 1.0 - # TODO(google/cel-spec/issues/225): These are supported comparisons with heterogeneous equality enabled. - "--skip_test=comparisons/eq_literal/eq_list_elem_mixed_types_error,eq_mixed_types_error,eq_map_value_mixed_types_error", - "--skip_test=comparisons/ne_literal/ne_mixed_types_error", - "--skip_test=macros/exists/list_elem_type_exhaustive,map_key_type_exhaustive", # TODO(issues/119): Strong typing support for enums, specified but not implemented. "--skip_test=enums/strong_proto2", "--skip_test=enums/strong_proto3", From 2fd4dbccf076b119f1b9444b18c66e8172d89e6d Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 22 Mar 2022 20:28:37 +0000 Subject: [PATCH 036/111] Internal change PiperOrigin-RevId: 436555880 --- base/BUILD | 4 + base/internal/BUILD | 6 ++ base/internal/type.post.h | 23 ++++++ base/internal/value.post.h | 1 + base/type.cc | 18 +++++ base/type.h | 138 ++++++++++++++++++++++++++++++++++ base/type_factory.h | 28 +++++-- base/type_test.cc | 122 ++++++++++++++++++++++++++++++ base/value.cc | 41 ++++++++++ base/value.h | 86 +++++++++++++++++++++ base/value_factory.h | 25 +++++-- base/value_test.cc | 148 +++++++++++++++++++++++++++++++++++++ 12 files changed, 627 insertions(+), 13 deletions(-) diff --git a/base/BUILD b/base/BUILD index be8548188..a0c428379 100644 --- a/base/BUILD +++ b/base/BUILD @@ -120,6 +120,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], ) @@ -130,8 +131,10 @@ cc_test( ":handle", ":memory_manager", ":type", + ":value", "//internal:testing", "@com_google_absl//absl/hash:hash_testing", + "@com_google_absl//absl/status", ], ) @@ -159,6 +162,7 @@ cc_library( "//internal:utf8", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", "@com_google_absl//absl/status", diff --git a/base/internal/BUILD b/base/internal/BUILD index ce4b046d7..2e13eb5e0 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -54,6 +54,12 @@ cc_library( "type.pre.h", "type.post.h", ], + deps = [ + "//base:handle", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/numeric:bits", + ], ) cc_library( diff --git a/base/internal/type.post.h b/base/internal/type.post.h index 102c8dee2..89927e8fc 100644 --- a/base/internal/type.post.h +++ b/base/internal/type.post.h @@ -40,6 +40,10 @@ class TypeHandleBase { public: constexpr TypeHandleBase() = default; + // Used by derived classes to bypass default construction to perform their own + // construction. + explicit TypeHandleBase(HandleInPlace) {} + // Called by `Transient` and `Persistent` to implement the same operator. They // will handle enforcing const correctness. Type& operator*() const { return get(); } @@ -163,6 +167,24 @@ class TypeHandle final : public TypeHandleBase { explicit TypeHandle(const TransientTypeHandle& other) { rep_ = other.Ref(); } + template + TypeHandle(UnmanagedResource, F& from) : TypeHandleBase(kHandleInPlace) { + uintptr_t rep = reinterpret_cast( + static_cast(static_cast(std::addressof(from)))); + ABSL_ASSERT(absl::countr_zero(rep) >= + 2); // Verify the lower 2 bits are available. + rep_ = rep | kTypeHandleUnmanaged; + } + + template + TypeHandle(ManagedResource, F& from) : TypeHandleBase(kHandleInPlace) { + uintptr_t rep = reinterpret_cast( + static_cast(static_cast(std::addressof(from)))); + ABSL_ASSERT(absl::countr_zero(rep) >= + 2); // Verify the lower 2 bits are available. + rep_ = rep; + } + ~TypeHandle() { Unref(); } TypeHandle& operator=(const PersistentTypeHandle& other) { @@ -242,6 +264,7 @@ CEL_INTERNAL_TYPE_DECL(BytesType); CEL_INTERNAL_TYPE_DECL(StringType); CEL_INTERNAL_TYPE_DECL(DurationType); CEL_INTERNAL_TYPE_DECL(TimestampType); +CEL_INTERNAL_TYPE_DECL(EnumType); #undef CEL_INTERNAL_TYPE_DECL } // namespace cel diff --git a/base/internal/value.post.h b/base/internal/value.post.h index d7a3fe752..bc7dfe899 100644 --- a/base/internal/value.post.h +++ b/base/internal/value.post.h @@ -655,6 +655,7 @@ CEL_INTERNAL_VALUE_DECL(BytesValue); CEL_INTERNAL_VALUE_DECL(StringValue); CEL_INTERNAL_VALUE_DECL(DurationValue); CEL_INTERNAL_VALUE_DECL(TimestampValue); +CEL_INTERNAL_VALUE_DECL(EnumValue); #undef CEL_INTERNAL_VALUE_DECL } // namespace cel diff --git a/base/type.cc b/base/type.cc index e6294ae6b..cef3aa0f5 100644 --- a/base/type.cc +++ b/base/type.cc @@ -17,6 +17,7 @@ #include #include "absl/types/span.h" +#include "absl/types/variant.h" #include "base/handle.h" #include "internal/no_destructor.h" @@ -40,6 +41,7 @@ CEL_INTERNAL_TYPE_IMPL(BytesType); CEL_INTERNAL_TYPE_IMPL(StringType); CEL_INTERNAL_TYPE_IMPL(DurationType); CEL_INTERNAL_TYPE_IMPL(TimestampType); +CEL_INTERNAL_TYPE_IMPL(EnumType); #undef CEL_INTERNAL_TYPE_IMPL absl::Span> Type::parameters() const { return {}; } @@ -117,4 +119,20 @@ const TimestampType& TimestampType::Get() { return *instance; } +struct EnumType::FindConstantVisitor final { + const EnumType& enum_type; + + absl::StatusOr operator()(absl::string_view name) const { + return enum_type.FindConstantByName(name); + } + + absl::StatusOr operator()(int64_t number) const { + return enum_type.FindConstantByNumber(number); + } +}; + +absl::StatusOr EnumType::FindConstant(ConstantId id) const { + return absl::visit(FindConstantVisitor{*this}, id.data_); +} + } // namespace cel diff --git a/base/type.h b/base/type.h index 87093a183..028e75b36 100644 --- a/base/type.h +++ b/base/type.h @@ -15,13 +15,16 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPE_H_ +#include #include #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/hash/hash.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "absl/types/variant.h" #include "base/handle.h" #include "base/internal/type.pre.h" // IWYU pragma: export #include "base/kind.h" @@ -42,7 +45,9 @@ class StringType; class BytesType; class DurationType; class TimestampType; +class EnumType; class TypeFactory; +class TypeProvider; class NullValue; class ErrorValue; @@ -54,6 +59,7 @@ class BytesValue; class StringValue; class DurationValue; class TimestampValue; +class EnumValue; class ValueFactory; namespace internal { @@ -87,6 +93,7 @@ class Type : public base_internal::Resource { friend class BytesType; friend class DurationType; friend class TimestampType; + friend class EnumType; friend class base_internal::TypeHandleBase; Type() = default; @@ -411,6 +418,123 @@ class TimestampType final : public Type { TimestampType(TimestampType&&) = delete; }; +// EnumType represents an enumeration type. An enumeration is a set of constants +// that can be looked up by name and/or number. +class EnumType : public Type { + public: + struct Constant; + + class ConstantId final { + public: + explicit ConstantId(absl::string_view name) + : data_(absl::in_place_type, name) {} + + explicit ConstantId(int64_t number) + : data_(absl::in_place_type, number) {} + + ConstantId() = delete; + + ConstantId(const ConstantId&) = default; + ConstantId& operator=(const ConstantId&) = default; + + private: + friend class EnumType; + friend class EnumValue; + + absl::variant data_; + }; + + Kind kind() const final { return Kind::kEnum; } + + absl::Span> parameters() const final { + return Type::parameters(); + } + + // Find the constant definition for the given identifier. + absl::StatusOr FindConstant(ConstantId id) const; + + protected: + EnumType() = default; + + // Construct a new instance of EnumValue with a type of this. Called by + // EnumValue::New. + virtual absl::StatusOr> NewInstanceByName( + ValueFactory& value_factory, absl::string_view name) const = 0; + + // Construct a new instance of EnumValue with a type of this. Called by + // EnumValue::New. + virtual absl::StatusOr> NewInstanceByNumber( + ValueFactory& value_factory, int64_t number) const = 0; + + // Called by FindConstant. + virtual absl::StatusOr FindConstantByName( + absl::string_view name) const = 0; + + // Called by FindConstant. + virtual absl::StatusOr FindConstantByNumber( + int64_t number) const = 0; + + private: + struct NewInstanceVisitor; + struct FindConstantVisitor; + + friend struct NewInstanceVisitor; + friend struct FindConstantVisitor; + friend class EnumValue; + friend class TypeFactory; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kEnum; } + + EnumType(const EnumType&) = delete; + EnumType(EnumType&&) = delete; + + std::pair SizeAndAlignment() const override = 0; +}; + +// CEL_DECLARE_ENUM_TYPE declares `enum_type` as an enumeration type. It must be +// part of the class definition of `enum_type`. +// +// class MyEnumType : public cel::EnumType { +// ... +// private: +// CEL_DECLARE_ENUM_TYPE(MyEnumType); +// }; +#define CEL_DECLARE_ENUM_TYPE(enum_type) \ + private: \ + friend class ::cel::base_internal::TypeHandleBase; \ + \ + ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; + +// CEL_IMPLEMENT_ENUM_TYPE implements `enum_type` as an enumeration type. It +// must be called after the class definition of `enum_type`. +// +// class MyEnumType : public cel::EnumType { +// ... +// private: +// CEL_DECLARE_ENUM_TYPE(MyEnumType); +// }; +// +// CEL_IMPLEMENT_ENUM_TYPE(MyEnumType); +#define CEL_IMPLEMENT_ENUM_TYPE(enum_type) \ + static_assert(::std::is_base_of_v<::cel::EnumType, enum_type>, \ + #enum_type " must inherit from cel::EnumType"); \ + static_assert(!::std::is_abstract_v, \ + "this must not be abstract"); \ + \ + ::std::pair<::std::size_t, ::std::size_t> enum_type::SizeAndAlignment() \ + const { \ + static_assert( \ + ::std::is_same_v>>, \ + "this must be the same as " #enum_type); \ + return ::std::pair<::std::size_t, ::std::size_t>(sizeof(enum_type), \ + alignof(enum_type)); \ + } + } // namespace cel // type.pre.h forward declares types so they can be friended above. The types @@ -419,4 +543,18 @@ class TimestampType final : public Type { // header and making it difficult to read. #include "base/internal/type.post.h" // IWYU pragma: export +namespace cel { + +struct EnumType::Constant final { + explicit Constant(absl::string_view name, int64_t number) + : name(name), number(number) {} + + // The unqualified enumeration value name. + absl::string_view name; + // The enumeration value number. + int64_t number; +}; + +} // namespace cel + #endif // THIRD_PARTY_CEL_CPP_BASE_TYPE_H_ diff --git a/base/type_factory.h b/base/type_factory.h index 5f578a9e1..39049b8ab 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -15,6 +15,8 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPE_FACTORY_H_ +#include + #include "absl/base/attributes.h" #include "base/handle.h" #include "base/memory_manager.h" @@ -28,6 +30,14 @@ namespace cel { // While TypeFactory is not final and has a virtual destructor, inheriting it is // forbidden outside of the CEL codebase. class TypeFactory { + private: + template + using PropagateConstT = std::conditional_t, const U, U>; + + template + using EnableIfBaseOfT = + std::enable_if_t>, V>; + public: explicit TypeFactory( MemoryManager& memory_manager ABSL_ATTRIBUTE_LIFETIME_BOUND) @@ -64,14 +74,14 @@ class TypeFactory { Persistent GetTimestampType() ABSL_ATTRIBUTE_LIFETIME_BOUND; - // TODO(issues/5): Add CreateStructType(Args...) - // and CreateEnumType(Args...) which returns - // Persistent - - protected: - // Ignore unused for now, as it will be used in the future. - ABSL_ATTRIBUTE_UNUSED MemoryManager& memory_manager() const { - return memory_manager_; + template + EnableIfBaseOfT>>> + CreateEnumType(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal::PersistentHandleFactory>::template Make>(memory_manager(), + std::forward( + args)...); } private: @@ -85,6 +95,8 @@ class TypeFactory { const T>(T::Get())); } + MemoryManager& memory_manager() const { return memory_manager_; } + MemoryManager& memory_manager_; }; diff --git a/base/type_test.cc b/base/type_test.cc index c98d5c0c5..d3eb81305 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -18,15 +18,72 @@ #include #include "absl/hash/hash_testing.h" +#include "absl/status/status.h" #include "base/handle.h" #include "base/memory_manager.h" #include "base/type_factory.h" +#include "base/value.h" #include "internal/testing.h" namespace cel { namespace { using testing::SizeIs; +using cel::internal::StatusIs; + +enum class TestEnum { + kValue1 = 1, + kValue2 = 2, +}; + +class TestEnumType final : public EnumType { + public: + using EnumType::EnumType; + + absl::string_view name() const override { return "test_enum.TestEnum"; } + + protected: + absl::StatusOr> NewInstanceByName( + ValueFactory& value_factory, absl::string_view name) const override { + return absl::UnimplementedError(""); + } + + absl::StatusOr> NewInstanceByNumber( + ValueFactory& value_factory, int64_t number) const override { + return absl::UnimplementedError(""); + } + + absl::StatusOr FindConstantByName( + absl::string_view name) const override { + if (name == "VALUE1") { + return Constant("VALUE1", static_cast(TestEnum::kValue1)); + } else if (name == "VALUE2") { + return Constant("VALUE2", static_cast(TestEnum::kValue2)); + } + return absl::NotFoundError(""); + } + + absl::StatusOr FindConstantByNumber(int64_t number) const override { + switch (number) { + case 1: + return Constant("VALUE1", static_cast(TestEnum::kValue1)); + case 2: + return Constant("VALUE2", static_cast(TestEnum::kValue2)); + default: + return absl::NotFoundError(""); + } + } + + private: + CEL_DECLARE_ENUM_TYPE(TestEnumType); +}; + +CEL_IMPLEMENT_ENUM_TYPE(TestEnumType); + +template +Persistent Must(absl::StatusOr> status_or_handle) { + return std::move(status_or_handle).value(); +} template constexpr void IS_INITIALIZED(T&) {} @@ -122,6 +179,7 @@ TEST(Type, Null) { EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); } TEST(Type, Error) { @@ -140,6 +198,7 @@ TEST(Type, Error) { EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); } TEST(Type, Dyn) { @@ -158,6 +217,7 @@ TEST(Type, Dyn) { EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); } TEST(Type, Any) { @@ -176,6 +236,7 @@ TEST(Type, Any) { EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); } TEST(Type, Bool) { @@ -194,6 +255,7 @@ TEST(Type, Bool) { EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); } TEST(Type, Int) { @@ -212,6 +274,7 @@ TEST(Type, Int) { EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); } TEST(Type, Uint) { @@ -230,6 +293,7 @@ TEST(Type, Uint) { EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); } TEST(Type, Double) { @@ -248,6 +312,7 @@ TEST(Type, Double) { EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); } TEST(Type, String) { @@ -266,6 +331,7 @@ TEST(Type, String) { EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); } TEST(Type, Bytes) { @@ -284,6 +350,7 @@ TEST(Type, Bytes) { EXPECT_TRUE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); } TEST(Type, Duration) { @@ -302,6 +369,7 @@ TEST(Type, Duration) { EXPECT_FALSE(type_factory.GetDurationType().Is()); EXPECT_TRUE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); } TEST(Type, Timestamp) { @@ -321,6 +389,59 @@ TEST(Type, Timestamp) { EXPECT_FALSE(type_factory.GetTimestampType().Is()); EXPECT_FALSE(type_factory.GetTimestampType().Is()); EXPECT_TRUE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); +} + +TEST(Type, Enum) { + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto enum_type, + type_factory.CreateEnumType()); + EXPECT_EQ(enum_type->kind(), Kind::kEnum); + EXPECT_EQ(enum_type->name(), "test_enum.TestEnum"); + EXPECT_THAT(enum_type->parameters(), SizeIs(0)); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_TRUE(enum_type.Is()); +} + +TEST(EnumType, FindConstant) { + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto enum_type, + type_factory.CreateEnumType()); + + ASSERT_OK_AND_ASSIGN(auto value1, + enum_type->FindConstant(EnumType::ConstantId("VALUE1"))); + EXPECT_EQ(value1.name, "VALUE1"); + EXPECT_EQ(value1.number, 1); + + ASSERT_OK_AND_ASSIGN(value1, + enum_type->FindConstant(EnumType::ConstantId(1))); + EXPECT_EQ(value1.name, "VALUE1"); + EXPECT_EQ(value1.number, 1); + + ASSERT_OK_AND_ASSIGN(auto value2, + enum_type->FindConstant(EnumType::ConstantId("VALUE2"))); + EXPECT_EQ(value2.name, "VALUE2"); + EXPECT_EQ(value2.number, 2); + + ASSERT_OK_AND_ASSIGN(value2, + enum_type->FindConstant(EnumType::ConstantId(2))); + EXPECT_EQ(value2.name, "VALUE2"); + EXPECT_EQ(value2.number, 2); + + EXPECT_THAT(enum_type->FindConstant(EnumType::ConstantId("VALUE3")), + StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(enum_type->FindConstant(EnumType::ConstantId(3)), + StatusIs(absl::StatusCode::kNotFound)); } TEST(Type, SupportsAbslHash) { @@ -338,6 +459,7 @@ TEST(Type, SupportsAbslHash) { Persistent(type_factory.GetBytesType()), Persistent(type_factory.GetDurationType()), Persistent(type_factory.GetTimestampType()), + Persistent(Must(type_factory.CreateEnumType())), })); } diff --git a/base/value.cc b/base/value.cc index b7c9d7908..0fdfc1d66 100644 --- a/base/value.cc +++ b/base/value.cc @@ -27,6 +27,7 @@ #include "absl/base/call_once.h" #include "absl/base/macros.h" #include "absl/base/optimization.h" +#include "absl/container/btree_set.h" #include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" #include "absl/status/status.h" @@ -61,6 +62,7 @@ CEL_INTERNAL_VALUE_IMPL(BytesValue); CEL_INTERNAL_VALUE_IMPL(StringValue); CEL_INTERNAL_VALUE_IMPL(DurationValue); CEL_INTERNAL_VALUE_IMPL(TimestampValue); +CEL_INTERNAL_VALUE_IMPL(EnumValue); #undef CEL_INTERNAL_VALUE_IMPL namespace { @@ -764,6 +766,45 @@ void StringValue::HashValue(absl::HashState state) const { rep()); } +struct EnumType::NewInstanceVisitor final { + const EnumType& enum_type; + ValueFactory& value_factory; + + absl::StatusOr> operator()( + absl::string_view name) const { + return enum_type.NewInstanceByName(value_factory, name); + } + + absl::StatusOr> operator()(int64_t number) const { + return enum_type.NewInstanceByNumber(value_factory, number); + } +}; + +absl::StatusOr> EnumValue::New( + const Persistent& enum_type, ValueFactory& value_factory, + EnumType::ConstantId id) { + CEL_ASSIGN_OR_RETURN( + auto enum_value, + absl::visit(EnumType::NewInstanceVisitor{*enum_type, value_factory}, + id.data_)); + if (!enum_value->type_) { + // In case somebody is caching, we avoid setting the type_ if it has already + // been set, to avoid a race condition where one CPU sees a half written + // pointer. + const_cast(*enum_value).type_ = enum_type; + } + return enum_value; +} + +bool EnumValue::Equals(const Value& other) const { + return kind() == other.kind() && type() == other.type() && + number() == internal::down_cast(other).number(); +} + +void EnumValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), number()); +} + namespace base_internal { absl::Cord InlinedCordBytesValue::ToCord(bool reference_counted) const { diff --git a/base/value.h b/base/value.h index 483eaa21a..60072e20a 100644 --- a/base/value.h +++ b/base/value.h @@ -22,6 +22,7 @@ #include #include "absl/base/attributes.h" +#include "absl/base/macros.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" @@ -47,6 +48,7 @@ class BytesValue; class StringValue; class DurationValue; class TimestampValue; +class EnumValue; class ValueFactory; namespace internal { @@ -79,6 +81,7 @@ class Value : public base_internal::Resource { friend class StringValue; friend class DurationValue; friend class TimestampValue; + friend class EnumValue; friend class base_internal::ValueHandleBase; friend class base_internal::StringBytesValue; friend class base_internal::ExternalDataBytesValue; @@ -570,6 +573,89 @@ class TimestampValue final : public Value, absl::Time value_; }; +// EnumValue represents a single constant belonging to cel::EnumType. +class EnumValue : public Value { + public: + static absl::StatusOr> New( + const Persistent& enum_type, ValueFactory& value_factory, + EnumType::ConstantId id); + + Transient type() const final { + ABSL_ASSERT(type_); + return type_; + } + + Kind kind() const final { return Kind::kEnum; } + + virtual int64_t number() const = 0; + + virtual absl::string_view name() const = 0; + + protected: + EnumValue() = default; + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kEnum; } + + EnumValue(const EnumValue&) = delete; + EnumValue(EnumValue&&) = delete; + + bool Equals(const Value& other) const final; + void HashValue(absl::HashState state) const final; + + std::pair SizeAndAlignment() const override = 0; + + // Set lazily, by EnumValue::New. + Persistent type_; +}; + +// CEL_DECLARE_ENUM_VALUE declares `enum_value` as an enumeration value. It must +// be part of the class definition of `enum_value`. +// +// class MyEnumValue : public cel::EnumValue { +// ... +// private: +// CEL_DECLARE_ENUM_VALUE(MyEnumValue); +// }; +#define CEL_DECLARE_ENUM_VALUE(enum_value) \ + private: \ + friend class ::cel::base_internal::ValueHandleBase; \ + \ + ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; + +// CEL_IMPLEMENT_ENUM_VALUE implements `enum_value` as an enumeration value. It +// must be called after the class definition of `enum_value`. +// +// class MyEnumValue : public cel::EnumValue { +// ... +// private: +// CEL_DECLARE_ENUM_VALUE(MyEnumValue); +// }; +// +// CEL_IMPLEMENT_ENUM_VALUE(MyEnumValue); +#define CEL_IMPLEMENT_ENUM_VALUE(enum_value) \ + static_assert(::std::is_base_of_v<::cel::EnumValue, enum_value>, \ + #enum_value " must inherit from cel::EnumValue"); \ + static_assert(!::std::is_abstract_v, \ + "this must not be abstract"); \ + \ + ::std::pair<::std::size_t, ::std::size_t> enum_value::SizeAndAlignment() \ + const { \ + static_assert( \ + ::std::is_same_v>>, \ + "this must be the same as " #enum_value); \ + return ::std::pair<::std::size_t, ::std::size_t>(sizeof(enum_value), \ + alignof(enum_value)); \ + } + } // namespace cel // value.pre.h forward declares types so they can be friended above. The types diff --git a/base/value_factory.h b/base/value_factory.h index 24b9e6172..22d2d27f6 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "absl/base/attributes.h" @@ -33,6 +34,14 @@ namespace cel { class ValueFactory final { + private: + template + using PropagateConstT = std::conditional_t, const U, U>; + + template + using EnableIfBaseOfT = + std::enable_if_t>, V>; + public: explicit ValueFactory( MemoryManager& memory_manager ABSL_ATTRIBUTE_LIFETIME_BOUND) @@ -132,16 +141,22 @@ class ValueFactory final { absl::StatusOr> CreateTimestampValue( absl::Time value) ABSL_ATTRIBUTE_LIFETIME_BOUND; - // TODO(issues/5): Add CreateStructType(Args...) and - // CreateEnumType(Args...) which returns Persistent - - protected: - MemoryManager& memory_manager() const { return memory_manager_; } + template + EnableIfBaseOfT>>> + CreateEnumValue(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal:: + PersistentHandleFactory>::template Make< + std::remove_const_t>(memory_manager(), + std::forward(args)...); + } private: friend class BytesValue; friend class StringValue; + MemoryManager& memory_manager() const { return memory_manager_; } + Persistent GetEmptyBytesValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; diff --git a/base/value_test.cc b/base/value_test.cc index da3b305cf..8b69644ef 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -39,6 +39,87 @@ namespace { using cel::internal::StatusIs; +enum class TestEnum { + kValue1 = 1, + kValue2 = 2, +}; + +class TestEnumValue final : public EnumValue { + public: + explicit TestEnumValue(TestEnum test_enum) : test_enum_(test_enum) {} + + std::string DebugString() const override { return std::string(name()); } + + absl::string_view name() const override { + switch (test_enum_) { + case TestEnum::kValue1: + return "VALUE1"; + case TestEnum::kValue2: + return "VALUE2"; + } + } + + int64_t number() const override { + switch (test_enum_) { + case TestEnum::kValue1: + return 1; + case TestEnum::kValue2: + return 2; + } + } + + private: + CEL_DECLARE_ENUM_VALUE(TestEnumValue); + + TestEnum test_enum_; +}; + +CEL_IMPLEMENT_ENUM_VALUE(TestEnumValue); + +class TestEnumType final : public EnumType { + public: + using EnumType::EnumType; + + absl::string_view name() const override { return "test_enum.TestEnum"; } + + protected: + absl::StatusOr> NewInstanceByName( + ValueFactory& value_factory, absl::string_view name) const override { + if (name == "VALUE1") { + return value_factory.CreateEnumValue(TestEnum::kValue1); + } else if (name == "VALUE2") { + return value_factory.CreateEnumValue(TestEnum::kValue2); + } + return absl::NotFoundError(""); + } + + absl::StatusOr> NewInstanceByNumber( + ValueFactory& value_factory, int64_t number) const override { + switch (number) { + case 1: + return value_factory.CreateEnumValue(TestEnum::kValue1); + case 2: + return value_factory.CreateEnumValue(TestEnum::kValue2); + default: + return absl::NotFoundError(""); + } + } + + absl::StatusOr FindConstantByName( + absl::string_view name) const override { + return absl::UnimplementedError(""); + } + + absl::StatusOr FindConstantByNumber(int64_t number) const override { + return absl::UnimplementedError(""); + } + + private: + CEL_DECLARE_ENUM_TYPE(TestEnumType); +}; + +CEL_IMPLEMENT_ENUM_TYPE(TestEnumType); + template Persistent Must(absl::StatusOr> status_or_handle) { return std::move(status_or_handle).value(); @@ -1362,8 +1443,74 @@ INSTANTIATE_TEST_SUITE_P(StringToCordTest, StringToCordTest, {"\xef\xbf\xbd"}, })); +TEST(Value, Enum) { + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto enum_type, + type_factory.CreateEnumType()); + ASSERT_OK_AND_ASSIGN( + auto one_value, + EnumValue::New(enum_type, value_factory, EnumType::ConstantId("VALUE1"))); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value, Must(EnumValue::New(enum_type, value_factory, + EnumType::ConstantId("VALUE1")))); + EXPECT_EQ(one_value->kind(), Kind::kEnum); + EXPECT_EQ(one_value->type(), enum_type); + EXPECT_EQ(one_value->name(), "VALUE1"); + EXPECT_EQ(one_value->number(), 1); + + ASSERT_OK_AND_ASSIGN( + auto two_value, + EnumValue::New(enum_type, value_factory, EnumType::ConstantId("VALUE2"))); + EXPECT_TRUE(two_value.Is()); + EXPECT_FALSE(two_value.Is()); + EXPECT_EQ(two_value, two_value); + EXPECT_EQ(two_value->kind(), Kind::kEnum); + EXPECT_EQ(two_value->type(), enum_type); + EXPECT_EQ(two_value->name(), "VALUE2"); + EXPECT_EQ(two_value->number(), 2); + + EXPECT_NE(one_value, two_value); + EXPECT_NE(two_value, one_value); +} + +TEST(EnumType, NewInstance) { + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto enum_type, + type_factory.CreateEnumType()); + ASSERT_OK_AND_ASSIGN( + auto one_value, + EnumValue::New(enum_type, value_factory, EnumType::ConstantId("VALUE1"))); + ASSERT_OK_AND_ASSIGN( + auto two_value, + EnumValue::New(enum_type, value_factory, EnumType::ConstantId("VALUE2"))); + ASSERT_OK_AND_ASSIGN( + auto one_value_by_number, + EnumValue::New(enum_type, value_factory, EnumType::ConstantId(1))); + ASSERT_OK_AND_ASSIGN( + auto two_value_by_number, + EnumValue::New(enum_type, value_factory, EnumType::ConstantId(2))); + EXPECT_EQ(one_value, one_value_by_number); + EXPECT_EQ(two_value, two_value_by_number); + + EXPECT_THAT( + EnumValue::New(enum_type, value_factory, EnumType::ConstantId("VALUE3")), + StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(EnumValue::New(enum_type, value_factory, EnumType::ConstantId(3)), + StatusIs(absl::StatusCode::kNotFound)); +} + TEST(Value, SupportsAbslHash) { ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto enum_type, + type_factory.CreateEnumType()); + ASSERT_OK_AND_ASSIGN( + auto enum_value, + EnumValue::New(enum_type, value_factory, EnumType::ConstantId("VALUE1"))); EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ Persistent(value_factory.GetNullValue()), Persistent( @@ -1384,6 +1531,7 @@ TEST(Value, SupportsAbslHash) { Persistent(Must(value_factory.CreateStringValue("foo"))), Persistent( Must(value_factory.CreateStringValue(absl::Cord("bar")))), + Persistent(enum_value), })); } From a14bf9269de36ba5b87870d90526b87b1f9be6d8 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 24 Mar 2022 06:38:18 +0000 Subject: [PATCH 037/111] Update CEL C++ interpreter to consult registered type providers before creating messages. PiperOrigin-RevId: 436920062 --- eval/compiler/BUILD | 5 + eval/compiler/flat_expr_builder.cc | 14 +- eval/compiler/flat_expr_builder_test.cc | 22 ++- eval/compiler/resolver.cc | 13 +- eval/compiler/resolver.h | 5 + eval/compiler/resolver_test.cc | 43 ++++-- eval/eval/BUILD | 5 +- eval/eval/create_struct_step.cc | 164 ++++------------------ eval/eval/create_struct_step.h | 8 +- eval/eval/create_struct_step_test.cc | 42 ++++-- eval/public/BUILD | 6 + eval/public/cel_expr_builder_factory.cc | 7 +- eval/public/cel_type_registry.cc | 13 ++ eval/public/cel_type_registry.h | 22 ++- eval/public/cel_type_registry_test.cc | 51 ++++++- eval/public/structs/legacy_type_adapter.h | 2 +- 16 files changed, 228 insertions(+), 194 deletions(-) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index f0ab9b06e..e877d633b 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -73,6 +73,7 @@ cc_test( "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_descriptor_pool_builder", "//eval/public/structs:cel_proto_wrapper", + "//eval/public/structs:protobuf_descriptor_type_provider", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//internal:status_macros", @@ -247,9 +248,13 @@ cc_test( "//eval/public:cel_function", "//eval/public:cel_function_registry", "//eval/public:cel_type_registry", + "//eval/public:cel_value", + "//eval/public/structs:protobuf_descriptor_type_provider", "//eval/testutil:test_message_cc_proto", "//internal:status_macros", "//internal:testing", "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index a2a50e1f1..b2ed5cbc6 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -560,16 +560,18 @@ class FlatExprVisitor : public AstVisitor { // If the message name is not empty, then the message name must be resolved // within the container, and if a descriptor is found, then a proto message // creation step will be created. - auto message_desc = resolver_.FindDescriptor(message_name, expr->id()); - if (ValidateOrError(message_desc != nullptr, - "Invalid message creation: missing descriptor for '", + auto type_adapter = resolver_.FindTypeAdapter(message_name, expr->id()); + if (ValidateOrError(type_adapter.has_value() && + type_adapter->mutation_apis() != nullptr, + "Invalid struct creation: missing type info for '", message_name, "'")) { for (const auto& entry : struct_expr->entries()) { ValidateOrError(entry.has_field_key(), - "Message entry missing field name"); - ValidateOrError(entry.has_value(), "Message entry missing value"); + "Struct entry missing field name"); + ValidateOrError(entry.has_value(), "Struct entry missing value"); } - AddStep(CreateCreateStructStep(struct_expr, message_desc, expr->id())); + AddStep(CreateCreateStructStep(struct_expr, type_adapter->mutation_apis(), + expr->id())); } } diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index a8077839a..ac4bdfc29 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -29,6 +29,7 @@ #include "google/protobuf/descriptor.pb.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" #include "google/protobuf/text_format.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" @@ -47,6 +48,7 @@ #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_descriptor_pool_builder.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "eval/public/testing/matchers.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" @@ -206,6 +208,10 @@ TEST(FlatExprBuilderTest, MessageFieldValueUnset) { Expr expr; SourceInfo source_info; FlatExprBuilder builder; + builder.GetTypeRegistry()->RegisterTypeProvider( + std::make_unique( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); // Don't set either the field or the value for the message creation step. auto* create_message = expr.mutable_struct_expr(); @@ -213,13 +219,13 @@ TEST(FlatExprBuilderTest, MessageFieldValueUnset) { auto* entry = create_message->add_entries(); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Message entry missing field name"))); + HasSubstr("Struct entry missing field name"))); // Set the entry field, but not the value. entry->set_field_key("bool_value"); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Message entry missing value"))); + HasSubstr("Struct entry missing value"))); } TEST(FlatExprBuilderTest, BinaryCallTooManyArguments) { @@ -1616,6 +1622,11 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { // not link the generated message, so it's not included in the generated pool. FlatExprBuilder builder(google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory()); + builder.GetTypeRegistry()->RegisterTypeProvider( + std::make_unique( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); + EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument)); @@ -1634,6 +1645,10 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { // This time, the message is *known*. We are using a custom descriptor pool // that has been primed with the relevant message. FlatExprBuilder builder2(&desc_pool, &message_factory); + builder2.GetTypeRegistry()->RegisterTypeProvider( + std::make_unique(&desc_pool, + &message_factory)); + ASSERT_OK_AND_ASSIGN(auto expression, builder2.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -1722,6 +1737,9 @@ TEST_P(CustomDescriptorPoolTest, TestType) { google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("m")); FlatExprBuilder builder(&descriptor_pool, &message_factory); + builder.GetTypeRegistry()->RegisterTypeProvider( + std::make_unique(&descriptor_pool, + &message_factory)); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); // Create test subject, invoke custom setter for message diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index d6474cdff..426df40c1 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -9,6 +9,7 @@ #include "absl/strings/str_split.h" #include "absl/types/optional.h" #include "eval/public/cel_builtins.h" +#include "eval/public/cel_value.h" namespace google::api::expr::runtime { @@ -144,18 +145,18 @@ std::vector Resolver::FindLazyOverloads( return funcs; } -const google::protobuf::Descriptor* Resolver::FindDescriptor(absl::string_view name, - int64_t expr_id) const { +absl::optional Resolver::FindTypeAdapter( + absl::string_view name, int64_t expr_id) const { // Resolve the fully qualified names and then defer to the type registry // for possible matches. auto names = FullyQualifiedNames(name, expr_id); for (const auto& name : names) { - auto desc = type_registry_->FindDescriptor(name); - if (desc != nullptr) { - return desc; + auto maybe_adapter = type_registry_->FindTypeAdapter(name); + if (maybe_adapter.has_value()) { + return maybe_adapter; } } - return nullptr; + return absl::nullopt; } } // namespace google::api::expr::runtime diff --git a/eval/compiler/resolver.h b/eval/compiler/resolver.h index 739254e07..2156b0570 100644 --- a/eval/compiler/resolver.h +++ b/eval/compiler/resolver.h @@ -47,6 +47,11 @@ class Resolver { const google::protobuf::Descriptor* FindDescriptor(absl::string_view name, int64_t expr_id) const; + // FindTypeAdapter returns the adapter for the given type name if one exists, + // following resolution rules for the expression container. + absl::optional FindTypeAdapter(absl::string_view name, + int64_t expr_id) const; + // FindLazyOverloads returns the set, possibly empty, of lazy overloads // matching the given function signature. std::vector FindLazyOverloads( diff --git a/eval/compiler/resolver_test.cc b/eval/compiler/resolver_test.cc index 4583199a3..8ecfab760 100644 --- a/eval/compiler/resolver_test.cc +++ b/eval/compiler/resolver_test.cc @@ -3,10 +3,15 @@ #include #include +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" #include "absl/status/status.h" +#include "absl/types/optional.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_type_registry.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "eval/testutil/test_message.pb.h" #include "internal/status_macros.h" #include "internal/testing.h" @@ -114,34 +119,48 @@ TEST(ResolverTest, TestFindConstantQualifiedTypeDisabled) { EXPECT_FALSE(type_value.has_value()); } -TEST(ResolverTest, TestFindDescriptorBySimpleName) { +TEST(ResolverTest, FindTypeAdapterBySimpleName) { CelFunctionRegistry func_registry; CelTypeRegistry type_registry; Resolver resolver("google.api.expr.runtime", &func_registry, &type_registry); - - auto desc_value = resolver.FindDescriptor("TestMessage", -1); - EXPECT_TRUE(desc_value != nullptr); - EXPECT_THAT(desc_value, Eq(TestMessage::GetDescriptor())); + type_registry.RegisterTypeProvider( + std::make_unique( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); + + absl::optional adapter = + resolver.FindTypeAdapter("TestMessage", -1); + EXPECT_TRUE(adapter.has_value()); + EXPECT_THAT(adapter->mutation_apis(), testing::NotNull()); } -TEST(ResolverTest, TestFindDescriptorByQualifiedName) { +TEST(ResolverTest, FindTypeAdapterByQualifiedName) { CelFunctionRegistry func_registry; CelTypeRegistry type_registry; + type_registry.RegisterTypeProvider( + std::make_unique( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); Resolver resolver("google.api.expr.runtime", &func_registry, &type_registry); - auto desc_value = - resolver.FindDescriptor(".google.api.expr.runtime.TestMessage", -1); - EXPECT_TRUE(desc_value != nullptr); - EXPECT_THAT(desc_value, Eq(TestMessage::GetDescriptor())); + absl::optional adapter = + resolver.FindTypeAdapter(".google.api.expr.runtime.TestMessage", -1); + EXPECT_TRUE(adapter.has_value()); + EXPECT_THAT(adapter->mutation_apis(), testing::NotNull()); } TEST(ResolverTest, TestFindDescriptorNotFound) { CelFunctionRegistry func_registry; CelTypeRegistry type_registry; + type_registry.RegisterTypeProvider( + std::make_unique( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); Resolver resolver("google.api.expr.runtime", &func_registry, &type_registry); - auto desc_value = resolver.FindDescriptor("UndefinedMessage", -1); - EXPECT_TRUE(desc_value == nullptr); + absl::optional adapter = + resolver.FindTypeAdapter("UndefinedMessage", -1); + EXPECT_FALSE(adapter.has_value()); } TEST(ResolverTest, TestFindOverloads) { diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 22a67c7fa..885b4b86a 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -223,15 +223,12 @@ cc_library( ":expression_step_base", "//eval/public:cel_value", "//eval/public/containers:container_backed_map_impl", - "//eval/public/containers:field_access", "//eval/public/structs:cel_proto_wrapper", - "//extensions/protobuf:memory_manager", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_protobuf//:protobuf", ], ) @@ -512,6 +509,8 @@ cc_test( "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", + "//eval/public/structs:proto_message_type_adapter", + "//eval/public/structs:protobuf_descriptor_type_provider", "//eval/testutil:test_message_cc_proto", "//internal:status_macros", "//internal:testing", diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 5ce180885..2d1574c19 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -2,10 +2,10 @@ #include #include +#include #include #include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/arena.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -13,31 +13,23 @@ #include "eval/eval/expression_step_base.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_map_impl.h" -#include "eval/public/containers/field_access.h" -#include "eval/public/structs/cel_proto_wrapper.h" -#include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { -using ::cel::extensions::ProtoMemoryManager; -using ::google::protobuf::Descriptor; -using ::google::protobuf::FieldDescriptor; -using ::google::protobuf::Message; -using ::google::protobuf::MessageFactory; - class CreateStructStepForMessage : public ExpressionStepBase { public: struct FieldEntry { - const FieldDescriptor* field; + std::string field_name; }; - CreateStructStepForMessage(int64_t expr_id, const Descriptor* descriptor, - std::vector entries) + CreateStructStepForMessage( + int64_t expr_id, const LegacyTypeAdapter::MutationApis* type_adapter, + std::vector entries) : ExpressionStepBase(expr_id), - descriptor_(descriptor), + type_adapter_(type_adapter), entries_(std::move(entries)) {} absl::Status Evaluate(ExecutionFrame* frame) const override; @@ -45,7 +37,7 @@ class CreateStructStepForMessage : public ExpressionStepBase { private: absl::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; - const Descriptor* descriptor_; + const LegacyTypeAdapter::MutationApis* type_adapter_; std::vector entries_; }; @@ -68,10 +60,6 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, absl::Span args = frame->value_stack().GetSpan(entries_size); - // This implementation requires arena-backed memory manager. - google::protobuf::Arena* arena = - ProtoMemoryManager::CastToProtoArena(frame->memory_manager()); - if (frame->enable_unknowns()) { auto unknown_set = frame->attribute_utility().MergeUnknowns( args, frame->value_stack().GetAttributeSpan(entries_size), @@ -83,121 +71,20 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, } } - const Message* prototype = - frame->message_factory()->GetPrototype(descriptor_); - - Message* msg = (prototype != nullptr) ? prototype->New(arena) : nullptr; - - if (msg == nullptr) { - *result = CreateErrorValue( - frame->memory_manager(), - absl::Substitute("Failed to create message $0", descriptor_->name())); - return absl::OkStatus(); - } + CEL_ASSIGN_OR_RETURN(CelValue instance, + type_adapter_->NewInstance(frame->memory_manager())); int index = 0; for (const auto& entry : entries_) { const CelValue& arg = args[index++]; - absl::Status status = absl::OkStatus(); - - if (entry.field->is_map()) { - constexpr int kKeyField = 1; - constexpr int kValueField = 2; - - const CelMap* cel_map; - if (!arg.GetValue(&cel_map) || cel_map == nullptr) { - status = absl::InvalidArgumentError(absl::Substitute( - "Failed to create message $0, field $1: value is not CelMap", - descriptor_->name(), entry.field->name())); - break; - } - - auto entry_descriptor = entry.field->message_type(); - - if (entry_descriptor == nullptr) { - status = absl::InvalidArgumentError( - absl::Substitute("Failed to create message $0, field $1: failed to " - "find map entry descriptor", - descriptor_->name(), entry.field->name())); - break; - } - - auto key_field_descriptor = - entry_descriptor->FindFieldByNumber(kKeyField); - auto value_field_descriptor = - entry_descriptor->FindFieldByNumber(kValueField); - - if (key_field_descriptor == nullptr) { - status = absl::InvalidArgumentError( - absl::Substitute("Failed to create message $0, field $1: failed to " - "find key field descriptor", - descriptor_->name(), entry.field->name())); - break; - } - if (value_field_descriptor == nullptr) { - status = absl::InvalidArgumentError( - absl::Substitute("Failed to create message $0, field $1: failed to " - "find value field descriptor", - descriptor_->name(), entry.field->name())); - break; - } - - const CelList* key_list = cel_map->ListKeys(); - for (int i = 0; i < key_list->size(); i++) { - CelValue key = (*key_list)[i]; - - auto value = (*cel_map)[key]; - if (!value.has_value()) { - status = absl::InvalidArgumentError(absl::Substitute( - "Failed to create message $0, field $1: Error serializing CelMap", - descriptor_->name(), entry.field->name())); - break; - } - - Message* entry_msg = msg->GetReflection()->AddMessage(msg, entry.field); - status = - SetValueToSingleField(key, key_field_descriptor, entry_msg, arena); - if (!status.ok()) { - break; - } - status = SetValueToSingleField(value.value(), value_field_descriptor, - entry_msg, arena); - if (!status.ok()) { - break; - } - } - - } else if (entry.field->is_repeated()) { - const CelList* cel_list; - if (!arg.GetValue(&cel_list) || cel_list == nullptr) { - *result = CreateErrorValue( - frame->memory_manager(), - absl::Substitute( - "Failed to create message $0: value $1 is not CelList", - descriptor_->name(), entry.field->name())); - return absl::OkStatus(); - } - - for (int i = 0; i < cel_list->size(); i++) { - status = - AddValueToRepeatedField((*cel_list)[i], entry.field, msg, arena); - if (!status.ok()) break; - } - } else { - status = SetValueToSingleField(arg, entry.field, msg, arena); - } - - if (!status.ok()) { - *result = CreateErrorValue( - frame->memory_manager(), - absl::Substitute("Failed to create message $0: reason $1", - descriptor_->name(), status.ToString())); - return absl::OkStatus(); - } + CEL_RETURN_IF_ERROR(type_adapter_->SetField( + entry.field_name, arg, frame->memory_manager(), instance)); } - *result = CelProtoWrapper::CreateMessage(msg, arena); + CEL_RETURN_IF_ERROR( + type_adapter_->AdaptFromWellKnownType(frame->memory_manager(), instance)); + *result = instance; return absl::OkStatus(); } @@ -208,7 +95,10 @@ absl::Status CreateStructStepForMessage::Evaluate(ExecutionFrame* frame) const { } CelValue result; - CEL_RETURN_IF_ERROR(DoEvaluate(frame, &result)); + absl::Status status = DoEvaluate(frame, &result); + if (!status.ok()) { + result = CreateErrorValue(frame->memory_manager(), status); + } frame->value_stack().Pop(entries_.size()); frame->value_stack().Push(result); @@ -268,22 +158,20 @@ absl::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { absl::StatusOr> CreateCreateStructStep( const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, - const Descriptor* message_desc, int64_t expr_id) { - if (message_desc != nullptr) { + const LegacyTypeAdapter::MutationApis* type_adapter, int64_t expr_id) { + if (type_adapter != nullptr) { std::vector entries; for (const auto& entry : create_struct_expr->entries()) { - const FieldDescriptor* field_desc = - message_desc->FindFieldByName(entry.field_key()); - if (field_desc == nullptr) { - return absl::InvalidArgumentError( - absl::StrCat("Invalid message creation: field '", entry.field_key(), - "' not found in '", message_desc->full_name(), "'")); + if (!type_adapter->DefinesField(entry.field_key())) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid message creation: field '", entry.field_key(), + "' not found in '", create_struct_expr->message_name(), "'")); } - entries.push_back({field_desc}); + entries.push_back({entry.field_key()}); } - return std::make_unique(expr_id, message_desc, + return std::make_unique(expr_id, type_adapter, std::move(entries)); } else { // Make map-creating step. diff --git a/eval/eval/create_struct_step.h b/eval/eval/create_struct_step.h index 0f4b66838..c47422782 100644 --- a/eval/eval/create_struct_step.h +++ b/eval/eval/create_struct_step.h @@ -4,23 +4,23 @@ #include #include -#include "google/protobuf/descriptor.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "eval/eval/evaluator_core.h" +#include "eval/public/cel_value.h" namespace google::api::expr::runtime { // Factory method for CreateStruct - based Execution step absl::StatusOr> CreateCreateStructStep( const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, - const google::protobuf::Descriptor* message_desc, int64_t expr_id); + const LegacyTypeAdapter::MutationApis* type_adapter, int64_t expr_id); inline absl::StatusOr> CreateCreateStructStep( const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, int64_t expr_id) { - return CreateCreateStructStep(create_struct_expr, /*message_desc=*/nullptr, - expr_id); + return CreateCreateStructStep(create_struct_expr, + /*type_adapter=*/nullptr, expr_id); } } // namespace google::api::expr::runtime diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index c54b29db8..e62d6a213 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -5,6 +5,7 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -14,6 +15,8 @@ #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/proto_message_type_adapter.h" +#include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "eval/testutil/test_message.pb.h" #include "internal/status_macros.h" #include "internal/testing.h" @@ -44,6 +47,10 @@ absl::StatusOr RunExpression(absl::string_view field, bool enable_unknowns) { ExecutionPath path; CelTypeRegistry type_registry; + type_registry.RegisterTypeProvider( + std::make_unique( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); Expr expr0; Expr expr1; @@ -58,13 +65,14 @@ absl::StatusOr RunExpression(absl::string_view field, auto entry = create_struct->add_entries(); entry->set_field_key(field.data()); - auto desc = type_registry.FindDescriptor(create_struct->message_name()); - if (desc == nullptr) { + auto adapter = type_registry.FindTypeAdapter(create_struct->message_name()); + if (!adapter.has_value() || adapter->mutation_apis() == nullptr) { return absl::Status(absl::StatusCode::kFailedPrecondition, "missing proto message type"); } - CEL_ASSIGN_OR_RETURN(auto step1, - CreateCreateStructStep(create_struct, desc, expr1.id())); + CEL_ASSIGN_OR_RETURN( + auto step1, CreateCreateStructStep(create_struct, + adapter->mutation_apis(), expr1.id())); path.push_back(std::move(step0)); path.push_back(std::move(step1)); @@ -169,16 +177,20 @@ class CreateCreateStructStepTest : public testing::TestWithParam {}; TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { ExecutionPath path; CelTypeRegistry type_registry; - + type_registry.RegisterTypeProvider( + std::make_unique( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); Expr expr1; auto create_struct = expr1.mutable_struct_expr(); create_struct->set_message_name("google.api.expr.runtime.TestMessage"); - auto desc = type_registry.FindDescriptor(create_struct->message_name()); - ASSERT_TRUE(desc != nullptr); + auto adapter = type_registry.FindTypeAdapter(create_struct->message_name()); + ASSERT_TRUE(adapter.has_value() && adapter->mutation_apis() != nullptr); - ASSERT_OK_AND_ASSIGN(auto step, - CreateCreateStructStep(create_struct, desc, expr1.id())); + ASSERT_OK_AND_ASSIGN( + auto step, CreateCreateStructStep(create_struct, adapter->mutation_apis(), + expr1.id())); path.push_back(std::move(step)); CelExpressionFlatImpl cel_expr( @@ -199,6 +211,10 @@ TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { TEST_P(CreateCreateStructStepTest, TestMessageCreationBadField) { ExecutionPath path; CelTypeRegistry type_registry; + type_registry.RegisterTypeProvider( + std::make_unique( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); Expr expr1; auto create_struct = expr1.mutable_struct_expr(); @@ -207,10 +223,12 @@ TEST_P(CreateCreateStructStepTest, TestMessageCreationBadField) { entry->set_field_key("bad_field"); auto value = entry->mutable_value(); value->mutable_const_expr()->set_bool_value(true); - auto desc = type_registry.FindDescriptor(create_struct->message_name()); - ASSERT_TRUE(desc != nullptr); + auto adapter = type_registry.FindTypeAdapter(create_struct->message_name()); + ASSERT_TRUE(adapter.has_value() && adapter->mutation_apis() != nullptr); - EXPECT_THAT(CreateCreateStructStep(create_struct, desc, expr1.id()).status(), + EXPECT_THAT(CreateCreateStructStep(create_struct, adapter->mutation_apis(), + expr1.id()) + .status(), StatusIs(absl::StatusCode::kInvalidArgument, testing::HasSubstr("'bad_field'"))); } diff --git a/eval/public/BUILD b/eval/public/BUILD index 448c8a220..e8266f651 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -406,6 +406,8 @@ cc_library( ":cel_expression", ":cel_options", "//eval/compiler:flat_expr_builder", + "//eval/public/structs:proto_message_type_adapter", + "//eval/public/structs:protobuf_descriptor_type_provider", "//internal:proto_util", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", @@ -612,6 +614,7 @@ cc_library( hdrs = ["cel_type_registry.h"], deps = [ ":cel_value", + "//eval/public/structs:legacy_type_provider", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/status", @@ -627,9 +630,12 @@ cc_test( srcs = ["cel_type_registry_test.cc"], deps = [ ":cel_type_registry", + ":cel_value", + "//eval/public/structs:legacy_type_provider", "//eval/testutil:test_message_cc_proto", "//internal:testing", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index c78e846c5..017521457 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -16,14 +16,16 @@ #include "eval/public/cel_expr_builder_factory.h" +#include #include #include #include "absl/status/status.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/public/cel_options.h" +#include "eval/public/structs/proto_message_type_adapter.h" +#include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "internal/proto_util.h" - namespace google::api::expr::runtime { namespace { @@ -45,6 +47,9 @@ std::unique_ptr CreateCelExpressionBuilder( } auto builder = absl::make_unique(descriptor_pool, message_factory); + builder->GetTypeRegistry()->RegisterTypeProvider( + std::make_unique(descriptor_pool, + message_factory)); builder->set_shortcircuiting(options.short_circuiting); builder->set_constant_folding(options.constant_folding, options.constant_arena); diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc index 085c1daba..6bb7d335e 100644 --- a/eval/public/cel_type_registry.cc +++ b/eval/public/cel_type_registry.cc @@ -69,6 +69,19 @@ const google::protobuf::Descriptor* CelTypeRegistry::FindDescriptor( std::string(fully_qualified_type_name)); } +// Find a type's CelValue instance by its fully qualified name. +absl::optional CelTypeRegistry::FindTypeAdapter( + absl::string_view fully_qualified_type_name) const { + for (const auto& provider : type_providers_) { + auto maybe_adapter = provider->ProvideLegacyType(fully_qualified_type_name); + if (maybe_adapter.has_value()) { + return maybe_adapter; + } + } + + return absl::nullopt; +} + absl::optional CelTypeRegistry::FindType( absl::string_view fully_qualified_type_name) const { // Searches through explicitly registered type names first. diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index f20eab8d2..4e12c6440 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -1,12 +1,16 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ +#include +#include + #include "google/protobuf/descriptor.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "eval/public/cel_value.h" +#include "eval/public/structs/legacy_type_provider.h" namespace google::api::expr::runtime { @@ -44,8 +48,17 @@ class CelTypeRegistry { // Enum registration must be performed prior to CelExpression creation. void Register(const google::protobuf::EnumDescriptor* enum_descriptor); - // Find a protobuf Descriptor given a fully qualified protobuf type name. - const google::protobuf::Descriptor* FindDescriptor( + // Register a new type provider. + // + // Type providers are consulted in the order they are added. + void RegisterTypeProvider(std::unique_ptr provider) { + type_providers_.push_back(std::move(provider)); + } + + // Find a type adapter given a fully qualified type name. + // Adapter provides a generic interface for the reflecion operations the + // interpreter needs to provide. + absl::optional FindTypeAdapter( absl::string_view fully_qualified_type_name) const; // Find a type's CelValue instance by its fully qualified name. @@ -59,11 +72,16 @@ class CelTypeRegistry { } private: + // Find a protobuf Descriptor given a fully qualified protobuf type name. + const google::protobuf::Descriptor* FindDescriptor( + absl::string_view fully_qualified_type_name) const; + const google::protobuf::DescriptorPool* descriptor_pool_; // externally owned // pointer-stability is required for the strings in the types set, which is // why a node_hash_set is used instead of another container type. absl::node_hash_set types_; absl::flat_hash_set enums_; + std::vector> type_providers_; }; } // namespace google::api::expr::runtime diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc index d79625804..50b73e6fa 100644 --- a/eval/public/cel_type_registry_test.cc +++ b/eval/public/cel_type_registry_test.cc @@ -1,9 +1,14 @@ #include "eval/public/cel_type_registry.h" +#include #include +#include #include "google/protobuf/any.pb.h" #include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/legacy_type_provider.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" @@ -13,6 +18,27 @@ namespace { using testing::Eq; +class TestTypeProvider : public LegacyTypeProvider { + public: + explicit TestTypeProvider(std::vector types) + : types_(std::move(types)) {} + + // Return a type adapter for an opaque type + // (no reflection operations supported). + absl::optional ProvideLegacyType( + absl::string_view name) const override { + for (const auto& type : types_) { + if (name == type) { + return LegacyTypeAdapter(/*access=*/nullptr, /*mutation=*/nullptr); + } + } + return absl::nullopt; + } + + private: + std::vector types_; +}; + TEST(CelTypeRegistryTest, TestRegisterEnumDescriptor) { CelTypeRegistry registry; registry.Register(TestMessage::TestEnum_descriptor()); @@ -42,17 +68,28 @@ TEST(CelTypeRegistryTest, TestRegisterTypeName) { EXPECT_THAT(type->CelTypeOrDie().value(), Eq("custom_type")); } -TEST(CelTypeRegistryTest, TestFindDescriptorFound) { +TEST(CelTypeRegistryTest, TestFindTypeAdapterFound) { + CelTypeRegistry registry; + registry.RegisterTypeProvider(std::make_unique( + std::vector{"google.protobuf.Any"})); + auto desc = registry.FindTypeAdapter("google.protobuf.Any"); + ASSERT_TRUE(desc.has_value()); +} + +TEST(CelTypeRegistryTest, TestFindTypeAdapterFoundMultipleProviders) { CelTypeRegistry registry; - auto desc = registry.FindDescriptor("google.protobuf.Any"); - ASSERT_TRUE(desc != nullptr); - EXPECT_THAT(desc->full_name(), Eq("google.protobuf.Any")); + registry.RegisterTypeProvider(std::make_unique( + std::vector{"google.protobuf.Int64"})); + registry.RegisterTypeProvider(std::make_unique( + std::vector{"google.protobuf.Any"})); + auto desc = registry.FindTypeAdapter("google.protobuf.Any"); + ASSERT_TRUE(desc.has_value()); } -TEST(CelTypeRegistryTest, TestFindDescriptorNotFound) { +TEST(CelTypeRegistryTest, TestFindTypeAdapterNotFound) { CelTypeRegistry registry; - auto desc = registry.FindDescriptor("missing.MessageType"); - EXPECT_TRUE(desc == nullptr); + auto desc = registry.FindTypeAdapter("missing.MessageType"); + EXPECT_FALSE(desc.has_value()); } TEST(CelTypeRegistryTest, TestFindTypeCoreTypeFound) { diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h index fbefe1c35..58dea0fd8 100644 --- a/eval/public/structs/legacy_type_adapter.h +++ b/eval/public/structs/legacy_type_adapter.h @@ -39,7 +39,7 @@ namespace google::api::expr::runtime { class LegacyTypeAdapter { public: // Interface for mutation apis. - // Note: in the new type system, the provider represents this by returning + // Note: in the new type system, a type provider represents this by returning // a cel::Type and cel::ValueFactory for the type. class MutationApis { public: From 43cc3a1df6c73c03548be9df56cf4d38d4a972fc Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 24 Mar 2022 17:00:24 +0000 Subject: [PATCH 038/111] Internal change PiperOrigin-RevId: 437020204 --- internal/BUILD | 15 +++++++++ internal/rtti.h | 75 +++++++++++++++++++++++++++++++++++++++++++ internal/rtti_test.cc | 35 ++++++++++++++++++++ 3 files changed, 125 insertions(+) create mode 100644 internal/rtti.h create mode 100644 internal/rtti_test.cc diff --git a/internal/BUILD b/internal/BUILD index 3b8f43163..e4981349b 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -159,6 +159,21 @@ cc_test( ], ) +cc_library( + name = "rtti", + hdrs = ["rtti.h"], +) + +cc_test( + name = "rtti_test", + srcs = ["rtti_test.cc"], + deps = [ + ":rtti", + "//internal:testing", + "@com_google_absl//absl/hash:hash_testing", + ], +) + cc_library( name = "testing", testonly = True, diff --git a/internal/rtti.h b/internal/rtti.h new file mode 100644 index 000000000..c10df58ca --- /dev/null +++ b/internal/rtti.h @@ -0,0 +1,75 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_RTTI_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_RTTI_H_ + +#include +#include + +namespace cel::internal { + +class TypeInfo; + +template +TypeInfo TypeId(); + +// TypeInfo is an RTTI-like alternative for identifying a type at runtime. Its +// main benefit is it does not require RTTI being available, allowing CEL to +// work without RTTI. +// +// This is used to implement the runtime type system and conversion between CEL +// values and their native C++ counterparts. +class TypeInfo final { + public: + constexpr TypeInfo() = default; + + TypeInfo(const TypeInfo&) = default; + + TypeInfo& operator=(const TypeInfo&) = default; + + friend bool operator==(const TypeInfo& lhs, const TypeInfo& rhs) { + return lhs.id_ == rhs.id_; + } + + friend bool operator!=(const TypeInfo& lhs, const TypeInfo& rhs) { + return !operator==(lhs, rhs); + } + + template + friend H AbslHashValue(H state, const TypeInfo& type) { + return H::combine(std::move(state), reinterpret_cast(type.id_)); + } + + private: + template + friend TypeInfo TypeId(); + + constexpr explicit TypeInfo(void* id) : id_(id) {} + + void* id_ = nullptr; +}; + +template +TypeInfo TypeId() { + // Adapted from Abseil and GTL. I believe this not being const is to ensure + // the compiler does not merge multiple constants with the same value to share + // the same address. + static char id; + return TypeInfo(&id); +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_RTTI_H_ diff --git a/internal/rtti_test.cc b/internal/rtti_test.cc new file mode 100644 index 000000000..94543977c --- /dev/null +++ b/internal/rtti_test.cc @@ -0,0 +1,35 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/rtti.h" + +#include "absl/hash/hash_testing.h" +#include "internal/testing.h" + +namespace cel::internal { +namespace { + +struct Type1 {}; + +struct Type2 {}; + +TEST(TypeInfo, Default) { EXPECT_EQ(TypeInfo(), TypeInfo()); } + +TEST(TypeId, SupportsAbslHash) { + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {TypeInfo(), TypeId(), TypeId()})); +} + +} // namespace +} // namespace cel::internal From c3fa7b250c52c15d72fbec3e30f7fa2460372590 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 24 Mar 2022 20:47:35 +0000 Subject: [PATCH 039/111] Internal change PiperOrigin-RevId: 437077618 --- base/BUILD | 8 ++++++- base/internal/type.post.h | 1 + base/internal/type.pre.h | 2 ++ base/type.cc | 15 +++++++++++++ base/type.h | 40 +++++++++++++++++++++++++++++++++++ base/type_factory.cc | 44 +++++++++++++++++++++++++++++++++++++++ base/type_factory.h | 11 ++++++++++ base/type_factory_test.cc | 34 ++++++++++++++++++++++++++++++ base/type_test.cc | 40 +++++++++++++++++++++++++++++++++++ 9 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 base/type_factory_test.cc diff --git a/base/BUILD b/base/BUILD index a0c428379..e5fc06487 100644 --- a/base/BUILD +++ b/base/BUILD @@ -113,12 +113,15 @@ cc_library( ":kind", ":memory_manager", "//base/internal:type", + "//internal:casts", "//internal:no_destructor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], @@ -126,7 +129,10 @@ cc_library( cc_test( name = "type_test", - srcs = ["type_test.cc"], + srcs = [ + "type_factory_test.cc", + "type_test.cc", + ], deps = [ ":handle", ":memory_manager", diff --git a/base/internal/type.post.h b/base/internal/type.post.h index 89927e8fc..5245015ba 100644 --- a/base/internal/type.post.h +++ b/base/internal/type.post.h @@ -265,6 +265,7 @@ CEL_INTERNAL_TYPE_DECL(StringType); CEL_INTERNAL_TYPE_DECL(DurationType); CEL_INTERNAL_TYPE_DECL(TimestampType); CEL_INTERNAL_TYPE_DECL(EnumType); +CEL_INTERNAL_TYPE_DECL(ListType); #undef CEL_INTERNAL_TYPE_DECL } // namespace cel diff --git a/base/internal/type.pre.h b/base/internal/type.pre.h index d6bf8ae0b..65c1722a4 100644 --- a/base/internal/type.pre.h +++ b/base/internal/type.pre.h @@ -39,6 +39,8 @@ inline constexpr uintptr_t kTypeHandleBits = kTypeHandleUnmanaged | kTypeHandleReserved; inline constexpr uintptr_t kTypeHandleMask = ~kTypeHandleBits; +class ListTypeImpl; + } // namespace cel::base_internal #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_PRE_H_ diff --git a/base/type.cc b/base/type.cc index cef3aa0f5..26af7cc39 100644 --- a/base/type.cc +++ b/base/type.cc @@ -19,6 +19,7 @@ #include "absl/types/span.h" #include "absl/types/variant.h" #include "base/handle.h" +#include "internal/casts.h" #include "internal/no_destructor.h" namespace cel { @@ -42,6 +43,7 @@ CEL_INTERNAL_TYPE_IMPL(StringType); CEL_INTERNAL_TYPE_IMPL(DurationType); CEL_INTERNAL_TYPE_IMPL(TimestampType); CEL_INTERNAL_TYPE_IMPL(EnumType); +CEL_INTERNAL_TYPE_IMPL(ListType); #undef CEL_INTERNAL_TYPE_IMPL absl::Span> Type::parameters() const { return {}; } @@ -135,4 +137,17 @@ absl::StatusOr EnumType::FindConstant(ConstantId id) const { return absl::visit(FindConstantVisitor{*this}, id.data_); } +bool ListType::Equals(const Type& other) const { + if (kind() != other.kind()) { + return false; + } + return element() == internal::down_cast(other).element(); +} + +void ListType::HashValue(absl::HashState state) const { + // We specifically hash the element first and then call the parent method to + // avoid hash suffix/prefix collisions. + Type::HashValue(absl::HashState::combine(std::move(state), element())); +} + } // namespace cel diff --git a/base/type.h b/base/type.h index 028e75b36..cba373bb4 100644 --- a/base/type.h +++ b/base/type.h @@ -46,6 +46,7 @@ class BytesType; class DurationType; class TimestampType; class EnumType; +class ListType; class TypeFactory; class TypeProvider; @@ -94,6 +95,7 @@ class Type : public base_internal::Resource { friend class DurationType; friend class TimestampType; friend class EnumType; + friend class ListType; friend class base_internal::TypeHandleBase; Type() = default; @@ -535,6 +537,44 @@ class EnumType : public Type { alignof(enum_type)); \ } +// ListType represents a list type. A list is a sequential container where each +// element is the same type. +class ListType : public Type { + // I would have liked to make this class final, but we cannot instantiate + // Persistent or Transient at this point. It must be + // done after the post include below. Maybe we should separate out the post + // includes on a per type basis so we can do that? + public: + Kind kind() const final { return Kind::kList; } + + absl::string_view name() const final { return "list"; } + + // Returns the type of the elements in the list. + virtual Transient element() const = 0; + + private: + friend class TypeFactory; + friend class base_internal::TypeHandleBase; + friend class base_internal::ListTypeImpl; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kList; } + + ListType() = default; + + ListType(const ListType&) = delete; + ListType(ListType&&) = delete; + + std::pair SizeAndAlignment() const override = 0; + + // Called by base_internal::TypeHandleBase. + bool Equals(const Type& other) const final; + + // Called by base_internal::TypeHandleBase. + void HashValue(absl::HashState state) const final; +}; + } // namespace cel // type.pre.h forward declares types so they can be friended above. The types diff --git a/base/type_factory.cc b/base/type_factory.cc index 4f3509fb2..24446e504 100644 --- a/base/type_factory.cc +++ b/base/type_factory.cc @@ -14,6 +14,11 @@ #include "base/type_factory.h" +#include + +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" #include "base/handle.h" #include "base/type.h" @@ -21,10 +26,30 @@ namespace cel { namespace { +using base_internal::PersistentHandleFactory; using base_internal::TransientHandleFactory; } // namespace +namespace base_internal { + +class ListTypeImpl final : public ListType { + public: + explicit ListTypeImpl(Persistent element) + : element_(std::move(element)) {} + + Transient element() const override { return element_; } + + private: + std::pair SizeAndAlignment() const override { + return std::make_pair(sizeof(ListTypeImpl), alignof(ListTypeImpl)); + } + + Persistent element_; +}; + +} // namespace base_internal + Persistent TypeFactory::GetNullType() { return WrapSingletonType(); } @@ -73,4 +98,23 @@ Persistent TypeFactory::GetTimestampType() { return WrapSingletonType(); } +absl::StatusOr> TypeFactory::CreateListType( + const Persistent& element) { + absl::MutexLock lock(&mutex_); + auto existing = list_types_.find(element); + if (existing != list_types_.end()) { + return existing->second; + } + auto list_type = PersistentHandleFactory::Make< + const base_internal::ListTypeImpl>(memory_manager(), element); + if (ABSL_PREDICT_FALSE(!list_type)) { + // TODO(issues/5): maybe have the handle factories return statuses as + // they can add details on the size and alignment more easily and + // consistently? + return absl::ResourceExhaustedError("Failed to allocate memory"); + } + list_types_.insert({element, list_type}); + return list_type; +} + } // namespace cel diff --git a/base/type_factory.h b/base/type_factory.h index 39049b8ab..8be984cc9 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -18,6 +18,9 @@ #include #include "absl/base/attributes.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "base/handle.h" #include "base/memory_manager.h" #include "base/type.h" @@ -84,6 +87,9 @@ class TypeFactory { args)...); } + absl::StatusOr> CreateListType( + const Persistent& element) ABSL_ATTRIBUTE_LIFETIME_BOUND; + private: template static Persistent WrapSingletonType() { @@ -98,6 +104,11 @@ class TypeFactory { MemoryManager& memory_manager() const { return memory_manager_; } MemoryManager& memory_manager_; + absl::Mutex mutex_; + // Mapping from list element types to the list type. This allows us to cache + // list types and avoid re-creating the same type. + absl::flat_hash_map, Persistent> + list_types_ ABSL_GUARDED_BY(mutex_); }; } // namespace cel diff --git a/base/type_factory_test.cc b/base/type_factory_test.cc new file mode 100644 index 000000000..9ddd2f3c8 --- /dev/null +++ b/base/type_factory_test.cc @@ -0,0 +1,34 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/type_factory.h" + +#include "absl/status/status.h" +#include "base/memory_manager.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +TEST(TypeFactory, CreateListTypeCaches) { + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto list_type_1, + type_factory.CreateListType(type_factory.GetBoolType())); + ASSERT_OK_AND_ASSIGN(auto list_type_2, + type_factory.CreateListType(type_factory.GetBoolType())); + EXPECT_EQ(list_type_1.operator->(), list_type_2.operator->()); +} + +} // namespace +} // namespace cel diff --git a/base/type_test.cc b/base/type_test.cc index d3eb81305..ace083bfa 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -180,6 +180,7 @@ TEST(Type, Null) { EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); } TEST(Type, Error) { @@ -199,6 +200,7 @@ TEST(Type, Error) { EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); } TEST(Type, Dyn) { @@ -218,6 +220,7 @@ TEST(Type, Dyn) { EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); } TEST(Type, Any) { @@ -237,6 +240,7 @@ TEST(Type, Any) { EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); } TEST(Type, Bool) { @@ -256,6 +260,7 @@ TEST(Type, Bool) { EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); } TEST(Type, Int) { @@ -275,6 +280,7 @@ TEST(Type, Int) { EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); } TEST(Type, Uint) { @@ -294,6 +300,7 @@ TEST(Type, Uint) { EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); } TEST(Type, Double) { @@ -313,6 +320,7 @@ TEST(Type, Double) { EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); } TEST(Type, String) { @@ -332,6 +340,7 @@ TEST(Type, String) { EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); } TEST(Type, Bytes) { @@ -351,6 +360,7 @@ TEST(Type, Bytes) { EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); } TEST(Type, Duration) { @@ -370,6 +380,7 @@ TEST(Type, Duration) { EXPECT_TRUE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); } TEST(Type, Timestamp) { @@ -390,6 +401,7 @@ TEST(Type, Timestamp) { EXPECT_FALSE(type_factory.GetTimestampType().Is()); EXPECT_TRUE(type_factory.GetTimestampType().Is()); EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); } TEST(Type, Enum) { @@ -411,6 +423,32 @@ TEST(Type, Enum) { EXPECT_FALSE(enum_type.Is()); EXPECT_FALSE(enum_type.Is()); EXPECT_TRUE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); +} + +TEST(Type, List) { + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto list_type, + type_factory.CreateListType(type_factory.GetBoolType())); + EXPECT_EQ(list_type, + Must(type_factory.CreateListType(type_factory.GetBoolType()))); + EXPECT_EQ(list_type->kind(), Kind::kList); + EXPECT_EQ(list_type->name(), "list"); + EXPECT_EQ(list_type->element(), type_factory.GetBoolType()); + EXPECT_THAT(list_type->parameters(), SizeIs(0)); + EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); + EXPECT_TRUE(list_type.Is()); } TEST(EnumType, FindConstant) { @@ -460,6 +498,8 @@ TEST(Type, SupportsAbslHash) { Persistent(type_factory.GetDurationType()), Persistent(type_factory.GetTimestampType()), Persistent(Must(type_factory.CreateEnumType())), + Persistent( + Must(type_factory.CreateListType(type_factory.GetBoolType()))), })); } From a45ff1831e6248928749c5b221818ec7111b1b91 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 25 Mar 2022 04:59:44 +0000 Subject: [PATCH 040/111] Add a gunit matcher for CelList. PiperOrigin-RevId: 437163886 --- eval/public/testing/BUILD | 1 + eval/public/testing/matchers.h | 46 ++++++++++++++++++++++++++-- eval/public/testing/matchers_test.cc | 35 +++++++++++++++++++++ 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/eval/public/testing/BUILD b/eval/public/testing/BUILD index ab40cbc6a..b348a0bd3 100644 --- a/eval/public/testing/BUILD +++ b/eval/public/testing/BUILD @@ -25,6 +25,7 @@ cc_test( srcs = ["matchers_test.cc"], deps = [ ":matchers", + "//eval/public/containers:container_backed_list_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", "//internal:testing", diff --git a/eval/public/testing/matchers.h b/eval/public/testing/matchers.h index 1b59fb8bb..5d8d2e70c 100644 --- a/eval/public/testing/matchers.h +++ b/eval/public/testing/matchers.h @@ -59,8 +59,50 @@ CelValueMatcher IsCelTimestamp(testing::Matcher m); // The matcher |m| is wrapped to allow using the testing::status::... matchers. CelValueMatcher IsCelError(testing::Matcher m); -// TODO(issues/73): add helpers for working with maps, unknown sets, and -// lists. +// A matcher that wraps a Container matcher so that container matchers can be +// used for matching CelList. +// +// This matcher can be avoided if CelList supported the iterators needed by the +// standard container matchers but given that it is an interface it is a much +// larger project. +// +// TODO(issues/73): Re-use CelValueMatcherImpl. There are template details +// that need to be worked out specifically on how CelValueMatcherImpl can accept +// a generic matcher for CelList instead of testing::Matcher. +template +class CelListMatcher : public testing::MatcherInterface { + public: + explicit CelListMatcher(ContainerMatcher m) : container_matcher_(m) {} + + bool MatchAndExplain(const CelValue& v, + testing::MatchResultListener* listener) const override { + const CelList* cel_list; + if (!v.GetValue(&cel_list) || cel_list == nullptr) return false; + + std::vector cel_vector; + cel_vector.reserve(cel_list->size()); + for (int i = 0; i < cel_list->size(); ++i) { + cel_vector.push_back((*cel_list)[i]); + } + return container_matcher_.Matches(cel_vector); + } + + void DescribeTo(std::ostream* os) const override { + CelValue::Type type = + static_cast(CelValue::IndexOf::value); + *os << absl::StrCat("type is ", CelValue::TypeName(type), " and "); + container_matcher_.DescribeTo(os); + } + + private: + const testing::Matcher> container_matcher_; +}; + +template +CelValueMatcher IsCelList(ContainerMatcher m) { + return CelValueMatcher(new CelListMatcher(m)); +} +// TODO(issues/73): add helpers for working with maps and unknown sets. } // namespace test } // namespace runtime diff --git a/eval/public/testing/matchers_test.cc b/eval/public/testing/matchers_test.cc index 64542ecef..6b30a40af 100644 --- a/eval/public/testing/matchers_test.cc +++ b/eval/public/testing/matchers_test.cc @@ -2,6 +2,7 @@ #include "absl/status/status.h" #include "absl/time/time.h" +#include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" @@ -14,11 +15,14 @@ namespace runtime { namespace test { namespace { +using testing::Contains; using testing::DoubleEq; using testing::DoubleNear; +using testing::ElementsAre; using testing::Gt; using testing::Lt; using testing::Not; +using testing::UnorderedElementsAre; using testutil::EqualsProto; TEST(IsCelValue, EqualitySmoketest) { @@ -117,6 +121,37 @@ TEST(SpecialMatchers, SmokeTest) { EXPECT_THAT(message, IsCelMessage(EqualsProto(proto_message))); } +TEST(ListMatchers, NotList) { + EXPECT_THAT(CelValue::CreateInt64(1), + Not(IsCelList(Contains(IsCelInt64(1))))); +} + +TEST(ListMatchers, All) { + ContainerBackedListImpl list({ + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), + CelValue::CreateInt64(3), + CelValue::CreateInt64(4), + }); + CelValue cel_list = CelValue::CreateList(&list); + EXPECT_THAT(cel_list, IsCelList(Contains(IsCelInt64(3)))); + EXPECT_THAT(cel_list, IsCelList(Not(Contains(IsCelInt64(0))))); + + EXPECT_THAT(cel_list, IsCelList(ElementsAre(IsCelInt64(1), IsCelInt64(2), + IsCelInt64(3), IsCelInt64(4)))); + EXPECT_THAT(cel_list, + IsCelList(Not(ElementsAre(IsCelInt64(2), IsCelInt64(1), + IsCelInt64(3), IsCelInt64(4))))); + + EXPECT_THAT(cel_list, + IsCelList(UnorderedElementsAre(IsCelInt64(2), IsCelInt64(1), + IsCelInt64(4), IsCelInt64(3)))); + EXPECT_THAT( + cel_list, + IsCelList(Not(UnorderedElementsAre(IsCelInt64(2), IsCelInt64(1), + IsCelInt64(4), IsCelInt64(0))))); +} + } // namespace } // namespace test } // namespace runtime From 1ebb79c2b6b019c549cb272b571e67dcaec37fae Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 25 Mar 2022 18:03:45 +0000 Subject: [PATCH 041/111] Internal change PiperOrigin-RevId: 437289322 --- base/type.cc | 3 ++ base/type.h | 2 ++ base/type_test.cc | 76 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+) diff --git a/base/type.cc b/base/type.cc index 26af7cc39..c4d972880 100644 --- a/base/type.cc +++ b/base/type.cc @@ -14,6 +14,7 @@ #include "base/type.h" +#include #include #include "absl/types/span.h" @@ -48,6 +49,8 @@ CEL_INTERNAL_TYPE_IMPL(ListType); absl::Span> Type::parameters() const { return {}; } +std::string Type::DebugString() const { return std::string(name()); } + std::pair Type::SizeAndAlignment() const { // Currently no implementation of Type is reference counted. However once we // introduce Struct it likely will be. Using 0 here will trigger runtime diff --git a/base/type.h b/base/type.h index cba373bb4..900347df3 100644 --- a/base/type.h +++ b/base/type.h @@ -81,6 +81,8 @@ class Type : public base_internal::Resource { // Returns the type parameters of the type, i.e. key and value type of map. virtual absl::Span> parameters() const; + virtual std::string DebugString() const; + private: friend class NullType; friend class ErrorType; diff --git a/base/type_test.cc b/base/type_test.cc index ace083bfa..4f93e5d27 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -482,6 +482,82 @@ TEST(EnumType, FindConstant) { StatusIs(absl::StatusCode::kNotFound)); } +TEST(NullType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetNullType()->DebugString(), "null_type"); +} + +TEST(ErrorType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetErrorType()->DebugString(), "*error*"); +} + +TEST(DynType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetDynType()->DebugString(), "dyn"); +} + +TEST(AnyType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetAnyType()->DebugString(), "google.protobuf.Any"); +} + +TEST(BoolType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetBoolType()->DebugString(), "bool"); +} + +TEST(IntType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetIntType()->DebugString(), "int"); +} + +TEST(UintType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetUintType()->DebugString(), "uint"); +} + +TEST(DoubleType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetDoubleType()->DebugString(), "double"); +} + +TEST(StringType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetStringType()->DebugString(), "string"); +} + +TEST(BytesType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetBytesType()->DebugString(), "bytes"); +} + +TEST(DurationType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetDurationType()->DebugString(), + "google.protobuf.Duration"); +} + +TEST(TimestampType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + EXPECT_EQ(type_factory.GetTimestampType()->DebugString(), + "google.protobuf.Timestamp"); +} + +TEST(EnumType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto enum_type, + type_factory.CreateEnumType()); + EXPECT_EQ(enum_type->DebugString(), "test_enum.TestEnum"); +} + +TEST(ListType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto list_type, + type_factory.CreateListType(type_factory.GetBoolType())); + EXPECT_EQ(list_type->DebugString(), "list"); +} + TEST(Type, SupportsAbslHash) { TypeFactory type_factory(MemoryManager::Global()); EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ From c26c28bb243320bcc46da10206cd5ff81e20599a Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 25 Mar 2022 19:01:09 +0000 Subject: [PATCH 042/111] Internal change PiperOrigin-RevId: 437302980 --- base/internal/type.post.h | 1 + base/internal/type.pre.h | 1 + base/type.cc | 15 +++++++++++ base/type.h | 43 ++++++++++++++++++++++++++++++ base/type_factory.cc | 40 +++++++++++++++++++++++++++- base/type_factory.h | 17 ++++++++++-- base/type_factory_test.cc | 11 ++++++++ base/type_test.cc | 56 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 181 insertions(+), 3 deletions(-) diff --git a/base/internal/type.post.h b/base/internal/type.post.h index 5245015ba..1ccc4d30d 100644 --- a/base/internal/type.post.h +++ b/base/internal/type.post.h @@ -266,6 +266,7 @@ CEL_INTERNAL_TYPE_DECL(DurationType); CEL_INTERNAL_TYPE_DECL(TimestampType); CEL_INTERNAL_TYPE_DECL(EnumType); CEL_INTERNAL_TYPE_DECL(ListType); +CEL_INTERNAL_TYPE_DECL(MapType); #undef CEL_INTERNAL_TYPE_DECL } // namespace cel diff --git a/base/internal/type.pre.h b/base/internal/type.pre.h index 65c1722a4..f8a9029e6 100644 --- a/base/internal/type.pre.h +++ b/base/internal/type.pre.h @@ -40,6 +40,7 @@ inline constexpr uintptr_t kTypeHandleBits = inline constexpr uintptr_t kTypeHandleMask = ~kTypeHandleBits; class ListTypeImpl; +class MapTypeImpl; } // namespace cel::base_internal diff --git a/base/type.cc b/base/type.cc index c4d972880..9773ace4a 100644 --- a/base/type.cc +++ b/base/type.cc @@ -45,6 +45,7 @@ CEL_INTERNAL_TYPE_IMPL(DurationType); CEL_INTERNAL_TYPE_IMPL(TimestampType); CEL_INTERNAL_TYPE_IMPL(EnumType); CEL_INTERNAL_TYPE_IMPL(ListType); +CEL_INTERNAL_TYPE_IMPL(MapType); #undef CEL_INTERNAL_TYPE_IMPL absl::Span> Type::parameters() const { return {}; } @@ -153,4 +154,18 @@ void ListType::HashValue(absl::HashState state) const { Type::HashValue(absl::HashState::combine(std::move(state), element())); } +bool MapType::Equals(const Type& other) const { + if (kind() != other.kind()) { + return false; + } + return key() == internal::down_cast(other).key() && + value() == internal::down_cast(other).value(); +} + +void MapType::HashValue(absl::HashState state) const { + // We specifically hash the element first and then call the parent method to + // avoid hash suffix/prefix collisions. + Type::HashValue(absl::HashState::combine(std::move(state), key(), value())); +} + } // namespace cel diff --git a/base/type.h b/base/type.h index 900347df3..57bc7fc2e 100644 --- a/base/type.h +++ b/base/type.h @@ -47,6 +47,7 @@ class DurationType; class TimestampType; class EnumType; class ListType; +class MapType; class TypeFactory; class TypeProvider; @@ -98,6 +99,7 @@ class Type : public base_internal::Resource { friend class TimestampType; friend class EnumType; friend class ListType; + friend class MapType; friend class base_internal::TypeHandleBase; Type() = default; @@ -577,6 +579,47 @@ class ListType : public Type { void HashValue(absl::HashState state) const final; }; +// MapType represents a map type. A map is container of key and value pairs +// where each key appears at most once. +class MapType : public Type { + // I would have liked to make this class final, but we cannot instantiate + // Persistent or Transient at this point. It must be + // done after the post include below. Maybe we should separate out the post + // includes on a per type basis so we can do that? + public: + Kind kind() const final { return Kind::kMap; } + + absl::string_view name() const final { return "map"; } + + // Returns the type of the keys in the map. + virtual Transient key() const = 0; + + // Returns the type of the values in the map. + virtual Transient value() const = 0; + + private: + friend class TypeFactory; + friend class base_internal::TypeHandleBase; + friend class base_internal::MapTypeImpl; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kMap; } + + MapType() = default; + + MapType(const MapType&) = delete; + MapType(MapType&&) = delete; + + std::pair SizeAndAlignment() const override = 0; + + // Called by base_internal::TypeHandleBase. + bool Equals(const Type& other) const final; + + // Called by base_internal::TypeHandleBase. + void HashValue(absl::HashState state) const final; +}; + } // namespace cel // type.pre.h forward declares types so they can be friended above. The types diff --git a/base/type_factory.cc b/base/type_factory.cc index 24446e504..b29a9ae30 100644 --- a/base/type_factory.cc +++ b/base/type_factory.cc @@ -48,6 +48,24 @@ class ListTypeImpl final : public ListType { Persistent element_; }; +class MapTypeImpl final : public MapType { + public: + MapTypeImpl(Persistent key, Persistent value) + : key_(std::move(key)), value_(std::move(value)) {} + + Transient key() const override { return key_; } + + Transient value() const override { return value_; } + + private: + std::pair SizeAndAlignment() const override { + return std::make_pair(sizeof(MapTypeImpl), alignof(MapTypeImpl)); + } + + Persistent key_; + Persistent value_; +}; + } // namespace base_internal Persistent TypeFactory::GetNullType() { @@ -100,7 +118,7 @@ Persistent TypeFactory::GetTimestampType() { absl::StatusOr> TypeFactory::CreateListType( const Persistent& element) { - absl::MutexLock lock(&mutex_); + absl::MutexLock lock(&list_types_mutex_); auto existing = list_types_.find(element); if (existing != list_types_.end()) { return existing->second; @@ -117,4 +135,24 @@ absl::StatusOr> TypeFactory::CreateListType( return list_type; } +absl::StatusOr> TypeFactory::CreateMapType( + const Persistent& key, const Persistent& value) { + auto key_and_value = std::make_pair(key, value); + absl::MutexLock lock(&map_types_mutex_); + auto existing = map_types_.find(key_and_value); + if (existing != map_types_.end()) { + return existing->second; + } + auto map_type = PersistentHandleFactory::Make< + const base_internal::MapTypeImpl>(memory_manager(), key, value); + if (ABSL_PREDICT_FALSE(!map_type)) { + // TODO(issues/5): maybe have the handle factories return statuses as + // they can add details on the size and alignment more easily and + // consistently? + return absl::ResourceExhaustedError("Failed to allocate memory"); + } + map_types_.insert({std::move(key_and_value), map_type}); + return map_type; +} + } // namespace cel diff --git a/base/type_factory.h b/base/type_factory.h index 8be984cc9..83014eaad 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_BASE_TYPE_FACTORY_H_ #include +#include #include "absl/base/attributes.h" #include "absl/base/thread_annotations.h" @@ -90,6 +91,10 @@ class TypeFactory { absl::StatusOr> CreateListType( const Persistent& element) ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::StatusOr> CreateMapType( + const Persistent& key, + const Persistent& value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + private: template static Persistent WrapSingletonType() { @@ -104,11 +109,19 @@ class TypeFactory { MemoryManager& memory_manager() const { return memory_manager_; } MemoryManager& memory_manager_; - absl::Mutex mutex_; + + absl::Mutex list_types_mutex_; // Mapping from list element types to the list type. This allows us to cache // list types and avoid re-creating the same type. absl::flat_hash_map, Persistent> - list_types_ ABSL_GUARDED_BY(mutex_); + list_types_ ABSL_GUARDED_BY(list_types_mutex_); + + absl::Mutex map_types_mutex_; + // Mapping from map key and value types to the map type. This allows us to + // cache map types and avoid re-creating the same type. + absl::flat_hash_map, Persistent>, + Persistent> + map_types_ ABSL_GUARDED_BY(map_types_mutex_); }; } // namespace cel diff --git a/base/type_factory_test.cc b/base/type_factory_test.cc index 9ddd2f3c8..1dc80d797 100644 --- a/base/type_factory_test.cc +++ b/base/type_factory_test.cc @@ -30,5 +30,16 @@ TEST(TypeFactory, CreateListTypeCaches) { EXPECT_EQ(list_type_1.operator->(), list_type_2.operator->()); } +TEST(TypeFactory, CreateMapTypeCaches) { + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto map_type_1, + type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetBoolType())); + ASSERT_OK_AND_ASSIGN(auto map_type_2, + type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetBoolType())); + EXPECT_EQ(map_type_1.operator->(), map_type_2.operator->()); +} + } // namespace } // namespace cel diff --git a/base/type_test.cc b/base/type_test.cc index 4f93e5d27..f679410bc 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -181,6 +181,7 @@ TEST(Type, Null) { EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); } TEST(Type, Error) { @@ -201,6 +202,7 @@ TEST(Type, Error) { EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); } TEST(Type, Dyn) { @@ -221,6 +223,7 @@ TEST(Type, Dyn) { EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); } TEST(Type, Any) { @@ -241,6 +244,7 @@ TEST(Type, Any) { EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); } TEST(Type, Bool) { @@ -261,6 +265,7 @@ TEST(Type, Bool) { EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); } TEST(Type, Int) { @@ -281,6 +286,7 @@ TEST(Type, Int) { EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); } TEST(Type, Uint) { @@ -301,6 +307,7 @@ TEST(Type, Uint) { EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); } TEST(Type, Double) { @@ -321,6 +328,7 @@ TEST(Type, Double) { EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); } TEST(Type, String) { @@ -341,6 +349,7 @@ TEST(Type, String) { EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); } TEST(Type, Bytes) { @@ -361,6 +370,7 @@ TEST(Type, Bytes) { EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); } TEST(Type, Duration) { @@ -381,6 +391,7 @@ TEST(Type, Duration) { EXPECT_FALSE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); } TEST(Type, Timestamp) { @@ -402,6 +413,7 @@ TEST(Type, Timestamp) { EXPECT_TRUE(type_factory.GetTimestampType().Is()); EXPECT_FALSE(type_factory.GetTimestampType().Is()); EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); } TEST(Type, Enum) { @@ -424,6 +436,7 @@ TEST(Type, Enum) { EXPECT_FALSE(enum_type.Is()); EXPECT_TRUE(enum_type.Is()); EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); } TEST(Type, List) { @@ -449,6 +462,39 @@ TEST(Type, List) { EXPECT_FALSE(list_type.Is()); EXPECT_FALSE(list_type.Is()); EXPECT_TRUE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); +} + +TEST(Type, Map) { + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto map_type, + type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetBoolType())); + EXPECT_EQ(map_type, + Must(type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetBoolType()))); + EXPECT_NE(map_type, + Must(type_factory.CreateMapType(type_factory.GetBoolType(), + type_factory.GetStringType()))); + EXPECT_EQ(map_type->kind(), Kind::kMap); + EXPECT_EQ(map_type->name(), "map"); + EXPECT_EQ(map_type->key(), type_factory.GetStringType()); + EXPECT_EQ(map_type->value(), type_factory.GetBoolType()); + EXPECT_THAT(map_type->parameters(), SizeIs(0)); + EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); + EXPECT_TRUE(map_type.Is()); } TEST(EnumType, FindConstant) { @@ -558,6 +604,14 @@ TEST(ListType, DebugString) { EXPECT_EQ(list_type->DebugString(), "list"); } +TEST(MapType, DebugString) { + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto map_type, + type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetBoolType())); + EXPECT_EQ(map_type->DebugString(), "map"); +} + TEST(Type, SupportsAbslHash) { TypeFactory type_factory(MemoryManager::Global()); EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ @@ -576,6 +630,8 @@ TEST(Type, SupportsAbslHash) { Persistent(Must(type_factory.CreateEnumType())), Persistent( Must(type_factory.CreateListType(type_factory.GetBoolType()))), + Persistent(Must(type_factory.CreateMapType( + type_factory.GetStringType(), type_factory.GetBoolType()))), })); } From ae678dfa85b243a53d04afee67ad1d5b8b2f8d4d Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 25 Mar 2022 20:49:56 +0000 Subject: [PATCH 043/111] Refactors: Remove direct references to descriptor pool and message factory from the core evaluator. Wire type registry into execution frame for runtime type lookups. PiperOrigin-RevId: 437327358 --- eval/compiler/flat_expr_builder.cc | 2 +- eval/eval/BUILD | 28 +++++++++++ eval/eval/comprehension_step_test.cc | 6 +-- eval/eval/const_value_step_test.cc | 6 +-- eval/eval/container_access_step_test.cc | 6 +-- eval/eval/create_list_step_test.cc | 16 +++--- eval/eval/create_struct_step_test.cc | 16 +++--- eval/eval/evaluator_core.cc | 4 +- eval/eval/evaluator_core.h | 25 ++++----- eval/eval/evaluator_core_test.cc | 13 ++--- eval/eval/function_step_test.cc | 67 ++++++++++--------------- eval/eval/ident_step_test.cc | 20 +++----- eval/eval/logic_step_test.cc | 5 +- eval/eval/select_step_test.cc | 35 ++++++------- eval/eval/shadowable_value_step_test.cc | 6 +-- eval/eval/ternary_step_test.cc | 5 +- eval/eval/test_type_registry.cc | 40 +++++++++++++++ eval/eval/test_type_registry.h | 27 ++++++++++ 18 files changed, 189 insertions(+), 138 deletions(-) create mode 100644 eval/eval/test_type_registry.cc create mode 100644 eval/eval/test_type_registry.h diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index b2ed5cbc6..d72d33edb 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1058,7 +1058,7 @@ FlatExprBuilder::CreateExpressionImpl( std::unique_ptr expression_impl = absl::make_unique( - expr, std::move(execution_path), descriptor_pool_, message_factory_, + expr, std::move(execution_path), GetTypeRegistry(), comprehension_max_iterations_, std::move(iter_variable_names), enable_unknowns_, enable_unknown_function_results_, enable_missing_attribute_errors_, enable_null_coercion_, diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 885b4b86a..3d4d5f44b 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -19,9 +19,11 @@ cc_library( ":attribute_utility", ":evaluator_stack", "//base:memory_manager", + "//eval/compiler:resolver", "//eval/public:base_activation", "//eval/public:cel_attribute", "//eval/public:cel_expression", + "//eval/public:cel_type_registry", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//extensions/protobuf:memory_manager", @@ -300,6 +302,7 @@ cc_test( ":comprehension_step", ":evaluator_core", ":ident_step", + ":test_type_registry", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_options", @@ -323,6 +326,7 @@ cc_test( deps = [ ":attribute_trail", ":evaluator_core", + ":test_type_registry", "//eval/compiler:flat_expr_builder", "//eval/public:activation", "//eval/public:builtin_func_registrar", @@ -345,6 +349,7 @@ cc_test( deps = [ ":const_value_step", ":evaluator_core", + ":test_type_registry", "//eval/public:activation", "//internal:status_macros", "//internal:testing", @@ -363,6 +368,7 @@ cc_test( deps = [ ":container_access_step", ":ident_step", + ":test_type_registry", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", @@ -393,6 +399,7 @@ cc_test( deps = [ ":evaluator_core", ":ident_step", + ":test_type_registry", "//eval/public:activation", "//internal:status_macros", "//internal:testing", @@ -412,6 +419,7 @@ cc_test( ":expression_build_warning", ":function_step", ":ident_step", + ":test_type_registry", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_function", @@ -440,6 +448,7 @@ cc_test( deps = [ ":ident_step", ":logic_step", + ":test_type_registry", "//eval/public:activation", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", @@ -458,6 +467,7 @@ cc_test( deps = [ ":ident_step", ":select_step", + ":test_type_registry", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:unknown_attribute_set", @@ -484,6 +494,7 @@ cc_test( ":const_value_step", ":create_list_step", ":ident_step", + ":test_type_registry", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:unknown_attribute_set", @@ -504,6 +515,7 @@ cc_test( deps = [ ":create_struct_step", ":ident_step", + ":test_type_registry", "//eval/public:activation", "//eval/public:cel_type_registry", "//eval/public/containers:container_backed_list_impl", @@ -650,6 +662,7 @@ cc_test( deps = [ ":ident_step", ":ternary_step", + ":test_type_registry", "//eval/public:activation", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", @@ -686,6 +699,7 @@ cc_test( deps = [ ":evaluator_core", ":shadowable_value_step", + ":test_type_registry", "//eval/public:activation", "//eval/public:cel_value", "//internal:status_macros", @@ -695,3 +709,17 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "test_type_registry", + testonly = True, + srcs = ["test_type_registry.cc"], + hdrs = ["test_type_registry.h"], + deps = [ + "//eval/public:cel_type_registry", + "//eval/public/containers:field_access", + "//eval/public/structs:protobuf_descriptor_type_provider", + "//internal:no_destructor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/eval/eval/comprehension_step_test.cc b/eval/eval/comprehension_step_test.cc index feb7312dc..5ee42109b 100644 --- a/eval/eval/comprehension_step_test.cc +++ b/eval/eval/comprehension_step_test.cc @@ -12,6 +12,7 @@ #include "absl/strings/string_view.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" +#include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_options.h" @@ -45,9 +46,8 @@ class ListKeysStepTest : public testing::Test { std::unique_ptr MakeExpression( ExecutionPath&& path, bool unknown_attributes = false) { return std::make_unique( - &dummy_expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, std::set(), - unknown_attributes, unknown_attributes); + &dummy_expr_, std::move(path), &TestTypeRegistry(), 0, + std::set(), unknown_attributes, unknown_attributes); } private: diff --git a/eval/eval/const_value_step_test.cc b/eval/eval/const_value_step_test.cc index 5251ee185..b5f351309 100644 --- a/eval/eval/const_value_step_test.cc +++ b/eval/eval/const_value_step_test.cc @@ -6,6 +6,7 @@ #include "google/protobuf/descriptor.h" #include "absl/status/statusor.h" #include "eval/eval/evaluator_core.h" +#include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "internal/status_macros.h" #include "internal/testing.h" @@ -33,9 +34,8 @@ absl::StatusOr RunConstantExpression(const Expr* expr, google::api::expr::v1alpha1::Expr dummy_expr; - CelExpressionFlatImpl impl( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}); + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), &TestTypeRegistry(), + 0, {}); Activation activation; diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 5a8c9f2e5..89ce881e2 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -11,6 +11,7 @@ #include "google/protobuf/descriptor.h" #include "absl/status/status.h" #include "eval/eval/ident_step.h" +#include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_attribute.h" @@ -65,9 +66,8 @@ CelValue EvaluateAttributeHelper( std::move(CreateIdentStep(&key_expr->ident_expr(), 2).value())); path.push_back(std::move(CreateContainerAccessStep(call, 3).value())); - CelExpressionFlatImpl cel_expr( - &expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, enable_unknown); + CelExpressionFlatImpl cel_expr(&expr, std::move(path), &TestTypeRegistry(), 0, + {}, enable_unknown); Activation activation; activation.InsertValue("container", container); diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index 8a80268f2..516f68cb1 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -8,6 +8,7 @@ #include "absl/strings/str_cat.h" #include "eval/eval/const_value_step.h" #include "eval/eval/ident_step.h" +#include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/unknown_attribute_set.h" @@ -46,9 +47,8 @@ absl::StatusOr RunExpression(const std::vector& values, CreateCreateListStep(create_list, dummy_expr.id())); path.push_back(std::move(step)); - CelExpressionFlatImpl cel_expr( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, enable_unknowns); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), + &TestTypeRegistry(), 0, {}, enable_unknowns); Activation activation; return cel_expr.Evaluate(activation, arena); @@ -80,9 +80,8 @@ absl::StatusOr RunExpressionWithCelValues( CreateCreateListStep(create_list, dummy_expr.id())); path.push_back(std::move(step0)); - CelExpressionFlatImpl cel_expr( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, enable_unknowns); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), + &TestTypeRegistry(), 0, {}, enable_unknowns); return cel_expr.Evaluate(activation, arena); } @@ -103,9 +102,8 @@ TEST(CreateListStepTest, TestCreateListStackUnderflow) { CreateCreateListStep(create_list, dummy_expr.id())); path.push_back(std::move(step0)); - CelExpressionFlatImpl cel_expr( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), + &TestTypeRegistry(), 0, {}); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index e62d6a213..85efc2d2f 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -10,6 +10,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "eval/eval/ident_step.h" +#include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "eval/public/cel_type_registry.h" #include "eval/public/containers/container_backed_list_impl.h" @@ -77,9 +78,8 @@ absl::StatusOr RunExpression(absl::string_view field, path.push_back(std::move(step0)); path.push_back(std::move(step1)); - CelExpressionFlatImpl cel_expr( - &expr1, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, enable_unknowns); + CelExpressionFlatImpl cel_expr(&expr1, std::move(path), &type_registry, 0, {}, + enable_unknowns); Activation activation; activation.InsertValue("message", value); @@ -166,9 +166,8 @@ absl::StatusOr RunCreateMapExpression( CreateCreateStructStep(create_struct, expr1.id())); path.push_back(std::move(step1)); - CelExpressionFlatImpl cel_expr( - &expr1, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, enable_unknowns); + CelExpressionFlatImpl cel_expr(&expr1, std::move(path), &TestTypeRegistry(), + 0, {}, enable_unknowns); return cel_expr.Evaluate(activation, arena); } @@ -193,9 +192,8 @@ TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { expr1.id())); path.push_back(std::move(step)); - CelExpressionFlatImpl cel_expr( - &expr1, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, GetParam()); + CelExpressionFlatImpl cel_expr(&expr1, std::move(path), &type_registry, 0, {}, + GetParam()); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index febbad54d..27904ce45 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -153,8 +153,8 @@ absl::StatusOr CelExpressionFlatImpl::Trace( ::cel::internal::down_cast(_state); state->Reset(); - ExecutionFrame frame(path_, activation, descriptor_pool_, message_factory_, - max_iterations_, state, enable_unknowns_, + ExecutionFrame frame(path_, activation, &type_registry_, max_iterations_, + state, enable_unknowns_, enable_unknown_function_results_, enable_missing_attribute_errors_, enable_null_coercion_, enable_heterogeneous_equality_); diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 7f3308c6f..b3f867776 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -23,12 +23,14 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/memory_manager.h" +#include "eval/compiler/resolver.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/attribute_utility.h" #include "eval/eval/evaluator_stack.h" #include "eval/public/base_activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_expression.h" +#include "eval/public/cel_type_registry.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "extensions/protobuf/memory_manager.h" @@ -119,8 +121,7 @@ class ExecutionFrame { // arena serves as allocation manager during the expression evaluation. ExecutionFrame(const ExecutionPath& flat, const BaseActivation& activation, - const google::protobuf::DescriptorPool* descriptor_pool, - google::protobuf::MessageFactory* message_factory, int max_iterations, + const CelTypeRegistry* type_registry, int max_iterations, CelExpressionFlatEvaluationState* state, bool enable_unknowns, bool enable_unknown_function_results, bool enable_missing_attribute_errors, @@ -129,8 +130,7 @@ class ExecutionFrame { : pc_(0UL), execution_path_(flat), activation_(activation), - descriptor_pool_(descriptor_pool), - message_factory_(message_factory), + type_registry_(*type_registry), enable_unknowns_(enable_unknowns), enable_unknown_function_results_(enable_unknown_function_results), enable_missing_attribute_errors_(enable_missing_attribute_errors), @@ -177,10 +177,7 @@ class ExecutionFrame { cel::MemoryManager& memory_manager() { return state_->memory_manager(); } - const google::protobuf::DescriptorPool* descriptor_pool() const { - return descriptor_pool_; - } - google::protobuf::MessageFactory* message_factory() const { return message_factory_; } + const CelTypeRegistry& type_registry() { return type_registry_; } const AttributeUtility& attribute_utility() const { return attribute_utility_; @@ -241,8 +238,7 @@ class ExecutionFrame { size_t pc_; // pc_ - Program Counter. Current position on execution path. const ExecutionPath& execution_path_; const BaseActivation& activation_; - const google::protobuf::DescriptorPool* descriptor_pool_; - google::protobuf::MessageFactory* message_factory_; + const CelTypeRegistry& type_registry_; bool enable_unknowns_; bool enable_unknown_function_results_; bool enable_missing_attribute_errors_; @@ -265,8 +261,7 @@ class CelExpressionFlatImpl : public CelExpression { // bound). CelExpressionFlatImpl(ABSL_ATTRIBUTE_UNUSED const Expr* root_expr, ExecutionPath path, - const google::protobuf::DescriptorPool* descriptor_pool, - google::protobuf::MessageFactory* message_factory, + const CelTypeRegistry* type_registry, int max_iterations, std::set iter_variable_names, bool enable_unknowns = false, @@ -277,8 +272,7 @@ class CelExpressionFlatImpl : public CelExpression { std::unique_ptr rewritten_expr = nullptr) : rewritten_expr_(std::move(rewritten_expr)), path_(std::move(path)), - descriptor_pool_(descriptor_pool), - message_factory_(message_factory), + type_registry_(*type_registry), max_iterations_(max_iterations), iter_variable_names_(std::move(iter_variable_names)), enable_unknowns_(enable_unknowns), @@ -318,8 +312,7 @@ class CelExpressionFlatImpl : public CelExpression { // Maintain lifecycle of a modified expression. std::unique_ptr rewritten_expr_; const ExecutionPath path_; - const google::protobuf::DescriptorPool* descriptor_pool_; - google::protobuf::MessageFactory* message_factory_; + const CelTypeRegistry& type_registry_; const int max_iterations_; const std::set iter_variable_names_; bool enable_unknowns_; diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index 61728e73a..129ef5785 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -7,6 +7,7 @@ #include "google/protobuf/descriptor.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/eval/attribute_trail.h" +#include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_attribute.h" @@ -69,9 +70,7 @@ TEST(EvaluatorCoreTest, ExecutionFrameNext) { Activation activation; CelExpressionFlatEvaluationState state(path.size(), {}, nullptr); - ExecutionFrame frame(path, activation, - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, &state, + ExecutionFrame frame(path, activation, &TestTypeRegistry(), 0, &state, /*enable_unknowns=*/false, /*enable_unknown_funcion_results=*/false, /*enable_missing_attribute_errors=*/false, @@ -95,9 +94,7 @@ TEST(EvaluatorCoreTest, ExecutionFrameSetGetClearVar) { ProtoMemoryManager manager(&arena); ExecutionPath path; CelExpressionFlatEvaluationState state(path.size(), {test_iter_var}, nullptr); - ExecutionFrame frame(path, activation, - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, &state, + ExecutionFrame frame(path, activation, &TestTypeRegistry(), 0, &state, /*enable_unknowns=*/false, /*enable_unknown_funcion_results=*/false, /*enable_missing_attribute_errors=*/false, @@ -168,9 +165,7 @@ TEST(EvaluatorCoreTest, SimpleEvaluatorTest) { auto dummy_expr = absl::make_unique(); CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, - {}); + &TestTypeRegistry(), 0, {}); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index d64020434..690ce82cd 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -12,6 +12,7 @@ #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_build_warning.h" #include "eval/eval/ident_step.h" +#include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" @@ -228,9 +229,8 @@ class FunctionStepTest break; } return absl::make_unique( - &dummy_expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, std::set(), - unknowns, unknown_function_results); + &dummy_expr_, std::move(path), &TestTypeRegistry(), 0, + std::set(), unknowns, unknown_function_results); } private: @@ -483,9 +483,8 @@ class FunctionStepTestUnknowns break; } return absl::make_unique( - &expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, std::set(), - true, unknown_functions); + &expr_, std::move(path), &TestTypeRegistry(), 0, + std::set(), true, unknown_functions); } private: @@ -634,9 +633,8 @@ TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { Expr dummy_expr; - CelExpressionFlatImpl impl( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true); + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), &TestTypeRegistry(), + 0, {}, true, true); Activation activation; google::protobuf::Arena arena; @@ -685,9 +683,8 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { Expr dummy_expr; - CelExpressionFlatImpl impl( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true); + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), &TestTypeRegistry(), + 0, {}, true, true); Activation activation; google::protobuf::Arena arena; @@ -736,9 +733,8 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { Expr dummy_expr; - CelExpressionFlatImpl impl( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true); + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), &TestTypeRegistry(), + 0, {}, true, true); Activation activation; google::protobuf::Arena arena; @@ -782,9 +778,8 @@ TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { Expr dummy_expr; - CelExpressionFlatImpl impl( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true); + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), &TestTypeRegistry(), + 0, {}, true, true); Activation activation; google::protobuf::Arena arena; @@ -884,10 +879,9 @@ TEST_F(FunctionStepNullCoercionTest, EnabledSupportsMessageOverloads) { path.push_back(std::move(call_step)); - CelExpressionFlatImpl impl( - &dummy_expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true, true, - /*enable_null_coercion=*/true); + CelExpressionFlatImpl impl(&dummy_expr_, std::move(path), &TestTypeRegistry(), + 0, {}, true, true, true, + /*enable_null_coercion=*/true); ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation_, &arena_)); ASSERT_TRUE(value.IsString()); @@ -909,10 +903,9 @@ TEST_F(FunctionStepNullCoercionTest, EnabledPrefersNullOverloads) { path.push_back(std::move(call_step)); - CelExpressionFlatImpl impl( - &dummy_expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true, true, - /*enable_null_coercion=*/true); + CelExpressionFlatImpl impl(&dummy_expr_, std::move(path), &TestTypeRegistry(), + 0, {}, true, true, true, + /*enable_null_coercion=*/true); ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation_, &arena_)); ASSERT_TRUE(value.IsString()); @@ -933,10 +926,9 @@ TEST_F(FunctionStepNullCoercionTest, EnabledNullMessageDoesNotEscape) { path.push_back(std::move(call_step)); - CelExpressionFlatImpl impl( - &dummy_expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true, true, - /*enable_null_coercion=*/true); + CelExpressionFlatImpl impl(&dummy_expr_, std::move(path), &TestTypeRegistry(), + 0, {}, true, true, true, + /*enable_null_coercion=*/true); ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation_, &arena_)); ASSERT_TRUE(value.IsNull()); @@ -957,10 +949,9 @@ TEST_F(FunctionStepNullCoercionTest, Disabled) { path.push_back(std::move(call_step)); - CelExpressionFlatImpl impl( - &dummy_expr_, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, true, true, true, - /*enable_null_coercion=*/false); + CelExpressionFlatImpl impl(&dummy_expr_, std::move(path), &TestTypeRegistry(), + 0, {}, true, true, true, + /*enable_null_coercion=*/false); ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation_, &arena_)); ASSERT_TRUE(value.IsError()); @@ -985,9 +976,7 @@ TEST(FunctionStepStrictnessTest, path.push_back(std::move(step1)); Expr placeholder_expr; CelExpressionFlatImpl impl(&placeholder_expr, std::move(path), - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, - true, true); + &TestTypeRegistry(), 0, {}, true, true); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); @@ -1012,9 +1001,7 @@ TEST(FunctionStepStrictnessTest, IfFunctionNonStrictAndGivenUnknownInvokesIt) { path.push_back(std::move(step1)); Expr placeholder_expr; CelExpressionFlatImpl impl(&placeholder_expr, std::move(path), - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, - true, true); + &TestTypeRegistry(), 0, {}, true, true); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); diff --git a/eval/eval/ident_step_test.cc b/eval/eval/ident_step_test.cc index 5bbd692ef..ee2438a17 100644 --- a/eval/eval/ident_step_test.cc +++ b/eval/eval/ident_step_test.cc @@ -6,6 +6,7 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/descriptor.h" #include "eval/eval/evaluator_core.h" +#include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "internal/status_macros.h" #include "internal/testing.h" @@ -32,9 +33,7 @@ TEST(IdentStepTest, TestIdentStep) { auto dummy_expr = absl::make_unique(); CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, - {}); + &TestTypeRegistry(), 0, {}); Activation activation; Arena arena; @@ -63,9 +62,7 @@ TEST(IdentStepTest, TestIdentStepNameNotFound) { auto dummy_expr = absl::make_unique(); CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, - {}); + &TestTypeRegistry(), 0, {}); Activation activation; Arena arena; @@ -91,8 +88,7 @@ TEST(IdentStepTest, DisableMissingAttributeErrorsOK) { auto dummy_expr = absl::make_unique(); CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, + &TestTypeRegistry(), 0, {}, /*enable_unknowns=*/false); Activation activation; @@ -130,9 +126,7 @@ TEST(IdentStepTest, TestIdentStepMissingAttributeErrors) { auto dummy_expr = absl::make_unique(); CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, - false, false, + &TestTypeRegistry(), 0, {}, false, false, /*enable_missing_attribute_errors=*/true); Activation activation; @@ -172,9 +166,7 @@ TEST(IdentStepTest, TestIdentStepUnknownAttribute) { // Expression with unknowns enabled. CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, - true); + &TestTypeRegistry(), 0, {}, true); Activation activation; Arena arena; diff --git a/eval/eval/logic_step_test.cc b/eval/eval/logic_step_test.cc index 1300360ed..7584a4219 100644 --- a/eval/eval/logic_step_test.cc +++ b/eval/eval/logic_step_test.cc @@ -4,6 +4,7 @@ #include "google/protobuf/descriptor.h" #include "eval/eval/ident_step.h" +#include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" @@ -42,9 +43,7 @@ class LogicStepTest : public testing::TestWithParam { auto dummy_expr = absl::make_unique(); CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, - {}, enable_unknown); + &TestTypeRegistry(), 0, {}, enable_unknown); Activation activation; activation.InsertValue("name0", arg0); diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index 8b3ec5452..5b1fab4ff 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -9,6 +9,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "eval/eval/ident_step.h" +#include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/containers/container_backed_map_impl.h" @@ -60,9 +61,8 @@ absl::StatusOr RunExpression(const CelValue target, path.push_back(std::move(step1)); CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, - {}, options.enable_unknowns); + &TestTypeRegistry(), 0, {}, + options.enable_unknowns); Activation activation; activation.InsertValue("target", target); @@ -207,9 +207,8 @@ TEST(SelectStepTest, MapPresenseIsErrorTest) { path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); - CelExpressionFlatImpl cel_expr( - &select_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, false); + CelExpressionFlatImpl cel_expr(&select_expr, std::move(path), + &TestTypeRegistry(), 0, {}, false); Activation activation; activation.InsertValue("target", CelProtoWrapper::CreateMessage(&message, &arena)); @@ -513,9 +512,8 @@ TEST_P(SelectStepTest, CelErrorAsArgument) { google::protobuf::Arena arena; bool enable_unknowns = GetParam(); - CelExpressionFlatImpl cel_expr( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, enable_unknowns); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), + &TestTypeRegistry(), 0, {}, enable_unknowns); Activation activation; activation.InsertValue("message", CelValue::CreateError(&error)); @@ -548,10 +546,9 @@ TEST(SelectStepTest, DisableMissingAttributeOK) { path.push_back(std::move(step0)); path.push_back(std::move(step1)); - CelExpressionFlatImpl cel_expr( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, - /*enable_unknowns=*/false); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), + &TestTypeRegistry(), 0, {}, + /*enable_unknowns=*/false); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena)); @@ -591,10 +588,9 @@ TEST(SelectStepTest, UnrecoverableUnknownValueProducesError) { path.push_back(std::move(step0)); path.push_back(std::move(step1)); - CelExpressionFlatImpl cel_expr( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, false, false, - /*enable_missing_attribute_errors=*/true); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), + &TestTypeRegistry(), 0, {}, false, false, + /*enable_missing_attribute_errors=*/true); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena)); @@ -640,9 +636,8 @@ TEST(SelectStepTest, UnknownPatternResolvesToUnknown) { path.push_back(*std::move(step0_status)); path.push_back(*std::move(step1_status)); - CelExpressionFlatImpl cel_expr( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}, true); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), + &TestTypeRegistry(), 0, {}, true); { std::vector unknown_patterns; diff --git a/eval/eval/shadowable_value_step_test.cc b/eval/eval/shadowable_value_step_test.cc index e4de0d03e..f90e8add6 100644 --- a/eval/eval/shadowable_value_step_test.cc +++ b/eval/eval/shadowable_value_step_test.cc @@ -7,6 +7,7 @@ #include "google/protobuf/descriptor.h" #include "absl/status/statusor.h" #include "eval/eval/evaluator_core.h" +#include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" @@ -29,9 +30,8 @@ absl::StatusOr RunShadowableExpression(const std::string& identifier, path.push_back(std::move(step)); google::api::expr::v1alpha1::Expr dummy_expr; - CelExpressionFlatImpl impl( - &dummy_expr, std::move(path), google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, {}); + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), &TestTypeRegistry(), + 0, {}); return impl.Evaluate(activation, arena); } diff --git a/eval/eval/ternary_step_test.cc b/eval/eval/ternary_step_test.cc index 10d57df61..b89512d7c 100644 --- a/eval/eval/ternary_step_test.cc +++ b/eval/eval/ternary_step_test.cc @@ -5,6 +5,7 @@ #include "google/protobuf/descriptor.h" #include "eval/eval/ident_step.h" +#include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" @@ -55,9 +56,7 @@ class LogicStepTest : public testing::TestWithParam { auto dummy_expr = absl::make_unique(); CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory(), 0, - {}, enable_unknown); + &TestTypeRegistry(), 0, {}, enable_unknown); Activation activation; std::string value("test"); diff --git a/eval/eval/test_type_registry.cc b/eval/eval/test_type_registry.cc new file mode 100644 index 000000000..baa175ae3 --- /dev/null +++ b/eval/eval/test_type_registry.cc @@ -0,0 +1,40 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/eval/test_type_registry.h" + +#include + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "eval/public/cel_type_registry.h" +#include "eval/public/containers/field_access.h" +#include "eval/public/structs/protobuf_descriptor_type_provider.h" +#include "internal/no_destructor.h" + +namespace google::api::expr::runtime { + +const CelTypeRegistry& TestTypeRegistry() { + static CelTypeRegistry* registry = ([]() { + auto registry = std::make_unique(); + registry->RegisterTypeProvider(std::make_unique( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); + return registry.release(); + }()); + + return *registry; +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/test_type_registry.h b/eval/eval/test_type_registry.h new file mode 100644 index 000000000..cdf81cffd --- /dev/null +++ b/eval/eval/test_type_registry.h @@ -0,0 +1,27 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_TEST_TYPE_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_TEST_TYPE_REGISTRY_H_ + +#include "eval/public/cel_type_registry.h" +namespace google::api::expr::runtime { + +// Returns a static singleton type registry suitable for use in most +// tests directly creating CelExpressionFlatImpl instances. +const CelTypeRegistry& TestTypeRegistry(); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_TEST_TYPE_REGISTRY_H_ From 3e50dcd9683dc071a6e1a7df4ae69e1ad27db340 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 25 Mar 2022 21:10:56 +0000 Subject: [PATCH 044/111] Internal change PiperOrigin-RevId: 437331917 --- base/BUILD | 2 ++ base/internal/BUILD | 2 ++ base/internal/type.post.h | 5 +++ base/internal/type.pre.h | 13 ++++++-- base/internal/value.post.h | 4 +++ base/internal/value.pre.h | 13 ++++++-- base/type.h | 62 ++++++++++++++++++++++++++------------ base/type_factory.h | 14 +++------ base/type_test.cc | 1 + base/value.h | 62 ++++++++++++++++++++++++++------------ base/value_factory.h | 14 +++------ base/value_test.cc | 2 ++ 12 files changed, 130 insertions(+), 64 deletions(-) diff --git a/base/BUILD b/base/BUILD index e5fc06487..cde3192af 100644 --- a/base/BUILD +++ b/base/BUILD @@ -115,6 +115,7 @@ cc_library( "//base/internal:type", "//internal:casts", "//internal:no_destructor", + "//internal:rtti", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", @@ -162,6 +163,7 @@ cc_library( "//base/internal:value", "//internal:casts", "//internal:no_destructor", + "//internal:rtti", "//internal:status_macros", "//internal:strings", "//internal:time", diff --git a/base/internal/BUILD b/base/internal/BUILD index 2e13eb5e0..33ebe7ea3 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -56,6 +56,7 @@ cc_library( ], deps = [ "//base:handle", + "//internal:rtti", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/hash", "@com_google_absl//absl/numeric:bits", @@ -71,6 +72,7 @@ cc_library( deps = [ "//base:handle", "//internal:casts", + "//internal:rtti", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/hash", "@com_google_absl//absl/numeric:bits", diff --git a/base/internal/type.post.h b/base/internal/type.post.h index 1ccc4d30d..4081dadff 100644 --- a/base/internal/type.post.h +++ b/base/internal/type.post.h @@ -27,11 +27,16 @@ #include "absl/hash/hash.h" #include "absl/numeric/bits.h" #include "base/handle.h" +#include "internal/rtti.h" namespace cel { namespace base_internal { +inline internal::TypeInfo GetEnumTypeTypeId(const EnumType& enum_type) { + return enum_type.TypeId(); +} + // Base implementation of persistent and transient handles for types. This // contains implementation details shared among both, but is never used // directly. The derived classes are responsible for defining appropriate diff --git a/base/internal/type.pre.h b/base/internal/type.pre.h index f8a9029e6..b7eda7950 100644 --- a/base/internal/type.pre.h +++ b/base/internal/type.pre.h @@ -20,8 +20,13 @@ #include #include "base/handle.h" +#include "internal/rtti.h" -namespace cel::base_internal { +namespace cel { + +class EnumType; + +namespace base_internal { class TypeHandleBase; template @@ -42,6 +47,10 @@ inline constexpr uintptr_t kTypeHandleMask = ~kTypeHandleBits; class ListTypeImpl; class MapTypeImpl; -} // namespace cel::base_internal +internal::TypeInfo GetEnumTypeTypeId(const EnumType& enum_type); + +} // namespace base_internal + +} // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_PRE_H_ diff --git a/base/internal/value.post.h b/base/internal/value.post.h index bc7dfe899..197fb6c11 100644 --- a/base/internal/value.post.h +++ b/base/internal/value.post.h @@ -37,6 +37,10 @@ namespace cel { namespace base_internal { +inline internal::TypeInfo GetEnumValueTypeId(const EnumValue& enum_value) { + return enum_value.TypeId(); +} + // Implementation of BytesValue that is stored inlined within a handle. Since // absl::Cord is reference counted itself, this is more efficient than storing // this on the heap. diff --git a/base/internal/value.pre.h b/base/internal/value.pre.h index 837e2f9d5..ce8ec888c 100644 --- a/base/internal/value.pre.h +++ b/base/internal/value.pre.h @@ -22,8 +22,13 @@ #include #include "base/handle.h" +#include "internal/rtti.h" -namespace cel::base_internal { +namespace cel { + +class EnumValue; + +namespace base_internal { class ValueHandleBase; template @@ -41,6 +46,8 @@ inline constexpr uintptr_t kValueHandleBits = kValueHandleManaged | kValueHandleUnmanaged; inline constexpr uintptr_t kValueHandleMask = ~kValueHandleBits; +internal::TypeInfo GetEnumValueTypeId(const EnumValue& enum_value); + class InlinedCordBytesValue; class InlinedStringViewBytesValue; class StringBytesValue; @@ -157,6 +164,8 @@ struct ExternalData final { std::unique_ptr releaser; }; -} // namespace cel::base_internal +} // namespace base_internal + +} // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_PRE_H_ diff --git a/base/type.h b/base/type.h index 57bc7fc2e..ff42c8739 100644 --- a/base/type.h +++ b/base/type.h @@ -29,6 +29,8 @@ #include "base/internal/type.pre.h" // IWYU pragma: export #include "base/kind.h" #include "base/memory_manager.h" +#include "internal/casts.h" +#include "internal/rtti.h" namespace cel { @@ -481,6 +483,8 @@ class EnumType : public Type { int64_t number) const = 0; private: + friend internal::TypeInfo base_internal::GetEnumTypeTypeId( + const EnumType& enum_type); struct NewInstanceVisitor; struct FindConstantVisitor; @@ -498,6 +502,9 @@ class EnumType : public Type { EnumType(EnumType&&) = delete; std::pair SizeAndAlignment() const override = 0; + + // Called by CEL_IMPLEMENT_ENUM_TYPE() and Is() to perform type checking. + virtual internal::TypeInfo TypeId() const = 0; }; // CEL_DECLARE_ENUM_TYPE declares `enum_type` as an enumeration type. It must be @@ -508,11 +515,15 @@ class EnumType : public Type { // private: // CEL_DECLARE_ENUM_TYPE(MyEnumType); // }; -#define CEL_DECLARE_ENUM_TYPE(enum_type) \ - private: \ - friend class ::cel::base_internal::TypeHandleBase; \ - \ - ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; +#define CEL_DECLARE_ENUM_TYPE(enum_type) \ + private: \ + friend class ::cel::base_internal::TypeHandleBase; \ + \ + static bool Is(const ::cel::Type& type); \ + \ + ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ + \ + ::cel::internal::TypeInfo TypeId() const override; // CEL_IMPLEMENT_ENUM_TYPE implements `enum_type` as an enumeration type. It // must be called after the class definition of `enum_type`. @@ -524,21 +535,32 @@ class EnumType : public Type { // }; // // CEL_IMPLEMENT_ENUM_TYPE(MyEnumType); -#define CEL_IMPLEMENT_ENUM_TYPE(enum_type) \ - static_assert(::std::is_base_of_v<::cel::EnumType, enum_type>, \ - #enum_type " must inherit from cel::EnumType"); \ - static_assert(!::std::is_abstract_v, \ - "this must not be abstract"); \ - \ - ::std::pair<::std::size_t, ::std::size_t> enum_type::SizeAndAlignment() \ - const { \ - static_assert( \ - ::std::is_same_v>>, \ - "this must be the same as " #enum_type); \ - return ::std::pair<::std::size_t, ::std::size_t>(sizeof(enum_type), \ - alignof(enum_type)); \ +#define CEL_IMPLEMENT_ENUM_TYPE(enum_type) \ + static_assert(::std::is_base_of_v<::cel::EnumType, enum_type>, \ + #enum_type " must inherit from cel::EnumType"); \ + static_assert(!::std::is_abstract_v, \ + "this must not be abstract"); \ + \ + bool enum_type::Is(const ::cel::Type& type) { \ + return type.kind() == ::cel::Kind::kEnum && \ + ::cel::base_internal::GetEnumTypeTypeId( \ + ::cel::internal::down_cast(type)) == \ + ::cel::internal::TypeId(); \ + } \ + \ + ::std::pair<::std::size_t, ::std::size_t> enum_type::SizeAndAlignment() \ + const { \ + static_assert( \ + ::std::is_same_v>>, \ + "this must be the same as " #enum_type); \ + return ::std::pair<::std::size_t, ::std::size_t>(sizeof(enum_type), \ + alignof(enum_type)); \ + } \ + \ + ::cel::internal::TypeInfo enum_type::TypeId() const { \ + return ::cel::internal::TypeId(); \ } // ListType represents a list type. A list is a sequential container where each diff --git a/base/type_factory.h b/base/type_factory.h index 83014eaad..268b993e8 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -35,9 +35,6 @@ namespace cel { // forbidden outside of the CEL codebase. class TypeFactory { private: - template - using PropagateConstT = std::conditional_t, const U, U>; - template using EnableIfBaseOfT = std::enable_if_t>, V>; @@ -79,13 +76,10 @@ class TypeFactory { ABSL_ATTRIBUTE_LIFETIME_BOUND; template - EnableIfBaseOfT>>> - CreateEnumType(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { - return base_internal::PersistentHandleFactory>::template Make>(memory_manager(), - std::forward( - args)...); + EnableIfBaseOfT>> CreateEnumType( + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal::PersistentHandleFactory::template Make< + std::remove_const_t>(memory_manager(), std::forward(args)...); } absl::StatusOr> CreateListType( diff --git a/base/type_test.cc b/base/type_test.cc index f679410bc..4ea75f412 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -435,6 +435,7 @@ TEST(Type, Enum) { EXPECT_FALSE(enum_type.Is()); EXPECT_FALSE(enum_type.Is()); EXPECT_TRUE(enum_type.Is()); + EXPECT_TRUE(enum_type.Is()); EXPECT_FALSE(enum_type.Is()); EXPECT_FALSE(enum_type.Is()); } diff --git a/base/value.h b/base/value.h index 60072e20a..41cff1f09 100644 --- a/base/value.h +++ b/base/value.h @@ -34,6 +34,8 @@ #include "base/kind.h" #include "base/memory_manager.h" #include "base/type.h" +#include "internal/casts.h" +#include "internal/rtti.h" namespace cel { @@ -595,6 +597,8 @@ class EnumValue : public Value { EnumValue() = default; private: + friend internal::TypeInfo base_internal::GetEnumValueTypeId( + const EnumValue& enum_value); template friend class base_internal::ValueHandle; friend class base_internal::ValueHandleBase; @@ -611,6 +615,9 @@ class EnumValue : public Value { std::pair SizeAndAlignment() const override = 0; + // Called by CEL_IMPLEMENT_ENUM_VALUE() and Is() to perform type checking. + virtual internal::TypeInfo TypeId() const = 0; + // Set lazily, by EnumValue::New. Persistent type_; }; @@ -623,11 +630,15 @@ class EnumValue : public Value { // private: // CEL_DECLARE_ENUM_VALUE(MyEnumValue); // }; -#define CEL_DECLARE_ENUM_VALUE(enum_value) \ - private: \ - friend class ::cel::base_internal::ValueHandleBase; \ - \ - ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; +#define CEL_DECLARE_ENUM_VALUE(enum_value) \ + private: \ + friend class ::cel::base_internal::ValueHandleBase; \ + \ + static bool Is(const ::cel::Value& value); \ + \ + ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ + \ + ::cel::internal::TypeInfo TypeId() const override; // CEL_IMPLEMENT_ENUM_VALUE implements `enum_value` as an enumeration value. It // must be called after the class definition of `enum_value`. @@ -639,21 +650,32 @@ class EnumValue : public Value { // }; // // CEL_IMPLEMENT_ENUM_VALUE(MyEnumValue); -#define CEL_IMPLEMENT_ENUM_VALUE(enum_value) \ - static_assert(::std::is_base_of_v<::cel::EnumValue, enum_value>, \ - #enum_value " must inherit from cel::EnumValue"); \ - static_assert(!::std::is_abstract_v, \ - "this must not be abstract"); \ - \ - ::std::pair<::std::size_t, ::std::size_t> enum_value::SizeAndAlignment() \ - const { \ - static_assert( \ - ::std::is_same_v>>, \ - "this must be the same as " #enum_value); \ - return ::std::pair<::std::size_t, ::std::size_t>(sizeof(enum_value), \ - alignof(enum_value)); \ +#define CEL_IMPLEMENT_ENUM_VALUE(enum_value) \ + static_assert(::std::is_base_of_v<::cel::EnumValue, enum_value>, \ + #enum_value " must inherit from cel::EnumValue"); \ + static_assert(!::std::is_abstract_v, \ + "this must not be abstract"); \ + \ + bool enum_value::Is(const ::cel::Value& value) { \ + return value.kind() == ::cel::Kind::kEnum && \ + ::cel::base_internal::GetEnumValueTypeId( \ + ::cel::internal::down_cast(value)) == \ + ::cel::internal::TypeId(); \ + } \ + \ + ::std::pair<::std::size_t, ::std::size_t> enum_value::SizeAndAlignment() \ + const { \ + static_assert( \ + ::std::is_same_v>>, \ + "this must be the same as " #enum_value); \ + return ::std::pair<::std::size_t, ::std::size_t>(sizeof(enum_value), \ + alignof(enum_value)); \ + } \ + \ + ::cel::internal::TypeInfo enum_value::TypeId() const { \ + return ::cel::internal::TypeId(); \ } } // namespace cel diff --git a/base/value_factory.h b/base/value_factory.h index 22d2d27f6..e5321b733 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -35,9 +35,6 @@ namespace cel { class ValueFactory final { private: - template - using PropagateConstT = std::conditional_t, const U, U>; - template using EnableIfBaseOfT = std::enable_if_t>, V>; @@ -142,13 +139,10 @@ class ValueFactory final { absl::Time value) ABSL_ATTRIBUTE_LIFETIME_BOUND; template - EnableIfBaseOfT>>> - CreateEnumValue(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { - return base_internal:: - PersistentHandleFactory>::template Make< - std::remove_const_t>(memory_manager(), - std::forward(args)...); + EnableIfBaseOfT>> CreateEnumValue( + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal::PersistentHandleFactory::template Make< + std::remove_const_t>(memory_manager(), std::forward(args)...); } private: diff --git a/base/value_test.cc b/base/value_test.cc index 8b69644ef..75a991544 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -1452,6 +1452,7 @@ TEST(Value, Enum) { auto one_value, EnumValue::New(enum_type, value_factory, EnumType::ConstantId("VALUE1"))); EXPECT_TRUE(one_value.Is()); + EXPECT_TRUE(one_value.Is()); EXPECT_FALSE(one_value.Is()); EXPECT_EQ(one_value, one_value); EXPECT_EQ(one_value, Must(EnumValue::New(enum_type, value_factory, @@ -1465,6 +1466,7 @@ TEST(Value, Enum) { auto two_value, EnumValue::New(enum_type, value_factory, EnumType::ConstantId("VALUE2"))); EXPECT_TRUE(two_value.Is()); + EXPECT_TRUE(two_value.Is()); EXPECT_FALSE(two_value.Is()); EXPECT_EQ(two_value, two_value); EXPECT_EQ(two_value->kind(), Kind::kEnum); From b06fa1cbb5840b7c5cffac9395867496df5fe296 Mon Sep 17 00:00:00 2001 From: jcking Date: Mon, 28 Mar 2022 15:48:15 +0000 Subject: [PATCH 045/111] Internal change PiperOrigin-RevId: 437768872 --- base/BUILD | 1 + base/internal/type.post.h | 5 ++ base/internal/type.pre.h | 3 + base/type.cc | 19 +++++ base/type.h | 140 ++++++++++++++++++++++++++++++++ base/type_factory.h | 7 ++ base/type_manager.h | 5 +- base/type_test.cc | 165 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 344 insertions(+), 1 deletion(-) diff --git a/base/BUILD b/base/BUILD index cde3192af..e56c465ca 100644 --- a/base/BUILD +++ b/base/BUILD @@ -140,6 +140,7 @@ cc_test( ":type", ":value", "//internal:testing", + "@com_google_absl//absl/hash", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/status", ], diff --git a/base/internal/type.post.h b/base/internal/type.post.h index 4081dadff..35111acc9 100644 --- a/base/internal/type.post.h +++ b/base/internal/type.post.h @@ -37,6 +37,10 @@ inline internal::TypeInfo GetEnumTypeTypeId(const EnumType& enum_type) { return enum_type.TypeId(); } +inline internal::TypeInfo GetStructTypeTypeId(const StructType& struct_type) { + return struct_type.TypeId(); +} + // Base implementation of persistent and transient handles for types. This // contains implementation details shared among both, but is never used // directly. The derived classes are responsible for defining appropriate @@ -270,6 +274,7 @@ CEL_INTERNAL_TYPE_DECL(StringType); CEL_INTERNAL_TYPE_DECL(DurationType); CEL_INTERNAL_TYPE_DECL(TimestampType); CEL_INTERNAL_TYPE_DECL(EnumType); +CEL_INTERNAL_TYPE_DECL(StructType); CEL_INTERNAL_TYPE_DECL(ListType); CEL_INTERNAL_TYPE_DECL(MapType); #undef CEL_INTERNAL_TYPE_DECL diff --git a/base/internal/type.pre.h b/base/internal/type.pre.h index b7eda7950..2886ac5fc 100644 --- a/base/internal/type.pre.h +++ b/base/internal/type.pre.h @@ -25,6 +25,7 @@ namespace cel { class EnumType; +class StructType; namespace base_internal { @@ -49,6 +50,8 @@ class MapTypeImpl; internal::TypeInfo GetEnumTypeTypeId(const EnumType& enum_type); +internal::TypeInfo GetStructTypeTypeId(const StructType& struct_type); + } // namespace base_internal } // namespace cel diff --git a/base/type.cc b/base/type.cc index 9773ace4a..fca46d53a 100644 --- a/base/type.cc +++ b/base/type.cc @@ -20,6 +20,7 @@ #include "absl/types/span.h" #include "absl/types/variant.h" #include "base/handle.h" +#include "base/type_manager.h" #include "internal/casts.h" #include "internal/no_destructor.h" @@ -141,6 +142,24 @@ absl::StatusOr EnumType::FindConstant(ConstantId id) const { return absl::visit(FindConstantVisitor{*this}, id.data_); } +struct StructType::FindFieldVisitor final { + const StructType& struct_type; + TypeManager& type_manager; + + absl::StatusOr operator()(absl::string_view name) const { + return struct_type.FindFieldByName(type_manager, name); + } + + absl::StatusOr operator()(int64_t number) const { + return struct_type.FindFieldByNumber(type_manager, number); + } +}; + +absl::StatusOr StructType::FindField( + TypeManager& type_manager, FieldId id) const { + return absl::visit(FindFieldVisitor{*this, type_manager}, id.data_); +} + bool ListType::Equals(const Type& other) const { if (kind() != other.kind()) { return false; diff --git a/base/type.h b/base/type.h index ff42c8739..4cc10bef5 100644 --- a/base/type.h +++ b/base/type.h @@ -52,6 +52,7 @@ class ListType; class MapType; class TypeFactory; class TypeProvider; +class TypeManager; class NullValue; class ErrorValue; @@ -100,6 +101,7 @@ class Type : public base_internal::Resource { friend class DurationType; friend class TimestampType; friend class EnumType; + friend class StructType; friend class ListType; friend class MapType; friend class base_internal::TypeHandleBase; @@ -563,6 +565,131 @@ class EnumType : public Type { return ::cel::internal::TypeId(); \ } +// StructType represents an struct type. An struct is a set of fields +// that can be looked up by name and/or number. +class StructType : public Type { + public: + struct Field; + + class FieldId final { + public: + explicit FieldId(absl::string_view name) + : data_(absl::in_place_type, name) {} + + explicit FieldId(int64_t number) + : data_(absl::in_place_type, number) {} + + FieldId() = delete; + + FieldId(const FieldId&) = default; + FieldId& operator=(const FieldId&) = default; + + private: + friend class StructType; + + absl::variant data_; + }; + + Kind kind() const final { return Kind::kStruct; } + + absl::Span> parameters() const final { + return Type::parameters(); + } + + // Find the field definition for the given identifier. + absl::StatusOr FindField(TypeManager& type_manager, FieldId id) const; + + protected: + StructType() = default; + + // TODO(issues/5): NewInstance + + // Called by FindField. + virtual absl::StatusOr FindFieldByName( + TypeManager& type_manager, absl::string_view name) const = 0; + + // Called by FindField. + virtual absl::StatusOr FindFieldByNumber(TypeManager& type_manager, + int64_t number) const = 0; + + private: + friend internal::TypeInfo base_internal::GetStructTypeTypeId( + const StructType& struct_type); + struct FindFieldVisitor; + + friend struct FindFieldVisitor; + friend class TypeFactory; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kStruct; } + + StructType(const StructType&) = delete; + StructType(StructType&&) = delete; + + std::pair SizeAndAlignment() const override = 0; + + // Called by CEL_IMPLEMENT_STRUCT_TYPE() and Is() to perform type checking. + virtual internal::TypeInfo TypeId() const = 0; +}; + +// CEL_DECLARE_STRUCT_TYPE declares `struct_type` as an struct type. It must be +// part of the class definition of `struct_type`. +// +// class MyStructType : public cel::StructType { +// ... +// private: +// CEL_DECLARE_STRUCT_TYPE(MyStructType); +// }; +#define CEL_DECLARE_STRUCT_TYPE(struct_type) \ + private: \ + friend class ::cel::base_internal::TypeHandleBase; \ + \ + static bool Is(const ::cel::Type& type); \ + \ + ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ + \ + ::cel::internal::TypeInfo TypeId() const override; + +// CEL_IMPLEMENT_ENUM_TYPE implements `struct_type` as an struct type. It +// must be called after the class definition of `struct_type`. +// +// class MyStructType : public cel::StructType { +// ... +// private: +// CEL_DECLARE_STRUCT_TYPE(MyStructType); +// }; +// +// CEL_IMPLEMENT_STRUCT_TYPE(MyStructType); +#define CEL_IMPLEMENT_STRUCT_TYPE(struct_type) \ + static_assert(::std::is_base_of_v<::cel::StructType, struct_type>, \ + #struct_type " must inherit from cel::StructType"); \ + static_assert(!::std::is_abstract_v, \ + "this must not be abstract"); \ + \ + bool struct_type::Is(const ::cel::Type& type) { \ + return type.kind() == ::cel::Kind::kStruct && \ + ::cel::base_internal::GetStructTypeTypeId( \ + ::cel::internal::down_cast(type)) == \ + ::cel::internal::TypeId(); \ + } \ + \ + ::std::pair<::std::size_t, ::std::size_t> struct_type::SizeAndAlignment() \ + const { \ + static_assert( \ + ::std::is_same_v>>, \ + "this must be the same as " #struct_type); \ + return ::std::pair<::std::size_t, ::std::size_t>(sizeof(struct_type), \ + alignof(struct_type)); \ + } \ + \ + ::cel::internal::TypeInfo struct_type::TypeId() const { \ + return ::cel::internal::TypeId(); \ + } + // ListType represents a list type. A list is a sequential container where each // element is the same type. class ListType : public Type { @@ -662,6 +789,19 @@ struct EnumType::Constant final { int64_t number; }; +struct StructType::Field final { + explicit Field(absl::string_view name, int64_t number, + Persistent type) + : name(name), number(number), type(std::move(type)) {} + + // The field name. + absl::string_view name; + // The field number. + int64_t number; + // The field type; + Persistent type; +}; + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_TYPE_H_ diff --git a/base/type_factory.h b/base/type_factory.h index 268b993e8..0ceab92cb 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -82,6 +82,13 @@ class TypeFactory { std::remove_const_t>(memory_manager(), std::forward(args)...); } + template + EnableIfBaseOfT>> + CreateStructType(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal::PersistentHandleFactory::template Make< + std::remove_const_t>(memory_manager(), std::forward(args)...); + } + absl::StatusOr> CreateListType( const Persistent& element) ABSL_ATTRIBUTE_LIFETIME_BOUND; diff --git a/base/type_manager.h b/base/type_manager.h index e18f30f27..28353e6b7 100644 --- a/base/type_manager.h +++ b/base/type_manager.h @@ -25,7 +25,10 @@ namespace cel { // and registering type implementations. // // TODO(issues/5): more comments after solidifying role -class TypeManager : public TypeFactory, public TypeRegistry {}; +class TypeManager : public TypeFactory, public TypeRegistry { + public: + using TypeFactory::TypeFactory; +}; } // namespace cel diff --git a/base/type_test.cc b/base/type_test.cc index 4ea75f412..ad6e70ca7 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -17,11 +17,13 @@ #include #include +#include "absl/hash/hash.h" #include "absl/hash/hash_testing.h" #include "absl/status/status.h" #include "base/handle.h" #include "base/memory_manager.h" #include "base/type_factory.h" +#include "base/type_manager.h" #include "base/value.h" #include "internal/testing.h" @@ -80,6 +82,56 @@ class TestEnumType final : public EnumType { CEL_IMPLEMENT_ENUM_TYPE(TestEnumType); +// struct TestStruct { +// bool bool_field; +// int64_t int_field; +// uint64_t uint_field; +// double double_field; +// }; + +class TestStructType final : public StructType { + public: + using StructType::StructType; + + absl::string_view name() const override { return "test_struct.TestStruct"; } + + protected: + absl::StatusOr FindFieldByName(TypeManager& type_manager, + absl::string_view name) const override { + if (name == "bool_field") { + return Field("bool_field", 0, type_manager.GetBoolType()); + } else if (name == "int_field") { + return Field("int_field", 1, type_manager.GetIntType()); + } else if (name == "uint_field") { + return Field("uint_field", 2, type_manager.GetUintType()); + } else if (name == "double_field") { + return Field("double_field", 3, type_manager.GetDoubleType()); + } + return absl::NotFoundError(""); + } + + absl::StatusOr FindFieldByNumber(TypeManager& type_manager, + int64_t number) const override { + switch (number) { + case 0: + return Field("bool_field", 0, type_manager.GetBoolType()); + case 1: + return Field("int_field", 1, type_manager.GetIntType()); + case 2: + return Field("uint_field", 2, type_manager.GetUintType()); + case 3: + return Field("double_field", 3, type_manager.GetDoubleType()); + default: + return absl::NotFoundError(""); + } + } + + private: + CEL_DECLARE_STRUCT_TYPE(TestStructType); +}; + +CEL_IMPLEMENT_STRUCT_TYPE(TestStructType); + template Persistent Must(absl::StatusOr> status_or_handle) { return std::move(status_or_handle).value(); @@ -180,6 +232,7 @@ TEST(Type, Null) { EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); } @@ -201,6 +254,7 @@ TEST(Type, Error) { EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); } @@ -222,6 +276,7 @@ TEST(Type, Dyn) { EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); } @@ -243,6 +298,7 @@ TEST(Type, Any) { EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); } @@ -264,6 +320,7 @@ TEST(Type, Bool) { EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); } @@ -285,6 +342,7 @@ TEST(Type, Int) { EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); } @@ -306,6 +364,7 @@ TEST(Type, Uint) { EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); } @@ -327,6 +386,7 @@ TEST(Type, Double) { EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); } @@ -348,6 +408,7 @@ TEST(Type, String) { EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); } @@ -369,6 +430,7 @@ TEST(Type, Bytes) { EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); } @@ -390,6 +452,7 @@ TEST(Type, Duration) { EXPECT_TRUE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); } @@ -412,6 +475,7 @@ TEST(Type, Timestamp) { EXPECT_FALSE(type_factory.GetTimestampType().Is()); EXPECT_TRUE(type_factory.GetTimestampType().Is()); EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); EXPECT_FALSE(type_factory.GetTimestampType().Is()); EXPECT_FALSE(type_factory.GetTimestampType().Is()); } @@ -436,6 +500,32 @@ TEST(Type, Enum) { EXPECT_FALSE(enum_type.Is()); EXPECT_TRUE(enum_type.Is()); EXPECT_TRUE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); +} + +TEST(Type, Struct) { + TypeManager type_manager(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto enum_type, + type_manager.CreateStructType()); + EXPECT_EQ(enum_type->kind(), Kind::kStruct); + EXPECT_EQ(enum_type->name(), "test_struct.TestStruct"); + EXPECT_THAT(enum_type->parameters(), SizeIs(0)); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); + EXPECT_TRUE(enum_type.Is()); + EXPECT_TRUE(enum_type.Is()); EXPECT_FALSE(enum_type.Is()); EXPECT_FALSE(enum_type.Is()); } @@ -462,6 +552,7 @@ TEST(Type, List) { EXPECT_FALSE(list_type.Is()); EXPECT_FALSE(list_type.Is()); EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); EXPECT_TRUE(list_type.Is()); EXPECT_FALSE(list_type.Is()); } @@ -494,6 +585,7 @@ TEST(Type, Map) { EXPECT_FALSE(map_type.Is()); EXPECT_FALSE(map_type.Is()); EXPECT_FALSE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); EXPECT_FALSE(map_type.Is()); EXPECT_TRUE(map_type.Is()); } @@ -529,6 +621,70 @@ TEST(EnumType, FindConstant) { StatusIs(absl::StatusCode::kNotFound)); } +TEST(StructType, FindField) { + TypeManager type_manager(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto struct_type, + type_manager.CreateStructType()); + + ASSERT_OK_AND_ASSIGN( + auto field1, + struct_type->FindField(type_manager, StructType::FieldId("bool_field"))); + EXPECT_EQ(field1.name, "bool_field"); + EXPECT_EQ(field1.number, 0); + EXPECT_EQ(field1.type, type_manager.GetBoolType()); + + ASSERT_OK_AND_ASSIGN( + field1, struct_type->FindField(type_manager, StructType::FieldId(0))); + EXPECT_EQ(field1.name, "bool_field"); + EXPECT_EQ(field1.number, 0); + EXPECT_EQ(field1.type, type_manager.GetBoolType()); + + ASSERT_OK_AND_ASSIGN( + auto field2, + struct_type->FindField(type_manager, StructType::FieldId("int_field"))); + EXPECT_EQ(field2.name, "int_field"); + EXPECT_EQ(field2.number, 1); + EXPECT_EQ(field2.type, type_manager.GetIntType()); + + ASSERT_OK_AND_ASSIGN( + field2, struct_type->FindField(type_manager, StructType::FieldId(1))); + EXPECT_EQ(field2.name, "int_field"); + EXPECT_EQ(field2.number, 1); + EXPECT_EQ(field2.type, type_manager.GetIntType()); + + ASSERT_OK_AND_ASSIGN( + auto field3, + struct_type->FindField(type_manager, StructType::FieldId("uint_field"))); + EXPECT_EQ(field3.name, "uint_field"); + EXPECT_EQ(field3.number, 2); + EXPECT_EQ(field3.type, type_manager.GetUintType()); + + ASSERT_OK_AND_ASSIGN( + field3, struct_type->FindField(type_manager, StructType::FieldId(2))); + EXPECT_EQ(field3.name, "uint_field"); + EXPECT_EQ(field3.number, 2); + EXPECT_EQ(field3.type, type_manager.GetUintType()); + + ASSERT_OK_AND_ASSIGN( + auto field4, struct_type->FindField(type_manager, + StructType::FieldId("double_field"))); + EXPECT_EQ(field4.name, "double_field"); + EXPECT_EQ(field4.number, 3); + EXPECT_EQ(field4.type, type_manager.GetDoubleType()); + + ASSERT_OK_AND_ASSIGN( + field4, struct_type->FindField(type_manager, StructType::FieldId(3))); + EXPECT_EQ(field4.name, "double_field"); + EXPECT_EQ(field4.number, 3); + EXPECT_EQ(field4.type, type_manager.GetDoubleType()); + + EXPECT_THAT(struct_type->FindField(type_manager, + StructType::FieldId("missing_field")), + StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(struct_type->FindField(type_manager, StructType::FieldId(4)), + StatusIs(absl::StatusCode::kNotFound)); +} + TEST(NullType, DebugString) { TypeFactory type_factory(MemoryManager::Global()); EXPECT_EQ(type_factory.GetNullType()->DebugString(), "null_type"); @@ -598,6 +754,13 @@ TEST(EnumType, DebugString) { EXPECT_EQ(enum_type->DebugString(), "test_enum.TestEnum"); } +TEST(StructType, DebugString) { + TypeManager type_manager(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto struct_type, + type_manager.CreateStructType()); + EXPECT_EQ(struct_type->DebugString(), "test_struct.TestStruct"); +} + TEST(ListType, DebugString) { TypeFactory type_factory(MemoryManager::Global()); ASSERT_OK_AND_ASSIGN(auto list_type, @@ -629,6 +792,8 @@ TEST(Type, SupportsAbslHash) { Persistent(type_factory.GetDurationType()), Persistent(type_factory.GetTimestampType()), Persistent(Must(type_factory.CreateEnumType())), + Persistent( + Must(type_factory.CreateStructType())), Persistent( Must(type_factory.CreateListType(type_factory.GetBoolType()))), Persistent(Must(type_factory.CreateMapType( From 2e2ea38b3c269fe2b8d806211e9a703270b41bf7 Mon Sep 17 00:00:00 2001 From: jcking Date: Mon, 28 Mar 2022 17:33:57 +0000 Subject: [PATCH 046/111] Internal change PiperOrigin-RevId: 437795581 --- base/type.cc | 10 ++++++++++ base/type.h | 4 ++++ base/type_test.cc | 4 ++-- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/base/type.cc b/base/type.cc index fca46d53a..cd578eaf1 100644 --- a/base/type.cc +++ b/base/type.cc @@ -17,6 +17,7 @@ #include #include +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "base/handle.h" @@ -160,6 +161,10 @@ absl::StatusOr StructType::FindField( return absl::visit(FindFieldVisitor{*this, type_manager}, id.data_); } +std::string ListType::DebugString() const { + return absl::StrCat(name(), "(", element()->DebugString(), ")"); +} + bool ListType::Equals(const Type& other) const { if (kind() != other.kind()) { return false; @@ -173,6 +178,11 @@ void ListType::HashValue(absl::HashState state) const { Type::HashValue(absl::HashState::combine(std::move(state), element())); } +std::string MapType::DebugString() const { + return absl::StrCat(name(), "(", key()->DebugString(), ", ", + value()->DebugString(), ")"); +} + bool MapType::Equals(const Type& other) const { if (kind() != other.kind()) { return false; diff --git a/base/type.h b/base/type.h index 4cc10bef5..c1d07f8a0 100644 --- a/base/type.h +++ b/base/type.h @@ -702,6 +702,8 @@ class ListType : public Type { absl::string_view name() const final { return "list"; } + std::string DebugString() const final; + // Returns the type of the elements in the list. virtual Transient element() const = 0; @@ -740,6 +742,8 @@ class MapType : public Type { absl::string_view name() const final { return "map"; } + std::string DebugString() const final; + // Returns the type of the keys in the map. virtual Transient key() const = 0; diff --git a/base/type_test.cc b/base/type_test.cc index ad6e70ca7..c0366d4db 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -765,7 +765,7 @@ TEST(ListType, DebugString) { TypeFactory type_factory(MemoryManager::Global()); ASSERT_OK_AND_ASSIGN(auto list_type, type_factory.CreateListType(type_factory.GetBoolType())); - EXPECT_EQ(list_type->DebugString(), "list"); + EXPECT_EQ(list_type->DebugString(), "list(bool)"); } TEST(MapType, DebugString) { @@ -773,7 +773,7 @@ TEST(MapType, DebugString) { ASSERT_OK_AND_ASSIGN(auto map_type, type_factory.CreateMapType(type_factory.GetStringType(), type_factory.GetBoolType())); - EXPECT_EQ(map_type->DebugString(), "map"); + EXPECT_EQ(map_type->DebugString(), "map(string, bool)"); } TEST(Type, SupportsAbslHash) { From 0bbc8970af225c317d143e82861cc2feb113b486 Mon Sep 17 00:00:00 2001 From: tswadell Date: Mon, 28 Mar 2022 18:17:55 +0000 Subject: [PATCH 047/111] Sync from GitHub PiperOrigin-RevId: 437807796 --- conformance/BUILD | 2 -- conformance/server.cc | 12 ++++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index d5748fbce..9c2408c83 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -77,8 +77,6 @@ cc_binary( # Tests which require spec changes. # TODO(issues/93): Deprecate Duration.getMilliseconds. "--skip_test=timestamps/duration_converters/get_milliseconds", - # TODO(issues/110): Tune parse limits to mirror those for proto deserialization and C++ safety limits. - "--skip_test=parse/nest/list_index,message_literal,funcall,list_literal,map_literal;repeat/conditional,add_sub,mul_div,select,index,map_literal,message_literal", # Broken test cases which should be supported. # TODO(issues/112): Unbound functions result in empty eval response. diff --git a/conformance/server.cc b/conformance/server.cc index 6a717d470..c16580026 100644 --- a/conformance/server.cc +++ b/conformance/server.cc @@ -196,8 +196,10 @@ int RunServer(bool optimize) { std::cerr << "Failed to parse JSON" << std::endl; } service.Parse(&request, &response); - if (!MessageToJsonString(response, &output).ok()) { - std::cerr << "Failed to convert to JSON" << std::endl; + auto status = MessageToJsonString(response, &output); + if (!status.ok()) { + std::cerr << "Failed to convert to JSON:" << status.ToString() + << std::endl; } } else if (cmd == "eval") { conformance::v1alpha1::EvalRequest request; @@ -206,8 +208,10 @@ int RunServer(bool optimize) { std::cerr << "Failed to parse JSON" << std::endl; } service.Eval(&request, &response); - if (!MessageToJsonString(response, &output).ok()) { - std::cerr << "Failed to convert to JSON" << std::endl; + auto status = MessageToJsonString(response, &output); + if (!status.ok()) { + std::cerr << "Failed to convert to JSON:" << status.ToString() + << std::endl; } } else if (cmd.empty()) { return 0; From b8c09126b3e1405f216326ff2f8e2ad84b4b2a71 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Mon, 28 Mar 2022 23:14:00 +0000 Subject: [PATCH 048/111] Add accessor APIs to legacy type adapter. PiperOrigin-RevId: 437879303 --- eval/public/structs/BUILD | 3 + .../structs/proto_message_type_adapter.cc | 68 ++++- .../structs/proto_message_type_adapter.h | 10 +- .../proto_message_type_adapter_test.cc | 271 ++++++++++++++++-- .../protobuf_descriptor_type_provider.cc | 4 +- .../protobuf_descriptor_type_provider.h | 5 +- eval/public/testing/matchers.cc | 15 +- eval/public/testing/matchers.h | 3 + eval/public/testing/matchers_test.cc | 15 +- 9 files changed, 348 insertions(+), 46 deletions(-) diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 180829047..75ff1ec11 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -129,6 +129,8 @@ cc_library( "//base:memory_manager", "//eval/public:cel_value", "//eval/public/containers:field_access", + "//eval/public/containers:field_backed_list_impl", + "//eval/public/containers:field_backed_map_impl", "//extensions/protobuf:memory_manager", "//internal:status_macros", "@com_google_absl//absl/status", @@ -146,6 +148,7 @@ cc_test( "//eval/public:cel_value", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", + "//eval/public/containers:field_access", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//extensions/protobuf:memory_manager", diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index abefd239f..d48213583 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -22,12 +22,17 @@ #include "absl/strings/substitute.h" #include "eval/public/cel_value.h" #include "eval/public/containers/field_access.h" +#include "eval/public/containers/field_backed_list_impl.h" +#include "eval/public/containers/field_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { +using ::cel::extensions::ProtoMemoryManager; +using ::google::protobuf::FieldDescriptor; using ::google::protobuf::Message; +using ::google::protobuf::Reflection; absl::Status ProtoMessageTypeAdapter::ValidateSetFieldOp( bool assertion, absl::string_view field, absl::string_view detail) const { @@ -42,8 +47,7 @@ absl::Status ProtoMessageTypeAdapter::ValidateSetFieldOp( absl::StatusOr ProtoMessageTypeAdapter::NewInstance( cel::MemoryManager& memory_manager) const { // This implementation requires arena-backed memory manager. - google::protobuf::Arena* arena = - cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager); + google::protobuf::Arena* arena = ProtoMemoryManager::CastToProtoArena(memory_manager); const Message* prototype = message_factory_->GetPrototype(descriptor_); Message* msg = (prototype != nullptr) ? prototype->New(arena) : nullptr; @@ -61,13 +65,69 @@ bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { absl::StatusOr ProtoMessageTypeAdapter::HasField( absl::string_view field_name, const CelValue& value) const { - return absl::UnimplementedError("Not yet implemented."); + const google::protobuf::Message* message; + if (!value.GetValue(&message) || message == nullptr) { + return absl::InvalidArgumentError("HasField called on non-message type."); + } + + const Reflection* reflection = message->GetReflection(); + ABSL_ASSERT(descriptor_ == message->GetDescriptor()); + + const FieldDescriptor* field_desc = descriptor_->FindFieldByName(field_name.data()); + + if (field_desc == nullptr) { + return absl::NotFoundError(absl::StrCat("no_such_field : ", field_name)); + } + + if (field_desc->is_map()) { + // When the map field appears in a has(msg.map_field) expression, the map + // is considered 'present' when it is non-empty. Since maps are repeated + // fields they don't participate with standard proto presence testing since + // the repeated field is always at least empty. + return reflection->FieldSize(*message, field_desc) != 0; + } + + if (field_desc->is_repeated()) { + // When the list field appears in a has(msg.list_field) expression, the list + // is considered 'present' when it is non-empty. + return reflection->FieldSize(*message, field_desc) != 0; + } + + // Standard proto presence test for non-repeated fields. + return reflection->HasField(*message, field_desc); } absl::StatusOr ProtoMessageTypeAdapter::GetField( absl::string_view field_name, const CelValue& instance, cel::MemoryManager& memory_manager) const { - return absl::UnimplementedError("Not yet implemented."); + const google::protobuf::Message* message; + if (!instance.GetValue(&message) || message == nullptr) { + return absl::InvalidArgumentError("GetField called on non-message type."); + } + + const FieldDescriptor* field_desc = descriptor_->FindFieldByName(field_name.data()); + + if (field_desc == nullptr) { + return CreateNoSuchFieldError(memory_manager, field_name); + } + + google::protobuf::Arena* arena = ProtoMemoryManager::CastToProtoArena(memory_manager); + + if (field_desc->is_map()) { + CelMap* map = google::protobuf::Arena::Create(arena, message, + field_desc, arena); + return CelValue::CreateMap(map); + } + if (field_desc->is_repeated()) { + CelList* list = google::protobuf::Arena::Create( + arena, message, field_desc, arena); + return CelValue::CreateList(list); + } + + CelValue result; + CEL_RETURN_IF_ERROR(CreateValueFromSingleField( + message, field_desc, unboxing_option_, arena, &result)); + return result; } absl::Status ProtoMessageTypeAdapter::SetField( diff --git a/eval/public/structs/proto_message_type_adapter.h b/eval/public/structs/proto_message_type_adapter.h index 5d75927a6..7827c608b 100644 --- a/eval/public/structs/proto_message_type_adapter.h +++ b/eval/public/structs/proto_message_type_adapter.h @@ -21,6 +21,7 @@ #include "absl/strings/string_view.h" #include "base/memory_manager.h" #include "eval/public/cel_value.h" +#include "eval/public/containers/field_access.h" #include "eval/public/structs/legacy_type_adapter.h" namespace google::api::expr::runtime { @@ -29,8 +30,12 @@ class ProtoMessageTypeAdapter : public LegacyTypeAdapter::AccessApis, public LegacyTypeAdapter::MutationApis { public: ProtoMessageTypeAdapter(const google::protobuf::Descriptor* descriptor, - google::protobuf::MessageFactory* message_factory) - : message_factory_(message_factory), descriptor_(descriptor) {} + google::protobuf::MessageFactory* message_factory, + ProtoWrapperTypeOptions unboxing_option = + ProtoWrapperTypeOptions::kUnsetNull) + : message_factory_(message_factory), + descriptor_(descriptor), + unboxing_option_(unboxing_option) {} ~ProtoMessageTypeAdapter() override = default; @@ -61,6 +66,7 @@ class ProtoMessageTypeAdapter : public LegacyTypeAdapter::AccessApis, google::protobuf::MessageFactory* message_factory_; const google::protobuf::Descriptor* descriptor_; + ProtoWrapperTypeOptions unboxing_option_; }; } // namespace google::api::expr::runtime diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 40acbacb0..90b734256 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -22,6 +22,7 @@ #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/containers/field_access.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" @@ -32,32 +33,103 @@ namespace google::api::expr::runtime { namespace { +using testing::_; using testing::EqualsProto; using testing::HasSubstr; +using testing::Optional; +using cel::internal::IsOkAndHolds; using cel::internal::StatusIs; -TEST(ProtoMessageTypeAdapter, HasFieldNotYetImplemented) { +TEST(ProtoMessageTypeAdapter, HasFieldSingular) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); + + TestMessage example; + + CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + + EXPECT_THAT(adapter.HasField("int64_value", value), IsOkAndHolds(false)); + example.set_int64_value(10); + EXPECT_THAT(adapter.HasField("int64_value", value), IsOkAndHolds(true)); +} + +TEST(ProtoMessageTypeAdapter, HasFieldRepeated) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); + + TestMessage example; + + CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + + EXPECT_THAT(adapter.HasField("int64_list", value), IsOkAndHolds(false)); + example.add_int64_list(10); + EXPECT_THAT(adapter.HasField("int64_list", value), IsOkAndHolds(true)); +} + +TEST(ProtoMessageTypeAdapter, HasFieldMap) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); TestMessage example; example.set_int64_value(10); CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); - EXPECT_THAT(adapter.HasField("value", value), - StatusIs(absl::StatusCode::kUnimplemented)); + EXPECT_THAT(adapter.HasField("int64_int32_map", value), IsOkAndHolds(false)); + (*example.mutable_int64_int32_map())[2] = 3; + EXPECT_THAT(adapter.HasField("int64_int32_map", value), IsOkAndHolds(true)); } -TEST(ProtoMessageTypeAdapter, GetFieldNotYetImplemented) { +TEST(ProtoMessageTypeAdapter, HasFieldUnknownField) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); + + TestMessage example; + example.set_int64_value(10); + + CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + + EXPECT_THAT(adapter.HasField("unknown_field", value), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST(ProtoMessageTypeAdapter, HasFieldNonMessageType) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); + + CelValue value = CelValue::CreateInt64(10); + + EXPECT_THAT(adapter.HasField("unknown_field", value), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ProtoMessageTypeAdapter, GetFieldSingular) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); TestMessage example; @@ -66,7 +138,156 @@ TEST(ProtoMessageTypeAdapter, GetFieldNotYetImplemented) { CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); EXPECT_THAT(adapter.GetField("int64_value", value, manager), - StatusIs(absl::StatusCode::kUnimplemented)); + IsOkAndHolds(test::IsCelInt64(10))); +} + +TEST(ProtoMessageTypeAdapter, GetFieldNoSuchField) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); + cel::extensions::ProtoMemoryManager manager(&arena); + + TestMessage example; + example.set_int64_value(10); + + CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + + EXPECT_THAT(adapter.GetField("unknown_field", value, manager), + IsOkAndHolds(test::IsCelError(StatusIs( + absl::StatusCode::kNotFound, HasSubstr("unknown_field"))))); +} + +TEST(ProtoMessageTypeAdapter, GetFieldNotAMessage) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); + cel::extensions::ProtoMemoryManager manager(&arena); + + CelValue value = CelValue::CreateNull(); + + EXPECT_THAT(adapter.GetField("int64_value", value, manager), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ProtoMessageTypeAdapter, GetFieldRepeated) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); + cel::extensions::ProtoMemoryManager manager(&arena); + + TestMessage example; + example.add_int64_list(10); + example.add_int64_list(20); + + CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + + ASSERT_OK_AND_ASSIGN(CelValue result, + adapter.GetField("int64_list", value, manager)); + + const CelList* held_value; + ASSERT_TRUE(result.GetValue(&held_value)) << result.DebugString(); + + EXPECT_EQ(held_value->size(), 2); + EXPECT_THAT((*held_value)[0], test::IsCelInt64(10)); + EXPECT_THAT((*held_value)[1], test::IsCelInt64(20)); +} + +TEST(ProtoMessageTypeAdapter, GetFieldMap) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); + cel::extensions::ProtoMemoryManager manager(&arena); + + TestMessage example; + (*example.mutable_int64_int32_map())[10] = 20; + + CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + + ASSERT_OK_AND_ASSIGN(CelValue result, + adapter.GetField("int64_int32_map", value, manager)); + + const CelMap* held_value; + ASSERT_TRUE(result.GetValue(&held_value)) << result.DebugString(); + + EXPECT_EQ(held_value->size(), 1); + EXPECT_THAT((*held_value)[CelValue::CreateInt64(10)], + Optional(test::IsCelInt64(20))); +} + +TEST(ProtoMessageTypeAdapter, GetFieldWrapperType) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); + cel::extensions::ProtoMemoryManager manager(&arena); + + TestMessage example; + example.mutable_int64_wrapper_value()->set_value(10); + + CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + + EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), + IsOkAndHolds(test::IsCelInt64(10))); +} + +TEST(ProtoMessageTypeAdapter, GetFieldWrapperTypeUnsetNullUnbox) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); + cel::extensions::ProtoMemoryManager manager(&arena); + + TestMessage example; + + CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + + EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), + IsOkAndHolds(test::IsCelNull())); + + // Wrapper field present, but default value. + example.mutable_int64_wrapper_value()->clear_value(); + EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), + IsOkAndHolds(test::IsCelInt64(_))); +} + +TEST(ProtoMessageTypeAdapter, GetFieldWrapperTypeUnsetDefaultValueUnbox) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetProtoDefault); + cel::extensions::ProtoMemoryManager manager(&arena); + + TestMessage example; + + CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + + EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), + IsOkAndHolds(test::IsCelInt64(_))); + + // Wrapper field present with unset value is used to signal Null, but legacy + // behavior just returns the proto default value. + example.mutable_int64_wrapper_value()->clear_value(); + // Same behavior for this option. + EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), + IsOkAndHolds(test::IsCelInt64(_))); } TEST(ProtoMessageTypeAdapter, NewInstance) { @@ -74,7 +295,8 @@ TEST(ProtoMessageTypeAdapter, NewInstance) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue result, adapter.NewInstance(manager)); @@ -97,7 +319,8 @@ TEST(ProtoMessageTypeAdapter, NewInstanceUnsupportedDescriptor) { ProtoMessageTypeAdapter adapter( pool.FindMessageTypeByName("google.api.expr.runtime.FakeMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); // Message factory doesn't know how to create our custom message, even though @@ -111,7 +334,8 @@ TEST(ProtoMessageTypeAdapter, DefinesField) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); EXPECT_TRUE(adapter.DefinesField("int64_value")); EXPECT_FALSE(adapter.DefinesField("not_a_field")); @@ -122,7 +346,8 @@ TEST(ProtoMessageTypeAdapter, SetFieldSingular) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue value, adapter.NewInstance(manager)); @@ -145,7 +370,8 @@ TEST(ProtoMessageTypeAdapter, SetFieldMap) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); CelMapBuilder builder; @@ -172,7 +398,8 @@ TEST(ProtoMessageTypeAdapter, SetFieldRepeated) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); ContainerBackedListImpl list( @@ -194,7 +421,8 @@ TEST(ProtoMessageTypeAdapter, SetFieldNotAField) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); @@ -210,7 +438,8 @@ TEST(ProtoMesssageTypeAdapter, SetFieldWrongType) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); ContainerBackedListImpl list( @@ -249,7 +478,8 @@ TEST(ProtoMesssageTypeAdapter, SetFieldNotAMessage) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); CelValue int_value = CelValue::CreateInt64(42); @@ -264,7 +494,8 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownType) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.protobuf.Int64Value"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); @@ -281,7 +512,8 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeUnspecial) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); @@ -299,7 +531,8 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeNotAMessageError) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + google::protobuf::MessageFactory::generated_factory(), + ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); CelValue instance = CelValue::CreateNull(); diff --git a/eval/public/structs/protobuf_descriptor_type_provider.cc b/eval/public/structs/protobuf_descriptor_type_provider.cc index 8c96c6b38..65e7bc48d 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.cc +++ b/eval/public/structs/protobuf_descriptor_type_provider.cc @@ -49,7 +49,7 @@ std::unique_ptr ProtobufDescriptorProvider::GetType( return nullptr; } - return std::make_unique(descriptor, - message_factory_); + return std::make_unique(descriptor, message_factory_, + unboxing_option_); } } // namespace google::api::expr::runtime diff --git a/eval/public/structs/protobuf_descriptor_type_provider.h b/eval/public/structs/protobuf_descriptor_type_provider.h index c5091ff2d..4d745a1c7 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.h +++ b/eval/public/structs/protobuf_descriptor_type_provider.h @@ -36,7 +36,9 @@ class ProtobufDescriptorProvider : public LegacyTypeProvider { public: ProtobufDescriptorProvider(const google::protobuf::DescriptorPool* pool, google::protobuf::MessageFactory* factory) - : descriptor_pool_(pool), message_factory_(factory) {} + : descriptor_pool_(pool), + message_factory_(factory), + unboxing_option_(ProtoWrapperTypeOptions::kUnsetNull) {} absl::optional ProvideLegacyType( absl::string_view name) const override; @@ -49,6 +51,7 @@ class ProtobufDescriptorProvider : public LegacyTypeProvider { const google::protobuf::DescriptorPool* descriptor_pool_; google::protobuf::MessageFactory* message_factory_; + ProtoWrapperTypeOptions unboxing_option_; mutable absl::flat_hash_map> type_cache_; diff --git a/eval/public/testing/matchers.cc b/eval/public/testing/matchers.cc index a8333d210..d9e52c7fd 100644 --- a/eval/public/testing/matchers.cc +++ b/eval/public/testing/matchers.cc @@ -7,10 +7,7 @@ #include "absl/strings/string_view.h" #include "eval/public/set_util.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { void PrintTo(const CelValue& value, std::ostream* os) { *os << value.DebugString(); @@ -19,6 +16,7 @@ void PrintTo(const CelValue& value, std::ostream* os) { namespace test { namespace { +using testing::_; using testing::MatcherInterface; using testing::MatchResultListener; @@ -68,6 +66,10 @@ CelValueMatcher EqualsCelValue(const CelValue& v) { return CelValueMatcher(new CelValueEqualImpl(v)); } +CelValueMatcher IsCelNull() { + return CelValueMatcher(new CelValueMatcherImpl(_)); +} + CelValueMatcher IsCelBool(testing::Matcher m) { return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } @@ -114,7 +116,4 @@ CelValueMatcher IsCelError(testing::Matcher m) { } } // namespace test -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/testing/matchers.h b/eval/public/testing/matchers.h index 5d8d2e70c..82515d8e4 100644 --- a/eval/public/testing/matchers.h +++ b/eval/public/testing/matchers.h @@ -28,6 +28,9 @@ using CelValueMatcher = testing::Matcher; // Tests equality to CelValue v using the set_util implementation. CelValueMatcher EqualsCelValue(const CelValue& v); +// Matches CelValues of type null. +CelValueMatcher IsCelNull(); + // Matches CelValues of type bool whose held value matches |m|. CelValueMatcher IsCelBool(testing::Matcher m); diff --git a/eval/public/testing/matchers_test.cc b/eval/public/testing/matchers_test.cc index 6b30a40af..6a39b2572 100644 --- a/eval/public/testing/matchers_test.cc +++ b/eval/public/testing/matchers_test.cc @@ -8,11 +8,7 @@ #include "internal/testing.h" #include "testutil/util.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { -namespace test { +namespace google::api::expr::runtime::test { namespace { using testing::Contains; @@ -64,6 +60,9 @@ TEST(IsCelValue, EqualitySmoketest) { } TEST(PrimitiveMatchers, Smoketest) { + EXPECT_THAT(CelValue::CreateNull(), IsCelNull()); + EXPECT_THAT(CelValue::CreateBool(false), Not(IsCelNull())); + EXPECT_THAT(CelValue::CreateBool(true), IsCelBool(true)); EXPECT_THAT(CelValue::CreateBool(false), IsCelBool(Not(true))); @@ -153,8 +152,4 @@ TEST(ListMatchers, All) { } } // namespace -} // namespace test -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime::test From a814dc628c14792f03863a037614e36e447543e4 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 29 Mar 2022 00:50:49 +0000 Subject: [PATCH 049/111] Update internal value representation to use MessageLite (with optional down cast to full message) PiperOrigin-RevId: 437898181 --- eval/eval/BUILD | 3 ++ eval/eval/const_value_step.cc | 7 ++- eval/eval/const_value_step_test.cc | 48 ++++++++++++++++- eval/public/BUILD | 9 ++++ eval/public/cel_attribute.cc | 67 ++++++++++++++++++----- eval/public/cel_attribute.h | 68 ++++++------------------ eval/public/cel_attribute_test.cc | 41 +++++++++------ eval/public/cel_function_adapter.cc | 6 ++- eval/public/cel_function_adapter.h | 5 ++ eval/public/cel_value.h | 73 ++++++++++++++++++++++---- eval/public/cel_value_internal.h | 61 +++++++++++++++++++++ eval/public/cel_value_test.cc | 30 +++++++++++ eval/public/containers/BUILD | 1 + eval/public/containers/field_access.cc | 46 +++------------- eval/public/set_util.cc | 36 ++++++++----- eval/public/testing/BUILD | 1 + eval/public/testing/matchers.cc | 30 ++++++++++- 17 files changed, 385 insertions(+), 147 deletions(-) diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 3d4d5f44b..ae44d8b1f 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -88,6 +88,7 @@ cc_library( ":expression_step_base", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", + "//internal:proto_util", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], @@ -351,9 +352,11 @@ cc_test( ":evaluator_core", ":test_type_registry", "//eval/public:activation", + "//eval/public/testing:matchers", "//internal:status_macros", "//internal:testing", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], diff --git a/eval/eval/const_value_step.cc b/eval/eval/const_value_step.cc index 7a305cff0..f010abc7d 100644 --- a/eval/eval/const_value_step.cc +++ b/eval/eval/const_value_step.cc @@ -7,6 +7,7 @@ #include "absl/status/statusor.h" #include "eval/eval/expression_step_base.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "internal/proto_util.h" namespace google::api::expr::runtime { @@ -58,10 +59,12 @@ absl::optional ConvertConstant(const Constant* const_expr) { value = CelValue::CreateBytes(&const_expr->bytes_value()); break; case Constant::kDurationValue: - value = CelProtoWrapper::CreateDuration(&const_expr->duration_value()); + value = CelValue::CreateDuration( + expr::internal::DecodeDuration(const_expr->duration_value())); break; case Constant::kTimestampValue: - value = CelProtoWrapper::CreateTimestamp(&const_expr->timestamp_value()); + value = CelValue::CreateTimestamp( + expr::internal::DecodeTime(const_expr->timestamp_value())); break; default: // constant with no kind specified diff --git a/eval/eval/const_value_step_test.cc b/eval/eval/const_value_step_test.cc index b5f351309..fa339ea93 100644 --- a/eval/eval/const_value_step_test.cc +++ b/eval/eval/const_value_step_test.cc @@ -3,11 +3,15 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" #include "google/protobuf/descriptor.h" #include "absl/status/statusor.h" +#include "absl/time/time.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" +#include "eval/public/testing/matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" @@ -17,8 +21,10 @@ namespace { using testing::Eq; -using google::api::expr::v1alpha1::Constant; -using google::api::expr::v1alpha1::Expr; +using ::google::api::expr::v1alpha1::Constant; +using ::google::api::expr::v1alpha1::Expr; +using ::google::protobuf::Duration; +using ::google::protobuf::Timestamp; using google::protobuf::Arena; @@ -162,6 +168,44 @@ TEST(ConstValueStepTest, TestEvaluationConstBytes) { EXPECT_THAT(value.BytesOrDie().value(), Eq("test")); } +TEST(ConstValueStepTest, TestEvaluationConstDuration) { + Expr expr; + auto const_expr = expr.mutable_const_expr(); + Duration* duration = const_expr->mutable_duration_value(); + duration->set_seconds(5); + duration->set_nanos(2000); + + google::protobuf::Arena arena; + + auto status = RunConstantExpression(&expr, const_expr, &arena); + + ASSERT_OK(status); + + auto value = status.value(); + + EXPECT_THAT(value, + test::IsCelDuration(absl::Seconds(5) + absl::Nanoseconds(2000))); +} + +TEST(ConstValueStepTest, TestEvaluationConstTimestamp) { + Expr expr; + auto const_expr = expr.mutable_const_expr(); + Timestamp* timestamp_proto = const_expr->mutable_timestamp_value(); + timestamp_proto->set_seconds(3600); + timestamp_proto->set_nanos(1000); + + google::protobuf::Arena arena; + + auto status = RunConstantExpression(&expr, const_expr, &arena); + + ASSERT_OK(status); + + auto value = status.value(); + + EXPECT_THAT(value, test::IsCelTimestamp(absl::FromUnixSeconds(3600) + + absl::Nanoseconds(1000))); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/public/BUILD b/eval/public/BUILD index e8266f651..718e41dd8 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -24,7 +24,11 @@ cc_library( "cel_value_internal.h", ], deps = [ + "//internal:casts", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", ], ) @@ -40,6 +44,7 @@ cc_library( ":cel_value_internal", "//base:memory_manager", "//extensions/protobuf:memory_manager", + "//internal:casts", "//internal:status_macros", "//internal:utf8", "@com_google_absl//absl/base:core_headers", @@ -165,11 +170,13 @@ cc_library( deps = [ ":cel_function", ":cel_function_registry", + ":cel_value", "//eval/public/structs:cel_proto_wrapper", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", ], ) @@ -457,6 +464,7 @@ cc_test( ":unknown_set", "//base:memory_manager", "//eval/public/testing:matchers", + "//eval/testutil:test_message_cc_proto", "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", @@ -476,6 +484,7 @@ cc_test( ":cel_attribute", ":cel_value", "//eval/public/structs:cel_proto_wrapper", + "//internal:status_macros", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/eval/public/cel_attribute.cc b/eval/public/cel_attribute.cc index 893daf81d..917413022 100644 --- a/eval/public/cel_attribute.cc +++ b/eval/public/cel_attribute.cc @@ -8,10 +8,7 @@ #include "absl/types/variant.h" #include "eval/public/cel_value.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { // Visitation for attribute qualifier kinds @@ -45,7 +42,8 @@ class CelAttributeStringPrinter { public: // String representation for the given qualifier is appended to output. // output must be non-null. - explicit CelAttributeStringPrinter(std::string* output) : output_(*output) {} + explicit CelAttributeStringPrinter(std::string* output, CelValue::Type type) + : output_(*output), type_(type) {} absl::Status operator()(int64_t index) { absl::StrAppend(&output_, "[", index, "]"); @@ -72,12 +70,54 @@ class CelAttributeStringPrinter { // Attributes are represented as generic CelValues, but remaining kinds are // not legal attribute qualifiers. return absl::InvalidArgumentError(absl::StrCat( - "Unsupported attribute qualifier ", - CelValue::TypeName(CelValue::Type(CelValue::IndexOf::value)))); + "Unsupported attribute qualifier ", CelValue::TypeName(type_))); } private: std::string& output_; + CelValue::Type type_; +}; + +// Helper class, used to implement CelAttributeQualifier::operator==. +class EqualVisitor { + public: + template + class NestedEqualVisitor { + public: + explicit NestedEqualVisitor(const T& arg) : arg_(arg) {} + + template + bool operator()(const U&) const { + return false; + } + + bool operator()(const T& other) const { return other == arg_; } + + private: + const T& arg_; + }; + // Message wrapper is unsupported. Add specialization to make visitor + // compile. + template <> + class NestedEqualVisitor { + public: + explicit NestedEqualVisitor( + const CelValue::MessageWrapper&) {} + template + bool operator()(const U&) const { + return false; + } + }; + + explicit EqualVisitor(const CelValue& other) : other_(other) {} + + template + bool operator()(const Type& arg) { + return other_.template InternalVisit(NestedEqualVisitor(arg)); + } + + private: + const CelValue& other_; }; } // namespace @@ -127,14 +167,15 @@ const absl::StatusOr CelAttribute::AsString() const { std::string result = variable_.ident_expr().name(); for (const auto& qualifier : qualifier_path_) { - CEL_RETURN_IF_ERROR( - qualifier.Visit(CelAttributeStringPrinter(&result))); + CEL_RETURN_IF_ERROR(qualifier.Visit( + CelAttributeStringPrinter(&result, qualifier.type()))); } return result; } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +bool CelAttributeQualifier::IsMatch(const CelValue& cel_value) const { + return value_.template InternalVisit(EqualVisitor(cel_value)); +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_attribute.h b/eval/public/cel_attribute.h index b05cead38..0e5523e0a 100644 --- a/eval/public/cel_attribute.h +++ b/eval/public/cel_attribute.h @@ -19,60 +19,19 @@ #include "eval/public/cel_value_internal.h" #include "internal/status_macros.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { // CelAttributeQualifier represents a segment in // attribute resolutuion path. A segment can be qualified by values of // following types: string/int64_t/uint64/bool. class CelAttributeQualifier { - private: - // Helper class, used to implement CelAttributeQualifier::operator==. - class EqualVisitor { - public: - template - class NestedEqualVisitor { - public: - explicit NestedEqualVisitor(const T& arg) : arg_(arg) {} - - template - bool operator()(const U&) const { - return false; - } - - bool operator()(const T& other) const { return other == arg_; } - - private: - const T& arg_; - }; - - explicit EqualVisitor(const CelValue& other) : other_(other) {} - - template - bool operator()(const Type& arg) { - return other_.template Visit(NestedEqualVisitor(arg)); - } - - private: - const CelValue& other_; - }; - - CelValue value_; - - explicit CelAttributeQualifier(CelValue value) : value_(value) {} - public: // Factory method. static CelAttributeQualifier Create(CelValue value) { return CelAttributeQualifier(value); } - template - T Visit(Op&& operation) const { - return value_.Visit(operation); - } + CelValue::Type type() const { return value_.type(); } // Family of Get... methods. Return values if requested type matches the // stored one. @@ -101,14 +60,23 @@ class CelAttributeQualifier { return IsMatch(other.value_); } - bool IsMatch(const CelValue& cel_value) const { - return value_.template Visit(EqualVisitor(cel_value)); - } + bool IsMatch(const CelValue& cel_value) const; bool IsMatch(absl::string_view other_key) const { absl::optional key = GetStringKey(); return (key.has_value() && key.value() == other_key); } + + private: + friend class CelAttribute; + explicit CelAttributeQualifier(CelValue value) : value_(value) {} + + template + T Visit(Op&& operation) const { + return value_.InternalVisit(operation); + } + + CelValue value_; }; // CelAttributeQualifierPattern matches a segment in @@ -119,7 +87,8 @@ class CelAttributeQualifierPattern { // Qualifier value. If not set, treated as wildcard. absl::optional value_; - CelAttributeQualifierPattern(absl::optional value) + explicit CelAttributeQualifierPattern( + absl::optional value) : value_(value) {} public: @@ -246,9 +215,6 @@ CelAttributePattern CreateCelAttributePattern( CelAttributeQualifierPattern>> path_spec = {}); -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_ATTRIBUTE_PATTERN_H_ diff --git a/eval/public/cel_attribute_test.cc b/eval/public/cel_attribute_test.cc index 2fb81f7a8..8b013c4fb 100644 --- a/eval/public/cel_attribute_test.cc +++ b/eval/public/cel_attribute_test.cc @@ -7,12 +7,10 @@ #include "absl/strings/string_view.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "internal/status_macros.h" #include "internal/testing.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { using google::api::expr::v1alpha1::Expr; @@ -22,6 +20,7 @@ using ::google::protobuf::Timestamp; using testing::Eq; using testing::IsEmpty; using testing::SizeIs; +using cel::internal::StatusIs; class DummyMap : public CelMap { public: @@ -351,14 +350,29 @@ TEST(CelAttribute, AsStringInvalidRoot) { TEST(CelAttribute, InvalidQualifiers) { Expr expr; expr.mutable_ident_expr()->set_name("var"); + google::protobuf::Arena arena; - CelAttribute attr(expr, { - CelAttributeQualifier::Create( - CelValue::CreateDuration(absl::Minutes(2))), - }); - - EXPECT_EQ(attr.AsString().status().code(), - absl::StatusCode::kInvalidArgument); + CelAttribute attr1(expr, { + CelAttributeQualifier::Create( + CelValue::CreateDuration(absl::Minutes(2))), + }); + CelAttribute attr2(expr, + { + CelAttributeQualifier::Create( + CelProtoWrapper::CreateMessage(&expr, &arena)), + }); + + // Implementation detail: Messages as attribute qualifiers are unsupported, + // so the implementation treats them inequal to any other. This is included + // for coverage. + EXPECT_FALSE(attr1 == attr2); + EXPECT_FALSE(attr2 == attr1); + EXPECT_FALSE(attr2 == attr2); + + // If the attribute includes an unsupported qualifier, return invalid argument + // error. + EXPECT_THAT(attr1.AsString(), StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(attr2.AsString(), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(CelAttribute, AsStringQualiferTypes) { @@ -379,7 +393,4 @@ TEST(CelAttribute, AsStringQualiferTypes) { } } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/cel_function_adapter.cc b/eval/public/cel_function_adapter.cc index ee82673c8..791abf3ed 100644 --- a/eval/public/cel_function_adapter.cc +++ b/eval/public/cel_function_adapter.cc @@ -7,12 +7,16 @@ namespace runtime { namespace internal { +template <> +absl::optional TypeCodeMatch() { + return CelValue::Type::kMessage; +} + template <> absl::optional TypeCodeMatch() { return CelValue::Type::kAny; } - } // namespace internal } // namespace runtime diff --git a/eval/public/cel_function_adapter.h b/eval/public/cel_function_adapter.h index 62bc4b733..d2eb0c9ee 100644 --- a/eval/public/cel_function_adapter.h +++ b/eval/public/cel_function_adapter.h @@ -5,11 +5,13 @@ #include #include +#include "google/protobuf/message.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "internal/status_macros.h" @@ -31,6 +33,9 @@ absl::optional TypeCodeMatch() { return arg_type; } +template <> +absl::optional TypeCodeMatch(); + // A bit of a trick - to pass Any kind of value, we use generic // CelValue parameters. template <> diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 5a6442bb6..f626d51d6 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -34,6 +34,7 @@ #include "absl/types/variant.h" #include "base/memory_manager.h" #include "eval/public/cel_value_internal.h" +#include "internal/casts.h" #include "internal/status_macros.h" #include "internal/utf8.h" @@ -114,13 +115,21 @@ class CelValue { // absl::variant. using NullType = absl::monostate; + // MessageWrapper wraps a tagged MessageLite with the accessors used to + // get field values. + // + // message_ptr(): get the MessageLite pointer for the wrapper. + // + // HasFullProto(): returns whether it's safe to downcast to google::protobuf::Message. + using MessageWrapper = internal::MessageWrapper; + private: // CelError MUST BE the last in the declaration - it is a ceiling for Type // enum using ValueHolder = internal::ValueHolder< NullType, bool, int64_t, uint64_t, double, StringHolder, BytesHolder, - const google::protobuf::Message*, absl::Duration, absl::Time, const CelList*, - const CelMap*, const UnknownSet*, CelTypeHolder, const CelError*>; + MessageWrapper, absl::Duration, absl::Time, const CelList*, const CelMap*, + const UnknownSet*, CelTypeHolder, const CelError*>; public: // Metafunction providing positions corresponding to specific @@ -139,7 +148,7 @@ class CelValue { kDouble = IndexOf::value, kString = IndexOf::value, kBytes = IndexOf::value, - kMessage = IndexOf::value, + kMessage = IndexOf::value, kDuration = IndexOf::value, kTimestamp = IndexOf::value, kList = IndexOf::value, @@ -282,7 +291,10 @@ class CelValue { // Returns stored const Message* value. // Fails if stored value type is not const Message*. const google::protobuf::Message* MessageOrDie() const { - return GetValueOrDie(Type::kMessage); + MessageWrapper wrapped = GetValueOrDie(Type::kMessage); + ABSL_ASSERT(wrapped.HasFullProto()); + return cel::internal::down_cast( + wrapped.message_ptr()); } // Returns stored duration value. @@ -341,7 +353,7 @@ class CelValue { bool IsBytes() const { return value_.is(); } - bool IsMessage() const { return value_.is(); } + bool IsMessage() const { return value_.is(); } bool IsDuration() const { return value_.is(); } @@ -359,21 +371,56 @@ class CelValue { // Invokes op() with the active value, and returns the result. // All overloads of op() must have the same return type. + // Note: this depends on the internals of CelValue, so use with caution. + template + ReturnType InternalVisit(Op&& op) const { + return value_.template Visit(std::forward(op)); + } + + // Invokes op() with the active value, and returns the result. + // All overloads of op() must have the same return type. + // TODO(issues/5): Move to CelProtoWrapper to retain the assumed + // google::protobuf::Message variant version behavior for client code. template ReturnType Visit(Op&& op) const { - return value_.template Visit(op); + return value_.template Visit( + internal::MessageVisitAdapter(std::forward(op))); } // Template-style getter. // Returns true, if assignment successful template bool GetValue(Arg* value) const { - return this->template Visit(AssignerOp(value)); + return this->template InternalVisit(AssignerOp(value)); + } + + // Specialization for MessageWrapper to support legacy behavior while + // migrating off hard dependency on google::protobuf::Message. + // TODO(issues/5): Move to CelProtoWrapper. + template <> + bool GetValue(const google::protobuf::Message** value) const { + auto* held_value = value_.get(); + if (held_value == nullptr || !held_value->HasFullProto()) { + return false; + } + + *value = cel::internal::down_cast( + held_value->message_ptr()); + return true; } // Provides type names for internal logging. static std::string TypeName(Type value_type); + // Factory for message wrapper. This should only be used by internal + // libraries. + // TODO(issues/5): exposed for testing while wiring adapter APIs. Should + // make private visibility after refactors are done. + static CelValue CreateMessageWrapper(MessageWrapper value) { + CheckNullPointer(value.message_ptr(), Type::kMessage); + return CelValue(value); + } + private: ValueHolder value_; @@ -401,7 +448,11 @@ class CelValue { } bool operator()(NullType) const { return true; } - bool operator()(const google::protobuf::Message* arg) const { return arg == nullptr; } + // Note: this is not typically possible, but is supported for allowing + // function resolution for null ptrs as Messages. + bool operator()(const MessageWrapper& arg) const { + return arg.message_ptr() == nullptr; + } }; // Constructs CelValue wrapping value supplied as argument. @@ -413,13 +464,14 @@ class CelValue { // internal libraries. static CelValue CreateMessage(const google::protobuf::Message* value) { CheckNullPointer(value, Type::kMessage); - return CelValue(value); + return CelValue(MessageWrapper(value)); } // This is provided for backwards compatibility with resolving null to message // overloads. static CelValue CreateNullMessage() { - return CelValue(static_cast(nullptr)); + return CelValue( + MessageWrapper(static_cast(nullptr))); } // Crashes with a null pointer error. @@ -455,6 +507,7 @@ class CelValue { friend class CelProtoWrapper; friend class ProtoMessageTypeAdapter; friend class EvaluatorStack; + friend class TestOnly_FactoryAccessor; }; static_assert(absl::is_trivially_destructible::value, diff --git a/eval/public/cel_value_internal.h b/eval/public/cel_value_internal.h index b6654430d..52ad77ab1 100644 --- a/eval/public/cel_value_internal.h +++ b/eval/public/cel_value_internal.h @@ -17,7 +17,15 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_INTERNAL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_INTERNAL_H_ +#include +#include + +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" +#include "absl/base/macros.h" +#include "absl/numeric/bits.h" #include "absl/types/variant.h" +#include "internal/casts.h" namespace google::api::expr::runtime::internal { @@ -75,6 +83,59 @@ class ValueHolder { absl::variant value_; }; +class MessageWrapper { + public: + static_assert(alignof(google::protobuf::MessageLite) >= 2, + "Assume that valid MessageLite ptrs have a free low-order bit"); + MessageWrapper() : message_ptr_(0) {} + explicit MessageWrapper(const google::protobuf::MessageLite* message) + : message_ptr_(reinterpret_cast(message)) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); + } + + explicit MessageWrapper(const google::protobuf::Message* message) + : message_ptr_(reinterpret_cast(message) | kTagMask) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); + } + + bool HasFullProto() const { return (message_ptr_ & kTagMask) == kTagMask; } + + const google::protobuf::MessageLite* message_ptr() const { + return reinterpret_cast(message_ptr_ & + kPtrMask); + } + + private: + static constexpr uintptr_t kTagMask = 1 << 0; + static constexpr uintptr_t kPtrMask = ~kTagMask; + uintptr_t message_ptr_; + // TODO(issues/5): add LegacyTypeAccessApis to expose generic accessors for + // MessageLite. +}; + +static_assert(sizeof(MessageWrapper) <= 2 * sizeof(uintptr_t), + "MessageWrapper must not increase CelValue size."); + +// Adapter for visitor clients that depend on google::protobuf::Message as a variant type. +template +struct MessageVisitAdapter { + explicit MessageVisitAdapter(Op&& op) : op(std::forward(op)) {} + + template + T operator()(const ArgT& arg) { + return op(arg); + } + + template <> + T operator()(const MessageWrapper& wrapper) { + ABSL_ASSERT(wrapper.HasFullProto()); + return op(cel::internal::down_cast( + wrapper.message_ptr())); + } + + Op op; +}; + } // namespace google::api::expr::runtime::internal #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_INTERNAL_H_ diff --git a/eval/public/cel_value_test.cc b/eval/public/cel_value_test.cc index 89955f40d..537ebc20b 100644 --- a/eval/public/cel_value_test.cc +++ b/eval/public/cel_value_test.cc @@ -10,6 +10,7 @@ #include "eval/public/testing/matchers.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" +#include "eval/testutil/test_message.pb.h" #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" @@ -375,4 +376,33 @@ TEST(CelValueTest, DebugString) { // List and map DebugString() test coverage is in cel_proto_wrapper_test.cc. } +TEST(CelValueTest, Message) { + TestMessage message; + auto value = + CelValue::CreateMessageWrapper(CelValue::MessageWrapper(&message)); + EXPECT_TRUE(value.IsMessage()); + CelValue::MessageWrapper held; + ASSERT_TRUE(value.GetValue(&held)); + EXPECT_TRUE(held.HasFullProto()); + EXPECT_EQ(held.message_ptr(), + static_cast(&message)); +} + +TEST(CelValueTest, MessageLite) { + TestMessage message; + // Upcast to message lite. + const google::protobuf::MessageLite* ptr = &message; + auto value = CelValue::CreateMessageWrapper(CelValue::MessageWrapper(ptr)); + EXPECT_TRUE(value.IsMessage()); + CelValue::MessageWrapper held; + ASSERT_TRUE(value.GetValue(&held)); + EXPECT_FALSE(held.HasFullProto()); + EXPECT_EQ(held.message_ptr(), &message); +} + +TEST(CelValueTest, Size) { + // CelValue performance degrades when it becomes larger. + static_assert(sizeof(CelValue) <= 3 * sizeof(uintptr_t)); +} + } // namespace google::api::expr::runtime diff --git a/eval/public/containers/BUILD b/eval/public/containers/BUILD index 2d78c8681..bec0dffdc 100644 --- a/eval/public/containers/BUILD +++ b/eval/public/containers/BUILD @@ -29,6 +29,7 @@ cc_library( deps = [ "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", + "//internal:casts", "//internal:overflow", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", diff --git a/eval/public/containers/field_access.cc b/eval/public/containers/field_access.cc index b7dcc7ead..d3019cda3 100644 --- a/eval/public/containers/field_access.cc +++ b/eval/public/containers/field_access.cc @@ -30,6 +30,7 @@ #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "internal/casts.h" #include "internal/overflow.h" namespace google::api::expr::runtime { @@ -341,35 +342,6 @@ class MapValueAccessor : public FieldAccessor { const MapValueConstRef* value_ref_; }; -// Helper classes that should retrieve values from CelValue, -// when CelValue content inherits from Message. -template -class MessageRetriever { - public: - absl::optional operator()(const T&) const { return {}; } -}; - -// Partial specialization, valid when T is assignable to message -// -template -class MessageRetriever { - public: - absl::optional operator()(const T& arg) const { - const Message* msg = arg; - return msg; - } -}; - -class MessageRetrieverOp { - public: - template - absl::optional operator()(const T& arg) { - // Metaprogramming hacks... - return MessageRetriever::value>()( - arg); - } -}; - } // namespace absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, @@ -518,18 +490,14 @@ class FieldSetter { return true; } - // We attempt to retrieve value if it derives from google::protobuf::Message. - // That includes both generic Protobuf message types and specific - // message types stored in CelValue as separate entities. - auto value = cel_value.template Visit>( - MessageRetrieverOp()); - - if (!value.has_value()) { - return false; + if (CelValue::MessageWrapper wrapper; + cel_value.GetValue(&wrapper) && wrapper.HasFullProto()) { + static_cast(this)->SetMessage( + cel::internal::down_cast(wrapper.message_ptr())); + return true; } - static_cast(this)->SetMessage(value.value()); - return true; + return false; } // This method provides message field content, wrapped in CelValue. diff --git a/eval/public/set_util.cc b/eval/public/set_util.cc index 885d9031f..43c9e37a3 100644 --- a/eval/public/set_util.cc +++ b/eval/public/set_util.cc @@ -2,10 +2,7 @@ #include -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { // Default implementation is operator<. @@ -21,6 +18,21 @@ int ComparisonImpl(T lhs, T rhs) { } } +// Message wrapper specialization +template <> +int ComparisonImpl(CelValue::MessageWrapper lhs_wrapper, + CelValue::MessageWrapper rhs_wrapper) { + auto* lhs = lhs_wrapper.message_ptr(); + auto* rhs = rhs_wrapper.message_ptr(); + if (lhs < rhs) { + return -1; + } else if (lhs > rhs) { + return 1; + } else { + return 0; + } +} + // List specialization -- compare size then elementwise compare. template <> int ComparisonImpl(const CelList* lhs, const CelList* rhs) { @@ -88,7 +100,6 @@ int ComparisonImpl(const CelMap* lhs, const CelMap* rhs) { } struct ComparisonVisitor { - CelValue rhs; explicit ComparisonVisitor(CelValue rhs) : rhs(rhs) {} template int operator()(T lhs_value) { @@ -99,27 +110,26 @@ struct ComparisonVisitor { } return ComparisonImpl(lhs_value, rhs_value); } + + CelValue rhs; }; } // namespace int CelValueCompare(CelValue lhs, CelValue rhs) { - return lhs.Visit(ComparisonVisitor(rhs)); + return lhs.InternalVisit(ComparisonVisitor(rhs)); } bool CelValueLessThan(CelValue lhs, CelValue rhs) { - return lhs.Visit(ComparisonVisitor(rhs)) < 0; + return lhs.InternalVisit(ComparisonVisitor(rhs)) < 0; } bool CelValueEqual(CelValue lhs, CelValue rhs) { - return lhs.Visit(ComparisonVisitor(rhs)) == 0; + return lhs.InternalVisit(ComparisonVisitor(rhs)) == 0; } bool CelValueGreaterThan(CelValue lhs, CelValue rhs) { - return lhs.Visit(ComparisonVisitor(rhs)) > 0; + return lhs.InternalVisit(ComparisonVisitor(rhs)) > 0; } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/testing/BUILD b/eval/public/testing/BUILD index b348a0bd3..b74539044 100644 --- a/eval/public/testing/BUILD +++ b/eval/public/testing/BUILD @@ -13,6 +13,7 @@ cc_library( "//eval/public:cel_value", "//eval/public:set_util", "//eval/public:unknown_set", + "//internal:casts", "//internal:testing", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", diff --git a/eval/public/testing/matchers.cc b/eval/public/testing/matchers.cc index d9e52c7fd..dc23827e9 100644 --- a/eval/public/testing/matchers.cc +++ b/eval/public/testing/matchers.cc @@ -2,10 +2,12 @@ #include +#include "google/protobuf/message.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "eval/public/set_util.h" +#include "internal/casts.h" namespace google::api::expr::runtime { @@ -42,7 +44,7 @@ template class CelValueMatcherImpl : public testing::MatcherInterface { public: explicit CelValueMatcherImpl(testing::Matcher m) - : underlying_type_matcher_(m) {} + : underlying_type_matcher_(std::move(m)) {} bool MatchAndExplain(const CelValue& v, testing::MatchResultListener* listener) const override { UnderlyingType arg; @@ -60,6 +62,32 @@ class CelValueMatcherImpl : public testing::MatcherInterface { const testing::Matcher underlying_type_matcher_; }; +// Template specialization for google::protobuf::Message. +template <> +class CelValueMatcherImpl + : public testing::MatcherInterface { + public: + explicit CelValueMatcherImpl(testing::Matcher m) + : underlying_type_matcher_(std::move(m)) {} + bool MatchAndExplain(const CelValue& v, + testing::MatchResultListener* listener) const override { + CelValue::MessageWrapper arg; + return v.GetValue(&arg) && arg.HasFullProto() && + underlying_type_matcher_.Matches( + cel::internal::down_cast( + arg.message_ptr())); + } + + void DescribeTo(std::ostream* os) const override { + *os << absl::StrCat("type is ", + CelValue::TypeName(CelValue::Type::kMessage), " and "); + underlying_type_matcher_.DescribeTo(os); + } + + private: + const testing::Matcher underlying_type_matcher_; +}; + } // namespace CelValueMatcher EqualsCelValue(const CelValue& v) { From 7b66af431ebfa90bc9605196d3f6bc2c3ec8754f Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 29 Mar 2022 22:31:32 +0000 Subject: [PATCH 050/111] Make protobuf type provider thread-compatible. PiperOrigin-RevId: 438142521 --- eval/public/structs/BUILD | 4 +++- .../protobuf_descriptor_type_provider.cc | 18 +++++++++++------- .../protobuf_descriptor_type_provider.h | 4 +++- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 75ff1ec11..17f57c10f 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -164,11 +164,13 @@ cc_library( srcs = ["protobuf_descriptor_type_provider.cc"], hdrs = ["protobuf_descriptor_type_provider.h"], deps = [ + ":legacy_type_provider", ":proto_message_type_adapter", "//eval/public:cel_value", - "//eval/public/structs:legacy_type_provider", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], diff --git a/eval/public/structs/protobuf_descriptor_type_provider.cc b/eval/public/structs/protobuf_descriptor_type_provider.cc index 65e7bc48d..214d84ee5 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.cc +++ b/eval/public/structs/protobuf_descriptor_type_provider.cc @@ -18,6 +18,7 @@ #include #include "google/protobuf/descriptor.h" +#include "absl/synchronization/mutex.h" #include "eval/public/cel_value.h" #include "eval/public/structs/proto_message_type_adapter.h" @@ -26,13 +27,16 @@ namespace google::api::expr::runtime { absl::optional ProtobufDescriptorProvider::ProvideLegacyType( absl::string_view name) const { const ProtoMessageTypeAdapter* result = nullptr; - auto it = type_cache_.find(name); - if (it != type_cache_.end()) { - result = it->second.get(); - } else { - auto type_provider = GetType(name); - result = type_provider.get(); - type_cache_[name] = std::move(type_provider); + { + absl::MutexLock lock(&mu_); + auto it = type_cache_.find(name); + if (it != type_cache_.end()) { + result = it->second.get(); + } else { + auto type_provider = GetType(name); + result = type_provider.get(); + type_cache_[name] = std::move(type_provider); + } } if (result == nullptr) { return absl::nullopt; diff --git a/eval/public/structs/protobuf_descriptor_type_provider.h b/eval/public/structs/protobuf_descriptor_type_provider.h index 4d745a1c7..1d0c3a669 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.h +++ b/eval/public/structs/protobuf_descriptor_type_provider.h @@ -21,6 +21,7 @@ #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -54,7 +55,8 @@ class ProtobufDescriptorProvider : public LegacyTypeProvider { ProtoWrapperTypeOptions unboxing_option_; mutable absl::flat_hash_map> - type_cache_; + type_cache_ ABSL_GUARDED_BY(mu_); + mutable absl::Mutex mu_; }; } // namespace google::api::expr::runtime From 4e9a250cc5caf528e89fbb0d195dfc006e32415b Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 30 Mar 2022 03:28:49 +0000 Subject: [PATCH 051/111] Internal change PiperOrigin-RevId: 438197345 --- base/BUILD | 2 + base/internal/value.post.h | 6 + base/internal/value.pre.h | 3 + base/type.cc | 1 + base/type.h | 6 +- base/type_test.cc | 5 + base/value.cc | 68 +++++++ base/value.h | 133 ++++++++++++ base/value_factory.h | 7 + base/value_test.cc | 407 +++++++++++++++++++++++++++++++++++++ 10 files changed, 637 insertions(+), 1 deletion(-) diff --git a/base/BUILD b/base/BUILD index e56c465ca..9b0131504 100644 --- a/base/BUILD +++ b/base/BUILD @@ -196,9 +196,11 @@ cc_test( "//internal:strings", "//internal:testing", "//internal:time", + "@com_google_absl//absl/hash", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/time", ], ) diff --git a/base/internal/value.post.h b/base/internal/value.post.h index 197fb6c11..c3aa600ac 100644 --- a/base/internal/value.post.h +++ b/base/internal/value.post.h @@ -41,6 +41,11 @@ inline internal::TypeInfo GetEnumValueTypeId(const EnumValue& enum_value) { return enum_value.TypeId(); } +inline internal::TypeInfo GetStructValueTypeId( + const StructValue& struct_value) { + return struct_value.TypeId(); +} + // Implementation of BytesValue that is stored inlined within a handle. Since // absl::Cord is reference counted itself, this is more efficient than storing // this on the heap. @@ -660,6 +665,7 @@ CEL_INTERNAL_VALUE_DECL(StringValue); CEL_INTERNAL_VALUE_DECL(DurationValue); CEL_INTERNAL_VALUE_DECL(TimestampValue); CEL_INTERNAL_VALUE_DECL(EnumValue); +CEL_INTERNAL_VALUE_DECL(StructValue); #undef CEL_INTERNAL_VALUE_DECL } // namespace cel diff --git a/base/internal/value.pre.h b/base/internal/value.pre.h index ce8ec888c..88c7eefd4 100644 --- a/base/internal/value.pre.h +++ b/base/internal/value.pre.h @@ -27,6 +27,7 @@ namespace cel { class EnumValue; +class StructValue; namespace base_internal { @@ -48,6 +49,8 @@ inline constexpr uintptr_t kValueHandleMask = ~kValueHandleBits; internal::TypeInfo GetEnumValueTypeId(const EnumValue& enum_value); +internal::TypeInfo GetStructValueTypeId(const StructValue& struct_value); + class InlinedCordBytesValue; class InlinedStringViewBytesValue; class StringBytesValue; diff --git a/base/type.cc b/base/type.cc index cd578eaf1..dbaa8cada 100644 --- a/base/type.cc +++ b/base/type.cc @@ -46,6 +46,7 @@ CEL_INTERNAL_TYPE_IMPL(StringType); CEL_INTERNAL_TYPE_IMPL(DurationType); CEL_INTERNAL_TYPE_IMPL(TimestampType); CEL_INTERNAL_TYPE_IMPL(EnumType); +CEL_INTERNAL_TYPE_IMPL(StructType); CEL_INTERNAL_TYPE_IMPL(ListType); CEL_INTERNAL_TYPE_IMPL(MapType); #undef CEL_INTERNAL_TYPE_IMPL diff --git a/base/type.h b/base/type.h index c1d07f8a0..e619ced7c 100644 --- a/base/type.h +++ b/base/type.h @@ -65,6 +65,7 @@ class StringValue; class DurationValue; class TimestampValue; class EnumValue; +class StructValue; class ValueFactory; namespace internal { @@ -586,6 +587,7 @@ class StructType : public Type { private: friend class StructType; + friend class StructValue; absl::variant data_; }; @@ -602,7 +604,8 @@ class StructType : public Type { protected: StructType() = default; - // TODO(issues/5): NewInstance + virtual absl::StatusOr> NewInstance( + ValueFactory& value_factory) const = 0; // Called by FindField. virtual absl::StatusOr FindFieldByName( @@ -620,6 +623,7 @@ class StructType : public Type { friend struct FindFieldVisitor; friend class TypeFactory; friend class base_internal::TypeHandleBase; + friend class StructValue; // Called by base_internal::TypeHandleBase to implement Is for Transient and // Persistent. diff --git a/base/type_test.cc b/base/type_test.cc index c0366d4db..10f41caea 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -96,6 +96,11 @@ class TestStructType final : public StructType { absl::string_view name() const override { return "test_struct.TestStruct"; } protected: + absl::StatusOr> NewInstance( + ValueFactory& value_factory) const override { + return absl::UnimplementedError(""); + } + absl::StatusOr FindFieldByName(TypeManager& type_manager, absl::string_view name) const override { if (name == "bool_field") { diff --git a/base/value.cc b/base/value.cc index 0fdfc1d66..e28c0400e 100644 --- a/base/value.cc +++ b/base/value.cc @@ -63,6 +63,7 @@ CEL_INTERNAL_VALUE_IMPL(StringValue); CEL_INTERNAL_VALUE_IMPL(DurationValue); CEL_INTERNAL_VALUE_IMPL(TimestampValue); CEL_INTERNAL_VALUE_IMPL(EnumValue); +CEL_INTERNAL_VALUE_IMPL(StructValue); #undef CEL_INTERNAL_VALUE_IMPL namespace { @@ -805,6 +806,73 @@ void EnumValue::HashValue(absl::HashState state) const { absl::HashState::combine(std::move(state), type(), number()); } +struct StructValue::SetFieldVisitor final { + StructValue& struct_value; + const Persistent& value; + + absl::Status operator()(absl::string_view name) const { + return struct_value.SetFieldByName(name, value); + } + + absl::Status operator()(int64_t number) const { + return struct_value.SetFieldByNumber(number, value); + } +}; + +struct StructValue::GetFieldVisitor final { + const StructValue& struct_value; + ValueFactory& value_factory; + + absl::StatusOr> operator()( + absl::string_view name) const { + return struct_value.GetFieldByName(value_factory, name); + } + + absl::StatusOr> operator()(int64_t number) const { + return struct_value.GetFieldByNumber(value_factory, number); + } +}; + +struct StructValue::HasFieldVisitor final { + const StructValue& struct_value; + + absl::StatusOr operator()(absl::string_view name) const { + return struct_value.HasFieldByName(name); + } + + absl::StatusOr operator()(int64_t number) const { + return struct_value.HasFieldByNumber(number); + } +}; + +absl::StatusOr> StructValue::New( + const Persistent& struct_type, + ValueFactory& value_factory) { + CEL_ASSIGN_OR_RETURN(auto struct_value, + struct_type->NewInstance(value_factory)); + if (!struct_value->type_) { + // In case somebody is caching, we avoid setting the type_ if it has already + // been set, to avoid a race condition where one CPU sees a half written + // pointer. + const_cast(*struct_value).type_ = struct_type; + } + return struct_value; +} + +absl::Status StructValue::SetField(FieldId field, + const Persistent& value) { + return absl::visit(SetFieldVisitor{*this, value}, field.data_); +} + +absl::StatusOr> StructValue::GetField( + ValueFactory& value_factory, FieldId field) const { + return absl::visit(GetFieldVisitor{*this, value_factory}, field.data_); +} + +absl::StatusOr StructValue::HasField(FieldId field) const { + return absl::visit(HasFieldVisitor{*this}, field.data_); +} + namespace base_internal { absl::Cord InlinedCordBytesValue::ToCord(bool reference_counted) const { diff --git a/base/value.h b/base/value.h index 41cff1f09..f6234ef72 100644 --- a/base/value.h +++ b/base/value.h @@ -51,6 +51,7 @@ class StringValue; class DurationValue; class TimestampValue; class EnumValue; +class StructValue; class ValueFactory; namespace internal { @@ -84,6 +85,7 @@ class Value : public base_internal::Resource { friend class DurationValue; friend class TimestampValue; friend class EnumValue; + friend class StructValue; friend class base_internal::ValueHandleBase; friend class base_internal::StringBytesValue; friend class base_internal::ExternalDataBytesValue; @@ -678,6 +680,137 @@ class EnumValue : public Value { return ::cel::internal::TypeId(); \ } +// StructValue represents an instance of cel::StructType. +class StructValue : public Value { + public: + using FieldId = StructType::FieldId; + + static absl::StatusOr> New( + const Persistent& struct_type, + ValueFactory& value_factory); + + Transient type() const final { + ABSL_ASSERT(type_); + return type_; + } + + Kind kind() const final { return Kind::kStruct; } + + absl::Status SetField(FieldId field, const Persistent& value); + + absl::StatusOr> GetField(ValueFactory& value_factory, + FieldId field) const; + + absl::StatusOr HasField(FieldId field) const; + + protected: + StructValue() = default; + + virtual absl::Status SetFieldByName(absl::string_view name, + const Persistent& value) = 0; + + virtual absl::Status SetFieldByNumber( + int64_t number, const Persistent& value) = 0; + + virtual absl::StatusOr> GetFieldByName( + ValueFactory& value_factory, absl::string_view name) const = 0; + + virtual absl::StatusOr> GetFieldByNumber( + ValueFactory& value_factory, int64_t number) const = 0; + + virtual absl::StatusOr HasFieldByName(absl::string_view name) const = 0; + + virtual absl::StatusOr HasFieldByNumber(int64_t number) const = 0; + + private: + struct SetFieldVisitor; + struct GetFieldVisitor; + struct HasFieldVisitor; + + friend struct SetFieldVisitor; + friend struct GetFieldVisitor; + friend struct HasFieldVisitor; + friend internal::TypeInfo base_internal::GetStructValueTypeId( + const StructValue& struct_value); + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kStruct; } + + StructValue(const StructValue&) = delete; + StructValue(StructValue&&) = delete; + + bool Equals(const Value& other) const override = 0; + void HashValue(absl::HashState state) const override = 0; + + std::pair SizeAndAlignment() const override = 0; + + // Called by CEL_IMPLEMENT_ENUM_VALUE() and Is() to perform type checking. + virtual internal::TypeInfo TypeId() const = 0; + + // Set lazily, by StructValue::New. + Persistent type_; +}; + +// CEL_DECLARE_STRUCT_VALUE declares `struct_value` as an struct value. It must +// be part of the class definition of `struct_value`. +// +// class MyStructValue : public cel::StructValue { +// ... +// private: +// CEL_DECLARE_STRUCT_VALUE(MyStructValue); +// }; +#define CEL_DECLARE_STRUCT_VALUE(struct_value) \ + private: \ + friend class ::cel::base_internal::ValueHandleBase; \ + \ + static bool Is(const ::cel::Value& value); \ + \ + ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ + \ + ::cel::internal::TypeInfo TypeId() const override; + +// CEL_IMPLEMENT_STRUCT_VALUE implements `struct_value` as an struct +// value. It must be called after the class definition of `struct_value`. +// +// class MyStructValue : public cel::StructValue { +// ... +// private: +// CEL_DECLARE_STRUCT_VALUE(MyStructValue); +// }; +// +// CEL_IMPLEMENT_STRUCT_VALUE(MyStructValue); +#define CEL_IMPLEMENT_STRUCT_VALUE(struct_value) \ + static_assert(::std::is_base_of_v<::cel::StructValue, struct_value>, \ + #struct_value " must inherit from cel::StructValue"); \ + static_assert(!::std::is_abstract_v, \ + "this must not be abstract"); \ + \ + bool struct_value::Is(const ::cel::Value& value) { \ + return value.kind() == ::cel::Kind::kStruct && \ + ::cel::base_internal::GetStructValueTypeId( \ + ::cel::internal::down_cast( \ + value)) == ::cel::internal::TypeId(); \ + } \ + \ + ::std::pair<::std::size_t, ::std::size_t> struct_value::SizeAndAlignment() \ + const { \ + static_assert( \ + ::std::is_same_v>>, \ + "this must be the same as " #struct_value); \ + return ::std::pair<::std::size_t, ::std::size_t>(sizeof(struct_value), \ + alignof(struct_value)); \ + } \ + \ + ::cel::internal::TypeInfo struct_value::TypeId() const { \ + return ::cel::internal::TypeId(); \ + } + } // namespace cel // value.pre.h forward declares types so they can be friended above. The types diff --git a/base/value_factory.h b/base/value_factory.h index e5321b733..cc072fe80 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -145,6 +145,13 @@ class ValueFactory final { std::remove_const_t>(memory_manager(), std::forward(args)...); } + template + EnableIfBaseOfT>> + CreateStructValue(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal::PersistentHandleFactory::template Make< + std::remove_const_t>(memory_manager(), std::forward(args)...); + } + private: friend class BytesValue; friend class StringValue; diff --git a/base/value_test.cc b/base/value_test.cc index 75a991544..46a4680bb 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -22,13 +22,16 @@ #include #include +#include "absl/hash/hash.h" #include "absl/hash/hash_testing.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "base/memory_manager.h" #include "base/type.h" #include "base/type_factory.h" +#include "base/type_manager.h" #include "base/value_factory.h" #include "internal/strings.h" #include "internal/testing.h" @@ -37,6 +40,8 @@ namespace cel { namespace { +using testing::Eq; +using cel::internal::IsOkAndHolds; using cel::internal::StatusIs; enum class TestEnum { @@ -120,6 +125,224 @@ class TestEnumType final : public EnumType { CEL_IMPLEMENT_ENUM_TYPE(TestEnumType); +struct TestStruct final { + bool bool_field = false; + int64_t int_field = 0; + uint64_t uint_field = 0; + double double_field = 0.0; +}; + +bool operator==(const TestStruct& lhs, const TestStruct& rhs) { + return lhs.bool_field == rhs.bool_field && lhs.int_field == rhs.int_field && + lhs.uint_field == rhs.uint_field && + lhs.double_field == rhs.double_field; +} + +template +H AbslHashValue(H state, const TestStruct& test_struct) { + return H::combine(std::move(state), test_struct.bool_field, + test_struct.int_field, test_struct.uint_field, + test_struct.double_field); +} + +class TestStructValue final : public StructValue { + public: + explicit TestStructValue(TestStruct value) : value_(std::move(value)) {} + + std::string DebugString() const override { + return absl::StrCat("bool_field: ", value().bool_field, + " int_field: ", value().int_field, + " uint_field: ", value().uint_field, + " double_field: ", value().double_field); + } + + const TestStruct& value() const { return value_; } + + protected: + absl::Status SetFieldByName(absl::string_view name, + const Persistent& value) override { + if (name == "bool_field") { + if (!value.Is()) { + return absl::InvalidArgumentError(""); + } + value_.bool_field = value.As()->value(); + } else if (name == "int_field") { + if (!value.Is()) { + return absl::InvalidArgumentError(""); + } + value_.int_field = value.As()->value(); + } else if (name == "uint_field") { + if (!value.Is()) { + return absl::InvalidArgumentError(""); + } + value_.uint_field = value.As()->value(); + } else if (name == "double_field") { + if (!value.Is()) { + return absl::InvalidArgumentError(""); + } + value_.double_field = value.As()->value(); + } else { + return absl::NotFoundError(""); + } + return absl::OkStatus(); + } + + absl::Status SetFieldByNumber(int64_t number, + const Persistent& value) override { + switch (number) { + case 0: + if (!value.Is()) { + return absl::InvalidArgumentError(""); + } + value_.bool_field = value.As()->value(); + break; + case 1: + if (!value.Is()) { + return absl::InvalidArgumentError(""); + } + value_.int_field = value.As()->value(); + break; + case 2: + if (!value.Is()) { + return absl::InvalidArgumentError(""); + } + value_.uint_field = value.As()->value(); + break; + case 3: + if (!value.Is()) { + return absl::InvalidArgumentError(""); + } + value_.double_field = value.As()->value(); + break; + default: + return absl::NotFoundError(""); + } + return absl::OkStatus(); + } + + absl::StatusOr> GetFieldByName( + ValueFactory& value_factory, absl::string_view name) const override { + if (name == "bool_field") { + return value_factory.CreateBoolValue(value().bool_field); + } else if (name == "int_field") { + return value_factory.CreateIntValue(value().int_field); + } else if (name == "uint_field") { + return value_factory.CreateUintValue(value().uint_field); + } else if (name == "double_field") { + return value_factory.CreateDoubleValue(value().double_field); + } + return absl::NotFoundError(""); + } + + absl::StatusOr> GetFieldByNumber( + ValueFactory& value_factory, int64_t number) const override { + switch (number) { + case 0: + return value_factory.CreateBoolValue(value().bool_field); + case 1: + return value_factory.CreateIntValue(value().int_field); + case 2: + return value_factory.CreateUintValue(value().uint_field); + case 3: + return value_factory.CreateDoubleValue(value().double_field); + default: + return absl::NotFoundError(""); + } + } + + absl::StatusOr HasFieldByName(absl::string_view name) const override { + if (name == "bool_field") { + return true; + } else if (name == "int_field") { + return true; + } else if (name == "uint_field") { + return true; + } else if (name == "double_field") { + return true; + } + return absl::NotFoundError(""); + } + + absl::StatusOr HasFieldByNumber(int64_t number) const override { + switch (number) { + case 0: + return true; + case 1: + return true; + case 2: + return true; + case 3: + return true; + default: + return absl::NotFoundError(""); + } + } + + private: + bool Equals(const Value& other) const override { + return Is(other) && + value() == static_cast(other).value(); + } + + void HashValue(absl::HashState state) const override { + absl::HashState::combine(std::move(state), type(), value()); + } + + TestStruct value_; + + CEL_DECLARE_STRUCT_VALUE(TestStructValue); +}; + +CEL_IMPLEMENT_STRUCT_VALUE(TestStructValue); + +class TestStructType final : public StructType { + public: + using StructType::StructType; + + absl::string_view name() const override { return "test_struct.TestStruct"; } + + protected: + absl::StatusOr> NewInstance( + ValueFactory& value_factory) const override { + return value_factory.CreateStructValue(TestStruct{}); + } + + absl::StatusOr FindFieldByName(TypeManager& type_manager, + absl::string_view name) const override { + if (name == "bool_field") { + return Field("bool_field", 0, type_manager.GetBoolType()); + } else if (name == "int_field") { + return Field("int_field", 1, type_manager.GetIntType()); + } else if (name == "uint_field") { + return Field("uint_field", 2, type_manager.GetUintType()); + } else if (name == "double_field") { + return Field("double_field", 3, type_manager.GetDoubleType()); + } + return absl::NotFoundError(""); + } + + absl::StatusOr FindFieldByNumber(TypeManager& type_manager, + int64_t number) const override { + switch (number) { + case 0: + return Field("bool_field", 0, type_manager.GetBoolType()); + case 1: + return Field("int_field", 1, type_manager.GetIntType()); + case 2: + return Field("uint_field", 2, type_manager.GetUintType()); + case 3: + return Field("double_field", 3, type_manager.GetDoubleType()); + default: + return absl::NotFoundError(""); + } + } + + private: + CEL_DECLARE_STRUCT_TYPE(TestStructType); +}; + +CEL_IMPLEMENT_STRUCT_TYPE(TestStructType); + template Persistent Must(absl::StatusOr> status_or_handle) { return std::move(status_or_handle).value(); @@ -1505,14 +1728,197 @@ TEST(EnumType, NewInstance) { StatusIs(absl::StatusCode::kNotFound)); } +TEST(Value, Struct) { + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto struct_type, + type_factory.CreateStructType()); + ASSERT_OK_AND_ASSIGN(auto zero_value, + StructValue::New(struct_type, value_factory)); + EXPECT_TRUE(zero_value.Is()); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Must(StructValue::New(struct_type, value_factory))); + EXPECT_EQ(zero_value->kind(), Kind::kStruct); + EXPECT_EQ(zero_value->type(), struct_type); + EXPECT_EQ(zero_value.As()->value(), TestStruct{}); + + ASSERT_OK_AND_ASSIGN(auto one_value, + StructValue::New(struct_type, value_factory)); + ASSERT_OK(one_value->SetField(StructValue::FieldId("bool_field"), + value_factory.CreateBoolValue(true))); + ASSERT_OK(one_value->SetField(StructValue::FieldId("int_field"), + value_factory.CreateIntValue(1))); + ASSERT_OK(one_value->SetField(StructValue::FieldId("uint_field"), + value_factory.CreateUintValue(1))); + ASSERT_OK(one_value->SetField(StructValue::FieldId("double_field"), + value_factory.CreateDoubleValue(1.0))); + EXPECT_TRUE(one_value.Is()); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value->kind(), Kind::kStruct); + EXPECT_EQ(one_value->type(), struct_type); + EXPECT_EQ(one_value.As()->value(), + (TestStruct{true, 1, 1, 1.0})); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(StructValue, SetField) { + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto struct_type, + type_factory.CreateStructType()); + ASSERT_OK_AND_ASSIGN(auto struct_value, + StructValue::New(struct_type, value_factory)); + EXPECT_OK(struct_value->SetField(StructValue::FieldId("bool_field"), + value_factory.CreateBoolValue(true))); + EXPECT_THAT( + struct_value->GetField(value_factory, StructValue::FieldId("bool_field")), + IsOkAndHolds(Eq(value_factory.CreateBoolValue(true)))); + EXPECT_OK(struct_value->SetField(StructValue::FieldId(0), + value_factory.CreateBoolValue(false))); + EXPECT_THAT(struct_value->GetField(value_factory, StructValue::FieldId(0)), + IsOkAndHolds(Eq(value_factory.CreateBoolValue(false)))); + EXPECT_OK(struct_value->SetField(StructValue::FieldId("int_field"), + value_factory.CreateIntValue(1))); + EXPECT_THAT( + struct_value->GetField(value_factory, StructValue::FieldId("int_field")), + IsOkAndHolds(Eq(value_factory.CreateIntValue(1)))); + EXPECT_OK(struct_value->SetField(StructValue::FieldId(1), + value_factory.CreateIntValue(0))); + EXPECT_THAT(struct_value->GetField(value_factory, StructValue::FieldId(1)), + IsOkAndHolds(Eq(value_factory.CreateIntValue(0)))); + EXPECT_OK(struct_value->SetField(StructValue::FieldId("uint_field"), + value_factory.CreateUintValue(1))); + EXPECT_THAT( + struct_value->GetField(value_factory, StructValue::FieldId("uint_field")), + IsOkAndHolds(Eq(value_factory.CreateUintValue(1)))); + EXPECT_OK(struct_value->SetField(StructValue::FieldId(2), + value_factory.CreateUintValue(0))); + EXPECT_THAT(struct_value->GetField(value_factory, StructValue::FieldId(2)), + IsOkAndHolds(Eq(value_factory.CreateUintValue(0)))); + EXPECT_OK(struct_value->SetField(StructValue::FieldId("double_field"), + value_factory.CreateDoubleValue(1.0))); + EXPECT_THAT(struct_value->GetField(value_factory, + StructValue::FieldId("double_field")), + IsOkAndHolds(Eq(value_factory.CreateDoubleValue(1.0)))); + EXPECT_OK(struct_value->SetField(StructValue::FieldId(3), + value_factory.CreateDoubleValue(0.0))); + EXPECT_THAT(struct_value->GetField(value_factory, StructValue::FieldId(3)), + IsOkAndHolds(Eq(value_factory.CreateDoubleValue(0.0)))); + + EXPECT_THAT(struct_value->SetField(StructValue::FieldId("bool_field"), + value_factory.GetNullValue()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(struct_value->SetField(StructValue::FieldId(0), + value_factory.GetNullValue()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(struct_value->SetField(StructValue::FieldId("int_field"), + value_factory.GetNullValue()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(struct_value->SetField(StructValue::FieldId(1), + value_factory.GetNullValue()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(struct_value->SetField(StructValue::FieldId("uint_field"), + value_factory.GetNullValue()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(struct_value->SetField(StructValue::FieldId(2), + value_factory.GetNullValue()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(struct_value->SetField(StructValue::FieldId("double_field"), + value_factory.GetNullValue()), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(struct_value->SetField(StructValue::FieldId(3), + value_factory.GetNullValue()), + StatusIs(absl::StatusCode::kInvalidArgument)); + + EXPECT_THAT(struct_value->SetField(StructValue::FieldId("missing_field"), + value_factory.GetNullValue()), + StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(struct_value->SetField(StructValue::FieldId(4), + value_factory.GetNullValue()), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST(StructValue, GetField) { + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto struct_type, + type_factory.CreateStructType()); + ASSERT_OK_AND_ASSIGN(auto struct_value, + StructValue::New(struct_type, value_factory)); + EXPECT_THAT( + struct_value->GetField(value_factory, StructValue::FieldId("bool_field")), + IsOkAndHolds(Eq(value_factory.CreateBoolValue(false)))); + EXPECT_THAT(struct_value->GetField(value_factory, StructValue::FieldId(0)), + IsOkAndHolds(Eq(value_factory.CreateBoolValue(false)))); + EXPECT_THAT( + struct_value->GetField(value_factory, StructValue::FieldId("int_field")), + IsOkAndHolds(Eq(value_factory.CreateIntValue(0)))); + EXPECT_THAT(struct_value->GetField(value_factory, StructValue::FieldId(1)), + IsOkAndHolds(Eq(value_factory.CreateIntValue(0)))); + EXPECT_THAT( + struct_value->GetField(value_factory, StructValue::FieldId("uint_field")), + IsOkAndHolds(Eq(value_factory.CreateUintValue(0)))); + EXPECT_THAT(struct_value->GetField(value_factory, StructValue::FieldId(2)), + IsOkAndHolds(Eq(value_factory.CreateUintValue(0)))); + EXPECT_THAT(struct_value->GetField(value_factory, + StructValue::FieldId("double_field")), + IsOkAndHolds(Eq(value_factory.CreateDoubleValue(0.0)))); + EXPECT_THAT(struct_value->GetField(value_factory, StructValue::FieldId(3)), + IsOkAndHolds(Eq(value_factory.CreateDoubleValue(0.0)))); + EXPECT_THAT(struct_value->GetField(value_factory, + StructValue::FieldId("missing_field")), + StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(struct_value->HasField(StructValue::FieldId(4)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST(StructValue, HasField) { + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto struct_type, + type_factory.CreateStructType()); + ASSERT_OK_AND_ASSIGN(auto struct_value, + StructValue::New(struct_type, value_factory)); + EXPECT_THAT(struct_value->HasField(StructValue::FieldId("bool_field")), + IsOkAndHolds(true)); + EXPECT_THAT(struct_value->HasField(StructValue::FieldId(0)), + IsOkAndHolds(true)); + EXPECT_THAT(struct_value->HasField(StructValue::FieldId("int_field")), + IsOkAndHolds(true)); + EXPECT_THAT(struct_value->HasField(StructValue::FieldId(1)), + IsOkAndHolds(true)); + EXPECT_THAT(struct_value->HasField(StructValue::FieldId("uint_field")), + IsOkAndHolds(true)); + EXPECT_THAT(struct_value->HasField(StructValue::FieldId(2)), + IsOkAndHolds(true)); + EXPECT_THAT(struct_value->HasField(StructValue::FieldId("double_field")), + IsOkAndHolds(true)); + EXPECT_THAT(struct_value->HasField(StructValue::FieldId(3)), + IsOkAndHolds(true)); + EXPECT_THAT(struct_value->HasField(StructValue::FieldId("missing_field")), + StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(struct_value->HasField(StructValue::FieldId(4)), + StatusIs(absl::StatusCode::kNotFound)); +} + TEST(Value, SupportsAbslHash) { ValueFactory value_factory(MemoryManager::Global()); TypeFactory type_factory(MemoryManager::Global()); ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); + ASSERT_OK_AND_ASSIGN(auto struct_type, + type_factory.CreateStructType()); ASSERT_OK_AND_ASSIGN( auto enum_value, EnumValue::New(enum_type, value_factory, EnumType::ConstantId("VALUE1"))); + ASSERT_OK_AND_ASSIGN(auto struct_value, + StructValue::New(struct_type, value_factory)); EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ Persistent(value_factory.GetNullValue()), Persistent( @@ -1534,6 +1940,7 @@ TEST(Value, SupportsAbslHash) { Persistent( Must(value_factory.CreateStringValue(absl::Cord("bar")))), Persistent(enum_value), + Persistent(struct_value), })); } From f94ff294820cf10acfa36ec2e7e8962ef4958458 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 1 Apr 2022 20:31:55 +0000 Subject: [PATCH 052/111] Make map lookup error actually say which key wasn't found. Currently the error just says "Key not found in map" twice. PiperOrigin-RevId: 438898661 --- eval/eval/container_access_step.cc | 5 ++--- eval/eval/container_access_step_test.cc | 5 +++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index cc0bdcb66..576508422 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -64,8 +64,7 @@ inline CelValue ContainerAccessStep::LookupInMap(const CelMap* cel_map, return *maybe_value; } } - return CreateNoSuchKeyError(frame->memory_manager(), - "Key not found in map"); + return CreateNoSuchKeyError(frame->memory_manager(), key.DebugString()); } } @@ -78,7 +77,7 @@ inline CelValue ContainerAccessStep::LookupInMap(const CelMap* cel_map, return maybe_value.value(); } - return CreateNoSuchKeyError(frame->memory_manager(), "Key not found in map"); + return CreateNoSuchKeyError(frame->memory_manager(), key.DebugString()); } inline CelValue ContainerAccessStep::LookupInList(const CelList* cel_list, diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 89ce881e2..7f04f1f30 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -36,6 +36,7 @@ using ::google::api::expr::v1alpha1::Expr; using ::google::api::expr::v1alpha1::SourceInfo; using ::google::protobuf::Struct; using testing::_; +using testing::AllOf; using testing::HasSubstr; using cel::internal::StatusIs; @@ -201,6 +202,10 @@ TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccessNotFound) { CelValue::CreateString(&kKey1), std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), + StatusIs(absl::StatusCode::kNotFound, + AllOf(HasSubstr("Key not found in map : "), + HasSubstr("testkey1")))); } TEST_F(ContainerAccessStepTest, TestInvalidReceiverCreateContainerAccessStep) { From 0f50752745b125f1f122bfeb65fd07a839bae880 Mon Sep 17 00:00:00 2001 From: tswadell Date: Fri, 1 Apr 2022 22:47:48 +0000 Subject: [PATCH 053/111] Make map lookup error actually say which key wasn't found. Currently the error just says "Key not found in map" twice. PiperOrigin-RevId: 438928479 --- eval/eval/container_access_step.cc | 5 +++-- eval/eval/container_access_step_test.cc | 5 ----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index 576508422..cc0bdcb66 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -64,7 +64,8 @@ inline CelValue ContainerAccessStep::LookupInMap(const CelMap* cel_map, return *maybe_value; } } - return CreateNoSuchKeyError(frame->memory_manager(), key.DebugString()); + return CreateNoSuchKeyError(frame->memory_manager(), + "Key not found in map"); } } @@ -77,7 +78,7 @@ inline CelValue ContainerAccessStep::LookupInMap(const CelMap* cel_map, return maybe_value.value(); } - return CreateNoSuchKeyError(frame->memory_manager(), key.DebugString()); + return CreateNoSuchKeyError(frame->memory_manager(), "Key not found in map"); } inline CelValue ContainerAccessStep::LookupInList(const CelList* cel_list, diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 7f04f1f30..89ce881e2 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -36,7 +36,6 @@ using ::google::api::expr::v1alpha1::Expr; using ::google::api::expr::v1alpha1::SourceInfo; using ::google::protobuf::Struct; using testing::_; -using testing::AllOf; using testing::HasSubstr; using cel::internal::StatusIs; @@ -202,10 +201,6 @@ TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccessNotFound) { CelValue::CreateString(&kKey1), std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsError()); - EXPECT_THAT(*result.ErrorOrDie(), - StatusIs(absl::StatusCode::kNotFound, - AllOf(HasSubstr("Key not found in map : "), - HasSubstr("testkey1")))); } TEST_F(ContainerAccessStepTest, TestInvalidReceiverCreateContainerAccessStep) { From 8e0e91bf3852a4d350e1c38fc03194c72d8636b4 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 1 Apr 2022 23:16:01 +0000 Subject: [PATCH 054/111] Seperate proto-specific function adapter code from core implementation. Update expr:runtime::FunctionAdapter to be an alias of the proto-message enabled helper code. PiperOrigin-RevId: 438933702 --- eval/public/BUILD | 21 +- eval/public/cel_function_adapter.cc | 25 -- eval/public/cel_function_adapter.h | 339 ++++++------------------ eval/public/cel_function_adapter_impl.h | 310 ++++++++++++++++++++++ 4 files changed, 404 insertions(+), 291 deletions(-) delete mode 100644 eval/public/cel_function_adapter.cc create mode 100644 eval/public/cel_function_adapter_impl.h diff --git a/eval/public/BUILD b/eval/public/BUILD index 718e41dd8..1e0c64391 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -160,15 +160,30 @@ cc_library( ) cc_library( - name = "cel_function_adapter", - srcs = [ - "cel_function_adapter.cc", + name = "cel_function_adapter_impl", + hdrs = [ + "cel_function_adapter_impl.h", ], + deps = [ + ":cel_function", + ":cel_function_registry", + ":cel_value", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "cel_function_adapter", hdrs = [ "cel_function_adapter.h", ], deps = [ ":cel_function", + ":cel_function_adapter_impl", ":cel_function_registry", ":cel_value", "//eval/public/structs:cel_proto_wrapper", diff --git a/eval/public/cel_function_adapter.cc b/eval/public/cel_function_adapter.cc deleted file mode 100644 index 791abf3ed..000000000 --- a/eval/public/cel_function_adapter.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "eval/public/cel_function_adapter.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -namespace internal { - -template <> -absl::optional TypeCodeMatch() { - return CelValue::Type::kMessage; -} - -template <> -absl::optional TypeCodeMatch() { - return CelValue::Type::kAny; -} - -} // namespace internal - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/public/cel_function_adapter.h b/eval/public/cel_function_adapter.h index d2eb0c9ee..9c5bdb18e 100644 --- a/eval/public/cel_function_adapter.h +++ b/eval/public/cel_function_adapter.h @@ -10,6 +10,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "eval/public/cel_function.h" +#include "eval/public/cel_function_adapter_impl.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" @@ -19,60 +20,75 @@ namespace google::api::expr::runtime { namespace internal { -// TypeCodeMatch template function family -// Used for CEL type deduction based on C++ native -// type. -template -absl::optional TypeCodeMatch() { - int index = CelValue::IndexOf::value; - if (index < 0) return {}; - CelValue::Type arg_type = static_cast(index); - if (arg_type >= CelValue::Type::kAny) { - return {}; +// A type code matcher that adds support for google::protobuf::Message. +struct ProtoAdapterTypeCodeMatcher { + template + constexpr absl::optional type_code() { + return internal::TypeCodeMatcher().type_code(); } - return arg_type; -} - -template <> -absl::optional TypeCodeMatch(); - -// A bit of a trick - to pass Any kind of value, we use generic -// CelValue parameters. -template <> -absl::optional TypeCodeMatch(); - -template -bool AddType(std::vector*) { - return true; -} -// AddType template method -// Appends CEL type constant deduced from C++ type Type to descriptor -template -bool AddType(std::vector* arg_types) { - auto kind = TypeCodeMatch(); - if (!kind) { - return false; + template <> + constexpr absl::optional type_code() { + return CelValue::Type::kMessage; } +}; - arg_types->push_back(kind.value()); - - return AddType(arg_types); +// A value converter that handles wrapping google::protobuf::Messages as CelValues. +struct ProtoAdapterValueConverter + : public internal::ValueConverterBase { + using BaseType = internal::ValueConverterBase; + using BaseType::NativeToValue; + using BaseType::ValueToNative; - return true; -} + absl::Status NativeToValue(const ::google::protobuf::Message* value, + ::google::protobuf::Arena* arena, CelValue* result) { + if (value == nullptr) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Null Message pointer returned"); + } + *result = CelProtoWrapper::CreateMessage(value, arena); + return absl::OkStatus(); + } +}; +// Internal alias for message enabled function adapter. +// TODO(issues/5): follow-up will introduce lite proto (via +// CelValue::MessageWrapper) equivalent. +template +using ProtoMessageFunctionAdapter = + internal::FunctionAdapter; } // namespace internal // FunctionAdapter is a helper class that simplifies creation of CelFunction // implementations. -// It accepts method implementations as std::function, allowing -// them to be lambdas/regular C++ functions. CEL method descriptors are -// deduced based on C++ function signatures. // -// CelFunction::Evaluate will set result to the value returned by the handler. -// To handle errors, choose CelValue as the return type, and use the -// CreateError/Create* helpers in cel_value.h. +// The static Create member function accepts CelFunction::Evalaute method +// implementations as std::function, allowing them to be lambdas/regular C++ +// functions. CEL method descriptors ddeduced based on C++ function signatures. +// +// The adapted CelFunction::Evaluate implementation will set result to the +// value returned by the handler. To handle errors, choose CelValue as the +// return type, and use the CreateError/Create* helpers in cel_value.h. +// +// The wrapped std::function may return absl::StatusOr. If the wrapped +// function returns the absl::Status variant, the generated CelFunction +// implementation will return a non-ok status code, rather than a CelError +// wrapping an absl::Status value. A returned non-ok status indicates a hard +// error, meaning the interpreter cannot reasonably continue evaluation (e.g. +// data corruption or broken invariant). To create a CelError that follows +// logical pruning rules, the extension function implementation should return a +// CelError or an error-typed CelValue. +// +// FunctionAdapter +// ReturnType: the C++ return type of the function implementation +// Arguments: the C++ Argument type of the function implementation +// +// Static Methods: +// +// Create(absl::string_view function_name, bool receiver_style, +// FunctionType func) -> absl::StatusOr> // // Usage example: // @@ -82,227 +98,24 @@ bool AddType(std::vector* arg_types) { // // CEL_ASSIGN_OR_RETURN(auto cel_func, // FunctionAdapter::Create("<", false, func)); +// +// CreateAndRegister(absl::string_view function_name, bool receiver_style, +// FunctionType func, CelFunctionRegisry registry) +// -> absl::Status +// +// Usage example: +// +// auto func = [](::google::protobuf::Arena* arena, int64_t i, int64_t j) -> bool { +// return i < j; +// }; +// +// CEL_RETURN_IF_ERROR(( +// FunctionAdapter::CreateAndRegister("<", false, +// func, cel_expression_builder->GetRegistry())); +// template -class FunctionAdapter : public CelFunction { - public: - using FuncType = std::function; - - FunctionAdapter(CelFunctionDescriptor descriptor, FuncType handler) - : CelFunction(std::move(descriptor)), handler_(std::move(handler)) {} - - static absl::StatusOr> Create( - absl::string_view name, bool receiver_type, - std::function handler) { - std::vector arg_types; - arg_types.reserve(sizeof...(Arguments)); - - if (!internal::AddType<0, Arguments...>(&arg_types)) { - return absl::Status( - absl::StatusCode::kInternal, - absl::StrCat("Failed to create adapter for ", name, - ": failed to determine input parameter type")); - } - - return absl::make_unique( - CelFunctionDescriptor(name, receiver_type, std::move(arg_types)), - std::move(handler)); - } - - // Creates function handler and attempts to register it with - // supplied function registry. - static absl::Status CreateAndRegister( - absl::string_view name, bool receiver_type, - std::function handler, - CelFunctionRegistry* registry) { - CEL_ASSIGN_OR_RETURN(auto cel_function, - Create(name, receiver_type, std::move(handler))); - - return registry->Register(std::move(cel_function)); - } - -#if defined(__clang__) || !defined(__GNUC__) - template - inline absl::Status RunWrap(absl::Span arguments, - std::tuple<::google::protobuf::Arena*, Arguments...> input, - CelValue* result, ::google::protobuf::Arena* arena) const { - if (!ConvertFromValue(arguments[arg_index], - &std::get(input))) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Type conversion failed"); - } - return RunWrap(arguments, input, result, arena); - } - - template <> - inline absl::Status RunWrap( - absl::Span, - std::tuple<::google::protobuf::Arena*, Arguments...> input, CelValue* result, - ::google::protobuf::Arena* arena) const { - return CreateReturnValue(absl::apply(handler_, input), arena, result); - } -#else - inline absl::Status RunWrap( - std::function func, - ABSL_ATTRIBUTE_UNUSED const absl::Span argset, - ::google::protobuf::Arena* arena, CelValue* result, - ABSL_ATTRIBUTE_UNUSED int arg_index) const { - return CreateReturnValue(func(), arena, result); - } - - template - inline absl::Status RunWrap(std::function func, - const absl::Span argset, - ::google::protobuf::Arena* arena, CelValue* result, - int arg_index) const { - Arg argument; - if (!ConvertFromValue(argset[arg_index], &argument)) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Type conversion failed"); - } - - std::function wrapped_func = - [func, argument](Args... args) -> ReturnType { - return func(argument, args...); - }; - - return RunWrap(std::move(wrapped_func), argset, arena, result, - arg_index + 1); - } -#endif - - absl::Status Evaluate(absl::Span arguments, CelValue* result, - ::google::protobuf::Arena* arena) const override { - if (arguments.size() != sizeof...(Arguments)) { - return absl::Status(absl::StatusCode::kInternal, - "Argument number mismatch"); - } - -#if defined(__clang__) || !defined(__GNUC__) - std::tuple<::google::protobuf::Arena*, Arguments...> input; - std::get<0>(input) = arena; - return RunWrap<0>(arguments, input, result, arena); -#else - const auto* handler = &handler_; - std::function wrapped_handler = - [handler, arena](Arguments... args) -> ReturnType { - return (*handler)(arena, args...); - }; - return RunWrap(std::move(wrapped_handler), arguments, arena, result, 0); -#endif - } - - private: - template - static bool ConvertFromValue(CelValue value, ArgType* result) { - return value.GetValue(result); - } - - // Special conversion - from CelValue to CelValue - plain copy - static bool ConvertFromValue(CelValue value, CelValue* result) { - *result = std::move(value); - return true; - } - - // CreateReturnValue method wraps evaluation result with CelValue. - static absl::Status CreateReturnValue(bool value, ::google::protobuf::Arena*, - CelValue* result) { - *result = CelValue::CreateBool(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(int64_t value, ::google::protobuf::Arena*, - CelValue* result) { - *result = CelValue::CreateInt64(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(uint64_t value, ::google::protobuf::Arena*, - CelValue* result) { - *result = CelValue::CreateUint64(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(double value, ::google::protobuf::Arena*, - CelValue* result) { - *result = CelValue::CreateDouble(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(CelValue::StringHolder value, - ::google::protobuf::Arena*, CelValue* result) { - *result = CelValue::CreateString(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(CelValue::BytesHolder value, - ::google::protobuf::Arena*, CelValue* result) { - *result = CelValue::CreateBytes(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(const ::google::protobuf::Message* value, - ::google::protobuf::Arena* arena, - CelValue* result) { - if (value == nullptr) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Null Message pointer returned"); - } - *result = CelProtoWrapper::CreateMessage(value, arena); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(const CelList* value, ::google::protobuf::Arena*, - CelValue* result) { - if (value == nullptr) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Null CelList pointer returned"); - } - *result = CelValue::CreateList(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(const CelMap* value, ::google::protobuf::Arena*, - CelValue* result) { - if (value == nullptr) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Null CelMap pointer returned"); - } - *result = CelValue::CreateMap(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(CelValue::CelTypeHolder value, - ::google::protobuf::Arena*, CelValue* result) { - *result = CelValue::CreateCelType(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(const CelError* value, ::google::protobuf::Arena*, - CelValue* result) { - if (value == nullptr) { - return absl::Status(absl::StatusCode::kInvalidArgument, - "Null CelError pointer returned"); - } - *result = CelValue::CreateError(value); - return absl::OkStatus(); - } - - static absl::Status CreateReturnValue(const CelValue& value, ::google::protobuf::Arena*, - CelValue* result) { - *result = value; - return absl::OkStatus(); - } - - template - static absl::Status CreateReturnValue(absl::StatusOr value, - ::google::protobuf::Arena* arena, - CelValue* result) { - CEL_ASSIGN_OR_RETURN(auto held_value, value); - return CreateReturnValue(held_value, arena, result); - } - - FuncType handler_; -}; +using FunctionAdapter = + internal::ProtoMessageFunctionAdapter; } // namespace google::api::expr::runtime diff --git a/eval/public/cel_function_adapter_impl.h b/eval/public/cel_function_adapter_impl.h new file mode 100644 index 000000000..9e669a21a --- /dev/null +++ b/eval/public/cel_function_adapter_impl.h @@ -0,0 +1,310 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_IMPL_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_value.h" +#include "internal/status_macros.h" + +namespace google::api::expr::runtime { + +namespace internal { +// TypeCodeMatch template helper. +// Used for CEL type deduction based on C++ native type. +struct TypeCodeMatcher { + template + constexpr absl::optional type_code() { + int index = CelValue::IndexOf::value; + if (index < 0) return {}; + CelValue::Type arg_type = static_cast(index); + if (arg_type >= CelValue::Type::kAny) { + return {}; + } + return arg_type; + } + + // A bit of a trick - to pass Any kind of value, we use generic CelValue + // parameters. + template <> + constexpr absl::optional type_code() { + return CelValue::Type::kAny; + } +}; + +// Template helper to construct an argument list for a CelFunctionDescriptor. +template +struct TypeAdder { + template + bool AddType(std::vector* arg_types) const { + auto kind = TypeCodeMatcher().template type_code(); + if (!kind) { + return false; + } + + arg_types->push_back(*kind); + + return AddType(arg_types); + + return true; + } + + template + bool AddType(std::vector* arg_types) const { + return true; + } +}; + +// Template helper for C++ types to CEL conversions. +// Uses CRTP to dispatch to derived class overloads in the StatusOr helper. +template +struct ValueConverterBase { + // Value to native uwraps a CelValue to a native type. + template + bool ValueToNative(CelValue value, T* result) { + return value.GetValue(result); + } + + // Specialization for CelValue (any typed) + template <> + bool ValueToNative(CelValue value, CelValue* result) { + *result = std::move(value); + return true; + } + + // Native to value wraps a native return type to a CelValue. + absl::Status NativeToValue(bool value, ::google::protobuf::Arena*, CelValue* result) { + *result = CelValue::CreateBool(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(int64_t value, ::google::protobuf::Arena*, + CelValue* result) { + *result = CelValue::CreateInt64(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(uint64_t value, ::google::protobuf::Arena*, + CelValue* result) { + *result = CelValue::CreateUint64(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(double value, ::google::protobuf::Arena*, CelValue* result) { + *result = CelValue::CreateDouble(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(CelValue::StringHolder value, ::google::protobuf::Arena*, + CelValue* result) { + *result = CelValue::CreateString(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(CelValue::BytesHolder value, ::google::protobuf::Arena*, + CelValue* result) { + *result = CelValue::CreateBytes(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(const CelList* value, ::google::protobuf::Arena*, + CelValue* result) { + if (value == nullptr) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Null CelList pointer returned"); + } + *result = CelValue::CreateList(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(const CelMap* value, ::google::protobuf::Arena*, + CelValue* result) { + if (value == nullptr) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Null CelMap pointer returned"); + } + *result = CelValue::CreateMap(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(CelValue::CelTypeHolder value, ::google::protobuf::Arena*, + CelValue* result) { + *result = CelValue::CreateCelType(value); + return absl::OkStatus(); + } + + absl::Status NativeToValue(const CelError* value, ::google::protobuf::Arena*, + CelValue* result) { + if (value == nullptr) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Null CelError pointer returned"); + } + *result = CelValue::CreateError(value); + return absl::OkStatus(); + } + + // Special case -- just forward a CelValue. + absl::Status NativeToValue(const CelValue& value, ::google::protobuf::Arena*, + CelValue* result) { + *result = value; + return absl::OkStatus(); + } + + template + absl::Status NativeToValue(absl::StatusOr value, ::google::protobuf::Arena* arena, + CelValue* result) { + CEL_ASSIGN_OR_RETURN(auto held_value, value); + return Derived().NativeToValue(held_value, arena, result); + } +}; + +struct ValueConverter : public ValueConverterBase {}; + +// Generalized implementation for function adapters. See comments on +// instantiated versions for details on usage. +// +// TypeCodeMatcher provides the mapping from C++ type to CEL type. +// ValueConverter provides value conversions from native to CEL and vice versa. +// ReturnType and Arguments types are instantiated for the particular shape of +// the adapted functions. +template +class FunctionAdapter : public CelFunction { + public: + using FuncType = std::function; + using TypeAdder = internal::TypeAdder; + + FunctionAdapter(CelFunctionDescriptor descriptor, FuncType handler) + : CelFunction(std::move(descriptor)), handler_(std::move(handler)) {} + + static absl::StatusOr> Create( + absl::string_view name, bool receiver_type, + std::function handler) { + std::vector arg_types; + arg_types.reserve(sizeof...(Arguments)); + + if (!TypeAdder().template AddType<0, Arguments...>(&arg_types)) { + return absl::Status( + absl::StatusCode::kInternal, + absl::StrCat("Failed to create adapter for ", name, + ": failed to determine input parameter type")); + } + + return absl::make_unique( + CelFunctionDescriptor(name, receiver_type, std::move(arg_types)), + std::move(handler)); + } + + // Creates function handler and attempts to register it with + // supplied function registry. + static absl::Status CreateAndRegister( + absl::string_view name, bool receiver_type, + std::function handler, + CelFunctionRegistry* registry) { + CEL_ASSIGN_OR_RETURN(auto cel_function, + Create(name, receiver_type, std::move(handler))); + + return registry->Register(std::move(cel_function)); + } + +#if defined(__clang__) || !defined(__GNUC__) + template + inline absl::Status RunWrap(absl::Span arguments, + std::tuple<::google::protobuf::Arena*, Arguments...> input, + CelValue* result, ::google::protobuf::Arena* arena) const { + if (!ValueConverter().ValueToNative(arguments[arg_index], + &std::get(input))) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Type conversion failed"); + } + return RunWrap(arguments, input, result, arena); + } + + template <> + inline absl::Status RunWrap( + absl::Span, + std::tuple<::google::protobuf::Arena*, Arguments...> input, CelValue* result, + ::google::protobuf::Arena* arena) const { + return ValueConverter().NativeToValue(absl::apply(handler_, input), arena, + result); + } +#else + inline absl::Status RunWrap( + std::function func, + ABSL_ATTRIBUTE_UNUSED const absl::Span argset, + ::google::protobuf::Arena* arena, CelValue* result, + ABSL_ATTRIBUTE_UNUSED int arg_index) const { + return ValueConverter().NativeToValue(func(), arena, result); + } + + template + inline absl::Status RunWrap(std::function func, + const absl::Span argset, + ::google::protobuf::Arena* arena, CelValue* result, + int arg_index) const { + Arg argument; + if (!ValueConverter().ValueToNative(argset[arg_index], &argument)) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Type conversion failed"); + } + + std::function wrapped_func = + [func, argument](Args... args) -> ReturnType { + return func(argument, args...); + }; + + return RunWrap(std::move(wrapped_func), argset, arena, result, + arg_index + 1); + } +#endif + + absl::Status Evaluate(absl::Span arguments, CelValue* result, + ::google::protobuf::Arena* arena) const override { + if (arguments.size() != sizeof...(Arguments)) { + return absl::Status(absl::StatusCode::kInternal, + "Argument number mismatch"); + } + +#if defined(__clang__) || !defined(__GNUC__) + std::tuple<::google::protobuf::Arena*, Arguments...> input; + std::get<0>(input) = arena; + return RunWrap<0>(arguments, input, result, arena); +#else + const auto* handler = &handler_; + std::function wrapped_handler = + [handler, arena](Arguments... args) -> ReturnType { + return (*handler)(arena, args...); + }; + return RunWrap(std::move(wrapped_handler), arguments, arena, result, 0); +#endif + } + + private: + FuncType handler_; +}; + +} // namespace internal + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_IMPL_H_ From df1e2cf9a1d9d75fbc45d41b15f613d9a21142db Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Mon, 4 Apr 2022 17:00:00 +0000 Subject: [PATCH 055/111] Unnest legacy adapter interfaces. PiperOrigin-RevId: 439343585 --- eval/eval/create_struct_step.cc | 10 +- eval/eval/create_struct_step.h | 2 +- eval/public/cel_value.h | 2 + eval/public/structs/legacy_type_adapter.h | 132 +++++++++--------- .../structs/legacy_type_adapter_test.cc | 2 +- .../structs/proto_message_type_adapter.h | 4 +- 6 files changed, 78 insertions(+), 74 deletions(-) diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 2d1574c19..3328953e4 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -25,9 +25,9 @@ class CreateStructStepForMessage : public ExpressionStepBase { std::string field_name; }; - CreateStructStepForMessage( - int64_t expr_id, const LegacyTypeAdapter::MutationApis* type_adapter, - std::vector entries) + CreateStructStepForMessage(int64_t expr_id, + const LegacyTypeMutationApis* type_adapter, + std::vector entries) : ExpressionStepBase(expr_id), type_adapter_(type_adapter), entries_(std::move(entries)) {} @@ -37,7 +37,7 @@ class CreateStructStepForMessage : public ExpressionStepBase { private: absl::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; - const LegacyTypeAdapter::MutationApis* type_adapter_; + const LegacyTypeMutationApis* type_adapter_; std::vector entries_; }; @@ -158,7 +158,7 @@ absl::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { absl::StatusOr> CreateCreateStructStep( const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, - const LegacyTypeAdapter::MutationApis* type_adapter, int64_t expr_id) { + const LegacyTypeMutationApis* type_adapter, int64_t expr_id) { if (type_adapter != nullptr) { std::vector entries; diff --git a/eval/eval/create_struct_step.h b/eval/eval/create_struct_step.h index c47422782..8f8a2eeac 100644 --- a/eval/eval/create_struct_step.h +++ b/eval/eval/create_struct_step.h @@ -14,7 +14,7 @@ namespace google::api::expr::runtime { // Factory method for CreateStruct - based Execution step absl::StatusOr> CreateCreateStructStep( const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, - const LegacyTypeAdapter::MutationApis* type_adapter, int64_t expr_id); + const LegacyTypeMutationApis* type_adapter, int64_t expr_id); inline absl::StatusOr> CreateCreateStructStep( const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index f626d51d6..345e22b04 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -120,6 +120,8 @@ class CelValue { // // message_ptr(): get the MessageLite pointer for the wrapper. // + // access_apis(): get the accessors used for the type. + // // HasFullProto(): returns whether it's safe to downcast to google::protobuf::Message. using MessageWrapper = internal::MessageWrapper; diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h index 58dea0fd8..237b92b77 100644 --- a/eval/public/structs/legacy_type_adapter.h +++ b/eval/public/structs/legacy_type_adapter.h @@ -11,6 +11,9 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// Definitions for legacy type APIs to emulate the behavior of the new type +// system. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ @@ -21,90 +24,89 @@ namespace google::api::expr::runtime { +// Interface for mutation apis. +// Note: in the new type system, a type provider represents this by returning +// a cel::Type and cel::ValueFactory for the type. +class LegacyTypeMutationApis { + public: + virtual ~LegacyTypeMutationApis() = default; + + // Return whether the type defines the given field. + // TODO(issues/5): This is only used to eagerly fail during the planning + // phase. Check if it's safe to remove this behavior and fail at runtime. + virtual bool DefinesField(absl::string_view field_name) const = 0; + + // Create a new empty instance of the type. + // May return a status if the type is not possible to create. + virtual absl::StatusOr NewInstance( + cel::MemoryManager& memory_manager) const = 0; + + // Normalize special types to a native CEL value after building. + // The default implementation is a no-op. + // The interpreter guarantees that instance is uniquely owned by the + // interpreter, and can be safely mutated. + virtual absl::Status AdaptFromWellKnownType( + cel::MemoryManager& memory_manager, CelValue& instance) const { + return absl::OkStatus(); + } + + // Set field on instance to value. + // The interpreter guarantees that instance is uniquely owned by the + // interpreter, and can be safely mutated. + virtual absl::Status SetField(absl::string_view field_name, + const CelValue& value, + cel::MemoryManager& memory_manager, + CelValue& instance) const = 0; +}; + +// Interface for access apis. +// Note: in new type system this is integrated into the StructValue (via +// dynamic dispatch to concerete implementations). +class LegacyTypeAccessApis { + public: + virtual ~LegacyTypeAccessApis() = default; + + // Return whether an instance of the type has field set to a non-default + // value. + virtual absl::StatusOr HasField(absl::string_view field_name, + const CelValue& value) const = 0; + + // Access field on instance. + virtual absl::StatusOr GetField( + absl::string_view field_name, const CelValue& instance, + cel::MemoryManager& memory_manager) const = 0; +}; + // Type information about a legacy Struct type. // Provides methods to the interpreter for interacting with a custom type. // -// This provides Apis for emulating the behavior of new types working on -// existing cel values. -// -// MutationApis provide equivalent behavior to a cel::Type and cel::ValueFactory -// (resolved from a type name). +// mutation_apis() provide equivalent behavior to a cel::Type and +// cel::ValueFactory (resolved from a type name). // -// AccessApis provide equivalent behavior to cel::StructValue accessors (virtual -// dispatch to a concrete implementation for accessing underlying values). +// access_apis() provide equivalent behavior to cel::StructValue accessors +// (virtual dispatch to a concrete implementation for accessing underlying +// values). // // This class is a simple wrapper around (nullable) pointers to the interface // implementations. The underlying pointers are expected to be valid as long as // the type provider that returned this object. class LegacyTypeAdapter { public: - // Interface for mutation apis. - // Note: in the new type system, a type provider represents this by returning - // a cel::Type and cel::ValueFactory for the type. - class MutationApis { - public: - virtual ~MutationApis() = default; - - // Return whether the type defines the given field. - // TODO(issues/5): This is only used to eagerly fail during the planning - // phase. Check if it's safe to remove this behavior and fail at runtime. - virtual bool DefinesField(absl::string_view field_name) const = 0; - - // Create a new empty instance of the type. - // May return a status if the type is not possible to create. - virtual absl::StatusOr NewInstance( - cel::MemoryManager& memory_manager) const = 0; - - // Normalize special types to a native CEL value after building. - // The default implementation is a no-op. - // The interpreter guarantees that instance is uniquely owned by the - // interpreter, and can be safely mutated. - virtual absl::Status AdaptFromWellKnownType( - cel::MemoryManager& memory_manager, CelValue& instance) const { - return absl::OkStatus(); - } - - // Set field on instance to value. - // The interpreter guarantees that instance is uniquely owned by the - // interpreter, and can be safely mutated. - virtual absl::Status SetField(absl::string_view field_name, - const CelValue& value, - cel::MemoryManager& memory_manager, - CelValue& instance) const = 0; - }; - - // Interface for access apis. - // Note: in new type system this is integrated into the StructValue (via - // dynamic dispatch to concerete implementations). - class AccessApis { - public: - virtual ~AccessApis() = default; - - // Return whether an instance of the type has field set to a non-default - // value. - virtual absl::StatusOr HasField(absl::string_view field_name, - const CelValue& value) const = 0; - - // Access field on instance. - virtual absl::StatusOr GetField( - absl::string_view field_name, const CelValue& instance, - cel::MemoryManager& memory_manager) const = 0; - }; - - LegacyTypeAdapter(const AccessApis* access, const MutationApis* mutation) + LegacyTypeAdapter(const LegacyTypeAccessApis* access, + const LegacyTypeMutationApis* mutation) : access_apis_(access), mutation_apis_(mutation) {} // Apis for access for the represented type. // If null, access is not supported (this is an opaque type). - const AccessApis* access_apis() { return access_apis_; } + const LegacyTypeAccessApis* access_apis() { return access_apis_; } // Apis for mutation for the represented type. // If null, mutation is not supported (this type cannot be created). - const MutationApis* mutation_apis() { return mutation_apis_; } + const LegacyTypeMutationApis* mutation_apis() { return mutation_apis_; } private: - const AccessApis* access_apis_; - const MutationApis* mutation_apis_; + const LegacyTypeAccessApis* access_apis_; + const LegacyTypeMutationApis* mutation_apis_; }; } // namespace google::api::expr::runtime diff --git a/eval/public/structs/legacy_type_adapter_test.cc b/eval/public/structs/legacy_type_adapter_test.cc index ce93f9f71..ac2cc53cb 100644 --- a/eval/public/structs/legacy_type_adapter_test.cc +++ b/eval/public/structs/legacy_type_adapter_test.cc @@ -23,7 +23,7 @@ namespace google::api::expr::runtime { namespace { -class TestMutationApiImpl : public LegacyTypeAdapter::MutationApis { +class TestMutationApiImpl : public LegacyTypeMutationApis { public: TestMutationApiImpl() {} bool DefinesField(absl::string_view field_name) const override { diff --git a/eval/public/structs/proto_message_type_adapter.h b/eval/public/structs/proto_message_type_adapter.h index 7827c608b..46cf54d65 100644 --- a/eval/public/structs/proto_message_type_adapter.h +++ b/eval/public/structs/proto_message_type_adapter.h @@ -26,8 +26,8 @@ namespace google::api::expr::runtime { -class ProtoMessageTypeAdapter : public LegacyTypeAdapter::AccessApis, - public LegacyTypeAdapter::MutationApis { +class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, + public LegacyTypeMutationApis { public: ProtoMessageTypeAdapter(const google::protobuf::Descriptor* descriptor, google::protobuf::MessageFactory* message_factory, From 02a47eaf243e4c1227c4b4c85170400095ea3c29 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 5 Apr 2022 17:17:14 +0000 Subject: [PATCH 056/111] Introduce TypeInfo apis (indirection to resolve cyclic dependency between CelValue and accessors). PiperOrigin-RevId: 439613648 --- eval/public/structs/BUILD | 6 ++ eval/public/structs/legacy_type_info_apis.h | 61 +++++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 eval/public/structs/legacy_type_info_apis.h diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 17f57c10f..732b569b5 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -188,3 +188,9 @@ cc_test( "//internal:testing", ], ) + +cc_library( + name = "legacy_type_info_apis", + hdrs = ["legacy_type_info_apis.h"], + deps = ["//eval/public:cel_value_internal"], +) diff --git a/eval/public/structs/legacy_type_info_apis.h b/eval/public/structs/legacy_type_info_apis.h new file mode 100644 index 000000000..26d77ea40 --- /dev/null +++ b/eval/public/structs/legacy_type_info_apis.h @@ -0,0 +1,61 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_INFO_APIS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_INFO_APIS_H_ + +#include + +#include "eval/public/cel_value_internal.h" + +namespace google::api::expr::runtime { + +// Forward declared to resolve cyclic dependency. +class LegacyTypeAccessApis; + +// Interface for providing type info from a user defined type (represented as a +// message). +// +// Provides ability to obtain field access apis, type info, and debug +// representation of a message/ +// +// This is implemented as a separate class from LegacyTypeAccessApis to resolve +// cyclic dependency between CelValue (which needs to access these apis to +// provide DebugString and ObtainCelTypename) and LegacyTypeAccessApis (which +// needs to return CelValue type for field access). +class LegacyTypeInfoApis { + public: + virtual ~LegacyTypeInfoApis() = default; + + // Return a debug representation of the wrapped message. + virtual std::string DebugString( + const internal::MessageWrapper& wrapped_message) const = 0; + + // Return a const-reference to the typename for the wrapped message's type. + // The CEL interpreter assumes that the typename is owned externally and will + // outlive any CelValues created by the interpreter. + virtual const std::string& GetTypename( + const internal::MessageWrapper& wrapped_message) const = 0; + + // Return a pointer to the wrapped message's access api implementation. + // The CEL interpreter assumes that the is owned externally and will + // outlive any CelValues created by the interpreter. + // Nullptr means the value does not provide access apis. + virtual const LegacyTypeAccessApis* GetAccessApis( + const internal::MessageWrapper& wrapped_message) const = 0; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_INFO_APIS_H_ From f629fc34d73ef1f9bf99a6809d9ae73d72775db6 Mon Sep 17 00:00:00 2001 From: tswadell Date: Wed, 6 Apr 2022 17:34:29 +0000 Subject: [PATCH 057/111] Mark tests as opt-out or opt-in for heterogeneous equality PiperOrigin-RevId: 439878041 --- eval/eval/container_access_step_test.cc | 1 + eval/public/builtin_func_test.cc | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 89ce881e2..c6630d87b 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -462,6 +462,7 @@ TEST_F(ContainerAccessHeterogeneousLookupsTest, StringKeyUnaffected) { class ContainerAccessHeterogeneousLookupsDisabledTest : public testing::Test { public: ContainerAccessHeterogeneousLookupsDisabledTest() { + options_.enable_heterogeneous_equality = false; builder_ = CreateCelExpressionBuilder(options_); } diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index e38a49a0c..c30633004 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -123,9 +123,10 @@ class BuiltinsTest : public ::testing::Test { // Helper method. Looks up in registry and tests for no matching equality // overload. void TestNoMatchingEqualOverload(const CelValue& ref, const CelValue& other) { + options_.enable_heterogeneous_equality = false; CelValue eq_value; ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kEqual, {}, {ref, other}, &eq_value)); + PerformRun(builtin::kEqual, {}, {ref, other}, &eq_value, options_)); ASSERT_TRUE(eq_value.IsError()) << " for " << CelValue::TypeName(ref.type()) << " and " << CelValue::TypeName(other.type()); @@ -133,7 +134,7 @@ class BuiltinsTest : public ::testing::Test { CelValue ineq_value; ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kInequal, {}, {ref, other}, &ineq_value)); + PerformRun(builtin::kInequal, {}, {ref, other}, &ineq_value, options_)); ASSERT_TRUE(ineq_value.IsError()) << " for " << CelValue::TypeName(ref.type()) << " and " << CelValue::TypeName(other.type()); @@ -1617,6 +1618,7 @@ TEST_F(BuiltinsTest, TestInt64MapIn) { data[value] = CelValue::CreateInt64(value * value); } FakeInt64Map cel_map(data); + options_.enable_heterogeneous_equality = false; TestInMap(&cel_map, CelValue::CreateInt64(-4), true); TestInMap(&cel_map, CelValue::CreateInt64(4), false); TestInMap(&cel_map, CelValue::CreateUint64(3), false); @@ -1640,6 +1642,7 @@ TEST_F(BuiltinsTest, TestUint64MapIn) { data[value] = CelValue::CreateUint64(value * value); } FakeUint64Map cel_map(data); + options_.enable_heterogeneous_equality = false; TestInMap(&cel_map, CelValue::CreateUint64(4), true); TestInMap(&cel_map, CelValue::CreateUint64(44), false); TestInMap(&cel_map, CelValue::CreateInt64(4), false); From bf7da914353649a6255c591807e99b6c2772dafa Mon Sep 17 00:00:00 2001 From: tswadell Date: Wed, 6 Apr 2022 21:48:33 +0000 Subject: [PATCH 058/111] Remove usage of CelValue.DebugString() from CEL error messages. PiperOrigin-RevId: 439940887 --- eval/public/containers/field_access.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/eval/public/containers/field_access.cc b/eval/public/containers/field_access.cc index d3019cda3..75ca40970 100644 --- a/eval/public/containers/field_access.cc +++ b/eval/public/containers/field_access.cc @@ -736,9 +736,9 @@ absl::Status SetValueToSingleField(const CelValue& value, ? absl::OkStatus() : absl::InvalidArgumentError(absl::Substitute( "Could not assign supplied argument to message \"$0\" field " - "\"$1\" of type $2: value was \"$3\"", + "\"$1\" of type $2: value type \"$3\"", msg->GetDescriptor()->name(), desc->name(), - desc->type_name(), value.DebugString())); + desc->type_name(), CelValue::TypeName(value.type()))); } absl::Status AddValueToRepeatedField(const CelValue& value, @@ -748,10 +748,10 @@ absl::Status AddValueToRepeatedField(const CelValue& value, return (setter.SetFieldFromCelValue(value)) ? absl::OkStatus() : absl::InvalidArgumentError(absl::Substitute( - "Could not add supplied argument \"$2\" to message \"$0\" " - "field \"$1\".", + "Could not add supplied argument to message \"$0\" field " + "\"$1\" of type $2: value type \"$3\"", msg->GetDescriptor()->name(), desc->name(), - value.DebugString())); + desc->type_name(), CelValue::TypeName(value.type()))); } } // namespace google::api::expr::runtime From 076a06fac04d11dce4a397b346654076312526ba Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 8 Apr 2022 19:34:34 +0000 Subject: [PATCH 059/111] Add basic benchmark for deeply nested field accesses. PiperOrigin-RevId: 440428189 --- eval/tests/benchmark_test.cc | 51 ++++++++++++++++++++++++++++++++ eval/tests/request_context.proto | 14 +++++++++ 2 files changed, 65 insertions(+) diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index 782ecdcc4..e21864865 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -576,6 +576,57 @@ void BM_ReadProtoMap(benchmark::State& state) { BENCHMARK(BM_ReadProtoMap); +void BM_NestedProtoFieldRead(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( + !request.a.b.c.d.e + )cel")); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + + RequestContext request; + request.mutable_a()->mutable_b()->mutable_c()->mutable_d()->set_e(false); + activation.InsertValue("request", + CelProtoWrapper::CreateMessage(&request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } +} + +BENCHMARK(BM_NestedProtoFieldRead); + +void BM_NestedProtoFieldReadDefaults(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( + !request.a.b.c.d.e + )cel")); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + + RequestContext request; + activation.InsertValue("request", + CelProtoWrapper::CreateMessage(&request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } +} + +BENCHMARK(BM_NestedProtoFieldReadDefaults); + // This expression has no equivalent CEL expression. // Sum a square with a nested comprehension constexpr char kNestedListSum[] = R"( diff --git a/eval/tests/request_context.proto b/eval/tests/request_context.proto index 2e307d3b1..446cd2df2 100644 --- a/eval/tests/request_context.proto +++ b/eval/tests/request_context.proto @@ -6,8 +6,22 @@ option cc_enable_arenas = true; // Message representing a sample request context message RequestContext { + // Example for deeply nested messages. + message D { + bool e = 1; + } + message C { + D d = 1; + } + message B { + C c = 1; + } + message A { + B b = 1; + } string ip = 1; string path = 2; string token = 3; map headers = 4; + A a = 5; } From fee8517c843757257037af6f09a81c45bc129329 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 8 Apr 2022 19:36:59 +0000 Subject: [PATCH 060/111] Add simple benchmarks for proto list and struct accesses. PiperOrigin-RevId: 440428702 --- eval/tests/BUILD | 1 + eval/tests/benchmark_test.cc | 66 ++++++++++++++++++++++++++++++++++-- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 4c80f6b19..5e792de12 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -38,6 +38,7 @@ cc_test( "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index e21864865..220bcb1d7 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -4,6 +4,8 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/rpc/context/attribute_context.pb.h" #include "google/protobuf/text_format.h" #include "absl/base/attributes.h" #include "absl/container/btree_map.h" @@ -31,8 +33,9 @@ namespace runtime { namespace { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; +using ::google::api::expr::v1alpha1::Expr; +using ::google::api::expr::v1alpha1::SourceInfo; +using ::google::rpc::context::AttributeContext; // Benchmark test // Evaluates cel expression: @@ -627,6 +630,65 @@ void BM_NestedProtoFieldReadDefaults(benchmark::State& state) { BENCHMARK(BM_NestedProtoFieldReadDefaults); +void BM_ProtoStructAccess(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( + has(request.auth.claims.iss) && request.auth.claims.iss == 'accounts.google.com' + )cel")); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + + AttributeContext::Request request; + auto* auth = request.mutable_auth(); + (*auth->mutable_claims()->mutable_fields())["iss"].set_string_value( + "accounts.google.com"); + activation.InsertValue("request", + CelProtoWrapper::CreateMessage(&request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } +} + +BENCHMARK(BM_ProtoStructAccess); + +void BM_ProtoListAccess(benchmark::State& state) { + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( + "//.../accessLevels/MY_LEVEL_4" in request.auth.access_levels + )cel")); + auto builder = CreateCelExpressionBuilder(); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder->CreateExpression(&parsed_expr.expr(), nullptr)); + + AttributeContext::Request request; + auto* auth = request.mutable_auth(); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_0"); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_1"); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_2"); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_3"); + auth->add_access_levels("//.../accessLevels/MY_LEVEL_4"); + activation.InsertValue("request", + CelProtoWrapper::CreateMessage(&request, &arena)); + + for (auto _ : state) { + ASSERT_OK_AND_ASSIGN(CelValue result, + cel_expr->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + } +} + +BENCHMARK(BM_ProtoListAccess); + // This expression has no equivalent CEL expression. // Sum a square with a nested comprehension constexpr char kNestedListSum[] = R"( From d1bd6429b168455e3584c97f8dc0158504c16b1c Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 8 Apr 2022 20:53:08 +0000 Subject: [PATCH 061/111] Internal change PiperOrigin-RevId: 440445755 --- base/BUILD | 1 + base/internal/value.post.h | 8 ++ base/internal/value.pre.h | 3 + base/value.cc | 1 + base/value.h | 114 +++++++++++++++++++++++ base/value_factory.h | 9 ++ base/value_test.cc | 184 +++++++++++++++++++++++++++++++++---- 7 files changed, 302 insertions(+), 18 deletions(-) diff --git a/base/BUILD b/base/BUILD index 9b0131504..b8b6ff4e8 100644 --- a/base/BUILD +++ b/base/BUILD @@ -179,6 +179,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", ], ) diff --git a/base/internal/value.post.h b/base/internal/value.post.h index c3aa600ac..522d917b0 100644 --- a/base/internal/value.post.h +++ b/base/internal/value.post.h @@ -46,6 +46,10 @@ inline internal::TypeInfo GetStructValueTypeId( return struct_value.TypeId(); } +inline internal::TypeInfo GetListValueTypeId(const ListValue& list_value) { + return list_value.TypeId(); +} + // Implementation of BytesValue that is stored inlined within a handle. Since // absl::Cord is reference counted itself, this is more efficient than storing // this on the heap. @@ -494,6 +498,8 @@ class ValueHandle final : public ValueHandleBase { ValueHandle& operator=(TransientValueHandle&& other) { if (not_empty_and_inlined()) { DestructInlined(*this); + } else { + Reset(); } Base::Move(other, *this); return *this; @@ -586,6 +592,7 @@ class ValueHandle final : public ValueHandleBase { DestructInlined(*this); } else if (reffed()) { Unref(); + Reset(); } Base::Move(other, *this); return *this; @@ -666,6 +673,7 @@ CEL_INTERNAL_VALUE_DECL(DurationValue); CEL_INTERNAL_VALUE_DECL(TimestampValue); CEL_INTERNAL_VALUE_DECL(EnumValue); CEL_INTERNAL_VALUE_DECL(StructValue); +CEL_INTERNAL_VALUE_DECL(ListValue); #undef CEL_INTERNAL_VALUE_DECL } // namespace cel diff --git a/base/internal/value.pre.h b/base/internal/value.pre.h index 88c7eefd4..f38af32e4 100644 --- a/base/internal/value.pre.h +++ b/base/internal/value.pre.h @@ -28,6 +28,7 @@ namespace cel { class EnumValue; class StructValue; +class ListValue; namespace base_internal { @@ -51,6 +52,8 @@ internal::TypeInfo GetEnumValueTypeId(const EnumValue& enum_value); internal::TypeInfo GetStructValueTypeId(const StructValue& struct_value); +internal::TypeInfo GetListValueTypeId(const ListValue& list_value); + class InlinedCordBytesValue; class InlinedStringViewBytesValue; class StringBytesValue; diff --git a/base/value.cc b/base/value.cc index e28c0400e..c743a0772 100644 --- a/base/value.cc +++ b/base/value.cc @@ -64,6 +64,7 @@ CEL_INTERNAL_VALUE_IMPL(DurationValue); CEL_INTERNAL_VALUE_IMPL(TimestampValue); CEL_INTERNAL_VALUE_IMPL(EnumValue); CEL_INTERNAL_VALUE_IMPL(StructValue); +CEL_INTERNAL_VALUE_IMPL(ListValue); #undef CEL_INTERNAL_VALUE_IMPL namespace { diff --git a/base/value.h b/base/value.h index f6234ef72..4637c1e58 100644 --- a/base/value.h +++ b/base/value.h @@ -28,6 +28,7 @@ #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "absl/types/optional.h" #include "absl/types/variant.h" #include "base/handle.h" #include "base/internal/value.pre.h" // IWYU pragma: export @@ -52,7 +53,9 @@ class DurationValue; class TimestampValue; class EnumValue; class StructValue; +class ListValue; class ValueFactory; +class TypedListValueFactory; namespace internal { template @@ -86,6 +89,7 @@ class Value : public base_internal::Resource { friend class TimestampValue; friend class EnumValue; friend class StructValue; + friend class ListValue; friend class base_internal::ValueHandleBase; friend class base_internal::StringBytesValue; friend class base_internal::ExternalDataBytesValue; @@ -811,6 +815,116 @@ class StructValue : public Value { return ::cel::internal::TypeId(); \ } +// ListValue represents an instance of cel::ListType. +class ListValue : public Value { + public: + // TODO(issues/5): implement iterators so we can have cheap concated lists + + Transient type() const final { + ABSL_ASSERT(type_); + return type_; + } + + Kind kind() const final { return Kind::kList; } + + virtual size_t size() const = 0; + + virtual bool empty() const { return size() == 0; } + + virtual absl::StatusOr> Get( + ValueFactory& value_factory, size_t index) const = 0; + + protected: + explicit ListValue(const Persistent& type) : type_(type) {} + + private: + friend internal::TypeInfo base_internal::GetListValueTypeId( + const ListValue& list_value); + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kList; } + + ListValue(const ListValue&) = delete; + ListValue(ListValue&&) = delete; + + // TODO(issues/5): I do not like this, we should have these two take a + // ValueFactory and return absl::StatusOr and absl::Status. We support + // lazily created values, so errors can occur during equality testing. + // Especially if there are different value implementations for the same type. + bool Equals(const Value& other) const override = 0; + void HashValue(absl::HashState state) const override = 0; + + std::pair SizeAndAlignment() const override = 0; + + // Called by CEL_IMPLEMENT_ENUM_VALUE() and Is() to perform type checking. + virtual internal::TypeInfo TypeId() const = 0; + + const Persistent type_; +}; + +// TODO(issues/5): generalize the macros to avoid repeating them when they +// are ultimately very similar. + +// CEL_DECLARE_LIST_VALUE declares `list_value` as an list value. It must +// be part of the class definition of `list_value`. +// +// class MyListValue : public cel::ListValue { +// ... +// private: +// CEL_DECLARE_LIST_VALUE(MyListValue); +// }; +#define CEL_DECLARE_LIST_VALUE(list_value) \ + private: \ + friend class ::cel::base_internal::ValueHandleBase; \ + \ + static bool Is(const ::cel::Value& value); \ + \ + ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ + \ + ::cel::internal::TypeInfo TypeId() const override; + +// CEL_IMPLEMENT_LIST_VALUE implements `list_value` as an list +// value. It must be called after the class definition of `list_value`. +// +// class MyListValue : public cel::ListValue { +// ... +// private: +// CEL_DECLARE_LIST_VALUE(MyListValue); +// }; +// +// CEL_IMPLEMENT_LIST_VALUE(MyListValue); +#define CEL_IMPLEMENT_LIST_VALUE(list_value) \ + static_assert(::std::is_base_of_v<::cel::ListValue, list_value>, \ + #list_value " must inherit from cel::ListValue"); \ + static_assert(!::std::is_abstract_v, \ + "this must not be abstract"); \ + \ + bool list_value::Is(const ::cel::Value& value) { \ + return value.kind() == ::cel::Kind::kList && \ + ::cel::base_internal::GetListValueTypeId( \ + ::cel::internal::down_cast(value)) == \ + ::cel::internal::TypeId(); \ + } \ + \ + ::std::pair<::std::size_t, ::std::size_t> list_value::SizeAndAlignment() \ + const { \ + static_assert( \ + ::std::is_same_v>>, \ + "this must be the same as " #list_value); \ + return ::std::pair<::std::size_t, ::std::size_t>(sizeof(list_value), \ + alignof(list_value)); \ + } \ + \ + ::cel::internal::TypeInfo list_value::TypeId() const { \ + return ::cel::internal::TypeId(); \ + } + } // namespace cel // value.pre.h forward declares types so they can be friended above. The types diff --git a/base/value_factory.h b/base/value_factory.h index cc072fe80..450673213 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -152,6 +152,15 @@ class ValueFactory final { std::remove_const_t>(memory_manager(), std::forward(args)...); } + template + EnableIfBaseOfT>> CreateListValue( + const Persistent& type, + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal::PersistentHandleFactory::template Make< + std::remove_const_t>(memory_manager(), type, + std::forward(args)...); + } + private: friend class BytesValue; friend class StringValue; diff --git a/base/value_test.cc b/base/value_test.cc index 46a4680bb..8a20ab43f 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -27,6 +27,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/time/time.h" #include "base/memory_manager.h" #include "base/type.h" @@ -343,11 +344,55 @@ class TestStructType final : public StructType { CEL_IMPLEMENT_STRUCT_TYPE(TestStructType); +class TestListValue final : public ListValue { + public: + explicit TestListValue(const Persistent& type, + std::vector elements) + : ListValue(type), elements_(std::move(elements)) { + ABSL_ASSERT(type->element().Is()); + } + + size_t size() const override { return elements_.size(); } + + absl::StatusOr> Get(ValueFactory& value_factory, + size_t index) const override { + return value_factory.CreateIntValue(elements_[index]); + } + + std::string DebugString() const override { + return absl::StrCat("[", absl::StrJoin(elements_, ", "), "]"); + } + + const std::vector& value() const { return elements_; } + + private: + bool Equals(const Value& other) const override { + return Is(other) && + elements_ == + internal::down_cast(other).elements_; + } + + void HashValue(absl::HashState state) const override { + absl::HashState::combine(std::move(state), type(), elements_); + } + + std::vector elements_; + + CEL_DECLARE_LIST_VALUE(TestListValue); +}; + +CEL_IMPLEMENT_LIST_VALUE(TestListValue); + template Persistent Must(absl::StatusOr> status_or_handle) { return std::move(status_or_handle).value(); } +template +Transient Must(absl::StatusOr> status_or_handle) { + return std::move(status_or_handle).value(); +} + template constexpr void IS_INITIALIZED(T&) {} @@ -395,7 +440,8 @@ TEST(Value, DefaultConstructor) { struct ConstructionAssignmentTestCase final { std::string name; - std::function(ValueFactory&)> default_value; + std::function(TypeFactory&, ValueFactory&)> + default_value; }; using ConstructionAssignmentTest = @@ -403,27 +449,33 @@ using ConstructionAssignmentTest = TEST_P(ConstructionAssignmentTest, CopyConstructor) { const auto& test_case = GetParam(); + TypeFactory type_factory(MemoryManager::Global()); ValueFactory value_factory(MemoryManager::Global()); - Persistent from(test_case.default_value(value_factory)); + Persistent from( + test_case.default_value(type_factory, value_factory)); Persistent to(from); IS_INITIALIZED(to); - EXPECT_EQ(to, test_case.default_value(value_factory)); + EXPECT_EQ(to, test_case.default_value(type_factory, value_factory)); } TEST_P(ConstructionAssignmentTest, MoveConstructor) { const auto& test_case = GetParam(); + TypeFactory type_factory(MemoryManager::Global()); ValueFactory value_factory(MemoryManager::Global()); - Persistent from(test_case.default_value(value_factory)); + Persistent from( + test_case.default_value(type_factory, value_factory)); Persistent to(std::move(from)); IS_INITIALIZED(from); EXPECT_EQ(from, value_factory.GetNullValue()); - EXPECT_EQ(to, test_case.default_value(value_factory)); + EXPECT_EQ(to, test_case.default_value(type_factory, value_factory)); } TEST_P(ConstructionAssignmentTest, CopyAssignment) { const auto& test_case = GetParam(); + TypeFactory type_factory(MemoryManager::Global()); ValueFactory value_factory(MemoryManager::Global()); - Persistent from(test_case.default_value(value_factory)); + Persistent from( + test_case.default_value(type_factory, value_factory)); Persistent to; to = from; EXPECT_EQ(to, from); @@ -431,53 +483,71 @@ TEST_P(ConstructionAssignmentTest, CopyAssignment) { TEST_P(ConstructionAssignmentTest, MoveAssignment) { const auto& test_case = GetParam(); + TypeFactory type_factory(MemoryManager::Global()); ValueFactory value_factory(MemoryManager::Global()); - Persistent from(test_case.default_value(value_factory)); + Persistent from( + test_case.default_value(type_factory, value_factory)); Persistent to; to = std::move(from); IS_INITIALIZED(from); EXPECT_EQ(from, value_factory.GetNullValue()); - EXPECT_EQ(to, test_case.default_value(value_factory)); + EXPECT_EQ(to, test_case.default_value(type_factory, value_factory)); } INSTANTIATE_TEST_SUITE_P( ConstructionAssignmentTest, ConstructionAssignmentTest, testing::ValuesIn({ {"Null", - [](ValueFactory& value_factory) -> Persistent { + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { return value_factory.GetNullValue(); }}, {"Bool", - [](ValueFactory& value_factory) -> Persistent { + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { return value_factory.CreateBoolValue(false); }}, {"Int", - [](ValueFactory& value_factory) -> Persistent { + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { return value_factory.CreateIntValue(0); }}, {"Uint", - [](ValueFactory& value_factory) -> Persistent { + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { return value_factory.CreateUintValue(0); }}, {"Double", - [](ValueFactory& value_factory) -> Persistent { + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { return value_factory.CreateDoubleValue(0.0); }}, {"Duration", - [](ValueFactory& value_factory) -> Persistent { + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { return Must(value_factory.CreateDurationValue(absl::ZeroDuration())); }}, {"Timestamp", - [](ValueFactory& value_factory) -> Persistent { + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { return Must(value_factory.CreateTimestampValue(absl::UnixEpoch())); }}, {"Error", - [](ValueFactory& value_factory) -> Persistent { + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { return value_factory.CreateErrorValue(absl::CancelledError()); }}, {"Bytes", - [](ValueFactory& value_factory) -> Persistent { - return Must(value_factory.CreateBytesValue(0)); + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must(value_factory.CreateBytesValue(nullptr)); + }}, + {"List", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must(value_factory.CreateListValue( + Must(type_factory.CreateListType(type_factory.GetIntType())), + std::vector{})); }}, }), [](const testing::TestParamInfo& info) { @@ -1907,6 +1977,78 @@ TEST(StructValue, HasField) { StatusIs(absl::StatusCode::kNotFound)); } +TEST(Value, List) { + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto list_type, + type_factory.CreateListType(type_factory.GetIntType())); + ASSERT_OK_AND_ASSIGN(auto zero_value, + value_factory.CreateListValue( + list_type, std::vector{})); + EXPECT_TRUE(zero_value.Is()); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Must(value_factory.CreateListValue( + list_type, std::vector{}))); + EXPECT_EQ(zero_value->kind(), Kind::kList); + EXPECT_EQ(zero_value->type(), list_type); + EXPECT_EQ(zero_value.As()->value(), std::vector{}); + + ASSERT_OK_AND_ASSIGN(auto one_value, + value_factory.CreateListValue( + list_type, std::vector{1})); + EXPECT_TRUE(one_value.Is()); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value->kind(), Kind::kList); + EXPECT_EQ(one_value->type(), list_type); + EXPECT_EQ(one_value.As()->value(), std::vector{1}); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(ListValue, DebugString) { + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto list_type, + type_factory.CreateListType(type_factory.GetIntType())); + ASSERT_OK_AND_ASSIGN(auto list_value, + value_factory.CreateListValue( + list_type, std::vector{})); + EXPECT_EQ(list_value->DebugString(), "[]"); + ASSERT_OK_AND_ASSIGN(list_value, + value_factory.CreateListValue( + list_type, std::vector{0, 1, 2, 3, 4, 5})); + EXPECT_EQ(list_value->DebugString(), "[0, 1, 2, 3, 4, 5]"); +} + +TEST(ListValue, Get) { + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto list_type, + type_factory.CreateListType(type_factory.GetIntType())); + ASSERT_OK_AND_ASSIGN(auto list_value, + value_factory.CreateListValue( + list_type, std::vector{})); + EXPECT_TRUE(list_value->empty()); + EXPECT_EQ(list_value->size(), 0); + + ASSERT_OK_AND_ASSIGN(list_value, + value_factory.CreateListValue( + list_type, std::vector{0, 1, 2})); + EXPECT_FALSE(list_value->empty()); + EXPECT_EQ(list_value->size(), 3); + EXPECT_EQ(Must(list_value->Get(value_factory, 0)), + value_factory.CreateIntValue(0)); + EXPECT_EQ(Must(list_value->Get(value_factory, 1)), + value_factory.CreateIntValue(1)); + EXPECT_EQ(Must(list_value->Get(value_factory, 2)), + value_factory.CreateIntValue(2)); +} + TEST(Value, SupportsAbslHash) { ValueFactory value_factory(MemoryManager::Global()); TypeFactory type_factory(MemoryManager::Global()); @@ -1919,6 +2061,11 @@ TEST(Value, SupportsAbslHash) { EnumValue::New(enum_type, value_factory, EnumType::ConstantId("VALUE1"))); ASSERT_OK_AND_ASSIGN(auto struct_value, StructValue::New(struct_type, value_factory)); + ASSERT_OK_AND_ASSIGN(auto list_type, + type_factory.CreateListType(type_factory.GetIntType())); + ASSERT_OK_AND_ASSIGN(auto list_value, + value_factory.CreateListValue( + list_type, std::vector{})); EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ Persistent(value_factory.GetNullValue()), Persistent( @@ -1941,6 +2088,7 @@ TEST(Value, SupportsAbslHash) { Must(value_factory.CreateStringValue(absl::Cord("bar")))), Persistent(enum_value), Persistent(struct_value), + Persistent(list_value), })); } From 221541bc196bf56e69b6847a1cd096929aab5b89 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 12 Apr 2022 21:02:12 +0000 Subject: [PATCH 062/111] Branch MessageWrapper library to help resolve cyclic dependencies. PiperOrigin-RevId: 441284720 --- eval/public/structs/BUILD | 58 ++ eval/public/structs/cel_proto_wrap_util.cc | 887 +++++++++++++++++ eval/public/structs/cel_proto_wrap_util.h | 48 + .../structs/cel_proto_wrap_util_test.cc | 890 ++++++++++++++++++ eval/public/structs/protobuf_value_factory.h | 36 + 5 files changed, 1919 insertions(+) create mode 100644 eval/public/structs/cel_proto_wrap_util.cc create mode 100644 eval/public/structs/cel_proto_wrap_util.h create mode 100644 eval/public/structs/cel_proto_wrap_util_test.cc create mode 100644 eval/public/structs/protobuf_value_factory.h diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 732b569b5..99d686e9d 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -38,6 +38,64 @@ cc_library( ], ) +cc_library( + name = "protobuf_value_factory", + hdrs = [ + "protobuf_value_factory.h", + ], + deps = [ + "//eval/public:cel_value", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "cel_proto_wrap_util", + srcs = [ + "cel_proto_wrap_util.cc", + ], + hdrs = [ + "cel_proto_wrap_util.h", + ], + deps = [ + ":protobuf_value_factory", + "//eval/public:cel_value", + "//eval/testutil:test_message_cc_proto", + "//internal:overflow", + "//internal:proto_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "cel_proto_wrap_util_test", + size = "small", + srcs = [ + "cel_proto_wrap_util_test.cc", + ], + deps = [ + ":cel_proto_wrap_util", + ":protobuf_value_factory", + "//eval/public:cel_value", + "//eval/public/containers:container_backed_list_impl", + "//eval/public/containers:container_backed_map_impl", + "//eval/testutil:test_message_cc_proto", + "//internal:proto_util", + "//internal:status_macros", + "//internal:testing", + "//testutil:util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "cel_proto_descriptor_pool_builder", srcs = ["cel_proto_descriptor_pool_builder.cc"], diff --git a/eval/public/structs/cel_proto_wrap_util.cc b/eval/public/structs/cel_proto_wrap_util.cc new file mode 100644 index 000000000..25f0c41e8 --- /dev/null +++ b/eval/public/structs/cel_proto_wrap_util.cc @@ -0,0 +1,887 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/cel_proto_wrap_util.h" + +#include + +#include +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/message.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/protobuf_value_factory.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/overflow.h" +#include "internal/proto_util.h" + +namespace google::api::expr::runtime::internal { + +namespace { + +using google::api::expr::internal::DecodeDuration; +using google::api::expr::internal::DecodeTime; +using google::api::expr::internal::EncodeTime; +using google::protobuf::Any; +using google::protobuf::BoolValue; +using google::protobuf::BytesValue; +using google::protobuf::DoubleValue; +using google::protobuf::Duration; +using google::protobuf::FloatValue; +using google::protobuf::Int32Value; +using google::protobuf::Int64Value; +using google::protobuf::ListValue; +using google::protobuf::StringValue; +using google::protobuf::Struct; +using google::protobuf::Timestamp; +using google::protobuf::UInt32Value; +using google::protobuf::UInt64Value; +using google::protobuf::Value; +using google::protobuf::Arena; +using google::protobuf::Descriptor; +using google::protobuf::DescriptorPool; +using google::protobuf::Message; +using google::protobuf::MessageFactory; + +// kMaxIntJSON is defined as the Number.MAX_SAFE_INTEGER value per EcmaScript 6. +constexpr int64_t kMaxIntJSON = (1ll << 53) - 1; + +// kMinIntJSON is defined as the Number.MIN_SAFE_INTEGER value per EcmaScript 6. +constexpr int64_t kMinIntJSON = -kMaxIntJSON; + +// Forward declaration for google.protobuf.Value +google::protobuf::Message* MessageFromValue(const CelValue& value, Value* json); + +// IsJSONSafe indicates whether the int is safely representable as a floating +// point value in JSON. +static bool IsJSONSafe(int64_t i) { + return i >= kMinIntJSON && i <= kMaxIntJSON; +} + +// IsJSONSafe indicates whether the uint is safely representable as a floating +// point value in JSON. +static bool IsJSONSafe(uint64_t i) { + return i <= static_cast(kMaxIntJSON); +} + +// Map implementation wrapping google.protobuf.ListValue +class DynamicList : public CelList { + public: + DynamicList(const ListValue* values, ProtobufValueFactory factory, + Arena* arena) + : arena_(arena), factory_(std::move(factory)), values_(values) {} + + CelValue operator[](int index) const override; + + // List size + int size() const override { return values_->values_size(); } + + private: + Arena* arena_; + ProtobufValueFactory factory_; + const ListValue* values_; +}; + +// Map implementation wrapping google.protobuf.Struct. +class DynamicMap : public CelMap { + public: + DynamicMap(const Struct* values, ProtobufValueFactory factory, Arena* arena) + : arena_(arena), + factory_(std::move(factory)), + values_(values), + key_list_(values) {} + + absl::StatusOr Has(const CelValue& key) const override { + CelValue::StringHolder str_key; + if (!key.GetValue(&str_key)) { + // Not a string key. + return absl::InvalidArgumentError(absl::StrCat( + "Invalid map key type: '", CelValue::TypeName(key.type()), "'")); + } + + return values_->fields().contains(std::string(str_key.value())); + } + + absl::optional operator[](CelValue key) const override; + + int size() const override { return values_->fields_size(); } + + const CelList* ListKeys() const override { return &key_list_; } + + private: + // List of keys in Struct.fields map. + // It utilizes lazy initialization, to avoid performance penalties. + class DynamicMapKeyList : public CelList { + public: + explicit DynamicMapKeyList(const Struct* values) + : values_(values), keys_(), initialized_(false) {} + + // Index access + CelValue operator[](int index) const override { + CheckInit(); + return keys_[index]; + } + + // List size + int size() const override { + CheckInit(); + return values_->fields_size(); + } + + private: + void CheckInit() const { + absl::MutexLock lock(&mutex_); + if (!initialized_) { + for (const auto& it : values_->fields()) { + keys_.push_back(CelValue::CreateString(&it.first)); + } + initialized_ = true; + } + } + + const Struct* values_; + mutable absl::Mutex mutex_; + mutable std::vector keys_; + mutable bool initialized_; + }; + + Arena* arena_; + ProtobufValueFactory factory_; + const Struct* values_; + const DynamicMapKeyList key_list_; +}; + +// ValueFactory provides ValueFromMessage(....) function family. +// Functions of this family create CelValue object from specific subtypes of +// protobuf message. +class ValueFactory { + public: + ValueFactory(const ProtobufValueFactory& factory, google::protobuf::Arena* arena) + : factory_(factory), arena_(arena) {} + + CelValue ValueFromMessage(const Duration* duration) { + return CelValue::CreateDuration(DecodeDuration(*duration)); + } + + CelValue ValueFromMessage(const Timestamp* timestamp) { + return CelValue::CreateTimestamp(DecodeTime(*timestamp)); + } + + CelValue ValueFromMessage(const ListValue* list_values) { + return CelValue::CreateList( + Arena::Create(arena_, list_values, factory_, arena_)); + } + + CelValue ValueFromMessage(const Struct* struct_value) { + return CelValue::CreateMap( + Arena::Create(arena_, struct_value, factory_, arena_)); + } + + CelValue ValueFromMessage(const Any* any_value, + const DescriptorPool* descriptor_pool, + MessageFactory* message_factory) { + auto type_url = any_value->type_url(); + auto pos = type_url.find_last_of('/'); + if (pos == absl::string_view::npos) { + // TODO(issues/25) What error code? + // Malformed type_url + return CreateErrorValue(arena_, "Malformed type_url string"); + } + + std::string full_name = std::string(type_url.substr(pos + 1)); + const Descriptor* nested_descriptor = + descriptor_pool->FindMessageTypeByName(full_name); + + if (nested_descriptor == nullptr) { + // Descriptor not found for the type + // TODO(issues/25) What error code? + return CreateErrorValue(arena_, "Descriptor not found"); + } + + const Message* prototype = message_factory->GetPrototype(nested_descriptor); + if (prototype == nullptr) { + // Failed to obtain prototype for the descriptor + // TODO(issues/25) What error code? + return CreateErrorValue(arena_, "Prototype not found"); + } + + Message* nested_message = prototype->New(arena_); + if (!any_value->UnpackTo(nested_message)) { + // Failed to unpack. + // TODO(issues/25) What error code? + return CreateErrorValue(arena_, "Failed to unpack Any into message"); + } + + return UnwrapMessageToValue(nested_message, factory_, arena_); + } + + CelValue ValueFromMessage(const Any* any_value) { + return ValueFromMessage(any_value, DescriptorPool::generated_pool(), + MessageFactory::generated_factory()); + } + + CelValue ValueFromMessage(const BoolValue* wrapper) { + return CelValue::CreateBool(wrapper->value()); + } + + CelValue ValueFromMessage(const Int32Value* wrapper) { + return CelValue::CreateInt64(wrapper->value()); + } + + CelValue ValueFromMessage(const UInt32Value* wrapper) { + return CelValue::CreateUint64(wrapper->value()); + } + + CelValue ValueFromMessage(const Int64Value* wrapper) { + return CelValue::CreateInt64(wrapper->value()); + } + + CelValue ValueFromMessage(const UInt64Value* wrapper) { + return CelValue::CreateUint64(wrapper->value()); + } + + CelValue ValueFromMessage(const FloatValue* wrapper) { + return CelValue::CreateDouble(wrapper->value()); + } + + CelValue ValueFromMessage(const DoubleValue* wrapper) { + return CelValue::CreateDouble(wrapper->value()); + } + + CelValue ValueFromMessage(const StringValue* wrapper) { + return CelValue::CreateString(&wrapper->value()); + } + + CelValue ValueFromMessage(const BytesValue* wrapper) { + // BytesValue stores value as Cord + return CelValue::CreateBytes( + Arena::Create(arena_, std::string(wrapper->value()))); + } + + CelValue ValueFromMessage(const Value* value) { + switch (value->kind_case()) { + case Value::KindCase::kNullValue: + return CelValue::CreateNull(); + case Value::KindCase::kNumberValue: + return CelValue::CreateDouble(value->number_value()); + case Value::KindCase::kStringValue: + return CelValue::CreateString(&value->string_value()); + case Value::KindCase::kBoolValue: + return CelValue::CreateBool(value->bool_value()); + case Value::KindCase::kStructValue: + return UnwrapMessageToValue(&value->struct_value(), factory_, arena_); + case Value::KindCase::kListValue: + return UnwrapMessageToValue(&value->list_value(), factory_, arena_); + default: + return CelValue::CreateNull(); + } + } + + private: + const ProtobufValueFactory& factory_; + google::protobuf::Arena* arena_; +}; + +// Class makes CelValue from generic protobuf Message. +// It holds a registry of CelValue factories for specific subtypes of Message. +// If message does not match any of types stored in registry, generic +// message-containing CelValue is created. +class ValueFromMessageMaker { + public: + template + static CelValue CreateWellknownTypeValue(const google::protobuf::Message* msg, + const ProtobufValueFactory& factory, + Arena* arena) { + const MessageType* message = + google::protobuf::DynamicCastToGenerated(msg); + if (message == nullptr) { + auto message_copy = Arena::CreateMessage(arena); + if (MessageType::descriptor() == msg->GetDescriptor()) { + message_copy->CopyFrom(*msg); + message = message_copy; + } else { + // message of well-known type but from a descriptor pool other than the + // generated one. + std::string serialized_msg; + if (msg->SerializeToString(&serialized_msg) && + message_copy->ParseFromString(serialized_msg)) { + message = message_copy; + } + } + } + return ValueFactory(factory, arena).ValueFromMessage(message); + } + + static absl::optional CreateValue( + const google::protobuf::Message* message, const ProtobufValueFactory& factory, + Arena* arena) { + switch (message->GetDescriptor()->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: + return CreateWellknownTypeValue(message, factory, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: + return CreateWellknownTypeValue(message, factory, arena); + // WELLKNOWNTYPE_FIELDMASK has no special CelValue type + default: + return absl::nullopt; + } + } + + // Non-copyable, non-assignable + ValueFromMessageMaker(const ValueFromMessageMaker&) = delete; + ValueFromMessageMaker& operator=(const ValueFromMessageMaker&) = delete; +}; + +CelValue DynamicList::operator[](int index) const { + return ValueFactory(factory_, arena_) + .ValueFromMessage(&values_->values(index)); +} + +absl::optional DynamicMap::operator[](CelValue key) const { + CelValue::StringHolder str_key; + if (!key.GetValue(&str_key)) { + // Not a string key. + return CreateErrorValue(arena_, absl::InvalidArgumentError(absl::StrCat( + "Invalid map key type: '", + CelValue::TypeName(key.type()), "'"))); + } + + auto it = values_->fields().find(std::string(str_key.value())); + if (it == values_->fields().end()) { + return absl::nullopt; + } + + return ValueFactory(factory_, arena_).ValueFromMessage(&it->second); +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, Duration* duration) { + absl::Duration val; + if (!value.GetValue(&val)) { + return nullptr; + } + auto status = google::api::expr::internal::EncodeDuration(val, duration); + if (!status.ok()) { + return nullptr; + } + return duration; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, BoolValue* wrapper) { + bool val; + if (!value.GetValue(&val)) { + return nullptr; + } + wrapper->set_value(val); + return wrapper; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, BytesValue* wrapper) { + CelValue::BytesHolder view_val; + if (!value.GetValue(&view_val)) { + return nullptr; + } + wrapper->set_value(view_val.value().data()); + return wrapper; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, DoubleValue* wrapper) { + double val; + if (!value.GetValue(&val)) { + return nullptr; + } + wrapper->set_value(val); + return wrapper; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, FloatValue* wrapper) { + double val; + if (!value.GetValue(&val)) { + return nullptr; + } + // Abort the conversion if the value is outside the float range. + if (val > std::numeric_limits::max()) { + wrapper->set_value(std::numeric_limits::infinity()); + return wrapper; + } + if (val < std::numeric_limits::lowest()) { + wrapper->set_value(-std::numeric_limits::infinity()); + return wrapper; + } + wrapper->set_value(val); + return wrapper; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, Int32Value* wrapper) { + int64_t val; + if (!value.GetValue(&val)) { + return nullptr; + } + // Abort the conversion if the value is outside the int32_t range. + if (!cel::internal::CheckedInt64ToInt32(val).ok()) { + return nullptr; + } + wrapper->set_value(val); + return wrapper; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, Int64Value* wrapper) { + int64_t val; + if (!value.GetValue(&val)) { + return nullptr; + } + wrapper->set_value(val); + return wrapper; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, StringValue* wrapper) { + CelValue::StringHolder view_val; + if (!value.GetValue(&view_val)) { + return nullptr; + } + wrapper->set_value(view_val.value().data()); + return wrapper; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, Timestamp* timestamp) { + absl::Time val; + if (!value.GetValue(&val)) { + return nullptr; + } + auto status = EncodeTime(val, timestamp); + if (!status.ok()) { + return nullptr; + } + return timestamp; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, UInt32Value* wrapper) { + uint64_t val; + if (!value.GetValue(&val)) { + return nullptr; + } + // Abort the conversion if the value is outside the uint32_t range. + if (!cel::internal::CheckedUint64ToUint32(val).ok()) { + return nullptr; + } + wrapper->set_value(val); + return wrapper; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, UInt64Value* wrapper) { + uint64_t val; + if (!value.GetValue(&val)) { + return nullptr; + } + wrapper->set_value(val); + return wrapper; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, ListValue* json_list) { + if (!value.IsList()) { + return nullptr; + } + const CelList& list = *value.ListOrDie(); + for (int i = 0; i < list.size(); i++) { + auto e = list[i]; + Value* elem = json_list->add_values(); + auto result = MessageFromValue(e, elem); + if (result == nullptr) { + return nullptr; + } + } + return json_list; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, Struct* json_struct) { + if (!value.IsMap()) { + return nullptr; + } + const CelMap& map = *value.MapOrDie(); + const auto& keys = *map.ListKeys(); + auto fields = json_struct->mutable_fields(); + for (int i = 0; i < keys.size(); i++) { + auto k = keys[i]; + // If the key is not a string type, abort the conversion. + if (!k.IsString()) { + return nullptr; + } + absl::string_view key = k.StringOrDie().value(); + + auto v = map[k]; + if (!v.has_value()) { + return nullptr; + } + Value field_value; + auto result = MessageFromValue(*v, &field_value); + // If the value is not a valid JSON type, abort the conversion. + if (result == nullptr) { + return nullptr; + } + (*fields)[std::string(key)] = field_value; + } + return json_struct; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, Value* json) { + switch (value.type()) { + case CelValue::Type::kBool: { + bool val; + if (value.GetValue(&val)) { + json->set_bool_value(val); + return json; + } + } break; + case CelValue::Type::kBytes: { + // Base64 encode byte strings to ensure they can safely be transpored + // in a JSON string. + CelValue::BytesHolder val; + if (value.GetValue(&val)) { + json->set_string_value(absl::Base64Escape(val.value())); + return json; + } + } break; + case CelValue::Type::kDouble: { + double val; + if (value.GetValue(&val)) { + json->set_number_value(val); + return json; + } + } break; + case CelValue::Type::kDuration: { + // Convert duration values to a protobuf JSON format. + absl::Duration val; + if (value.GetValue(&val)) { + auto encode = google::api::expr::internal::EncodeDurationToString(val); + if (!encode.ok()) { + return nullptr; + } + json->set_string_value(*encode); + return json; + } + } break; + case CelValue::Type::kInt64: { + int64_t val; + // Convert int64_t values within the int53 range to doubles, otherwise + // serialize the value to a string. + if (value.GetValue(&val)) { + if (IsJSONSafe(val)) { + json->set_number_value(val); + } else { + json->set_string_value(absl::StrCat(val)); + } + return json; + } + } break; + case CelValue::Type::kString: { + CelValue::StringHolder val; + if (value.GetValue(&val)) { + json->set_string_value(val.value().data()); + return json; + } + } break; + case CelValue::Type::kTimestamp: { + // Convert timestamp values to a protobuf JSON format. + absl::Time val; + if (value.GetValue(&val)) { + auto encode = google::api::expr::internal::EncodeTimeToString(val); + if (!encode.ok()) { + return nullptr; + } + json->set_string_value(*encode); + return json; + } + } break; + case CelValue::Type::kUint64: { + uint64_t val; + // Convert uint64_t values within the int53 range to doubles, otherwise + // serialize the value to a string. + if (value.GetValue(&val)) { + if (IsJSONSafe(val)) { + json->set_number_value(val); + } else { + json->set_string_value(absl::StrCat(val)); + } + return json; + } + } break; + case CelValue::Type::kList: { + auto lv = MessageFromValue(value, json->mutable_list_value()); + if (lv != nullptr) { + return json; + } + } break; + case CelValue::Type::kMap: { + auto sv = MessageFromValue(value, json->mutable_struct_value()); + if (sv != nullptr) { + return json; + } + } break; + case CelValue::Type::kNullType: + json->set_null_value(protobuf::NULL_VALUE); + return json; + default: + return nullptr; + } + return nullptr; +} + +google::protobuf::Message* MessageFromValue(const CelValue& value, Any* any) { + // In open source, any->PackFrom() returns void rather than boolean. + switch (value.type()) { + case CelValue::Type::kBool: { + BoolValue v; + auto msg = MessageFromValue(value, &v); + if (msg != nullptr) { + any->PackFrom(*msg); + return any; + } + } break; + case CelValue::Type::kBytes: { + BytesValue v; + auto msg = MessageFromValue(value, &v); + if (msg != nullptr) { + any->PackFrom(*msg); + return any; + } + } break; + case CelValue::Type::kDouble: { + DoubleValue v; + auto msg = MessageFromValue(value, &v); + if (msg != nullptr) { + any->PackFrom(*msg); + return any; + } + } break; + case CelValue::Type::kDuration: { + Duration v; + auto msg = MessageFromValue(value, &v); + if (msg != nullptr) { + any->PackFrom(*msg); + return any; + } + } break; + case CelValue::Type::kInt64: { + Int64Value v; + auto msg = MessageFromValue(value, &v); + if (msg != nullptr) { + any->PackFrom(*msg); + return any; + } + } break; + case CelValue::Type::kString: { + StringValue v; + auto msg = MessageFromValue(value, &v); + if (msg != nullptr) { + any->PackFrom(*msg); + return any; + } + } break; + case CelValue::Type::kTimestamp: { + Timestamp v; + auto msg = MessageFromValue(value, &v); + if (msg != nullptr) { + any->PackFrom(*msg); + return any; + } + } break; + case CelValue::Type::kUint64: { + UInt64Value v; + auto msg = MessageFromValue(value, &v); + if (msg != nullptr) { + any->PackFrom(*msg); + return any; + } + } break; + case CelValue::Type::kList: { + ListValue v; + auto msg = MessageFromValue(value, &v); + if (msg != nullptr) { + any->PackFrom(*msg); + return any; + } + } break; + case CelValue::Type::kMap: { + Struct v; + auto msg = MessageFromValue(value, &v); + if (msg != nullptr) { + any->PackFrom(*msg); + return any; + } + } break; + case CelValue::Type::kNullType: { + Value v; + auto msg = MessageFromValue(value, &v); + if (msg != nullptr) { + any->PackFrom(*msg); + return any; + } + } break; + case CelValue::Type::kMessage: { + any->PackFrom(*(value.MessageOrDie())); + return any; + } break; + default: + break; + } + return nullptr; +} + +// Factory class, responsible for populating a Message type instance with the +// value of a simple CelValue. +class MessageFromValueFactory { + public: + virtual ~MessageFromValueFactory() {} + virtual const google::protobuf::Descriptor* GetDescriptor() const = 0; + virtual absl::optional WrapMessage( + const CelValue& value, Arena* arena) const = 0; +}; + +// MessageFromValueMaker makes a specific protobuf Message instance based on +// the desired protobuf type name and an input CelValue. +// +// It holds a registry of CelValue factories for specific subtypes of Message. +// If message does not match any of types stored in registry, an the factory +// returns an absent value. +class MessageFromValueMaker { + public: + // Non-copyable, non-assignable + MessageFromValueMaker(const MessageFromValueMaker&) = delete; + MessageFromValueMaker& operator=(const MessageFromValueMaker&) = delete; + + template + static google::protobuf::Message* WrapWellknownTypeMessage(const CelValue& value, + Arena* arena) { + // If the value is a message type, see if it is already of the proper type + // name, and return it directly. + if (value.IsMessage()) { + const auto* msg = value.MessageOrDie(); + if (MessageType::descriptor()->well_known_type() == + msg->GetDescriptor()->well_known_type()) { + return nullptr; + } + } + // Otherwise, allocate an empty message type, and attempt to populate it + // using the proper MessageFromValue overload. + auto* msg_buffer = Arena::CreateMessage(arena); + return MessageFromValue(value, msg_buffer); + } + + static google::protobuf::Message* MaybeWrapMessage(const google::protobuf::Descriptor* descriptor, + const CelValue& value, + Arena* arena) { + switch (descriptor->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: + return WrapWellknownTypeMessage(value, arena); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: + return WrapWellknownTypeMessage(value, arena); + // WELLKNOWNTYPE_FIELDMASK has no special CelValue type + default: + return nullptr; + } + } +}; + +} // namespace + +CelValue UnwrapMessageToValue(const google::protobuf::Message* value, + const ProtobufValueFactory& factory, + Arena* arena) { + // Messages are Nullable types + if (value == nullptr) { + return CelValue::CreateNull(); + } + + absl::optional special_value = + ValueFromMessageMaker::CreateValue(value, factory, arena); + if (special_value.has_value()) { + return *special_value; + } + return factory(value); +} + +const google::protobuf::Message* MaybeWrapValueToMessage( + const google::protobuf::Descriptor* descriptor, const CelValue& value, Arena* arena) { + google::protobuf::Message* msg = + MessageFromValueMaker::MaybeWrapMessage(descriptor, value, arena); + return msg; +} + +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/structs/cel_proto_wrap_util.h b/eval/public/structs/cel_proto_wrap_util.h new file mode 100644 index 000000000..a03f6ba2f --- /dev/null +++ b/eval/public/structs/cel_proto_wrap_util.h @@ -0,0 +1,48 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAP_UTIL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAP_UTIL_H_ + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/descriptor.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/protobuf_value_factory.h" +#include "internal/proto_util.h" + +namespace google::api::expr::runtime::internal { + +// UnwrapValue creates CelValue from google::protobuf::Message. +// As some of CEL basic types are subclassing google::protobuf::Message, +// this method contains type checking and downcasts. +CelValue UnwrapMessageToValue(const google::protobuf::Message* value, + const ProtobufValueFactory& factory, + google::protobuf::Arena* arena); + +// MaybeWrapValue attempts to wrap the input value in a proto message with +// the given type_name. If the value can be wrapped, it is returned as a +// protobuf message. Otherwise, the result will be nullptr. +// +// This method is the complement to MaybeUnwrapValue which may unwrap a protobuf +// message to native CelValue representation during a protobuf field read. +// Just as CreateMessage should only be used when reading protobuf values, +// MaybeWrapValue should only be used when assigning protobuf fields. +const google::protobuf::Message* MaybeWrapValueToMessage( + const google::protobuf::Descriptor* descriptor, const CelValue& value, + google::protobuf::Arena* arena); + +} // namespace google::api::expr::runtime::internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAP_UTIL_H_ diff --git a/eval/public/structs/cel_proto_wrap_util_test.cc b/eval/public/structs/cel_proto_wrap_util_test.cc new file mode 100644 index 000000000..c4d5e0762 --- /dev/null +++ b/eval/public/structs/cel_proto_wrap_util_test.cc @@ -0,0 +1,890 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/cel_proto_wrap_util.h" + +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/empty.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/structs/protobuf_value_factory.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/proto_util.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "testutil/util.h" + +namespace google::api::expr::runtime::internal { + +namespace { + +using testing::Eq; +using testing::UnorderedPointwise; + +using google::protobuf::Duration; +using google::protobuf::ListValue; +using google::protobuf::Struct; +using google::protobuf::Timestamp; +using google::protobuf::Value; + +using google::protobuf::Any; +using google::protobuf::BoolValue; +using google::protobuf::BytesValue; +using google::protobuf::DoubleValue; +using google::protobuf::FloatValue; +using google::protobuf::Int32Value; +using google::protobuf::Int64Value; +using google::protobuf::StringValue; +using google::protobuf::UInt32Value; +using google::protobuf::UInt64Value; + +using google::protobuf::Arena; + +CelValue ProtobufValueFactoryImpl(const google::protobuf::Message* m) { + return CelValue::CreateMessageWrapper(CelValue::MessageWrapper(m)); +} + +class CelProtoWrapperTest : public ::testing::Test { + protected: + CelProtoWrapperTest() {} + + void ExpectWrappedMessage(const CelValue& value, + const google::protobuf::Message& message) { + // Test the input value wraps to the destination message type. + auto* result = + MaybeWrapValueToMessage(message.GetDescriptor(), value, arena()); + EXPECT_TRUE(result != nullptr); + EXPECT_THAT(result, testutil::EqualsProto(message)); + + // Ensure that double wrapping results in the object being wrapped once. + auto* identity = MaybeWrapValueToMessage( + message.GetDescriptor(), ProtobufValueFactoryImpl(result), arena()); + EXPECT_TRUE(identity == nullptr); + + // Check to make sure that even dynamic messages can be used as input to + // the wrapping call. + result = MaybeWrapValueToMessage(ReflectedCopy(message)->GetDescriptor(), + value, arena()); + EXPECT_TRUE(result != nullptr); + EXPECT_THAT(result, testutil::EqualsProto(message)); + } + + void ExpectNotWrapped(const CelValue& value, const google::protobuf::Message& message) { + // Test the input value does not wrap by asserting value == result. + auto result = + MaybeWrapValueToMessage(message.GetDescriptor(), value, arena()); + EXPECT_TRUE(result == nullptr); + } + + template + void ExpectUnwrappedPrimitive(const google::protobuf::Message& message, T result) { + CelValue cel_value = + UnwrapMessageToValue(&message, &ProtobufValueFactoryImpl, arena()); + T value; + EXPECT_TRUE(cel_value.GetValue(&value)); + EXPECT_THAT(value, Eq(result)); + + T dyn_value; + CelValue cel_dyn_value = UnwrapMessageToValue( + ReflectedCopy(message).get(), &ProtobufValueFactoryImpl, arena()); + EXPECT_THAT(cel_dyn_value.type(), Eq(cel_value.type())); + EXPECT_TRUE(cel_dyn_value.GetValue(&dyn_value)); + EXPECT_THAT(value, Eq(dyn_value)); + } + + void ExpectUnwrappedMessage(const google::protobuf::Message& message, + google::protobuf::Message* result) { + CelValue cel_value = + UnwrapMessageToValue(&message, &ProtobufValueFactoryImpl, arena()); + if (result == nullptr) { + EXPECT_TRUE(cel_value.IsNull()); + return; + } + EXPECT_TRUE(cel_value.IsMessage()); + EXPECT_THAT(cel_value.MessageOrDie(), testutil::EqualsProto(*result)); + } + + std::unique_ptr ReflectedCopy( + const google::protobuf::Message& message) { + std::unique_ptr dynamic_value( + factory_.GetPrototype(message.GetDescriptor())->New()); + dynamic_value->CopyFrom(message); + return dynamic_value; + } + + Arena* arena() { return &arena_; } + + private: + Arena arena_; + google::protobuf::DynamicMessageFactory factory_; +}; + +TEST_F(CelProtoWrapperTest, TestType) { + Duration msg_duration; + msg_duration.set_seconds(2); + msg_duration.set_nanos(3); + + CelValue value_duration2 = + UnwrapMessageToValue(&msg_duration, &ProtobufValueFactoryImpl, arena()); + EXPECT_THAT(value_duration2.type(), Eq(CelValue::Type::kDuration)); + + Timestamp msg_timestamp; + msg_timestamp.set_seconds(2); + msg_timestamp.set_nanos(3); + + CelValue value_timestamp2 = + UnwrapMessageToValue(&msg_timestamp, &ProtobufValueFactoryImpl, arena()); + EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); +} + +// This test verifies CelValue support of Duration type. +TEST_F(CelProtoWrapperTest, TestDuration) { + Duration msg_duration; + msg_duration.set_seconds(2); + msg_duration.set_nanos(3); + CelValue value = + UnwrapMessageToValue(&msg_duration, &ProtobufValueFactoryImpl, arena()); + EXPECT_THAT(value.type(), Eq(CelValue::Type::kDuration)); + + Duration out; + auto status = expr::internal::EncodeDuration(value.DurationOrDie(), &out); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(out, testutil::EqualsProto(msg_duration)); +} + +// This test verifies CelValue support of Timestamp type. +TEST_F(CelProtoWrapperTest, TestTimestamp) { + Timestamp msg_timestamp; + msg_timestamp.set_seconds(2); + msg_timestamp.set_nanos(3); + + CelValue value = + UnwrapMessageToValue(&msg_timestamp, &ProtobufValueFactoryImpl, arena()); + + EXPECT_TRUE(value.IsTimestamp()); + Timestamp out; + auto status = expr::internal::EncodeTime(value.TimestampOrDie(), &out); + EXPECT_TRUE(status.ok()); + EXPECT_THAT(out, testutil::EqualsProto(msg_timestamp)); +} + +// Dynamic Values test +// +TEST_F(CelProtoWrapperTest, UnwrapMessageToValueNull) { + Value json; + json.set_null_value(google::protobuf::NullValue::NULL_VALUE); + ExpectUnwrappedMessage(json, nullptr); +} + +// Test support for unwrapping a google::protobuf::Value to a CEL value. +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueNull) { + Value value_msg; + value_msg.set_null_value(protobuf::NULL_VALUE); + + CelValue value = UnwrapMessageToValue(ReflectedCopy(value_msg).get(), + &ProtobufValueFactoryImpl, arena()); + EXPECT_TRUE(value.IsNull()); +} + +TEST_F(CelProtoWrapperTest, UnwrapMessageToValueBool) { + bool value = true; + + Value json; + json.set_bool_value(true); + ExpectUnwrappedPrimitive(json, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapMessageToValueNumber) { + double value = 1.0; + + Value json; + json.set_number_value(value); + ExpectUnwrappedPrimitive(json, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapMessageToValueString) { + const std::string test = "test"; + auto value = CelValue::StringHolder(&test); + + Value json; + json.set_string_value(test); + ExpectUnwrappedPrimitive(json, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapMessageToValueStruct) { + const std::vector kFields = {"field1", "field2", "field3"}; + Struct value_struct; + + auto& value1 = (*value_struct.mutable_fields())[kFields[0]]; + value1.set_bool_value(true); + + auto& value2 = (*value_struct.mutable_fields())[kFields[1]]; + value2.set_number_value(1.0); + + auto& value3 = (*value_struct.mutable_fields())[kFields[2]]; + value3.set_string_value("test"); + + CelValue value = + UnwrapMessageToValue(&value_struct, &ProtobufValueFactoryImpl, arena()); + ASSERT_TRUE(value.IsMap()); + + const CelMap* cel_map = value.MapOrDie(); + + CelValue field1 = CelValue::CreateString(&kFields[0]); + auto field1_presence = cel_map->Has(field1); + ASSERT_OK(field1_presence); + EXPECT_TRUE(*field1_presence); + auto lookup1 = (*cel_map)[field1]; + ASSERT_TRUE(lookup1.has_value()); + ASSERT_TRUE(lookup1->IsBool()); + EXPECT_EQ(lookup1->BoolOrDie(), true); + + CelValue field2 = CelValue::CreateString(&kFields[1]); + auto field2_presence = cel_map->Has(field2); + ASSERT_OK(field2_presence); + EXPECT_TRUE(*field2_presence); + auto lookup2 = (*cel_map)[field2]; + ASSERT_TRUE(lookup2.has_value()); + ASSERT_TRUE(lookup2->IsDouble()); + EXPECT_DOUBLE_EQ(lookup2->DoubleOrDie(), 1.0); + + CelValue field3 = CelValue::CreateString(&kFields[2]); + auto field3_presence = cel_map->Has(field3); + ASSERT_OK(field3_presence); + EXPECT_TRUE(*field3_presence); + auto lookup3 = (*cel_map)[field3]; + ASSERT_TRUE(lookup3.has_value()); + ASSERT_TRUE(lookup3->IsString()); + EXPECT_EQ(lookup3->StringOrDie().value(), "test"); + + std::string missing = "missing_field"; + CelValue missing_field = CelValue::CreateString(&missing); + auto missing_field_presence = cel_map->Has(missing_field); + ASSERT_OK(missing_field_presence); + EXPECT_FALSE(*missing_field_presence); + + const CelList* key_list = cel_map->ListKeys(); + ASSERT_EQ(key_list->size(), kFields.size()); + + std::vector result_keys; + for (int i = 0; i < key_list->size(); i++) { + CelValue key = (*key_list)[i]; + ASSERT_TRUE(key.IsString()); + result_keys.push_back(std::string(key.StringOrDie().value())); + } + + EXPECT_THAT(result_keys, UnorderedPointwise(Eq(), kFields)); +} + +// Test support for google::protobuf::Struct when it is created as dynamic +// message +TEST_F(CelProtoWrapperTest, UnwrapDynamicStruct) { + Struct struct_msg; + const std::string kFieldInt = "field_int"; + const std::string kFieldBool = "field_bool"; + (*struct_msg.mutable_fields())[kFieldInt].set_number_value(1.); + (*struct_msg.mutable_fields())[kFieldBool].set_bool_value(true); + CelValue value = UnwrapMessageToValue(ReflectedCopy(struct_msg).get(), + &ProtobufValueFactoryImpl, arena()); + EXPECT_TRUE(value.IsMap()); + const CelMap* cel_map = value.MapOrDie(); + ASSERT_TRUE(cel_map != nullptr); + + { + auto lookup = (*cel_map)[CelValue::CreateString(&kFieldInt)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsDouble()); + EXPECT_THAT(v.DoubleOrDie(), testing::DoubleEq(1.)); + } + { + auto lookup = (*cel_map)[CelValue::CreateString(&kFieldBool)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsBool()); + EXPECT_EQ(v.BoolOrDie(), true); + } + { + auto presence = cel_map->Has(CelValue::CreateBool(true)); + ASSERT_FALSE(presence.ok()); + EXPECT_EQ(presence.status().code(), absl::StatusCode::kInvalidArgument); + auto lookup = (*cel_map)[CelValue::CreateBool(true)]; + ASSERT_TRUE(lookup.has_value()); + auto v = lookup.value(); + ASSERT_TRUE(v.IsError()); + } +} + +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueStruct) { + const std::string kField1 = "field1"; + const std::string kField2 = "field2"; + Value value_msg; + (*value_msg.mutable_struct_value()->mutable_fields())[kField1] + .set_number_value(1); + (*value_msg.mutable_struct_value()->mutable_fields())[kField2] + .set_number_value(2); + + CelValue value = UnwrapMessageToValue(ReflectedCopy(value_msg).get(), + &ProtobufValueFactoryImpl, arena()); + EXPECT_TRUE(value.IsMap()); + EXPECT_TRUE( + (*value.MapOrDie())[CelValue::CreateString(&kField1)].has_value()); + EXPECT_TRUE( + (*value.MapOrDie())[CelValue::CreateString(&kField2)].has_value()); +} + +TEST_F(CelProtoWrapperTest, UnwrapMessageToValueList) { + const std::vector kFields = {"field1", "field2", "field3"}; + + ListValue list_value; + + list_value.add_values()->set_bool_value(true); + list_value.add_values()->set_number_value(1.0); + list_value.add_values()->set_string_value("test"); + + CelValue value = + UnwrapMessageToValue(&list_value, &ProtobufValueFactoryImpl, arena()); + ASSERT_TRUE(value.IsList()); + + const CelList* cel_list = value.ListOrDie(); + + ASSERT_EQ(cel_list->size(), 3); + + CelValue value1 = (*cel_list)[0]; + ASSERT_TRUE(value1.IsBool()); + EXPECT_EQ(value1.BoolOrDie(), true); + + auto value2 = (*cel_list)[1]; + ASSERT_TRUE(value2.IsDouble()); + EXPECT_DOUBLE_EQ(value2.DoubleOrDie(), 1.0); + + auto value3 = (*cel_list)[2]; + ASSERT_TRUE(value3.IsString()); + EXPECT_EQ(value3.StringOrDie().value(), "test"); +} + +TEST_F(CelProtoWrapperTest, UnwrapDynamicValueListValue) { + Value value_msg; + value_msg.mutable_list_value()->add_values()->set_number_value(1.); + value_msg.mutable_list_value()->add_values()->set_number_value(2.); + + CelValue value = UnwrapMessageToValue(ReflectedCopy(value_msg).get(), + &ProtobufValueFactoryImpl, arena()); + EXPECT_TRUE(value.IsList()); + EXPECT_THAT((*value.ListOrDie())[0].DoubleOrDie(), testing::DoubleEq(1)); + EXPECT_THAT((*value.ListOrDie())[1].DoubleOrDie(), testing::DoubleEq(2)); +} + +// Test support of google.protobuf.Any in CelValue. +TEST_F(CelProtoWrapperTest, UnwrapAnyValue) { + TestMessage test_message; + test_message.set_string_value("test"); + + Any any; + any.PackFrom(test_message); + ExpectUnwrappedMessage(any, &test_message); +} + +TEST_F(CelProtoWrapperTest, UnwrapInvalidAny) { + Any any; + CelValue value = + UnwrapMessageToValue(&any, &ProtobufValueFactoryImpl, arena()); + ASSERT_TRUE(value.IsError()); + + any.set_type_url("/"); + ASSERT_TRUE( + UnwrapMessageToValue(&any, &ProtobufValueFactoryImpl, arena()).IsError()); + + any.set_type_url("/invalid.proto.name"); + ASSERT_TRUE( + UnwrapMessageToValue(&any, &ProtobufValueFactoryImpl, arena()).IsError()); +} + +// Test support of google.protobuf.Value wrappers in CelValue. +TEST_F(CelProtoWrapperTest, UnwrapBoolWrapper) { + bool value = true; + + BoolValue wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapInt32Wrapper) { + int64_t value = 12; + + Int32Value wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapUInt32Wrapper) { + uint64_t value = 12; + + UInt32Value wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapInt64Wrapper) { + int64_t value = 12; + + Int64Value wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapUInt64Wrapper) { + uint64_t value = 12; + + UInt64Value wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapFloatWrapper) { + double value = 42.5; + + FloatValue wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapDoubleWrapper) { + double value = 42.5; + + DoubleValue wrapper; + wrapper.set_value(value); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapStringWrapper) { + std::string text = "42"; + auto value = CelValue::StringHolder(&text); + + StringValue wrapper; + wrapper.set_value(text); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, UnwrapBytesWrapper) { + std::string text = "42"; + auto value = CelValue::BytesHolder(&text); + + BytesValue wrapper; + wrapper.set_value("42"); + ExpectUnwrappedPrimitive(wrapper, value); +} + +TEST_F(CelProtoWrapperTest, WrapNull) { + auto cel_value = CelValue::CreateNull(); + + Value json; + json.set_null_value(protobuf::NULL_VALUE); + ExpectWrappedMessage(cel_value, json); + + Any any; + any.PackFrom(json); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapBool) { + auto cel_value = CelValue::CreateBool(true); + + Value json; + json.set_bool_value(true); + ExpectWrappedMessage(cel_value, json); + + BoolValue wrapper; + wrapper.set_value(true); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapBytes) { + std::string str = "hello world"; + auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); + + BytesValue wrapper; + wrapper.set_value(str); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapBytesToValue) { + std::string str = "hello world"; + auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); + + Value json; + json.set_string_value("aGVsbG8gd29ybGQ="); + ExpectWrappedMessage(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapDuration) { + auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); + + Duration d; + d.set_seconds(300); + ExpectWrappedMessage(cel_value, d); + + Any any; + any.PackFrom(d); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapDurationToValue) { + auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); + + Value json; + json.set_string_value("300s"); + ExpectWrappedMessage(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapDouble) { + double num = 1.5; + auto cel_value = CelValue::CreateDouble(num); + + Value json; + json.set_number_value(num); + ExpectWrappedMessage(cel_value, json); + + DoubleValue wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapDoubleToFloatValue) { + double num = 1.5; + auto cel_value = CelValue::CreateDouble(num); + + FloatValue wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); + + // Imprecise double -> float representation results in truncation. + double small_num = -9.9e-100; + wrapper.set_value(small_num); + cel_value = CelValue::CreateDouble(small_num); + ExpectWrappedMessage(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapDoubleOverflow) { + double lowest_double = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateDouble(lowest_double); + + // Double exceeds float precision, overflow to -infinity. + FloatValue wrapper; + wrapper.set_value(-std::numeric_limits::infinity()); + ExpectWrappedMessage(cel_value, wrapper); + + double max_double = std::numeric_limits::max(); + cel_value = CelValue::CreateDouble(max_double); + + wrapper.set_value(std::numeric_limits::infinity()); + ExpectWrappedMessage(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapInt64) { + int32_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); + + Value json; + json.set_number_value(static_cast(num)); + ExpectWrappedMessage(cel_value, json); + + Int64Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapInt64ToInt32Value) { + int32_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); + + Int32Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapFailureInt64ToInt32Value) { + int64_t num = std::numeric_limits::lowest(); + auto cel_value = CelValue::CreateInt64(num); + + Int32Value wrapper; + ExpectNotWrapped(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapInt64ToValue) { + int64_t max = std::numeric_limits::max(); + auto cel_value = CelValue::CreateInt64(max); + + Value json; + json.set_string_value(absl::StrCat(max)); + ExpectWrappedMessage(cel_value, json); + + int64_t min = std::numeric_limits::min(); + cel_value = CelValue::CreateInt64(min); + + json.set_string_value(absl::StrCat(min)); + ExpectWrappedMessage(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapUint64) { + uint32_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); + + Value json; + json.set_number_value(static_cast(num)); + ExpectWrappedMessage(cel_value, json); + + UInt64Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapUint64ToUint32Value) { + uint32_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); + + UInt32Value wrapper; + wrapper.set_value(num); + ExpectWrappedMessage(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapUint64ToValue) { + uint64_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); + + Value json; + json.set_string_value(absl::StrCat(num)); + ExpectWrappedMessage(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapFailureUint64ToUint32Value) { + uint64_t num = std::numeric_limits::max(); + auto cel_value = CelValue::CreateUint64(num); + + UInt32Value wrapper; + ExpectNotWrapped(cel_value, wrapper); +} + +TEST_F(CelProtoWrapperTest, WrapString) { + std::string str = "test"; + auto cel_value = CelValue::CreateString(CelValue::StringHolder(&str)); + + Value json; + json.set_string_value(str); + ExpectWrappedMessage(cel_value, json); + + StringValue wrapper; + wrapper.set_value(str); + ExpectWrappedMessage(cel_value, wrapper); + + Any any; + any.PackFrom(wrapper); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapTimestamp) { + absl::Time ts = absl::FromUnixSeconds(1615852799); + auto cel_value = CelValue::CreateTimestamp(ts); + + Timestamp t; + t.set_seconds(1615852799); + ExpectWrappedMessage(cel_value, t); + + Any any; + any.PackFrom(t); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapTimestampToValue) { + absl::Time ts = absl::FromUnixSeconds(1615852799); + auto cel_value = CelValue::CreateTimestamp(ts); + + Value json; + json.set_string_value("2021-03-15T23:59:59Z"); + ExpectWrappedMessage(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapList) { + std::vector list_elems = { + CelValue::CreateDouble(1.5), + CelValue::CreateInt64(-2L), + }; + ContainerBackedListImpl list(std::move(list_elems)); + auto cel_value = CelValue::CreateList(&list); + + Value json; + json.mutable_list_value()->add_values()->set_number_value(1.5); + json.mutable_list_value()->add_values()->set_number_value(-2.); + ExpectWrappedMessage(cel_value, json); + ExpectWrappedMessage(cel_value, json.list_value()); + + Any any; + any.PackFrom(json.list_value()); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapFailureListValueBadJSON) { + TestMessage message; + std::vector list_elems = { + CelValue::CreateDouble(1.5), + UnwrapMessageToValue(&message, &ProtobufValueFactoryImpl, arena()), + }; + ContainerBackedListImpl list(std::move(list_elems)); + auto cel_value = CelValue::CreateList(&list); + + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapStruct) { + const std::string kField1 = "field1"; + std::vector> args = { + {CelValue::CreateString(CelValue::StringHolder(&kField1)), + CelValue::CreateBool(true)}}; + auto cel_map = + CreateContainerBackedMap( + absl::Span>(args.data(), args.size())) + .value(); + auto cel_value = CelValue::CreateMap(cel_map.get()); + + Value json; + (*json.mutable_struct_value()->mutable_fields())[kField1].set_bool_value( + true); + ExpectWrappedMessage(cel_value, json); + ExpectWrappedMessage(cel_value, json.struct_value()); + + Any any; + any.PackFrom(json.struct_value()); + ExpectWrappedMessage(cel_value, any); +} + +TEST_F(CelProtoWrapperTest, WrapFailureStructBadKeyType) { + std::vector> args = { + {CelValue::CreateInt64(1L), CelValue::CreateBool(true)}}; + auto cel_map = + CreateContainerBackedMap( + absl::Span>(args.data(), args.size())) + .value(); + auto cel_value = CelValue::CreateMap(cel_map.get()); + + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapFailureStructBadValueType) { + const std::string kField1 = "field1"; + TestMessage bad_value; + std::vector> args = { + {CelValue::CreateString(CelValue::StringHolder(&kField1)), + UnwrapMessageToValue(&bad_value, &ProtobufValueFactoryImpl, arena())}}; + auto cel_map = + CreateContainerBackedMap( + absl::Span>(args.data(), args.size())) + .value(); + auto cel_value = CelValue::CreateMap(cel_map.get()); + Value json; + ExpectNotWrapped(cel_value, json); +} + +TEST_F(CelProtoWrapperTest, WrapFailureWrongType) { + auto cel_value = CelValue::CreateNull(); + std::vector wrong_types = { + &BoolValue::default_instance(), &BytesValue::default_instance(), + &DoubleValue::default_instance(), &Duration::default_instance(), + &FloatValue::default_instance(), &Int32Value::default_instance(), + &Int64Value::default_instance(), &ListValue::default_instance(), + &StringValue::default_instance(), &Struct::default_instance(), + &Timestamp::default_instance(), &UInt32Value::default_instance(), + &UInt64Value::default_instance(), + }; + for (const auto* wrong_type : wrong_types) { + ExpectNotWrapped(cel_value, *wrong_type); + } +} + +TEST_F(CelProtoWrapperTest, WrapFailureErrorToAny) { + auto cel_value = CreateNoSuchFieldError(arena(), "error_field"); + ExpectNotWrapped(cel_value, Any::default_instance()); +} + +TEST_F(CelProtoWrapperTest, DebugString) { + google::protobuf::Empty e; + EXPECT_EQ(UnwrapMessageToValue(&e, &ProtobufValueFactoryImpl, arena()) + .DebugString(), + "Message: "); + + ListValue list_value; + list_value.add_values()->set_bool_value(true); + list_value.add_values()->set_number_value(1.0); + list_value.add_values()->set_string_value("test"); + CelValue value = + UnwrapMessageToValue(&list_value, &ProtobufValueFactoryImpl, arena()); + EXPECT_EQ(value.DebugString(), + "CelList: [bool: 1, double: 1.000000, string: test]"); + + Struct value_struct; + auto& value1 = (*value_struct.mutable_fields())["a"]; + value1.set_bool_value(true); + auto& value2 = (*value_struct.mutable_fields())["b"]; + value2.set_number_value(1.0); + auto& value3 = (*value_struct.mutable_fields())["c"]; + value3.set_string_value("test"); + + value = + UnwrapMessageToValue(&value_struct, &ProtobufValueFactoryImpl, arena()); + EXPECT_THAT( + value.DebugString(), + testing::AllOf(testing::StartsWith("CelMap: {"), + testing::HasSubstr(": "), + testing::HasSubstr(": : "))); +} + +} // namespace + +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/structs/protobuf_value_factory.h b/eval/public/structs/protobuf_value_factory.h new file mode 100644 index 000000000..7d4223411 --- /dev/null +++ b/eval/public/structs/protobuf_value_factory.h @@ -0,0 +1,36 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_VALUE_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_VALUE_FACTORY_H_ + +#include + +#include "google/protobuf/message.h" +#include "eval/public/cel_value.h" + +namespace google::api::expr::runtime::internal { + +// Definiton for factory producing a properly initialized message-typed +// CelValue. +// +// google::protobuf::Message is assumed adapted as possible, so this function just +// associates it with appropriate type information. +// +// Used to break cyclic dependency between field access and message wrapping -- +// not intended for general use. +using ProtobufValueFactory = std::function; +} // namespace google::api::expr::runtime::internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_VALUE_FACTORY_H_ From 5c4082a806bd9b11f3a9f7010049aa1167950d56 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 12 Apr 2022 21:24:06 +0000 Subject: [PATCH 063/111] Update cel message value factory to use internal wrap / unwrap primitives. PiperOrigin-RevId: 441290747 --- eval/public/structs/BUILD | 7 +- eval/public/structs/cel_proto_wrapper.cc | 868 +---------------------- 2 files changed, 15 insertions(+), 860 deletions(-) diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 99d686e9d..72078deb9 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -25,14 +25,9 @@ cc_library( "cel_proto_wrapper.h", ], deps = [ + ":cel_proto_wrap_util", "//eval/public:cel_value", - "//eval/testutil:test_message_cc_proto", - "//internal:overflow", "//internal:proto_util", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index 12b24e9c6..8ff065efc 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -14,881 +14,41 @@ #include "eval/public/structs/cel_proto_wrapper.h" -#include - -#include -#include -#include -#include -#include - -#include "google/protobuf/any.pb.h" -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/struct.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "google/protobuf/wrappers.pb.h" #include "google/protobuf/message.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/strings/escaping.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/strings/substitute.h" -#include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "eval/public/cel_value.h" -#include "eval/testutil/test_message.pb.h" -#include "internal/overflow.h" -#include "internal/proto_util.h" +#include "eval/public/structs/cel_proto_wrap_util.h" namespace google::api::expr::runtime { namespace { -using google::protobuf::Arena; -using google::protobuf::Descriptor; -using google::protobuf::DescriptorPool; -using google::protobuf::Message; -using google::protobuf::MessageFactory; - -using google::api::expr::internal::EncodeTime; -using google::protobuf::Any; -using google::protobuf::BoolValue; -using google::protobuf::BytesValue; -using google::protobuf::DoubleValue; -using google::protobuf::Duration; -using google::protobuf::FloatValue; -using google::protobuf::Int32Value; -using google::protobuf::Int64Value; -using google::protobuf::ListValue; -using google::protobuf::StringValue; -using google::protobuf::Struct; -using google::protobuf::Timestamp; -using google::protobuf::UInt32Value; -using google::protobuf::UInt64Value; -using google::protobuf::Value; - -// kMaxIntJSON is defined as the Number.MAX_SAFE_INTEGER value per EcmaScript 6. -constexpr int64_t kMaxIntJSON = (1ll << 53) - 1; - -// kMinIntJSON is defined as the Number.MIN_SAFE_INTEGER value per EcmaScript 6. -constexpr int64_t kMinIntJSON = -kMaxIntJSON; - -// Forward declaration for google.protobuf.Value -CelValue ValueFromMessage(const Value* value, Arena* arena); -absl::optional MessageFromValue(const CelValue& value, - Value* json); - -// IsJSONSafe indicates whether the int is safely representable as a floating -// point value in JSON. -static bool IsJSONSafe(int64_t i) { - return i >= kMinIntJSON && i <= kMaxIntJSON; -} - -// IsJSONSafe indicates whether the uint is safely representable as a floating -// point value in JSON. -static bool IsJSONSafe(uint64_t i) { - return i <= static_cast(kMaxIntJSON); -} - -// Map implementation wrapping google.protobuf.ListValue -class DynamicList : public CelList { - public: - DynamicList(const ListValue* values, Arena* arena) - : arena_(arena), values_(values) {} - - CelValue operator[](int index) const override { - return ValueFromMessage(&values_->values(index), arena_); - } - - // List size - int size() const override { return values_->values_size(); } - - private: - Arena* arena_; - const ListValue* values_; -}; - -// Map implementation wrapping google.protobuf.Struct. -class DynamicMap : public CelMap { - public: - DynamicMap(const Struct* values, Arena* arena) - : arena_(arena), values_(values), key_list_(values) {} - - absl::StatusOr Has(const CelValue& key) const override { - CelValue::StringHolder str_key; - if (!key.GetValue(&str_key)) { - // Not a string key. - return absl::InvalidArgumentError(absl::StrCat( - "Invalid map key type: '", CelValue::TypeName(key.type()), "'")); - } - - return values_->fields().contains(std::string(str_key.value())); - } - - absl::optional operator[](CelValue key) const override { - CelValue::StringHolder str_key; - if (!key.GetValue(&str_key)) { - // Not a string key. - return CreateErrorValue( - arena_, - absl::InvalidArgumentError(absl::StrCat( - "Invalid map key type: '", CelValue::TypeName(key.type()), "'"))); - } - - auto it = values_->fields().find(std::string(str_key.value())); - if (it == values_->fields().end()) { - return absl::nullopt; - } - - return ValueFromMessage(&it->second, arena_); - } - - int size() const override { return values_->fields_size(); } - - const CelList* ListKeys() const override { return &key_list_; } - - private: - // List of keys in Struct.fields map. - // It utilizes lazy initialization, to avoid performance penalties. - class DynamicMapKeyList : public CelList { - public: - explicit DynamicMapKeyList(const Struct* values) - : values_(values), keys_(), initialized_(false) {} - - // Index access - CelValue operator[](int index) const override { - CheckInit(); - return keys_[index]; - } - - // List size - int size() const override { - CheckInit(); - return values_->fields_size(); - } - - private: - void CheckInit() const { - absl::MutexLock lock(&mutex_); - if (!initialized_) { - for (const auto& it : values_->fields()) { - keys_.push_back(CelValue::CreateString(&it.first)); - } - initialized_ = true; - } - } - - const Struct* values_; - mutable absl::Mutex mutex_; - mutable std::vector keys_; - mutable bool initialized_; - }; - - Arena* arena_; - const Struct* values_; - const DynamicMapKeyList key_list_; -}; - -// ValueFromMessage(....) function family. -// Functions of this family create CelValue object from specific subtypes of -// protobuf message. -CelValue ValueFromMessage(const Duration* duration, Arena*) { - return CelProtoWrapper::CreateDuration(duration); -} - -CelValue ValueFromMessage(const Timestamp* timestamp, Arena*) { - return CelProtoWrapper::CreateTimestamp(timestamp); -} - -CelValue ValueFromMessage(const ListValue* list_values, Arena* arena) { - return CelValue::CreateList( - Arena::Create(arena, list_values, arena)); -} - -CelValue ValueFromMessage(const Struct* struct_value, Arena* arena) { - return CelValue::CreateMap( - Arena::Create(arena, struct_value, arena)); -} - -CelValue ValueFromMessage(const Any* any_value, Arena* arena, - const DescriptorPool* descriptor_pool, - MessageFactory* message_factory) { - auto type_url = any_value->type_url(); - auto pos = type_url.find_last_of('/'); - if (pos == absl::string_view::npos) { - // TODO(issues/25) What error code? - // Malformed type_url - return CreateErrorValue(arena, "Malformed type_url string"); - } - - std::string full_name = std::string(type_url.substr(pos + 1)); - const Descriptor* nested_descriptor = - descriptor_pool->FindMessageTypeByName(full_name); - - if (nested_descriptor == nullptr) { - // Descriptor not found for the type - // TODO(issues/25) What error code? - return CreateErrorValue(arena, "Descriptor not found"); - } - - const Message* prototype = message_factory->GetPrototype(nested_descriptor); - if (prototype == nullptr) { - // Failed to obtain prototype for the descriptor - // TODO(issues/25) What error code? - return CreateErrorValue(arena, "Prototype not found"); - } - - Message* nested_message = prototype->New(arena); - if (!any_value->UnpackTo(nested_message)) { - // Failed to unpack. - // TODO(issues/25) What error code? - return CreateErrorValue(arena, "Failed to unpack Any into message"); - } - - return CelProtoWrapper::CreateMessage(nested_message, arena); -} - -CelValue ValueFromMessage(const Any* any_value, Arena* arena) { - return ValueFromMessage(any_value, arena, DescriptorPool::generated_pool(), - MessageFactory::generated_factory()); -} - -CelValue ValueFromMessage(const BoolValue* wrapper, Arena*) { - return CelValue::CreateBool(wrapper->value()); -} - -CelValue ValueFromMessage(const Int32Value* wrapper, Arena*) { - return CelValue::CreateInt64(wrapper->value()); -} - -CelValue ValueFromMessage(const UInt32Value* wrapper, Arena*) { - return CelValue::CreateUint64(wrapper->value()); -} - -CelValue ValueFromMessage(const Int64Value* wrapper, Arena*) { - return CelValue::CreateInt64(wrapper->value()); -} - -CelValue ValueFromMessage(const UInt64Value* wrapper, Arena*) { - return CelValue::CreateUint64(wrapper->value()); -} - -CelValue ValueFromMessage(const FloatValue* wrapper, Arena*) { - return CelValue::CreateDouble(wrapper->value()); -} - -CelValue ValueFromMessage(const DoubleValue* wrapper, Arena*) { - return CelValue::CreateDouble(wrapper->value()); -} - -CelValue ValueFromMessage(const StringValue* wrapper, Arena*) { - return CelValue::CreateString(&wrapper->value()); -} - -CelValue ValueFromMessage(const BytesValue* wrapper, Arena* arena) { - // BytesValue stores value as Cord - return CelValue::CreateBytes( - Arena::Create(arena, std::string(wrapper->value()))); -} - -CelValue ValueFromMessage(const Value* value, Arena* arena) { - switch (value->kind_case()) { - case Value::KindCase::kNullValue: - return CelValue::CreateNull(); - case Value::KindCase::kNumberValue: - return CelValue::CreateDouble(value->number_value()); - case Value::KindCase::kStringValue: - return CelValue::CreateString(&value->string_value()); - case Value::KindCase::kBoolValue: - return CelValue::CreateBool(value->bool_value()); - case Value::KindCase::kStructValue: - return CelProtoWrapper::CreateMessage(&value->struct_value(), arena); - case Value::KindCase::kListValue: - return CelProtoWrapper::CreateMessage(&value->list_value(), arena); - default: - return CelValue::CreateNull(); - } -} - -// Factory class, responsible for creating CelValue object from Message of some -// fixed subtype. -class ValueFromMessageFactory { - public: - virtual ~ValueFromMessageFactory() {} - virtual const google::protobuf::Descriptor* GetDescriptor() const = 0; - virtual absl::optional CreateValue(const google::protobuf::Message* value, - Arena* arena) const = 0; -}; - -// Class makes CelValue from generic protobuf Message. -// It holds a registry of CelValue factories for specific subtypes of Message. -// If message does not match any of types stored in registry, generic -// message-containing CelValue is created. -class ValueFromMessageMaker { - public: - template - static absl::optional CreateWellknownTypeValue( - const google::protobuf::Message* msg, Arena* arena) { - const MessageType* message = - google::protobuf::DynamicCastToGenerated(msg); - if (message == nullptr) { - auto message_copy = Arena::CreateMessage(arena); - if (MessageType::descriptor() == msg->GetDescriptor()) { - message_copy->CopyFrom(*msg); - message = message_copy; - } else { - // message of well-known type but from a descriptor pool other than the - // generated one. - std::string serialized_msg; - if (msg->SerializeToString(&serialized_msg) && - message_copy->ParseFromString(serialized_msg)) { - message = message_copy; - } - } - } - return ValueFromMessage(message, arena); - } - - static absl::optional CreateValue(const google::protobuf::Message* message, - Arena* arena) { - switch (message->GetDescriptor()->well_known_type()) { - case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: - return CreateWellknownTypeValue(message, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: - return CreateWellknownTypeValue(message, arena); - // WELLKNOWNTYPE_FIELDMASK has no special CelValue type - default: - return absl::nullopt; - } - } - - // Non-copyable, non-assignable - ValueFromMessageMaker(const ValueFromMessageMaker&) = delete; - ValueFromMessageMaker& operator=(const ValueFromMessageMaker&) = delete; -}; +using ::google::protobuf::Arena; +using ::google::protobuf::Descriptor; +using ::google::protobuf::Message; -absl::optional MessageFromValue(const CelValue& value, - Duration* duration) { - absl::Duration val; - if (!value.GetValue(&val)) { - return absl::nullopt; - } - auto status = google::api::expr::internal::EncodeDuration(val, duration); - if (!status.ok()) { - return absl::nullopt; - } - return duration; -} - -absl::optional MessageFromValue(const CelValue& value, - BoolValue* wrapper) { - bool val; - if (!value.GetValue(&val)) { - return absl::nullopt; - } - wrapper->set_value(val); - return wrapper; -} - -absl::optional MessageFromValue(const CelValue& value, - BytesValue* wrapper) { - CelValue::BytesHolder view_val; - if (!value.GetValue(&view_val)) { - return absl::nullopt; - } - wrapper->set_value(view_val.value().data()); - return wrapper; -} - -absl::optional MessageFromValue(const CelValue& value, - DoubleValue* wrapper) { - double val; - if (!value.GetValue(&val)) { - return absl::nullopt; - } - wrapper->set_value(val); - return wrapper; -} - -absl::optional MessageFromValue(const CelValue& value, - FloatValue* wrapper) { - double val; - if (!value.GetValue(&val)) { - return absl::nullopt; - } - // Abort the conversion if the value is outside the float range. - if (val > std::numeric_limits::max()) { - wrapper->set_value(std::numeric_limits::infinity()); - return wrapper; - } - if (val < std::numeric_limits::lowest()) { - wrapper->set_value(-std::numeric_limits::infinity()); - return wrapper; - } - wrapper->set_value(val); - return wrapper; -} - -absl::optional MessageFromValue(const CelValue& value, - Int32Value* wrapper) { - int64_t val; - if (!value.GetValue(&val)) { - return absl::nullopt; - } - // Abort the conversion if the value is outside the int32_t range. - if (!cel::internal::CheckedInt64ToInt32(val).ok()) { - return absl::nullopt; - } - wrapper->set_value(val); - return wrapper; -} - -absl::optional MessageFromValue(const CelValue& value, - Int64Value* wrapper) { - int64_t val; - if (!value.GetValue(&val)) { - return absl::nullopt; - } - wrapper->set_value(val); - return wrapper; -} - -absl::optional MessageFromValue(const CelValue& value, - StringValue* wrapper) { - CelValue::StringHolder view_val; - if (!value.GetValue(&view_val)) { - return absl::nullopt; - } - wrapper->set_value(view_val.value().data()); - return wrapper; +CelValue WrapMessage(const Message* m) { + return CelValue::CreateMessageWrapper(CelValue::MessageWrapper(m)); } -absl::optional MessageFromValue(const CelValue& value, - Timestamp* timestamp) { - absl::Time val; - if (!value.GetValue(&val)) { - return absl::nullopt; - } - auto status = EncodeTime(val, timestamp); - if (!status.ok()) { - return absl::nullopt; - } - return timestamp; -} - -absl::optional MessageFromValue(const CelValue& value, - UInt32Value* wrapper) { - uint64_t val; - if (!value.GetValue(&val)) { - return absl::nullopt; - } - // Abort the conversion if the value is outside the uint32_t range. - if (!cel::internal::CheckedUint64ToUint32(val).ok()) { - return absl::nullopt; - } - wrapper->set_value(val); - return wrapper; -} - -absl::optional MessageFromValue(const CelValue& value, - UInt64Value* wrapper) { - uint64_t val; - if (!value.GetValue(&val)) { - return absl::nullopt; - } - wrapper->set_value(val); - return wrapper; -} - -absl::optional MessageFromValue(const CelValue& value, - ListValue* json_list) { - if (!value.IsList()) { - return absl::nullopt; - } - const CelList& list = *value.ListOrDie(); - for (int i = 0; i < list.size(); i++) { - auto e = list[i]; - Value* elem = json_list->add_values(); - auto result = MessageFromValue(e, elem); - if (!result.has_value()) { - return absl::nullopt; - } - } - return json_list; -} - -absl::optional MessageFromValue(const CelValue& value, - Struct* json_struct) { - if (!value.IsMap()) { - return absl::nullopt; - } - const CelMap& map = *value.MapOrDie(); - const auto& keys = *map.ListKeys(); - auto fields = json_struct->mutable_fields(); - for (int i = 0; i < keys.size(); i++) { - auto k = keys[i]; - // If the key is not a string type, abort the conversion. - if (!k.IsString()) { - return absl::nullopt; - } - absl::string_view key = k.StringOrDie().value(); - - auto v = map[k]; - if (!v.has_value()) { - return absl::nullopt; - } - Value field_value; - auto result = MessageFromValue(*v, &field_value); - // If the value is not a valid JSON type, abort the conversion. - if (!result.has_value()) { - return absl::nullopt; - } - (*fields)[std::string(key)] = field_value; - } - return json_struct; -} - -absl::optional MessageFromValue(const CelValue& value, - Value* json) { - switch (value.type()) { - case CelValue::Type::kBool: { - bool val; - if (value.GetValue(&val)) { - json->set_bool_value(val); - return json; - } - } break; - case CelValue::Type::kBytes: { - // Base64 encode byte strings to ensure they can safely be transpored - // in a JSON string. - CelValue::BytesHolder val; - if (value.GetValue(&val)) { - json->set_string_value(absl::Base64Escape(val.value())); - return json; - } - } break; - case CelValue::Type::kDouble: { - double val; - if (value.GetValue(&val)) { - json->set_number_value(val); - return json; - } - } break; - case CelValue::Type::kDuration: { - // Convert duration values to a protobuf JSON format. - absl::Duration val; - if (value.GetValue(&val)) { - auto encode = google::api::expr::internal::EncodeDurationToString(val); - if (!encode.ok()) { - return absl::nullopt; - } - json->set_string_value(*encode); - return json; - } - } break; - case CelValue::Type::kInt64: { - int64_t val; - // Convert int64_t values within the int53 range to doubles, otherwise - // serialize the value to a string. - if (value.GetValue(&val)) { - if (IsJSONSafe(val)) { - json->set_number_value(val); - } else { - json->set_string_value(absl::StrCat(val)); - } - return json; - } - } break; - case CelValue::Type::kString: { - CelValue::StringHolder val; - if (value.GetValue(&val)) { - json->set_string_value(val.value().data()); - return json; - } - } break; - case CelValue::Type::kTimestamp: { - // Convert timestamp values to a protobuf JSON format. - absl::Time val; - if (value.GetValue(&val)) { - auto encode = google::api::expr::internal::EncodeTimeToString(val); - if (!encode.ok()) { - return absl::nullopt; - } - json->set_string_value(*encode); - return json; - } - } break; - case CelValue::Type::kUint64: { - uint64_t val; - // Convert uint64_t values within the int53 range to doubles, otherwise - // serialize the value to a string. - if (value.GetValue(&val)) { - if (IsJSONSafe(val)) { - json->set_number_value(val); - } else { - json->set_string_value(absl::StrCat(val)); - } - return json; - } - } break; - case CelValue::Type::kList: { - auto lv = MessageFromValue(value, json->mutable_list_value()); - if (lv.has_value()) { - return json; - } - } break; - case CelValue::Type::kMap: { - auto sv = MessageFromValue(value, json->mutable_struct_value()); - if (sv.has_value()) { - return json; - } - } break; - case CelValue::Type::kNullType: - json->set_null_value(protobuf::NULL_VALUE); - return json; - default: - return absl::nullopt; - } - return absl::nullopt; -} - -absl::optional MessageFromValue(const CelValue& value, - Any* any) { - // In open source, any->PackFrom() returns void rather than boolean. - switch (value.type()) { - case CelValue::Type::kBool: { - BoolValue v; - auto msg = MessageFromValue(value, &v); - if (msg.has_value()) { - any->PackFrom(**msg); - return any; - } - } break; - case CelValue::Type::kBytes: { - BytesValue v; - auto msg = MessageFromValue(value, &v); - if (msg.has_value()) { - any->PackFrom(**msg); - return any; - } - } break; - case CelValue::Type::kDouble: { - DoubleValue v; - auto msg = MessageFromValue(value, &v); - if (msg.has_value()) { - any->PackFrom(**msg); - return any; - } - } break; - case CelValue::Type::kDuration: { - Duration v; - auto msg = MessageFromValue(value, &v); - if (msg.has_value()) { - any->PackFrom(**msg); - return any; - } - } break; - case CelValue::Type::kInt64: { - Int64Value v; - auto msg = MessageFromValue(value, &v); - if (msg.has_value()) { - any->PackFrom(**msg); - return any; - } - } break; - case CelValue::Type::kString: { - StringValue v; - auto msg = MessageFromValue(value, &v); - if (msg.has_value()) { - any->PackFrom(**msg); - return any; - } - } break; - case CelValue::Type::kTimestamp: { - Timestamp v; - auto msg = MessageFromValue(value, &v); - if (msg.has_value()) { - any->PackFrom(**msg); - return any; - } - } break; - case CelValue::Type::kUint64: { - UInt64Value v; - auto msg = MessageFromValue(value, &v); - if (msg.has_value()) { - any->PackFrom(**msg); - return any; - } - } break; - case CelValue::Type::kList: { - ListValue v; - auto msg = MessageFromValue(value, &v); - if (msg.has_value()) { - any->PackFrom(**msg); - return any; - } - } break; - case CelValue::Type::kMap: { - Struct v; - auto msg = MessageFromValue(value, &v); - if (msg.has_value()) { - any->PackFrom(**msg); - return any; - } - } break; - case CelValue::Type::kNullType: { - Value v; - auto msg = MessageFromValue(value, &v); - if (msg.has_value()) { - any->PackFrom(**msg); - return any; - } - } break; - case CelValue::Type::kMessage: { - any->PackFrom(*(value.MessageOrDie())); - return any; - } break; - default: - break; - } - return absl::nullopt; -} - -// Factory class, responsible for populating a Message type instance with the -// value of a simple CelValue. -class MessageFromValueFactory { - public: - virtual ~MessageFromValueFactory() {} - virtual const google::protobuf::Descriptor* GetDescriptor() const = 0; - virtual absl::optional WrapMessage( - const CelValue& value, Arena* arena) const = 0; -}; - -// MessageFromValueMaker makes a specific protobuf Message instance based on -// the desired protobuf type name and an input CelValue. -// -// It holds a registry of CelValue factories for specific subtypes of Message. -// If message does not match any of types stored in registry, an the factory -// returns an absent value. -class MessageFromValueMaker { - public: - // Non-copyable, non-assignable - MessageFromValueMaker(const MessageFromValueMaker&) = delete; - MessageFromValueMaker& operator=(const MessageFromValueMaker&) = delete; - - template - static absl::optional WrapWellknownTypeMessage( - const CelValue& value, Arena* arena) { - // If the value is a message type, see if it is already of the proper type - // name, and return it directly. - if (value.IsMessage()) { - const auto* msg = value.MessageOrDie(); - if (MessageType::descriptor()->well_known_type() == - msg->GetDescriptor()->well_known_type()) { - return absl::nullopt; - } - } - // Otherwise, allocate an empty message type, and attempt to populate it - // using the proper MessageFromValue overload. - auto* msg_buffer = Arena::CreateMessage(arena); - return MessageFromValue(value, msg_buffer); - } - - static absl::optional MaybeWrapMessage( - const google::protobuf::Descriptor* descriptor, const CelValue& value, - Arena* arena) { - switch (descriptor->well_known_type()) { - case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: - return WrapWellknownTypeMessage(value, arena); - case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: - return WrapWellknownTypeMessage(value, arena); - // WELLKNOWNTYPE_FIELDMASK has no special CelValue type - default: - return absl::nullopt; - } - } -}; - } // namespace // CreateMessage creates CelValue from google::protobuf::Message. // As some of CEL basic types are subclassing google::protobuf::Message, // this method contains type checking and downcasts. -CelValue CelProtoWrapper::CreateMessage(const google::protobuf::Message* value, - Arena* arena) { - // Messages are Nullable types - if (value == nullptr) { - return CelValue::CreateNull(); - } - - absl::optional special_value; - - special_value = ValueFromMessageMaker::CreateValue(value, arena); - return special_value.has_value() ? special_value.value() - : CelValue::CreateMessage(value); +CelValue CelProtoWrapper::CreateMessage(const Message* value, Arena* arena) { + return internal::UnwrapMessageToValue(value, &WrapMessage, arena); } absl::optional CelProtoWrapper::MaybeWrapValue( - const google::protobuf::Descriptor* descriptor, const CelValue& value, Arena* arena) { - absl::optional msg = - MessageFromValueMaker::MaybeWrapMessage(descriptor, value, arena); - if (!msg.has_value()) { + const Descriptor* descriptor, const CelValue& value, Arena* arena) { + const Message* msg = + internal::MaybeWrapValueToMessage(descriptor, value, arena); + if (msg != nullptr) { + return WrapMessage(msg); + } else { return absl::nullopt; } - return CelValue::CreateMessage(msg.value()); } } // namespace google::api::expr::runtime From b36f63a610f51be71daa1b11ced40d1821e0b455 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 Apr 2022 18:55:10 +0000 Subject: [PATCH 064/111] Add support for namespaced function resolution for ParsedExpressions. PiperOrigin-RevId: 441541207 --- eval/compiler/BUILD | 5 + eval/compiler/flat_expr_builder.cc | 14 +- eval/compiler/flat_expr_builder.h | 13 + .../flat_expr_builder_comprehensions_test.cc | 3 +- eval/compiler/flat_expr_builder_test.cc | 188 ++++++ eval/compiler/qualified_reference_resolver.cc | 358 +++++------- eval/compiler/qualified_reference_resolver.h | 22 +- .../qualified_reference_resolver_test.cc | 543 +++++++++++------- eval/eval/expression_build_warning.cc | 5 +- eval/eval/expression_build_warning.h | 7 +- eval/public/cel_expr_builder_factory.cc | 2 + eval/public/cel_options.h | 7 + 12 files changed, 729 insertions(+), 438 deletions(-) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index e877d633b..827d82e03 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -59,6 +59,7 @@ cc_test( ], deps = [ ":flat_expr_builder", + "//eval/eval:expression_build_warning", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", @@ -167,9 +168,12 @@ cc_library( ":resolver", "//eval/eval:const_value_step", "//eval/eval:expression_build_warning", + "//eval/public:ast_rewrite", "//eval/public:cel_builtins", "//eval/public:cel_function_registry", + "//eval/public:source_position", "//internal:status_macros", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -213,6 +217,7 @@ cc_test( "//testutil:util", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index d72d33edb..9f9450f9f 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1016,19 +1016,23 @@ FlatExprBuilder::CreateExpressionImpl( const Expr* effective_expr = expr; // transformed expression preserving expression IDs + bool rewrites_enabled = enable_qualified_identifier_rewrites_ || + (reference_map != nullptr && !reference_map->empty()); std::unique_ptr rewrite_buffer = nullptr; + // TODO(issues/98): A type checker may perform these rewrites, but there // currently isn't a signal to expose that in an expression. If that becomes // available, we can skip the reference resolve step here if it's already // done. - if (reference_map != nullptr && !reference_map->empty()) { - absl::StatusOr> rewritten = ResolveReferences( - *effective_expr, *reference_map, resolver, &warnings_builder); + if (rewrites_enabled) { + rewrite_buffer = std::make_unique(*expr); + absl::StatusOr rewritten = + ResolveReferences(reference_map, resolver, source_info, + warnings_builder, rewrite_buffer.get()); if (!rewritten.ok()) { return rewritten.status(); } - if (rewritten->has_value()) { - rewrite_buffer = std::make_unique((*std::move(rewritten)).value()); + if (*rewritten) { effective_expr = rewrite_buffer.get(); } // TODO(issues/99): we could setup a check step here that confirms all of diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 9094c0c98..fc0c387f3 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -49,6 +49,7 @@ class FlatExprBuilder : public CelExpressionBuilder { enable_null_coercion_(true), enable_wrapper_type_null_unboxing_(false), enable_heterogeneous_equality_(false), + enable_qualified_identifier_rewrites_(false), descriptor_pool_(descriptor_pool), message_factory_(message_factory) {} @@ -149,6 +150,17 @@ class FlatExprBuilder : public CelExpressionBuilder { enable_heterogeneous_equality_ = enabled; } + // If enable_qualified_identifier_rewrites is true, the evaluator will attempt + // to disambiguate namespace qualified identifiers. + // + // For functions, this will attempt to determine whether a function call is a + // receiver call or a namespace qualified function. + void set_enable_qualified_identifier_rewrites( + bool enable_qualified_identifier_rewrites) { + enable_qualified_identifier_rewrites_ = + enable_qualified_identifier_rewrites; + } + absl::StatusOr> CreateExpression( const google::api::expr::v1alpha1::Expr* expr, const google::api::expr::v1alpha1::SourceInfo* source_info) const override; @@ -188,6 +200,7 @@ class FlatExprBuilder : public CelExpressionBuilder { bool enable_null_coercion_; bool enable_wrapper_type_null_unboxing_; bool enable_heterogeneous_equality_; + bool enable_qualified_identifier_rewrites_; const google::protobuf::DescriptorPool* descriptor_pool_; google::protobuf::MessageFactory* message_factory_; diff --git a/eval/compiler/flat_expr_builder_comprehensions_test.cc b/eval/compiler/flat_expr_builder_comprehensions_test.cc index 4a3f6aac1..52b1276ed 100644 --- a/eval/compiler/flat_expr_builder_comprehensions_test.cc +++ b/eval/compiler/flat_expr_builder_comprehensions_test.cc @@ -115,7 +115,8 @@ TEST(FlatExprBuilderComprehensionsTest, InvalidComprehensionWithRewrite) { ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr).status(), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Invalid comprehension"))); + testing::AnyOf(HasSubstr("Invalid comprehension"), + HasSubstr("Invalid empty expression")))); } TEST(FlatExprBuilderComprehensionsTest, ComprehensionWithConcatVulernability) { diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index ac4bdfc29..a30a98932 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" @@ -36,6 +37,7 @@ #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "eval/eval/expression_build_warning.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_attribute.h" @@ -627,6 +629,192 @@ TEST(FlatExprBuilderTest, InvalidContainer) { HasSubstr("container: 'bad.'"))); } +TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupport) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("ext.XOr(a, b)")); + FlatExprBuilder builder; + builder.set_enable_qualified_identifier_rewrites(true); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "ext.XOr", /*receiver_style=*/false, + [](google::protobuf::Arena*, bool a, bool b) { return a != b; }, + builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + + google::protobuf::Arena arena; + Activation act1; + act1.InsertValue("a", CelValue::CreateBool(false)); + act1.InsertValue("b", CelValue::CreateBool(true)); + + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); + + Activation act2; + act2.InsertValue("a", CelValue::CreateBool(true)); + act2.InsertValue("b", CelValue::CreateBool(true)); + + ASSERT_OK_AND_ASSIGN(result, cel_expr->Evaluate(act2, &arena)); + EXPECT_THAT(result, test::IsCelBool(false)); +} + +TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportWithContainer) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("XOr(a, b)")); + FlatExprBuilder builder; + builder.set_enable_qualified_identifier_rewrites(true); + builder.set_container("ext"); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "ext.XOr", /*receiver_style=*/false, + [](google::protobuf::Arena*, bool a, bool b) { return a != b; }, + builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + google::protobuf::Arena arena; + Activation act1; + act1.InsertValue("a", CelValue::CreateBool(false)); + act1.InsertValue("b", CelValue::CreateBool(true)); + + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); + + Activation act2; + act2.InsertValue("a", CelValue::CreateBool(true)); + act2.InsertValue("b", CelValue::CreateBool(true)); + + ASSERT_OK_AND_ASSIGN(result, cel_expr->Evaluate(act2, &arena)); + EXPECT_THAT(result, test::IsCelBool(false)); +} + +TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrder) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("c.d.Get()")); + FlatExprBuilder builder; + builder.set_enable_qualified_identifier_rewrites(true); + builder.set_container("a.b"); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "a.b.c.d.Get", /*receiver_style=*/false, + [](google::protobuf::Arena*) { return true; }, builder.GetRegistry())); + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return false; }, + builder.GetRegistry())); + ASSERT_OK((FunctionAdapter::CreateAndRegister( + "Get", + /*receiver_style=*/true, [](google::protobuf::Arena*, bool) { return false; }, + builder.GetRegistry()))); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + google::protobuf::Arena arena; + Activation act1; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); +} + +TEST(FlatExprBuilderTest, + ParsedNamespacedFunctionResolutionOrderParentContainer) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("c.d.Get()")); + FlatExprBuilder builder; + builder.set_enable_qualified_identifier_rewrites(true); + builder.set_container("a.b"); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "a.c.d.Get", /*receiver_style=*/false, + [](google::protobuf::Arena*) { return true; }, builder.GetRegistry())); + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return false; }, + builder.GetRegistry())); + ASSERT_OK((FunctionAdapter::CreateAndRegister( + "Get", + /*receiver_style=*/true, [](google::protobuf::Arena*, bool) { return false; }, + builder.GetRegistry()))); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + google::protobuf::Arena arena; + Activation act1; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); +} + +TEST(FlatExprBuilderTest, + ParsedNamespacedFunctionResolutionOrderExplicitGlobal) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(".c.d.Get()")); + FlatExprBuilder builder; + builder.set_enable_qualified_identifier_rewrites(true); + builder.set_container("a.b"); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "a.c.d.Get", /*receiver_style=*/false, + [](google::protobuf::Arena*) { return false; }, builder.GetRegistry())); + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return true; }, + builder.GetRegistry())); + ASSERT_OK((FunctionAdapter::CreateAndRegister( + "Get", + /*receiver_style=*/true, [](google::protobuf::Arena*, bool) { return false; }, + builder.GetRegistry()))); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + google::protobuf::Arena arena; + Activation act1; + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); +} + +TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderReceiverCall) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("e.Get()")); + FlatExprBuilder builder; + builder.set_enable_qualified_identifier_rewrites(true); + builder.set_container("a.b"); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "a.c.d.Get", /*receiver_style=*/false, + [](google::protobuf::Arena*) { return false; }, builder.GetRegistry())); + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return false; }, + builder.GetRegistry())); + ASSERT_OK((FunctionAdapter::CreateAndRegister( + "Get", + /*receiver_style=*/true, [](google::protobuf::Arena*, bool) { return true; }, + builder.GetRegistry()))); + ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( + &expr.expr(), &expr.source_info())); + google::protobuf::Arena arena; + Activation act1; + act1.InsertValue("e", CelValue::CreateBool(false)); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelBool(true)); +} + +TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportDisabled) { + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("ext.XOr(a, b)")); + FlatExprBuilder builder; + builder.set_fail_on_warnings(false); + std::vector build_warnings; + builder.set_container("ext"); + using FunctionAdapterT = FunctionAdapter; + + ASSERT_OK(FunctionAdapterT::CreateAndRegister( + "ext.XOr", /*receiver_style=*/false, + [](google::protobuf::Arena*, bool a, bool b) { return a != b; }, + builder.GetRegistry())); + ASSERT_OK_AND_ASSIGN( + auto cel_expr, builder.CreateExpression(&expr.expr(), &expr.source_info(), + &build_warnings)); + google::protobuf::Arena arena; + Activation act1; + act1.InsertValue("a", CelValue::CreateBool(false)); + act1.InsertValue("b", CelValue::CreateBool(true)); + + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); + EXPECT_THAT(result, test::IsCelError(StatusIs(absl::StatusCode::kUnknown, + HasSubstr("ext")))); +} + TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { CheckedExpr expr; // foo && bar diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 4bf4f5dde..0c880a09d 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -4,6 +4,9 @@ #include #include +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -12,8 +15,10 @@ #include "absl/types/optional.h" #include "eval/eval/const_value_step.h" #include "eval/eval/expression_build_warning.h" +#include "eval/public/ast_rewrite.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/source_position.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { @@ -31,29 +36,6 @@ bool IsSpecialFunction(absl::string_view function_name) { function_name == builtin::kIndex || function_name == builtin::kTernary; } -// Convert a select expr sub tree into a namespace name if possible. -// If any operand of the top element is a not a select or an ident node, -// return nullopt. -absl::optional ToNamespace(const Expr& expr) { - absl::optional maybe_parent_namespace; - switch (expr.expr_kind_case()) { - case Expr::kIdentExpr: - return expr.ident_expr().name(); - case Expr::kSelectExpr: - if (expr.select_expr().test_only()) { - return absl::nullopt; - } - maybe_parent_namespace = ToNamespace(expr.select_expr().operand()); - if (!maybe_parent_namespace.has_value()) { - return absl::nullopt; - } - return absl::StrCat(*maybe_parent_namespace, ".", - expr.select_expr().field()); - default: - return absl::nullopt; - } -} - bool OverloadExists(const Resolver& resolver, absl::string_view name, const std::vector& arguments_matcher, bool receiver_style = false) { @@ -76,16 +58,28 @@ absl::optional BestOverloadMatch(const Resolver& resolver, auto names = resolver.FullyQualifiedNames(base_name); for (auto name = names.begin(); name != names.end(); ++name) { if (OverloadExists(resolver, *name, arguments_matcher)) { + if (base_name[0] == '.') { + // Preserve leading '.' to prevent re-resolving at plan time. + return std::string(base_name); + } return *name; } } return absl::nullopt; } -class ReferenceResolver { +// Rewriter visitor for resolving references. +// +// On previsit pass, replace (possibly qualified) identifier branches with the +// canonical name in the reference map (most qualified references considered +// first). +// +// On post visit pass, update function calls to determine whether the function +// target is a namespace for the function or a receiver for the call. +class ReferenceResolver : public AstRewriterBase { public: - ReferenceResolver(const google::protobuf::Map& reference_map, - const Resolver& resolver, BuilderWarnings* warnings) + ReferenceResolver(const google::protobuf::Map* reference_map, + const Resolver& resolver, BuilderWarnings& warnings) : reference_map_(reference_map), resolver_(resolver), warnings_(warnings) {} @@ -95,90 +89,44 @@ class ReferenceResolver { // TODO(issues/95): If possible, it would be nice to write a general utility // for running the preprocess steps when traversing the AST instead of having // one pass per transform. - absl::StatusOr Rewrite(Expr* out) { - const auto reference_iter = reference_map_.find(out->id()); - const Reference* reference = nullptr; - if (reference_iter != reference_map_.end()) { - if (!reference_iter->second.has_value()) { - reference = &reference_iter->second; + bool PreVisitRewrite(Expr* expr, const SourcePosition* position) override { + const Reference* reference = GetReferenceForId(expr->id()); + + // Fold compile time constant (e.g. enum values) + if (reference != nullptr && reference->has_value()) { + if (reference->value().constant_kind_case() == Constant::kInt64Value) { + // Replace enum idents with const reference value. + expr->mutable_const_expr()->set_int64_value( + reference->value().int64_value()); + return true; } else { - if (out->expr_kind_case() == Expr::kIdentExpr && - reference_iter->second.value().constant_kind_case() == - Constant::kInt64Value) { - // Replace enum idents with const reference value. - out->clear_ident_expr(); - out->mutable_const_expr()->set_int64_value( - reference_iter->second.value().int64_value()); - return true; - } + // No update if the constant reference isn't an int (an enum value). + return false; } } - bool updated = false; - switch (out->expr_kind_case()) { - case Expr::kConstExpr: { - return false; - } - case Expr::kIdentExpr: - return MaybeUpdateIdentNode(out, reference); - case Expr::kSelectExpr: - return MaybeUpdateSelectNode(out, reference); - case Expr::kCallExpr: { - return MaybeUpdateCallNode(out, reference); - } - case Expr::kListExpr: { - auto* list_expr = out->mutable_list_expr(); - int list_size = list_expr->elements_size(); - for (int i = 0; i < list_size; i++) { - CEL_ASSIGN_OR_RETURN(bool rewrite_result, - Rewrite(list_expr->mutable_elements(i))); - updated = updated || rewrite_result; - } - return updated; - } - case Expr::kStructExpr: { - return MaybeUpdateStructNode(out, reference); + if (reference != nullptr) { + switch (expr->expr_kind_case()) { + case Expr::kIdentExpr: + return MaybeUpdateIdentNode(expr, *reference); + case Expr::kSelectExpr: + return MaybeUpdateSelectNode(expr, *reference); + default: + // Call nodes are updated on post visit so they will see any select + // path rewrites. + return false; } - case Expr::kComprehensionExpr: { - auto* out_expr = out->mutable_comprehension_expr(); - bool rewrite_result; - - if (out_expr->has_accu_init()) { - CEL_ASSIGN_OR_RETURN(rewrite_result, - Rewrite(out_expr->mutable_accu_init())); - updated = updated || rewrite_result; - } - - if (out_expr->has_iter_range()) { - CEL_ASSIGN_OR_RETURN(rewrite_result, - Rewrite(out_expr->mutable_iter_range())); - updated = updated || rewrite_result; - } - - if (out_expr->has_loop_condition()) { - CEL_ASSIGN_OR_RETURN(rewrite_result, - Rewrite(out_expr->mutable_loop_condition())); - updated = updated || rewrite_result; - } - - if (out_expr->has_loop_step()) { - CEL_ASSIGN_OR_RETURN(rewrite_result, - Rewrite(out_expr->mutable_loop_step())); - updated = updated || rewrite_result; - } - - if (out_expr->has_result()) { - CEL_ASSIGN_OR_RETURN(rewrite_result, - Rewrite(out_expr->mutable_result())); - updated = updated || rewrite_result; - } + } + return false; + } - return updated; - } - default: - GOOGLE_LOG(ERROR) << "Unsupported Expr kind: " << out->expr_kind_case(); - return false; + bool PostVisitRewrite(Expr* expr, + const SourcePosition* source_position) override { + const Reference* reference = GetReferenceForId(expr->id()); + if (expr->has_call_expr()) { + return MaybeUpdateCallNode(expr, reference); } + return false; } private: @@ -187,39 +135,28 @@ class ReferenceResolver { // // TODO(issues/95): This duplicates some of the overload matching behavior // for parsed expressions. We should refactor to consolidate the code. - absl::StatusOr MaybeUpdateCallNode(Expr* out, - const Reference* reference) { + bool MaybeUpdateCallNode(Expr* out, const Reference* reference) { auto* call_expr = out->mutable_call_expr(); if (reference != nullptr && reference->overload_id_size() == 0) { - CEL_RETURN_IF_ERROR(warnings_->AddWarning(absl::InvalidArgumentError( - absl::StrCat("Reference map doesn't provide overloads for ", - out->call_expr().function())))); + warnings_ + .AddWarning(absl::InvalidArgumentError( + absl::StrCat("Reference map doesn't provide overloads for ", + out->call_expr().function()))) + .IgnoreError(); } bool receiver_style = call_expr->has_target(); - bool updated = false; int arg_num = call_expr->args_size(); if (receiver_style) { - // First check the target to see if the reference map indicates it - // should be rewritten. - absl::StatusOr rewrite_result = - Rewrite(call_expr->mutable_target()); - CEL_RETURN_IF_ERROR(rewrite_result.status()); - bool target_updated = rewrite_result.value(); - updated = target_updated; - if (!target_updated) { - // If the function receiver was not rewritten, check to see if it's - // actually a namespace for the function. - auto maybe_namespace = ToNamespace(call_expr->target()); - if (maybe_namespace.has_value()) { - std::string resolved_name = - absl::StrCat(*maybe_namespace, ".", call_expr->function()); - auto maybe_resolved_function = - BestOverloadMatch(resolver_, resolved_name, arg_num); - if (maybe_resolved_function.has_value()) { - call_expr->set_function(maybe_resolved_function.value()); - call_expr->clear_target(); - updated = true; - } + auto maybe_namespace = ToNamespace(call_expr->target()); + if (maybe_namespace.has_value()) { + std::string resolved_name = + absl::StrCat(*maybe_namespace, ".", call_expr->function()); + auto resolved_function = + BestOverloadMatch(resolver_, resolved_name, arg_num); + if (resolved_function.has_value()) { + call_expr->set_function(*resolved_function); + call_expr->clear_target(); + return true; } } } else { @@ -228,12 +165,14 @@ class ReferenceResolver { auto maybe_resolved_function = BestOverloadMatch(resolver_, call_expr->function(), arg_num); if (!maybe_resolved_function.has_value()) { - CEL_RETURN_IF_ERROR(warnings_->AddWarning(absl::InvalidArgumentError( - absl::StrCat("No overload found in reference resolve step for ", - call_expr->function())))); + warnings_ + .AddWarning(absl::InvalidArgumentError( + absl::StrCat("No overload found in reference resolve step for ", + call_expr->function()))) + .IgnoreError(); } else if (maybe_resolved_function.value() != call_expr->function()) { call_expr->set_function(maybe_resolved_function.value()); - updated = true; + return true; } } // For parity, if we didn't rewrite the receiver call style function, @@ -242,102 +181,107 @@ class ReferenceResolver { !OverloadExists(resolver_, call_expr->function(), ArgumentsMatcher(arg_num + 1), /* receiver_style= */ true)) { - CEL_RETURN_IF_ERROR(warnings_->AddWarning(absl::InvalidArgumentError( - absl::StrCat("No overload found in reference resolve step for ", - call_expr->function())))); - } - for (int i = 0; i < arg_num; i++) { - absl::StatusOr rewrite_result = Rewrite(call_expr->mutable_args(i)); - CEL_RETURN_IF_ERROR(rewrite_result.status()); - updated = updated || rewrite_result.value(); + warnings_ + .AddWarning(absl::InvalidArgumentError( + absl::StrCat("No overload found in reference resolve step for ", + call_expr->function()))) + .IgnoreError(); } - return updated; + return false; } - // Attempt to resolve a select node. If reference is not-null and valid, - // replace the select node with the fully qualified ident node. Otherwise, - // continue recursively rewriting the Expr. - absl::StatusOr MaybeUpdateSelectNode(Expr* out, - const Reference* reference) { - if (reference != nullptr) { - if (out->select_expr().test_only()) { - CEL_RETURN_IF_ERROR(warnings_->AddWarning( - absl::InvalidArgumentError("Reference map points to a presence " - "test -- has(container.attr)"))); - } else if (!reference->name().empty()) { - out->clear_select_expr(); - out->mutable_ident_expr()->set_name(reference->name()); - return true; - } + // Attempt to resolve a select node. If reference is valid, + // replace the select node with the fully qualified ident node. + bool MaybeUpdateSelectNode(Expr* out, const Reference& reference) { + if (out->select_expr().test_only()) { + warnings_ + .AddWarning( + absl::InvalidArgumentError("Reference map points to a presence " + "test -- has(container.attr)")) + .IgnoreError(); + } else if (!reference.name().empty()) { + out->mutable_ident_expr()->set_name(reference.name()); + rewritten_reference_.insert(out->id()); + return true; } - return Rewrite(out->mutable_select_expr()->mutable_operand()); + return false; } - // Attempt to resolve an ident node. If reference is not-null and valid, + // Attempt to resolve an ident node. If reference is valid, // replace the node with the fully qualified ident node. - bool MaybeUpdateIdentNode(Expr* out, const Reference* reference) { - if (reference != nullptr && !reference->name().empty() && - reference->name() != out->ident_expr().name()) { - out->mutable_ident_expr()->set_name(reference->name()); + bool MaybeUpdateIdentNode(Expr* out, const Reference& reference) { + if (!reference.name().empty() && + reference.name() != out->ident_expr().name()) { + out->mutable_ident_expr()->set_name(reference.name()); + rewritten_reference_.insert(out->id()); return true; } return false; } - // Update a create struct node. Currently, just handles recursing. - // - // TODO(issues/72): annotating the execution plan with this may help - // identify problems with the environment setup. This will probably - // also require the type map information from a checked expression. - absl::StatusOr MaybeUpdateStructNode(Expr* out, - const Reference* reference) { - auto* struct_expr = out->mutable_struct_expr(); - int entries_size = struct_expr->entries_size(); - bool updated = false; - for (int i = 0; i < entries_size; i++) { - auto* new_entry = struct_expr->mutable_entries(i); - switch (new_entry->key_kind_case()) { - case Expr::CreateStruct::Entry::kFieldKey: - // Nothing to do. - break; - case Expr::CreateStruct::Entry::kMapKey: { - auto key_updated = Rewrite(new_entry->mutable_map_key()); - CEL_RETURN_IF_ERROR(key_updated.status()); - updated = updated || key_updated.value(); - break; + // Convert a select expr sub tree into a namespace name if possible. + // If any operand of the top element is a not a select or an ident node, + // return nullopt. + absl::optional ToNamespace(const Expr& expr) { + absl::optional maybe_parent_namespace; + if (rewritten_reference_.find(expr.id()) != rewritten_reference_.end()) { + // The target expr matches a reference (resolved to an ident decl). + // This should not be treated as a function qualifier. + return absl::nullopt; + } + switch (expr.expr_kind_case()) { + case Expr::kIdentExpr: + return expr.ident_expr().name(); + case Expr::kSelectExpr: + if (expr.select_expr().test_only()) { + return absl::nullopt; } - default: - GOOGLE_LOG(ERROR) << "Unsupported Entry kind: " - << new_entry->key_kind_case(); - break; - } - auto value_updated = Rewrite(new_entry->mutable_value()); - CEL_RETURN_IF_ERROR(value_updated.status()); - updated = updated || value_updated.value(); + maybe_parent_namespace = ToNamespace(expr.select_expr().operand()); + if (!maybe_parent_namespace.has_value()) { + return absl::nullopt; + } + return absl::StrCat(*maybe_parent_namespace, ".", + expr.select_expr().field()); + default: + return absl::nullopt; + } + } + + // Find a reference for the given expr id. + // + // Returns nullptr if no reference is available. + const Reference* GetReferenceForId(int64_t expr_id) { + if (reference_map_ == nullptr) { + return nullptr; } - return updated; + auto iter = reference_map_->find(expr_id); + if (iter == reference_map_->end()) { + return nullptr; + } + return &iter->second; } - const google::protobuf::Map& reference_map_; + const google::protobuf::Map* reference_map_; const Resolver& resolver_; - BuilderWarnings* warnings_; + BuilderWarnings& warnings_; + absl::flat_hash_set rewritten_reference_; }; } // namespace -absl::StatusOr> ResolveReferences( - const Expr& expr, const google::protobuf::Map& reference_map, - const Resolver& resolver, BuilderWarnings* warnings) { - Expr out(expr); +absl::StatusOr ResolveReferences( + const google::protobuf::Map* reference_map, + const Resolver& resolver, const SourceInfo* source_info, + BuilderWarnings& warnings, Expr* expr) { ReferenceResolver ref_resolver(reference_map, resolver, warnings); - absl::StatusOr rewrite_result = ref_resolver.Rewrite(&out); - if (!rewrite_result.ok()) { - return rewrite_result.status(); - } else if (rewrite_result.value()) { - return absl::optional(out); - } else { - return absl::optional(); + + // Rewriting interface doesn't support failing mid traverse propagate first + // error encountered if fail fast enabled. + bool was_rewritten = AstRewrite(expr, source_info, &ref_resolver); + if (warnings.fail_immediately() && !warnings.warnings().empty()) { + return warnings.warnings().front(); } + return was_rewritten; } } // namespace google::api::expr::runtime diff --git a/eval/compiler/qualified_reference_resolver.h b/eval/compiler/qualified_reference_resolver.h index 80b7f84fe..9c79b44d2 100644 --- a/eval/compiler/qualified_reference_resolver.h +++ b/eval/compiler/qualified_reference_resolver.h @@ -14,15 +14,19 @@ namespace google::api::expr::runtime { -// A transformation over input expression that produces a new expression with -// subexpressions replaced by appropriate expressions referring to the -// fully-qualified entity name or constant expressions in case of enums. -// Returns modified expr if updates found. -// Otherwise, returns nullopt. -absl::StatusOr> ResolveReferences( - const google::api::expr::v1alpha1::Expr& expr, - const google::protobuf::Map& reference_map, - const Resolver& resolver, BuilderWarnings* warnings); +// Resolves possibly qualified names in the provided expression, updating +// subexpressions with to use the fully qualified name, or a constant +// expressions in the case of enums. +// +// Returns true if updates were applied. +// +// Will warn or return a non-ok status if references can't be resolved (no +// function overload could match a call) or are inconsistnet (reference map +// points to an expr node that isn't a reference). +absl::StatusOr ResolveReferences( + const google::protobuf::Map* reference_map, + const Resolver& resolver, const SourceInfo* source_info, + BuilderWarnings& warnings, Expr* expr); } // namespace google::api::expr::runtime diff --git a/eval/compiler/qualified_reference_resolver_test.cc b/eval/compiler/qualified_reference_resolver_test.cc index f309d5dd1..48cf0a323 100644 --- a/eval/compiler/qualified_reference_resolver_test.cc +++ b/eval/compiler/qualified_reference_resolver_test.cc @@ -3,6 +3,7 @@ #include #include +#include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/text_format.h" #include "absl/status/status.h" #include "absl/types/optional.h" @@ -24,8 +25,9 @@ using ::google::api::expr::v1alpha1::Reference; using testing::ElementsAre; using testing::Eq; using testing::IsEmpty; -using testing::Optional; using testing::UnorderedElementsAre; +using cel::internal::IsOkAndHolds; +using cel::internal::StatusIs; using testutil::EqualsProto; // foo.bar.var1 && bar.foo.var2 @@ -81,6 +83,7 @@ Expr ParseTestProto(const std::string& pb) { TEST(ResolveReferences, Basic) { Expr expr = ParseTestProto(kExpr); + SourceInfo source_info; google::protobuf::Map reference_map; reference_map[2].set_name("foo.bar.var1"); reference_map[5].set_name("bar.foo.var2"); @@ -89,38 +92,49 @@ TEST(ResolveReferences, Basic) { CelTypeRegistry type_registry; Resolver registry("", &func_registry, &type_registry); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Optional(EqualsProto(R"pb( - id: 1 - call_expr { - function: "_&&_" - args { - id: 2 - ident_expr { name: "foo.bar.var1" } - } - args { - id: 5 - ident_expr { name: "bar.foo.var2" } - } - })pb"))); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + ASSERT_THAT(result, IsOkAndHolds(true)); + EXPECT_THAT(expr, EqualsProto(R"pb( + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + ident_expr { name: "foo.bar.var1" } + } + args { + id: 5 + ident_expr { name: "bar.foo.var2" } + } + })pb")); } -TEST(ResolveReferences, ReturnsNulloptIfNoChanges) { +TEST(ResolveReferences, ReturnsFalseIfNoChanges) { Expr expr = ParseTestProto(kExpr); + SourceInfo source_info; google::protobuf::Map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; CelTypeRegistry type_registry; Resolver registry("", &func_registry, &type_registry); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Eq(absl::nullopt)); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + ASSERT_THAT(result, IsOkAndHolds(false)); + + // reference to the same name also doesn't count as a rewrite. + reference_map[4].set_name("foo"); + reference_map[7].set_name("bar"); + + result = ResolveReferences(&reference_map, registry, &source_info, warnings, + &expr); + ASSERT_THAT(result, IsOkAndHolds(false)); } TEST(ResolveReferences, NamespacedIdent) { Expr expr = ParseTestProto(kExpr); + SourceInfo source_info; google::protobuf::Map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; @@ -129,33 +143,34 @@ TEST(ResolveReferences, NamespacedIdent) { reference_map[2].set_name("foo.bar.var1"); reference_map[7].set_name("namespace_x.bar"); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Optional(EqualsProto(R"pb( - id: 1 - call_expr { - function: "_&&_" - args { - id: 2 - ident_expr { name: "foo.bar.var1" } - } - args { - id: 5 - select_expr { - field: "var2" - operand { - id: 6 - select_expr { - field: "foo" - operand { - id: 7 - ident_expr { name: "namespace_x.bar" } - } - } - } - } - } - })pb"))); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + ASSERT_THAT(result, IsOkAndHolds(true)); + EXPECT_THAT(expr, EqualsProto(R"pb( + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + ident_expr { name: "foo.bar.var1" } + } + args { + id: 5 + select_expr { + field: "var2" + operand { + id: 6 + select_expr { + field: "foo" + operand { + id: 7 + ident_expr { name: "namespace_x.bar" } + } + } + } + } + } + })pb")); } TEST(ResolveReferences, WarningOnPresenceTest) { @@ -175,6 +190,8 @@ TEST(ResolveReferences, WarningOnPresenceTest) { } } })"); + SourceInfo source_info; + google::protobuf::Map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; @@ -182,9 +199,10 @@ TEST(ResolveReferences, WarningOnPresenceTest) { Resolver registry("", &func_registry, &type_registry); reference_map[1].set_name("foo.bar.var1"); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Eq(absl::nullopt)); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT( warnings.warnings(), testing::ElementsAre(Eq(absl::Status( @@ -219,8 +237,11 @@ constexpr char kEnumExpr[] = R"( } } )"; + TEST(ResolveReferences, EnumConstReferenceUsed) { Expr expr = ParseTestProto(kEnumExpr); + SourceInfo source_info; + google::protobuf::Map reference_map; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); @@ -231,25 +252,63 @@ TEST(ResolveReferences, EnumConstReferenceUsed) { reference_map[5].mutable_value()->set_int64_value(9); BuilderWarnings warnings; - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Optional(EqualsProto(R"pb( - id: 1 - call_expr { - function: "_==_" - args { - id: 2 - ident_expr { name: "foo.bar.var1" } - } - args { - id: 5 - const_expr { int64_value: 9 } - } - })pb"))); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(true)); + EXPECT_THAT(expr, EqualsProto(R"pb( + id: 1 + call_expr { + function: "_==_" + args { + id: 2 + ident_expr { name: "foo.bar.var1" } + } + args { + id: 5 + const_expr { int64_value: 9 } + } + })pb")); +} + +TEST(ResolveReferences, EnumConstReferenceUsedSelect) { + Expr expr = ParseTestProto(kEnumExpr); + SourceInfo source_info; + + google::protobuf::Map reference_map; + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); + reference_map[2].set_name("foo.bar.var1"); + reference_map[2].mutable_value()->set_int64_value(2); + reference_map[5].set_name("bar.foo.Enum.ENUM_VAL1"); + reference_map[5].mutable_value()->set_int64_value(9); + BuilderWarnings warnings; + + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(true)); + EXPECT_THAT(expr, EqualsProto(R"pb( + id: 1 + call_expr { + function: "_==_" + args { + id: 2 + const_expr { int64_value: 2 } + } + args { + id: 5 + const_expr { int64_value: 9 } + } + })pb")); } TEST(ResolveReferences, ConstReferenceSkipped) { Expr expr = ParseTestProto(kExpr); + SourceInfo source_info; + google::protobuf::Map reference_map; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); @@ -260,33 +319,35 @@ TEST(ResolveReferences, ConstReferenceSkipped) { reference_map[5].set_name("bar.foo.var2"); BuilderWarnings warnings; - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Optional(EqualsProto(R"pb( - id: 1 - call_expr { - function: "_&&_" - args { - id: 2 - select_expr { - field: "var1" - operand { - id: 3 - select_expr { - field: "bar" - operand { - id: 4 - ident_expr { name: "foo" } - } - } - } - } - } - args { - id: 5 - ident_expr { name: "bar.foo.var2" } - } - })pb"))); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(true)); + EXPECT_THAT(expr, EqualsProto(R"pb( + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + select_expr { + field: "var1" + operand { + id: 3 + select_expr { + field: "bar" + operand { + id: 4 + ident_expr { name: "foo" } + } + } + } + } + } + args { + id: 5 + ident_expr { name: "bar.foo.var2" } + } + })pb")); } constexpr char kExtensionAndExpr[] = R"( @@ -309,6 +370,8 @@ call_expr { TEST(ResolveReferences, FunctionReferenceBasic) { Expr expr = ParseTestProto(kExtensionAndExpr); + SourceInfo source_info; + google::protobuf::Map reference_map; CelFunctionRegistry func_registry; ASSERT_OK(func_registry.RegisterLazyFunction( @@ -322,13 +385,16 @@ TEST(ResolveReferences, FunctionReferenceBasic) { BuilderWarnings warnings; reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Eq(absl::nullopt)); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(false)); } TEST(ResolveReferences, FunctionReferenceMissingOverloadDetected) { Expr expr = ParseTestProto(kExtensionAndExpr); + SourceInfo source_info; + google::protobuf::Map reference_map; CelFunctionRegistry func_registry; CelTypeRegistry type_registry; @@ -336,9 +402,10 @@ TEST(ResolveReferences, FunctionReferenceMissingOverloadDetected) { BuilderWarnings warnings; reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Eq(absl::nullopt)); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(warnings.warnings(), ElementsAre(StatusCodeIs(absl::StatusCode::kInvalidArgument))); } @@ -357,6 +424,7 @@ TEST(ResolveReferences, SpecialBuiltinsNotWarned) { const_expr { bool_value: false } } })"); + SourceInfo source_info; std::vector special_builtins{builtin::kAnd, builtin::kOr, builtin::kTernary, builtin::kIndex}; @@ -370,9 +438,10 @@ TEST(ResolveReferences, SpecialBuiltinsNotWarned) { reference_map[1].add_overload_id(absl::StrCat("builtin.", builtin_fn)); expr.mutable_call_expr()->set_function(builtin_fn); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Eq(absl::nullopt)); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(warnings.warnings(), IsEmpty()); } } @@ -380,6 +449,8 @@ TEST(ResolveReferences, SpecialBuiltinsNotWarned) { TEST(ResolveReferences, FunctionReferenceMissingOverloadDetectedAndMissingReference) { Expr expr = ParseTestProto(kExtensionAndExpr); + SourceInfo source_info; + google::protobuf::Map reference_map; CelFunctionRegistry func_registry; CelTypeRegistry type_registry; @@ -387,9 +458,10 @@ TEST(ResolveReferences, BuilderWarnings warnings; reference_map[1].set_name("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Eq(absl::nullopt)); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT( warnings.warnings(), UnorderedElementsAre( @@ -399,8 +471,28 @@ TEST(ResolveReferences, "Reference map doesn't provide overloads for boolean_and")))); } +TEST(ResolveReferences, EmulatesEagerFailing) { + Expr expr = ParseTestProto(kExtensionAndExpr); + SourceInfo source_info; + + google::protobuf::Map reference_map; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); + BuilderWarnings warnings(/*fail_eagerly=*/true); + reference_map[1].set_name("udf_boolean_and"); + + EXPECT_THAT( + ResolveReferences(&reference_map, registry, &source_info, warnings, + &expr), + StatusIs(absl::StatusCode::kInvalidArgument, + "Reference map doesn't provide overloads for boolean_and")); +} + TEST(ResolveReferences, FunctionReferenceToWrongExprKind) { Expr expr = ParseTestProto(kExtensionAndExpr); + SourceInfo source_info; + google::protobuf::Map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; @@ -408,9 +500,10 @@ TEST(ResolveReferences, FunctionReferenceToWrongExprKind) { Resolver registry("", &func_registry, &type_registry); reference_map[2].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Eq(absl::nullopt)); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(warnings.warnings(), ElementsAre(StatusCodeIs(absl::StatusCode::kInvalidArgument))); } @@ -435,6 +528,8 @@ call_expr { TEST(ResolveReferences, FunctionReferenceWithTargetNoChange) { Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); + SourceInfo source_info; + google::protobuf::Map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; @@ -444,15 +539,18 @@ TEST(ResolveReferences, FunctionReferenceWithTargetNoChange) { Resolver registry("", &func_registry, &type_registry); reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Eq(absl::nullopt)); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(warnings.warnings(), IsEmpty()); } TEST(ResolveReferences, FunctionReferenceWithTargetNoChangeMissingOverloadDetected) { Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); + SourceInfo source_info; + google::protobuf::Map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; @@ -460,15 +558,18 @@ TEST(ResolveReferences, Resolver registry("", &func_registry, &type_registry); reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Eq(absl::nullopt)); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(warnings.warnings(), ElementsAre(StatusCodeIs(absl::StatusCode::kInvalidArgument))); } TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunction) { Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); + SourceInfo source_info; + google::protobuf::Map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; @@ -478,24 +579,28 @@ TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunction) { Resolver registry("", &func_registry, &type_registry); reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Optional(EqualsProto(R"pb( - id: 1 - call_expr { - function: "ext.boolean_and" - args { - id: 3 - const_expr { bool_value: false } - } - } - )pb"))); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(true)); + EXPECT_THAT(expr, EqualsProto(R"pb( + id: 1 + call_expr { + function: "ext.boolean_and" + args { + id: 3 + const_expr { bool_value: false } + } + } + )pb")); EXPECT_THAT(warnings.warnings(), IsEmpty()); } TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunctionInContainer) { Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); + SourceInfo source_info; + google::protobuf::Map reference_map; reference_map[1].add_overload_id("udf_boolean_and"); BuilderWarnings warnings; @@ -504,19 +609,20 @@ TEST(ResolveReferences, "com.google.ext.boolean_and", false, {CelValue::Type::kBool}))); CelTypeRegistry type_registry; Resolver registry("com.google", &func_registry, &type_registry); - - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Optional(EqualsProto(R"pb( - id: 1 - call_expr { - function: "com.google.ext.boolean_and" - args { - id: 3 - const_expr { bool_value: false } - } - } - )pb"))); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(true)); + EXPECT_THAT(expr, EqualsProto(R"pb( + id: 1 + call_expr { + function: "com.google.ext.boolean_and" + args { + id: 3 + const_expr { bool_value: false } + } + } + )pb")); EXPECT_THAT(warnings.warnings(), IsEmpty()); } @@ -548,6 +654,8 @@ call_expr { TEST(ResolveReferences, FunctionReferenceWithHasTargetNoChange) { Expr expr = ParseTestProto(kReceiverCallHasExtensionAndExpr); + SourceInfo source_info; + google::protobuf::Map reference_map; BuilderWarnings warnings; CelFunctionRegistry func_registry; @@ -559,10 +667,12 @@ TEST(ResolveReferences, FunctionReferenceWithHasTargetNoChange) { Resolver registry("", &func_registry, &type_registry); reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(false)); // The target is unchanged because it is a test_only select. - EXPECT_THAT(*result, Eq(absl::nullopt)); + EXPECT_THAT(expr, EqualsProto(kReceiverCallHasExtensionAndExpr)); EXPECT_THAT(warnings.warnings(), IsEmpty()); } @@ -635,6 +745,9 @@ comprehension_expr: { )"; TEST(ResolveReferences, EnumConstReferenceUsedInComprehension) { Expr expr = ParseTestProto(kComprehensionExpr); + + SourceInfo source_info; + google::protobuf::Map reference_map; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); @@ -646,79 +759,81 @@ TEST(ResolveReferences, EnumConstReferenceUsedInComprehension) { reference_map[7].mutable_value()->set_int64_value(2); BuilderWarnings warnings; - auto result = ResolveReferences(expr, reference_map, registry, &warnings); - ASSERT_OK(result); - EXPECT_THAT(*result, Optional(EqualsProto(R"pb( - id: 17 - comprehension_expr { - iter_var: "i" - iter_range { - id: 1 - list_expr { - elements { - id: 2 - const_expr { int64_value: 1 } - } - elements { - id: 3 - const_expr { int64_value: 2 } - } - elements { - id: 4 - const_expr { int64_value: 3 } - } - } - } - accu_var: "__result__" - accu_init { - id: 10 - const_expr { bool_value: false } - } - loop_condition { - id: 13 - call_expr { - function: "@not_strictly_false" - args { - id: 12 - call_expr { - function: "!_" - args { - id: 11 - ident_expr { name: "__result__" } - } - } - } - } - } - loop_step { - id: 15 - call_expr { - function: "_||_" - args { - id: 14 - ident_expr { name: "__result__" } - } - args { - id: 8 - call_expr { - function: "_==_" - args { - id: 7 - const_expr { int64_value: 2 } - } - args { - id: 9 - ident_expr { name: "i" } - } - } - } - } - } - result { - id: 16 - ident_expr { name: "__result__" } - } - })pb"))); + auto result = ResolveReferences(&reference_map, registry, &source_info, + warnings, &expr); + + ASSERT_THAT(result, IsOkAndHolds(true)); + EXPECT_THAT(expr, EqualsProto(R"pb( + id: 17 + comprehension_expr { + iter_var: "i" + iter_range { + id: 1 + list_expr { + elements { + id: 2 + const_expr { int64_value: 1 } + } + elements { + id: 3 + const_expr { int64_value: 2 } + } + elements { + id: 4 + const_expr { int64_value: 3 } + } + } + } + accu_var: "__result__" + accu_init { + id: 10 + const_expr { bool_value: false } + } + loop_condition { + id: 13 + call_expr { + function: "@not_strictly_false" + args { + id: 12 + call_expr { + function: "!_" + args { + id: 11 + ident_expr { name: "__result__" } + } + } + } + } + } + loop_step { + id: 15 + call_expr { + function: "_||_" + args { + id: 14 + ident_expr { name: "__result__" } + } + args { + id: 8 + call_expr { + function: "_==_" + args { + id: 7 + const_expr { int64_value: 2 } + } + args { + id: 9 + ident_expr { name: "i" } + } + } + } + } + } + result { + id: 16 + ident_expr { name: "__result__" } + } + })pb")); } } // namespace diff --git a/eval/eval/expression_build_warning.cc b/eval/eval/expression_build_warning.cc index dc634fb00..b7fba14a3 100644 --- a/eval/eval/expression_build_warning.cc +++ b/eval/eval/expression_build_warning.cc @@ -3,10 +3,13 @@ namespace google::api::expr::runtime { absl::Status BuilderWarnings::AddWarning(const absl::Status& warning) { + // Track errors + warnings_.push_back(warning); + if (fail_immediately_) { return warning; } - warnings_.push_back(warning); + return absl::OkStatus(); } diff --git a/eval/eval/expression_build_warning.h b/eval/eval/expression_build_warning.h index db3e88da8..59d192bda 100644 --- a/eval/eval/expression_build_warning.h +++ b/eval/eval/expression_build_warning.h @@ -1,6 +1,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_BUILD_WARNING_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_BUILD_WARNING_H_ +#include #include #include "absl/status/status.h" @@ -17,8 +18,12 @@ class BuilderWarnings { // set. absl::Status AddWarning(const absl::Status& warning); + bool fail_immediately() const { return fail_immediately_; } + // Return the list of recorded warnings. - const std::vector& warnings() const { return warnings_; } + const std::vector& warnings() const& { return warnings_; } + + std::vector&& warnings() && { return std::move(warnings_); } private: std::vector warnings_; diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 017521457..1fb0f23a5 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -68,6 +68,8 @@ std::unique_ptr CreateCelExpressionBuilder( options.enable_empty_wrapper_null_unboxing); builder->set_enable_heterogeneous_equality( options.enable_heterogeneous_equality); + builder->set_enable_qualified_identifier_rewrites( + options.enable_qualified_identifier_rewrites); switch (options.unknown_processing) { case UnknownProcessingOptions::kAttributeAndFunction: diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index d354b952d..38f0511c8 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -131,6 +131,13 @@ struct InterpreterOptions { // that will result in a Null cel value, as opposed to returning the // cel representation of the proto defined default int64_t: 0. bool enable_empty_wrapper_null_unboxing = false; + + // Enables expression rewrites to disambiguate namespace qualified identifiers + // from container access for variables and receiver-style calls for functions. + // + // Note: This makes an implicit copy of the input expression for lifetime + // safety. + bool enable_qualified_identifier_rewrites = false; }; } // namespace google::api::expr::runtime From 667279c20ad6a6d20eeaf31ffc2bb6f04f16d5f9 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 Apr 2022 21:12:38 +0000 Subject: [PATCH 065/111] Branch field access implementation to remove direct dependency on CelProtoWrapper. PiperOrigin-RevId: 441576102 --- eval/public/structs/BUILD | 41 + eval/public/structs/field_access_impl.cc | 745 ++++++++++++++++++ eval/public/structs/field_access_impl.h | 88 +++ eval/public/structs/field_access_impl_test.cc | 647 +++++++++++++++ 4 files changed, 1521 insertions(+) create mode 100644 eval/public/structs/field_access_impl.cc create mode 100644 eval/public/structs/field_access_impl.h create mode 100644 eval/public/structs/field_access_impl_test.cc diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 72078deb9..87fa55fb3 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -91,6 +91,47 @@ cc_test( ], ) +cc_library( + name = "field_access_impl", + srcs = [ + "field_access_impl.cc", + ], + hdrs = [ + "field_access_impl.h", + ], + deps = [ + ":cel_proto_wrap_util", + ":protobuf_value_factory", + "//eval/public:cel_value", + "//internal:casts", + "//internal:overflow", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "field_access_impl_test", + srcs = ["field_access_impl_test.cc"], + deps = [ + ":cel_proto_wrapper", + ":field_access_impl", + "//eval/public:cel_value", + "//eval/public/testing:matchers", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "//internal:time", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "cel_proto_descriptor_pool_builder", srcs = ["cel_proto_descriptor_pool_builder.cc"], diff --git a/eval/public/structs/field_access_impl.cc b/eval/public/structs/field_access_impl.cc new file mode 100644 index 000000000..9f8faf7ba --- /dev/null +++ b/eval/public/structs/field_access_impl.cc @@ -0,0 +1,745 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/field_access_impl.h" + +#include +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/map_field.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "eval/public/structs/cel_proto_wrap_util.h" +#include "internal/casts.h" +#include "internal/overflow.h" + +namespace google::api::expr::runtime::internal { + +namespace { + +using ::google::protobuf::Arena; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::MapValueConstRef; +using ::google::protobuf::Message; +using ::google::protobuf::Reflection; + +// Well-known type protobuf type names which require special get / set behavior. +constexpr absl::string_view kProtobufAny = "google.protobuf.Any"; +constexpr absl::string_view kTypeGoogleApisComPrefix = "type.googleapis.com/"; + +// Singular message fields and repeated message fields have similar access model +// To provide common approach, we implement accessor classes, based on CRTP. +// FieldAccessor is CRTP base class, specifying Get.. method family. +template +class FieldAccessor { + public: + bool GetBool() const { return static_cast(this)->GetBool(); } + + int64_t GetInt32() const { + return static_cast(this)->GetInt32(); + } + + uint64_t GetUInt32() const { + return static_cast(this)->GetUInt32(); + } + + int64_t GetInt64() const { + return static_cast(this)->GetInt64(); + } + + uint64_t GetUInt64() const { + return static_cast(this)->GetUInt64(); + } + + double GetFloat() const { + return static_cast(this)->GetFloat(); + } + + double GetDouble() const { + return static_cast(this)->GetDouble(); + } + + const std::string* GetString(std::string* buffer) const { + return static_cast(this)->GetString(buffer); + } + + const Message* GetMessage() const { + return static_cast(this)->GetMessage(); + } + + int64_t GetEnumValue() const { + return static_cast(this)->GetEnumValue(); + } + + // This method provides message field content, wrapped in CelValue. + // If value provided successfully, return a CelValue, otherwise returns a + // status with non-ok status code. + // + // arena Arena to use for allocations if needed. + absl::StatusOr CreateValueFromFieldAccessor(Arena* arena) { + switch (field_desc_->cpp_type()) { + case FieldDescriptor::CPPTYPE_BOOL: { + bool value = GetBool(); + return CelValue::CreateBool(value); + } + case FieldDescriptor::CPPTYPE_INT32: { + int64_t value = GetInt32(); + return CelValue::CreateInt64(value); + } + case FieldDescriptor::CPPTYPE_INT64: { + int64_t value = GetInt64(); + return CelValue::CreateInt64(value); + } + case FieldDescriptor::CPPTYPE_UINT32: { + uint64_t value = GetUInt32(); + return CelValue::CreateUint64(value); + } + case FieldDescriptor::CPPTYPE_UINT64: { + uint64_t value = GetUInt64(); + return CelValue::CreateUint64(value); + } + case FieldDescriptor::CPPTYPE_FLOAT: { + double value = GetFloat(); + return CelValue::CreateDouble(value); + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + double value = GetDouble(); + return CelValue::CreateDouble(value); + } + case FieldDescriptor::CPPTYPE_STRING: { + std::string buffer; + const std::string* value = GetString(&buffer); + if (value == &buffer) { + value = google::protobuf::Arena::Create(arena, std::move(buffer)); + } + switch (field_desc_->type()) { + case FieldDescriptor::TYPE_STRING: + return CelValue::CreateString(value); + case FieldDescriptor::TYPE_BYTES: + return CelValue::CreateBytes(value); + default: + return absl::Status(absl::StatusCode::kInvalidArgument, + "Error handling C++ string conversion"); + } + break; + } + case FieldDescriptor::CPPTYPE_MESSAGE: { + const google::protobuf::Message* msg_value = GetMessage(); + return UnwrapMessageToValue(msg_value, protobuf_value_factory_, arena); + } + case FieldDescriptor::CPPTYPE_ENUM: { + int enum_value = GetEnumValue(); + return CelValue::CreateInt64(enum_value); + } + default: + return absl::Status(absl::StatusCode::kInvalidArgument, + "Unhandled C++ type conversion"); + } + return absl::Status(absl::StatusCode::kInvalidArgument, + "Unhandled C++ type conversion"); + } + + protected: + FieldAccessor(const Message* msg, const FieldDescriptor* field_desc, + const ProtobufValueFactory& protobuf_value_factory) + : msg_(msg), + field_desc_(field_desc), + protobuf_value_factory_(protobuf_value_factory) {} + + const Message* msg_; + const FieldDescriptor* field_desc_; + const ProtobufValueFactory& protobuf_value_factory_; +}; + +const absl::flat_hash_set& WellKnownWrapperTypes() { + static auto* wrapper_types = new absl::flat_hash_set{ + "google.protobuf.BoolValue", "google.protobuf.DoubleValue", + "google.protobuf.FloatValue", "google.protobuf.Int64Value", + "google.protobuf.Int32Value", "google.protobuf.UInt64Value", + "google.protobuf.UInt32Value", "google.protobuf.StringValue", + "google.protobuf.BytesValue", + }; + return *wrapper_types; +} + +bool IsWrapperType(const FieldDescriptor* field_descriptor) { + return WellKnownWrapperTypes().find( + field_descriptor->message_type()->full_name()) != + WellKnownWrapperTypes().end(); +} + +// Accessor class, to work with singular fields +class ScalarFieldAccessor : public FieldAccessor { + public: + ScalarFieldAccessor(const Message* msg, const FieldDescriptor* field_desc, + bool unset_wrapper_as_null, + const ProtobufValueFactory& factory) + : FieldAccessor(msg, field_desc, factory), + unset_wrapper_as_null_(unset_wrapper_as_null) {} + + bool GetBool() const { return GetReflection()->GetBool(*msg_, field_desc_); } + + int64_t GetInt32() const { + return GetReflection()->GetInt32(*msg_, field_desc_); + } + + uint64_t GetUInt32() const { + return GetReflection()->GetUInt32(*msg_, field_desc_); + } + + int64_t GetInt64() const { + return GetReflection()->GetInt64(*msg_, field_desc_); + } + + uint64_t GetUInt64() const { + return GetReflection()->GetUInt64(*msg_, field_desc_); + } + + double GetFloat() const { + return GetReflection()->GetFloat(*msg_, field_desc_); + } + + double GetDouble() const { + return GetReflection()->GetDouble(*msg_, field_desc_); + } + + const std::string* GetString(std::string* buffer) const { + return &GetReflection()->GetStringReference(*msg_, field_desc_, buffer); + } + + const Message* GetMessage() const { + // Unset wrapper types have special semantics. + // If set, return the unwrapped value, else return 'null'. + if (unset_wrapper_as_null_ && + !GetReflection()->HasField(*msg_, field_desc_) && + IsWrapperType(field_desc_)) { + return nullptr; + } + return &GetReflection()->GetMessage(*msg_, field_desc_); + } + + int64_t GetEnumValue() const { + return GetReflection()->GetEnumValue(*msg_, field_desc_); + } + + const Reflection* GetReflection() const { return msg_->GetReflection(); } + + private: + bool unset_wrapper_as_null_; +}; + +// Accessor class, to work with repeated fields. +class RepeatedFieldAccessor : public FieldAccessor { + public: + RepeatedFieldAccessor(const Message* msg, const FieldDescriptor* field_desc, + int index, const ProtobufValueFactory& factory) + : FieldAccessor(msg, field_desc, factory), index_(index) {} + + bool GetBool() const { + return GetReflection()->GetRepeatedBool(*msg_, field_desc_, index_); + } + + int64_t GetInt32() const { + return GetReflection()->GetRepeatedInt32(*msg_, field_desc_, index_); + } + + uint64_t GetUInt32() const { + return GetReflection()->GetRepeatedUInt32(*msg_, field_desc_, index_); + } + + int64_t GetInt64() const { + return GetReflection()->GetRepeatedInt64(*msg_, field_desc_, index_); + } + + uint64_t GetUInt64() const { + return GetReflection()->GetRepeatedUInt64(*msg_, field_desc_, index_); + } + + double GetFloat() const { + return GetReflection()->GetRepeatedFloat(*msg_, field_desc_, index_); + } + + double GetDouble() const { + return GetReflection()->GetRepeatedDouble(*msg_, field_desc_, index_); + } + + const std::string* GetString(std::string* buffer) const { + return &GetReflection()->GetRepeatedStringReference(*msg_, field_desc_, + index_, buffer); + } + + const Message* GetMessage() const { + return &GetReflection()->GetRepeatedMessage(*msg_, field_desc_, index_); + } + + int64_t GetEnumValue() const { + return GetReflection()->GetRepeatedEnumValue(*msg_, field_desc_, index_); + } + + const Reflection* GetReflection() const { return msg_->GetReflection(); } + + private: + int index_; +}; + +// Accessor class, to work with map values +class MapValueAccessor : public FieldAccessor { + public: + MapValueAccessor(const Message* msg, const FieldDescriptor* field_desc, + const MapValueConstRef* value_ref, + const ProtobufValueFactory& factory) + : FieldAccessor(msg, field_desc, factory), value_ref_(value_ref) {} + + bool GetBool() const { return value_ref_->GetBoolValue(); } + + int64_t GetInt32() const { return value_ref_->GetInt32Value(); } + + uint64_t GetUInt32() const { return value_ref_->GetUInt32Value(); } + + int64_t GetInt64() const { return value_ref_->GetInt64Value(); } + + uint64_t GetUInt64() const { return value_ref_->GetUInt64Value(); } + + double GetFloat() const { return value_ref_->GetFloatValue(); } + + double GetDouble() const { return value_ref_->GetDoubleValue(); } + + const std::string* GetString(std::string* /*buffer*/) const { + return &value_ref_->GetStringValue(); + } + + const Message* GetMessage() const { return &value_ref_->GetMessageValue(); } + + int64_t GetEnumValue() const { return value_ref_->GetEnumValue(); } + + const Reflection* GetReflection() const { return msg_->GetReflection(); } + + private: + const MapValueConstRef* value_ref_; +}; + +// Singular message fields and repeated message fields have similar access model +// To provide common approach, we implement field setter classes, based on CRTP. +// FieldAccessor is CRTP base class, specifying Get.. method family. +template +class FieldSetter { + public: + bool AssignBool(const CelValue& cel_value) const { + bool value; + + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetBool(value); + return true; + } + + bool AssignInt32(const CelValue& cel_value) const { + int64_t value; + if (!cel_value.GetValue(&value)) { + return false; + } + absl::StatusOr checked_cast = + cel::internal::CheckedInt64ToInt32(value); + if (!checked_cast.ok()) { + return false; + } + static_cast(this)->SetInt32(*checked_cast); + return true; + } + + bool AssignUInt32(const CelValue& cel_value) const { + uint64_t value; + if (!cel_value.GetValue(&value)) { + return false; + } + if (!cel::internal::CheckedUint64ToUint32(value).ok()) { + return false; + } + static_cast(this)->SetUInt32(value); + return true; + } + + bool AssignInt64(const CelValue& cel_value) const { + int64_t value; + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetInt64(value); + return true; + } + + bool AssignUInt64(const CelValue& cel_value) const { + uint64_t value; + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetUInt64(value); + return true; + } + + bool AssignFloat(const CelValue& cel_value) const { + double value; + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetFloat(value); + return true; + } + + bool AssignDouble(const CelValue& cel_value) const { + double value; + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetDouble(value); + return true; + } + + bool AssignString(const CelValue& cel_value) const { + CelValue::StringHolder value; + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetString(value); + return true; + } + + bool AssignBytes(const CelValue& cel_value) const { + CelValue::BytesHolder value; + if (!cel_value.GetValue(&value)) { + return false; + } + static_cast(this)->SetBytes(value); + return true; + } + + bool AssignEnum(const CelValue& cel_value) const { + int64_t value; + if (!cel_value.GetValue(&value)) { + return false; + } + if (!cel::internal::CheckedInt64ToInt32(value).ok()) { + return false; + } + static_cast(this)->SetEnum(value); + return true; + } + + bool AssignMessage(const google::protobuf::Message* message) const { + return static_cast(this)->SetMessage(message); + } + + // This method provides message field content, wrapped in CelValue. + // If value provided successfully, returns Ok. + // arena Arena to use for allocations if needed. + // result pointer to object to store value in. + bool SetFieldFromCelValue(const CelValue& value) { + switch (field_desc_->cpp_type()) { + case FieldDescriptor::CPPTYPE_BOOL: { + return AssignBool(value); + } + case FieldDescriptor::CPPTYPE_INT32: { + return AssignInt32(value); + } + case FieldDescriptor::CPPTYPE_INT64: { + return AssignInt64(value); + } + case FieldDescriptor::CPPTYPE_UINT32: { + return AssignUInt32(value); + } + case FieldDescriptor::CPPTYPE_UINT64: { + return AssignUInt64(value); + } + case FieldDescriptor::CPPTYPE_FLOAT: { + return AssignFloat(value); + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + return AssignDouble(value); + } + case FieldDescriptor::CPPTYPE_STRING: { + switch (field_desc_->type()) { + case FieldDescriptor::TYPE_STRING: + + return AssignString(value); + case FieldDescriptor::TYPE_BYTES: + return AssignBytes(value); + default: + return false; + } + break; + } + case FieldDescriptor::CPPTYPE_MESSAGE: { + // When the field is a message, it might be a well-known type with a + // non-proto representation that requires special handling before it + // can be set on the field. + const google::protobuf::Message* wrapped_value = + MaybeWrapValueToMessage(field_desc_->message_type(), value, arena_); + if (wrapped_value == nullptr) { + // It we aren't unboxing to a protobuf null representation, setting a + // field to null is a no-op. + if (value.IsNull()) { + return true; + } + if (CelValue::MessageWrapper wrapper; + value.GetValue(&wrapper) && wrapper.HasFullProto()) { + wrapped_value = cel::internal::down_cast( + wrapper.message_ptr()); + } else { + return false; + } + } + + return AssignMessage(wrapped_value); + } + case FieldDescriptor::CPPTYPE_ENUM: { + return AssignEnum(value); + } + default: + return false; + } + + return true; + } + + protected: + FieldSetter(Message* msg, const FieldDescriptor* field_desc, Arena* arena) + : msg_(msg), field_desc_(field_desc), arena_(arena) {} + + Message* msg_; + const FieldDescriptor* field_desc_; + Arena* arena_; +}; + +// Accessor class, to work with singular fields +class ScalarFieldSetter : public FieldSetter { + public: + ScalarFieldSetter(Message* msg, const FieldDescriptor* field_desc, + Arena* arena) + : FieldSetter(msg, field_desc, arena) {} + + bool SetBool(bool value) const { + GetReflection()->SetBool(msg_, field_desc_, value); + return true; + } + + bool SetInt32(int32_t value) const { + GetReflection()->SetInt32(msg_, field_desc_, value); + return true; + } + + bool SetUInt32(uint32_t value) const { + GetReflection()->SetUInt32(msg_, field_desc_, value); + return true; + } + + bool SetInt64(int64_t value) const { + GetReflection()->SetInt64(msg_, field_desc_, value); + return true; + } + + bool SetUInt64(uint64_t value) const { + GetReflection()->SetUInt64(msg_, field_desc_, value); + return true; + } + + bool SetFloat(float value) const { + GetReflection()->SetFloat(msg_, field_desc_, value); + return true; + } + + bool SetDouble(double value) const { + GetReflection()->SetDouble(msg_, field_desc_, value); + return true; + } + + bool SetString(CelValue::StringHolder value) const { + GetReflection()->SetString(msg_, field_desc_, std::string(value.value())); + return true; + } + + bool SetBytes(CelValue::BytesHolder value) const { + GetReflection()->SetString(msg_, field_desc_, std::string(value.value())); + return true; + } + + bool SetMessage(const Message* value) const { + if (!value) { + GOOGLE_LOG(ERROR) << "Message is NULL"; + return true; + } + + if (value->GetDescriptor()->full_name() == + field_desc_->message_type()->full_name()) { + GetReflection()->MutableMessage(msg_, field_desc_)->MergeFrom(*value); + return true; + + } else if (field_desc_->message_type()->full_name() == kProtobufAny) { + auto any_msg = google::protobuf::DynamicCastToGenerated( + GetReflection()->MutableMessage(msg_, field_desc_)); + if (any_msg == nullptr) { + // TODO(issues/68): This is probably a dynamic message. We should + // implement this once we add support for dynamic protobuf types. + return false; + } + any_msg->set_type_url(absl::StrCat(kTypeGoogleApisComPrefix, + value->GetDescriptor()->full_name())); + return value->SerializeToString(any_msg->mutable_value()); + } + return false; + } + + bool SetEnum(const int64_t value) const { + GetReflection()->SetEnumValue(msg_, field_desc_, value); + return true; + } + + const Reflection* GetReflection() const { return msg_->GetReflection(); } +}; + +// Appender class, to work with repeated fields +class RepeatedFieldSetter : public FieldSetter { + public: + RepeatedFieldSetter(Message* msg, const FieldDescriptor* field_desc, + Arena* arena) + : FieldSetter(msg, field_desc, arena) {} + + bool SetBool(bool value) const { + GetReflection()->AddBool(msg_, field_desc_, value); + return true; + } + + bool SetInt32(int32_t value) const { + GetReflection()->AddInt32(msg_, field_desc_, value); + return true; + } + + bool SetUInt32(uint32_t value) const { + GetReflection()->AddUInt32(msg_, field_desc_, value); + return true; + } + + bool SetInt64(int64_t value) const { + GetReflection()->AddInt64(msg_, field_desc_, value); + return true; + } + + bool SetUInt64(uint64_t value) const { + GetReflection()->AddUInt64(msg_, field_desc_, value); + return true; + } + + bool SetFloat(float value) const { + GetReflection()->AddFloat(msg_, field_desc_, value); + return true; + } + + bool SetDouble(double value) const { + GetReflection()->AddDouble(msg_, field_desc_, value); + return true; + } + + bool SetString(CelValue::StringHolder value) const { + GetReflection()->AddString(msg_, field_desc_, std::string(value.value())); + return true; + } + + bool SetBytes(CelValue::BytesHolder value) const { + GetReflection()->AddString(msg_, field_desc_, std::string(value.value())); + return true; + } + + bool SetMessage(const Message* value) const { + if (!value) return true; + if (value->GetDescriptor()->full_name() != + field_desc_->message_type()->full_name()) { + return false; + } + + GetReflection()->AddMessage(msg_, field_desc_)->MergeFrom(*value); + return true; + } + + bool SetEnum(const int64_t value) const { + GetReflection()->AddEnumValue(msg_, field_desc_, value); + return true; + } + + private: + const Reflection* GetReflection() const { return msg_->GetReflection(); } +}; + +} // namespace + +absl::StatusOr CreateValueFromSingleField( + const google::protobuf::Message* msg, const FieldDescriptor* desc, + ProtoWrapperTypeOptions options, const ProtobufValueFactory& factory, + google::protobuf::Arena* arena) { + ScalarFieldAccessor accessor( + msg, desc, (options == ProtoWrapperTypeOptions::kUnsetNull), factory); + return accessor.CreateValueFromFieldAccessor(arena); +} + +absl::StatusOr CreateValueFromRepeatedField( + const google::protobuf::Message* msg, const FieldDescriptor* desc, int index, + const ProtobufValueFactory& factory, google::protobuf::Arena* arena) { + RepeatedFieldAccessor accessor(msg, desc, index, factory); + return accessor.CreateValueFromFieldAccessor(arena); +} + +absl::StatusOr CreateValueFromMapValue( + const google::protobuf::Message* msg, const FieldDescriptor* desc, + const MapValueConstRef* value_ref, const ProtobufValueFactory& factory, + google::protobuf::Arena* arena) { + MapValueAccessor accessor(msg, desc, value_ref, factory); + return accessor.CreateValueFromFieldAccessor(arena); +} + +absl::Status SetValueToSingleField(const CelValue& value, + const FieldDescriptor* desc, Message* msg, + Arena* arena) { + ScalarFieldSetter setter(msg, desc, arena); + return (setter.SetFieldFromCelValue(value)) + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::Substitute( + "Could not assign supplied argument to message \"$0\" field " + "\"$1\" of type $2: value type \"$3\"", + msg->GetDescriptor()->name(), desc->name(), + desc->type_name(), CelValue::TypeName(value.type()))); +} + +absl::Status AddValueToRepeatedField(const CelValue& value, + const FieldDescriptor* desc, Message* msg, + Arena* arena) { + RepeatedFieldSetter setter(msg, desc, arena); + return (setter.SetFieldFromCelValue(value)) + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::Substitute( + "Could not add supplied argument to message \"$0\" field " + "\"$1\" of type $2: value type \"$3\"", + msg->GetDescriptor()->name(), desc->name(), + desc->type_name(), CelValue::TypeName(value.type()))); +} + +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/structs/field_access_impl.h b/eval/public/structs/field_access_impl.h new file mode 100644 index 000000000..150280e28 --- /dev/null +++ b/eval/public/structs/field_access_impl.h @@ -0,0 +1,88 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_FIELD_ACCESS_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_FIELD_ACCESS_IMPL_H_ + +#include "eval/public/cel_value.h" +#include "eval/public/structs/protobuf_value_factory.h" + +namespace google::api::expr::runtime::internal { + +// Options for handling unset wrapper types. +enum class ProtoWrapperTypeOptions { + // Default: legacy behavior following proto semantics (unset behaves as though + // it is set to default value). + kUnsetProtoDefault, + // CEL spec behavior, unset wrapper is treated as a null value when accessed. + kUnsetNull, +}; + +// Creates CelValue from singular message field. +// Returns status of the operation. +// msg Message containing the field. +// desc Descriptor of the field to access. +// options Option to enable treating unset wrapper type fields as null. +// arena Arena object to allocate result on, if needed. +// result pointer to CelValue to store the result in. +absl::StatusOr CreateValueFromSingleField( + const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, + ProtoWrapperTypeOptions options, const ProtobufValueFactory& factory, + google::protobuf::Arena* arena); + +// Creates CelValue from repeated message field. +// Returns status of the operation. +// msg Message containing the field. +// desc Descriptor of the field to access. +// arena Arena object to allocate result on, if needed. +// index position in the repeated field. +absl::StatusOr CreateValueFromRepeatedField( + const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, int index, + const ProtobufValueFactory& factory, google::protobuf::Arena* arena); + +// Creates CelValue from map message field. +// Returns status of the operation. +// msg Message containing the field. +// desc Descriptor of the field to access. +// value_ref pointer to map value. +// arena Arena object to allocate result on, if needed. +// TODO(issues/5): This should be inlined into the FieldBackedMap +// implementation. +absl::StatusOr CreateValueFromMapValue( + const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, + const google::protobuf::MapValueConstRef* value_ref, + const ProtobufValueFactory& factory, google::protobuf::Arena* arena); + +// Assigns content of CelValue to singular message field. +// Returns status of the operation. +// msg Message containing the field. +// desc Descriptor of the field to access. +// arena Arena to perform allocations, if necessary, when setting the field. +absl::Status SetValueToSingleField(const CelValue& value, + const google::protobuf::FieldDescriptor* desc, + google::protobuf::Message* msg, google::protobuf::Arena* arena); + +// Adds content of CelValue to repeated message field. +// Returns status of the operation. +// msg Message containing the field. +// desc Descriptor of the field to access. +// arena Arena to perform allocations, if necessary, when adding the value. +absl::Status AddValueToRepeatedField(const CelValue& value, + const google::protobuf::FieldDescriptor* desc, + google::protobuf::Message* msg, + google::protobuf::Arena* arena); + +} // namespace google::api::expr::runtime::internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_FIELD_ACCESS_IMPL_H_ diff --git a/eval/public/structs/field_access_impl_test.cc b/eval/public/structs/field_access_impl_test.cc new file mode 100644 index 000000000..caa697760 --- /dev/null +++ b/eval/public/structs/field_access_impl_test.cc @@ -0,0 +1,647 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/field_access_impl.h" + +#include +#include + +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/testing/matchers.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" +#include "internal/time.h" +#include "proto/test/v1/proto3/test_all_types.pb.h" + +namespace google::api::expr::runtime::internal { + +namespace { + +using ::cel::internal::MaxDuration; +using ::cel::internal::MaxTimestamp; +using ::google::api::expr::test::v1::proto3::TestAllTypes; +using ::google::protobuf::Arena; +using ::google::protobuf::FieldDescriptor; +using testing::EqualsProto; +using testing::HasSubstr; +using cel::internal::StatusIs; + +CelValue MessageValueFactory(const google::protobuf::Message* message) { + return CelValue::CreateMessageWrapper(CelValue::MessageWrapper(message)); +} + +TEST(FieldAccessTest, SetDuration) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = SetValueToSingleField(CelValue::CreateDuration(MaxDuration()), + field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetDurationBadDuration) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = SetValueToSingleField( + CelValue::CreateDuration(MaxDuration() + absl::Seconds(1)), field, &msg, + &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetDurationBadInputType) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_duration"); + auto status = + SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetTimestamp) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = SetValueToSingleField(CelValue::CreateTimestamp(MaxTimestamp()), + field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetTimestampBadTime) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = SetValueToSingleField( + CelValue::CreateTimestamp(MaxTimestamp() + absl::Seconds(1)), field, &msg, + &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetTimestampBadInputType) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); + auto status = + SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +TEST(FieldAccessTest, SetInt32Overflow) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_int32"); + EXPECT_THAT( + SetValueToSingleField( + CelValue::CreateInt64(std::numeric_limits::max() + 1L), + field, &msg, &arena), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Could not assign"))); +} + +TEST(FieldAccessTest, SetUint32Overflow) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("single_uint32"); + EXPECT_THAT( + SetValueToSingleField( + CelValue::CreateUint64(std::numeric_limits::max() + 1L), + field, &msg, &arena), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Could not assign"))); +} + +TEST(FieldAccessTest, SetMessage) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + TestAllTypes::NestedMessage* nested_msg = + google::protobuf::Arena::CreateMessage(&arena); + nested_msg->set_bb(1); + auto status = SetValueToSingleField( + CelProtoWrapper::CreateMessage(nested_msg, &arena), field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +TEST(FieldAccessTest, SetMessageWithNull) { + Arena arena; + TestAllTypes msg; + const FieldDescriptor* field = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + auto status = + SetValueToSingleField(CelValue::CreateNull(), field, &msg, &arena); + EXPECT_TRUE(status.ok()); +} + +struct AccessFieldTestParam { + absl::string_view field_name; + absl::string_view message_textproto; + CelValue cel_value; +}; + +std::string GetTestName( + const testing::TestParamInfo& info) { + return std::string(info.param.field_name); +} + +class SingleFieldTest : public testing::TestWithParam { + public: + absl::string_view field_name() const { return GetParam().field_name; } + absl::string_view message_textproto() const { + return GetParam().message_textproto; + } + CelValue cel_value() const { return GetParam().cel_value; } +}; + +TEST_P(SingleFieldTest, Getter) { + TestAllTypes test_message; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(message_textproto().data(), &test_message)); + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN( + CelValue accessed_value, + CreateValueFromSingleField( + &test_message, + test_message.GetDescriptor()->FindFieldByName(field_name().data()), + ProtoWrapperTypeOptions::kUnsetProtoDefault, &MessageValueFactory, + &arena)); + + EXPECT_THAT(accessed_value, test::EqualsCelValue(cel_value())); +} + +TEST_P(SingleFieldTest, Setter) { + TestAllTypes test_message; + CelValue to_set = cel_value(); + google::protobuf::Arena arena; + + ASSERT_OK(SetValueToSingleField( + to_set, test_message.GetDescriptor()->FindFieldByName(field_name().data()), + &test_message, &arena)); + + EXPECT_THAT(test_message, EqualsProto(message_textproto())); +} + +INSTANTIATE_TEST_SUITE_P( + AllTypes, SingleFieldTest, + testing::ValuesIn({ + {"single_int32", "single_int32: 1", CelValue::CreateInt64(1)}, + {"single_int64", "single_int64: 1", CelValue::CreateInt64(1)}, + {"single_uint32", "single_uint32: 1", CelValue::CreateUint64(1)}, + {"single_uint64", "single_uint64: 1", CelValue::CreateUint64(1)}, + {"single_sint32", "single_sint32: 1", CelValue::CreateInt64(1)}, + {"single_sint64", "single_sint64: 1", CelValue::CreateInt64(1)}, + {"single_fixed32", "single_fixed32: 1", CelValue::CreateUint64(1)}, + {"single_fixed64", "single_fixed64: 1", CelValue::CreateUint64(1)}, + {"single_sfixed32", "single_sfixed32: 1", CelValue::CreateInt64(1)}, + {"single_sfixed64", "single_sfixed64: 1", CelValue::CreateInt64(1)}, + {"single_float", "single_float: 1.0", CelValue::CreateDouble(1.0)}, + {"single_double", "single_double: 1.0", CelValue::CreateDouble(1.0)}, + {"single_bool", "single_bool: true", CelValue::CreateBool(true)}, + {"single_string", "single_string: 'abcd'", + CelValue::CreateStringView("abcd")}, + {"single_bytes", "single_bytes: 'asdf'", + CelValue::CreateBytesView("asdf")}, + {"standalone_enum", "standalone_enum: BAZ", CelValue::CreateInt64(2)}, + // Basic coverage for unwrapping -- specifics are managed by the + // wrapping library. + {"single_int64_wrapper", "single_int64_wrapper { value: 20 }", + CelValue::CreateInt64(20)}, + {"single_value", "single_value { null_value: NULL_VALUE }", + CelValue::CreateNull()}, + }), + &GetTestName); + +TEST(CreateValueFromSingleFieldTest, GetMessage) { + TestAllTypes test_message; + google::protobuf::Arena arena; + + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + "standalone_message { bb: 10 }", &test_message)); + + ASSERT_OK_AND_ASSIGN( + CelValue accessed_value, + CreateValueFromSingleField( + &test_message, + test_message.GetDescriptor()->FindFieldByName("standalone_message"), + ProtoWrapperTypeOptions::kUnsetProtoDefault, &MessageValueFactory, + &arena)); + + EXPECT_THAT(accessed_value, test::IsCelMessage(EqualsProto("bb: 10"))); +} + +TEST(SetValueToSingleFieldTest, WrongType) { + TestAllTypes test_message; + google::protobuf::Arena arena; + + EXPECT_THAT(SetValueToSingleField( + CelValue::CreateDouble(1.0), + test_message.GetDescriptor()->FindFieldByName("single_int32"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(SetValueToSingleFieldTest, IntOutOfRange) { + CelValue out_of_range = CelValue::CreateInt64(1LL << 31); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + google::protobuf::Arena arena; + + EXPECT_THAT(SetValueToSingleField(out_of_range, + descriptor->FindFieldByName("single_int32"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); + + // proto enums are are represented as int32_t, but CEL converts to/from int64_t. + EXPECT_THAT(SetValueToSingleField( + out_of_range, descriptor->FindFieldByName("standalone_enum"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(SetValueToSingleFieldTest, UintOutOfRange) { + CelValue out_of_range = CelValue::CreateUint64(1LL << 32); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + google::protobuf::Arena arena; + + EXPECT_THAT(SetValueToSingleField( + out_of_range, descriptor->FindFieldByName("single_uint32"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(SetValueToSingleFieldTest, SetMessage) { + TestAllTypes::NestedMessage nested_message; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + bb: 42 + )", + &nested_message)); + google::protobuf::Arena arena; + CelValue nested_value = + CelProtoWrapper::CreateMessage(&nested_message, &arena); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + + ASSERT_OK(SetValueToSingleField( + nested_value, descriptor->FindFieldByName("standalone_message"), + &test_message, &arena)); + EXPECT_THAT(test_message, EqualsProto("standalone_message { bb: 42 }")); +} + +TEST(SetValueToSingleFieldTest, SetAnyMessage) { + TestAllTypes::NestedMessage nested_message; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + bb: 42 + )", + &nested_message)); + google::protobuf::Arena arena; + CelValue nested_value = + CelProtoWrapper::CreateMessage(&nested_message, &arena); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + + ASSERT_OK(SetValueToSingleField(nested_value, + descriptor->FindFieldByName("single_any"), + &test_message, &arena)); + + TestAllTypes::NestedMessage unpacked; + test_message.single_any().UnpackTo(&unpacked); + EXPECT_THAT(unpacked, EqualsProto("bb: 42")); +} + +TEST(SetValueToSingleFieldTest, SetMessageToNullNoop) { + google::protobuf::Arena arena; + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + + ASSERT_OK(SetValueToSingleField( + CelValue::CreateNull(), descriptor->FindFieldByName("standalone_message"), + &test_message, &arena)); + EXPECT_THAT(test_message, EqualsProto(test_message.default_instance())); +} + +class RepeatedFieldTest : public testing::TestWithParam { + public: + absl::string_view field_name() const { return GetParam().field_name; } + absl::string_view message_textproto() const { + return GetParam().message_textproto; + } + CelValue cel_value() const { return GetParam().cel_value; } +}; + +TEST_P(RepeatedFieldTest, GetFirstElem) { + TestAllTypes test_message; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(message_textproto().data(), &test_message)); + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN( + CelValue accessed_value, + CreateValueFromRepeatedField( + &test_message, + test_message.GetDescriptor()->FindFieldByName(field_name().data()), 0, + &MessageValueFactory, &arena)); + + EXPECT_THAT(accessed_value, test::EqualsCelValue(cel_value())); +} + +TEST_P(RepeatedFieldTest, AppendElem) { + TestAllTypes test_message; + CelValue to_add = cel_value(); + google::protobuf::Arena arena; + + ASSERT_OK(AddValueToRepeatedField( + to_add, test_message.GetDescriptor()->FindFieldByName(field_name().data()), + &test_message, &arena)); + + EXPECT_THAT(test_message, EqualsProto(message_textproto())); +} + +INSTANTIATE_TEST_SUITE_P( + AllTypes, RepeatedFieldTest, + testing::ValuesIn( + {{"repeated_int32", "repeated_int32: 1", CelValue::CreateInt64(1)}, + {"repeated_int64", "repeated_int64: 1", CelValue::CreateInt64(1)}, + {"repeated_uint32", "repeated_uint32: 1", CelValue::CreateUint64(1)}, + {"repeated_uint64", "repeated_uint64: 1", CelValue::CreateUint64(1)}, + {"repeated_sint32", "repeated_sint32: 1", CelValue::CreateInt64(1)}, + {"repeated_sint64", "repeated_sint64: 1", CelValue::CreateInt64(1)}, + {"repeated_fixed32", "repeated_fixed32: 1", CelValue::CreateUint64(1)}, + {"repeated_fixed64", "repeated_fixed64: 1", CelValue::CreateUint64(1)}, + {"repeated_sfixed32", "repeated_sfixed32: 1", + CelValue::CreateInt64(1)}, + {"repeated_sfixed64", "repeated_sfixed64: 1", + CelValue::CreateInt64(1)}, + {"repeated_float", "repeated_float: 1.0", CelValue::CreateDouble(1.0)}, + {"repeated_double", "repeated_double: 1.0", + CelValue::CreateDouble(1.0)}, + {"repeated_bool", "repeated_bool: true", CelValue::CreateBool(true)}, + {"repeated_string", "repeated_string: 'abcd'", + CelValue::CreateStringView("abcd")}, + {"repeated_bytes", "repeated_bytes: 'asdf'", + CelValue::CreateBytesView("asdf")}, + {"repeated_nested_enum", "repeated_nested_enum: BAZ", + CelValue::CreateInt64(2)}}), + &GetTestName); + +TEST(RepeatedFieldTest, GetMessage) { + TestAllTypes test_message; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + "repeated_nested_message { bb: 30 }", &test_message)); + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(CelValue accessed_value, + CreateValueFromRepeatedField( + &test_message, + test_message.GetDescriptor()->FindFieldByName( + "repeated_nested_message"), + 0, &MessageValueFactory, &arena)); + + EXPECT_THAT(accessed_value, test::IsCelMessage(EqualsProto("bb: 30"))); +} + +TEST(AddValueToRepeatedFieldTest, WrongType) { + TestAllTypes test_message; + google::protobuf::Arena arena; + + EXPECT_THAT( + AddValueToRepeatedField( + CelValue::CreateDouble(1.0), + test_message.GetDescriptor()->FindFieldByName("repeated_int32"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(AddValueToRepeatedFieldTest, IntOutOfRange) { + CelValue out_of_range = CelValue::CreateInt64(1LL << 31); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + google::protobuf::Arena arena; + + EXPECT_THAT(AddValueToRepeatedField( + out_of_range, descriptor->FindFieldByName("repeated_int32"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); + + // proto enums are are represented as int32_t, but CEL converts to/from int64_t. + EXPECT_THAT( + AddValueToRepeatedField( + out_of_range, descriptor->FindFieldByName("repeated_nested_enum"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(AddValueToRepeatedFieldTest, UintOutOfRange) { + CelValue out_of_range = CelValue::CreateUint64(1LL << 32); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + google::protobuf::Arena arena; + + EXPECT_THAT(AddValueToRepeatedField( + out_of_range, descriptor->FindFieldByName("repeated_uint32"), + &test_message, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(AddValueToRepeatedFieldTest, AddMessage) { + TestAllTypes::NestedMessage nested_message; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + bb: 42 + )", + &nested_message)); + google::protobuf::Arena arena; + CelValue nested_value = + CelProtoWrapper::CreateMessage(&nested_message, &arena); + TestAllTypes test_message; + const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); + + ASSERT_OK(AddValueToRepeatedField( + nested_value, descriptor->FindFieldByName("repeated_nested_message"), + &test_message, &arena)); + EXPECT_THAT(test_message, EqualsProto("repeated_nested_message { bb: 42 }")); +} + +constexpr std::array kWrapperFieldNames = { + "single_bool_wrapper", "single_int64_wrapper", "single_int32_wrapper", + "single_uint64_wrapper", "single_uint32_wrapper", "single_double_wrapper", + "single_float_wrapper", "single_string_wrapper", "single_bytes_wrapper"}; + +// Unset wrapper type fields are treated as null if accessed after option +// enabled. +TEST(CreateValueFromFieldTest, UnsetWrapperTypesNullIfEnabled) { + CelValue result; + TestAllTypes test_message; + google::protobuf::Arena arena; + + for (const auto& field : kWrapperFieldNames) { + ASSERT_OK_AND_ASSIGN( + result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName(field), + ProtoWrapperTypeOptions::kUnsetNull, &MessageValueFactory, &arena)); + ASSERT_TRUE(result.IsNull()) << field << ": " << result.DebugString(); + } +} + +// Unset wrapper type fields are treated as proto default under old +// behavior. +TEST(CreateValueFromFieldTest, UnsetWrapperTypesDefaultValueIfDisabled) { + CelValue result; + TestAllTypes test_message; + google::protobuf::Arena arena; + + for (const auto& field : kWrapperFieldNames) { + ASSERT_OK_AND_ASSIGN( + result, CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName(field), + ProtoWrapperTypeOptions::kUnsetProtoDefault, + &MessageValueFactory, &arena)); + ASSERT_FALSE(result.IsNull()) << field << ": " << result.DebugString(); + } +} + +// If a wrapper type is set to default value, the corresponding CelValue is the +// proto default value. +TEST(CreateValueFromFieldTest, SetWrapperTypesDefaultValue) { + CelValue result; + TestAllTypes test_message; + google::protobuf::Arena arena; + + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + single_bool_wrapper {} + single_int64_wrapper {} + single_int32_wrapper {} + single_uint64_wrapper {} + single_uint32_wrapper {} + single_double_wrapper {} + single_float_wrapper {} + single_string_wrapper {} + single_bytes_wrapper {} + )pb", + &test_message)); + + ASSERT_OK_AND_ASSIGN( + result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName("single_bool_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, &MessageValueFactory, &arena)); + EXPECT_THAT(result, test::IsCelBool(false)); + + ASSERT_OK_AND_ASSIGN( + result, + CreateValueFromSingleField(&test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_int64_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + &MessageValueFactory, &arena)); + EXPECT_THAT(result, test::IsCelInt64(0)); + + ASSERT_OK_AND_ASSIGN( + result, + CreateValueFromSingleField(&test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_int32_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + &MessageValueFactory, &arena)); + EXPECT_THAT(result, test::IsCelInt64(0)); + + ASSERT_OK_AND_ASSIGN( + result, CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_uint64_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, &MessageValueFactory, + + &arena)); + EXPECT_THAT(result, test::IsCelUint64(0)); + + ASSERT_OK_AND_ASSIGN( + result, CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_uint32_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, &MessageValueFactory, + + &arena)); + EXPECT_THAT(result, test::IsCelUint64(0)); + + ASSERT_OK_AND_ASSIGN( + result, + CreateValueFromSingleField(&test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_double_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &MessageValueFactory, &arena)); + EXPECT_THAT(result, test::IsCelDouble(0.0f)); + + ASSERT_OK_AND_ASSIGN( + result, + CreateValueFromSingleField(&test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_float_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &MessageValueFactory, &arena)); + EXPECT_THAT(result, test::IsCelDouble(0.0f)); + + ASSERT_OK_AND_ASSIGN( + result, + CreateValueFromSingleField(&test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_string_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &MessageValueFactory, &arena)); + EXPECT_THAT(result, test::IsCelString("")); + + ASSERT_OK_AND_ASSIGN( + result, + CreateValueFromSingleField(&test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_bytes_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &MessageValueFactory, &arena)); + EXPECT_THAT(result, test::IsCelBytes("")); +} + +} // namespace + +} // namespace google::api::expr::runtime::internal From 2e12dd4d312ddb77662531c3f268523d29fdb415 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 Apr 2022 22:39:03 +0000 Subject: [PATCH 066/111] Migrate public field access helpers to use branched internal implementation. PiperOrigin-RevId: 441597747 --- eval/eval/BUILD | 1 + eval/eval/select_step.cc | 1 + eval/public/cel_options.h | 9 + eval/public/containers/BUILD | 7 +- eval/public/containers/field_access.cc | 710 +---------------------- eval/public/containers/field_access.h | 10 +- eval/public/structs/BUILD | 1 + eval/public/structs/cel_proto_wrapper.cc | 12 +- eval/public/structs/cel_proto_wrapper.h | 4 + eval/public/structs/field_access_impl.h | 10 +- 10 files changed, 46 insertions(+), 719 deletions(-) diff --git a/eval/eval/BUILD b/eval/eval/BUILD index ae44d8b1f..a1a33e7c9 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -180,6 +180,7 @@ cc_library( deps = [ ":evaluator_core", ":expression_step_base", + "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/containers:field_access", "//eval/public/containers:field_backed_list_impl", diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index e8e7c7cb9..55a72e563 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -10,6 +10,7 @@ #include "absl/strings/string_view.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" +#include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/field_access.h" #include "eval/public/containers/field_backed_list_impl.h" diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index 38f0511c8..9fd18e138 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -32,6 +32,15 @@ enum class UnknownProcessingOptions { kAttributeAndFunction }; +// Options for handling unset wrapper types on field access. +enum class ProtoWrapperTypeOptions { + // Default: legacy behavior following proto semantics (unset behaves as though + // it is set to default value). + kUnsetProtoDefault, + // CEL spec behavior, unset wrapper is treated as a null value when accessed. + kUnsetNull, +}; + // Interpreter options for controlling evaluation and builtin functions. struct InterpreterOptions { // Level of unknown support enabled. diff --git a/eval/public/containers/BUILD b/eval/public/containers/BUILD index bec0dffdc..bc4f11f18 100644 --- a/eval/public/containers/BUILD +++ b/eval/public/containers/BUILD @@ -27,13 +27,12 @@ cc_library( "field_access.h", ], deps = [ + "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", - "//internal:casts", - "//internal:overflow", - "@com_google_absl//absl/container:flat_hash_set", + "//eval/public/structs:field_access_impl", + "//internal:status_macros", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/containers/field_access.cc b/eval/public/containers/field_access.cc index 75ca40970..ddd2cc93b 100644 --- a/eval/public/containers/field_access.cc +++ b/eval/public/containers/field_access.cc @@ -14,335 +14,19 @@ #include "eval/public/containers/field_access.h" -#include -#include -#include -#include - -#include "google/protobuf/any.pb.h" -#include "google/protobuf/struct.pb.h" -#include "google/protobuf/wrappers.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/map_field.h" -#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/strings/substitute.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "internal/casts.h" -#include "internal/overflow.h" +#include "eval/public/structs/field_access_impl.h" +#include "internal/status_macros.h" namespace google::api::expr::runtime { -namespace { - using ::google::protobuf::Arena; using ::google::protobuf::FieldDescriptor; using ::google::protobuf::MapValueConstRef; using ::google::protobuf::Message; -using ::google::protobuf::Reflection; - -// Well-known type protobuf type names which require special get / set behavior. -constexpr absl::string_view kProtobufAny = "google.protobuf.Any"; -constexpr absl::string_view kTypeGoogleApisComPrefix = "type.googleapis.com/"; - -// Singular message fields and repeated message fields have similar access model -// To provide common approach, we implement accessor classes, based on CRTP. -// FieldAccessor is CRTP base class, specifying Get.. method family. -template -class FieldAccessor { - public: - bool GetBool() const { return static_cast(this)->GetBool(); } - - int64_t GetInt32() const { - return static_cast(this)->GetInt32(); - } - - uint64_t GetUInt32() const { - return static_cast(this)->GetUInt32(); - } - - int64_t GetInt64() const { - return static_cast(this)->GetInt64(); - } - - uint64_t GetUInt64() const { - return static_cast(this)->GetUInt64(); - } - - double GetFloat() const { - return static_cast(this)->GetFloat(); - } - - double GetDouble() const { - return static_cast(this)->GetDouble(); - } - - const std::string* GetString(std::string* buffer) const { - return static_cast(this)->GetString(buffer); - } - - const Message* GetMessage() const { - return static_cast(this)->GetMessage(); - } - - int64_t GetEnumValue() const { - return static_cast(this)->GetEnumValue(); - } - - // This method provides message field content, wrapped in CelValue. - // If value provided successfully, returns Ok. - // arena Arena to use for allocations if needed. - // result pointer to object to store value in. - absl::Status CreateValueFromFieldAccessor(Arena* arena, CelValue* result) { - switch (field_desc_->cpp_type()) { - case FieldDescriptor::CPPTYPE_BOOL: { - bool value = GetBool(); - *result = CelValue::CreateBool(value); - break; - } - case FieldDescriptor::CPPTYPE_INT32: { - int64_t value = GetInt32(); - *result = CelValue::CreateInt64(value); - break; - } - case FieldDescriptor::CPPTYPE_INT64: { - int64_t value = GetInt64(); - *result = CelValue::CreateInt64(value); - break; - } - case FieldDescriptor::CPPTYPE_UINT32: { - uint64_t value = GetUInt32(); - *result = CelValue::CreateUint64(value); - break; - } - case FieldDescriptor::CPPTYPE_UINT64: { - uint64_t value = GetUInt64(); - *result = CelValue::CreateUint64(value); - break; - } - case FieldDescriptor::CPPTYPE_FLOAT: { - double value = GetFloat(); - *result = CelValue::CreateDouble(value); - break; - } - case FieldDescriptor::CPPTYPE_DOUBLE: { - double value = GetDouble(); - *result = CelValue::CreateDouble(value); - break; - } - case FieldDescriptor::CPPTYPE_STRING: { - std::string buffer; - const std::string* value = GetString(&buffer); - if (value == &buffer) { - value = google::protobuf::Arena::Create(arena, std::move(buffer)); - } - switch (field_desc_->type()) { - case FieldDescriptor::TYPE_STRING: - *result = CelValue::CreateString(value); - break; - case FieldDescriptor::TYPE_BYTES: - *result = CelValue::CreateBytes(value); - break; - default: - return absl::Status(absl::StatusCode::kInvalidArgument, - "Error handling C++ string conversion"); - } - break; - } - case FieldDescriptor::CPPTYPE_MESSAGE: { - const google::protobuf::Message* msg_value = GetMessage(); - *result = CelProtoWrapper::CreateMessage(msg_value, arena); - break; - } - case FieldDescriptor::CPPTYPE_ENUM: { - int enum_value = GetEnumValue(); - *result = CelValue::CreateInt64(enum_value); - break; - } - default: - return absl::Status(absl::StatusCode::kInvalidArgument, - "Unhandled C++ type conversion"); - } - - return absl::OkStatus(); - } - - protected: - FieldAccessor(const Message* msg, const FieldDescriptor* field_desc) - : msg_(msg), field_desc_(field_desc) {} - - const Message* msg_; - const FieldDescriptor* field_desc_; -}; - -const absl::flat_hash_set& WellKnownWrapperTypes() { - static auto* wrapper_types = new absl::flat_hash_set{ - "google.protobuf.BoolValue", "google.protobuf.DoubleValue", - "google.protobuf.FloatValue", "google.protobuf.Int64Value", - "google.protobuf.Int32Value", "google.protobuf.UInt64Value", - "google.protobuf.UInt32Value", "google.protobuf.StringValue", - "google.protobuf.BytesValue", - }; - return *wrapper_types; -} - -bool IsWrapperType(const FieldDescriptor* field_descriptor) { - return WellKnownWrapperTypes().find( - field_descriptor->message_type()->full_name()) != - WellKnownWrapperTypes().end(); -} - -// Accessor class, to work with singular fields -class ScalarFieldAccessor : public FieldAccessor { - public: - ScalarFieldAccessor(const Message* msg, const FieldDescriptor* field_desc, - bool unset_wrapper_as_null) - : FieldAccessor(msg, field_desc), - unset_wrapper_as_null_(unset_wrapper_as_null) {} - - bool GetBool() const { return GetReflection()->GetBool(*msg_, field_desc_); } - - int64_t GetInt32() const { - return GetReflection()->GetInt32(*msg_, field_desc_); - } - - uint64_t GetUInt32() const { - return GetReflection()->GetUInt32(*msg_, field_desc_); - } - - int64_t GetInt64() const { - return GetReflection()->GetInt64(*msg_, field_desc_); - } - - uint64_t GetUInt64() const { - return GetReflection()->GetUInt64(*msg_, field_desc_); - } - - double GetFloat() const { - return GetReflection()->GetFloat(*msg_, field_desc_); - } - - double GetDouble() const { - return GetReflection()->GetDouble(*msg_, field_desc_); - } - - const std::string* GetString(std::string* buffer) const { - return &GetReflection()->GetStringReference(*msg_, field_desc_, buffer); - } - - const Message* GetMessage() const { - // Unset wrapper types have special semantics. - // If set, return the unwrapped value, else return 'null'. - if (unset_wrapper_as_null_ && - !GetReflection()->HasField(*msg_, field_desc_) && - IsWrapperType(field_desc_)) { - return nullptr; - } - return &GetReflection()->GetMessage(*msg_, field_desc_); - } - - int64_t GetEnumValue() const { - return GetReflection()->GetEnumValue(*msg_, field_desc_); - } - - const Reflection* GetReflection() const { return msg_->GetReflection(); } - - private: - bool unset_wrapper_as_null_; -}; - -// Accessor class, to work with repeated fields. -class RepeatedFieldAccessor : public FieldAccessor { - public: - RepeatedFieldAccessor(const Message* msg, const FieldDescriptor* field_desc, - int index) - : FieldAccessor(msg, field_desc), index_(index) {} - - bool GetBool() const { - return GetReflection()->GetRepeatedBool(*msg_, field_desc_, index_); - } - - int64_t GetInt32() const { - return GetReflection()->GetRepeatedInt32(*msg_, field_desc_, index_); - } - - uint64_t GetUInt32() const { - return GetReflection()->GetRepeatedUInt32(*msg_, field_desc_, index_); - } - - int64_t GetInt64() const { - return GetReflection()->GetRepeatedInt64(*msg_, field_desc_, index_); - } - - uint64_t GetUInt64() const { - return GetReflection()->GetRepeatedUInt64(*msg_, field_desc_, index_); - } - - double GetFloat() const { - return GetReflection()->GetRepeatedFloat(*msg_, field_desc_, index_); - } - - double GetDouble() const { - return GetReflection()->GetRepeatedDouble(*msg_, field_desc_, index_); - } - - const std::string* GetString(std::string* buffer) const { - return &GetReflection()->GetRepeatedStringReference(*msg_, field_desc_, - index_, buffer); - } - - const Message* GetMessage() const { - return &GetReflection()->GetRepeatedMessage(*msg_, field_desc_, index_); - } - - int64_t GetEnumValue() const { - return GetReflection()->GetRepeatedEnumValue(*msg_, field_desc_, index_); - } - - const Reflection* GetReflection() const { return msg_->GetReflection(); } - - private: - int index_; -}; - -// Accessor class, to work with map values -class MapValueAccessor : public FieldAccessor { - public: - MapValueAccessor(const Message* msg, const FieldDescriptor* field_desc, - const MapValueConstRef* value_ref) - : FieldAccessor(msg, field_desc), value_ref_(value_ref) {} - - bool GetBool() const { return value_ref_->GetBoolValue(); } - - int64_t GetInt32() const { return value_ref_->GetInt32Value(); } - - uint64_t GetUInt32() const { return value_ref_->GetUInt32Value(); } - - int64_t GetInt64() const { return value_ref_->GetInt64Value(); } - - uint64_t GetUInt64() const { return value_ref_->GetUInt64Value(); } - - double GetFloat() const { return value_ref_->GetFloatValue(); } - - double GetDouble() const { return value_ref_->GetDoubleValue(); } - - const std::string* GetString(std::string* /*buffer*/) const { - return &value_ref_->GetStringValue(); - } - - const Message* GetMessage() const { return &value_ref_->GetMessageValue(); } - - int64_t GetEnumValue() const { return value_ref_->GetEnumValue(); } - - const Reflection* GetReflection() const { return msg_->GetReflection(); } - - private: - const MapValueConstRef* value_ref_; -}; - -} // namespace absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, const FieldDescriptor* desc, @@ -357,401 +41,45 @@ absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, ProtoWrapperTypeOptions options, google::protobuf::Arena* arena, CelValue* result) { - ScalarFieldAccessor accessor( - msg, desc, (options == ProtoWrapperTypeOptions::kUnsetNull)); - return accessor.CreateValueFromFieldAccessor(arena, result); + CEL_ASSIGN_OR_RETURN( + *result, + internal::CreateValueFromSingleField( + msg, desc, options, &CelProtoWrapper::InternalWrapMessage, arena)); + return absl::OkStatus(); } absl::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, const FieldDescriptor* desc, google::protobuf::Arena* arena, int index, CelValue* result) { - RepeatedFieldAccessor accessor(msg, desc, index); - return accessor.CreateValueFromFieldAccessor(arena, result); + CEL_ASSIGN_OR_RETURN( + *result, + internal::CreateValueFromRepeatedField( + msg, desc, index, &CelProtoWrapper::InternalWrapMessage, arena)); + return absl::OkStatus(); } absl::Status CreateValueFromMapValue(const google::protobuf::Message* msg, const FieldDescriptor* desc, const MapValueConstRef* value_ref, google::protobuf::Arena* arena, CelValue* result) { - MapValueAccessor accessor(msg, desc, value_ref); - return accessor.CreateValueFromFieldAccessor(arena, result); + CEL_ASSIGN_OR_RETURN( + *result, + internal::CreateValueFromMapValue( + msg, desc, value_ref, &CelProtoWrapper::InternalWrapMessage, arena)); + return absl::OkStatus(); } -// Singular message fields and repeated message fields have similar access model -// To provide common approach, we implement field setter classes, based on CRTP. -// FieldAccessor is CRTP base class, specifying Get.. method family. -template -class FieldSetter { - public: - bool AssignBool(const CelValue& cel_value) const { - bool value; - - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetBool(value); - return true; - } - - bool AssignInt32(const CelValue& cel_value) const { - int64_t value; - if (!cel_value.GetValue(&value)) { - return false; - } - if (!cel::internal::CheckedInt64ToInt32(value).ok()) { - return false; - } - static_cast(this)->SetInt32(value); - return true; - } - - bool AssignUInt32(const CelValue& cel_value) const { - uint64_t value; - if (!cel_value.GetValue(&value)) { - return false; - } - if (!cel::internal::CheckedUint64ToUint32(value).ok()) { - return false; - } - static_cast(this)->SetUInt32(value); - return true; - } - - bool AssignInt64(const CelValue& cel_value) const { - int64_t value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetInt64(value); - return true; - } - - bool AssignUInt64(const CelValue& cel_value) const { - uint64_t value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetUInt64(value); - return true; - } - - bool AssignFloat(const CelValue& cel_value) const { - double value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetFloat(value); - return true; - } - - bool AssignDouble(const CelValue& cel_value) const { - double value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetDouble(value); - return true; - } - - bool AssignString(const CelValue& cel_value) const { - CelValue::StringHolder value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetString(value); - return true; - } - - bool AssignBytes(const CelValue& cel_value) const { - CelValue::BytesHolder value; - if (!cel_value.GetValue(&value)) { - return false; - } - static_cast(this)->SetBytes(value); - return true; - } - - bool AssignEnum(const CelValue& cel_value) const { - int64_t value; - if (!cel_value.GetValue(&value)) { - return false; - } - if (!cel::internal::CheckedInt64ToInt32(value).ok()) { - return false; - } - static_cast(this)->SetEnum(value); - return true; - } - - bool AssignMessage(const CelValue& cel_value) const { - // Assigning a NULL to a message is OK, but a no-op. - if (cel_value.IsNull()) { - return true; - } - - if (CelValue::MessageWrapper wrapper; - cel_value.GetValue(&wrapper) && wrapper.HasFullProto()) { - static_cast(this)->SetMessage( - cel::internal::down_cast(wrapper.message_ptr())); - return true; - } - - return false; - } - - // This method provides message field content, wrapped in CelValue. - // If value provided successfully, returns Ok. - // arena Arena to use for allocations if needed. - // result pointer to object to store value in. - bool SetFieldFromCelValue(const CelValue& value) { - switch (field_desc_->cpp_type()) { - case FieldDescriptor::CPPTYPE_BOOL: { - return AssignBool(value); - } - case FieldDescriptor::CPPTYPE_INT32: { - return AssignInt32(value); - } - case FieldDescriptor::CPPTYPE_INT64: { - return AssignInt64(value); - } - case FieldDescriptor::CPPTYPE_UINT32: { - return AssignUInt32(value); - } - case FieldDescriptor::CPPTYPE_UINT64: { - return AssignUInt64(value); - } - case FieldDescriptor::CPPTYPE_FLOAT: { - return AssignFloat(value); - } - case FieldDescriptor::CPPTYPE_DOUBLE: { - return AssignDouble(value); - } - case FieldDescriptor::CPPTYPE_STRING: { - switch (field_desc_->type()) { - case FieldDescriptor::TYPE_STRING: - - return AssignString(value); - case FieldDescriptor::TYPE_BYTES: - return AssignBytes(value); - default: - return false; - } - break; - } - case FieldDescriptor::CPPTYPE_MESSAGE: { - // When the field is a message, it might be a well-known type with a - // non-proto representation that requires special handling before it - // can be set on the field. - auto wrapped_value = CelProtoWrapper::MaybeWrapValue( - field_desc_->message_type(), value, arena_); - return AssignMessage(wrapped_value.value_or(value)); - } - case FieldDescriptor::CPPTYPE_ENUM: { - return AssignEnum(value); - } - default: - return false; - } - - return true; - } - - protected: - FieldSetter(Message* msg, const FieldDescriptor* field_desc, Arena* arena) - : msg_(msg), field_desc_(field_desc), arena_(arena) {} - - Message* msg_; - const FieldDescriptor* field_desc_; - Arena* arena_; -}; - -// Accessor class, to work with singular fields -class ScalarFieldSetter : public FieldSetter { - public: - ScalarFieldSetter(Message* msg, const FieldDescriptor* field_desc, - Arena* arena) - : FieldSetter(msg, field_desc, arena) {} - - bool SetBool(bool value) const { - GetReflection()->SetBool(msg_, field_desc_, value); - return true; - } - - bool SetInt32(int32_t value) const { - GetReflection()->SetInt32(msg_, field_desc_, value); - return true; - } - - bool SetUInt32(uint32_t value) const { - GetReflection()->SetUInt32(msg_, field_desc_, value); - return true; - } - - bool SetInt64(int64_t value) const { - GetReflection()->SetInt64(msg_, field_desc_, value); - return true; - } - - bool SetUInt64(uint64_t value) const { - GetReflection()->SetUInt64(msg_, field_desc_, value); - return true; - } - - bool SetFloat(float value) const { - GetReflection()->SetFloat(msg_, field_desc_, value); - return true; - } - - bool SetDouble(double value) const { - GetReflection()->SetDouble(msg_, field_desc_, value); - return true; - } - - bool SetString(CelValue::StringHolder value) const { - GetReflection()->SetString(msg_, field_desc_, std::string(value.value())); - return true; - } - - bool SetBytes(CelValue::BytesHolder value) const { - GetReflection()->SetString(msg_, field_desc_, std::string(value.value())); - return true; - } - - bool SetMessage(const Message* value) const { - if (!value) { - GOOGLE_LOG(ERROR) << "Message is NULL"; - return true; - } - - if (value->GetDescriptor()->full_name() == - field_desc_->message_type()->full_name()) { - GetReflection()->MutableMessage(msg_, field_desc_)->MergeFrom(*value); - return true; - - } else if (field_desc_->message_type()->full_name() == kProtobufAny) { - auto any_msg = google::protobuf::DynamicCastToGenerated( - GetReflection()->MutableMessage(msg_, field_desc_)); - if (any_msg == nullptr) { - // TODO(issues/68): This is probably a dynamic message. We should - // implement this once we add support for dynamic protobuf types. - return false; - } - any_msg->set_type_url(absl::StrCat(kTypeGoogleApisComPrefix, - value->GetDescriptor()->full_name())); - return value->SerializeToString(any_msg->mutable_value()); - } - return false; - } - - bool SetEnum(const int64_t value) const { - GetReflection()->SetEnumValue(msg_, field_desc_, value); - return true; - } - - const Reflection* GetReflection() const { return msg_->GetReflection(); } -}; - -// Appender class, to work with repeated fields -class RepeatedFieldSetter : public FieldSetter { - public: - RepeatedFieldSetter(Message* msg, const FieldDescriptor* field_desc, - Arena* arena) - : FieldSetter(msg, field_desc, arena) {} - - bool SetBool(bool value) const { - GetReflection()->AddBool(msg_, field_desc_, value); - return true; - } - - bool SetInt32(int32_t value) const { - GetReflection()->AddInt32(msg_, field_desc_, value); - return true; - } - - bool SetUInt32(uint32_t value) const { - GetReflection()->AddUInt32(msg_, field_desc_, value); - return true; - } - - bool SetInt64(int64_t value) const { - GetReflection()->AddInt64(msg_, field_desc_, value); - return true; - } - - bool SetUInt64(uint64_t value) const { - GetReflection()->AddUInt64(msg_, field_desc_, value); - return true; - } - - bool SetFloat(float value) const { - GetReflection()->AddFloat(msg_, field_desc_, value); - return true; - } - - bool SetDouble(double value) const { - GetReflection()->AddDouble(msg_, field_desc_, value); - return true; - } - - bool SetString(CelValue::StringHolder value) const { - GetReflection()->AddString(msg_, field_desc_, std::string(value.value())); - return true; - } - - bool SetBytes(CelValue::BytesHolder value) const { - GetReflection()->AddString(msg_, field_desc_, std::string(value.value())); - return true; - } - - bool SetMessage(const Message* value) const { - if (!value) return true; - if (value->GetDescriptor()->full_name() != - field_desc_->message_type()->full_name()) { - return false; - } - - GetReflection()->AddMessage(msg_, field_desc_)->MergeFrom(*value); - return true; - } - - bool SetEnum(const int64_t value) const { - GetReflection()->AddEnumValue(msg_, field_desc_, value); - return true; - } - - private: - const Reflection* GetReflection() const { return msg_->GetReflection(); } -}; - -// This method sets message field -// If value provided successfully, returns Ok. -// arena Arena to use for allocations if needed. -// result pointer to object to store value in. absl::Status SetValueToSingleField(const CelValue& value, const FieldDescriptor* desc, Message* msg, Arena* arena) { - ScalarFieldSetter setter(msg, desc, arena); - return (setter.SetFieldFromCelValue(value)) - ? absl::OkStatus() - : absl::InvalidArgumentError(absl::Substitute( - "Could not assign supplied argument to message \"$0\" field " - "\"$1\" of type $2: value type \"$3\"", - msg->GetDescriptor()->name(), desc->name(), - desc->type_name(), CelValue::TypeName(value.type()))); + return internal::SetValueToSingleField(value, desc, msg, arena); } absl::Status AddValueToRepeatedField(const CelValue& value, const FieldDescriptor* desc, Message* msg, Arena* arena) { - RepeatedFieldSetter setter(msg, desc, arena); - return (setter.SetFieldFromCelValue(value)) - ? absl::OkStatus() - : absl::InvalidArgumentError(absl::Substitute( - "Could not add supplied argument to message \"$0\" field " - "\"$1\" of type $2: value type \"$3\"", - msg->GetDescriptor()->name(), desc->name(), - desc->type_name(), CelValue::TypeName(value.type()))); + return internal::AddValueToRepeatedField(value, desc, msg, arena); } } // namespace google::api::expr::runtime diff --git a/eval/public/containers/field_access.h b/eval/public/containers/field_access.h index bd15227ba..69d3191dd 100644 --- a/eval/public/containers/field_access.h +++ b/eval/public/containers/field_access.h @@ -1,19 +1,11 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_FIELD_ACCESS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_FIELD_ACCESS_H_ +#include "eval/public/cel_options.h" #include "eval/public/cel_value.h" namespace google::api::expr::runtime { -// Options for handling unset wrapper types. -enum class ProtoWrapperTypeOptions { - // Default: legacy behavior following proto semantics (unset behaves as though - // it is set to default value). - kUnsetProtoDefault, - // CEL spec behavior, unset wrapper is treated as a null value when accessed. - kUnsetNull, -}; - // Creates CelValue from singular message field. // Returns status of the operation. // msg Message containing the field. diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 87fa55fb3..02360746b 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -102,6 +102,7 @@ cc_library( deps = [ ":cel_proto_wrap_util", ":protobuf_value_factory", + "//eval/public:cel_options", "//eval/public:cel_value", "//internal:casts", "//internal:overflow", diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index 8ff065efc..496f134e8 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -27,17 +27,17 @@ using ::google::protobuf::Arena; using ::google::protobuf::Descriptor; using ::google::protobuf::Message; -CelValue WrapMessage(const Message* m) { - return CelValue::CreateMessageWrapper(CelValue::MessageWrapper(m)); -} - } // namespace +CelValue CelProtoWrapper::InternalWrapMessage(const Message* message) { + return CelValue::CreateMessage(message); +} + // CreateMessage creates CelValue from google::protobuf::Message. // As some of CEL basic types are subclassing google::protobuf::Message, // this method contains type checking and downcasts. CelValue CelProtoWrapper::CreateMessage(const Message* value, Arena* arena) { - return internal::UnwrapMessageToValue(value, &WrapMessage, arena); + return internal::UnwrapMessageToValue(value, &InternalWrapMessage, arena); } absl::optional CelProtoWrapper::MaybeWrapValue( @@ -45,7 +45,7 @@ absl::optional CelProtoWrapper::MaybeWrapValue( const Message* msg = internal::MaybeWrapValueToMessage(descriptor, value, arena); if (msg != nullptr) { - return WrapMessage(msg); + return InternalWrapMessage(msg); } else { return absl::nullopt; } diff --git a/eval/public/structs/cel_proto_wrapper.h b/eval/public/structs/cel_proto_wrapper.h index 633be5f28..2d65155c5 100644 --- a/eval/public/structs/cel_proto_wrapper.h +++ b/eval/public/structs/cel_proto_wrapper.h @@ -17,6 +17,10 @@ class CelProtoWrapper { static CelValue CreateMessage(const google::protobuf::Message* value, google::protobuf::Arena* arena); + // Internal utility for creating a CelValue wrapping a user defined type. + // Assumes that the message has been properly unpacked. + static CelValue InternalWrapMessage(const google::protobuf::Message* message); + // CreateDuration creates CelValue from a non-null protobuf duration value. static CelValue CreateDuration(const google::protobuf::Duration* value) { return CelValue(expr::internal::DecodeDuration(*value)); diff --git a/eval/public/structs/field_access_impl.h b/eval/public/structs/field_access_impl.h index 150280e28..4e2caca64 100644 --- a/eval/public/structs/field_access_impl.h +++ b/eval/public/structs/field_access_impl.h @@ -15,20 +15,12 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_FIELD_ACCESS_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_FIELD_ACCESS_IMPL_H_ +#include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/structs/protobuf_value_factory.h" namespace google::api::expr::runtime::internal { -// Options for handling unset wrapper types. -enum class ProtoWrapperTypeOptions { - // Default: legacy behavior following proto semantics (unset behaves as though - // it is set to default value). - kUnsetProtoDefault, - // CEL spec behavior, unset wrapper is treated as a null value when accessed. - kUnsetNull, -}; - // Creates CelValue from singular message field. // Returns status of the operation. // msg Message containing the field. From 5d2b8ef8f79636150bd3a64c0ced4eeb9bab6038 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 Apr 2022 22:39:40 +0000 Subject: [PATCH 067/111] Refactor FieldBackedListImpl to remove required build dependency on CelProtoWrapper::CreateStruct PiperOrigin-RevId: 441597897 --- eval/public/containers/BUILD | 42 ++- .../containers/field_backed_list_impl.cc | 30 --- .../containers/field_backed_list_impl.h | 24 +- .../internal_field_backed_list_impl.cc | 36 +++ .../internal_field_backed_list_impl.h | 59 ++++ .../internal_field_backed_list_impl_test.cc | 252 ++++++++++++++++++ 6 files changed, 391 insertions(+), 52 deletions(-) delete mode 100644 eval/public/containers/field_backed_list_impl.cc create mode 100644 eval/public/containers/internal_field_backed_list_impl.cc create mode 100644 eval/public/containers/internal_field_backed_list_impl.h create mode 100644 eval/public/containers/internal_field_backed_list_impl_test.cc diff --git a/eval/public/containers/BUILD b/eval/public/containers/BUILD index bc4f11f18..7d6bb4b74 100644 --- a/eval/public/containers/BUILD +++ b/eval/public/containers/BUILD @@ -18,6 +18,11 @@ licenses(["notice"]) # Apache 2.0 # TODO(issues/69): Expose this in a public API. +package_group( + name = "cel_internal", + packages = ["//eval/..."], +) + cc_library( name = "field_access", srcs = [ @@ -72,16 +77,13 @@ cc_library( cc_library( name = "field_backed_list_impl", - srcs = [ - "field_backed_list_impl.cc", - ], hdrs = [ "field_backed_list_impl.h", ], deps = [ - ":field_access", + ":internal_field_backed_list_impl", "//eval/public:cel_value", - "@com_google_absl//absl/strings", + "//eval/public/structs:cel_proto_wrapper", ], ) @@ -164,3 +166,33 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "internal_field_backed_list_impl", + srcs = [ + "internal_field_backed_list_impl.cc", + ], + hdrs = [ + "internal_field_backed_list_impl.h", + ], + deps = [ + "//eval/public:cel_value", + "//eval/public/structs:field_access_impl", + "//eval/public/structs:protobuf_value_factory", + ], +) + +cc_test( + name = "internal_field_backed_list_impl_test", + size = "small", + srcs = [ + "internal_field_backed_list_impl_test.cc", + ], + deps = [ + ":internal_field_backed_list_impl", + "//eval/public/structs:cel_proto_wrapper", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "//testutil:util", + ], +) diff --git a/eval/public/containers/field_backed_list_impl.cc b/eval/public/containers/field_backed_list_impl.cc deleted file mode 100644 index 2fa86c272..000000000 --- a/eval/public/containers/field_backed_list_impl.cc +++ /dev/null @@ -1,30 +0,0 @@ - -#include "eval/public/containers/field_backed_list_impl.h" - -#include "eval/public/cel_value.h" -#include "eval/public/containers/field_access.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -int FieldBackedListImpl::size() const { - return reflection_->FieldSize(*message_, descriptor_); -} - -CelValue FieldBackedListImpl::operator[](int index) const { - CelValue result = CelValue::CreateNull(); - auto status = CreateValueFromRepeatedField(message_, descriptor_, arena_, - index, &result); - if (!status.ok()) { - result = CreateErrorValue(arena_, status.ToString()); - } - - return result; -} - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google diff --git a/eval/public/containers/field_backed_list_impl.h b/eval/public/containers/field_backed_list_impl.h index ac330850c..39f654764 100644 --- a/eval/public/containers/field_backed_list_impl.h +++ b/eval/public/containers/field_backed_list_impl.h @@ -2,6 +2,8 @@ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_FIELD_BACKED_LIST_IMPL_H_ #include "eval/public/cel_value.h" +#include "eval/public/containers/internal_field_backed_list_impl.h" +#include "eval/public/structs/cel_proto_wrapper.h" namespace google { namespace api { @@ -10,29 +12,17 @@ namespace runtime { // CelList implementation that uses "repeated" message field // as backing storage. -class FieldBackedListImpl : public CelList { +class FieldBackedListImpl : public internal::FieldBackedListImpl { public: // message contains the "repeated" field // descriptor FieldDescriptor for the field + // arena is used for incidental allocations when unwrapping the field. FieldBackedListImpl(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, google::protobuf::Arena* arena) - : message_(message), - descriptor_(descriptor), - reflection_(message_->GetReflection()), - arena_(arena) {} - - // List size. - int size() const override; - - // List element access operator. - CelValue operator[](int index) const override; - - private: - const google::protobuf::Message* message_; - const google::protobuf::FieldDescriptor* descriptor_; - const google::protobuf::Reflection* reflection_; - google::protobuf::Arena* arena_; + : internal::FieldBackedListImpl( + message, descriptor, &CelProtoWrapper::InternalWrapMessage, arena) { + } }; } // namespace runtime diff --git a/eval/public/containers/internal_field_backed_list_impl.cc b/eval/public/containers/internal_field_backed_list_impl.cc new file mode 100644 index 000000000..6541db468 --- /dev/null +++ b/eval/public/containers/internal_field_backed_list_impl.cc @@ -0,0 +1,36 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/containers/internal_field_backed_list_impl.h" + +#include "eval/public/cel_value.h" +#include "eval/public/structs/field_access_impl.h" + +namespace google::api::expr::runtime::internal { + +int FieldBackedListImpl::size() const { + return reflection_->FieldSize(*message_, descriptor_); +} + +CelValue FieldBackedListImpl::operator[](int index) const { + auto result = CreateValueFromRepeatedField(message_, descriptor_, index, + factory_, arena_); + if (!result.ok()) { + CreateErrorValue(arena_, result.status().ToString()); + } + + return *result; +} + +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/containers/internal_field_backed_list_impl.h b/eval/public/containers/internal_field_backed_list_impl.h new file mode 100644 index 000000000..95f8de425 --- /dev/null +++ b/eval/public/containers/internal_field_backed_list_impl.h @@ -0,0 +1,59 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_LIST_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_LIST_IMPL_H_ + +#include + +#include "eval/public/cel_value.h" +#include "eval/public/structs/protobuf_value_factory.h" + +namespace google::api::expr::runtime::internal { + +// CelList implementation that uses "repeated" message field +// as backing storage. +// +// The internal implementation allows for interface updates without breaking +// clients that depend on this class for implementing custom CEL lists +class FieldBackedListImpl : public CelList { + public: + // message contains the "repeated" field + // descriptor FieldDescriptor for the field + FieldBackedListImpl(const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* descriptor, + ProtobufValueFactory factory, google::protobuf::Arena* arena) + : message_(message), + descriptor_(descriptor), + reflection_(message_->GetReflection()), + factory_(std::move(factory)), + arena_(arena) {} + + // List size. + int size() const override; + + // List element access operator. + CelValue operator[](int index) const override; + + private: + const google::protobuf::Message* message_; + const google::protobuf::FieldDescriptor* descriptor_; + const google::protobuf::Reflection* reflection_; + ProtobufValueFactory factory_; + google::protobuf::Arena* arena_; +}; + +} // namespace google::api::expr::runtime::internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_LIST_IMPL_H_ diff --git a/eval/public/containers/internal_field_backed_list_impl_test.cc b/eval/public/containers/internal_field_backed_list_impl_test.cc new file mode 100644 index 000000000..41b529527 --- /dev/null +++ b/eval/public/containers/internal_field_backed_list_impl_test.cc @@ -0,0 +1,252 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/containers/internal_field_backed_list_impl.h" + +#include + +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" +#include "testutil/util.h" + +namespace google::api::expr::runtime::internal { +namespace { + +using ::google::api::expr::testutil::EqualsProto; +using testing::DoubleEq; +using testing::Eq; + +// Helper method. Creates simple pipeline containing Select step and runs it. +std::unique_ptr CreateList(const TestMessage* message, + const std::string& field, + google::protobuf::Arena* arena) { + const google::protobuf::FieldDescriptor* field_desc = + message->GetDescriptor()->FindFieldByName(field); + + return absl::make_unique( + message, field_desc, &CelProtoWrapper::InternalWrapMessage, arena); +} + +TEST(FieldBackedListImplTest, BoolDatatypeTest) { + TestMessage message; + message.add_bool_list(true); + message.add_bool_list(false); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "bool_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].BoolOrDie(), true); + EXPECT_EQ((*cel_list)[1].BoolOrDie(), false); +} + +TEST(FieldBackedListImplTest, TestLength0) { + TestMessage message; + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "int32_list", &arena); + + ASSERT_EQ(cel_list->size(), 0); +} + +TEST(FieldBackedListImplTest, TestLength1) { + TestMessage message; + message.add_int32_list(1); + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "int32_list", &arena); + + ASSERT_EQ(cel_list->size(), 1); + EXPECT_EQ((*cel_list)[0].Int64OrDie(), 1); +} + +TEST(FieldBackedListImplTest, TestLength100000) { + TestMessage message; + + const int kLen = 100000; + + for (int i = 0; i < kLen; i++) { + message.add_int32_list(i); + } + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "int32_list", &arena); + + ASSERT_EQ(cel_list->size(), kLen); + for (int i = 0; i < kLen; i++) { + EXPECT_EQ((*cel_list)[i].Int64OrDie(), i); + } +} + +TEST(FieldBackedListImplTest, Int32DatatypeTest) { + TestMessage message; + message.add_int32_list(1); + message.add_int32_list(2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "int32_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].Int64OrDie(), 1); + EXPECT_EQ((*cel_list)[1].Int64OrDie(), 2); +} + +TEST(FieldBackedListImplTest, Int64DatatypeTest) { + TestMessage message; + message.add_int64_list(1); + message.add_int64_list(2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "int64_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].Int64OrDie(), 1); + EXPECT_EQ((*cel_list)[1].Int64OrDie(), 2); +} + +TEST(FieldBackedListImplTest, Uint32DatatypeTest) { + TestMessage message; + message.add_uint32_list(1); + message.add_uint32_list(2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "uint32_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].Uint64OrDie(), 1); + EXPECT_EQ((*cel_list)[1].Uint64OrDie(), 2); +} + +TEST(FieldBackedListImplTest, Uint64DatatypeTest) { + TestMessage message; + message.add_uint64_list(1); + message.add_uint64_list(2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "uint64_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].Uint64OrDie(), 1); + EXPECT_EQ((*cel_list)[1].Uint64OrDie(), 2); +} + +TEST(FieldBackedListImplTest, FloatDatatypeTest) { + TestMessage message; + message.add_float_list(1); + message.add_float_list(2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "float_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_THAT((*cel_list)[0].DoubleOrDie(), DoubleEq(1)); + EXPECT_THAT((*cel_list)[1].DoubleOrDie(), DoubleEq(2)); +} + +TEST(FieldBackedListImplTest, DoubleDatatypeTest) { + TestMessage message; + message.add_double_list(1); + message.add_double_list(2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "double_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_THAT((*cel_list)[0].DoubleOrDie(), DoubleEq(1)); + EXPECT_THAT((*cel_list)[1].DoubleOrDie(), DoubleEq(2)); +} + +TEST(FieldBackedListImplTest, StringDatatypeTest) { + TestMessage message; + message.add_string_list("1"); + message.add_string_list("2"); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "string_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].StringOrDie().value(), "1"); + EXPECT_EQ((*cel_list)[1].StringOrDie().value(), "2"); +} + + +TEST(FieldBackedListImplTest, BytesDatatypeTest) { + TestMessage message; + message.add_bytes_list("1"); + message.add_bytes_list("2"); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "bytes_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_EQ((*cel_list)[0].BytesOrDie().value(), "1"); + EXPECT_EQ((*cel_list)[1].BytesOrDie().value(), "2"); +} + +TEST(FieldBackedListImplTest, MessageDatatypeTest) { + TestMessage message; + TestMessage* msg1 = message.add_message_list(); + TestMessage* msg2 = message.add_message_list(); + + msg1->set_string_value("1"); + msg2->set_string_value("2"); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "message_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_THAT(*msg1, EqualsProto(*((*cel_list)[0].MessageOrDie()))); + EXPECT_THAT(*msg2, EqualsProto(*((*cel_list)[1].MessageOrDie()))); +} + +TEST(FieldBackedListImplTest, EnumDatatypeTest) { + TestMessage message; + + message.add_enum_list(TestMessage::TEST_ENUM_1); + message.add_enum_list(TestMessage::TEST_ENUM_2); + + google::protobuf::Arena arena; + + auto cel_list = CreateList(&message, "enum_list", &arena); + + ASSERT_EQ(cel_list->size(), 2); + + EXPECT_THAT((*cel_list)[0].Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); + EXPECT_THAT((*cel_list)[1].Int64OrDie(), Eq(TestMessage::TEST_ENUM_2)); +} + +} // namespace +} // namespace google::api::expr::runtime::internal From 7100bdc933ceea01b96d7078758aadc483784bba Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 Apr 2022 23:01:05 +0000 Subject: [PATCH 068/111] Migrate FieldBackedMap off of hard dependency on CelProtoWrapper to help resolve cyclic dependency. PiperOrigin-RevId: 441602804 --- eval/public/containers/BUILD | 45 ++- .../public/containers/field_backed_map_impl.h | 55 +--- .../containers/field_backed_map_impl_test.cc | 54 +--- ...l.cc => internal_field_backed_map_impl.cc} | 100 +++--- .../internal_field_backed_map_impl.h | 73 +++++ .../internal_field_backed_map_impl_test.cc | 288 ++++++++++++++++++ 6 files changed, 476 insertions(+), 139 deletions(-) rename eval/public/containers/{field_backed_map_impl.cc => internal_field_backed_map_impl.cc} (77%) create mode 100644 eval/public/containers/internal_field_backed_map_impl.h create mode 100644 eval/public/containers/internal_field_backed_map_impl_test.cc diff --git a/eval/public/containers/BUILD b/eval/public/containers/BUILD index 7d6bb4b74..3eb5effe6 100644 --- a/eval/public/containers/BUILD +++ b/eval/public/containers/BUILD @@ -89,18 +89,14 @@ cc_library( cc_library( name = "field_backed_map_impl", - srcs = [ - "field_backed_map_impl.cc", - ], hdrs = [ "field_backed_map_impl.h", ], deps = [ - ":field_access", + ":internal_field_backed_map_impl", "//eval/public:cel_value", - "@com_google_absl//absl/status", + "//eval/public/structs:cel_proto_wrapper", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) @@ -196,3 +192,40 @@ cc_test( "//testutil:util", ], ) + +cc_library( + name = "internal_field_backed_map_impl", + srcs = [ + "internal_field_backed_map_impl.cc", + ], + hdrs = [ + "internal_field_backed_map_impl.h", + ], + deps = [ + ":field_access", + "//eval/public:cel_value", + "//eval/public/structs:field_access_impl", + "//eval/public/structs:protobuf_value_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "internal_field_backed_map_impl_test", + size = "small", + srcs = [ + "internal_field_backed_map_impl_test.cc", + ], + visibility = [":cel_internal"], + deps = [ + ":internal_field_backed_map_impl", + "//eval/public/structs:cel_proto_wrapper", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) diff --git a/eval/public/containers/field_backed_map_impl.h b/eval/public/containers/field_backed_map_impl.h index 1ceb51185..8d8ded8b9 100644 --- a/eval/public/containers/field_backed_map_impl.h +++ b/eval/public/containers/field_backed_map_impl.h @@ -5,60 +5,31 @@ #include "google/protobuf/message.h" #include "absl/status/statusor.h" #include "eval/public/cel_value.h" +#include "eval/public/containers/internal_field_backed_map_impl.h" +#include "eval/public/structs/cel_proto_wrapper.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { // CelMap implementation that uses "map" message field // as backing storage. -class FieldBackedMapImpl : public CelMap { +// +// Trivial subclass of internal implementation to avoid API changes for clients +// that use this directly. +class FieldBackedMapImpl : public internal::FieldBackedMapImpl { public: // message contains the "map" field. Object stores the pointer // to the message, thus it is expected that message outlives the // object. // descriptor FieldDescriptor for the field + // arena is used for incidental allocations from unpacking the field. FieldBackedMapImpl(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, - google::protobuf::Arena* arena); - - // Map size. - int size() const override; - - // Map element access operator. - absl::optional operator[](CelValue key) const override; - - // Presence test function. - absl::StatusOr Has(const CelValue& key) const override; - - const CelList* ListKeys() const override; - - protected: - // These methods are exposed as protected methods for testing purposes since - // whether one or the other is used depends on build time flags, but each - // should be tested accordingly. - - absl::StatusOr LookupMapValue( - const CelValue& key, google::protobuf::MapValueConstRef* value_ref) const; - - absl::StatusOr LegacyHasMapValue(const CelValue& key) const; - - absl::optional LegacyLookupMapValue(const CelValue& key) const; - - private: - const google::protobuf::Message* message_; - const google::protobuf::FieldDescriptor* descriptor_; - const google::protobuf::FieldDescriptor* key_desc_; - const google::protobuf::FieldDescriptor* value_desc_; - const google::protobuf::Reflection* reflection_; - google::protobuf::Arena* arena_; - std::unique_ptr key_list_; + google::protobuf::Arena* arena) + : internal::FieldBackedMapImpl( + message, descriptor, &CelProtoWrapper::InternalWrapMessage, arena) { + } }; -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_FIELD_BACKED_MAP_IMPL_H_ diff --git a/eval/public/containers/field_backed_map_impl_test.cc b/eval/public/containers/field_backed_map_impl_test.cc index b5b11a017..1cf711851 100644 --- a/eval/public/containers/field_backed_map_impl_test.cc +++ b/eval/public/containers/field_backed_map_impl_test.cc @@ -8,10 +8,7 @@ #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime { namespace { using testing::Eq; @@ -19,25 +16,14 @@ using testing::HasSubstr; using testing::UnorderedPointwise; using cel::internal::StatusIs; -class FieldBackedMapTestImpl : public FieldBackedMapImpl { - public: - FieldBackedMapTestImpl(const google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* descriptor, - google::protobuf::Arena* arena) - : FieldBackedMapImpl(message, descriptor, arena) {} - - using FieldBackedMapImpl::LegacyHasMapValue; - using FieldBackedMapImpl::LegacyLookupMapValue; -}; - -// Helper method. Creates simple pipeline containing Select step and runs it. -std::unique_ptr CreateMap(const TestMessage* message, - const std::string& field, - google::protobuf::Arena* arena) { +// Test factory for FieldBackedMaps from message and field name. +std::unique_ptr CreateMap(const TestMessage* message, + const std::string& field, + google::protobuf::Arena* arena) { const google::protobuf::FieldDescriptor* field_desc = message->GetDescriptor()->FindFieldByName(field); - return absl::make_unique(message, field_desc, arena); + return absl::make_unique(message, field_desc, arena); } TEST(FieldBackedMapImplTest, BadKeyTypeTest) { @@ -56,7 +42,6 @@ TEST(FieldBackedMapImplTest, BadKeyTypeTest) { EXPECT_FALSE(result.ok()); EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kInvalidArgument)); - result = cel_map->LegacyHasMapValue(CelValue::CreateNull()); EXPECT_FALSE(result.ok()); EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kInvalidArgument)); @@ -65,12 +50,6 @@ TEST(FieldBackedMapImplTest, BadKeyTypeTest) { EXPECT_TRUE(lookup->IsError()); EXPECT_THAT(lookup->ErrorOrDie()->code(), Eq(absl::StatusCode::kInvalidArgument)); - - lookup = cel_map->LegacyLookupMapValue(CelValue::CreateNull()); - EXPECT_TRUE(lookup.has_value()); - EXPECT_TRUE(lookup->IsError()); - EXPECT_THAT(lookup->ErrorOrDie()->code(), - Eq(absl::StatusCode::kInvalidArgument)); } } @@ -86,14 +65,10 @@ TEST(FieldBackedMapImplTest, Int32KeyTest) { EXPECT_EQ((*cel_map)[CelValue::CreateInt64(0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateInt64(1)]->Int64OrDie(), 2); EXPECT_TRUE(cel_map->Has(CelValue::CreateInt64(1)).value_or(false)); - EXPECT_TRUE( - cel_map->LegacyHasMapValue(CelValue::CreateInt64(1)).value_or(false)); // Look up nonexistent key EXPECT_FALSE((*cel_map)[CelValue::CreateInt64(3)].has_value()); EXPECT_FALSE(cel_map->Has(CelValue::CreateInt64(3)).value_or(true)); - EXPECT_FALSE( - cel_map->LegacyHasMapValue(CelValue::CreateInt64(3)).value_or(true)); } TEST(FieldBackedMapImplTest, Int32KeyOutOfRangeTest) { @@ -125,10 +100,6 @@ TEST(FieldBackedMapImplTest, Int64KeyTest) { EXPECT_EQ((*cel_map)[CelValue::CreateInt64(0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateInt64(1)]->Int64OrDie(), 2); EXPECT_TRUE(cel_map->Has(CelValue::CreateInt64(1)).value_or(false)); - EXPECT_EQ( - cel_map->LegacyLookupMapValue(CelValue::CreateInt64(1))->Int64OrDie(), 2); - EXPECT_TRUE( - cel_map->LegacyHasMapValue(CelValue::CreateInt64(1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateInt64(3)].has_value(), false); @@ -144,8 +115,6 @@ TEST(FieldBackedMapImplTest, BoolKeyTest) { EXPECT_EQ((*cel_map)[CelValue::CreateBool(false)]->Int64OrDie(), 1); EXPECT_TRUE(cel_map->Has(CelValue::CreateBool(false)).value_or(false)); - EXPECT_TRUE( - cel_map->LegacyHasMapValue(CelValue::CreateBool(false)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateBool(true)].has_value(), false); @@ -165,8 +134,6 @@ TEST(FieldBackedMapImplTest, Uint32KeyTest) { EXPECT_EQ((*cel_map)[CelValue::CreateUint64(0)]->Uint64OrDie(), 1UL); EXPECT_EQ((*cel_map)[CelValue::CreateUint64(1)]->Uint64OrDie(), 2UL); EXPECT_TRUE(cel_map->Has(CelValue::CreateUint64(1)).value_or(false)); - EXPECT_TRUE( - cel_map->LegacyHasMapValue(CelValue::CreateUint64(1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateUint64(3)].has_value(), false); @@ -197,8 +164,6 @@ TEST(FieldBackedMapImplTest, Uint64KeyTest) { EXPECT_EQ((*cel_map)[CelValue::CreateUint64(0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateUint64(1)]->Int64OrDie(), 2); EXPECT_TRUE(cel_map->Has(CelValue::CreateUint64(1)).value_or(false)); - EXPECT_TRUE( - cel_map->LegacyHasMapValue(CelValue::CreateUint64(1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateUint64(3)].has_value(), false); @@ -220,8 +185,6 @@ TEST(FieldBackedMapImplTest, StringKeyTest) { EXPECT_EQ((*cel_map)[CelValue::CreateString(&test0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateString(&test1)]->Int64OrDie(), 2); EXPECT_TRUE(cel_map->Has(CelValue::CreateString(&test1)).value_or(false)); - EXPECT_TRUE(cel_map->LegacyHasMapValue(CelValue::CreateString(&test1)) - .value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateString(&test_notfound)].has_value(), @@ -271,7 +234,4 @@ TEST(FieldBackedMapImplTest, KeyListTest) { } } // namespace -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime diff --git a/eval/public/containers/field_backed_map_impl.cc b/eval/public/containers/internal_field_backed_map_impl.cc similarity index 77% rename from eval/public/containers/field_backed_map_impl.cc rename to eval/public/containers/internal_field_backed_map_impl.cc index 7f7460f99..2c837f64d 100644 --- a/eval/public/containers/field_backed_map_impl.cc +++ b/eval/public/containers/internal_field_backed_map_impl.cc @@ -1,6 +1,21 @@ -#include "eval/public/containers/field_backed_map_impl.h" +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/containers/internal_field_backed_map_impl.h" #include +#include #include "google/protobuf/descriptor.h" #include "google/protobuf/map_field.h" @@ -10,12 +25,12 @@ #include "absl/strings/str_cat.h" #include "eval/public/cel_value.h" #include "eval/public/containers/field_access.h" +#include "eval/public/structs/field_access_impl.h" +#include "eval/public/structs/protobuf_value_factory.h" #ifdef GOOGLE_PROTOBUF_HAS_CEL_MAP_REFLECTION_FRIEND -namespace google { -namespace protobuf { -namespace expr { +namespace google::protobuf::expr { // CelMapReflectionFriend provides access to Reflection's private methods. The // class is a friend of google::protobuf::Reflection. We do not add FieldBackedMapImpl as @@ -32,16 +47,11 @@ class CelMapReflectionFriend { } }; -} // namespace expr -} // namespace protobuf -} // namespace google +} // namespace google::protobuf::expr #endif // GOOGLE_PROTOBUF_HAS_CEL_MAP_REFLECTION_FRIEND -namespace google { -namespace api { -namespace expr { -namespace runtime { +namespace google::api::expr::runtime::internal { namespace { using google::protobuf::Descriptor; @@ -60,10 +70,12 @@ class KeyList : public CelList { // message contains the "repeated" field // descriptor FieldDescriptor for the field KeyList(const google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* descriptor, google::protobuf::Arena* arena) + const google::protobuf::FieldDescriptor* descriptor, + const ProtobufValueFactory& factory, google::protobuf::Arena* arena) : message_(message), descriptor_(descriptor), reflection_(message_->GetReflection()), + factory_(factory), arena_(arena) {} // List size. @@ -73,7 +85,6 @@ class KeyList : public CelList { // List element access operator. CelValue operator[](int index) const override { - CelValue key = CelValue::CreateNull(); const Message* entry = &reflection_->GetRepeatedMessage(*message_, descriptor_, index); @@ -86,17 +97,20 @@ class KeyList : public CelList { const FieldDescriptor* key_desc = entry_descriptor->FindFieldByNumber(kKeyTag); - auto status = CreateValueFromSingleField(entry, key_desc, arena_, &key); - if (!status.ok()) { - return CreateErrorValue(arena_, status); + absl::StatusOr key_value = CreateValueFromSingleField( + entry, key_desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, factory_, + arena_); + if (!key_value.ok()) { + return CreateErrorValue(arena_, key_value.status()); } - return key; + return *key_value; } private: const google::protobuf::Message* message_; const google::protobuf::FieldDescriptor* descriptor_; const google::protobuf::Reflection* reflection_; + const ProtobufValueFactory& factory_; google::protobuf::Arena* arena_; }; @@ -128,14 +142,16 @@ absl::Status InvalidMapKeyType(absl::string_view key_type) { FieldBackedMapImpl::FieldBackedMapImpl( const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, - google::protobuf::Arena* arena) + ProtobufValueFactory factory, google::protobuf::Arena* arena) : message_(message), descriptor_(descriptor), key_desc_(descriptor_->message_type()->FindFieldByNumber(kKeyTag)), value_desc_(descriptor_->message_type()->FindFieldByNumber(kValueTag)), reflection_(message_->GetReflection()), + factory_(std::move(factory)), arena_(arena), - key_list_(absl::make_unique(message, descriptor, arena)) {} + key_list_( + absl::make_unique(message, descriptor, factory_, arena)) {} int FieldBackedMapImpl::size() const { return reflection_->FieldSize(*message_, descriptor_); @@ -168,13 +184,12 @@ absl::optional FieldBackedMapImpl::operator[](CelValue key) const { // Get value descriptor treating it as a repeated field. // All values in protobuf map have the same type. // The map is not empty, because LookupMapValue returned true. - CelValue result = CelValue::CreateNull(); - const auto& status = CreateValueFromMapValue(message_, value_desc_, - &value_ref, arena_, &result); - if (!status.ok()) { - return CreateErrorValue(arena_, status); + absl::StatusOr result = CreateValueFromMapValue( + message_, value_desc_, &value_ref, factory_, arena_); + if (!result.ok()) { + return CreateErrorValue(arena_, result.status()); } - return result; + return *result; #else // GOOGLE_PROTOBUF_HAS_CEL_MAP_REFLECTION_FRIEND // Default proto implementation, does not use fast-path key lookup. @@ -262,7 +277,6 @@ absl::optional FieldBackedMapImpl::LegacyLookupMapValue( InvalidMapKeyType(key_desc_->cpp_type_name())); } - CelValue proto_key = CelValue::CreateNull(); int map_size = size(); for (int i = 0; i < map_size; i++) { const Message* entry = @@ -270,29 +284,30 @@ absl::optional FieldBackedMapImpl::LegacyLookupMapValue( if (entry == nullptr) continue; // Key Tag == 1 - auto status = - CreateValueFromSingleField(entry, key_desc_, arena_, &proto_key); - if (!status.ok()) { - return CreateErrorValue(arena_, status); + absl::StatusOr key_value = CreateValueFromSingleField( + entry, key_desc_, ProtoWrapperTypeOptions::kUnsetProtoDefault, factory_, + arena_); + if (!key_value.ok()) { + return CreateErrorValue(arena_, key_value.status()); } bool match = false; switch (key_desc_->cpp_type()) { case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: - match = key.BoolOrDie() == proto_key.BoolOrDie(); + match = key.BoolOrDie() == key_value->BoolOrDie(); break; case google::protobuf::FieldDescriptor::CPPTYPE_INT32: // fall through case google::protobuf::FieldDescriptor::CPPTYPE_INT64: - match = key.Int64OrDie() == proto_key.Int64OrDie(); + match = key.Int64OrDie() == key_value->Int64OrDie(); break; case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: // fall through case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: - match = key.Uint64OrDie() == proto_key.Uint64OrDie(); + match = key.Uint64OrDie() == key_value->Uint64OrDie(); break; case google::protobuf::FieldDescriptor::CPPTYPE_STRING: - match = key.StringOrDie() == proto_key.StringOrDie(); + match = key.StringOrDie() == key_value->StringOrDie(); break; default: // this would normally indicate a bad key type, which should not be @@ -301,19 +316,16 @@ absl::optional FieldBackedMapImpl::LegacyLookupMapValue( } if (match) { - CelValue result = CelValue::CreateNull(); - auto status = - CreateValueFromSingleField(entry, value_desc_, arena_, &result); - if (!status.ok()) { - return CreateErrorValue(arena_, status); + absl::StatusOr value_cel_value = CreateValueFromSingleField( + entry, value_desc_, ProtoWrapperTypeOptions::kUnsetProtoDefault, + factory_, arena_); + if (!value_cel_value.ok()) { + return CreateErrorValue(arena_, value_cel_value.status()); } - return result; + return *value_cel_value; } } return {}; } -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/containers/internal_field_backed_map_impl.h b/eval/public/containers/internal_field_backed_map_impl.h new file mode 100644 index 000000000..ae43a5e4c --- /dev/null +++ b/eval/public/containers/internal_field_backed_map_impl.h @@ -0,0 +1,73 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_MAP_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_MAP_IMPL_H_ + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "absl/status/statusor.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/protobuf_value_factory.h" + +namespace google::api::expr::runtime::internal { +// CelMap implementation that uses "map" message field +// as backing storage. +class FieldBackedMapImpl : public CelMap { + public: + // message contains the "map" field. Object stores the pointer + // to the message, thus it is expected that message outlives the + // object. + // descriptor FieldDescriptor for the field + FieldBackedMapImpl(const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* descriptor, + ProtobufValueFactory factory, google::protobuf::Arena* arena); + + // Map size. + int size() const override; + + // Map element access operator. + absl::optional operator[](CelValue key) const override; + + // Presence test function. + absl::StatusOr Has(const CelValue& key) const override; + + const CelList* ListKeys() const override; + + protected: + // These methods are exposed as protected methods for testing purposes since + // whether one or the other is used depends on build time flags, but each + // should be tested accordingly. + + absl::StatusOr LookupMapValue( + const CelValue& key, google::protobuf::MapValueConstRef* value_ref) const; + + absl::StatusOr LegacyHasMapValue(const CelValue& key) const; + + absl::optional LegacyLookupMapValue(const CelValue& key) const; + + private: + const google::protobuf::Message* message_; + const google::protobuf::FieldDescriptor* descriptor_; + const google::protobuf::FieldDescriptor* key_desc_; + const google::protobuf::FieldDescriptor* value_desc_; + const google::protobuf::Reflection* reflection_; + ProtobufValueFactory factory_; + google::protobuf::Arena* arena_; + std::unique_ptr key_list_; +}; + +} // namespace google::api::expr::runtime::internal + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_MAP_IMPL_H_ diff --git a/eval/public/containers/internal_field_backed_map_impl_test.cc b/eval/public/containers/internal_field_backed_map_impl_test.cc new file mode 100644 index 000000000..392b84f35 --- /dev/null +++ b/eval/public/containers/internal_field_backed_map_impl_test.cc @@ -0,0 +1,288 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/public/containers/internal_field_backed_map_impl.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime::internal { +namespace { + +using testing::Eq; +using testing::HasSubstr; +using testing::UnorderedPointwise; +using cel::internal::StatusIs; + +class FieldBackedMapTestImpl : public FieldBackedMapImpl { + public: + FieldBackedMapTestImpl(const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* descriptor, + google::protobuf::Arena* arena) + : FieldBackedMapImpl(message, descriptor, + &CelProtoWrapper::InternalWrapMessage, arena) {} + + // For code coverage, expose fallback lookups used when not compiled with + // support for optimized versions. + using FieldBackedMapImpl::LegacyHasMapValue; + using FieldBackedMapImpl::LegacyLookupMapValue; +}; + +// Helper method. Creates simple pipeline containing Select step and runs it. +std::unique_ptr CreateMap(const TestMessage* message, + const std::string& field, + google::protobuf::Arena* arena) { + const google::protobuf::FieldDescriptor* field_desc = + message->GetDescriptor()->FindFieldByName(field); + + return absl::make_unique(message, field_desc, arena); +} + +TEST(FieldBackedMapImplTest, BadKeyTypeTest) { + TestMessage message; + google::protobuf::Arena arena; + constexpr std::array map_types = { + "int64_int32_map", "uint64_int32_map", "string_int32_map", + "bool_int32_map", "int32_int32_map", "uint32_uint32_map", + }; + + for (auto map_type : map_types) { + auto cel_map = CreateMap(&message, std::string(map_type), &arena); + // Look up a boolean key. This should result in an error for both the + // presence test and the value lookup. + auto result = cel_map->Has(CelValue::CreateNull()); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kInvalidArgument)); + + result = cel_map->LegacyHasMapValue(CelValue::CreateNull()); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kInvalidArgument)); + + auto lookup = (*cel_map)[CelValue::CreateNull()]; + EXPECT_TRUE(lookup.has_value()); + EXPECT_TRUE(lookup->IsError()); + EXPECT_THAT(lookup->ErrorOrDie()->code(), + Eq(absl::StatusCode::kInvalidArgument)); + + lookup = cel_map->LegacyLookupMapValue(CelValue::CreateNull()); + EXPECT_TRUE(lookup.has_value()); + EXPECT_TRUE(lookup->IsError()); + EXPECT_THAT(lookup->ErrorOrDie()->code(), + Eq(absl::StatusCode::kInvalidArgument)); + } +} + +TEST(FieldBackedMapImplTest, Int32KeyTest) { + TestMessage message; + auto field_map = message.mutable_int32_int32_map(); + (*field_map)[0] = 1; + (*field_map)[1] = 2; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "int32_int32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateInt64(0)]->Int64OrDie(), 1); + EXPECT_EQ((*cel_map)[CelValue::CreateInt64(1)]->Int64OrDie(), 2); + EXPECT_TRUE(cel_map->Has(CelValue::CreateInt64(1)).value_or(false)); + EXPECT_TRUE( + cel_map->LegacyHasMapValue(CelValue::CreateInt64(1)).value_or(false)); + + // Look up nonexistent key + EXPECT_FALSE((*cel_map)[CelValue::CreateInt64(3)].has_value()); + EXPECT_FALSE(cel_map->Has(CelValue::CreateInt64(3)).value_or(true)); + EXPECT_FALSE( + cel_map->LegacyHasMapValue(CelValue::CreateInt64(3)).value_or(true)); +} + +TEST(FieldBackedMapImplTest, Int32KeyOutOfRangeTest) { + TestMessage message; + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "int32_int32_map", &arena); + + // Look up keys out of int32_t range + auto result = cel_map->Has( + CelValue::CreateInt64(std::numeric_limits::max() + 1L)); + EXPECT_THAT(result.status(), + StatusIs(absl::StatusCode::kOutOfRange, HasSubstr("overflow"))); + + result = cel_map->Has( + CelValue::CreateInt64(std::numeric_limits::lowest() - 1L)); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kOutOfRange)); +} + +TEST(FieldBackedMapImplTest, Int64KeyTest) { + TestMessage message; + auto field_map = message.mutable_int64_int32_map(); + (*field_map)[0] = 1; + (*field_map)[1] = 2; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "int64_int32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateInt64(0)]->Int64OrDie(), 1); + EXPECT_EQ((*cel_map)[CelValue::CreateInt64(1)]->Int64OrDie(), 2); + EXPECT_TRUE(cel_map->Has(CelValue::CreateInt64(1)).value_or(false)); + EXPECT_EQ( + cel_map->LegacyLookupMapValue(CelValue::CreateInt64(1))->Int64OrDie(), 2); + EXPECT_TRUE( + cel_map->LegacyHasMapValue(CelValue::CreateInt64(1)).value_or(false)); + + // Look up nonexistent key + EXPECT_EQ((*cel_map)[CelValue::CreateInt64(3)].has_value(), false); +} + +TEST(FieldBackedMapImplTest, BoolKeyTest) { + TestMessage message; + auto field_map = message.mutable_bool_int32_map(); + (*field_map)[false] = 1; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "bool_int32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateBool(false)]->Int64OrDie(), 1); + EXPECT_TRUE(cel_map->Has(CelValue::CreateBool(false)).value_or(false)); + EXPECT_TRUE( + cel_map->LegacyHasMapValue(CelValue::CreateBool(false)).value_or(false)); + // Look up nonexistent key + EXPECT_EQ((*cel_map)[CelValue::CreateBool(true)].has_value(), false); + + (*field_map)[true] = 2; + EXPECT_EQ((*cel_map)[CelValue::CreateBool(true)]->Int64OrDie(), 2); +} + +TEST(FieldBackedMapImplTest, Uint32KeyTest) { + TestMessage message; + auto field_map = message.mutable_uint32_uint32_map(); + (*field_map)[0] = 1u; + (*field_map)[1] = 2u; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "uint32_uint32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(0)]->Uint64OrDie(), 1UL); + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(1)]->Uint64OrDie(), 2UL); + EXPECT_TRUE(cel_map->Has(CelValue::CreateUint64(1)).value_or(false)); + EXPECT_TRUE( + cel_map->LegacyHasMapValue(CelValue::CreateUint64(1)).value_or(false)); + + // Look up nonexistent key + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(3)].has_value(), false); + EXPECT_EQ(cel_map->Has(CelValue::CreateUint64(3)).value_or(true), false); +} + +TEST(FieldBackedMapImplTest, Uint32KeyOutOfRangeTest) { + TestMessage message; + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "uint32_uint32_map", &arena); + + // Look up keys out of uint32_t range + auto result = cel_map->Has( + CelValue::CreateUint64(std::numeric_limits::max() + 1UL)); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kOutOfRange)); +} + +TEST(FieldBackedMapImplTest, Uint64KeyTest) { + TestMessage message; + auto field_map = message.mutable_uint64_int32_map(); + (*field_map)[0] = 1; + (*field_map)[1] = 2; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "uint64_int32_map", &arena); + + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(0)]->Int64OrDie(), 1); + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(1)]->Int64OrDie(), 2); + EXPECT_TRUE(cel_map->Has(CelValue::CreateUint64(1)).value_or(false)); + EXPECT_TRUE( + cel_map->LegacyHasMapValue(CelValue::CreateUint64(1)).value_or(false)); + + // Look up nonexistent key + EXPECT_EQ((*cel_map)[CelValue::CreateUint64(3)].has_value(), false); +} + +TEST(FieldBackedMapImplTest, StringKeyTest) { + TestMessage message; + auto field_map = message.mutable_string_int32_map(); + (*field_map)["test0"] = 1; + (*field_map)["test1"] = 2; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "string_int32_map", &arena); + + std::string test0 = "test0"; + std::string test1 = "test1"; + std::string test_notfound = "test_notfound"; + + EXPECT_EQ((*cel_map)[CelValue::CreateString(&test0)]->Int64OrDie(), 1); + EXPECT_EQ((*cel_map)[CelValue::CreateString(&test1)]->Int64OrDie(), 2); + EXPECT_TRUE(cel_map->Has(CelValue::CreateString(&test1)).value_or(false)); + EXPECT_TRUE(cel_map->LegacyHasMapValue(CelValue::CreateString(&test1)) + .value_or(false)); + + // Look up nonexistent key + EXPECT_EQ((*cel_map)[CelValue::CreateString(&test_notfound)].has_value(), + false); +} + +TEST(FieldBackedMapImplTest, EmptySizeTest) { + TestMessage message; + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "string_int32_map", &arena); + EXPECT_EQ(cel_map->size(), 0); +} + +TEST(FieldBackedMapImplTest, RepeatedAddTest) { + TestMessage message; + auto field_map = message.mutable_string_int32_map(); + (*field_map)["test0"] = 1; + (*field_map)["test1"] = 2; + (*field_map)["test0"] = 3; + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "string_int32_map", &arena); + + EXPECT_EQ(cel_map->size(), 2); +} + +TEST(FieldBackedMapImplTest, KeyListTest) { + TestMessage message; + auto field_map = message.mutable_string_int32_map(); + std::vector keys; + std::vector keys1; + for (int i = 0; i < 100; i++) { + keys.push_back(absl::StrCat("test", i)); + (*field_map)[keys.back()] = i; + } + + google::protobuf::Arena arena; + auto cel_map = CreateMap(&message, "string_int32_map", &arena); + const CelList* key_list = cel_map->ListKeys(); + + EXPECT_EQ(key_list->size(), 100); + for (int i = 0; i < key_list->size(); i++) { + keys1.push_back(std::string((*key_list)[i].StringOrDie().value())); + } + + EXPECT_THAT(keys, UnorderedPointwise(Eq(), keys1)); +} + +} // namespace +} // namespace google::api::expr::runtime::internal From 6a2eaa3e5791adc471f256436e053d4666f9efc3 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 Apr 2022 23:03:10 +0000 Subject: [PATCH 069/111] Update internal factory using definition to be a function ptr instead of std::function. PiperOrigin-RevId: 441603329 --- eval/public/structs/protobuf_value_factory.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/public/structs/protobuf_value_factory.h b/eval/public/structs/protobuf_value_factory.h index 7d4223411..59874daec 100644 --- a/eval/public/structs/protobuf_value_factory.h +++ b/eval/public/structs/protobuf_value_factory.h @@ -30,7 +30,7 @@ namespace google::api::expr::runtime::internal { // // Used to break cyclic dependency between field access and message wrapping -- // not intended for general use. -using ProtobufValueFactory = std::function; +using ProtobufValueFactory = CelValue (*)(const google::protobuf::Message*); } // namespace google::api::expr::runtime::internal #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_VALUE_FACTORY_H_ From 06045a168c90aeb5bb10423b15ce3888c79d5987 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 Apr 2022 23:03:39 +0000 Subject: [PATCH 070/111] Update legacy type APIs to operate on the message wrapper type instead of directly on a CelValue. PiperOrigin-RevId: 441603441 --- eval/eval/create_struct_step.cc | 7 +- eval/public/structs/BUILD | 5 ++ eval/public/structs/legacy_type_adapter.h | 18 ++-- .../structs/legacy_type_adapter_test.cc | 16 +++- .../structs/proto_message_type_adapter.cc | 47 +++++----- .../structs/proto_message_type_adapter.h | 16 ++-- .../proto_message_type_adapter_test.cc | 90 ++++++++++--------- .../protobuf_descriptor_type_provider_test.cc | 8 +- 8 files changed, 120 insertions(+), 87 deletions(-) diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 3328953e4..03caf078d 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -71,7 +71,7 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, } } - CEL_ASSIGN_OR_RETURN(CelValue instance, + CEL_ASSIGN_OR_RETURN(CelValue::MessageWrapper instance, type_adapter_->NewInstance(frame->memory_manager())); int index = 0; @@ -82,9 +82,8 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, entry.field_name, arg, frame->memory_manager(), instance)); } - CEL_RETURN_IF_ERROR( - type_adapter_->AdaptFromWellKnownType(frame->memory_manager(), instance)); - *result = instance; + CEL_ASSIGN_OR_RETURN(*result, type_adapter_->AdaptFromWellKnownType( + frame->memory_manager(), instance)); return absl::OkStatus(); } diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 02360746b..86c6cbd41 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -207,6 +207,8 @@ cc_test( deps = [ ":legacy_type_adapter", "//eval/public:cel_value", + "//eval/public/testing:matchers", + "//eval/testutil:test_message_cc_proto", "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", @@ -223,10 +225,12 @@ cc_library( ":legacy_type_adapter", "//base:memory_manager", "//eval/public:cel_value", + "//eval/public:cel_value_internal", "//eval/public/containers:field_access", "//eval/public/containers:field_backed_list_impl", "//eval/public/containers:field_backed_map_impl", "//extensions/protobuf:memory_manager", + "//internal:casts", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -241,6 +245,7 @@ cc_test( ":cel_proto_wrapper", ":proto_message_type_adapter", "//eval/public:cel_value", + "//eval/public:cel_value_internal", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/containers:field_access", diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h index 237b92b77..872c17ee0 100644 --- a/eval/public/structs/legacy_type_adapter.h +++ b/eval/public/structs/legacy_type_adapter.h @@ -38,16 +38,17 @@ class LegacyTypeMutationApis { // Create a new empty instance of the type. // May return a status if the type is not possible to create. - virtual absl::StatusOr NewInstance( + virtual absl::StatusOr NewInstance( cel::MemoryManager& memory_manager) const = 0; // Normalize special types to a native CEL value after building. // The default implementation is a no-op. // The interpreter guarantees that instance is uniquely owned by the // interpreter, and can be safely mutated. - virtual absl::Status AdaptFromWellKnownType( - cel::MemoryManager& memory_manager, CelValue& instance) const { - return absl::OkStatus(); + virtual absl::StatusOr AdaptFromWellKnownType( + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper instance) const { + return CelValue::CreateMessageWrapper(instance); } // Set field on instance to value. @@ -56,7 +57,7 @@ class LegacyTypeMutationApis { virtual absl::Status SetField(absl::string_view field_name, const CelValue& value, cel::MemoryManager& memory_manager, - CelValue& instance) const = 0; + CelValue::MessageWrapper& instance) const = 0; }; // Interface for access apis. @@ -68,12 +69,13 @@ class LegacyTypeAccessApis { // Return whether an instance of the type has field set to a non-default // value. - virtual absl::StatusOr HasField(absl::string_view field_name, - const CelValue& value) const = 0; + virtual absl::StatusOr HasField( + absl::string_view field_name, + const CelValue::MessageWrapper& value) const = 0; // Access field on instance. virtual absl::StatusOr GetField( - absl::string_view field_name, const CelValue& instance, + absl::string_view field_name, const CelValue::MessageWrapper& instance, cel::MemoryManager& memory_manager) const = 0; }; diff --git a/eval/public/structs/legacy_type_adapter_test.cc b/eval/public/structs/legacy_type_adapter_test.cc index ac2cc53cb..b6fe9a7f5 100644 --- a/eval/public/structs/legacy_type_adapter_test.cc +++ b/eval/public/structs/legacy_type_adapter_test.cc @@ -16,12 +16,15 @@ #include "google/protobuf/arena.h" #include "eval/public/cel_value.h" +#include "eval/public/testing/matchers.h" +#include "eval/testutil/test_message.pb.h" #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { +using testing::EqualsProto; class TestMutationApiImpl : public LegacyTypeMutationApis { public: @@ -30,26 +33,31 @@ class TestMutationApiImpl : public LegacyTypeMutationApis { return false; } - absl::StatusOr NewInstance( + absl::StatusOr NewInstance( cel::MemoryManager& memory_manager) const override { return absl::UnimplementedError("Not implemented"); } absl::Status SetField(absl::string_view field_name, const CelValue& value, cel::MemoryManager& memory_manager, - CelValue& instance) const override { + CelValue::MessageWrapper& instance) const override { return absl::UnimplementedError("Not implemented"); } }; TEST(LegacyTypeAdapterMutationApis, DefaultNoopAdapt) { - CelValue v; + TestMessage message; + internal::MessageWrapper wrapper(&message); google::protobuf::Arena arena; cel::extensions::ProtoMemoryManager manager(&arena); TestMutationApiImpl impl; - EXPECT_OK(impl.AdaptFromWellKnownType(manager, v)); + ASSERT_OK_AND_ASSIGN(CelValue v, + impl.AdaptFromWellKnownType(manager, wrapper)); + + EXPECT_THAT(v, + test::IsCelMessage(EqualsProto(TestMessage::default_instance()))); } } // namespace diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index d48213583..9e7abbd8f 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -21,11 +21,13 @@ #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "eval/public/cel_value.h" +#include "eval/public/cel_value_internal.h" #include "eval/public/containers/field_access.h" #include "eval/public/containers/field_backed_list_impl.h" #include "eval/public/containers/field_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "extensions/protobuf/memory_manager.h" +#include "internal/casts.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { @@ -44,7 +46,7 @@ absl::Status ProtoMessageTypeAdapter::ValidateSetFieldOp( return absl::OkStatus(); } -absl::StatusOr ProtoMessageTypeAdapter::NewInstance( +absl::StatusOr ProtoMessageTypeAdapter::NewInstance( cel::MemoryManager& memory_manager) const { // This implementation requires arena-backed memory manager. google::protobuf::Arena* arena = ProtoMemoryManager::CastToProtoArena(memory_manager); @@ -56,7 +58,7 @@ absl::StatusOr ProtoMessageTypeAdapter::NewInstance( return absl::InvalidArgumentError( absl::StrCat("Failed to create message ", descriptor_->name())); } - return CelValue::CreateMessage(msg); + return CelValue::MessageWrapper(msg); } bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { @@ -64,11 +66,12 @@ bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { } absl::StatusOr ProtoMessageTypeAdapter::HasField( - absl::string_view field_name, const CelValue& value) const { - const google::protobuf::Message* message; - if (!value.GetValue(&message) || message == nullptr) { - return absl::InvalidArgumentError("HasField called on non-message type."); + absl::string_view field_name, const CelValue::MessageWrapper& value) const { + if (!value.HasFullProto() || value.message_ptr() == nullptr) { + return absl::InvalidArgumentError("GetField called on non-message type."); } + const google::protobuf::Message* message = + cel::internal::down_cast(value.message_ptr()); const Reflection* reflection = message->GetReflection(); ABSL_ASSERT(descriptor_ == message->GetDescriptor()); @@ -98,13 +101,13 @@ absl::StatusOr ProtoMessageTypeAdapter::HasField( } absl::StatusOr ProtoMessageTypeAdapter::GetField( - absl::string_view field_name, const CelValue& instance, + absl::string_view field_name, const CelValue::MessageWrapper& instance, cel::MemoryManager& memory_manager) const { - const google::protobuf::Message* message; - if (!instance.GetValue(&message) || message == nullptr) { + if (!instance.HasFullProto() || instance.message_ptr() == nullptr) { return absl::InvalidArgumentError("GetField called on non-message type."); } - + const google::protobuf::Message* message = + cel::internal::down_cast(instance.message_ptr()); const FieldDescriptor* field_desc = descriptor_->FindFieldByName(field_name.data()); if (field_desc == nullptr) { @@ -132,15 +135,19 @@ absl::StatusOr ProtoMessageTypeAdapter::GetField( absl::Status ProtoMessageTypeAdapter::SetField( absl::string_view field_name, const CelValue& value, - cel::MemoryManager& memory_manager, CelValue& instance) const { + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper& instance) const { // Assume proto arena implementation if this provider is used. google::protobuf::Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager); - const google::protobuf::Message* message = nullptr; - if (!instance.GetValue(&message) || message == nullptr) { + + if (!instance.HasFullProto() || instance.message_ptr() == nullptr) { return absl::InternalError("SetField called on non-message type."); } + const google::protobuf::Message* message = + cel::internal::down_cast(instance.message_ptr()); + // Interpreter guarantees this is the top-level instance. google::protobuf::Message* mutable_message = const_cast(message); @@ -207,19 +214,19 @@ absl::Status ProtoMessageTypeAdapter::SetField( return absl::OkStatus(); } -absl::Status ProtoMessageTypeAdapter::AdaptFromWellKnownType( - cel::MemoryManager& memory_manager, CelValue& instance) const { +absl::StatusOr ProtoMessageTypeAdapter::AdaptFromWellKnownType( + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper instance) const { // Assume proto arena implementation if this provider is used. google::protobuf::Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager); - const google::protobuf::Message* message; - if (!instance.GetValue(&message) || message == nullptr) { + if (!instance.HasFullProto() || instance.message_ptr() == nullptr) { return absl::InternalError( "Adapt from well-known type failed: not a message"); } - - instance = CelProtoWrapper::CreateMessage(message, arena); - return absl::OkStatus(); + auto* message = + cel::internal::down_cast(instance.message_ptr()); + return CelProtoWrapper::CreateMessage(message, arena); } } // namespace google::api::expr::runtime diff --git a/eval/public/structs/proto_message_type_adapter.h b/eval/public/structs/proto_message_type_adapter.h index 46cf54d65..42d466a70 100644 --- a/eval/public/structs/proto_message_type_adapter.h +++ b/eval/public/structs/proto_message_type_adapter.h @@ -39,7 +39,7 @@ class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, ~ProtoMessageTypeAdapter() override = default; - absl::StatusOr NewInstance( + absl::StatusOr NewInstance( cel::MemoryManager& memory_manager) const override; bool DefinesField(absl::string_view field_name) const override; @@ -47,17 +47,19 @@ class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, absl::Status SetField(absl::string_view field_name, const CelValue& value, cel::MemoryManager& memory_manager, - CelValue& instance) const override; + CelValue::MessageWrapper& instance) const override; - absl::Status AdaptFromWellKnownType(cel::MemoryManager& memory_manager, - CelValue& instance) const override; + absl::StatusOr AdaptFromWellKnownType( + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper instance) const override; absl::StatusOr GetField( - absl::string_view field_name, const CelValue& instance, + absl::string_view field_name, const CelValue::MessageWrapper& instance, cel::MemoryManager& memory_manager) const override; - absl::StatusOr HasField(absl::string_view field_name, - const CelValue& value) const override; + absl::StatusOr HasField( + absl::string_view field_name, + const CelValue::MessageWrapper& value) const override; private: // Helper for standardizing error messages for SetField operation. diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 90b734256..2089cbd68 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -18,8 +18,10 @@ #include "google/protobuf/descriptor.pb.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" #include "absl/status/status.h" #include "eval/public/cel_value.h" +#include "eval/public/cel_value_internal.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/containers/field_access.h" @@ -50,7 +52,7 @@ TEST(ProtoMessageTypeAdapter, HasFieldSingular) { TestMessage example; - CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + internal::MessageWrapper value(&example); EXPECT_THAT(adapter.HasField("int64_value", value), IsOkAndHolds(false)); example.set_int64_value(10); @@ -67,7 +69,7 @@ TEST(ProtoMessageTypeAdapter, HasFieldRepeated) { TestMessage example; - CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + internal::MessageWrapper value(&example); EXPECT_THAT(adapter.HasField("int64_list", value), IsOkAndHolds(false)); example.add_int64_list(10); @@ -85,7 +87,7 @@ TEST(ProtoMessageTypeAdapter, HasFieldMap) { TestMessage example; example.set_int64_value(10); - CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + internal::MessageWrapper value(&example); EXPECT_THAT(adapter.HasField("int64_int32_map", value), IsOkAndHolds(false)); (*example.mutable_int64_int32_map())[2] = 3; @@ -103,7 +105,7 @@ TEST(ProtoMessageTypeAdapter, HasFieldUnknownField) { TestMessage example; example.set_int64_value(10); - CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + internal::MessageWrapper value(&example); EXPECT_THAT(adapter.HasField("unknown_field", value), StatusIs(absl::StatusCode::kNotFound)); @@ -117,7 +119,8 @@ TEST(ProtoMessageTypeAdapter, HasFieldNonMessageType) { google::protobuf::MessageFactory::generated_factory(), ProtoWrapperTypeOptions::kUnsetNull); - CelValue value = CelValue::CreateInt64(10); + internal::MessageWrapper value( + static_cast(nullptr)); EXPECT_THAT(adapter.HasField("unknown_field", value), StatusIs(absl::StatusCode::kInvalidArgument)); @@ -135,7 +138,7 @@ TEST(ProtoMessageTypeAdapter, GetFieldSingular) { TestMessage example; example.set_int64_value(10); - CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + internal::MessageWrapper value(&example); EXPECT_THAT(adapter.GetField("int64_value", value, manager), IsOkAndHolds(test::IsCelInt64(10))); @@ -153,7 +156,7 @@ TEST(ProtoMessageTypeAdapter, GetFieldNoSuchField) { TestMessage example; example.set_int64_value(10); - CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + internal::MessageWrapper value(&example); EXPECT_THAT(adapter.GetField("unknown_field", value, manager), IsOkAndHolds(test::IsCelError(StatusIs( @@ -169,7 +172,8 @@ TEST(ProtoMessageTypeAdapter, GetFieldNotAMessage) { ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); - CelValue value = CelValue::CreateNull(); + internal::MessageWrapper value( + static_cast(nullptr)); EXPECT_THAT(adapter.GetField("int64_value", value, manager), StatusIs(absl::StatusCode::kInvalidArgument)); @@ -188,7 +192,7 @@ TEST(ProtoMessageTypeAdapter, GetFieldRepeated) { example.add_int64_list(10); example.add_int64_list(20); - CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + internal::MessageWrapper value(&example); ASSERT_OK_AND_ASSIGN(CelValue result, adapter.GetField("int64_list", value, manager)); @@ -213,7 +217,7 @@ TEST(ProtoMessageTypeAdapter, GetFieldMap) { TestMessage example; (*example.mutable_int64_int32_map())[10] = 20; - CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + internal::MessageWrapper value(&example); ASSERT_OK_AND_ASSIGN(CelValue result, adapter.GetField("int64_int32_map", value, manager)); @@ -238,7 +242,7 @@ TEST(ProtoMessageTypeAdapter, GetFieldWrapperType) { TestMessage example; example.mutable_int64_wrapper_value()->set_value(10); - CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + internal::MessageWrapper value(&example); EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), IsOkAndHolds(test::IsCelInt64(10))); @@ -255,7 +259,7 @@ TEST(ProtoMessageTypeAdapter, GetFieldWrapperTypeUnsetNullUnbox) { TestMessage example; - CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + internal::MessageWrapper value(&example); EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), IsOkAndHolds(test::IsCelNull())); @@ -277,7 +281,7 @@ TEST(ProtoMessageTypeAdapter, GetFieldWrapperTypeUnsetDefaultValueUnbox) { TestMessage example; - CelValue value = CelProtoWrapper::CreateMessage(&example, &arena); + internal::MessageWrapper value(&example); EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), IsOkAndHolds(test::IsCelInt64(_))); @@ -299,10 +303,10 @@ TEST(ProtoMessageTypeAdapter, NewInstance) { ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); - ASSERT_OK_AND_ASSIGN(CelValue result, adapter.NewInstance(manager)); - const google::protobuf::Message* message; - ASSERT_TRUE(result.GetValue(&message)); - EXPECT_THAT(message, EqualsProto(TestMessage::default_instance())); + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper result, + adapter.NewInstance(manager)); + EXPECT_THAT(result.message_ptr(), + EqualsProto(TestMessage::default_instance())); } TEST(ProtoMessageTypeAdapter, NewInstanceUnsupportedDescriptor) { @@ -350,14 +354,13 @@ TEST(ProtoMessageTypeAdapter, SetFieldSingular) { ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); - ASSERT_OK_AND_ASSIGN(CelValue value, adapter.NewInstance(manager)); + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper value, + adapter.NewInstance(manager)); ASSERT_OK(adapter.SetField("int64_value", CelValue::CreateInt64(10), manager, value)); - const google::protobuf::Message* message; - ASSERT_TRUE(value.GetValue(&message)); - EXPECT_THAT(message, EqualsProto("int64_value: 10")); + EXPECT_THAT(value.message_ptr(), EqualsProto("int64_value: 10")); ASSERT_THAT(adapter.SetField("not_a_field", CelValue::CreateInt64(10), manager, value), @@ -380,14 +383,13 @@ TEST(ProtoMessageTypeAdapter, SetFieldMap) { CelValue value_to_set = CelValue::CreateMap(&builder); - ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, + adapter.NewInstance(manager)); ASSERT_OK( adapter.SetField("int64_int32_map", value_to_set, manager, instance)); - const google::protobuf::Message* message; - ASSERT_TRUE(instance.GetValue(&message)); - EXPECT_THAT(message, EqualsProto(R"pb( + EXPECT_THAT(instance.message_ptr(), EqualsProto(R"pb( int64_int32_map { key: 1 value: 2 } int64_int32_map { key: 2 value: 4 } )pb")); @@ -405,14 +407,14 @@ TEST(ProtoMessageTypeAdapter, SetFieldRepeated) { ContainerBackedListImpl list( {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); CelValue value_to_set = CelValue::CreateList(&list); - ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, + adapter.NewInstance(manager)); ASSERT_OK(adapter.SetField("int64_list", value_to_set, manager, instance)); - const google::protobuf::Message* message; - ASSERT_TRUE(instance.GetValue(&message)); - EXPECT_THAT(message, EqualsProto(R"pb( - int64_list: 1 int64_list: 2 + EXPECT_THAT(instance.message_ptr(), EqualsProto(R"pb( + int64_list: 1 + int64_list: 2 )pb")); } @@ -425,7 +427,8 @@ TEST(ProtoMessageTypeAdapter, SetFieldNotAField) { ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); - ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, + adapter.NewInstance(manager)); ASSERT_THAT(adapter.SetField("not_a_field", CelValue::CreateInt64(10), manager, instance), @@ -454,7 +457,8 @@ TEST(ProtoMesssageTypeAdapter, SetFieldWrongType) { CelValue int_value = CelValue::CreateInt64(42); - ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, + adapter.NewInstance(manager)); EXPECT_THAT(adapter.SetField("int64_value", map_value, manager, instance), StatusIs(absl::StatusCode::kInvalidArgument)); @@ -483,7 +487,8 @@ TEST(ProtoMesssageTypeAdapter, SetFieldNotAMessage) { cel::extensions::ProtoMemoryManager manager(&arena); CelValue int_value = CelValue::CreateInt64(42); - CelValue instance = CelValue::CreateNull(); + CelValue::MessageWrapper instance( + static_cast(nullptr)); EXPECT_THAT(adapter.SetField("int64_value", int_value, manager, instance), StatusIs(absl::StatusCode::kInternal)); @@ -498,13 +503,15 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownType) { ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); - ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, + adapter.NewInstance(manager)); ASSERT_OK( adapter.SetField("value", CelValue::CreateInt64(42), manager, instance)); - ASSERT_OK(adapter.AdaptFromWellKnownType(manager, instance)); + ASSERT_OK_AND_ASSIGN(CelValue value, + adapter.AdaptFromWellKnownType(manager, instance)); - EXPECT_THAT(instance, test::IsCelInt64(42)); + EXPECT_THAT(value, test::IsCelInt64(42)); } TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeUnspecial) { @@ -516,14 +523,16 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeUnspecial) { ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); - ASSERT_OK_AND_ASSIGN(CelValue instance, adapter.NewInstance(manager)); + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, + adapter.NewInstance(manager)); ASSERT_OK(adapter.SetField("int64_value", CelValue::CreateInt64(42), manager, instance)); - ASSERT_OK(adapter.AdaptFromWellKnownType(manager, instance)); + ASSERT_OK_AND_ASSIGN(CelValue value, + adapter.AdaptFromWellKnownType(manager, instance)); // TestMessage should not be converted to a CEL primitive type. - EXPECT_THAT(instance, test::IsCelMessage(EqualsProto("int64_value: 42"))); + EXPECT_THAT(value, test::IsCelMessage(EqualsProto("int64_value: 42"))); } TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeNotAMessageError) { @@ -535,7 +544,8 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeNotAMessageError) { ProtoWrapperTypeOptions::kUnsetNull); cel::extensions::ProtoMemoryManager manager(&arena); - CelValue instance = CelValue::CreateNull(); + CelValue::MessageWrapper instance( + static_cast(nullptr)); // Interpreter guaranteed to call this with a message type, otherwise, // something has broken. diff --git a/eval/public/structs/protobuf_descriptor_type_provider_test.cc b/eval/public/structs/protobuf_descriptor_type_provider_test.cc index 4443bb59a..39d153026 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider_test.cc +++ b/eval/public/structs/protobuf_descriptor_type_provider_test.cc @@ -35,17 +35,17 @@ TEST(ProtobufDescriptorProvider, Basic) { ASSERT_TRUE(type_adapter->mutation_apis() != nullptr); ASSERT_TRUE(type_adapter->mutation_apis()->DefinesField("value")); - ASSERT_OK_AND_ASSIGN(CelValue value, + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper value, type_adapter->mutation_apis()->NewInstance(manager)); - ASSERT_TRUE(value.IsMessage()); ASSERT_OK(type_adapter->mutation_apis()->SetField( "value", CelValue::CreateInt64(10), manager, value)); - ASSERT_OK( + ASSERT_OK_AND_ASSIGN( + CelValue adapted, type_adapter->mutation_apis()->AdaptFromWellKnownType(manager, value)); - EXPECT_THAT(value, test::IsCelInt64(10)); + EXPECT_THAT(adapted, test::IsCelInt64(10)); } // This is an implementation detail, but testing for coverage. From be9830bc2b52376e2255e0a4796627323c91a982 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 Apr 2022 23:04:11 +0000 Subject: [PATCH 071/111] Move wrapper type unboxing to a parameter on the GetField API PiperOrigin-RevId: 441603554 --- eval/public/structs/BUILD | 2 + eval/public/structs/legacy_type_adapter.h | 2 + .../structs/proto_message_type_adapter.cc | 3 +- .../structs/proto_message_type_adapter.h | 12 +- .../proto_message_type_adapter_test.cc | 117 ++++++++---------- .../protobuf_descriptor_type_provider.cc | 4 +- .../protobuf_descriptor_type_provider.h | 4 +- 7 files changed, 66 insertions(+), 78 deletions(-) diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 86c6cbd41..1ca7e1487 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -196,6 +196,7 @@ cc_library( hdrs = ["legacy_type_adapter.h"], deps = [ "//base:memory_manager", + "//eval/public:cel_options", "//eval/public:cel_value", "@com_google_absl//absl/status", ], @@ -224,6 +225,7 @@ cc_library( ":cel_proto_wrapper", ":legacy_type_adapter", "//base:memory_manager", + "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:cel_value_internal", "//eval/public/containers:field_access", diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h index 872c17ee0..a5dfcfb6f 100644 --- a/eval/public/structs/legacy_type_adapter.h +++ b/eval/public/structs/legacy_type_adapter.h @@ -20,6 +20,7 @@ #include "absl/status/status.h" #include "base/memory_manager.h" +#include "eval/public/cel_options.h" #include "eval/public/cel_value.h" namespace google::api::expr::runtime { @@ -76,6 +77,7 @@ class LegacyTypeAccessApis { // Access field on instance. virtual absl::StatusOr GetField( absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, cel::MemoryManager& memory_manager) const = 0; }; diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index 9e7abbd8f..08af0607c 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -102,6 +102,7 @@ absl::StatusOr ProtoMessageTypeAdapter::HasField( absl::StatusOr ProtoMessageTypeAdapter::GetField( absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, cel::MemoryManager& memory_manager) const { if (!instance.HasFullProto() || instance.message_ptr() == nullptr) { return absl::InvalidArgumentError("GetField called on non-message type."); @@ -129,7 +130,7 @@ absl::StatusOr ProtoMessageTypeAdapter::GetField( CelValue result; CEL_RETURN_IF_ERROR(CreateValueFromSingleField( - message, field_desc, unboxing_option_, arena, &result)); + message, field_desc, unboxing_option, arena, &result)); return result; } diff --git a/eval/public/structs/proto_message_type_adapter.h b/eval/public/structs/proto_message_type_adapter.h index 42d466a70..478354fbb 100644 --- a/eval/public/structs/proto_message_type_adapter.h +++ b/eval/public/structs/proto_message_type_adapter.h @@ -20,8 +20,8 @@ #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "base/memory_manager.h" +#include "eval/public/cel_options.h" #include "eval/public/cel_value.h" -#include "eval/public/containers/field_access.h" #include "eval/public/structs/legacy_type_adapter.h" namespace google::api::expr::runtime { @@ -30,12 +30,8 @@ class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, public LegacyTypeMutationApis { public: ProtoMessageTypeAdapter(const google::protobuf::Descriptor* descriptor, - google::protobuf::MessageFactory* message_factory, - ProtoWrapperTypeOptions unboxing_option = - ProtoWrapperTypeOptions::kUnsetNull) - : message_factory_(message_factory), - descriptor_(descriptor), - unboxing_option_(unboxing_option) {} + google::protobuf::MessageFactory* message_factory) + : message_factory_(message_factory), descriptor_(descriptor) {} ~ProtoMessageTypeAdapter() override = default; @@ -55,6 +51,7 @@ class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, absl::StatusOr GetField( absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, cel::MemoryManager& memory_manager) const override; absl::StatusOr HasField( @@ -68,7 +65,6 @@ class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, google::protobuf::MessageFactory* message_factory_; const google::protobuf::Descriptor* descriptor_; - ProtoWrapperTypeOptions unboxing_option_; }; } // namespace google::api::expr::runtime diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 2089cbd68..3d65be7ef 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -47,8 +47,7 @@ TEST(ProtoMessageTypeAdapter, HasFieldSingular) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); TestMessage example; @@ -64,8 +63,7 @@ TEST(ProtoMessageTypeAdapter, HasFieldRepeated) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); TestMessage example; @@ -81,8 +79,7 @@ TEST(ProtoMessageTypeAdapter, HasFieldMap) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); TestMessage example; example.set_int64_value(10); @@ -99,8 +96,7 @@ TEST(ProtoMessageTypeAdapter, HasFieldUnknownField) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); TestMessage example; example.set_int64_value(10); @@ -116,8 +112,7 @@ TEST(ProtoMessageTypeAdapter, HasFieldNonMessageType) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); internal::MessageWrapper value( static_cast(nullptr)); @@ -131,8 +126,7 @@ TEST(ProtoMessageTypeAdapter, GetFieldSingular) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); TestMessage example; @@ -140,7 +134,8 @@ TEST(ProtoMessageTypeAdapter, GetFieldSingular) { internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.GetField("int64_value", value, manager), + EXPECT_THAT(adapter.GetField("int64_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelInt64(10))); } @@ -149,8 +144,7 @@ TEST(ProtoMessageTypeAdapter, GetFieldNoSuchField) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); TestMessage example; @@ -158,7 +152,8 @@ TEST(ProtoMessageTypeAdapter, GetFieldNoSuchField) { internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.GetField("unknown_field", value, manager), + EXPECT_THAT(adapter.GetField("unknown_field", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelError(StatusIs( absl::StatusCode::kNotFound, HasSubstr("unknown_field"))))); } @@ -168,14 +163,14 @@ TEST(ProtoMessageTypeAdapter, GetFieldNotAMessage) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); internal::MessageWrapper value( static_cast(nullptr)); - EXPECT_THAT(adapter.GetField("int64_value", value, manager), + EXPECT_THAT(adapter.GetField("int64_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), StatusIs(absl::StatusCode::kInvalidArgument)); } @@ -184,8 +179,7 @@ TEST(ProtoMessageTypeAdapter, GetFieldRepeated) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); TestMessage example; @@ -194,8 +188,10 @@ TEST(ProtoMessageTypeAdapter, GetFieldRepeated) { internal::MessageWrapper value(&example); - ASSERT_OK_AND_ASSIGN(CelValue result, - adapter.GetField("int64_list", value, manager)); + ASSERT_OK_AND_ASSIGN( + CelValue result, + adapter.GetField("int64_list", value, ProtoWrapperTypeOptions::kUnsetNull, + manager)); const CelList* held_value; ASSERT_TRUE(result.GetValue(&held_value)) << result.DebugString(); @@ -210,8 +206,7 @@ TEST(ProtoMessageTypeAdapter, GetFieldMap) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); TestMessage example; @@ -219,8 +214,10 @@ TEST(ProtoMessageTypeAdapter, GetFieldMap) { internal::MessageWrapper value(&example); - ASSERT_OK_AND_ASSIGN(CelValue result, - adapter.GetField("int64_int32_map", value, manager)); + ASSERT_OK_AND_ASSIGN( + CelValue result, + adapter.GetField("int64_int32_map", value, + ProtoWrapperTypeOptions::kUnsetNull, manager)); const CelMap* held_value; ASSERT_TRUE(result.GetValue(&held_value)) << result.DebugString(); @@ -235,8 +232,7 @@ TEST(ProtoMessageTypeAdapter, GetFieldWrapperType) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); TestMessage example; @@ -244,7 +240,8 @@ TEST(ProtoMessageTypeAdapter, GetFieldWrapperType) { internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), + EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelInt64(10))); } @@ -253,20 +250,21 @@ TEST(ProtoMessageTypeAdapter, GetFieldWrapperTypeUnsetNullUnbox) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); TestMessage example; internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), + EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelNull())); // Wrapper field present, but default value. example.mutable_int64_wrapper_value()->clear_value(); - EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), + EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelInt64(_))); } @@ -275,23 +273,26 @@ TEST(ProtoMessageTypeAdapter, GetFieldWrapperTypeUnsetDefaultValueUnbox) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetProtoDefault); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); TestMessage example; internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), - IsOkAndHolds(test::IsCelInt64(_))); + EXPECT_THAT( + adapter.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetProtoDefault, manager), + IsOkAndHolds(test::IsCelInt64(_))); // Wrapper field present with unset value is used to signal Null, but legacy // behavior just returns the proto default value. example.mutable_int64_wrapper_value()->clear_value(); // Same behavior for this option. - EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, manager), - IsOkAndHolds(test::IsCelInt64(_))); + EXPECT_THAT( + adapter.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetProtoDefault, manager), + IsOkAndHolds(test::IsCelInt64(_))); } TEST(ProtoMessageTypeAdapter, NewInstance) { @@ -299,8 +300,7 @@ TEST(ProtoMessageTypeAdapter, NewInstance) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper result, @@ -323,8 +323,7 @@ TEST(ProtoMessageTypeAdapter, NewInstanceUnsupportedDescriptor) { ProtoMessageTypeAdapter adapter( pool.FindMessageTypeByName("google.api.expr.runtime.FakeMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); // Message factory doesn't know how to create our custom message, even though @@ -338,8 +337,7 @@ TEST(ProtoMessageTypeAdapter, DefinesField) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); EXPECT_TRUE(adapter.DefinesField("int64_value")); EXPECT_FALSE(adapter.DefinesField("not_a_field")); @@ -350,8 +348,7 @@ TEST(ProtoMessageTypeAdapter, SetFieldSingular) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper value, @@ -373,8 +370,7 @@ TEST(ProtoMessageTypeAdapter, SetFieldMap) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); CelMapBuilder builder; @@ -400,8 +396,7 @@ TEST(ProtoMessageTypeAdapter, SetFieldRepeated) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); ContainerBackedListImpl list( @@ -423,8 +418,7 @@ TEST(ProtoMessageTypeAdapter, SetFieldNotAField) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, @@ -441,8 +435,7 @@ TEST(ProtoMesssageTypeAdapter, SetFieldWrongType) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); ContainerBackedListImpl list( @@ -482,8 +475,7 @@ TEST(ProtoMesssageTypeAdapter, SetFieldNotAMessage) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); CelValue int_value = CelValue::CreateInt64(42); @@ -499,8 +491,7 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownType) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.protobuf.Int64Value"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, @@ -519,8 +510,7 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeUnspecial) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, @@ -540,8 +530,7 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeNotAMessageError) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory(), - ProtoWrapperTypeOptions::kUnsetNull); + google::protobuf::MessageFactory::generated_factory()); cel::extensions::ProtoMemoryManager manager(&arena); CelValue::MessageWrapper instance( diff --git a/eval/public/structs/protobuf_descriptor_type_provider.cc b/eval/public/structs/protobuf_descriptor_type_provider.cc index 214d84ee5..6467c7835 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.cc +++ b/eval/public/structs/protobuf_descriptor_type_provider.cc @@ -53,7 +53,7 @@ std::unique_ptr ProtobufDescriptorProvider::GetType( return nullptr; } - return std::make_unique(descriptor, message_factory_, - unboxing_option_); + return std::make_unique(descriptor, + message_factory_); } } // namespace google::api::expr::runtime diff --git a/eval/public/structs/protobuf_descriptor_type_provider.h b/eval/public/structs/protobuf_descriptor_type_provider.h index 1d0c3a669..4a04e9056 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.h +++ b/eval/public/structs/protobuf_descriptor_type_provider.h @@ -37,9 +37,7 @@ class ProtobufDescriptorProvider : public LegacyTypeProvider { public: ProtobufDescriptorProvider(const google::protobuf::DescriptorPool* pool, google::protobuf::MessageFactory* factory) - : descriptor_pool_(pool), - message_factory_(factory), - unboxing_option_(ProtoWrapperTypeOptions::kUnsetNull) {} + : descriptor_pool_(pool), message_factory_(factory) {} absl::optional ProvideLegacyType( absl::string_view name) const override; From 532df2c17f07a11123eabb87aecb5f65a1cabc9d Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 14 Apr 2022 18:40:47 +0000 Subject: [PATCH 072/111] Internal change PiperOrigin-RevId: 441814920 --- base/internal/value.post.h | 5 ++ base/internal/value.pre.h | 3 + base/value.h | 111 ++++++++++++++++++++++- base/value_factory.h | 9 ++ base/value_test.cc | 180 +++++++++++++++++++++++++++++++++++++ 5 files changed, 304 insertions(+), 4 deletions(-) diff --git a/base/internal/value.post.h b/base/internal/value.post.h index 522d917b0..cbef6bf19 100644 --- a/base/internal/value.post.h +++ b/base/internal/value.post.h @@ -50,6 +50,10 @@ inline internal::TypeInfo GetListValueTypeId(const ListValue& list_value) { return list_value.TypeId(); } +inline internal::TypeInfo GetMapValueTypeId(const MapValue& map_value) { + return map_value.TypeId(); +} + // Implementation of BytesValue that is stored inlined within a handle. Since // absl::Cord is reference counted itself, this is more efficient than storing // this on the heap. @@ -674,6 +678,7 @@ CEL_INTERNAL_VALUE_DECL(TimestampValue); CEL_INTERNAL_VALUE_DECL(EnumValue); CEL_INTERNAL_VALUE_DECL(StructValue); CEL_INTERNAL_VALUE_DECL(ListValue); +CEL_INTERNAL_VALUE_DECL(MapValue); #undef CEL_INTERNAL_VALUE_DECL } // namespace cel diff --git a/base/internal/value.pre.h b/base/internal/value.pre.h index f38af32e4..88b32e365 100644 --- a/base/internal/value.pre.h +++ b/base/internal/value.pre.h @@ -29,6 +29,7 @@ namespace cel { class EnumValue; class StructValue; class ListValue; +class MapValue; namespace base_internal { @@ -54,6 +55,8 @@ internal::TypeInfo GetStructValueTypeId(const StructValue& struct_value); internal::TypeInfo GetListValueTypeId(const ListValue& list_value); +internal::TypeInfo GetMapValueTypeId(const MapValue& map_value); + class InlinedCordBytesValue; class InlinedStringViewBytesValue; class StringBytesValue; diff --git a/base/value.h b/base/value.h index 4637c1e58..79e7aa327 100644 --- a/base/value.h +++ b/base/value.h @@ -54,6 +54,7 @@ class TimestampValue; class EnumValue; class StructValue; class ListValue; +class MapValue; class ValueFactory; class TypedListValueFactory; @@ -90,6 +91,7 @@ class Value : public base_internal::Resource { friend class EnumValue; friend class StructValue; friend class ListValue; + friend class MapValue; friend class base_internal::ValueHandleBase; friend class base_internal::StringBytesValue; friend class base_internal::ExternalDataBytesValue; @@ -820,10 +822,7 @@ class ListValue : public Value { public: // TODO(issues/5): implement iterators so we can have cheap concated lists - Transient type() const final { - ABSL_ASSERT(type_); - return type_; - } + Transient type() const final { return type_; } Kind kind() const final { return Kind::kList; } @@ -925,6 +924,110 @@ class ListValue : public Value { return ::cel::internal::TypeId(); \ } +// MapValue represents an instance of cel::MapType. +class MapValue : public Value { + public: + Transient type() const final { return type_; } + + Kind kind() const final { return Kind::kMap; } + + virtual size_t size() const = 0; + + virtual bool empty() const { return size() == 0; } + + virtual absl::StatusOr> Get( + ValueFactory& value_factory, const Transient& key) const = 0; + + virtual absl::StatusOr Has(const Transient& key) const = 0; + + protected: + explicit MapValue(const Persistent& type) : type_(type) {} + + private: + friend internal::TypeInfo base_internal::GetMapValueTypeId( + const MapValue& map_value); + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kMap; } + + MapValue(const MapValue&) = delete; + MapValue(MapValue&&) = delete; + + bool Equals(const Value& other) const override = 0; + void HashValue(absl::HashState state) const override = 0; + + std::pair SizeAndAlignment() const override = 0; + + // Called by CEL_IMPLEMENT_ENUM_VALUE() and Is() to perform type checking. + virtual internal::TypeInfo TypeId() const = 0; + + // Set lazily, by EnumValue::New. + Persistent type_; +}; + +// TODO(issues/5): generalize the macros to avoid repeating them when they +// are ultimately very similar. + +// CEL_DECLARE_MAP_VALUE declares `map_value` as an map value. It must +// be part of the class definition of `map_value`. +// +// class MyMapValue : public cel::MapValue { +// ... +// private: +// CEL_DECLARE_MAP_VALUE(MyMapValue); +// }; +#define CEL_DECLARE_MAP_VALUE(map_value) \ + private: \ + friend class ::cel::base_internal::ValueHandleBase; \ + \ + static bool Is(const ::cel::Value& value); \ + \ + ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ + \ + ::cel::internal::TypeInfo TypeId() const override; + +// CEL_IMPLEMENT_MAP_VALUE implements `map_value` as an map +// value. It must be called after the class definition of `map_value`. +// +// class MyMapValue : public cel::MapValue { +// ... +// private: +// CEL_DECLARE_MAP_VALUE(MyMapValue); +// }; +// +// CEL_IMPLEMENT_MAP_VALUE(MyMapValue); +#define CEL_IMPLEMENT_MAP_VALUE(map_value) \ + static_assert(::std::is_base_of_v<::cel::MapValue, map_value>, \ + #map_value " must inherit from cel::MapValue"); \ + static_assert(!::std::is_abstract_v, \ + "this must not be abstract"); \ + \ + bool map_value::Is(const ::cel::Value& value) { \ + return value.kind() == ::cel::Kind::kMap && \ + ::cel::base_internal::GetMapValueTypeId( \ + ::cel::internal::down_cast(value)) == \ + ::cel::internal::TypeId(); \ + } \ + \ + ::std::pair<::std::size_t, ::std::size_t> map_value::SizeAndAlignment() \ + const { \ + static_assert( \ + ::std::is_same_v>>, \ + "this must be the same as " #map_value); \ + return ::std::pair<::std::size_t, ::std::size_t>(sizeof(map_value), \ + alignof(map_value)); \ + } \ + \ + ::cel::internal::TypeInfo map_value::TypeId() const { \ + return ::cel::internal::TypeId(); \ + } + } // namespace cel // value.pre.h forward declares types so they can be friended above. The types diff --git a/base/value_factory.h b/base/value_factory.h index 450673213..0d1638f97 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -161,6 +161,15 @@ class ValueFactory final { std::forward(args)...); } + template + EnableIfBaseOfT>> CreateMapValue( + const Persistent& type, + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return base_internal::PersistentHandleFactory::template Make< + std::remove_const_t>(memory_manager(), type, + std::forward(args)...); + } + private: friend class BytesValue; friend class StringValue; diff --git a/base/value_test.cc b/base/value_test.cc index 8a20ab43f..3a361b36e 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -356,6 +357,9 @@ class TestListValue final : public ListValue { absl::StatusOr> Get(ValueFactory& value_factory, size_t index) const override { + if (index >= size()) { + return absl::OutOfRangeError(""); + } return value_factory.CreateIntValue(elements_[index]); } @@ -383,6 +387,69 @@ class TestListValue final : public ListValue { CEL_IMPLEMENT_LIST_VALUE(TestListValue); +class TestMapValue final : public MapValue { + public: + explicit TestMapValue(const Persistent& type, + std::map entries) + : MapValue(type), entries_(std::move(entries)) { + ABSL_ASSERT(type->key().Is()); + ABSL_ASSERT(type->value().Is()); + } + + size_t size() const override { return entries_.size(); } + + absl::StatusOr> Get( + ValueFactory& value_factory, + const Transient& key) const override { + if (!key.Is()) { + return absl::InvalidArgumentError(""); + } + auto entry = entries_.find(key.As()->ToString()); + if (entry == entries_.end()) { + return absl::NotFoundError(""); + } + return value_factory.CreateIntValue(entry->second); + } + + absl::StatusOr Has(const Transient& key) const override { + if (!key.Is()) { + return absl::InvalidArgumentError(""); + } + auto entry = entries_.find(key.As()->ToString()); + if (entry == entries_.end()) { + return false; + } + return true; + } + + std::string DebugString() const override { + std::vector parts; + for (const auto& entry : entries_) { + parts.push_back(absl::StrCat(internal::FormatStringLiteral(entry.first), + ": ", entry.second)); + } + return absl::StrCat("{", absl::StrJoin(parts, ", "), "}"); + } + + const std::map& value() const { return entries_; } + + private: + bool Equals(const Value& other) const override { + return Is(other) && + entries_ == internal::down_cast(other).entries_; + } + + void HashValue(absl::HashState state) const override { + absl::HashState::combine(std::move(state), type(), entries_); + } + + std::map entries_; + + CEL_DECLARE_MAP_VALUE(TestMapValue); +}; + +CEL_IMPLEMENT_MAP_VALUE(TestMapValue); + template Persistent Must(absl::StatusOr> status_or_handle) { return std::move(status_or_handle).value(); @@ -549,6 +616,14 @@ INSTANTIATE_TEST_SUITE_P( Must(type_factory.CreateListType(type_factory.GetIntType())), std::vector{})); }}, + {"Map", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must(value_factory.CreateMapValue( + Must(type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetIntType())), + std::map{})); + }}, }), [](const testing::TestParamInfo& info) { return info.param.name; @@ -2047,6 +2122,104 @@ TEST(ListValue, Get) { value_factory.CreateIntValue(1)); EXPECT_EQ(Must(list_value->Get(value_factory, 2)), value_factory.CreateIntValue(2)); + EXPECT_THAT(list_value->Get(value_factory, 3), + StatusIs(absl::StatusCode::kOutOfRange)); +} + +TEST(Value, Map) { + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto map_type, + type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetIntType())); + ASSERT_OK_AND_ASSIGN(auto zero_value, + value_factory.CreateMapValue( + map_type, std::map{})); + EXPECT_TRUE(zero_value.Is()); + EXPECT_TRUE(zero_value.Is()); + EXPECT_FALSE(zero_value.Is()); + EXPECT_EQ(zero_value, zero_value); + EXPECT_EQ(zero_value, Must(value_factory.CreateMapValue( + map_type, std::map{}))); + EXPECT_EQ(zero_value->kind(), Kind::kMap); + EXPECT_EQ(zero_value->type(), map_type); + EXPECT_EQ(zero_value.As()->value(), + (std::map{})); + + ASSERT_OK_AND_ASSIGN( + auto one_value, + value_factory.CreateMapValue( + map_type, std::map{{"foo", 1}})); + EXPECT_TRUE(one_value.Is()); + EXPECT_TRUE(one_value.Is()); + EXPECT_FALSE(one_value.Is()); + EXPECT_EQ(one_value, one_value); + EXPECT_EQ(one_value->kind(), Kind::kMap); + EXPECT_EQ(one_value->type(), map_type); + EXPECT_EQ(one_value.As()->value(), + (std::map{{"foo", 1}})); + + EXPECT_NE(zero_value, one_value); + EXPECT_NE(one_value, zero_value); +} + +TEST(MapValue, DebugString) { + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto map_type, + type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetIntType())); + ASSERT_OK_AND_ASSIGN(auto map_value, + value_factory.CreateMapValue( + map_type, std::map{})); + EXPECT_EQ(map_value->DebugString(), "{}"); + ASSERT_OK_AND_ASSIGN(map_value, + value_factory.CreateMapValue( + map_type, std::map{ + {"foo", 1}, {"bar", 2}, {"baz", 3}})); + EXPECT_EQ(map_value->DebugString(), "{\"bar\": 2, \"baz\": 3, \"foo\": 1}"); +} + +TEST(MapValue, GetAndHas) { + ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + ASSERT_OK_AND_ASSIGN(auto map_type, + type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetIntType())); + ASSERT_OK_AND_ASSIGN(auto map_value, + value_factory.CreateMapValue( + map_type, std::map{})); + EXPECT_TRUE(map_value->empty()); + EXPECT_EQ(map_value->size(), 0); + + ASSERT_OK_AND_ASSIGN(map_value, + value_factory.CreateMapValue( + map_type, std::map{ + {"foo", 1}, {"bar", 2}, {"baz", 3}})); + EXPECT_FALSE(map_value->empty()); + EXPECT_EQ(map_value->size(), 3); + EXPECT_EQ(Must(map_value->Get(value_factory, + Must(value_factory.CreateStringValue("foo")))), + value_factory.CreateIntValue(1)); + EXPECT_THAT(map_value->Has(Must(value_factory.CreateStringValue("foo"))), + IsOkAndHolds(true)); + EXPECT_EQ(Must(map_value->Get(value_factory, + Must(value_factory.CreateStringValue("bar")))), + value_factory.CreateIntValue(2)); + EXPECT_THAT(map_value->Has(Must(value_factory.CreateStringValue("bar"))), + IsOkAndHolds(true)); + EXPECT_EQ(Must(map_value->Get(value_factory, + Must(value_factory.CreateStringValue("baz")))), + value_factory.CreateIntValue(3)); + EXPECT_THAT(map_value->Has(Must(value_factory.CreateStringValue("baz"))), + IsOkAndHolds(true)); + EXPECT_THAT(map_value->Get(value_factory, value_factory.CreateIntValue(0)), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(map_value->Get(value_factory, + Must(value_factory.CreateStringValue("missing"))), + StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(map_value->Has(Must(value_factory.CreateStringValue("missing"))), + IsOkAndHolds(false)); } TEST(Value, SupportsAbslHash) { @@ -2066,6 +2239,12 @@ TEST(Value, SupportsAbslHash) { ASSERT_OK_AND_ASSIGN(auto list_value, value_factory.CreateListValue( list_type, std::vector{})); + ASSERT_OK_AND_ASSIGN(auto map_type, + type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetIntType())); + ASSERT_OK_AND_ASSIGN(auto map_value, + value_factory.CreateMapValue( + map_type, std::map{})); EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ Persistent(value_factory.GetNullValue()), Persistent( @@ -2089,6 +2268,7 @@ TEST(Value, SupportsAbslHash) { Persistent(enum_value), Persistent(struct_value), Persistent(list_value), + Persistent(map_value), })); } From 175a8ab319ec46772973276a53eff3715ddf4db6 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 14 Apr 2022 22:01:06 +0000 Subject: [PATCH 073/111] Internal change PiperOrigin-RevId: 441863363 --- base/internal/type.pre.h | 37 ++++++++ base/internal/value.pre.h | 37 ++++++++ base/type.h | 80 ++---------------- base/value.h | 172 +++++--------------------------------- 4 files changed, 101 insertions(+), 225 deletions(-) diff --git a/base/internal/type.pre.h b/base/internal/type.pre.h index 2886ac5fc..caaee404c 100644 --- a/base/internal/type.pre.h +++ b/base/internal/type.pre.h @@ -56,4 +56,41 @@ internal::TypeInfo GetStructTypeTypeId(const StructType& struct_type); } // namespace cel +#define CEL_INTERNAL_DECLARE_TYPE(base, derived) \ + private: \ + friend class ::cel::base_internal::TypeHandleBase; \ + \ + static bool Is(const ::cel::Type& type); \ + \ + ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ + \ + ::cel::internal::TypeInfo TypeId() const override; + +#define CEL_INTERNAL_IMPLEMENT_TYPE(base, derived) \ + static_assert(::std::is_base_of_v<::cel::base##Type, derived>, \ + #derived " must inherit from cel::" #base "Type"); \ + static_assert(!::std::is_abstract_v, "this must not be abstract"); \ + \ + bool derived::Is(const ::cel::Type& type) { \ + return type.kind() == ::cel::Kind::k##base && \ + ::cel::base_internal::Get##base##TypeTypeId( \ + ::cel::internal::down_cast(type)) == \ + ::cel::internal::TypeId(); \ + } \ + \ + ::std::pair<::std::size_t, ::std::size_t> derived::SizeAndAlignment() \ + const { \ + static_assert( \ + ::std::is_same_v>>, \ + "this must be the same as " #derived); \ + return ::std::pair<::std::size_t, ::std::size_t>(sizeof(derived), \ + alignof(derived)); \ + } \ + \ + ::cel::internal::TypeInfo derived::TypeId() const { \ + return ::cel::internal::TypeId(); \ + } + #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_TYPE_PRE_H_ diff --git a/base/internal/value.pre.h b/base/internal/value.pre.h index 88b32e365..4441bc7d9 100644 --- a/base/internal/value.pre.h +++ b/base/internal/value.pre.h @@ -177,4 +177,41 @@ struct ExternalData final { } // namespace cel +#define CEL_INTERNAL_DECLARE_VALUE(base, derived) \ + private: \ + friend class ::cel::base_internal::ValueHandleBase; \ + \ + static bool Is(const ::cel::Value& value); \ + \ + ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ + \ + ::cel::internal::TypeInfo TypeId() const override; + +#define CEL_INTERNAL_IMPLEMENT_VALUE(base, derived) \ + static_assert(::std::is_base_of_v<::cel::base##Value, derived>, \ + #derived " must inherit from cel::" #base "Value"); \ + static_assert(!::std::is_abstract_v, "this must not be abstract"); \ + \ + bool derived::Is(const ::cel::Value& value) { \ + return value.kind() == ::cel::Kind::k##base && \ + ::cel::base_internal::Get##base##ValueTypeId( \ + ::cel::internal::down_cast( \ + value)) == ::cel::internal::TypeId(); \ + } \ + \ + ::std::pair<::std::size_t, ::std::size_t> derived::SizeAndAlignment() \ + const { \ + static_assert( \ + ::std::is_same_v>>, \ + "this must be the same as " #derived); \ + return ::std::pair<::std::size_t, ::std::size_t>(sizeof(derived), \ + alignof(derived)); \ + } \ + \ + ::cel::internal::TypeInfo derived::TypeId() const { \ + return ::cel::internal::TypeId(); \ + } + #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_PRE_H_ diff --git a/base/type.h b/base/type.h index e619ced7c..5a08b2706 100644 --- a/base/type.h +++ b/base/type.h @@ -518,15 +518,8 @@ class EnumType : public Type { // private: // CEL_DECLARE_ENUM_TYPE(MyEnumType); // }; -#define CEL_DECLARE_ENUM_TYPE(enum_type) \ - private: \ - friend class ::cel::base_internal::TypeHandleBase; \ - \ - static bool Is(const ::cel::Type& type); \ - \ - ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ - \ - ::cel::internal::TypeInfo TypeId() const override; +#define CEL_DECLARE_ENUM_TYPE(enum_type) \ + CEL_INTERNAL_DECLARE_TYPE(Enum, enum_type) // CEL_IMPLEMENT_ENUM_TYPE implements `enum_type` as an enumeration type. It // must be called after the class definition of `enum_type`. @@ -538,33 +531,8 @@ class EnumType : public Type { // }; // // CEL_IMPLEMENT_ENUM_TYPE(MyEnumType); -#define CEL_IMPLEMENT_ENUM_TYPE(enum_type) \ - static_assert(::std::is_base_of_v<::cel::EnumType, enum_type>, \ - #enum_type " must inherit from cel::EnumType"); \ - static_assert(!::std::is_abstract_v, \ - "this must not be abstract"); \ - \ - bool enum_type::Is(const ::cel::Type& type) { \ - return type.kind() == ::cel::Kind::kEnum && \ - ::cel::base_internal::GetEnumTypeTypeId( \ - ::cel::internal::down_cast(type)) == \ - ::cel::internal::TypeId(); \ - } \ - \ - ::std::pair<::std::size_t, ::std::size_t> enum_type::SizeAndAlignment() \ - const { \ - static_assert( \ - ::std::is_same_v>>, \ - "this must be the same as " #enum_type); \ - return ::std::pair<::std::size_t, ::std::size_t>(sizeof(enum_type), \ - alignof(enum_type)); \ - } \ - \ - ::cel::internal::TypeInfo enum_type::TypeId() const { \ - return ::cel::internal::TypeId(); \ - } +#define CEL_IMPLEMENT_ENUM_TYPE(enum_type) \ + CEL_INTERNAL_IMPLEMENT_TYPE(Enum, enum_type) // StructType represents an struct type. An struct is a set of fields // that can be looked up by name and/or number. @@ -646,15 +614,8 @@ class StructType : public Type { // private: // CEL_DECLARE_STRUCT_TYPE(MyStructType); // }; -#define CEL_DECLARE_STRUCT_TYPE(struct_type) \ - private: \ - friend class ::cel::base_internal::TypeHandleBase; \ - \ - static bool Is(const ::cel::Type& type); \ - \ - ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ - \ - ::cel::internal::TypeInfo TypeId() const override; +#define CEL_DECLARE_STRUCT_TYPE(struct_type) \ + CEL_INTERNAL_DECLARE_TYPE(Struct, struct_type) // CEL_IMPLEMENT_ENUM_TYPE implements `struct_type` as an struct type. It // must be called after the class definition of `struct_type`. @@ -666,33 +627,8 @@ class StructType : public Type { // }; // // CEL_IMPLEMENT_STRUCT_TYPE(MyStructType); -#define CEL_IMPLEMENT_STRUCT_TYPE(struct_type) \ - static_assert(::std::is_base_of_v<::cel::StructType, struct_type>, \ - #struct_type " must inherit from cel::StructType"); \ - static_assert(!::std::is_abstract_v, \ - "this must not be abstract"); \ - \ - bool struct_type::Is(const ::cel::Type& type) { \ - return type.kind() == ::cel::Kind::kStruct && \ - ::cel::base_internal::GetStructTypeTypeId( \ - ::cel::internal::down_cast(type)) == \ - ::cel::internal::TypeId(); \ - } \ - \ - ::std::pair<::std::size_t, ::std::size_t> struct_type::SizeAndAlignment() \ - const { \ - static_assert( \ - ::std::is_same_v>>, \ - "this must be the same as " #struct_type); \ - return ::std::pair<::std::size_t, ::std::size_t>(sizeof(struct_type), \ - alignof(struct_type)); \ - } \ - \ - ::cel::internal::TypeInfo struct_type::TypeId() const { \ - return ::cel::internal::TypeId(); \ - } +#define CEL_IMPLEMENT_STRUCT_TYPE(struct_type) \ + CEL_INTERNAL_IMPLEMENT_TYPE(Struct, struct_type) // ListType represents a list type. A list is a sequential container where each // element is the same type. diff --git a/base/value.h b/base/value.h index 79e7aa327..d1285e322 100644 --- a/base/value.h +++ b/base/value.h @@ -638,15 +638,8 @@ class EnumValue : public Value { // private: // CEL_DECLARE_ENUM_VALUE(MyEnumValue); // }; -#define CEL_DECLARE_ENUM_VALUE(enum_value) \ - private: \ - friend class ::cel::base_internal::ValueHandleBase; \ - \ - static bool Is(const ::cel::Value& value); \ - \ - ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ - \ - ::cel::internal::TypeInfo TypeId() const override; +#define CEL_DECLARE_ENUM_VALUE(enum_value) \ + CEL_INTERNAL_DECLARE_VALUE(Enum, enum_value) // CEL_IMPLEMENT_ENUM_VALUE implements `enum_value` as an enumeration value. It // must be called after the class definition of `enum_value`. @@ -658,33 +651,8 @@ class EnumValue : public Value { // }; // // CEL_IMPLEMENT_ENUM_VALUE(MyEnumValue); -#define CEL_IMPLEMENT_ENUM_VALUE(enum_value) \ - static_assert(::std::is_base_of_v<::cel::EnumValue, enum_value>, \ - #enum_value " must inherit from cel::EnumValue"); \ - static_assert(!::std::is_abstract_v, \ - "this must not be abstract"); \ - \ - bool enum_value::Is(const ::cel::Value& value) { \ - return value.kind() == ::cel::Kind::kEnum && \ - ::cel::base_internal::GetEnumValueTypeId( \ - ::cel::internal::down_cast(value)) == \ - ::cel::internal::TypeId(); \ - } \ - \ - ::std::pair<::std::size_t, ::std::size_t> enum_value::SizeAndAlignment() \ - const { \ - static_assert( \ - ::std::is_same_v>>, \ - "this must be the same as " #enum_value); \ - return ::std::pair<::std::size_t, ::std::size_t>(sizeof(enum_value), \ - alignof(enum_value)); \ - } \ - \ - ::cel::internal::TypeInfo enum_value::TypeId() const { \ - return ::cel::internal::TypeId(); \ - } +#define CEL_IMPLEMENT_ENUM_VALUE(enum_value) \ + CEL_INTERNAL_IMPLEMENT_VALUE(Enum, enum_value) // StructValue represents an instance of cel::StructType. class StructValue : public Value { @@ -754,7 +722,7 @@ class StructValue : public Value { std::pair SizeAndAlignment() const override = 0; - // Called by CEL_IMPLEMENT_ENUM_VALUE() and Is() to perform type checking. + // Called by CEL_IMPLEMENT_STRUCT_VALUE() and Is() to perform type checking. virtual internal::TypeInfo TypeId() const = 0; // Set lazily, by StructValue::New. @@ -769,15 +737,8 @@ class StructValue : public Value { // private: // CEL_DECLARE_STRUCT_VALUE(MyStructValue); // }; -#define CEL_DECLARE_STRUCT_VALUE(struct_value) \ - private: \ - friend class ::cel::base_internal::ValueHandleBase; \ - \ - static bool Is(const ::cel::Value& value); \ - \ - ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ - \ - ::cel::internal::TypeInfo TypeId() const override; +#define CEL_DECLARE_STRUCT_VALUE(struct_value) \ + CEL_INTERNAL_DECLARE_VALUE(Struct, struct_value) // CEL_IMPLEMENT_STRUCT_VALUE implements `struct_value` as an struct // value. It must be called after the class definition of `struct_value`. @@ -789,33 +750,8 @@ class StructValue : public Value { // }; // // CEL_IMPLEMENT_STRUCT_VALUE(MyStructValue); -#define CEL_IMPLEMENT_STRUCT_VALUE(struct_value) \ - static_assert(::std::is_base_of_v<::cel::StructValue, struct_value>, \ - #struct_value " must inherit from cel::StructValue"); \ - static_assert(!::std::is_abstract_v, \ - "this must not be abstract"); \ - \ - bool struct_value::Is(const ::cel::Value& value) { \ - return value.kind() == ::cel::Kind::kStruct && \ - ::cel::base_internal::GetStructValueTypeId( \ - ::cel::internal::down_cast( \ - value)) == ::cel::internal::TypeId(); \ - } \ - \ - ::std::pair<::std::size_t, ::std::size_t> struct_value::SizeAndAlignment() \ - const { \ - static_assert( \ - ::std::is_same_v>>, \ - "this must be the same as " #struct_value); \ - return ::std::pair<::std::size_t, ::std::size_t>(sizeof(struct_value), \ - alignof(struct_value)); \ - } \ - \ - ::cel::internal::TypeInfo struct_value::TypeId() const { \ - return ::cel::internal::TypeId(); \ - } +#define CEL_IMPLEMENT_STRUCT_VALUE(struct_value) \ + CEL_INTERNAL_IMPLEMENT_VALUE(Struct, struct_value) // ListValue represents an instance of cel::ListType. class ListValue : public Value { @@ -859,15 +795,12 @@ class ListValue : public Value { std::pair SizeAndAlignment() const override = 0; - // Called by CEL_IMPLEMENT_ENUM_VALUE() and Is() to perform type checking. + // Called by CEL_IMPLEMENT_LIST_VALUE() and Is() to perform type checking. virtual internal::TypeInfo TypeId() const = 0; const Persistent type_; }; -// TODO(issues/5): generalize the macros to avoid repeating them when they -// are ultimately very similar. - // CEL_DECLARE_LIST_VALUE declares `list_value` as an list value. It must // be part of the class definition of `list_value`. // @@ -876,15 +809,8 @@ class ListValue : public Value { // private: // CEL_DECLARE_LIST_VALUE(MyListValue); // }; -#define CEL_DECLARE_LIST_VALUE(list_value) \ - private: \ - friend class ::cel::base_internal::ValueHandleBase; \ - \ - static bool Is(const ::cel::Value& value); \ - \ - ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ - \ - ::cel::internal::TypeInfo TypeId() const override; +#define CEL_DECLARE_LIST_VALUE(list_value) \ + CEL_INTERNAL_DECLARE_VALUE(List, list_value) // CEL_IMPLEMENT_LIST_VALUE implements `list_value` as an list // value. It must be called after the class definition of `list_value`. @@ -896,33 +822,8 @@ class ListValue : public Value { // }; // // CEL_IMPLEMENT_LIST_VALUE(MyListValue); -#define CEL_IMPLEMENT_LIST_VALUE(list_value) \ - static_assert(::std::is_base_of_v<::cel::ListValue, list_value>, \ - #list_value " must inherit from cel::ListValue"); \ - static_assert(!::std::is_abstract_v, \ - "this must not be abstract"); \ - \ - bool list_value::Is(const ::cel::Value& value) { \ - return value.kind() == ::cel::Kind::kList && \ - ::cel::base_internal::GetListValueTypeId( \ - ::cel::internal::down_cast(value)) == \ - ::cel::internal::TypeId(); \ - } \ - \ - ::std::pair<::std::size_t, ::std::size_t> list_value::SizeAndAlignment() \ - const { \ - static_assert( \ - ::std::is_same_v>>, \ - "this must be the same as " #list_value); \ - return ::std::pair<::std::size_t, ::std::size_t>(sizeof(list_value), \ - alignof(list_value)); \ - } \ - \ - ::cel::internal::TypeInfo list_value::TypeId() const { \ - return ::cel::internal::TypeId(); \ - } +#define CEL_IMPLEMENT_LIST_VALUE(list_value) \ + CEL_INTERNAL_IMPLEMENT_VALUE(List, list_value) // MapValue represents an instance of cel::MapType. class MapValue : public Value { @@ -962,16 +863,13 @@ class MapValue : public Value { std::pair SizeAndAlignment() const override = 0; - // Called by CEL_IMPLEMENT_ENUM_VALUE() and Is() to perform type checking. + // Called by CEL_IMPLEMENT_MAP_VALUE() and Is() to perform type checking. virtual internal::TypeInfo TypeId() const = 0; // Set lazily, by EnumValue::New. Persistent type_; }; -// TODO(issues/5): generalize the macros to avoid repeating them when they -// are ultimately very similar. - // CEL_DECLARE_MAP_VALUE declares `map_value` as an map value. It must // be part of the class definition of `map_value`. // @@ -980,15 +878,8 @@ class MapValue : public Value { // private: // CEL_DECLARE_MAP_VALUE(MyMapValue); // }; -#define CEL_DECLARE_MAP_VALUE(map_value) \ - private: \ - friend class ::cel::base_internal::ValueHandleBase; \ - \ - static bool Is(const ::cel::Value& value); \ - \ - ::std::pair<::std::size_t, ::std::size_t> SizeAndAlignment() const override; \ - \ - ::cel::internal::TypeInfo TypeId() const override; +#define CEL_DECLARE_MAP_VALUE(map_value) \ + CEL_INTERNAL_DECLARE_VALUE(Map, map_value) // CEL_IMPLEMENT_MAP_VALUE implements `map_value` as an map // value. It must be called after the class definition of `map_value`. @@ -1000,33 +891,8 @@ class MapValue : public Value { // }; // // CEL_IMPLEMENT_MAP_VALUE(MyMapValue); -#define CEL_IMPLEMENT_MAP_VALUE(map_value) \ - static_assert(::std::is_base_of_v<::cel::MapValue, map_value>, \ - #map_value " must inherit from cel::MapValue"); \ - static_assert(!::std::is_abstract_v, \ - "this must not be abstract"); \ - \ - bool map_value::Is(const ::cel::Value& value) { \ - return value.kind() == ::cel::Kind::kMap && \ - ::cel::base_internal::GetMapValueTypeId( \ - ::cel::internal::down_cast(value)) == \ - ::cel::internal::TypeId(); \ - } \ - \ - ::std::pair<::std::size_t, ::std::size_t> map_value::SizeAndAlignment() \ - const { \ - static_assert( \ - ::std::is_same_v>>, \ - "this must be the same as " #map_value); \ - return ::std::pair<::std::size_t, ::std::size_t>(sizeof(map_value), \ - alignof(map_value)); \ - } \ - \ - ::cel::internal::TypeInfo map_value::TypeId() const { \ - return ::cel::internal::TypeId(); \ - } +#define CEL_IMPLEMENT_MAP_VALUE(map_value) \ + CEL_INTERNAL_IMPLEMENT_VALUE(Map, map_value) } // namespace cel From d6276a78cbe353d11ff3680406399d4c72d7e62c Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 15 Apr 2022 00:18:50 +0000 Subject: [PATCH 074/111] Internal change PiperOrigin-RevId: 441893607 --- base/value_test.cc | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/base/value_test.cc b/base/value_test.cc index 3a361b36e..e9f3a984e 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -607,7 +607,26 @@ INSTANTIATE_TEST_SUITE_P( {"Bytes", [](TypeFactory& type_factory, ValueFactory& value_factory) -> Persistent { - return Must(value_factory.CreateBytesValue(nullptr)); + return Must(value_factory.CreateBytesValue("")); + }}, + {"String", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must(value_factory.CreateStringValue("")); + }}, + {"Enum", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must( + EnumValue::New(Must(type_factory.CreateEnumType()), + value_factory, EnumType::ConstantId("VALUE1"))); + }}, + {"Struct", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must(StructValue::New( + Must(type_factory.CreateStructType()), + value_factory)); }}, {"List", [](TypeFactory& type_factory, From 05e0549385139b939296c51bd2ef617b5f6cf3ce Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Sat, 16 Apr 2022 00:19:32 +0000 Subject: [PATCH 075/111] Introduce duck-typed message adapter. This defers to the descriptor of the given message, so it maintains the same behavior as the current implementation directly calling reflection APIs. PiperOrigin-RevId: 442129301 --- eval/public/structs/BUILD | 10 +- eval/public/structs/legacy_type_info_apis.h | 2 +- .../structs/proto_message_type_adapter.cc | 249 +++++++++----- .../structs/proto_message_type_adapter.h | 6 + .../proto_message_type_adapter_test.cc | 305 +++++++++++------- 5 files changed, 377 insertions(+), 195 deletions(-) diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 1ca7e1487..d6d249224 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -222,17 +222,20 @@ cc_library( srcs = ["proto_message_type_adapter.cc"], hdrs = ["proto_message_type_adapter.h"], deps = [ - ":cel_proto_wrapper", + ":cel_proto_wrap_util", + ":field_access_impl", ":legacy_type_adapter", + ":legacy_type_info_apis", "//base:memory_manager", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:cel_value_internal", "//eval/public/containers:field_access", - "//eval/public/containers:field_backed_list_impl", - "//eval/public/containers:field_backed_map_impl", + "//eval/public/containers:internal_field_backed_list_impl", + "//eval/public/containers:internal_field_backed_map_impl", "//extensions/protobuf:memory_manager", "//internal:casts", + "//internal:no_destructor", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -245,6 +248,7 @@ cc_test( srcs = ["proto_message_type_adapter_test.cc"], deps = [ ":cel_proto_wrapper", + ":legacy_type_adapter", ":proto_message_type_adapter", "//eval/public:cel_value", "//eval/public:cel_value_internal", diff --git a/eval/public/structs/legacy_type_info_apis.h b/eval/public/structs/legacy_type_info_apis.h index 26d77ea40..5971f23de 100644 --- a/eval/public/structs/legacy_type_info_apis.h +++ b/eval/public/structs/legacy_type_info_apis.h @@ -28,7 +28,7 @@ class LegacyTypeAccessApis; // message). // // Provides ability to obtain field access apis, type info, and debug -// representation of a message/ +// representation of a message. // // This is implemented as a separate class from LegacyTypeAccessApis to resolve // cyclic dependency between CelValue (which needs to access these apis to diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index 08af0607c..199feca9a 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -14,6 +14,8 @@ #include "eval/public/structs/proto_message_type_adapter.h" +#include + #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "absl/status/status.h" @@ -23,60 +25,50 @@ #include "eval/public/cel_value.h" #include "eval/public/cel_value_internal.h" #include "eval/public/containers/field_access.h" -#include "eval/public/containers/field_backed_list_impl.h" -#include "eval/public/containers/field_backed_map_impl.h" -#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/containers/internal_field_backed_list_impl.h" +#include "eval/public/containers/internal_field_backed_map_impl.h" +#include "eval/public/structs/cel_proto_wrap_util.h" +#include "eval/public/structs/field_access_impl.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" #include "extensions/protobuf/memory_manager.h" #include "internal/casts.h" +#include "internal/no_destructor.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { +namespace { + using ::cel::extensions::ProtoMemoryManager; using ::google::protobuf::FieldDescriptor; using ::google::protobuf::Message; using ::google::protobuf::Reflection; -absl::Status ProtoMessageTypeAdapter::ValidateSetFieldOp( - bool assertion, absl::string_view field, absl::string_view detail) const { - if (!assertion) { - return absl::InvalidArgumentError( - absl::Substitute("SetField failed on message $0, field '$1': $2", - descriptor_->full_name(), field, detail)); - } - return absl::OkStatus(); -} - -absl::StatusOr ProtoMessageTypeAdapter::NewInstance( - cel::MemoryManager& memory_manager) const { - // This implementation requires arena-backed memory manager. - google::protobuf::Arena* arena = ProtoMemoryManager::CastToProtoArena(memory_manager); - const Message* prototype = message_factory_->GetPrototype(descriptor_); - - Message* msg = (prototype != nullptr) ? prototype->New(arena) : nullptr; - - if (msg == nullptr) { - return absl::InvalidArgumentError( - absl::StrCat("Failed to create message ", descriptor_->name())); - } - return CelValue::MessageWrapper(msg); +const std::string& UnsupportedTypeName() { + static cel::internal::NoDestructor kUnsupportedTypeName( + ""); + return *kUnsupportedTypeName; } -bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { - return descriptor_->FindFieldByName(field_name.data()) != nullptr; -} +CelValue MessageCelValueFactory(const google::protobuf::Message* message); -absl::StatusOr ProtoMessageTypeAdapter::HasField( - absl::string_view field_name, const CelValue::MessageWrapper& value) const { +inline absl::StatusOr UnwrapMessage( + const CelValue::MessageWrapper& value, absl::string_view op) { if (!value.HasFullProto() || value.message_ptr() == nullptr) { - return absl::InvalidArgumentError("GetField called on non-message type."); + return absl::InternalError( + absl::StrCat(op, " called on non-message type.")); } - const google::protobuf::Message* message = - cel::internal::down_cast(value.message_ptr()); + return cel::internal::down_cast(value.message_ptr()); +} +// Shared implementation for HasField. +// Handles list or map specific behavior before calling reflection helpers. +absl::StatusOr HasFieldImpl(const google::protobuf::Message* message, + const google::protobuf::Descriptor* descriptor, + absl::string_view field_name) { + ABSL_ASSERT(descriptor == message->GetDescriptor()); const Reflection* reflection = message->GetReflection(); - ABSL_ASSERT(descriptor_ == message->GetDescriptor()); - - const FieldDescriptor* field_desc = descriptor_->FindFieldByName(field_name.data()); + const FieldDescriptor* field_desc = descriptor->FindFieldByName(field_name.data()); if (field_desc == nullptr) { return absl::NotFoundError(absl::StrCat("no_such_field : ", field_name)); @@ -100,16 +92,15 @@ absl::StatusOr ProtoMessageTypeAdapter::HasField( return reflection->HasField(*message, field_desc); } -absl::StatusOr ProtoMessageTypeAdapter::GetField( - absl::string_view field_name, const CelValue::MessageWrapper& instance, - ProtoWrapperTypeOptions unboxing_option, - cel::MemoryManager& memory_manager) const { - if (!instance.HasFullProto() || instance.message_ptr() == nullptr) { - return absl::InvalidArgumentError("GetField called on non-message type."); - } - const google::protobuf::Message* message = - cel::internal::down_cast(instance.message_ptr()); - const FieldDescriptor* field_desc = descriptor_->FindFieldByName(field_name.data()); +// Shared implementation for GetField. +// Handles list or map specific behavior before calling reflection helpers. +absl::StatusOr GetFieldImpl(const google::protobuf::Message* message, + const google::protobuf::Descriptor* descriptor, + absl::string_view field_name, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManager& memory_manager) { + ABSL_ASSERT(descriptor == message->GetDescriptor()); + const FieldDescriptor* field_desc = descriptor->FindFieldByName(field_name.data()); if (field_desc == nullptr) { return CreateNoSuchFieldError(memory_manager, field_name); @@ -118,22 +109,133 @@ absl::StatusOr ProtoMessageTypeAdapter::GetField( google::protobuf::Arena* arena = ProtoMemoryManager::CastToProtoArena(memory_manager); if (field_desc->is_map()) { - CelMap* map = google::protobuf::Arena::Create(arena, message, - field_desc, arena); - return CelValue::CreateMap(map); + auto map = memory_manager.New( + message, field_desc, &MessageCelValueFactory, arena); + + return CelValue::CreateMap(map.release()); } if (field_desc->is_repeated()) { - CelList* list = google::protobuf::Arena::Create( - arena, message, field_desc, arena); - return CelValue::CreateList(list); + auto list = memory_manager.New( + message, field_desc, &MessageCelValueFactory, arena); + return CelValue::CreateList(list.release()); } - CelValue result; - CEL_RETURN_IF_ERROR(CreateValueFromSingleField( - message, field_desc, unboxing_option, arena, &result)); + CEL_ASSIGN_OR_RETURN( + CelValue result, + internal::CreateValueFromSingleField(message, field_desc, unboxing_option, + &MessageCelValueFactory, arena)); return result; } +class DucktypedMessageAdapter : public LegacyTypeAccessApis, + public LegacyTypeInfoApis { + public: + // Implement field access APIs. + absl::StatusOr HasField( + absl::string_view field_name, + const CelValue::MessageWrapper& value) const override { + CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + UnwrapMessage(value, "HasField")); + return HasFieldImpl(message, message->GetDescriptor(), field_name); + } + + absl::StatusOr GetField( + absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManager& memory_manager) const override { + CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + UnwrapMessage(instance, "GetField")); + return GetFieldImpl(message, message->GetDescriptor(), field_name, + unboxing_option, memory_manager); + } + + // Implement TypeInfo Apis + const std::string& GetTypename( + const internal::MessageWrapper& wrapped_message) const override { + if (!wrapped_message.HasFullProto() || + wrapped_message.message_ptr() == nullptr) { + return UnsupportedTypeName(); + } + auto* message = cel::internal::down_cast( + wrapped_message.message_ptr()); + return message->GetDescriptor()->full_name(); + } + + std::string DebugString( + const internal::MessageWrapper& wrapped_message) const override { + if (!wrapped_message.HasFullProto() || + wrapped_message.message_ptr() == nullptr) { + return UnsupportedTypeName(); + } + auto* message = cel::internal::down_cast( + wrapped_message.message_ptr()); + return message->DebugString(); + } + + const LegacyTypeAccessApis* GetAccessApis( + const internal::MessageWrapper& wrapped_message) const override { + return this; + } + + static DucktypedMessageAdapter& GetSingleton() { + static cel::internal::NoDestructor instance; + return *instance; + } +}; + +CelValue MessageCelValueFactory(const google::protobuf::Message* message) { + return CelValue::CreateMessageWrapper(internal::MessageWrapper(message)); +} + +} // namespace + +absl::Status ProtoMessageTypeAdapter::ValidateSetFieldOp( + bool assertion, absl::string_view field, absl::string_view detail) const { + if (!assertion) { + return absl::InvalidArgumentError( + absl::Substitute("SetField failed on message $0, field '$1': $2", + descriptor_->full_name(), field, detail)); + } + return absl::OkStatus(); +} + +absl::StatusOr ProtoMessageTypeAdapter::NewInstance( + cel::MemoryManager& memory_manager) const { + // This implementation requires arena-backed memory manager. + google::protobuf::Arena* arena = ProtoMemoryManager::CastToProtoArena(memory_manager); + const Message* prototype = message_factory_->GetPrototype(descriptor_); + + Message* msg = (prototype != nullptr) ? prototype->New(arena) : nullptr; + + if (msg == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to create message ", descriptor_->name())); + } + return CelValue::MessageWrapper(msg); +} + +bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { + return descriptor_->FindFieldByName(field_name.data()) != nullptr; +} + +absl::StatusOr ProtoMessageTypeAdapter::HasField( + absl::string_view field_name, const CelValue::MessageWrapper& value) const { + CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + UnwrapMessage(value, "HasField")); + return HasFieldImpl(message, descriptor_, field_name); +} + +absl::StatusOr ProtoMessageTypeAdapter::GetField( + absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManager& memory_manager) const { + CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + UnwrapMessage(instance, "GetField")); + + return GetFieldImpl(message, descriptor_, field_name, unboxing_option, + memory_manager); +} + absl::Status ProtoMessageTypeAdapter::SetField( absl::string_view field_name, const CelValue& value, cel::MemoryManager& memory_manager, @@ -142,12 +244,8 @@ absl::Status ProtoMessageTypeAdapter::SetField( google::protobuf::Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager); - if (!instance.HasFullProto() || instance.message_ptr() == nullptr) { - return absl::InternalError("SetField called on non-message type."); - } - - const google::protobuf::Message* message = - cel::internal::down_cast(instance.message_ptr()); + CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + UnwrapMessage(instance, "SetField")); // Interpreter guarantees this is the top-level instance. google::protobuf::Message* mutable_message = const_cast(message); @@ -192,9 +290,9 @@ absl::Status ProtoMessageTypeAdapter::SetField( "error serializing CelMap")); Message* entry_msg = mutable_message->GetReflection()->AddMessage( mutable_message, field_descriptor); - CEL_RETURN_IF_ERROR( - SetValueToSingleField(key, key_field_descriptor, entry_msg, arena)); - CEL_RETURN_IF_ERROR(SetValueToSingleField( + CEL_RETURN_IF_ERROR(internal::SetValueToSingleField( + key, key_field_descriptor, entry_msg, arena)); + CEL_RETURN_IF_ERROR(internal::SetValueToSingleField( value.value(), value_field_descriptor, entry_msg, arena)); } @@ -205,12 +303,12 @@ absl::Status ProtoMessageTypeAdapter::SetField( field_name, "expected CelList value")); for (int i = 0; i < cel_list->size(); i++) { - CEL_RETURN_IF_ERROR(AddValueToRepeatedField( + CEL_RETURN_IF_ERROR(internal::AddValueToRepeatedField( (*cel_list)[i], field_descriptor, mutable_message, arena)); } } else { - CEL_RETURN_IF_ERROR( - SetValueToSingleField(value, field_descriptor, mutable_message, arena)); + CEL_RETURN_IF_ERROR(internal::SetValueToSingleField( + value, field_descriptor, mutable_message, arena)); } return absl::OkStatus(); } @@ -221,13 +319,14 @@ absl::StatusOr ProtoMessageTypeAdapter::AdaptFromWellKnownType( // Assume proto arena implementation if this provider is used. google::protobuf::Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager); - if (!instance.HasFullProto() || instance.message_ptr() == nullptr) { - return absl::InternalError( - "Adapt from well-known type failed: not a message"); - } - auto* message = - cel::internal::down_cast(instance.message_ptr()); - return CelProtoWrapper::CreateMessage(message, arena); + CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + UnwrapMessage(instance, "AdaptFromWellKnownType")); + return internal::UnwrapMessageToValue(message, &MessageCelValueFactory, + arena); +} + +const LegacyTypeInfoApis& GetGenericProtoTypeInfoInstance() { + return DucktypedMessageAdapter::GetSingleton(); } } // namespace google::api::expr::runtime diff --git a/eval/public/structs/proto_message_type_adapter.h b/eval/public/structs/proto_message_type_adapter.h index 478354fbb..99e22e89a 100644 --- a/eval/public/structs/proto_message_type_adapter.h +++ b/eval/public/structs/proto_message_type_adapter.h @@ -23,6 +23,7 @@ #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" namespace google::api::expr::runtime { @@ -67,6 +68,11 @@ class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, const google::protobuf::Descriptor* descriptor_; }; +// Returns a TypeInfo provider representing an arbitrary message. +// This allows for the legacy duck-typed behavior of messages on field access +// instead of expecting a particular message type given a TypeInfo. +const LegacyTypeInfoApis& GetGenericProtoTypeInfoInstance(); + } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 3d65be7ef..de7208a4b 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -26,6 +26,7 @@ #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/containers/field_access.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" #include "extensions/protobuf/memory_manager.h" @@ -35,6 +36,7 @@ namespace google::api::expr::runtime { namespace { +using ::cel::extensions::ProtoMemoryManager; using testing::_; using testing::EqualsProto; using testing::HasSubstr; @@ -42,145 +44,147 @@ using testing::Optional; using cel::internal::IsOkAndHolds; using cel::internal::StatusIs; -TEST(ProtoMessageTypeAdapter, HasFieldSingular) { +class ProtoMessageTypeAccessorTest : public testing::TestWithParam { + public: + ProtoMessageTypeAccessorTest() + : type_specific_instance_( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()) {} + + const LegacyTypeAccessApis& GetAccessApis() { + bool use_generic_instance = GetParam(); + if (use_generic_instance) { + // implementation detail: in general, type info implementations may + // return a different accessor object based on the messsage instance, but + // this implemenation returns the same one no matter the message. + return *GetGenericProtoTypeInfoInstance().GetAccessApis(dummy_); + + } else { + return type_specific_instance_; + } + } + + private: + ProtoMessageTypeAdapter type_specific_instance_; + CelValue::MessageWrapper dummy_; +}; + +TEST_P(ProtoMessageTypeAccessorTest, HasFieldSingular) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); - + const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.HasField("int64_value", value), IsOkAndHolds(false)); + EXPECT_THAT(accessor.HasField("int64_value", value), IsOkAndHolds(false)); example.set_int64_value(10); - EXPECT_THAT(adapter.HasField("int64_value", value), IsOkAndHolds(true)); + EXPECT_THAT(accessor.HasField("int64_value", value), IsOkAndHolds(true)); } -TEST(ProtoMessageTypeAdapter, HasFieldRepeated) { +TEST_P(ProtoMessageTypeAccessorTest, HasFieldRepeated) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.HasField("int64_list", value), IsOkAndHolds(false)); + EXPECT_THAT(accessor.HasField("int64_list", value), IsOkAndHolds(false)); example.add_int64_list(10); - EXPECT_THAT(adapter.HasField("int64_list", value), IsOkAndHolds(true)); + EXPECT_THAT(accessor.HasField("int64_list", value), IsOkAndHolds(true)); } -TEST(ProtoMessageTypeAdapter, HasFieldMap) { +TEST_P(ProtoMessageTypeAccessorTest, HasFieldMap) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; example.set_int64_value(10); internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.HasField("int64_int32_map", value), IsOkAndHolds(false)); + EXPECT_THAT(accessor.HasField("int64_int32_map", value), IsOkAndHolds(false)); (*example.mutable_int64_int32_map())[2] = 3; - EXPECT_THAT(adapter.HasField("int64_int32_map", value), IsOkAndHolds(true)); + EXPECT_THAT(accessor.HasField("int64_int32_map", value), IsOkAndHolds(true)); } -TEST(ProtoMessageTypeAdapter, HasFieldUnknownField) { +TEST_P(ProtoMessageTypeAccessorTest, HasFieldUnknownField) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; example.set_int64_value(10); internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.HasField("unknown_field", value), + EXPECT_THAT(accessor.HasField("unknown_field", value), StatusIs(absl::StatusCode::kNotFound)); } -TEST(ProtoMessageTypeAdapter, HasFieldNonMessageType) { +TEST_P(ProtoMessageTypeAccessorTest, HasFieldNonMessageType) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); + const LegacyTypeAccessApis& accessor = GetAccessApis(); internal::MessageWrapper value( static_cast(nullptr)); - EXPECT_THAT(adapter.HasField("unknown_field", value), - StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(accessor.HasField("unknown_field", value), + StatusIs(absl::StatusCode::kInternal)); } -TEST(ProtoMessageTypeAdapter, GetFieldSingular) { +TEST_P(ProtoMessageTypeAccessorTest, GetFieldSingular) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + ProtoMemoryManager manager(&arena); TestMessage example; example.set_int64_value(10); internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.GetField("int64_value", value, - ProtoWrapperTypeOptions::kUnsetNull, manager), + EXPECT_THAT(accessor.GetField("int64_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelInt64(10))); } -TEST(ProtoMessageTypeAdapter, GetFieldNoSuchField) { +TEST_P(ProtoMessageTypeAccessorTest, GetFieldNoSuchField) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + ProtoMemoryManager manager(&arena); TestMessage example; example.set_int64_value(10); internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.GetField("unknown_field", value, - ProtoWrapperTypeOptions::kUnsetNull, manager), + EXPECT_THAT(accessor.GetField("unknown_field", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelError(StatusIs( absl::StatusCode::kNotFound, HasSubstr("unknown_field"))))); } -TEST(ProtoMessageTypeAdapter, GetFieldNotAMessage) { +TEST_P(ProtoMessageTypeAccessorTest, GetFieldNotAMessage) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + ProtoMemoryManager manager(&arena); internal::MessageWrapper value( static_cast(nullptr)); - EXPECT_THAT(adapter.GetField("int64_value", value, - ProtoWrapperTypeOptions::kUnsetNull, manager), - StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(accessor.GetField("int64_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), + StatusIs(absl::StatusCode::kInternal)); } -TEST(ProtoMessageTypeAdapter, GetFieldRepeated) { +TEST_P(ProtoMessageTypeAccessorTest, GetFieldRepeated) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + ProtoMemoryManager manager(&arena); TestMessage example; example.add_int64_list(10); @@ -190,8 +194,8 @@ TEST(ProtoMessageTypeAdapter, GetFieldRepeated) { ASSERT_OK_AND_ASSIGN( CelValue result, - adapter.GetField("int64_list", value, ProtoWrapperTypeOptions::kUnsetNull, - manager)); + accessor.GetField("int64_list", value, + ProtoWrapperTypeOptions::kUnsetNull, manager)); const CelList* held_value; ASSERT_TRUE(result.GetValue(&held_value)) << result.DebugString(); @@ -201,13 +205,11 @@ TEST(ProtoMessageTypeAdapter, GetFieldRepeated) { EXPECT_THAT((*held_value)[1], test::IsCelInt64(20)); } -TEST(ProtoMessageTypeAdapter, GetFieldMap) { +TEST_P(ProtoMessageTypeAccessorTest, GetFieldMap) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + ProtoMemoryManager manager(&arena); TestMessage example; (*example.mutable_int64_int32_map())[10] = 20; @@ -216,8 +218,8 @@ TEST(ProtoMessageTypeAdapter, GetFieldMap) { ASSERT_OK_AND_ASSIGN( CelValue result, - adapter.GetField("int64_int32_map", value, - ProtoWrapperTypeOptions::kUnsetNull, manager)); + accessor.GetField("int64_int32_map", value, + ProtoWrapperTypeOptions::kUnsetNull, manager)); const CelMap* held_value; ASSERT_TRUE(result.GetValue(&held_value)) << result.DebugString(); @@ -227,62 +229,57 @@ TEST(ProtoMessageTypeAdapter, GetFieldMap) { Optional(test::IsCelInt64(20))); } -TEST(ProtoMessageTypeAdapter, GetFieldWrapperType) { +TEST_P(ProtoMessageTypeAccessorTest, GetFieldWrapperType) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + ProtoMemoryManager manager(&arena); TestMessage example; example.mutable_int64_wrapper_value()->set_value(10); internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, - ProtoWrapperTypeOptions::kUnsetNull, manager), + EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelInt64(10))); } -TEST(ProtoMessageTypeAdapter, GetFieldWrapperTypeUnsetNullUnbox) { +TEST_P(ProtoMessageTypeAccessorTest, GetFieldWrapperTypeUnsetNullUnbox) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + ProtoMemoryManager manager(&arena); TestMessage example; internal::MessageWrapper value(&example); - EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, - ProtoWrapperTypeOptions::kUnsetNull, manager), + EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelNull())); // Wrapper field present, but default value. example.mutable_int64_wrapper_value()->clear_value(); - EXPECT_THAT(adapter.GetField("int64_wrapper_value", value, - ProtoWrapperTypeOptions::kUnsetNull, manager), + EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelInt64(_))); } -TEST(ProtoMessageTypeAdapter, GetFieldWrapperTypeUnsetDefaultValueUnbox) { +TEST_P(ProtoMessageTypeAccessorTest, + GetFieldWrapperTypeUnsetDefaultValueUnbox) { google::protobuf::Arena arena; - ProtoMessageTypeAdapter adapter( - google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - "google.api.expr.runtime.TestMessage"), - google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + ProtoMemoryManager manager(&arena); TestMessage example; internal::MessageWrapper value(&example); EXPECT_THAT( - adapter.GetField("int64_wrapper_value", value, - ProtoWrapperTypeOptions::kUnsetProtoDefault, manager), + accessor.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetProtoDefault, manager), IsOkAndHolds(test::IsCelInt64(_))); // Wrapper field present with unset value is used to signal Null, but legacy @@ -290,18 +287,78 @@ TEST(ProtoMessageTypeAdapter, GetFieldWrapperTypeUnsetDefaultValueUnbox) { example.mutable_int64_wrapper_value()->clear_value(); // Same behavior for this option. EXPECT_THAT( - adapter.GetField("int64_wrapper_value", value, - ProtoWrapperTypeOptions::kUnsetProtoDefault, manager), + accessor.GetField("int64_wrapper_value", value, + ProtoWrapperTypeOptions::kUnsetProtoDefault, manager), IsOkAndHolds(test::IsCelInt64(_))); } +INSTANTIATE_TEST_SUITE_P(GenericAndSpecific, ProtoMessageTypeAccessorTest, + testing::Bool()); + +TEST(GetGenericProtoTypeInfoInstance, GetTypeName) { + const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); + + TestMessage test_message; + CelValue::MessageWrapper wrapped_message(&test_message); + + EXPECT_EQ(info_api.GetTypename(wrapped_message), test_message.GetTypeName()); +} + +TEST(GetGenericProtoTypeInfoInstance, DebugString) { + const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); + + TestMessage test_message; + test_message.set_string_value("abcd"); + CelValue::MessageWrapper wrapped_message(&test_message); + + EXPECT_EQ(info_api.DebugString(wrapped_message), test_message.DebugString()); +} + +TEST(GetGenericProtoTypeInfoInstance, GetAccessApis) { + const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); + + TestMessage test_message; + test_message.set_string_value("abcd"); + CelValue::MessageWrapper wrapped_message(&test_message); + + auto* accessor = info_api.GetAccessApis(wrapped_message); + google::protobuf::Arena arena; + ProtoMemoryManager manager(&arena); + + ASSERT_OK_AND_ASSIGN( + CelValue result, + accessor->GetField("string_value", wrapped_message, + ProtoWrapperTypeOptions::kUnsetNull, manager)); + EXPECT_THAT(result, test::IsCelString("abcd")); +} + +TEST(GetGenericProtoTypeInfoInstance, FallbackForNonMessage) { + const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); + + TestMessage test_message; + test_message.set_string_value("abcd"); + // Upcast to signal no google::protobuf::Message / reflection support. + CelValue::MessageWrapper wrapped_message( + static_cast(&test_message)); + + EXPECT_EQ(info_api.GetTypename(wrapped_message), ""); + EXPECT_EQ(info_api.DebugString(wrapped_message), ""); + + // Check for not-null. + CelValue::MessageWrapper null_message( + static_cast(nullptr)); + + EXPECT_EQ(info_api.GetTypename(null_message), ""); + EXPECT_EQ(info_api.DebugString(null_message), ""); +} + TEST(ProtoMessageTypeAdapter, NewInstance) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper result, adapter.NewInstance(manager)); @@ -324,7 +381,7 @@ TEST(ProtoMessageTypeAdapter, NewInstanceUnsupportedDescriptor) { ProtoMessageTypeAdapter adapter( pool.FindMessageTypeByName("google.api.expr.runtime.FakeMessage"), google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + ProtoMemoryManager manager(&arena); // Message factory doesn't know how to create our custom message, even though // we provided a descriptor for it. @@ -349,7 +406,7 @@ TEST(ProtoMessageTypeAdapter, SetFieldSingular) { google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper value, adapter.NewInstance(manager)); @@ -371,7 +428,7 @@ TEST(ProtoMessageTypeAdapter, SetFieldMap) { google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + ProtoMemoryManager manager(&arena); CelMapBuilder builder; ASSERT_OK(builder.Add(CelValue::CreateInt64(1), CelValue::CreateInt64(2))); @@ -397,7 +454,7 @@ TEST(ProtoMessageTypeAdapter, SetFieldRepeated) { google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + ProtoMemoryManager manager(&arena); ContainerBackedListImpl list( {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); @@ -419,7 +476,7 @@ TEST(ProtoMessageTypeAdapter, SetFieldNotAField) { google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, adapter.NewInstance(manager)); @@ -436,7 +493,7 @@ TEST(ProtoMesssageTypeAdapter, SetFieldWrongType) { google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + ProtoMemoryManager manager(&arena); ContainerBackedListImpl list( {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); @@ -476,7 +533,7 @@ TEST(ProtoMesssageTypeAdapter, SetFieldNotAMessage) { google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + ProtoMemoryManager manager(&arena); CelValue int_value = CelValue::CreateInt64(42); CelValue::MessageWrapper instance( @@ -486,13 +543,29 @@ TEST(ProtoMesssageTypeAdapter, SetFieldNotAMessage) { StatusIs(absl::StatusCode::kInternal)); } +TEST(ProtoMesssageTypeAdapter, SetFieldNullMessage) { + google::protobuf::Arena arena; + ProtoMessageTypeAdapter adapter( + google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "google.api.expr.runtime.TestMessage"), + google::protobuf::MessageFactory::generated_factory()); + ProtoMemoryManager manager(&arena); + + CelValue int_value = CelValue::CreateInt64(42); + CelValue::MessageWrapper instance( + static_cast(nullptr)); + + EXPECT_THAT(adapter.SetField("int64_value", int_value, manager, instance), + StatusIs(absl::StatusCode::kInternal)); +} + TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownType) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.protobuf.Int64Value"), google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, adapter.NewInstance(manager)); @@ -511,13 +584,13 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeUnspecial) { google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + ProtoMemoryManager manager(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, adapter.NewInstance(manager)); + ASSERT_OK(adapter.SetField("int64_value", CelValue::CreateInt64(42), manager, instance)); - ASSERT_OK_AND_ASSIGN(CelValue value, adapter.AdaptFromWellKnownType(manager, instance)); @@ -531,7 +604,7 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeNotAMessageError) { google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); - cel::extensions::ProtoMemoryManager manager(&arena); + ProtoMemoryManager manager(&arena); CelValue::MessageWrapper instance( static_cast(nullptr)); From b20d23b1bee61f62184a9ff1ebee5d0daf542a50 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Mon, 18 Apr 2022 17:48:56 +0000 Subject: [PATCH 076/111] Split internal library for encoding / decoding time from other proto-based utilities. PiperOrigin-RevId: 442579438 --- eval/eval/BUILD | 2 +- eval/eval/const_value_step.cc | 10 +- eval/public/BUILD | 7 +- eval/public/builtin_func_registrar.cc | 6 +- eval/public/comparison_functions.cc | 1 - eval/public/structs/BUILD | 8 +- eval/public/structs/cel_proto_wrap_util.cc | 14 +-- eval/public/structs/cel_proto_wrap_util.h | 4 - .../structs/cel_proto_wrap_util_test.cc | 6 +- eval/public/structs/cel_proto_wrapper.h | 6 +- eval/public/structs/cel_proto_wrapper_test.cc | 6 +- eval/public/transform_utility.cc | 6 +- eval/public/value_export_util.cc | 6 +- internal/BUILD | 29 ++++- internal/proto_time_encoding.cc | 102 ++++++++++++++++++ internal/proto_time_encoding.h | 49 +++++++++ internal/proto_time_encoding_test.cc | 73 +++++++++++++ internal/proto_util.cc | 74 ------------- internal/proto_util.h | 22 ---- 19 files changed, 289 insertions(+), 142 deletions(-) create mode 100644 internal/proto_time_encoding.cc create mode 100644 internal/proto_time_encoding.h create mode 100644 internal/proto_time_encoding_test.cc diff --git a/eval/eval/BUILD b/eval/eval/BUILD index a1a33e7c9..c586118af 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -88,7 +88,7 @@ cc_library( ":expression_step_base", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", - "//internal:proto_util", + "//internal:proto_time_encoding", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], diff --git a/eval/eval/const_value_step.cc b/eval/eval/const_value_step.cc index f010abc7d..33bac528b 100644 --- a/eval/eval/const_value_step.cc +++ b/eval/eval/const_value_step.cc @@ -7,7 +7,7 @@ #include "absl/status/statusor.h" #include "eval/eval/expression_step_base.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "internal/proto_util.h" +#include "internal/proto_time_encoding.h" namespace google::api::expr::runtime { @@ -60,11 +60,11 @@ absl::optional ConvertConstant(const Constant* const_expr) { break; case Constant::kDurationValue: value = CelValue::CreateDuration( - expr::internal::DecodeDuration(const_expr->duration_value())); + cel::internal::DecodeDuration(const_expr->duration_value())); break; case Constant::kTimestampValue: value = CelValue::CreateTimestamp( - expr::internal::DecodeTime(const_expr->timestamp_value())); + cel::internal::DecodeTime(const_expr->timestamp_value())); break; default: // constant with no kind specified @@ -76,13 +76,13 @@ absl::optional ConvertConstant(const Constant* const_expr) { absl::StatusOr> CreateConstValueStep( CelValue value, int64_t expr_id, bool comes_from_ast) { - return absl::make_unique(value, expr_id, comes_from_ast); + return std::make_unique(value, expr_id, comes_from_ast); } // Factory method for Constant(Enum value) - based Execution step absl::StatusOr> CreateConstValueStep( const google::protobuf::EnumValueDescriptor* value_descriptor, int64_t expr_id) { - return absl::make_unique( + return std::make_unique( CelValue::CreateInt64(value_descriptor->number()), expr_id, false); } diff --git a/eval/public/BUILD b/eval/public/BUILD index 1e0c64391..899e4e4a6 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -238,7 +238,7 @@ cc_library( "//eval/public/containers:container_backed_list_impl", "//internal:casts", "//internal:overflow", - "//internal:proto_util", + "//internal:proto_time_encoding", "//internal:status_macros", "//internal:time", "//internal:utf8", @@ -270,7 +270,6 @@ cc_library( "//eval/public/containers:container_backed_list_impl", "//internal:casts", "//internal:overflow", - "//internal:proto_util", "//internal:status_macros", "//internal:time", "//internal:utf8", @@ -446,7 +445,7 @@ cc_library( ], deps = [ ":cel_value", - "//internal:proto_util", + "//internal:proto_time_encoding", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", @@ -830,7 +829,7 @@ cc_library( "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", - "//internal:proto_util", + "//internal:proto_time_encoding", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index b57782fcd..75600d889 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -41,7 +41,7 @@ #include "eval/public/containers/container_backed_list_impl.h" #include "internal/casts.h" #include "internal/overflow.h" -#include "internal/proto_util.h" +#include "internal/proto_time_encoding.h" #include "internal/status_macros.h" #include "internal/time.h" #include "internal/utf8.h" @@ -51,9 +51,9 @@ namespace google::api::expr::runtime { namespace { +using ::cel::internal::EncodeDurationToString; +using ::cel::internal::EncodeTimeToString; using ::cel::internal::MaxTimestamp; -using ::google::api::expr::internal::EncodeDurationToString; -using ::google::api::expr::internal::EncodeTimeToString; using ::google::protobuf::Arena; // Time representing `9999-12-31T23:59:59.999999999Z`. diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index cc33df500..cc4cd6faf 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -41,7 +41,6 @@ #include "eval/public/containers/container_backed_list_impl.h" #include "internal/casts.h" #include "internal/overflow.h" -#include "internal/proto_util.h" #include "internal/status_macros.h" #include "internal/time.h" #include "internal/utf8.h" diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index d6d249224..60421cb25 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -27,7 +27,7 @@ cc_library( deps = [ ":cel_proto_wrap_util", "//eval/public:cel_value", - "//internal:proto_util", + "//internal:proto_time_encoding", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], @@ -57,7 +57,7 @@ cc_library( "//eval/public:cel_value", "//eval/testutil:test_message_cc_proto", "//internal:overflow", - "//internal:proto_util", + "//internal:proto_time_encoding", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -80,7 +80,7 @@ cc_test( "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/testutil:test_message_cc_proto", - "//internal:proto_util", + "//internal:proto_time_encoding", "//internal:status_macros", "//internal:testing", "//testutil:util", @@ -170,7 +170,7 @@ cc_test( "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/testutil:test_message_cc_proto", - "//internal:proto_util", + "//internal:proto_time_encoding", "//internal:status_macros", "//internal:testing", "//testutil:util", diff --git a/eval/public/structs/cel_proto_wrap_util.cc b/eval/public/structs/cel_proto_wrap_util.cc index 25f0c41e8..8ff817b7d 100644 --- a/eval/public/structs/cel_proto_wrap_util.cc +++ b/eval/public/structs/cel_proto_wrap_util.cc @@ -40,15 +40,15 @@ #include "eval/public/structs/protobuf_value_factory.h" #include "eval/testutil/test_message.pb.h" #include "internal/overflow.h" -#include "internal/proto_util.h" +#include "internal/proto_time_encoding.h" namespace google::api::expr::runtime::internal { namespace { -using google::api::expr::internal::DecodeDuration; -using google::api::expr::internal::DecodeTime; -using google::api::expr::internal::EncodeTime; +using cel::internal::DecodeDuration; +using cel::internal::DecodeTime; +using cel::internal::EncodeTime; using google::protobuf::Any; using google::protobuf::BoolValue; using google::protobuf::BytesValue; @@ -411,7 +411,7 @@ google::protobuf::Message* MessageFromValue(const CelValue& value, Duration* dur if (!value.GetValue(&val)) { return nullptr; } - auto status = google::api::expr::internal::EncodeDuration(val, duration); + auto status = cel::internal::EncodeDuration(val, duration); if (!status.ok()) { return nullptr; } @@ -603,7 +603,7 @@ google::protobuf::Message* MessageFromValue(const CelValue& value, Value* json) // Convert duration values to a protobuf JSON format. absl::Duration val; if (value.GetValue(&val)) { - auto encode = google::api::expr::internal::EncodeDurationToString(val); + auto encode = cel::internal::EncodeDurationToString(val); if (!encode.ok()) { return nullptr; } @@ -635,7 +635,7 @@ google::protobuf::Message* MessageFromValue(const CelValue& value, Value* json) // Convert timestamp values to a protobuf JSON format. absl::Time val; if (value.GetValue(&val)) { - auto encode = google::api::expr::internal::EncodeTimeToString(val); + auto encode = cel::internal::EncodeTimeToString(val); if (!encode.ok()) { return nullptr; } diff --git a/eval/public/structs/cel_proto_wrap_util.h b/eval/public/structs/cel_proto_wrap_util.h index a03f6ba2f..e828d3917 100644 --- a/eval/public/structs/cel_proto_wrap_util.h +++ b/eval/public/structs/cel_proto_wrap_util.h @@ -15,12 +15,8 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAP_UTIL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAP_UTIL_H_ -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "google/protobuf/descriptor.h" #include "eval/public/cel_value.h" #include "eval/public/structs/protobuf_value_factory.h" -#include "internal/proto_util.h" namespace google::api::expr::runtime::internal { diff --git a/eval/public/structs/cel_proto_wrap_util_test.cc b/eval/public/structs/cel_proto_wrap_util_test.cc index c4d5e0762..3a3e61f03 100644 --- a/eval/public/structs/cel_proto_wrap_util_test.cc +++ b/eval/public/structs/cel_proto_wrap_util_test.cc @@ -34,7 +34,7 @@ #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/protobuf_value_factory.h" #include "eval/testutil/test_message.pb.h" -#include "internal/proto_util.h" +#include "internal/proto_time_encoding.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "testutil/util.h" @@ -172,7 +172,7 @@ TEST_F(CelProtoWrapperTest, TestDuration) { EXPECT_THAT(value.type(), Eq(CelValue::Type::kDuration)); Duration out; - auto status = expr::internal::EncodeDuration(value.DurationOrDie(), &out); + auto status = cel::internal::EncodeDuration(value.DurationOrDie(), &out); EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_duration)); } @@ -188,7 +188,7 @@ TEST_F(CelProtoWrapperTest, TestTimestamp) { EXPECT_TRUE(value.IsTimestamp()); Timestamp out; - auto status = expr::internal::EncodeTime(value.TimestampOrDie(), &out); + auto status = cel::internal::EncodeTime(value.TimestampOrDie(), &out); EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_timestamp)); } diff --git a/eval/public/structs/cel_proto_wrapper.h b/eval/public/structs/cel_proto_wrapper.h index 2d65155c5..ccfc19b8c 100644 --- a/eval/public/structs/cel_proto_wrapper.h +++ b/eval/public/structs/cel_proto_wrapper.h @@ -5,7 +5,7 @@ #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/descriptor.h" #include "eval/public/cel_value.h" -#include "internal/proto_util.h" +#include "internal/proto_time_encoding.h" namespace google::api::expr::runtime { @@ -23,12 +23,12 @@ class CelProtoWrapper { // CreateDuration creates CelValue from a non-null protobuf duration value. static CelValue CreateDuration(const google::protobuf::Duration* value) { - return CelValue(expr::internal::DecodeDuration(*value)); + return CelValue(cel::internal::DecodeDuration(*value)); } // CreateTimestamp creates CelValue from a non-null protobuf timestamp value. static CelValue CreateTimestamp(const google::protobuf::Timestamp* value) { - return CelValue(expr::internal::DecodeTime(*value)); + return CelValue(cel::internal::DecodeTime(*value)); } // MaybeWrapValue attempts to wrap the input value in a proto message with diff --git a/eval/public/structs/cel_proto_wrapper_test.cc b/eval/public/structs/cel_proto_wrapper_test.cc index 296c32949..b9a7fefde 100644 --- a/eval/public/structs/cel_proto_wrapper_test.cc +++ b/eval/public/structs/cel_proto_wrapper_test.cc @@ -19,7 +19,7 @@ #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/testutil/test_message.pb.h" -#include "internal/proto_util.h" +#include "internal/proto_time_encoding.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "testutil/util.h" @@ -162,7 +162,7 @@ TEST_F(CelProtoWrapperTest, TestDuration) { CelValue value = CelProtoWrapper::CreateDuration(&msg_duration); EXPECT_TRUE(value.IsDuration()); Duration out; - auto status = expr::internal::EncodeDuration(value.DurationOrDie(), &out); + auto status = cel::internal::EncodeDuration(value.DurationOrDie(), &out); EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_duration)); } @@ -183,7 +183,7 @@ TEST_F(CelProtoWrapperTest, TestTimestamp) { // CelValue value = CelValue::CreateString("test"); EXPECT_TRUE(value.IsTimestamp()); Timestamp out; - auto status = expr::internal::EncodeTime(value.TimestampOrDie(), &out); + auto status = cel::internal::EncodeTime(value.TimestampOrDie(), &out); EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_timestamp)); } diff --git a/eval/public/transform_utility.cc b/eval/public/transform_utility.cc index 1af0ac578..1a5cd5d6e 100644 --- a/eval/public/transform_utility.cc +++ b/eval/public/transform_utility.cc @@ -13,7 +13,7 @@ #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "internal/proto_util.h" +#include "internal/proto_time_encoding.h" #include "internal/status_macros.h" @@ -47,7 +47,7 @@ absl::Status CelValueToValue(const CelValue& value, Value* result) { case CelValue::Type::kDuration: { google::protobuf::Duration duration; auto status = - expr::internal::EncodeDuration(value.DurationOrDie(), &duration); + cel::internal::EncodeDuration(value.DurationOrDie(), &duration); if (!status.ok()) { return status; } @@ -57,7 +57,7 @@ absl::Status CelValueToValue(const CelValue& value, Value* result) { case CelValue::Type::kTimestamp: { google::protobuf::Timestamp timestamp; auto status = - expr::internal::EncodeTime(value.TimestampOrDie(), ×tamp); + cel::internal::EncodeTime(value.TimestampOrDie(), ×tamp); if (!status.ok()) { return status; } diff --git a/eval/public/value_export_util.cc b/eval/public/value_export_util.cc index c95ef2006..481c3301c 100644 --- a/eval/public/value_export_util.cc +++ b/eval/public/value_export_util.cc @@ -6,7 +6,7 @@ #include "google/protobuf/util/time_util.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" -#include "internal/proto_util.h" +#include "internal/proto_time_encoding.h" namespace google::api::expr::runtime { @@ -73,7 +73,7 @@ absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { case CelValue::Type::kDuration: { Duration duration; auto status = - expr::internal::EncodeDuration(in_value.DurationOrDie(), &duration); + cel::internal::EncodeDuration(in_value.DurationOrDie(), &duration); if (!status.ok()) { return status; } @@ -83,7 +83,7 @@ absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { case CelValue::Type::kTimestamp: { Timestamp timestamp; auto status = - expr::internal::EncodeTime(in_value.TimestampOrDie(), ×tamp); + cel::internal::EncodeTime(in_value.TimestampOrDie(), ×tamp); if (!status.ok()) { return status; } diff --git a/internal/BUILD b/internal/BUILD index e4981349b..8e7483aef 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -137,13 +137,11 @@ cc_library( hdrs = ["proto_util.h"], deps = [ ":status_macros", - ":time", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf", ], ) @@ -159,6 +157,33 @@ cc_test( ], ) +cc_library( + name = "proto_time_encoding", + srcs = ["proto_time_encoding.cc"], + hdrs = ["proto_time_encoding.h"], + deps = [ + ":status_macros", + ":time", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_time_encoding_test", + srcs = ["proto_time_encoding_test.cc"], + deps = [ + ":proto_time_encoding", + ":testing", + "@com_google_absl//absl/time", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "rtti", hdrs = ["rtti.h"], diff --git a/internal/proto_time_encoding.cc b/internal/proto_time_encoding.cc new file mode 100644 index 000000000..f61f3dbcd --- /dev/null +++ b/internal/proto_time_encoding.cc @@ -0,0 +1,102 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/proto_time_encoding.h" + +#include + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/util/time_util.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "internal/status_macros.h" +#include "internal/time.h" + +namespace cel::internal { + +namespace { + +absl::Status Validate(absl::Time time) { + if (time < cel::internal::MinTimestamp()) { + return absl::InvalidArgumentError("time below min"); + } + + if (time > cel::internal::MaxTimestamp()) { + return absl::InvalidArgumentError("time above max"); + } + return absl::OkStatus(); +} + +absl::Status CelValidateDuration(absl::Duration duration) { + if (duration < cel::internal::MinDuration()) { + return absl::InvalidArgumentError("duration below min"); + } + + if (duration > cel::internal::MaxDuration()) { + return absl::InvalidArgumentError("duration above max"); + } + return absl::OkStatus(); +} + +} // namespace + +absl::Duration DecodeDuration(const google::protobuf::Duration& proto) { + return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); +} + +absl::Time DecodeTime(const google::protobuf::Timestamp& proto) { + return absl::FromUnixSeconds(proto.seconds()) + + absl::Nanoseconds(proto.nanos()); +} + +absl::Status EncodeDuration(absl::Duration duration, + google::protobuf::Duration* proto) { + CEL_RETURN_IF_ERROR(CelValidateDuration(duration)); + // s and n may both be negative, per the Duration proto spec. + const int64_t s = absl::IDivDuration(duration, absl::Seconds(1), &duration); + const int64_t n = absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); + proto->set_seconds(s); + proto->set_nanos(n); + return absl::OkStatus(); +} + +absl::StatusOr EncodeDurationToString(absl::Duration duration) { + google::protobuf::Duration d; + auto status = EncodeDuration(duration, &d); + if (!status.ok()) { + return status; + } + return google::protobuf::util::TimeUtil::ToString(d); +} + +absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto) { + CEL_RETURN_IF_ERROR(Validate(time)); + const int64_t s = absl::ToUnixSeconds(time); + proto->set_seconds(s); + proto->set_nanos((time - absl::FromUnixSeconds(s)) / absl::Nanoseconds(1)); + return absl::OkStatus(); +} + +absl::StatusOr EncodeTimeToString(absl::Time time) { + google::protobuf::Timestamp t; + auto status = EncodeTime(time, &t); + if (!status.ok()) { + return status; + } + return google::protobuf::util::TimeUtil::ToString(t); +} + +} // namespace cel::internal diff --git a/internal/proto_time_encoding.h b/internal/proto_time_encoding.h new file mode 100644 index 000000000..aa4128ee7 --- /dev/null +++ b/internal/proto_time_encoding.h @@ -0,0 +1,49 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Defines basic encode/decode operations for proto time and duration formats. +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_TIME_ENCODING_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_TIME_ENCODING_H_ + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/time/time.h" + +namespace cel::internal { + +/** Helper function to encode a duration in a google::protobuf::Duration. */ +absl::Status EncodeDuration(absl::Duration duration, + google::protobuf::Duration* proto); + +/** Helper function to encode an absl::Duration to a JSON-formatted string. */ +absl::StatusOr EncodeDurationToString(absl::Duration duration); + +/** Helper function to encode a time in a google::protobuf::Timestamp. */ +absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto); + +/** Helper function to encode an absl::Time to a JSON-formatted string. */ +absl::StatusOr EncodeTimeToString(absl::Time time); + +/** Helper function to decode a duration from a google::protobuf::Duration. */ +absl::Duration DecodeDuration(const google::protobuf::Duration& proto); + +/** Helper function to decode a time from a google::protobuf::Timestamp. */ +absl::Time DecodeTime(const google::protobuf::Timestamp& proto); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_TIME_ENCODING_H_ diff --git a/internal/proto_time_encoding_test.cc b/internal/proto_time_encoding_test.cc new file mode 100644 index 000000000..19342354d --- /dev/null +++ b/internal/proto_time_encoding_test.cc @@ -0,0 +1,73 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/proto_time_encoding.h" + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "absl/time/time.h" +#include "internal/testing.h" + +namespace cel::internal { +namespace { + +using testing::EqualsProto; + +TEST(EncodeDuration, Basic) { + google::protobuf::Duration proto_duration; + ASSERT_OK( + EncodeDuration(absl::Seconds(2) + absl::Nanoseconds(3), &proto_duration)); + + EXPECT_THAT(proto_duration, EqualsProto("seconds: 2 nanos: 3")); +} + +TEST(EncodeDurationToString, Basic) { + ASSERT_OK_AND_ASSIGN( + std::string json, + EncodeDurationToString(absl::Seconds(5) + absl::Nanoseconds(2))); + EXPECT_EQ(json, "5.000000002s"); +} + +TEST(EncodeTime, Basic) { + google::protobuf::Timestamp proto_timestamp; + ASSERT_OK(EncodeTime(absl::FromUnixMillis(300000), &proto_timestamp)); + + EXPECT_THAT(proto_timestamp, EqualsProto("seconds: 300")); +} + +TEST(EncodeTimeToString, Basic) { + ASSERT_OK_AND_ASSIGN(std::string json, + EncodeTimeToString(absl::FromUnixMillis(80000))); + + EXPECT_EQ(json, "1970-01-01T00:01:20Z"); +} + +TEST(DecodeDuration, Basic) { + google::protobuf::Duration proto_duration; + proto_duration.set_seconds(450); + proto_duration.set_nanos(4); + + EXPECT_EQ(DecodeDuration(proto_duration), + absl::Seconds(450) + absl::Nanoseconds(4)); +} + +TEST(DecodeTime, Basic) { + google::protobuf::Timestamp proto_timestamp; + proto_timestamp.set_seconds(450); + + EXPECT_EQ(DecodeTime(proto_timestamp), absl::FromUnixSeconds(450)); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/proto_util.cc b/internal/proto_util.cc index 7bc7d049f..9353196ed 100644 --- a/internal/proto_util.cc +++ b/internal/proto_util.cc @@ -21,89 +21,15 @@ #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" -#include "google/protobuf/util/time_util.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "internal/status_macros.h" -#include "internal/time.h" namespace google { namespace api { namespace expr { namespace internal { -namespace { - -absl::Status Validate(absl::Time time) { - if (time < cel::internal::MinTimestamp()) { - return absl::InvalidArgumentError("time below min"); - } - - if (time > cel::internal::MaxTimestamp()) { - return absl::InvalidArgumentError("time above max"); - } - return absl::OkStatus(); -} - -absl::Status ValidateDuration(absl::Duration duration) { - if (duration < cel::internal::MinDuration()) { - return absl::InvalidArgumentError("duration below min"); - } - - if (duration > cel::internal::MaxDuration()) { - return absl::InvalidArgumentError("duration above max"); - } - return absl::OkStatus(); -} - -} // namespace - -absl::Duration DecodeDuration(const google::protobuf::Duration& proto) { - return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); -} - -absl::Time DecodeTime(const google::protobuf::Timestamp& proto) { - return absl::FromUnixSeconds(proto.seconds()) + - absl::Nanoseconds(proto.nanos()); -} - -absl::Status EncodeDuration(absl::Duration duration, - google::protobuf::Duration* proto) { - CEL_RETURN_IF_ERROR(ValidateDuration(duration)); - // s and n may both be negative, per the Duration proto spec. - const int64_t s = absl::IDivDuration(duration, absl::Seconds(1), &duration); - const int64_t n = absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); - proto->set_seconds(s); - proto->set_nanos(n); - return absl::OkStatus(); -} - -absl::StatusOr EncodeDurationToString(absl::Duration duration) { - google::protobuf::Duration d; - auto status = EncodeDuration(duration, &d); - if (!status.ok()) { - return status; - } - return google::protobuf::util::TimeUtil::ToString(d); -} - -absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto) { - CEL_RETURN_IF_ERROR(Validate(time)); - const int64_t s = absl::ToUnixSeconds(time); - proto->set_seconds(s); - proto->set_nanos((time - absl::FromUnixSeconds(s)) / absl::Nanoseconds(1)); - return absl::OkStatus(); -} - -absl::StatusOr EncodeTimeToString(absl::Time time) { - google::protobuf::Timestamp t; - auto status = EncodeTime(time, &t); - if (!status.ok()) { - return status; - } - return google::protobuf::util::TimeUtil::ToString(t); -} - absl::Status ValidateStandardMessageTypes( const google::protobuf::DescriptorPool& descriptor_pool) { CEL_RETURN_IF_ERROR( diff --git a/internal/proto_util.h b/internal/proto_util.h index 386d1309a..09cd66502 100644 --- a/internal/proto_util.h +++ b/internal/proto_util.h @@ -15,15 +15,12 @@ #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_UTIL_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_UTIL_H_ -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/timestamp.pb.h" #include "google/protobuf/descriptor.pb.h" #include "google/protobuf/util/message_differencer.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "absl/time/time.h" namespace google { namespace api { @@ -37,25 +34,6 @@ struct DefaultProtoEqual { } }; -/** Helper function to encode a duration in a google::protobuf::Duration. */ -absl::Status EncodeDuration(absl::Duration duration, - google::protobuf::Duration* proto); - -/** Helper function to encode an absl::Duration to a JSON-formatted string. */ -absl::StatusOr EncodeDurationToString(absl::Duration duration); - -/** Helper function to encode a time in a google::protobuf::Timestamp. */ -absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto); - -/** Helper function to encode an absl::Time to a JSON-formatted string. */ -absl::StatusOr EncodeTimeToString(absl::Time time); - -/** Helper function to decode a duration from a google::protobuf::Duration. */ -absl::Duration DecodeDuration(const google::protobuf::Duration& proto); - -/** Helper function to decode a time from a google::protobuf::Timestamp. */ -absl::Time DecodeTime(const google::protobuf::Timestamp& proto); - template absl::Status ValidateStandardMessageType( const google::protobuf::DescriptorPool& descriptor_pool) { From d13b104a4db9ca8a1b6e976d2e293c1ec08fd0d4 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Mon, 18 Apr 2022 19:02:52 +0000 Subject: [PATCH 077/111] Wire TypeInfoApis into CelValue::MessageWrapper. Not used anywhere yet. Includes miscellaneous build clean fixes to avoid cyclic dependency. PiperOrigin-RevId: 442598029 --- eval/public/BUILD | 5 + eval/public/cel_value.h | 15 +- eval/public/cel_value_internal.h | 30 +++- eval/public/cel_value_test.cc | 12 +- eval/public/containers/BUILD | 1 - .../internal_field_backed_map_impl.cc | 1 - eval/public/structs/BUILD | 30 +++- .../structs/cel_proto_wrap_util_test.cc | 6 +- eval/public/structs/cel_proto_wrapper.cc | 5 +- eval/public/structs/field_access_impl_test.cc | 135 +++++++++--------- .../structs/legacy_type_adapter_test.cc | 3 +- eval/public/structs/legacy_type_info_apis.h | 10 +- .../structs/proto_message_type_adapter.cc | 6 +- .../proto_message_type_adapter_test.cc | 43 +++--- .../public/structs/trivial_legacy_type_info.h | 56 ++++++++ .../structs/trivial_legacy_type_info_test.cc | 49 +++++++ 16 files changed, 286 insertions(+), 121 deletions(-) create mode 100644 eval/public/structs/trivial_legacy_type_info.h create mode 100644 eval/public/structs/trivial_legacy_type_info_test.cc diff --git a/eval/public/BUILD b/eval/public/BUILD index 899e4e4a6..064409f0f 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -474,12 +474,17 @@ cc_test( ], deps = [ ":cel_value", + ":cel_value_internal", ":unknown_attribute_set", ":unknown_set", "//base:memory_manager", + "//eval/public/structs:legacy_type_adapter", + "//eval/public/structs:legacy_type_info_apis", + "//eval/public/structs:trivial_legacy_type_info", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//extensions/protobuf:memory_manager", + "//internal:no_destructor", "//internal:status_macros", "//internal:testing", "@com_google_absl//absl/status", diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 345e22b04..d0ba11dbd 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -118,9 +118,10 @@ class CelValue { // MessageWrapper wraps a tagged MessageLite with the accessors used to // get field values. // - // message_ptr(): get the MessageLite pointer for the wrapper. + // message_ptr(): get the MessageLite pointer of the wrapped message. // - // access_apis(): get the accessors used for the type. + // legacy_type_info(): get type information about the wrapped message. see + // LegacyTypeInfoApis. // // HasFullProto(): returns whether it's safe to downcast to google::protobuf::Message. using MessageWrapper = internal::MessageWrapper; @@ -420,6 +421,7 @@ class CelValue { // make private visibility after refactors are done. static CelValue CreateMessageWrapper(MessageWrapper value) { CheckNullPointer(value.message_ptr(), Type::kMessage); + CheckNullPointer(value.legacy_type_info(), Type::kMessage); return CelValue(value); } @@ -462,18 +464,11 @@ class CelValue { template explicit CelValue(T value) : value_(value) {} - // Overloads for creating Message types. This should only be used by - // internal libraries. - static CelValue CreateMessage(const google::protobuf::Message* value) { - CheckNullPointer(value, Type::kMessage); - return CelValue(MessageWrapper(value)); - } - // This is provided for backwards compatibility with resolving null to message // overloads. static CelValue CreateNullMessage() { return CelValue( - MessageWrapper(static_cast(nullptr))); + MessageWrapper(static_cast(nullptr), nullptr)); } // Crashes with a null pointer error. diff --git a/eval/public/cel_value_internal.h b/eval/public/cel_value_internal.h index 52ad77ab1..1281635ee 100644 --- a/eval/public/cel_value_internal.h +++ b/eval/public/cel_value_internal.h @@ -27,7 +27,12 @@ #include "absl/types/variant.h" #include "internal/casts.h" -namespace google::api::expr::runtime::internal { +namespace google::api::expr::runtime { + +// Forward declare to resolve circular dependency. +class LegacyTypeInfoApis; + +namespace internal { // Helper classes needed for IndexOf metafunction implementation. template @@ -87,14 +92,19 @@ class MessageWrapper { public: static_assert(alignof(google::protobuf::MessageLite) >= 2, "Assume that valid MessageLite ptrs have a free low-order bit"); - MessageWrapper() : message_ptr_(0) {} - explicit MessageWrapper(const google::protobuf::MessageLite* message) - : message_ptr_(reinterpret_cast(message)) { + MessageWrapper() : message_ptr_(0), legacy_type_info_(nullptr) {} + + MessageWrapper(const google::protobuf::MessageLite* message, + const LegacyTypeInfoApis* legacy_type_info) + : message_ptr_(reinterpret_cast(message)), + legacy_type_info_(legacy_type_info) { ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); } - explicit MessageWrapper(const google::protobuf::Message* message) - : message_ptr_(reinterpret_cast(message) | kTagMask) { + MessageWrapper(const google::protobuf::Message* message, + const LegacyTypeInfoApis* legacy_type_info) + : message_ptr_(reinterpret_cast(message) | kTagMask), + legacy_type_info_(legacy_type_info) { ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); } @@ -105,10 +115,15 @@ class MessageWrapper { kPtrMask); } + const LegacyTypeInfoApis* legacy_type_info() const { + return legacy_type_info_; + } + private: static constexpr uintptr_t kTagMask = 1 << 0; static constexpr uintptr_t kPtrMask = ~kTagMask; uintptr_t message_ptr_; + const LegacyTypeInfoApis* legacy_type_info_; // TODO(issues/5): add LegacyTypeAccessApis to expose generic accessors for // MessageLite. }; @@ -136,6 +151,7 @@ struct MessageVisitAdapter { Op op; }; -} // namespace google::api::expr::runtime::internal +} // namespace internal +} // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_INTERNAL_H_ diff --git a/eval/public/cel_value_test.cc b/eval/public/cel_value_test.cc index 537ebc20b..6f542e47b 100644 --- a/eval/public/cel_value_test.cc +++ b/eval/public/cel_value_test.cc @@ -7,6 +7,9 @@ #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "base/memory_manager.h" +#include "eval/public/cel_value_internal.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/public/testing/matchers.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" @@ -378,26 +381,29 @@ TEST(CelValueTest, DebugString) { TEST(CelValueTest, Message) { TestMessage message; - auto value = - CelValue::CreateMessageWrapper(CelValue::MessageWrapper(&message)); + auto value = CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(&message, TrivialTypeInfo::GetInstance())); EXPECT_TRUE(value.IsMessage()); CelValue::MessageWrapper held; ASSERT_TRUE(value.GetValue(&held)); EXPECT_TRUE(held.HasFullProto()); EXPECT_EQ(held.message_ptr(), static_cast(&message)); + EXPECT_EQ(held.legacy_type_info(), TrivialTypeInfo::GetInstance()); } TEST(CelValueTest, MessageLite) { TestMessage message; // Upcast to message lite. const google::protobuf::MessageLite* ptr = &message; - auto value = CelValue::CreateMessageWrapper(CelValue::MessageWrapper(ptr)); + auto value = CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(ptr, TrivialTypeInfo::GetInstance())); EXPECT_TRUE(value.IsMessage()); CelValue::MessageWrapper held; ASSERT_TRUE(value.GetValue(&held)); EXPECT_FALSE(held.HasFullProto()); EXPECT_EQ(held.message_ptr(), &message); + EXPECT_EQ(held.legacy_type_info(), TrivialTypeInfo::GetInstance()); } TEST(CelValueTest, Size) { diff --git a/eval/public/containers/BUILD b/eval/public/containers/BUILD index 3eb5effe6..f75b314ae 100644 --- a/eval/public/containers/BUILD +++ b/eval/public/containers/BUILD @@ -202,7 +202,6 @@ cc_library( "internal_field_backed_map_impl.h", ], deps = [ - ":field_access", "//eval/public:cel_value", "//eval/public/structs:field_access_impl", "//eval/public/structs:protobuf_value_factory", diff --git a/eval/public/containers/internal_field_backed_map_impl.cc b/eval/public/containers/internal_field_backed_map_impl.cc index 2c837f64d..4eabb99ad 100644 --- a/eval/public/containers/internal_field_backed_map_impl.cc +++ b/eval/public/containers/internal_field_backed_map_impl.cc @@ -24,7 +24,6 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "eval/public/cel_value.h" -#include "eval/public/containers/field_access.h" #include "eval/public/structs/field_access_impl.h" #include "eval/public/structs/protobuf_value_factory.h" diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 60421cb25..23ee01efc 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -26,7 +26,9 @@ cc_library( ], deps = [ ":cel_proto_wrap_util", + ":proto_message_type_adapter", "//eval/public:cel_value", + "//eval/public:cel_value_internal", "//internal:proto_time_encoding", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", @@ -75,11 +77,15 @@ cc_test( ], deps = [ ":cel_proto_wrap_util", + ":legacy_type_info_apis", ":protobuf_value_factory", + ":trivial_legacy_type_info", "//eval/public:cel_value", + "//eval/public:cel_value_internal", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/testutil:test_message_cc_proto", + "//internal:no_destructor", "//internal:proto_time_encoding", "//internal:status_macros", "//internal:testing", @@ -207,6 +213,7 @@ cc_test( srcs = ["legacy_type_adapter_test.cc"], deps = [ ":legacy_type_adapter", + ":trivial_legacy_type_info", "//eval/public:cel_value", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", @@ -230,7 +237,6 @@ cc_library( "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:cel_value_internal", - "//eval/public/containers:field_access", "//eval/public/containers:internal_field_backed_list_impl", "//eval/public/containers:internal_field_backed_map_impl", "//extensions/protobuf:memory_manager", @@ -249,6 +255,7 @@ cc_test( deps = [ ":cel_proto_wrapper", ":legacy_type_adapter", + ":legacy_type_info_apis", ":proto_message_type_adapter", "//eval/public:cel_value", "//eval/public:cel_value_internal", @@ -300,3 +307,24 @@ cc_library( hdrs = ["legacy_type_info_apis.h"], deps = ["//eval/public:cel_value_internal"], ) + +cc_library( + name = "trivial_legacy_type_info", + testonly = True, + hdrs = ["trivial_legacy_type_info.h"], + deps = [ + ":legacy_type_info_apis", + "//eval/public:cel_value_internal", + "//internal:no_destructor", + ], +) + +cc_test( + name = "trivial_legacy_type_info_test", + srcs = ["trivial_legacy_type_info_test.cc"], + deps = [ + ":trivial_legacy_type_info", + "//eval/public:cel_value_internal", + "//internal:testing", + ], +) diff --git a/eval/public/structs/cel_proto_wrap_util_test.cc b/eval/public/structs/cel_proto_wrap_util_test.cc index 3a3e61f03..57c838746 100644 --- a/eval/public/structs/cel_proto_wrap_util_test.cc +++ b/eval/public/structs/cel_proto_wrap_util_test.cc @@ -30,10 +30,13 @@ #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "eval/public/cel_value.h" +#include "eval/public/cel_value_internal.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/protobuf_value_factory.h" +#include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/testutil/test_message.pb.h" +#include "internal/no_destructor.h" #include "internal/proto_time_encoding.h" #include "internal/status_macros.h" #include "internal/testing.h" @@ -66,7 +69,8 @@ using google::protobuf::UInt64Value; using google::protobuf::Arena; CelValue ProtobufValueFactoryImpl(const google::protobuf::Message* m) { - return CelValue::CreateMessageWrapper(CelValue::MessageWrapper(m)); + return CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(m, TrivialTypeInfo::GetInstance())); } class CelProtoWrapperTest : public ::testing::Test { diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index 496f134e8..07fb68945 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -17,7 +17,9 @@ #include "google/protobuf/message.h" #include "absl/types/optional.h" #include "eval/public/cel_value.h" +#include "eval/public/cel_value_internal.h" #include "eval/public/structs/cel_proto_wrap_util.h" +#include "eval/public/structs/proto_message_type_adapter.h" namespace google::api::expr::runtime { @@ -30,7 +32,8 @@ using ::google::protobuf::Message; } // namespace CelValue CelProtoWrapper::InternalWrapMessage(const Message* message) { - return CelValue::CreateMessage(message); + return CelValue::CreateMessageWrapper( + internal::MessageWrapper(message, &GetGenericProtoTypeInfoInstance())); } // CreateMessage creates CelValue from google::protobuf::Message. diff --git a/eval/public/structs/field_access_impl_test.cc b/eval/public/structs/field_access_impl_test.cc index caa697760..3036eb902 100644 --- a/eval/public/structs/field_access_impl_test.cc +++ b/eval/public/structs/field_access_impl_test.cc @@ -45,10 +45,6 @@ using testing::EqualsProto; using testing::HasSubstr; using cel::internal::StatusIs; -CelValue MessageValueFactory(const google::protobuf::Message* message) { - return CelValue::CreateMessageWrapper(CelValue::MessageWrapper(message)); -} - TEST(FieldAccessTest, SetDuration) { Arena arena; TestAllTypes msg; @@ -195,8 +191,8 @@ TEST_P(SingleFieldTest, Getter) { CreateValueFromSingleField( &test_message, test_message.GetDescriptor()->FindFieldByName(field_name().data()), - ProtoWrapperTypeOptions::kUnsetProtoDefault, &MessageValueFactory, - &arena)); + ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(accessed_value, test::EqualsCelValue(cel_value())); } @@ -255,8 +251,8 @@ TEST(CreateValueFromSingleFieldTest, GetMessage) { CreateValueFromSingleField( &test_message, test_message.GetDescriptor()->FindFieldByName("standalone_message"), - ProtoWrapperTypeOptions::kUnsetProtoDefault, &MessageValueFactory, - &arena)); + ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(accessed_value, test::IsCelMessage(EqualsProto("bb: 10"))); } @@ -372,7 +368,7 @@ TEST_P(RepeatedFieldTest, GetFirstElem) { CreateValueFromRepeatedField( &test_message, test_message.GetDescriptor()->FindFieldByName(field_name().data()), 0, - &MessageValueFactory, &arena)); + &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(accessed_value, test::EqualsCelValue(cel_value())); } @@ -427,7 +423,7 @@ TEST(RepeatedFieldTest, GetMessage) { &test_message, test_message.GetDescriptor()->FindFieldByName( "repeated_nested_message"), - 0, &MessageValueFactory, &arena)); + 0, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(accessed_value, test::IsCelMessage(EqualsProto("bb: 30"))); } @@ -507,11 +503,11 @@ TEST(CreateValueFromFieldTest, UnsetWrapperTypesNullIfEnabled) { for (const auto& field : kWrapperFieldNames) { ASSERT_OK_AND_ASSIGN( - result, - CreateValueFromSingleField( - &test_message, - TestAllTypes::GetDescriptor()->FindFieldByName(field), - ProtoWrapperTypeOptions::kUnsetNull, &MessageValueFactory, &arena)); + result, CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName(field), + ProtoWrapperTypeOptions::kUnsetNull, + &CelProtoWrapper::InternalWrapMessage, &arena)); ASSERT_TRUE(result.IsNull()) << field << ": " << result.DebugString(); } } @@ -529,7 +525,7 @@ TEST(CreateValueFromFieldTest, UnsetWrapperTypesDefaultValueIfDisabled) { &test_message, TestAllTypes::GetDescriptor()->FindFieldByName(field), ProtoWrapperTypeOptions::kUnsetProtoDefault, - &MessageValueFactory, &arena)); + &CelProtoWrapper::InternalWrapMessage, &arena)); ASSERT_FALSE(result.IsNull()) << field << ": " << result.DebugString(); } } @@ -560,85 +556,88 @@ TEST(CreateValueFromFieldTest, SetWrapperTypesDefaultValue) { CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName("single_bool_wrapper"), - ProtoWrapperTypeOptions::kUnsetNull, &MessageValueFactory, &arena)); + ProtoWrapperTypeOptions::kUnsetNull, + &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelBool(false)); - ASSERT_OK_AND_ASSIGN( - result, - CreateValueFromSingleField(&test_message, - TestAllTypes::GetDescriptor()->FindFieldByName( - "single_int64_wrapper"), - ProtoWrapperTypeOptions::kUnsetNull, - &MessageValueFactory, &arena)); + ASSERT_OK_AND_ASSIGN(result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_int64_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + &CelProtoWrapper::InternalWrapMessage, &arena)); + EXPECT_THAT(result, test::IsCelInt64(0)); + + ASSERT_OK_AND_ASSIGN(result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_int32_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelInt64(0)); ASSERT_OK_AND_ASSIGN( result, CreateValueFromSingleField(&test_message, TestAllTypes::GetDescriptor()->FindFieldByName( - "single_int32_wrapper"), + "single_uint64_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, - &MessageValueFactory, &arena)); - EXPECT_THAT(result, test::IsCelInt64(0)); - - ASSERT_OK_AND_ASSIGN( - result, CreateValueFromSingleField( - &test_message, - TestAllTypes::GetDescriptor()->FindFieldByName( - "single_uint64_wrapper"), - ProtoWrapperTypeOptions::kUnsetNull, &MessageValueFactory, + &CelProtoWrapper::InternalWrapMessage, - &arena)); - EXPECT_THAT(result, test::IsCelUint64(0)); - - ASSERT_OK_AND_ASSIGN( - result, CreateValueFromSingleField( - &test_message, - TestAllTypes::GetDescriptor()->FindFieldByName( - "single_uint32_wrapper"), - ProtoWrapperTypeOptions::kUnsetNull, &MessageValueFactory, - - &arena)); + &arena)); EXPECT_THAT(result, test::IsCelUint64(0)); ASSERT_OK_AND_ASSIGN( result, CreateValueFromSingleField(&test_message, TestAllTypes::GetDescriptor()->FindFieldByName( - "single_double_wrapper"), + "single_uint32_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, + &CelProtoWrapper::InternalWrapMessage, - &MessageValueFactory, &arena)); + &arena)); + EXPECT_THAT(result, test::IsCelUint64(0)); + + ASSERT_OK_AND_ASSIGN(result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_double_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, + + &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelDouble(0.0f)); - ASSERT_OK_AND_ASSIGN( - result, - CreateValueFromSingleField(&test_message, - TestAllTypes::GetDescriptor()->FindFieldByName( - "single_float_wrapper"), - ProtoWrapperTypeOptions::kUnsetNull, + ASSERT_OK_AND_ASSIGN(result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_float_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, - &MessageValueFactory, &arena)); + &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelDouble(0.0f)); - ASSERT_OK_AND_ASSIGN( - result, - CreateValueFromSingleField(&test_message, - TestAllTypes::GetDescriptor()->FindFieldByName( - "single_string_wrapper"), - ProtoWrapperTypeOptions::kUnsetNull, + ASSERT_OK_AND_ASSIGN(result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_string_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, - &MessageValueFactory, &arena)); + &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelString("")); - ASSERT_OK_AND_ASSIGN( - result, - CreateValueFromSingleField(&test_message, - TestAllTypes::GetDescriptor()->FindFieldByName( - "single_bytes_wrapper"), - ProtoWrapperTypeOptions::kUnsetNull, + ASSERT_OK_AND_ASSIGN(result, + CreateValueFromSingleField( + &test_message, + TestAllTypes::GetDescriptor()->FindFieldByName( + "single_bytes_wrapper"), + ProtoWrapperTypeOptions::kUnsetNull, - &MessageValueFactory, &arena)); + &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelBytes("")); } diff --git a/eval/public/structs/legacy_type_adapter_test.cc b/eval/public/structs/legacy_type_adapter_test.cc index b6fe9a7f5..c51289e51 100644 --- a/eval/public/structs/legacy_type_adapter_test.cc +++ b/eval/public/structs/legacy_type_adapter_test.cc @@ -16,6 +16,7 @@ #include "google/protobuf/arena.h" #include "eval/public/cel_value.h" +#include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" #include "extensions/protobuf/memory_manager.h" @@ -47,7 +48,7 @@ class TestMutationApiImpl : public LegacyTypeMutationApis { TEST(LegacyTypeAdapterMutationApis, DefaultNoopAdapt) { TestMessage message; - internal::MessageWrapper wrapper(&message); + internal::MessageWrapper wrapper(&message, TrivialTypeInfo::GetInstance()); google::protobuf::Arena arena; cel::extensions::ProtoMemoryManager manager(&arena); diff --git a/eval/public/structs/legacy_type_info_apis.h b/eval/public/structs/legacy_type_info_apis.h index 5971f23de..939dc8a94 100644 --- a/eval/public/structs/legacy_type_info_apis.h +++ b/eval/public/structs/legacy_type_info_apis.h @@ -49,9 +49,13 @@ class LegacyTypeInfoApis { const internal::MessageWrapper& wrapped_message) const = 0; // Return a pointer to the wrapped message's access api implementation. - // The CEL interpreter assumes that the is owned externally and will - // outlive any CelValues created by the interpreter. - // Nullptr means the value does not provide access apis. + // + // The CEL interpreter assumes that the returned pointer is owned externally + // and will outlive any CelValues created by the interpreter. + // + // Nullptr signals that the value does not provide access apis. For field + // access, the interpreter will treat this the same as accessing a field that + // is not defined for the type. virtual const LegacyTypeAccessApis* GetAccessApis( const internal::MessageWrapper& wrapped_message) const = 0; }; diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index 199feca9a..e8630eb63 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -24,7 +24,6 @@ #include "absl/strings/substitute.h" #include "eval/public/cel_value.h" #include "eval/public/cel_value_internal.h" -#include "eval/public/containers/field_access.h" #include "eval/public/containers/internal_field_backed_list_impl.h" #include "eval/public/containers/internal_field_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrap_util.h" @@ -184,7 +183,8 @@ class DucktypedMessageAdapter : public LegacyTypeAccessApis, }; CelValue MessageCelValueFactory(const google::protobuf::Message* message) { - return CelValue::CreateMessageWrapper(internal::MessageWrapper(message)); + return CelValue::CreateMessageWrapper(internal::MessageWrapper( + message, &DucktypedMessageAdapter::GetSingleton())); } } // namespace @@ -211,7 +211,7 @@ absl::StatusOr ProtoMessageTypeAdapter::NewInstance( return absl::InvalidArgumentError( absl::StrCat("Failed to create message ", descriptor_->name())); } - return CelValue::MessageWrapper(msg); + return CelValue::MessageWrapper(msg, &GetGenericProtoTypeInfoInstance()); } bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index de7208a4b..de3b90ad9 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -27,6 +27,7 @@ #include "eval/public/containers/field_access.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" #include "extensions/protobuf/memory_manager.h" @@ -75,7 +76,7 @@ TEST_P(ProtoMessageTypeAccessorTest, HasFieldSingular) { const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; - internal::MessageWrapper value(&example); + internal::MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.HasField("int64_value", value), IsOkAndHolds(false)); example.set_int64_value(10); @@ -88,7 +89,7 @@ TEST_P(ProtoMessageTypeAccessorTest, HasFieldRepeated) { TestMessage example; - internal::MessageWrapper value(&example); + internal::MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.HasField("int64_list", value), IsOkAndHolds(false)); example.add_int64_list(10); @@ -102,7 +103,7 @@ TEST_P(ProtoMessageTypeAccessorTest, HasFieldMap) { TestMessage example; example.set_int64_value(10); - internal::MessageWrapper value(&example); + internal::MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.HasField("int64_int32_map", value), IsOkAndHolds(false)); (*example.mutable_int64_int32_map())[2] = 3; @@ -116,7 +117,7 @@ TEST_P(ProtoMessageTypeAccessorTest, HasFieldUnknownField) { TestMessage example; example.set_int64_value(10); - internal::MessageWrapper value(&example); + internal::MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.HasField("unknown_field", value), StatusIs(absl::StatusCode::kNotFound)); @@ -127,7 +128,7 @@ TEST_P(ProtoMessageTypeAccessorTest, HasFieldNonMessageType) { const LegacyTypeAccessApis& accessor = GetAccessApis(); internal::MessageWrapper value( - static_cast(nullptr)); + static_cast(nullptr), nullptr); EXPECT_THAT(accessor.HasField("unknown_field", value), StatusIs(absl::StatusCode::kInternal)); @@ -142,7 +143,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldSingular) { TestMessage example; example.set_int64_value(10); - internal::MessageWrapper value(&example); + internal::MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.GetField("int64_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), @@ -158,7 +159,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldNoSuchField) { TestMessage example; example.set_int64_value(10); - internal::MessageWrapper value(&example); + internal::MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.GetField("unknown_field", value, ProtoWrapperTypeOptions::kUnsetNull, manager), @@ -173,7 +174,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldNotAMessage) { ProtoMemoryManager manager(&arena); internal::MessageWrapper value( - static_cast(nullptr)); + static_cast(nullptr), nullptr); EXPECT_THAT(accessor.GetField("int64_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), @@ -190,7 +191,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldRepeated) { example.add_int64_list(10); example.add_int64_list(20); - internal::MessageWrapper value(&example); + internal::MessageWrapper value(&example, nullptr); ASSERT_OK_AND_ASSIGN( CelValue result, @@ -214,7 +215,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldMap) { TestMessage example; (*example.mutable_int64_int32_map())[10] = 20; - internal::MessageWrapper value(&example); + internal::MessageWrapper value(&example, nullptr); ASSERT_OK_AND_ASSIGN( CelValue result, @@ -238,7 +239,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldWrapperType) { TestMessage example; example.mutable_int64_wrapper_value()->set_value(10); - internal::MessageWrapper value(&example); + internal::MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), @@ -253,7 +254,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldWrapperTypeUnsetNullUnbox) { TestMessage example; - internal::MessageWrapper value(&example); + internal::MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), @@ -275,7 +276,7 @@ TEST_P(ProtoMessageTypeAccessorTest, TestMessage example; - internal::MessageWrapper value(&example); + internal::MessageWrapper value(&example, nullptr); EXPECT_THAT( accessor.GetField("int64_wrapper_value", value, @@ -299,7 +300,7 @@ TEST(GetGenericProtoTypeInfoInstance, GetTypeName) { const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); TestMessage test_message; - CelValue::MessageWrapper wrapped_message(&test_message); + CelValue::MessageWrapper wrapped_message(&test_message, nullptr); EXPECT_EQ(info_api.GetTypename(wrapped_message), test_message.GetTypeName()); } @@ -309,7 +310,7 @@ TEST(GetGenericProtoTypeInfoInstance, DebugString) { TestMessage test_message; test_message.set_string_value("abcd"); - CelValue::MessageWrapper wrapped_message(&test_message); + CelValue::MessageWrapper wrapped_message(&test_message, nullptr); EXPECT_EQ(info_api.DebugString(wrapped_message), test_message.DebugString()); } @@ -319,7 +320,7 @@ TEST(GetGenericProtoTypeInfoInstance, GetAccessApis) { TestMessage test_message; test_message.set_string_value("abcd"); - CelValue::MessageWrapper wrapped_message(&test_message); + CelValue::MessageWrapper wrapped_message(&test_message, nullptr); auto* accessor = info_api.GetAccessApis(wrapped_message); google::protobuf::Arena arena; @@ -339,14 +340,14 @@ TEST(GetGenericProtoTypeInfoInstance, FallbackForNonMessage) { test_message.set_string_value("abcd"); // Upcast to signal no google::protobuf::Message / reflection support. CelValue::MessageWrapper wrapped_message( - static_cast(&test_message)); + static_cast(&test_message), nullptr); EXPECT_EQ(info_api.GetTypename(wrapped_message), ""); EXPECT_EQ(info_api.DebugString(wrapped_message), ""); // Check for not-null. CelValue::MessageWrapper null_message( - static_cast(nullptr)); + static_cast(nullptr), nullptr); EXPECT_EQ(info_api.GetTypename(null_message), ""); EXPECT_EQ(info_api.DebugString(null_message), ""); @@ -537,7 +538,7 @@ TEST(ProtoMesssageTypeAdapter, SetFieldNotAMessage) { CelValue int_value = CelValue::CreateInt64(42); CelValue::MessageWrapper instance( - static_cast(nullptr)); + static_cast(nullptr), nullptr); EXPECT_THAT(adapter.SetField("int64_value", int_value, manager, instance), StatusIs(absl::StatusCode::kInternal)); @@ -553,7 +554,7 @@ TEST(ProtoMesssageTypeAdapter, SetFieldNullMessage) { CelValue int_value = CelValue::CreateInt64(42); CelValue::MessageWrapper instance( - static_cast(nullptr)); + static_cast(nullptr), nullptr); EXPECT_THAT(adapter.SetField("int64_value", int_value, manager, instance), StatusIs(absl::StatusCode::kInternal)); @@ -607,7 +608,7 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeNotAMessageError) { ProtoMemoryManager manager(&arena); CelValue::MessageWrapper instance( - static_cast(nullptr)); + static_cast(nullptr), nullptr); // Interpreter guaranteed to call this with a message type, otherwise, // something has broken. diff --git a/eval/public/structs/trivial_legacy_type_info.h b/eval/public/structs/trivial_legacy_type_info.h new file mode 100644 index 000000000..eabff8858 --- /dev/null +++ b/eval/public/structs/trivial_legacy_type_info.h @@ -0,0 +1,56 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TRIVIAL_LEGACY_TYPE_INFO_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TRIVIAL_LEGACY_TYPE_INFO_H_ + +#include + +#include "eval/public/cel_value_internal.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "internal/no_destructor.h" + +namespace google::api::expr::runtime { + +// Implementation of type info APIs suitable for testing where no message +// operations need to be supported. +class TrivialTypeInfo : public LegacyTypeInfoApis { + public: + const std::string& GetTypename( + const internal::MessageWrapper& wrapper) const override { + static cel::internal::NoDestructor kTypename("opaque type"); + return *kTypename; + } + + std::string DebugString( + const internal::MessageWrapper& wrapper) const override { + return "opaque"; + } + + const LegacyTypeAccessApis* GetAccessApis( + const internal::MessageWrapper& wrapper) const override { + // Accessors unsupported -- caller should treat this as an opaque type (no + // fields defined, field access always results in a CEL error). + return nullptr; + } + + static const TrivialTypeInfo* GetInstance() { + static cel::internal::NoDestructor kInstance; + return &(kInstance.get()); + } +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TRIVIAL_LEGACY_TYPE_INFO_H_ diff --git a/eval/public/structs/trivial_legacy_type_info_test.cc b/eval/public/structs/trivial_legacy_type_info_test.cc new file mode 100644 index 000000000..36832e888 --- /dev/null +++ b/eval/public/structs/trivial_legacy_type_info_test.cc @@ -0,0 +1,49 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/structs/trivial_legacy_type_info.h" + +#include "eval/public/cel_value_internal.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +TEST(TrivialTypeInfo, GetTypename) { + TrivialTypeInfo info; + internal::MessageWrapper wrapper; + + EXPECT_EQ(info.GetTypename(wrapper), "opaque type"); + EXPECT_EQ(TrivialTypeInfo::GetInstance()->GetTypename(wrapper), + "opaque type"); +} + +TEST(TrivialTypeInfo, DebugString) { + TrivialTypeInfo info; + internal::MessageWrapper wrapper; + + EXPECT_EQ(info.DebugString(wrapper), "opaque"); + EXPECT_EQ(TrivialTypeInfo::GetInstance()->DebugString(wrapper), "opaque"); +} + +TEST(TrivialTypeInfo, GetAccessApis) { + TrivialTypeInfo info; + internal::MessageWrapper wrapper; + + EXPECT_EQ(info.GetAccessApis(wrapper), nullptr); + EXPECT_EQ(TrivialTypeInfo::GetInstance()->GetAccessApis(wrapper), nullptr); +} + +} // namespace +} // namespace google::api::expr::runtime From af70f58c41b7d323e0b985675e322a3dd2694a51 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 19 Apr 2022 17:52:08 +0000 Subject: [PATCH 078/111] Update CelValue to provide full typename and debug string via type info APIs instead of directly calling proto reflection APIs. PiperOrigin-RevId: 442854438 --- eval/public/BUILD | 1 + eval/public/cel_value.cc | 18 ++++++++++++------ eval/public/cel_value_test.cc | 5 +++++ .../public/structs/cel_proto_wrap_util_test.cc | 4 +++- .../structs/proto_message_type_adapter.cc | 2 +- .../structs/proto_message_type_adapter_test.cc | 3 ++- 6 files changed, 24 insertions(+), 9 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index 064409f0f..883a64b4a 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -43,6 +43,7 @@ cc_library( deps = [ ":cel_value_internal", "//base:memory_manager", + "//eval/public/structs:legacy_type_info_apis", "//extensions/protobuf:memory_manager", "//internal:casts", "//internal:status_macros", diff --git a/eval/public/cel_value.cc b/eval/public/cel_value.cc index d84993e00..5b12a7362 100644 --- a/eval/public/cel_value.cc +++ b/eval/public/cel_value.cc @@ -10,6 +10,8 @@ #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "base/memory_manager.h" +#include "eval/public/cel_value_internal.h" +#include "eval/public/structs/legacy_type_info_apis.h" #include "extensions/protobuf/memory_manager.h" namespace google::api::expr::runtime { @@ -71,8 +73,10 @@ struct DebugStringVisitor { return absl::StrFormat("%s", arg.value()); } - std::string operator()(const google::protobuf::Message* arg) { - return arg == nullptr ? "NULL" : arg->ShortDebugString(); + std::string operator()(const internal::MessageWrapper& arg) { + return arg.message_ptr() == nullptr + ? "NULL" + : arg.legacy_type_info()->DebugString(arg); } std::string operator()(absl::Duration arg) { @@ -199,13 +203,15 @@ CelValue CelValue::ObtainCelType() const { case Type::kBytes: return CreateCelType(CelTypeHolder(kBytesTypeName)); case Type::kMessage: { - auto msg = MessageOrDie(); - if (msg == nullptr) { + MessageWrapper wrapper; + CelValue::GetValue(&wrapper); + if (wrapper.message_ptr() == nullptr) { return CreateCelType(CelTypeHolder(kNullTypeName)); } // Descritptor::full_name() returns const reference, so using pointer // should be safe. - return CreateCelType(CelTypeHolder(msg->GetDescriptor()->full_name())); + return CreateCelType( + CelTypeHolder(wrapper.legacy_type_info()->GetTypename(wrapper))); } case Type::kDuration: return CreateCelType(CelTypeHolder(kDurationTypeName)); @@ -232,7 +238,7 @@ CelValue CelValue::ObtainCelType() const { // Returns debug string describing a value const std::string CelValue::DebugString() const { return absl::StrCat(CelValue::TypeName(type()), ": ", - Visit(DebugStringVisitor())); + InternalVisit(DebugStringVisitor())); } CelValue CreateErrorValue(cel::MemoryManager& manager, diff --git a/eval/public/cel_value_test.cc b/eval/public/cel_value_test.cc index 6f542e47b..683518563 100644 --- a/eval/public/cel_value_test.cc +++ b/eval/public/cel_value_test.cc @@ -390,6 +390,9 @@ TEST(CelValueTest, Message) { EXPECT_EQ(held.message_ptr(), static_cast(&message)); EXPECT_EQ(held.legacy_type_info(), TrivialTypeInfo::GetInstance()); + // TrivialTypeInfo doesn't provide any details about the specific message. + EXPECT_EQ(value.ObtainCelType().CelTypeOrDie().value(), "opaque type"); + EXPECT_EQ(value.DebugString(), "Message: opaque"); } TEST(CelValueTest, MessageLite) { @@ -404,6 +407,8 @@ TEST(CelValueTest, MessageLite) { EXPECT_FALSE(held.HasFullProto()); EXPECT_EQ(held.message_ptr(), &message); EXPECT_EQ(held.legacy_type_info(), TrivialTypeInfo::GetInstance()); + EXPECT_EQ(value.ObtainCelType().CelTypeOrDie().value(), "opaque type"); + EXPECT_EQ(value.DebugString(), "Message: opaque"); } TEST(CelValueTest, Size) { diff --git a/eval/public/structs/cel_proto_wrap_util_test.cc b/eval/public/structs/cel_proto_wrap_util_test.cc index 57c838746..1a9311a97 100644 --- a/eval/public/structs/cel_proto_wrap_util_test.cc +++ b/eval/public/structs/cel_proto_wrap_util_test.cc @@ -858,9 +858,11 @@ TEST_F(CelProtoWrapperTest, WrapFailureErrorToAny) { TEST_F(CelProtoWrapperTest, DebugString) { google::protobuf::Empty e; + // Note: the value factory is trivial so the debug string for a message-typed + // value is uninteresting. EXPECT_EQ(UnwrapMessageToValue(&e, &ProtobufValueFactoryImpl, arena()) .DebugString(), - "Message: "); + "Message: opaque"); ListValue list_value; list_value.add_values()->set_bool_value(true); diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index e8630eb63..8e32b9806 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -168,7 +168,7 @@ class DucktypedMessageAdapter : public LegacyTypeAccessApis, } auto* message = cel::internal::down_cast( wrapped_message.message_ptr()); - return message->DebugString(); + return message->ShortDebugString(); } const LegacyTypeAccessApis* GetAccessApis( diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index de3b90ad9..09d69dce4 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -312,7 +312,8 @@ TEST(GetGenericProtoTypeInfoInstance, DebugString) { test_message.set_string_value("abcd"); CelValue::MessageWrapper wrapped_message(&test_message, nullptr); - EXPECT_EQ(info_api.DebugString(wrapped_message), test_message.DebugString()); + EXPECT_EQ(info_api.DebugString(wrapped_message), + test_message.ShortDebugString()); } TEST(GetGenericProtoTypeInfoInstance, GetAccessApis) { From eb2dd10a8a33d61c7c4f75c98025eb1d90fa9bec Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 19 Apr 2022 18:24:32 +0000 Subject: [PATCH 079/111] Update select step to delegate to type_info instead of calling reflection APIs directly. PiperOrigin-RevId: 442863808 --- eval/eval/BUILD | 10 +- eval/eval/select_step.cc | 97 +++++--------- eval/eval/select_step_test.cc | 123 ++++++++++++++++++ .../structs/proto_message_type_adapter.cc | 4 +- 4 files changed, 163 insertions(+), 71 deletions(-) diff --git a/eval/eval/BUILD b/eval/eval/BUILD index c586118af..74d387b61 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -182,9 +182,8 @@ cc_library( ":expression_step_base", "//eval/public:cel_options", "//eval/public:cel_value", - "//eval/public/containers:field_access", - "//eval/public/containers:field_backed_list_impl", - "//eval/public/containers:field_backed_map_impl", + "//eval/public/structs:legacy_type_adapter", + "//eval/public/structs:legacy_type_info_apis", "//extensions/protobuf:memory_manager", "//internal:status_macros", "@com_google_absl//absl/memory", @@ -474,15 +473,20 @@ cc_test( ":test_type_registry", "//eval/public:activation", "//eval/public:cel_attribute", + "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", + "//eval/public/structs:legacy_type_adapter", + "//eval/public/structs:trivial_legacy_type_info", + "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//internal:status_macros", "//internal:testing", "//testutil:util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 55a72e563..cbb5d2751 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -2,6 +2,7 @@ #include #include +#include #include "absl/memory/memory.h" #include "absl/status/status.h" @@ -12,21 +13,14 @@ #include "eval/eval/expression_step_base.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" -#include "eval/public/containers/field_access.h" -#include "eval/public/containers/field_backed_list_impl.h" -#include "eval/public/containers/field_backed_map_impl.h" -#include "extensions/protobuf/memory_manager.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { -using ::cel::extensions::ProtoMemoryManager; -using ::google::protobuf::Descriptor; -using ::google::protobuf::FieldDescriptor; -using ::google::protobuf::Reflection; - // Common error for cases where evaluation attempts to perform select operations // on an unsupported type. // @@ -55,7 +49,7 @@ class SelectStep : public ExpressionStepBase { absl::Status Evaluate(ExecutionFrame* frame) const override; private: - absl::Status CreateValueFromField(const google::protobuf::Message& msg, + absl::Status CreateValueFromField(const CelValue::MessageWrapper& msg, cel::MemoryManager& manager, CelValue* result) const; @@ -65,34 +59,18 @@ class SelectStep : public ExpressionStepBase { ProtoWrapperTypeOptions unboxing_option_; }; -absl::Status SelectStep::CreateValueFromField(const google::protobuf::Message& msg, - cel::MemoryManager& manager, - CelValue* result) const { - const Descriptor* desc = msg.GetDescriptor(); - const FieldDescriptor* field_desc = desc->FindFieldByName(field_); - - if (field_desc == nullptr) { - *result = CreateNoSuchFieldError(manager, field_); - return absl::OkStatus(); - } - - google::protobuf::Arena* arena = ProtoMemoryManager::CastToProtoArena(manager); - - if (field_desc->is_map()) { - CelMap* map = google::protobuf::Arena::Create(arena, &msg, - field_desc, arena); - *result = CelValue::CreateMap(map); - return absl::OkStatus(); - } - if (field_desc->is_repeated()) { - CelList* list = google::protobuf::Arena::Create( - arena, &msg, field_desc, arena); - *result = CelValue::CreateList(list); +absl::Status SelectStep::CreateValueFromField( + const CelValue::MessageWrapper& msg, cel::MemoryManager& manager, + CelValue* result) const { + const LegacyTypeAccessApis* accessor = + msg.legacy_type_info()->GetAccessApis(msg); + if (accessor == nullptr) { + *result = CreateNoSuchFieldError(manager); return absl::OkStatus(); } - - return CreateValueFromSingleField(&msg, field_desc, unboxing_option_, arena, - result); + CEL_ASSIGN_OR_RETURN( + *result, accessor->GetField(field_, msg, unboxing_option_, manager)); + return absl::OkStatus(); } absl::optional CheckForMarkedAttributes(const AttributeTrail& trail, @@ -122,33 +100,19 @@ absl::optional CheckForMarkedAttributes(const AttributeTrail& trail, return absl::nullopt; } -CelValue TestOnlySelect(const google::protobuf::Message& msg, const std::string& field, - cel::MemoryManager& manager) { - const Reflection* reflection = msg.GetReflection(); - const Descriptor* desc = msg.GetDescriptor(); - const FieldDescriptor* field_desc = desc->FindFieldByName(field); - - if (field_desc == nullptr) { - return CreateNoSuchFieldError(manager, field); +CelValue TestOnlySelect(const CelValue::MessageWrapper& msg, + const std::string& field, cel::MemoryManager& manager) { + const LegacyTypeAccessApis* accessor = + msg.legacy_type_info()->GetAccessApis(msg); + if (accessor == nullptr) { + return CreateNoSuchFieldError(manager); } - - if (field_desc->is_map()) { - // When the map field appears in a has(msg.map_field) expression, the map - // is considered 'present' when it is non-empty. Since maps are repeated - // fields they don't participate with standard proto presence testing since - // the repeated field is always at least empty. - - return CelValue::CreateBool(reflection->FieldSize(msg, field_desc) != 0); - } - - if (field_desc->is_repeated()) { - // When the list field appears in a has(msg.list_field) expression, the list - // is considered 'present' when it is non-empty. - return CelValue::CreateBool(reflection->FieldSize(msg, field_desc) != 0); - } - // Standard proto presence test for non-repeated fields. - return CelValue::CreateBool(reflection->HasField(msg, field_desc)); + absl::StatusOr result = accessor->HasField(field, msg); + if (!result.ok()) { + return CreateErrorValue(manager, std::move(result).status()); + } + return CelValue::CreateBool(*result); } CelValue TestOnlySelect(const CelMap& map, const std::string& field_name, @@ -235,9 +199,9 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { frame->value_stack().PopAndPush( TestOnlySelect(*arg.MapOrDie(), field_, frame->memory_manager())); return absl::OkStatus(); - } else if (arg.IsMessage()) { + } else if (CelValue::MessageWrapper message; arg.GetValue(&message)) { frame->value_stack().PopAndPush( - TestOnlySelect(*arg.MessageOrDie(), field_, frame->memory_manager())); + TestOnlySelect(message, field_, frame->memory_manager())); return absl::OkStatus(); } } @@ -246,11 +210,12 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { // Select steps can be applied to either maps or messages switch (arg.type()) { case CelValue::Type::kMessage: { - // not null. - const google::protobuf::Message* msg = arg.MessageOrDie(); + CelValue::MessageWrapper wrapper; + bool success = arg.GetValue(&wrapper); + ABSL_ASSERT(success); CEL_RETURN_IF_ERROR( - CreateValueFromField(*msg, frame->memory_manager(), &result)); + CreateValueFromField(wrapper, frame->memory_manager(), &result)); frame->value_stack().PopAndPush(result, result_trail); return absl::OkStatus(); diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index 5b1fab4ff..efe202cc8 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -8,12 +8,17 @@ #include "google/protobuf/descriptor.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "eval/eval/ident_step.h" #include "eval/eval/test_type_registry.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/trivial_legacy_type_info.h" +#include "eval/public/testing/matchers.h" #include "eval/public/unknown_attribute_set.h" #include "eval/testutil/test_message.pb.h" #include "internal/status_macros.h" @@ -25,8 +30,10 @@ namespace google::api::expr::runtime { namespace { using ::google::api::expr::v1alpha1::Expr; +using testing::_; using testing::Eq; using testing::HasSubstr; +using testing::Return; using cel::internal::StatusIs; using testutil::EqualsProto; @@ -36,6 +43,30 @@ struct RunExpressionOptions { bool enable_wrapper_type_null_unboxing = false; }; +// Simple implementation LegacyTypeAccessApis / LegacyTypeInfoApis that allows +// mocking for getters/setters. +class MockAccessor : public LegacyTypeAccessApis, public LegacyTypeInfoApis { + public: + MOCK_METHOD(absl::StatusOr, HasField, + (absl::string_view field_name, + const CelValue::MessageWrapper& value), + (const override)); + MOCK_METHOD(absl::StatusOr, GetField, + (absl::string_view field_name, + const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManager& memory_manager), + (const override)); + MOCK_METHOD((const std::string&), GetTypename, + (const CelValue::MessageWrapper& instance), (const override)); + MOCK_METHOD(std::string, DebugString, + (const CelValue::MessageWrapper& instance), (const override)); + const LegacyTypeAccessApis* GetAccessApis( + const CelValue::MessageWrapper& instance) const override { + return this; + } +}; + // Helper method. Creates simple pipeline containing Select step and runs it. absl::StatusOr RunExpression(const CelValue target, absl::string_view field, bool test, @@ -418,6 +449,98 @@ TEST_P(SelectStepTest, SimpleMessageTest) { EXPECT_THAT(*message2, EqualsProto(*result.MessageOrDie())); } +TEST_P(SelectStepTest, NullMessageAccessor) { + TestMessage message; + TestMessage* message2 = message.mutable_message_value(); + message2->set_int32_value(1); + message2->set_string_value("test"); + google::protobuf::Arena arena; + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + CelValue value = CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(&message, TrivialTypeInfo::GetInstance())); + + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(value, "message_value", + /*test=*/false, &arena, + /*unknown_path=*/"", options)); + + ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kNotFound)); + + // same for has + ASSERT_OK_AND_ASSIGN(result, RunExpression(value, "message_value", + /*test=*/true, &arena, + /*unknown_path=*/"", options)); + + ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_P(SelectStepTest, CustomAccessor) { + TestMessage message; + TestMessage* message2 = message.mutable_message_value(); + message2->set_int32_value(1); + message2->set_string_value("test"); + google::protobuf::Arena arena; + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + testing::NiceMock accessor; + CelValue value = CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(&message, &accessor)); + + ON_CALL(accessor, GetField(_, _, _, _)) + .WillByDefault(Return(CelValue::CreateInt64(2))); + ON_CALL(accessor, HasField(_, _)).WillByDefault(Return(false)); + + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(value, "message_value", + /*test=*/false, &arena, + /*unknown_path=*/"", options)); + + EXPECT_THAT(result, test::IsCelInt64(2)); + + // testonly select (has) + ASSERT_OK_AND_ASSIGN(result, RunExpression(value, "message_value", + /*test=*/true, &arena, + /*unknown_path=*/"", options)); + + EXPECT_THAT(result, test::IsCelBool(false)); +} + +TEST_P(SelectStepTest, CustomAccessorErrorHandling) { + TestMessage message; + TestMessage* message2 = message.mutable_message_value(); + message2->set_int32_value(1); + message2->set_string_value("test"); + google::protobuf::Arena arena; + RunExpressionOptions options; + options.enable_unknowns = GetParam(); + testing::NiceMock accessor; + CelValue value = CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(&message, &accessor)); + + ON_CALL(accessor, GetField(_, _, _, _)) + .WillByDefault(Return(absl::InternalError("bad data"))); + ON_CALL(accessor, HasField(_, _)) + .WillByDefault(Return(absl::NotFoundError("not found"))); + + // For get field, implementation may return an error-type cel value or a + // status (e.g. broken assumption using a core type). + ASSERT_THAT(RunExpression(value, "message_value", + /*test=*/false, &arena, + /*unknown_path=*/"", options), + StatusIs(absl::StatusCode::kInternal)); + + // testonly select (has) errors are coerced to CelError. + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(value, "message_value", + /*test=*/true, &arena, + /*unknown_path=*/"", options)); + + EXPECT_THAT(result, test::IsCelError(StatusIs(absl::StatusCode::kNotFound))); +} + TEST_P(SelectStepTest, SimpleEnumTest) { TestMessage message; message.set_enum_value(TestMessage::TEST_ENUM_1); diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index 8e32b9806..a7ef932f9 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -176,7 +176,7 @@ class DucktypedMessageAdapter : public LegacyTypeAccessApis, return this; } - static DucktypedMessageAdapter& GetSingleton() { + static const DucktypedMessageAdapter& GetSingleton() { static cel::internal::NoDestructor instance; return *instance; } @@ -211,7 +211,7 @@ absl::StatusOr ProtoMessageTypeAdapter::NewInstance( return absl::InvalidArgumentError( absl::StrCat("Failed to create message ", descriptor_->name())); } - return CelValue::MessageWrapper(msg, &GetGenericProtoTypeInfoInstance()); + return internal::MessageWrapper(msg, &GetGenericProtoTypeInfoInstance()); } bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { From 52dd139aaf77c09eb463bce42cab0b886d3b87b2 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 19 Apr 2022 18:25:06 +0000 Subject: [PATCH 080/111] Add equality test to legacy type access apis. PiperOrigin-RevId: 442863972 --- eval/public/structs/legacy_type_adapter.h | 13 ++++++++ .../structs/legacy_type_adapter_test.cc | 30 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h index a5dfcfb6f..af06a72f1 100644 --- a/eval/public/structs/legacy_type_adapter.h +++ b/eval/public/structs/legacy_type_adapter.h @@ -79,6 +79,19 @@ class LegacyTypeAccessApis { absl::string_view field_name, const CelValue::MessageWrapper& instance, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManager& memory_manager) const = 0; + + // Interface for equality operator. + // The interpreter will check that both instances report to be the same type, + // but implementations should confirm that both instances are actually of the + // same type. + // If the two instances are of different type, return false. Otherwise, + // return whether they are equal. + // To conform to the CEL spec, message equality should follow the behavior of + // MessageDifferencer::Equals. + virtual bool IsEqual(const CelValue::MessageWrapper& instance, + const CelValue::MessageWrapper& other_instance) const { + return false; + } }; // Type information about a legacy Struct type. diff --git a/eval/public/structs/legacy_type_adapter_test.cc b/eval/public/structs/legacy_type_adapter_test.cc index c51289e51..69b03db25 100644 --- a/eval/public/structs/legacy_type_adapter_test.cc +++ b/eval/public/structs/legacy_type_adapter_test.cc @@ -46,6 +46,23 @@ class TestMutationApiImpl : public LegacyTypeMutationApis { } }; +class TestAccessApiImpl : public LegacyTypeAccessApis { + public: + TestAccessApiImpl() {} + absl::StatusOr HasField( + absl::string_view field_name, + const CelValue::MessageWrapper& value) const override { + return absl::UnimplementedError("Not implemented"); + } + + absl::StatusOr GetField( + absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManager& memory_manager) const override { + return absl::UnimplementedError("Not implemented"); + } +}; + TEST(LegacyTypeAdapterMutationApis, DefaultNoopAdapt) { TestMessage message; internal::MessageWrapper wrapper(&message, TrivialTypeInfo::GetInstance()); @@ -61,5 +78,18 @@ TEST(LegacyTypeAdapterMutationApis, DefaultNoopAdapt) { test::IsCelMessage(EqualsProto(TestMessage::default_instance()))); } +TEST(LegacyTypeAdapterAccessApis, DefaultAlwaysInequal) { + TestMessage message; + internal::MessageWrapper wrapper(&message, nullptr); + internal::MessageWrapper wrapper2(&message, nullptr); + + google::protobuf::Arena arena; + cel::extensions::ProtoMemoryManager manager(&arena); + + TestAccessApiImpl impl; + + EXPECT_FALSE(impl.IsEqual(wrapper, wrapper2)); +} + } // namespace } // namespace google::api::expr::runtime From c912af2f26e6ddc3186d68712293d2088aef43c1 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 19 Apr 2022 18:25:45 +0000 Subject: [PATCH 081/111] Add IsEqualTo Implementation to proto message type adapter. PiperOrigin-RevId: 442864180 --- eval/public/structs/legacy_type_adapter.h | 4 +- .../structs/legacy_type_adapter_test.cc | 2 +- .../structs/proto_message_type_adapter.cc | 40 ++++++++++ .../structs/proto_message_type_adapter.h | 3 + .../proto_message_type_adapter_test.cc | 75 +++++++++++++++++++ 5 files changed, 121 insertions(+), 3 deletions(-) diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h index af06a72f1..5250f1b70 100644 --- a/eval/public/structs/legacy_type_adapter.h +++ b/eval/public/structs/legacy_type_adapter.h @@ -88,8 +88,8 @@ class LegacyTypeAccessApis { // return whether they are equal. // To conform to the CEL spec, message equality should follow the behavior of // MessageDifferencer::Equals. - virtual bool IsEqual(const CelValue::MessageWrapper& instance, - const CelValue::MessageWrapper& other_instance) const { + virtual bool IsEqualTo(const CelValue::MessageWrapper& instance, + const CelValue::MessageWrapper& other_instance) const { return false; } }; diff --git a/eval/public/structs/legacy_type_adapter_test.cc b/eval/public/structs/legacy_type_adapter_test.cc index 69b03db25..f7632e032 100644 --- a/eval/public/structs/legacy_type_adapter_test.cc +++ b/eval/public/structs/legacy_type_adapter_test.cc @@ -88,7 +88,7 @@ TEST(LegacyTypeAdapterAccessApis, DefaultAlwaysInequal) { TestAccessApiImpl impl; - EXPECT_FALSE(impl.IsEqual(wrapper, wrapper2)); + EXPECT_FALSE(impl.IsEqualTo(wrapper, wrapper2)); } } // namespace diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index a7ef932f9..1a089b235 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -18,6 +18,7 @@ #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" +#include "google/protobuf/util/message_differencer.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -60,6 +61,15 @@ inline absl::StatusOr UnwrapMessage( return cel::internal::down_cast(value.message_ptr()); } +bool ProtoEquals(const google::protobuf::Message& m1, const google::protobuf::Message& m2) { + // Equality behavior is undefined for message differencer if input messages + // have different descriptors. For CEL just return false. + if (m1.GetDescriptor() != m2.GetDescriptor()) { + return false; + } + return google::protobuf::util::MessageDifferencer::Equals(m1, m2); +} + // Shared implementation for HasField. // Handles list or map specific behavior before calling reflection helpers. absl::StatusOr HasFieldImpl(const google::protobuf::Message* message, @@ -148,6 +158,21 @@ class DucktypedMessageAdapter : public LegacyTypeAccessApis, unboxing_option, memory_manager); } + bool IsEqualTo( + const CelValue::MessageWrapper& instance, + const CelValue::MessageWrapper& other_instance) const override { + absl::StatusOr lhs = + UnwrapMessage(instance, "IsEqualTo"); + absl::StatusOr rhs = + UnwrapMessage(other_instance, "IsEqualTo"); + if (!lhs.ok() || !rhs.ok()) { + // Treat this as though the underlying types are different, just return + // false. + return false; + } + return ProtoEquals(**lhs, **rhs); + } + // Implement TypeInfo Apis const std::string& GetTypename( const internal::MessageWrapper& wrapped_message) const override { @@ -325,6 +350,21 @@ absl::StatusOr ProtoMessageTypeAdapter::AdaptFromWellKnownType( arena); } +bool ProtoMessageTypeAdapter::IsEqualTo( + const CelValue::MessageWrapper& instance, + const CelValue::MessageWrapper& other_instance) const { + absl::StatusOr lhs = + UnwrapMessage(instance, "IsEqualTo"); + absl::StatusOr rhs = + UnwrapMessage(other_instance, "IsEqualTo"); + if (!lhs.ok() || !rhs.ok()) { + // Treat this as though the underlying types are different, just return + // false. + return false; + } + return ProtoEquals(**lhs, **rhs); +} + const LegacyTypeInfoApis& GetGenericProtoTypeInfoInstance() { return DucktypedMessageAdapter::GetSingleton(); } diff --git a/eval/public/structs/proto_message_type_adapter.h b/eval/public/structs/proto_message_type_adapter.h index 99e22e89a..5282a6119 100644 --- a/eval/public/structs/proto_message_type_adapter.h +++ b/eval/public/structs/proto_message_type_adapter.h @@ -59,6 +59,9 @@ class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, absl::string_view field_name, const CelValue::MessageWrapper& value) const override; + bool IsEqualTo(const CelValue::MessageWrapper& instance, + const CelValue::MessageWrapper& other_instance) const override; + private: // Helper for standardizing error messages for SetField operation. absl::Status ValidateSetFieldOp(bool assertion, absl::string_view field, diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 09d69dce4..001ff82ec 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -38,6 +38,7 @@ namespace google::api::expr::runtime { namespace { using ::cel::extensions::ProtoMemoryManager; +using ::google::protobuf::Int64Value; using testing::_; using testing::EqualsProto; using testing::HasSubstr; @@ -293,6 +294,80 @@ TEST_P(ProtoMessageTypeAccessorTest, IsOkAndHolds(test::IsCelInt64(_))); } +TEST_P(ProtoMessageTypeAccessorTest, IsEqualTo) { + google::protobuf::Arena arena; + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + ProtoMemoryManager manager(&arena); + + TestMessage example; + example.mutable_int64_wrapper_value()->set_value(10); + TestMessage example2; + example2.mutable_int64_wrapper_value()->set_value(10); + + internal::MessageWrapper value(&example, nullptr); + internal::MessageWrapper value2(&example2, nullptr); + + EXPECT_TRUE(accessor.IsEqualTo(value, value2)); + EXPECT_TRUE(accessor.IsEqualTo(value2, value)); +} + +TEST_P(ProtoMessageTypeAccessorTest, IsEqualToSameTypeInequal) { + google::protobuf::Arena arena; + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + ProtoMemoryManager manager(&arena); + + TestMessage example; + example.mutable_int64_wrapper_value()->set_value(10); + TestMessage example2; + example2.mutable_int64_wrapper_value()->set_value(12); + + internal::MessageWrapper value(&example, nullptr); + internal::MessageWrapper value2(&example2, nullptr); + + EXPECT_FALSE(accessor.IsEqualTo(value, value2)); + EXPECT_FALSE(accessor.IsEqualTo(value2, value)); +} + +TEST_P(ProtoMessageTypeAccessorTest, IsEqualToDifferentTypeInequal) { + google::protobuf::Arena arena; + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + ProtoMemoryManager manager(&arena); + + TestMessage example; + example.mutable_int64_wrapper_value()->set_value(10); + Int64Value example2; + example2.set_value(10); + + internal::MessageWrapper value(&example, nullptr); + internal::MessageWrapper value2(&example2, nullptr); + + EXPECT_FALSE(accessor.IsEqualTo(value, value2)); + EXPECT_FALSE(accessor.IsEqualTo(value2, value)); +} + +TEST_P(ProtoMessageTypeAccessorTest, IsEqualToNonMessageInequal) { + google::protobuf::Arena arena; + const LegacyTypeAccessApis& accessor = GetAccessApis(); + + ProtoMemoryManager manager(&arena); + + TestMessage example; + example.mutable_int64_wrapper_value()->set_value(10); + TestMessage example2; + example2.mutable_int64_wrapper_value()->set_value(10); + + internal::MessageWrapper value(&example, nullptr); + // Upcast to message lite to prevent unwrapping to message. + internal::MessageWrapper value2( + static_cast(&example2), nullptr); + + EXPECT_FALSE(accessor.IsEqualTo(value, value2)); + EXPECT_FALSE(accessor.IsEqualTo(value2, value)); +} + INSTANTIATE_TEST_SUITE_P(GenericAndSpecific, ProtoMessageTypeAccessorTest, testing::Bool()); From 1a5c9461b497a1b2786573b887fbcfa4f1979ba8 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 19 Apr 2022 18:49:15 +0000 Subject: [PATCH 082/111] Update C++ CEL interpreter == implementation to use the type defined == implementation. PiperOrigin-RevId: 442871028 --- eval/public/BUILD | 6 +++- eval/public/comparison_functions.cc | 31 +++++++++++------- eval/public/comparison_functions_test.cc | 40 ++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index 883a64b4a..f61db7fd5 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -267,8 +267,10 @@ cc_library( ":cel_number", ":cel_options", ":cel_value", + ":cel_value_internal", "//eval/eval:mutable_list_impl", - "//eval/public/containers:container_backed_list_impl", + "//eval/public/structs:legacy_type_adapter", + "//eval/public/structs:legacy_type_info_apis", "//internal:casts", "//internal:overflow", "//internal:status_macros", @@ -297,12 +299,14 @@ cc_test( ":cel_function_registry", ":cel_options", ":cel_value", + ":cel_value_internal", ":comparison_functions", ":set_util", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/containers:field_backed_list_impl", "//eval/public/structs:cel_proto_wrapper", + "//eval/public/structs:trivial_legacy_type_info", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//internal:status_macros", diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index cc4cd6faf..ff9705a66 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -21,8 +21,6 @@ #include #include -#include "google/protobuf/map_field.h" -#include "google/protobuf/util/message_differencer.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" @@ -38,7 +36,9 @@ #include "eval/public/cel_number.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" -#include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/cel_value_internal.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" #include "internal/casts.h" #include "internal/overflow.h" #include "internal/status_macros.h" @@ -51,7 +51,6 @@ namespace google::api::expr::runtime { namespace { using ::google::protobuf::Arena; -using ::google::protobuf::util::MessageDifferencer; // Forward declaration of the functors for generic equality operator. // Equal only defined for same-typed values. @@ -295,13 +294,22 @@ absl::optional Inequal(const CelMap* t1, const CelMap* t2) { return absl::nullopt; } -bool MessageEqual(const google::protobuf::Message& m1, const google::protobuf::Message& m2) { - // Equality behavior is undefined for message differencer if input messages - // have different descriptors. For CEL just return false. - if (m1.GetDescriptor() != m2.GetDescriptor()) { +bool MessageEqual(const CelValue::MessageWrapper& m1, + const CelValue::MessageWrapper& m2) { + const LegacyTypeInfoApis* lhs_type_info = m1.legacy_type_info(); + const LegacyTypeInfoApis* rhs_type_info = m2.legacy_type_info(); + + if (lhs_type_info->GetTypename(m1) != rhs_type_info->GetTypename(m2)) { + return false; + } + + const LegacyTypeAccessApis* accessor = lhs_type_info->GetAccessApis(m1); + + if (accessor == nullptr) { return false; } - return MessageDifferencer::Equals(m1, m2); + + return accessor->IsEqualTo(m1, m2); } // Generic equality for CEL values of the same type. @@ -572,8 +580,9 @@ absl::optional CelValueEqualImpl(const CelValue& v1, const CelValue& v2) { if (v1.type() == v2.type()) { // Message equality is only defined if heterogeneous comparions are enabled // to preserve the legacy behavior for equality. - if (v1.type() == CelValue::Type::kMessage) { - return MessageEqual(*v1.MessageOrDie(), *v2.MessageOrDie()); + if (CelValue::MessageWrapper lhs, rhs; + v1.GetValue(&lhs) && v2.GetValue(&rhs)) { + return MessageEqual(lhs, rhs); } return HomogenousCelValueEqual(v1, v2); } diff --git a/eval/public/comparison_functions_test.cc b/eval/public/comparison_functions_test.cc index c37d73a10..e26c025e3 100644 --- a/eval/public/comparison_functions_test.cc +++ b/eval/public/comparison_functions_test.cc @@ -44,11 +44,13 @@ #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" +#include "eval/public/cel_value_internal.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/containers/field_backed_list_impl.h" #include "eval/public/set_util.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" // IWYU pragma: keep #include "internal/status_macros.h" @@ -397,6 +399,44 @@ TEST(CelValueEqualImplTest, NestedMaps) { Optional(false)); } +TEST(CelValueEqualImplTest, ProtoEqualityDifferingTypenameInequal) { + // If message wrappers report a different typename, treat as inequal without + // calling into the provided equal implementation. + google::protobuf::Arena arena; + TestMessage example; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + int32_value: 1 + uint32_value: 2 + string_value: "test" + )", + &example)); + + CelValue lhs = CelProtoWrapper::CreateMessage(&example, &arena); + CelValue rhs = CelValue::CreateMessageWrapper( + internal::MessageWrapper(&example, TrivialTypeInfo::GetInstance())); + + EXPECT_THAT(CelValueEqualImpl(lhs, rhs), Optional(false)); +} + +TEST(CelValueEqualImplTest, ProtoEqualityNoAccessorInequal) { + // If message wrappers report no access apis, then treat as inequal. + google::protobuf::Arena arena; + TestMessage example; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( + int32_value: 1 + uint32_value: 2 + string_value: "test" + )", + &example)); + + CelValue lhs = CelValue::CreateMessageWrapper( + internal::MessageWrapper(&example, TrivialTypeInfo::GetInstance())); + CelValue rhs = CelValue::CreateMessageWrapper( + internal::MessageWrapper(&example, TrivialTypeInfo::GetInstance())); + + EXPECT_THAT(CelValueEqualImpl(lhs, rhs), Optional(false)); +} + TEST(CelValueEqualImplTest, ProtoEqualityAny) { google::protobuf::Arena arena; TestMessage packed_value; From c43bbc7b5698a3b8023392178a765617758c9383 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 20 Apr 2022 00:28:34 +0000 Subject: [PATCH 083/111] Introduce PortableCelExpressionBuilder. PiperOrigin-RevId: 442954156 --- eval/compiler/flat_expr_builder.h | 65 +++++---------- eval/compiler/flat_expr_builder_test.cc | 9 +-- eval/public/BUILD | 26 ++++++ eval/public/cel_expr_builder_factory.cc | 5 +- eval/public/cel_type_registry.cc | 8 ++ eval/public/cel_type_registry.h | 5 +- eval/public/cel_type_registry_test.cc | 20 +++++ .../portable_cel_expr_builder_factory.cc | 80 +++++++++++++++++++ .../portable_cel_expr_builder_factory.h | 39 +++++++++ .../portable_cel_expr_builder_factory_test.cc | 50 ++++++++++++ 10 files changed, 255 insertions(+), 52 deletions(-) create mode 100644 eval/public/portable_cel_expr_builder_factory.cc create mode 100644 eval/public/portable_cel_expr_builder_factory.h create mode 100644 eval/public/portable_cel_expr_builder_factory_test.cc diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index fc0c387f3..dee1cc189 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -19,7 +19,6 @@ #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/descriptor.h" #include "absl/status/statusor.h" #include "eval/public/cel_expression.h" @@ -29,29 +28,10 @@ namespace google::api::expr::runtime { // Builds instances of CelExpressionFlatImpl. class FlatExprBuilder : public CelExpressionBuilder { public: - explicit FlatExprBuilder(const google::protobuf::DescriptorPool* descriptor_pool = - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory* message_factory = - google::protobuf::MessageFactory::generated_factory()) - : CelExpressionBuilder(descriptor_pool), - enable_unknowns_(false), - enable_unknown_function_results_(false), - enable_missing_attribute_errors_(false), - shortcircuiting_(true), - constant_folding_(false), - constant_arena_(nullptr), - enable_comprehension_(true), - comprehension_max_iterations_(0), - fail_on_warnings_(true), - enable_qualified_type_identifiers_(false), - enable_comprehension_list_append_(false), - enable_comprehension_vulnerability_check_(false), - enable_null_coercion_(true), - enable_wrapper_type_null_unboxing_(false), - enable_heterogeneous_equality_(false), - enable_qualified_identifier_rewrites_(false), - descriptor_pool_(descriptor_pool), - message_factory_(message_factory) {} + FlatExprBuilder() : CelExpressionBuilder() {} + + explicit FlatExprBuilder(const google::protobuf::DescriptorPool* descriptor_pool) + : CelExpressionBuilder(descriptor_pool) {} // set_enable_unknowns controls support for unknowns in expressions created. void set_enable_unknowns(bool enabled) { enable_unknowns_ = enabled; } @@ -184,26 +164,23 @@ class FlatExprBuilder : public CelExpressionBuilder { std::vector* warnings) const; private: - bool enable_unknowns_; - bool enable_unknown_function_results_; - bool enable_missing_attribute_errors_; - bool shortcircuiting_; - - bool constant_folding_; - google::protobuf::Arena* constant_arena_; - bool enable_comprehension_; - int comprehension_max_iterations_; - bool fail_on_warnings_; - bool enable_qualified_type_identifiers_; - bool enable_comprehension_list_append_; - bool enable_comprehension_vulnerability_check_; - bool enable_null_coercion_; - bool enable_wrapper_type_null_unboxing_; - bool enable_heterogeneous_equality_; - bool enable_qualified_identifier_rewrites_; - - const google::protobuf::DescriptorPool* descriptor_pool_; - google::protobuf::MessageFactory* message_factory_; + bool enable_unknowns_ = false; + bool enable_unknown_function_results_ = false; + bool enable_missing_attribute_errors_ = false; + bool shortcircuiting_ = true; + + bool constant_folding_ = false; + google::protobuf::Arena* constant_arena_ = nullptr; + bool enable_comprehension_ = true; + int comprehension_max_iterations_ = 0; + bool fail_on_warnings_ = true; + bool enable_qualified_type_identifiers_ = false; + bool enable_comprehension_list_append_ = false; + bool enable_comprehension_vulnerability_check_ = false; + bool enable_null_coercion_ = true; + bool enable_wrapper_type_null_unboxing_ = false; + bool enable_heterogeneous_equality_ = false; + bool enable_qualified_identifier_rewrites_ = false; }; } // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index a30a98932..c2cbd4218 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1808,8 +1808,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { // This time, the message is unknown. We only have the proto as data, we did // not link the generated message, so it's not included in the generated pool. - FlatExprBuilder builder(google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory()); + FlatExprBuilder builder(google::protobuf::DescriptorPool::generated_pool()); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -1832,7 +1831,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { // This time, the message is *known*. We are using a custom descriptor pool // that has been primed with the relevant message. - FlatExprBuilder builder2(&desc_pool, &message_factory); + FlatExprBuilder builder2(&desc_pool); builder2.GetTypeRegistry()->RegisterTypeProvider( std::make_unique(&desc_pool, &message_factory)); @@ -1874,7 +1873,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForSelect) { // This time, the message is *known*. We are using a custom descriptor pool // that has been primed with the relevant message. - FlatExprBuilder builder(&desc_pool, &message_factory); + FlatExprBuilder builder(&desc_pool); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -1924,7 +1923,7 @@ TEST_P(CustomDescriptorPoolTest, TestType) { ASSERT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("m")); - FlatExprBuilder builder(&descriptor_pool, &message_factory); + FlatExprBuilder builder(&descriptor_pool); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique(&descriptor_pool, &message_factory)); diff --git a/eval/public/BUILD b/eval/public/BUILD index f61db7fd5..4092363fe 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -923,6 +923,32 @@ cc_library( ], ) +cc_library( + name = "portable_cel_expr_builder_factory", + srcs = ["portable_cel_expr_builder_factory.cc"], + hdrs = ["portable_cel_expr_builder_factory.h"], + deps = [ + ":cel_expression", + ":cel_options", + "//eval/compiler:flat_expr_builder", + "//eval/public/structs:legacy_type_provider", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "portable_cel_expr_builder_factory_test", + srcs = ["portable_cel_expr_builder_factory_test.cc"], + deps = [ + ":builtin_func_registrar", + ":portable_cel_expr_builder_factory", + "//eval/public/structs:cel_proto_descriptor_pool_builder", + "//eval/public/structs:protobuf_descriptor_type_provider", + "//internal:testing", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "cel_number_test", srcs = ["cel_number_test.cc"], diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 1fb0f23a5..3c517ba14 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -45,11 +45,11 @@ std::unique_ptr CreateCelExpressionBuilder( GOOGLE_LOG(WARNING) << "Failed to validate standard message types: " << s; return nullptr; } - auto builder = - absl::make_unique(descriptor_pool, message_factory); + auto builder = absl::make_unique(descriptor_pool); builder->GetTypeRegistry()->RegisterTypeProvider( std::make_unique(descriptor_pool, message_factory)); + // LINT.IfChange builder->set_shortcircuiting(options.short_circuiting); builder->set_constant_folding(options.constant_folding, options.constant_arena); @@ -85,6 +85,7 @@ std::unique_ptr CreateCelExpressionBuilder( builder->set_enable_missing_attribute_errors( options.enable_missing_attribute_errors); + // LINT.ThenChange(//depot/google3/eval/public/portable_cel_expr_builder_factory.cc) return builder; } diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc index 6bb7d335e..ccc3f5cad 100644 --- a/eval/public/cel_type_registry.cc +++ b/eval/public/cel_type_registry.cc @@ -62,6 +62,14 @@ void CelTypeRegistry::Register(const google::protobuf::EnumDescriptor* enum_desc enums_.insert(enum_descriptor); } +std::shared_ptr +CelTypeRegistry::GetFirstTypeProvider() const { + if (type_providers_.empty()) { + return nullptr; + } + return type_providers_[0]; +} + const google::protobuf::Descriptor* CelTypeRegistry::FindDescriptor( absl::string_view fully_qualified_type_name) const { // Public protobuf interface only accepts const string&. diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index 4e12c6440..95e0c2214 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -55,6 +55,9 @@ class CelTypeRegistry { type_providers_.push_back(std::move(provider)); } + // Get the first registered type provider. + std::shared_ptr GetFirstTypeProvider() const; + // Find a type adapter given a fully qualified type name. // Adapter provides a generic interface for the reflecion operations the // interpreter needs to provide. @@ -81,7 +84,7 @@ class CelTypeRegistry { // why a node_hash_set is used instead of another container type. absl::node_hash_set types_; absl::flat_hash_set enums_; - std::vector> type_providers_; + std::vector> type_providers_; }; } // namespace google::api::expr::runtime diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc index 50b73e6fa..7e8475279 100644 --- a/eval/public/cel_type_registry_test.cc +++ b/eval/public/cel_type_registry_test.cc @@ -68,6 +68,26 @@ TEST(CelTypeRegistryTest, TestRegisterTypeName) { EXPECT_THAT(type->CelTypeOrDie().value(), Eq("custom_type")); } +TEST(CelTypeRegistryTest, TestGetFirstTypeProviderSuccess) { + CelTypeRegistry registry; + registry.RegisterTypeProvider(std::make_unique( + std::vector{"google.protobuf.Int64"})); + registry.RegisterTypeProvider(std::make_unique( + std::vector{"google.protobuf.Any"})); + auto type_provider = registry.GetFirstTypeProvider(); + ASSERT_NE(type_provider, nullptr); + ASSERT_TRUE( + type_provider->ProvideLegacyType("google.protobuf.Int64").has_value()); + ASSERT_FALSE( + type_provider->ProvideLegacyType("google.protobuf.Any").has_value()); +} + +TEST(CelTypeRegistryTest, TestGetFirstTypeProviderFailureOnEmpty) { + CelTypeRegistry registry; + auto type_provider = registry.GetFirstTypeProvider(); + ASSERT_EQ(type_provider, nullptr); +} + TEST(CelTypeRegistryTest, TestFindTypeAdapterFound) { CelTypeRegistry registry; registry.RegisterTypeProvider(std::make_unique( diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc new file mode 100644 index 000000000..30320b48b --- /dev/null +++ b/eval/public/portable_cel_expr_builder_factory.cc @@ -0,0 +1,80 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "eval/public/portable_cel_expr_builder_factory.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "eval/compiler/flat_expr_builder.h" +#include "eval/public/cel_options.h" + +namespace google::api::expr::runtime { + +std::unique_ptr CreatePortableExprBuilder( + std::unique_ptr type_provider, + const InterpreterOptions& options) { + if (type_provider == nullptr) { + GOOGLE_LOG(ERROR) << "Cannot pass nullptr as type_provider to " + "CreateProtoLiteExprBuilder"; + return nullptr; + } + auto builder = absl::make_unique(); + builder->GetTypeRegistry()->RegisterTypeProvider(std::move(type_provider)); + // LINT.IfChange + builder->set_shortcircuiting(options.short_circuiting); + builder->set_constant_folding(options.constant_folding, + options.constant_arena); + builder->set_enable_comprehension(options.enable_comprehension); + builder->set_enable_comprehension_list_append( + options.enable_comprehension_list_append); + builder->set_comprehension_max_iterations( + options.comprehension_max_iterations); + builder->set_fail_on_warnings(options.fail_on_warnings); + builder->set_enable_qualified_type_identifiers( + options.enable_qualified_type_identifiers); + builder->set_enable_comprehension_vulnerability_check( + options.enable_comprehension_vulnerability_check); + builder->set_enable_null_coercion(options.enable_null_to_message_coercion); + builder->set_enable_wrapper_type_null_unboxing( + options.enable_empty_wrapper_null_unboxing); + builder->set_enable_heterogeneous_equality( + options.enable_heterogeneous_equality); + builder->set_enable_qualified_identifier_rewrites( + options.enable_qualified_identifier_rewrites); + + switch (options.unknown_processing) { + case UnknownProcessingOptions::kAttributeAndFunction: + builder->set_enable_unknown_function_results(true); + builder->set_enable_unknowns(true); + break; + case UnknownProcessingOptions::kAttributeOnly: + builder->set_enable_unknowns(true); + break; + case UnknownProcessingOptions::kDisabled: + break; + } + + builder->set_enable_missing_attribute_errors( + options.enable_missing_attribute_errors); + // LINT.ThenChange(//depot/google3/eval/public/cel_expr_builder_factory.cc) + + return builder; +} + +} // namespace google::api::expr::runtime diff --git a/eval/public/portable_cel_expr_builder_factory.h b/eval/public/portable_cel_expr_builder_factory.h new file mode 100644 index 000000000..84cd86d82 --- /dev/null +++ b/eval/public/portable_cel_expr_builder_factory.h @@ -0,0 +1,39 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_EXPR_BUILDER_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_EXPR_BUILDER_FACTORY_H_ + +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/structs/legacy_type_provider.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Factory creates CelExpressionBuilder implementation for public use. +std::unique_ptr CreatePortableExprBuilder( + std::unique_ptr type_provider, + const InterpreterOptions& options = InterpreterOptions()); + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_EXPR_BUILDER_FACTORY_H_ diff --git a/eval/public/portable_cel_expr_builder_factory_test.cc b/eval/public/portable_cel_expr_builder_factory_test.cc new file mode 100644 index 000000000..5382647f1 --- /dev/null +++ b/eval/public/portable_cel_expr_builder_factory_test.cc @@ -0,0 +1,50 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/portable_cel_expr_builder_factory.h" + +#include + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/structs/cel_proto_descriptor_pool_builder.h" +#include "eval/public/structs/protobuf_descriptor_type_provider.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +TEST(PortableCelExprBuilderFactoryTest, CreateNullOnMissingTypeProvider) { + std::unique_ptr builder = + CreatePortableExprBuilder(nullptr); + ASSERT_EQ(builder, nullptr); +} + +TEST(PortableCelExprBuilderFactoryTest, CreateSuccess) { + google::protobuf::DescriptorPool descriptor_pool; + google::protobuf::Arena arena; + + // Setup descriptor pool and builder + ASSERT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); + google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); + auto type_provider = std::make_unique( + &descriptor_pool, &message_factory); + std::unique_ptr builder = + CreatePortableExprBuilder(std::move(type_provider)); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); +} + +} // namespace +} // namespace google::api::expr::runtime From a12cac6cb7b79923170956c1f85c5cb5a468e1b9 Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 20 Apr 2022 21:37:03 +0000 Subject: [PATCH 084/111] Internal change PiperOrigin-RevId: 443196376 --- base/type.h | 8 +++--- base/type_test.cc | 6 ++--- base/value.cc | 14 +++++----- base/value.h | 21 ++++++--------- base/value_factory.h | 63 +++++++++++++++++++++++++++++++++++++++++--- base/value_test.cc | 24 ++++++++++------- 6 files changed, 98 insertions(+), 38 deletions(-) diff --git a/base/type.h b/base/type.h index 5a08b2706..2e9314278 100644 --- a/base/type.h +++ b/base/type.h @@ -67,6 +67,8 @@ class TimestampValue; class EnumValue; class StructValue; class ValueFactory; +class TypedEnumValueFactory; +class TypedStructValueFactory; namespace internal { template @@ -470,12 +472,12 @@ class EnumType : public Type { // Construct a new instance of EnumValue with a type of this. Called by // EnumValue::New. virtual absl::StatusOr> NewInstanceByName( - ValueFactory& value_factory, absl::string_view name) const = 0; + TypedEnumValueFactory& factory, absl::string_view name) const = 0; // Construct a new instance of EnumValue with a type of this. Called by // EnumValue::New. virtual absl::StatusOr> NewInstanceByNumber( - ValueFactory& value_factory, int64_t number) const = 0; + TypedEnumValueFactory& factory, int64_t number) const = 0; // Called by FindConstant. virtual absl::StatusOr FindConstantByName( @@ -573,7 +575,7 @@ class StructType : public Type { StructType() = default; virtual absl::StatusOr> NewInstance( - ValueFactory& value_factory) const = 0; + TypedStructValueFactory& factory) const = 0; // Called by FindField. virtual absl::StatusOr FindFieldByName( diff --git a/base/type_test.cc b/base/type_test.cc index 10f41caea..a1d2cc6b4 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -46,12 +46,12 @@ class TestEnumType final : public EnumType { protected: absl::StatusOr> NewInstanceByName( - ValueFactory& value_factory, absl::string_view name) const override { + TypedEnumValueFactory& factory, absl::string_view name) const override { return absl::UnimplementedError(""); } absl::StatusOr> NewInstanceByNumber( - ValueFactory& value_factory, int64_t number) const override { + TypedEnumValueFactory& factory, int64_t number) const override { return absl::UnimplementedError(""); } @@ -97,7 +97,7 @@ class TestStructType final : public StructType { protected: absl::StatusOr> NewInstance( - ValueFactory& value_factory) const override { + TypedStructValueFactory& factory) const override { return absl::UnimplementedError(""); } diff --git a/base/value.cc b/base/value.cc index c743a0772..ed9c7b017 100644 --- a/base/value.cc +++ b/base/value.cc @@ -769,16 +769,18 @@ void StringValue::HashValue(absl::HashState state) const { } struct EnumType::NewInstanceVisitor final { - const EnumType& enum_type; + const Persistent& enum_type; ValueFactory& value_factory; absl::StatusOr> operator()( absl::string_view name) const { - return enum_type.NewInstanceByName(value_factory, name); + TypedEnumValueFactory factory(value_factory, enum_type); + return enum_type->NewInstanceByName(factory, name); } absl::StatusOr> operator()(int64_t number) const { - return enum_type.NewInstanceByNumber(value_factory, number); + TypedEnumValueFactory factory(value_factory, enum_type); + return enum_type->NewInstanceByNumber(factory, number); } }; @@ -787,7 +789,7 @@ absl::StatusOr> EnumValue::New( EnumType::ConstantId id) { CEL_ASSIGN_OR_RETURN( auto enum_value, - absl::visit(EnumType::NewInstanceVisitor{*enum_type, value_factory}, + absl::visit(EnumType::NewInstanceVisitor{enum_type, value_factory}, id.data_)); if (!enum_value->type_) { // In case somebody is caching, we avoid setting the type_ if it has already @@ -849,8 +851,8 @@ struct StructValue::HasFieldVisitor final { absl::StatusOr> StructValue::New( const Persistent& struct_type, ValueFactory& value_factory) { - CEL_ASSIGN_OR_RETURN(auto struct_value, - struct_type->NewInstance(value_factory)); + TypedStructValueFactory factory(value_factory, struct_type); + CEL_ASSIGN_OR_RETURN(auto struct_value, struct_type->NewInstance(factory)); if (!struct_value->type_) { // In case somebody is caching, we avoid setting the type_ if it has already // been set, to avoid a race condition where one CPU sees a half written diff --git a/base/value.h b/base/value.h index d1285e322..2cb47a93d 100644 --- a/base/value.h +++ b/base/value.h @@ -56,7 +56,6 @@ class StructValue; class ListValue; class MapValue; class ValueFactory; -class TypedListValueFactory; namespace internal { template @@ -590,10 +589,7 @@ class EnumValue : public Value { const Persistent& enum_type, ValueFactory& value_factory, EnumType::ConstantId id); - Transient type() const final { - ABSL_ASSERT(type_); - return type_; - } + Transient type() const final { return type_; } Kind kind() const final { return Kind::kEnum; } @@ -602,7 +598,9 @@ class EnumValue : public Value { virtual absl::string_view name() const = 0; protected: - EnumValue() = default; + explicit EnumValue(const Persistent& type) : type_(type) { + ABSL_ASSERT(type_); + } private: friend internal::TypeInfo base_internal::GetEnumValueTypeId( @@ -626,7 +624,6 @@ class EnumValue : public Value { // Called by CEL_IMPLEMENT_ENUM_VALUE() and Is() to perform type checking. virtual internal::TypeInfo TypeId() const = 0; - // Set lazily, by EnumValue::New. Persistent type_; }; @@ -663,10 +660,7 @@ class StructValue : public Value { const Persistent& struct_type, ValueFactory& value_factory); - Transient type() const final { - ABSL_ASSERT(type_); - return type_; - } + Transient type() const final { return type_; } Kind kind() const final { return Kind::kStruct; } @@ -678,7 +672,9 @@ class StructValue : public Value { absl::StatusOr HasField(FieldId field) const; protected: - StructValue() = default; + explicit StructValue(const Persistent& type) : type_(type) { + ABSL_ASSERT(type_); + } virtual absl::Status SetFieldByName(absl::string_view name, const Persistent& value) = 0; @@ -725,7 +721,6 @@ class StructValue : public Value { // Called by CEL_IMPLEMENT_STRUCT_VALUE() and Is() to perform type checking. virtual internal::TypeInfo TypeId() const = 0; - // Set lazily, by StructValue::New. Persistent type_; }; diff --git a/base/value_factory.h b/base/value_factory.h index 0d1638f97..20829e2fb 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -140,16 +140,20 @@ class ValueFactory final { template EnableIfBaseOfT>> CreateEnumValue( + const Persistent& enum_type, Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { return base_internal::PersistentHandleFactory::template Make< - std::remove_const_t>(memory_manager(), std::forward(args)...); + std::remove_const_t>(memory_manager(), enum_type, + std::forward(args)...); } template EnableIfBaseOfT>> - CreateStructValue(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + CreateStructValue(const Persistent& struct_type, + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { return base_internal::PersistentHandleFactory::template Make< - std::remove_const_t>(memory_manager(), std::forward(args)...); + std::remove_const_t>(memory_manager(), struct_type, + std::forward(args)...); } template @@ -194,6 +198,59 @@ class ValueFactory final { MemoryManager& memory_manager_; }; +// TypedEnumValueFactory creates EnumValue scoped to a specific EnumType. Used +// with EnumType::NewInstance. +class TypedEnumValueFactory final { + private: + template + using EnableIfBaseOfT = + std::enable_if_t>, V>; + + public: + TypedEnumValueFactory( + ValueFactory& value_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, + const Persistent& enum_type ABSL_ATTRIBUTE_LIFETIME_BOUND) + : value_factory_(value_factory), enum_type_(enum_type) {} + + template + EnableIfBaseOfT>> CreateEnumValue( + Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return value_factory_.CreateEnumValue(enum_type_, + std::forward(args)...); + } + + private: + ValueFactory& value_factory_; + const Persistent& enum_type_; +}; + +// TypedStructValueFactory creates StructValue scoped to a specific StructType. +// Used with StructType::NewInstance. +class TypedStructValueFactory final { + private: + template + using EnableIfBaseOfT = + std::enable_if_t>, V>; + + public: + TypedStructValueFactory(ValueFactory& value_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const Persistent& enum_type + ABSL_ATTRIBUTE_LIFETIME_BOUND) + : value_factory_(value_factory), struct_type_(enum_type) {} + + template + EnableIfBaseOfT>> + CreateStructValue(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { + return value_factory_.CreateStructValue(struct_type_, + std::forward(args)...); + } + + private: + ValueFactory& value_factory_; + const Persistent& struct_type_; +}; + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_VALUE_FACTORY_H_ diff --git a/base/value_test.cc b/base/value_test.cc index e9f3a984e..0a90009d8 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -53,7 +53,9 @@ enum class TestEnum { class TestEnumValue final : public EnumValue { public: - explicit TestEnumValue(TestEnum test_enum) : test_enum_(test_enum) {} + explicit TestEnumValue(const Persistent& type, + TestEnum test_enum) + : EnumValue(type), test_enum_(test_enum) {} std::string DebugString() const override { return std::string(name()); } @@ -91,22 +93,22 @@ class TestEnumType final : public EnumType { protected: absl::StatusOr> NewInstanceByName( - ValueFactory& value_factory, absl::string_view name) const override { + TypedEnumValueFactory& factory, absl::string_view name) const override { if (name == "VALUE1") { - return value_factory.CreateEnumValue(TestEnum::kValue1); + return factory.CreateEnumValue(TestEnum::kValue1); } else if (name == "VALUE2") { - return value_factory.CreateEnumValue(TestEnum::kValue2); + return factory.CreateEnumValue(TestEnum::kValue2); } return absl::NotFoundError(""); } absl::StatusOr> NewInstanceByNumber( - ValueFactory& value_factory, int64_t number) const override { + TypedEnumValueFactory& factory, int64_t number) const override { switch (number) { case 1: - return value_factory.CreateEnumValue(TestEnum::kValue1); + return factory.CreateEnumValue(TestEnum::kValue1); case 2: - return value_factory.CreateEnumValue(TestEnum::kValue2); + return factory.CreateEnumValue(TestEnum::kValue2); default: return absl::NotFoundError(""); } @@ -149,7 +151,9 @@ H AbslHashValue(H state, const TestStruct& test_struct) { class TestStructValue final : public StructValue { public: - explicit TestStructValue(TestStruct value) : value_(std::move(value)) {} + explicit TestStructValue(const Persistent& type, + TestStruct value) + : StructValue(type), value_(std::move(value)) {} std::string DebugString() const override { return absl::StrCat("bool_field: ", value().bool_field, @@ -305,8 +309,8 @@ class TestStructType final : public StructType { protected: absl::StatusOr> NewInstance( - ValueFactory& value_factory) const override { - return value_factory.CreateStructValue(TestStruct{}); + TypedStructValueFactory& factory) const override { + return factory.CreateStructValue(TestStruct{}); } absl::StatusOr FindFieldByName(TypeManager& type_manager, From 64a3b9030705cb288c421e68d618ac6aa1fcd46b Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 20 Apr 2022 21:46:40 +0000 Subject: [PATCH 085/111] Make map lookup error actually say which key wasn't found. Currently the error just says "Key not found in map" twice. PiperOrigin-RevId: 443198721 --- eval/eval/container_access_step.cc | 5 ++--- eval/eval/container_access_step_test.cc | 5 +++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index cc0bdcb66..576508422 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -64,8 +64,7 @@ inline CelValue ContainerAccessStep::LookupInMap(const CelMap* cel_map, return *maybe_value; } } - return CreateNoSuchKeyError(frame->memory_manager(), - "Key not found in map"); + return CreateNoSuchKeyError(frame->memory_manager(), key.DebugString()); } } @@ -78,7 +77,7 @@ inline CelValue ContainerAccessStep::LookupInMap(const CelMap* cel_map, return maybe_value.value(); } - return CreateNoSuchKeyError(frame->memory_manager(), "Key not found in map"); + return CreateNoSuchKeyError(frame->memory_manager(), key.DebugString()); } inline CelValue ContainerAccessStep::LookupInList(const CelList* cel_list, diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index c6630d87b..f1aac2e61 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -36,6 +36,7 @@ using ::google::api::expr::v1alpha1::Expr; using ::google::api::expr::v1alpha1::SourceInfo; using ::google::protobuf::Struct; using testing::_; +using testing::AllOf; using testing::HasSubstr; using cel::internal::StatusIs; @@ -201,6 +202,10 @@ TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccessNotFound) { CelValue::CreateString(&kKey1), std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsError()); + EXPECT_THAT(*result.ErrorOrDie(), + StatusIs(absl::StatusCode::kNotFound, + AllOf(HasSubstr("Key not found in map : "), + HasSubstr("testkey1")))); } TEST_F(ContainerAccessStepTest, TestInvalidReceiverCreateContainerAccessStep) { From 60b68e81fc6d383d8dac0e5502c054865fe7f01f Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 28 Apr 2022 18:08:31 +0000 Subject: [PATCH 086/111] Make use of reflection APIs for CelTypeRegistry loading core enums optional. PiperOrigin-RevId: 445199066 --- eval/compiler/resolver.cc | 27 +++---- eval/public/BUILD | 2 + eval/public/cel_type_registry.cc | 70 +++++++++++++++-- eval/public/cel_type_registry.h | 23 +++++- eval/public/cel_type_registry_test.cc | 109 ++++++++++++++++++++++---- 5 files changed, 193 insertions(+), 38 deletions(-) diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index 426df40c1..97ed5ee9f 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -7,6 +7,7 @@ #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_value.h" @@ -39,25 +40,23 @@ Resolver::Resolver(absl::string_view container, } for (const auto& prefix : namespace_prefixes_) { - for (auto enum_desc : type_registry->Enums()) { - absl::string_view enum_name = enum_desc->full_name(); + for (auto iter = type_registry->enums_map().begin(); + iter != type_registry->enums_map().end(); ++iter) { + absl::string_view enum_name = iter->first; if (!absl::StartsWith(enum_name, prefix)) { continue; } auto remainder = absl::StripPrefix(enum_name, prefix); - for (int i = 0; i < enum_desc->value_count(); i++) { - auto value_desc = enum_desc->value(i); - if (value_desc) { - // "prefixes" container is ascending-ordered. As such, we will be - // assigning enum reference to the deepest available. - // E.g. if both a.b.c.Name and a.b.Name are available, and - // we try to reference "Name" with the scope of "a.b.c", - // it will be resolved to "a.b.c.Name". - auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "", - value_desc->name()); - enum_value_map_[key] = CelValue::CreateInt64(value_desc->number()); - } + for (const auto& enumerator : iter->second) { + // "prefixes" container is ascending-ordered. As such, we will be + // assigning enum reference to the deepest available. + // E.g. if both a.b.c.Name and a.b.Name are available, and + // we try to reference "Name" with the scope of "a.b.c", + // it will be resolved to "a.b.c.Name". + auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "", + enumerator.name); + enum_value_map_[key] = CelValue::CreateInt64(enumerator.number); } } } diff --git a/eval/public/BUILD b/eval/public/BUILD index 4092363fe..80e0c4bef 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -648,6 +648,8 @@ cc_library( deps = [ ":cel_value", "//eval/public/structs:legacy_type_provider", + "//internal:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/status", diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc index ccc3f5cad..e7a688ed3 100644 --- a/eval/public/cel_type_registry.cc +++ b/eval/public/cel_type_registry.cc @@ -1,6 +1,7 @@ #include "eval/public/cel_type_registry.h" #include +#include #include #include "google/protobuf/struct.pb.h" @@ -10,6 +11,7 @@ #include "absl/status/status.h" #include "absl/types/optional.h" #include "eval/public/cel_value.h" +#include "internal/no_destructor.h" namespace google::api::expr::runtime { @@ -32,12 +34,59 @@ const absl::node_hash_set& GetCoreTypes() { return *kCoreTypes; } -const absl::flat_hash_set GetCoreEnums() { - static const auto* const kCoreEnums = - new absl::flat_hash_set{ - // Register the NULL_VALUE enum. - google::protobuf::NullValue_descriptor(), - }; +using DescriptorSet = absl::flat_hash_set; +using EnumMap = + absl::flat_hash_map>; + +void AddEnumFromDescriptor(const google::protobuf::EnumDescriptor* desc, EnumMap& map) { + std::vector enumerators; + enumerators.reserve(desc->value_count()); + for (int i = 0; i < desc->value_count(); i++) { + enumerators.push_back({desc->value(i)->name(), desc->value(i)->number()}); + } + map.insert(std::pair(desc->full_name(), std::move(enumerators))); +} + +// Portable version. Add overloads for specfic core supported enums. +template +struct EnumAdderT { + template + void AddEnum(DescriptorSet&) {} + + template + void AddEnum(EnumMap&) {} + + template <> + void AddEnum(EnumMap& map) { + map["google.protobuf.NullValue"] = {{"NULL_VALUE", 0}}; + } +}; + +template +struct EnumAdderT, void>::type> { + template + void AddEnum(DescriptorSet& set) { + set.insert(google::protobuf::GetEnumDescriptor()); + } + + template + void AddEnum(EnumMap& map) { + const google::protobuf::EnumDescriptor* desc = google::protobuf::GetEnumDescriptor(); + AddEnumFromDescriptor(desc, map); + } +}; + +// Enable loading the linked descriptor if using the full proto runtime. +// Otherwise, only support explcitly defined enums. +using EnumAdder = EnumAdderT; + +const absl::flat_hash_set& GetCoreEnums() { + static cel::internal::NoDestructor kCoreEnums([]() { + absl::flat_hash_set instance; + EnumAdder().AddEnum(instance); + return instance; + }()); return *kCoreEnums; } @@ -46,12 +95,16 @@ const absl::flat_hash_set GetCoreEnums( CelTypeRegistry::CelTypeRegistry() : descriptor_pool_(google::protobuf::DescriptorPool::generated_pool()), types_(GetCoreTypes()), - enums_(GetCoreEnums()) {} + enums_(GetCoreEnums()) { + EnumAdder().AddEnum(enums_map_); +} CelTypeRegistry::CelTypeRegistry(const google::protobuf::DescriptorPool* descriptor_pool) : descriptor_pool_(descriptor_pool), types_(GetCoreTypes()), - enums_(GetCoreEnums()) {} + enums_(GetCoreEnums()) { + EnumAdder().AddEnum(enums_map_); +} void CelTypeRegistry::Register(std::string fully_qualified_type_name) { // Registers the fully qualified type name as a CEL type. @@ -60,6 +113,7 @@ void CelTypeRegistry::Register(std::string fully_qualified_type_name) { void CelTypeRegistry::Register(const google::protobuf::EnumDescriptor* enum_descriptor) { enums_.insert(enum_descriptor); + AddEnumFromDescriptor(enum_descriptor, enums_map_); } std::shared_ptr diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index 95e0c2214..b716ea448 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -2,9 +2,11 @@ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ #include +#include #include #include "google/protobuf/descriptor.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" #include "absl/status/statusor.h" @@ -28,6 +30,12 @@ namespace google::api::expr::runtime { // pools. class CelTypeRegistry { public: + // Internal representation for enumerators. + struct Enumerator { + std::string name; + int64_t number; + }; + CelTypeRegistry(); explicit CelTypeRegistry(const google::protobuf::DescriptorPool* descriptor_pool); @@ -74,16 +82,27 @@ class CelTypeRegistry { return enums_; } + // Return the registered enums configured within the type registry in the + // internal format. + const absl::flat_hash_map>& enums_map() + const { + return enums_map_; + } + private: // Find a protobuf Descriptor given a fully qualified protobuf type name. const google::protobuf::Descriptor* FindDescriptor( absl::string_view fully_qualified_type_name) const; const google::protobuf::DescriptorPool* descriptor_pool_; // externally owned - // pointer-stability is required for the strings in the types set, which is - // why a node_hash_set is used instead of another container type. + + // node_hash_set provides pointer-stability, which is required for the + // strings backing CelType objects. absl::node_hash_set types_; + // Set of registered enums. absl::flat_hash_set enums_; + // Internal representation for enums. + absl::flat_hash_map> enums_map_; std::vector> type_providers_; }; diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc index 7e8475279..afbce4301 100644 --- a/eval/public/cel_type_registry_test.cc +++ b/eval/public/cel_type_registry_test.cc @@ -2,9 +2,11 @@ #include #include +#include #include -#include "google/protobuf/any.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/message.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "eval/public/cel_value.h" @@ -16,7 +18,13 @@ namespace google::api::expr::runtime { namespace { +using testing::AllOf; +using testing::Contains; using testing::Eq; +using testing::IsEmpty; +using testing::Key; +using testing::Pair; +using testing::UnorderedElementsAre; class TestTypeProvider : public LegacyTypeProvider { public: @@ -39,18 +47,87 @@ class TestTypeProvider : public LegacyTypeProvider { std::vector types_; }; -TEST(CelTypeRegistryTest, TestRegisterEnumDescriptor) { - CelTypeRegistry registry; - registry.Register(TestMessage::TestEnum_descriptor()); +MATCHER_P(MatchesEnumDescriptor, desc, "") { + const std::vector& enumerators = arg; + + if (enumerators.size() != desc->value_count()) { + return false; + } + + for (int i = 0; i < desc->value_count(); i++) { + const auto* value_desc = desc->value(i); + const auto& enumerator = enumerators[i]; + + if (value_desc->name() != enumerator.name) { + return false; + } + if (value_desc->number() != enumerator.number) { + return false; + } + } + return true; +} + +MATCHER_P2(EqualsEnumerator, name, number, "") { + const CelTypeRegistry::Enumerator& enumerator = arg; + return enumerator.name == name && enumerator.number == number; +} + +// Portable build version. +// Full template specification. Default in case of substitution failure below. +template +struct RegisterEnumDescriptorTestT { + void Test() { + // Portable version doesn't support registering at this time. + CelTypeRegistry registry; + + EXPECT_THAT(registry.Enums(), IsEmpty()); + } +}; - absl::flat_hash_set enum_set; - for (auto enum_desc : registry.Enums()) { - enum_set.insert(enum_desc->full_name()); +// Full proto runtime version. +template +struct RegisterEnumDescriptorTestT< + T, typename std::enable_if>::type> { + void Test() { + CelTypeRegistry registry; + registry.Register(google::protobuf::GetEnumDescriptor()); + + absl::flat_hash_set enum_set; + for (auto enum_desc : registry.Enums()) { + enum_set.insert(enum_desc->full_name()); + } + absl::flat_hash_set expected_set{ + "google.protobuf.NullValue", + "google.api.expr.runtime.TestMessage.TestEnum"}; + EXPECT_THAT(enum_set, Eq(expected_set)); + + EXPECT_THAT( + registry.enums_map(), + AllOf( + Contains(Pair( + "google.protobuf.NullValue", + MatchesEnumDescriptor( + google::protobuf::GetEnumDescriptor()))), + Contains(Pair( + "google.api.expr.runtime.TestMessage.TestEnum", + MatchesEnumDescriptor( + google::protobuf::GetEnumDescriptor()))))); } - absl::flat_hash_set expected_set; - expected_set.insert({"google.protobuf.NullValue"}); - expected_set.insert({"google.api.expr.runtime.TestMessage.TestEnum"}); - EXPECT_THAT(enum_set, Eq(expected_set)); +}; + +using RegisterEnumDescriptorTest = RegisterEnumDescriptorTestT; + +TEST(CelTypeRegistryTest, RegisterEnumDescriptor) { + RegisterEnumDescriptorTest().Test(); +} + +TEST(CelTypeRegistryTest, TestRegisterBuiltInEnum) { + CelTypeRegistry registry; + + ASSERT_THAT(registry.enums_map(), Contains(Key("google.protobuf.NullValue"))); + EXPECT_THAT(registry.enums_map().at("google.protobuf.NullValue"), + UnorderedElementsAre(EqualsEnumerator("NULL_VALUE", 0))); } TEST(CelTypeRegistryTest, TestRegisterTypeName) { @@ -123,9 +200,13 @@ TEST(CelTypeRegistryTest, TestFindTypeCoreTypeFound) { TEST(CelTypeRegistryTest, TestFindTypeProtobufTypeFound) { CelTypeRegistry registry; auto type = registry.FindType("google.protobuf.Any"); - ASSERT_TRUE(type.has_value()); - EXPECT_TRUE(type->IsCelType()); - EXPECT_THAT(type->CelTypeOrDie().value(), Eq("google.protobuf.Any")); + if constexpr (std::is_base_of_v) { + ASSERT_TRUE(type.has_value()); + EXPECT_TRUE(type->IsCelType()); + EXPECT_THAT(type->CelTypeOrDie().value(), Eq("google.protobuf.Any")); + } else { + EXPECT_FALSE(type.has_value()); + } } TEST(CelTypeRegistryTest, TestFindTypeNotRegisteredTypeNotFound) { From 175a10c68d6d1ec98a60528c8a1753de79d59426 Mon Sep 17 00:00:00 2001 From: tswadell Date: Thu, 28 Apr 2022 23:18:14 +0000 Subject: [PATCH 087/111] Ensure that list and map equality account for identity equality efficiently, e.g. listA == listA PiperOrigin-RevId: 445275254 --- eval/public/comparison_functions.cc | 77 ++++++++++++++++------------- 1 file changed, 42 insertions(+), 35 deletions(-) diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index ff9705a66..a68c4e221 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -55,22 +56,22 @@ using ::google::protobuf::Arena; // Forward declaration of the functors for generic equality operator. // Equal only defined for same-typed values. struct HomogenousEqualProvider { - absl::optional operator()(const CelValue& v1, const CelValue& v2) const; + std::optional operator()(const CelValue& v1, const CelValue& v2) const; }; // Equal defined between compatible types. struct HeterogeneousEqualProvider { - absl::optional operator()(const CelValue& v1, const CelValue& v2) const; + std::optional operator()(const CelValue& v1, const CelValue& v2) const; }; // Comparison template functions template -absl::optional Inequal(Type t1, Type t2) { +std::optional Inequal(Type t1, Type t2) { return t1 != t2; } template -absl::optional Equal(Type t1, Type t2) { +std::optional Equal(Type t1, Type t2) { return t1 == t2; } @@ -96,12 +97,12 @@ bool GreaterThanOrEqual(Arena* arena, Type t1, Type t2) { // Duration comparison specializations template <> -absl::optional Inequal(absl::Duration t1, absl::Duration t2) { +std::optional Inequal(absl::Duration t1, absl::Duration t2) { return absl::operator!=(t1, t2); } template <> -absl::optional Equal(absl::Duration t1, absl::Duration t2) { +std::optional Equal(absl::Duration t1, absl::Duration t2) { return absl::operator==(t1, t2); } @@ -127,12 +128,12 @@ bool GreaterThanOrEqual(Arena*, absl::Duration t1, absl::Duration t2) { // Timestamp comparison specializations template <> -absl::optional Inequal(absl::Time t1, absl::Time t2) { +std::optional Inequal(absl::Time t1, absl::Time t2) { return absl::operator!=(t1, t2); } template <> -absl::optional Equal(absl::Time t1, absl::Time t2) { +std::optional Equal(absl::Time t1, absl::Time t2) { return absl::operator==(t1, t2); } @@ -191,7 +192,10 @@ bool MessageNullInequal(Arena* arena, const google::protobuf::Message* t1, // Equality for lists. Template parameter provides either heterogeneous or // homogenous equality for comparing members. template -absl::optional ListEqual(const CelList* t1, const CelList* t2) { +std::optional ListEqual(const CelList* t1, const CelList* t2) { + if (t1 == t2) { + return true; + } int index_size = t1->size(); if (t2->size() != index_size) { return false; @@ -200,7 +204,7 @@ absl::optional ListEqual(const CelList* t1, const CelList* t2) { for (int i = 0; i < index_size; i++) { CelValue e1 = (*t1)[i]; CelValue e2 = (*t2)[i]; - absl::optional eq = EqualsProvider()(e1, e2); + std::optional eq = EqualsProvider()(e1, e2); if (eq.has_value()) { if (!(*eq)) { return false; @@ -216,14 +220,14 @@ absl::optional ListEqual(const CelList* t1, const CelList* t2) { // Homogeneous CelList specific overload implementation for CEL ==. template <> -absl::optional Equal(const CelList* t1, const CelList* t2) { +std::optional Equal(const CelList* t1, const CelList* t2) { return ListEqual(t1, t2); } // Homogeneous CelList specific overload implementation for CEL !=. template <> -absl::optional Inequal(const CelList* t1, const CelList* t2) { - absl::optional eq = Equal(t1, t2); +std::optional Inequal(const CelList* t1, const CelList* t2) { + std::optional eq = Equal(t1, t2); if (eq.has_value()) { return !*eq; } @@ -233,7 +237,10 @@ absl::optional Inequal(const CelList* t1, const CelList* t2) { // Equality for maps. Template parameter provides either heterogeneous or // homogenous equality for comparing values. template -absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { +std::optional MapEqual(const CelMap* t1, const CelMap* t2) { + if (t1 == t2) { + return true; + } if (t1->size() != t2->size()) { return false; } @@ -242,7 +249,7 @@ absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { for (int i = 0; i < keys->size(); i++) { CelValue key = (*keys)[i]; CelValue v1 = (*t1)[key].value(); - absl::optional v2 = (*t2)[key]; + std::optional v2 = (*t2)[key]; if (!v2.has_value()) { auto number = GetNumberFromCelValue(key); if (!number.has_value()) { @@ -250,7 +257,7 @@ absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { } if (!key.IsInt64() && number->LosslessConvertibleToInt()) { CelValue int_key = CelValue::CreateInt64(number->AsInt()); - absl::optional eq = EqualsProvider()(key, int_key); + std::optional eq = EqualsProvider()(key, int_key); if (eq.has_value() && *eq) { v2 = (*t2)[int_key]; } @@ -258,7 +265,7 @@ absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { if (!key.IsUint64() && !v2.has_value() && number->LosslessConvertibleToUint()) { CelValue uint_key = CelValue::CreateUint64(number->AsUint()); - absl::optional eq = EqualsProvider()(key, uint_key); + std::optional eq = EqualsProvider()(key, uint_key); if (eq.has_value() && *eq) { v2 = (*t2)[uint_key]; } @@ -267,7 +274,7 @@ absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { if (!v2.has_value()) { return false; } - absl::optional eq = EqualsProvider()(v1, *v2); + std::optional eq = EqualsProvider()(v1, *v2); if (!eq.has_value() || !*eq) { // Shortcircuit on value comparison errors and 'false' results. return eq; @@ -279,19 +286,19 @@ absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { // Homogeneous CelMap specific overload implementation for CEL ==. template <> -absl::optional Equal(const CelMap* t1, const CelMap* t2) { +std::optional Equal(const CelMap* t1, const CelMap* t2) { return MapEqual(t1, t2); } // Homogeneous CelMap specific overload implementation for CEL !=. template <> -absl::optional Inequal(const CelMap* t1, const CelMap* t2) { - absl::optional eq = Equal(t1, t2); +std::optional Inequal(const CelMap* t1, const CelMap* t2) { + std::optional eq = Equal(t1, t2); if (eq.has_value()) { // Propagate comparison errors. return !*eq; } - return absl::nullopt; + return std::nullopt; } bool MessageEqual(const CelValue::MessageWrapper& m1, @@ -315,10 +322,10 @@ bool MessageEqual(const CelValue::MessageWrapper& m1, // Generic equality for CEL values of the same type. // EqualityProvider is used for equality among members of container types. template -absl::optional HomogenousCelValueEqual(const CelValue& t1, - const CelValue& t2) { +std::optional HomogenousCelValueEqual(const CelValue& t1, + const CelValue& t2) { if (t1.type() != t2.type()) { - return absl::nullopt; + return std::nullopt; } switch (t1.type()) { case CelValue::Type::kNullType: @@ -350,13 +357,13 @@ absl::optional HomogenousCelValueEqual(const CelValue& t1, default: break; } - return absl::nullopt; + return std::nullopt; } template std::function WrapComparison(Op op) { return [op = std::move(op)](Arena* arena, Type lhs, Type rhs) -> CelValue { - absl::optional result = op(lhs, rhs); + std::optional result = op(lhs, rhs); if (result.has_value()) { return CelValue::CreateBool(*result); @@ -484,7 +491,7 @@ absl::Status RegisterNullMessageEqualityFunctions( // Wrapper around CelValueEqualImpl to work with the FunctionAdapter template. // Implements CEL ==, CelValue GeneralizedEqual(Arena* arena, CelValue t1, CelValue t2) { - absl::optional result = CelValueEqualImpl(t1, t2); + std::optional result = CelValueEqualImpl(t1, t2); if (result.has_value()) { return CelValue::CreateBool(*result); } @@ -496,7 +503,7 @@ CelValue GeneralizedEqual(Arena* arena, CelValue t1, CelValue t2) { // Wrapper around CelValueEqualImpl to work with the FunctionAdapter template. // Implements CEL !=. CelValue GeneralizedInequal(Arena* arena, CelValue t1, CelValue t2) { - absl::optional result = CelValueEqualImpl(t1, t2); + std::optional result = CelValueEqualImpl(t1, t2); if (result.has_value()) { return CelValue::CreateBool(!*result); } @@ -561,12 +568,12 @@ absl::Status RegisterHeterogeneousComparisonFunctions( return absl::OkStatus(); } -absl::optional HomogenousEqualProvider::operator()( +std::optional HomogenousEqualProvider::operator()( const CelValue& v1, const CelValue& v2) const { return HomogenousCelValueEqual(v1, v2); } -absl::optional HeterogeneousEqualProvider::operator()( +std::optional HeterogeneousEqualProvider::operator()( const CelValue& v1, const CelValue& v2) const { return CelValueEqualImpl(v1, v2); } @@ -576,7 +583,7 @@ absl::optional HeterogeneousEqualProvider::operator()( // Equal operator is defined for all types at plan time. Runtime delegates to // the correct implementation for types or returns nullopt if the comparison // isn't defined. -absl::optional CelValueEqualImpl(const CelValue& v1, const CelValue& v2) { +std::optional CelValueEqualImpl(const CelValue& v1, const CelValue& v2) { if (v1.type() == v2.type()) { // Message equality is only defined if heterogeneous comparions are enabled // to preserve the legacy behavior for equality. @@ -587,8 +594,8 @@ absl::optional CelValueEqualImpl(const CelValue& v1, const CelValue& v2) { return HomogenousCelValueEqual(v1, v2); } - absl::optional lhs = GetNumberFromCelValue(v1); - absl::optional rhs = GetNumberFromCelValue(v2); + std::optional lhs = GetNumberFromCelValue(v1); + std::optional rhs = GetNumberFromCelValue(v2); if (rhs.has_value() && lhs.has_value()) { return *lhs == *rhs; @@ -598,7 +605,7 @@ absl::optional CelValueEqualImpl(const CelValue& v1, const CelValue& v2) { // map containing an Error. Return no matching overload to propagate an error // instead of a false result. if (v1.IsError() || v1.IsUnknownSet() || v2.IsError() || v2.IsUnknownSet()) { - return absl::nullopt; + return std::nullopt; } return false; From 6e957ec30a5627c8121ec78473bc583b40d6c768 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Mon, 2 May 2022 23:19:13 +0000 Subject: [PATCH 088/111] Add native types alternatives for the current proto type representations of the AST. PiperOrigin-RevId: 446056289 --- base/BUILD | 26 ++ base/ast.h | 983 +++++++++++++++++++++++++++++++++++++++++++++++ base/ast_test.cc | 179 +++++++++ 3 files changed, 1188 insertions(+) create mode 100644 base/ast.h create mode 100644 base/ast_test.cc diff --git a/base/BUILD b/base/BUILD index b8b6ff4e8..8e7c62c83 100644 --- a/base/BUILD +++ b/base/BUILD @@ -205,3 +205,29 @@ cc_test( "@com_google_absl//absl/time", ], ) + +cc_library( + name = "ast", + hdrs = [ + "ast.h", + ], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "ast_test", + srcs = [ + "ast_test.cc", + ], + deps = [ + ":ast", + "//internal:testing", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:variant", + ], +) diff --git a/base/ast.h b/base/ast.h new file mode 100644 index 000000000..7eb0dfc25 --- /dev/null +++ b/base/ast.h @@ -0,0 +1,983 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_AST_H_ +#define THIRD_PARTY_CEL_CPP_BASE_AST_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/container/flat_hash_map.h" +#include "absl/time/time.h" +#include "absl/types/variant.h" +namespace cel::ast::internal { + +enum class NullValue { kNullValue = 0 }; + +// Represents a primitive literal. +// +// This is similar as the primitives supported in the well-known type +// `google.protobuf.Value`, but richer so it can represent CEL's full range of +// primitives. +// +// Lists and structs are not included as constants as these aggregate types may +// contain [Expr][] elements which require evaluation and are thus not constant. +// +// Examples of constants include: `"hello"`, `b'bytes'`, `1u`, `4.2`, `-2`, +// `true`, `null`. +// +// (-- +// TODO(issues/5): Extend or replace the constant with a canonical Value +// message that can hold any constant object representation supplied or +// produced at evaluation time. +// --) +using Constant = absl::variant; + +class Expr; + +// An identifier expression. e.g. `request`. +class Ident { + public: + explicit Ident(std::string name) : name_(std::move(name)) {} + + void set_name(std::string name) { name_ = std::move(name); } + + const std::string& name() const { return name_; } + + private: + // Required. Holds a single, unqualified identifier, possibly preceded by a + // '.'. + // + // Qualified names are represented by the [Expr.Select][] expression. + std::string name_; +}; + +// A field selection expression. e.g. `request.auth`. +class Select { + public: + Select(std::unique_ptr operand, std::string field, + bool test_only = false) + : operand_(std::move(operand)), + field_(std::move(field)), + test_only_(test_only) {} + + void set_operand(std::unique_ptr operand) { + operand_ = std::move(operand); + } + + void set_field(std::string field) { field_ = std::move(field); } + + void set_test_only(bool test_only) { test_only_ = test_only; } + + const Expr* operand() const { return operand_.get(); } + + Expr& mutable_operand() { + ABSL_ASSERT(operand_ != nullptr); + return *operand_; + } + + const std::string& field() const { return field_; } + + bool test_only() const { return test_only_; } + + private: + // Required. The target of the selection expression. + // + // For example, in the select expression `request.auth`, the `request` + // portion of the expression is the `operand`. + std::unique_ptr operand_; + // Required. The name of the field to select. + // + // For example, in the select expression `request.auth`, the `auth` portion + // of the expression would be the `field`. + std::string field_; + // Whether the select is to be interpreted as a field presence test. + // + // This results from the macro `has(request.auth)`. + bool test_only_; +}; + +// A call expression, including calls to predefined functions and operators. +// +// For example, `value == 10`, `size(map_value)`. +// (-- TODO(issues/5): Convert built-in globals to instance methods --) +class Call { + public: + Call(std::unique_ptr target, std::string function, + std::vector args) + : target_(std::move(target)), + function_(std::move(function)), + args_(std::move(args)) {} + + void set_target(std::unique_ptr target) { target_ = std::move(target); } + + void set_function(std::string function) { function_ = std::move(function); } + + void set_args(std::vector args) { args_ = std::move(args); } + + const Expr* target() const { return target_.get(); } + + Expr& mutable_target() { + ABSL_ASSERT(target_ != nullptr); + return *target_; + } + + const std::string& function() const { return function_; } + + const std::vector& args() const { return args_; } + + std::vector& mutable_args() { return args_; } + + private: + // The target of an method call-style expression. For example, `x` in + // `x.f()`. + std::unique_ptr target_; + // Required. The name of the function or method being called. + std::string function_; + // The arguments. + std::vector args_; +}; + +// A list creation expression. +// +// Lists may either be homogenous, e.g. `[1, 2, 3]`, or heterogeneous, e.g. +// `dyn([1, 'hello', 2.0])` +// (-- +// TODO(issues/5): Determine how to disable heterogeneous types as a feature +// of type-checking rather than through the language construct 'dyn'. +// --) +class CreateList { + public: + CreateList() {} + explicit CreateList(std::vector elements) + : elements_(std::move(elements)) {} + + void set_elements(std::vector elements) { + elements_ = std::move(elements); + } + + const std::vector& elements() const { return elements_; } + + std::vector& mutable_elements() { return elements_; } + + private: + // The elements part of the list. + std::vector elements_; +}; + +// A map or message creation expression. +// +// Maps are constructed as `{'key_name': 'value'}`. Message construction is +// similar, but prefixed with a type name and composed of field ids: +// `types.MyType{field_id: 'value'}`. +class CreateStruct { + public: + // Represents an entry. + class Entry { + public: + Entry(int64_t id, + absl::variant> key_kind, + std::unique_ptr value) + : id_(id), key_kind_(std::move(key_kind)), value_(std::move(value)) {} + + void set_id(int64_t id) { id_ = id; } + + void set_key_kind( + absl::variant> key_kind) { + key_kind_ = std::move(key_kind); + } + + void set_value(std::unique_ptr value) { value_ = std::move(value); } + + int64_t id() const { return id_; } + + const absl::variant>& key_kind() const { + return key_kind_; + } + + absl::variant>& mutable_key_kind() { + return key_kind_; + } + + const Expr* value() const { return value_.get(); } + + Expr& mutable_value() { + ABSL_ASSERT(value_ != nullptr); + return *value_; + } + + private: + // Required. An id assigned to this node by the parser which is unique + // in a given expression tree. This is used to associate type + // information and other attributes to the node. + int64_t id_; + // The `Entry` key kinds. + absl::variant> key_kind_; + // Required. The value assigned to the key. + std::unique_ptr value_; + }; + + CreateStruct() {} + CreateStruct(std::string message_name, std::vector entries) + : message_name_(std::move(message_name)), entries_(std::move(entries)) {} + + void set_message_name(std::string message_name) { + message_name_ = std::move(message_name); + } + + void set_entries(std::vector entries) { + entries_ = std::move(entries); + } + + const std::vector& entries() const { return entries_; } + + std::vector& mutable_entries() { return entries_; } + + private: + // The type name of the message to be created, empty when creating map + // literals. + std::string message_name_; + // The entries in the creation expression. + std::vector entries_; +}; + +// A comprehension expression applied to a list or map. +// +// Comprehensions are not part of the core syntax, but enabled with macros. +// A macro matches a specific call signature within a parsed AST and replaces +// the call with an alternate AST block. Macro expansion happens at parse +// time. +// +// The following macros are supported within CEL: +// +// Aggregate type macros may be applied to all elements in a list or all keys +// in a map: +// +// * `all`, `exists`, `exists_one` - test a predicate expression against +// the inputs and return `true` if the predicate is satisfied for all, +// any, or only one value `list.all(x, x < 10)`. +// * `filter` - test a predicate expression against the inputs and return +// the subset of elements which satisfy the predicate: +// `payments.filter(p, p > 1000)`. +// * `map` - apply an expression to all elements in the input and return the +// output aggregate type: `[1, 2, 3].map(i, i * i)`. +// +// The `has(m.x)` macro tests whether the property `x` is present in struct +// `m`. The semantics of this macro depend on the type of `m`. For proto2 +// messages `has(m.x)` is defined as 'defined, but not set`. For proto3, the +// macro tests whether the property is set to its default. For map and struct +// types, the macro tests whether the property `x` is defined on `m`. +// +// Comprehension evaluation can be best visualized as the following +// pseudocode: +// +// ``` +// let `accu_var` = `accu_init` +// for (let `iter_var` in `iter_range`) { +// if (!`loop_condition`) { +// break +// } +// `accu_var` = `loop_step` +// } +// return `result` +// ``` +// +// (-- +// TODO(issues/5): ensure comprehensions work equally well on maps and +// messages. +// --) +class Comprehension { + public: + Comprehension() {} + Comprehension(std::string iter_var, std::unique_ptr iter_range, + std::string accu_var, std::unique_ptr accu_init, + std::unique_ptr loop_condition, + std::unique_ptr loop_step, std::unique_ptr result) + : iter_var_(std::move(iter_var)), + iter_range_(std::move(iter_range)), + accu_var_(std::move(accu_var)), + accu_init_(std::move(accu_init)), + loop_condition_(std::move(loop_condition)), + loop_step_(std::move(loop_step)), + result_(std::move(result)) {} + + void set_iter_var(std::string iter_var) { iter_var_ = std::move(iter_var); } + + void set_iter_range(std::unique_ptr iter_range) { + iter_range_ = std::move(iter_range); + } + + void set_accu_var(std::string accu_var) { accu_var_ = std::move(accu_var); } + + void set_accu_init(std::unique_ptr accu_init) { + accu_init_ = std::move(accu_init); + } + + void set_loop_condition(std::unique_ptr loop_condition) { + loop_condition_ = std::move(loop_condition); + } + + void set_loop_step(std::unique_ptr loop_step) { + loop_step_ = std::move(loop_step); + } + + void set_result(std::unique_ptr result) { result_ = std::move(result); } + + const std::string& iter_var() const { return iter_var_; } + + const Expr* iter_range() const { return iter_range_.get(); } + + Expr& mutable_iter_range() { + ABSL_ASSERT(iter_range_ != nullptr); + return *iter_range_; + } + + const std::string& accu_var() const { return accu_var_; } + + const Expr* accu_init() const { return accu_init_.get(); } + + Expr& mutable_accu_init() { + ABSL_ASSERT(accu_init_ != nullptr); + return *accu_init_; + } + + const Expr* loop_condition() const { return loop_condition_.get(); } + + Expr& mutable_loop_condition() { + ABSL_ASSERT(loop_condition_ != nullptr); + return *loop_condition_; + } + + const Expr* loop_step() const { return loop_step_.get(); } + + Expr& mutable_loop_step() { + ABSL_ASSERT(loop_step_ != nullptr); + return *loop_step_; + } + + const Expr* result() const { return result_.get(); } + + Expr& mutable_result() { + ABSL_ASSERT(result_ != nullptr); + return *result_; + } + + private: + // The name of the iteration variable. + std::string iter_var_; + + // The range over which var iterates. + std::unique_ptr iter_range_; + + // The name of the variable used for accumulation of the result. + std::string accu_var_; + + // The initial value of the accumulator. + std::unique_ptr accu_init_; + + // An expression which can contain iter_var and accu_var. + // + // Returns false when the result has been computed and may be used as + // a hint to short-circuit the remainder of the comprehension. + std::unique_ptr loop_condition_; + + // An expression which can contain iter_var and accu_var. + // + // Computes the next value of accu_var. + std::unique_ptr loop_step_; + + // An expression which can contain accu_var. + // + // Computes the result. + std::unique_ptr result_; +}; + +using ExprKind = absl::variant; + +// Analogous to google::api::expr::v1alpha1::Expr +// An abstract representation of a common expression. +// +// Expressions are abstractly represented as a collection of identifiers, +// select statements, function calls, literals, and comprehensions. All +// operators with the exception of the '.' operator are modelled as function +// calls. This makes it easy to represent new operators into the existing AST. +// +// All references within expressions must resolve to a [Decl][] provided at +// type-check for an expression to be valid. A reference may either be a bare +// identifier `name` or a qualified identifier `google.api.name`. References +// may either refer to a value or a function declaration. +// +// For example, the expression `google.api.name.startsWith('expr')` references +// the declaration `google.api.name` within a [Expr.Select][] expression, and +// the function declaration `startsWith`. +// Move-only type. +class Expr { + public: + Expr() {} + Expr(int64_t id, ExprKind expr_kind) + : id_(id), expr_kind_(std::move(expr_kind)) {} + + Expr(Expr&& rhs) = default; + Expr& operator=(Expr&& rhs) = default; + + void set_id(int64_t id) { id_ = id; } + + void set_expr_kind(ExprKind expr_kind) { expr_kind_ = std::move(expr_kind); } + + int64_t id() const { return id_; } + + const ExprKind& expr_kind() const { return expr_kind_; } + + ExprKind& mutable_expr_kind() { return expr_kind_; } + + private: + // Required. An id assigned to this node by the parser which is unique in a + // given expression tree. This is used to associate type information and other + // attributes to a node in the parse tree. + int64_t id_ = 0; + // Required. Variants of expressions. + ExprKind expr_kind_; +}; + +// Source information collected at parse time. +class SourceInfo { + public: + SourceInfo() {} + SourceInfo(std::string syntax_version, std::string location, + std::vector line_offsets, + absl::flat_hash_map positions, + absl::flat_hash_map macro_calls) + : syntax_version_(std::move(syntax_version)), + location_(std::move(location)), + line_offsets_(std::move(line_offsets)), + positions_(std::move(positions)), + macro_calls_(std::move(macro_calls)) {} + + void set_syntax_version(std::string syntax_version) { + syntax_version_ = std::move(syntax_version); + } + + void set_location(std::string location) { location_ = std::move(location); } + + void set_line_offsets(std::vector line_offsets) { + line_offsets_ = std::move(line_offsets); + } + + void set_positions(absl::flat_hash_map positions) { + positions_ = std::move(positions); + } + + void set_macro_calls(absl::flat_hash_map macro_calls) { + macro_calls_ = std::move(macro_calls); + } + + const std::string& syntax_version() const { return syntax_version_; } + + const std::string& location() const { return location_; } + + const std::vector& line_offsets() const { return line_offsets_; } + + std::vector& mutable_line_offsets() { return line_offsets_; } + + const absl::flat_hash_map& positions() const { + return positions_; + } + + absl::flat_hash_map& mutable_positions() { + return positions_; + } + + const absl::flat_hash_map& macro_calls() const { + return macro_calls_; + } + + absl::flat_hash_map& mutable_macro_calls() { + return macro_calls_; + } + + private: + // The syntax version of the source, e.g. `cel1`. + std::string syntax_version_; + + // The location name. All position information attached to an expression is + // relative to this location. + // + // The location could be a file, UI element, or similar. For example, + // `acme/app/AnvilPolicy.cel`. + std::string location_; + + // Monotonically increasing list of code point offsets where newlines + // `\n` appear. + // + // The line number of a given position is the index `i` where for a given + // `id` the `line_offsets[i] < id_positions[id] < line_offsets[i+1]`. The + // column may be derivd from `id_positions[id] - line_offsets[i]`. + // + // TODO(issues/5): clarify this documentation + std::vector line_offsets_; + + // A map from the parse node id (e.g. `Expr.id`) to the code point offset + // within source. + absl::flat_hash_map positions_; + + // A map from the parse node id where a macro replacement was made to the + // call `Expr` that resulted in a macro expansion. + // + // For example, `has(value.field)` is a function call that is replaced by a + // `test_only` field selection in the AST. Likewise, the call + // `list.exists(e, e > 10)` translates to a comprehension expression. The key + // in the map corresponds to the expression id of the expanded macro, and the + // value is the call `Expr` that was replaced. + absl::flat_hash_map macro_calls_; +}; + +// Analogous to google::api::expr::v1alpha1::ParsedExpr +// An expression together with source information as returned by the parser. +// Move-only type. +class ParsedExpr { + public: + ParsedExpr() {} + ParsedExpr(Expr expr, SourceInfo source_info) + : expr_(std::move(expr)), source_info_(std::move(source_info)) {} + + ParsedExpr(ParsedExpr&& rhs) = default; + ParsedExpr& operator=(ParsedExpr&& rhs) = default; + + void set_expr(Expr expr) { expr_ = std::move(expr); } + + void set_source_info(SourceInfo source_info) { + source_info_ = std::move(source_info); + } + + const Expr& expr() const { return expr_; } + + Expr& mutable_expr() { return expr_; } + + const SourceInfo& source_info() const { return source_info_; } + + SourceInfo& mutable_source_info() { return source_info_; } + + private: + // The parsed expression. + Expr expr_; + // The source info derived from input that generated the parsed `expr`. + SourceInfo source_info_; +}; + +// CEL primitive types. +enum class PrimitiveType { + // Unspecified type. + kPrimitiveTypeUnspecified = 0, + // Boolean type. + kBool = 1, + // Int64 type. + // + // Proto-based integer values are widened to int64_t. + kInt64 = 2, + // Uint64 type. + // + // Proto-based unsigned integer values are widened to uint64_t. + kUint64 = 3, + // Double type. + // + // Proto-based float values are widened to double values. + kDouble = 4, + // String type. + kString = 5, + // Bytes type. + kBytes = 6, +}; + +// Well-known protobuf types treated with first-class support in CEL. +// +// TODO(issues/5): represent well-known via abstract types (or however) +// they will be named. +enum class WellKnownType { + // Unspecified type. + kWellKnownTypeUnspecified = 0, + // Well-known protobuf.Any type. + // + // Any types are a polymorphic message type. During type-checking they are + // treated like `DYN` types, but at runtime they are resolved to a specific + // message type specified at evaluation time. + kAny = 1, + // Well-known protobuf.Timestamp type, internally referenced as `timestamp`. + kTimestamp = 2, + // Well-known protobuf.Duration type, internally referenced as `duration`. + kDuration = 3, +}; + +class Type; + +// List type with typed elements, e.g. `list`. +class ListType { + explicit ListType(std::unique_ptr elem_type) + : elem_type_(std::move(elem_type)) {} + + void set_elem_type(std::unique_ptr elem_type) { + elem_type_ = std::move(elem_type); + } + + const Type* elem_type() const { return elem_type_.get(); } + + Type& mutable_elem_type() { + ABSL_ASSERT(elem_type_ != nullptr); + return *elem_type_; + } + + private: + std::unique_ptr elem_type_; +}; + +// Map type with parameterized key and value types, e.g. `map`. +class MapType { + public: + MapType(std::unique_ptr key_type, std::unique_ptr value_type) + : key_type_(std::move(key_type)), value_type_(std::move(value_type)) {} + + void set_key_type(std::unique_ptr key_type) { + key_type_ = std::move(key_type); + } + + void set_value_type(std::unique_ptr value_type) { + value_type_ = std::move(value_type); + } + + const Type* key_type() const { return key_type_.get(); } + + const Type* value_type() const { return value_type_.get(); } + + Type& mutable_key_type() { + ABSL_ASSERT(key_type_ != nullptr); + return *key_type_; + } + + Type& mutable_value_type() { + ABSL_ASSERT(value_type_ != nullptr); + return *value_type_; + } + + private: + // The type of the key. + std::unique_ptr key_type_; + + // The type of the value. + std::unique_ptr value_type_; +}; + +// Function type with result and arg types. +// +// (-- +// NOTE: function type represents a lambda-style argument to another function. +// Supported through macros, but not yet a first-class concept in CEL. +// --) +class FunctionType { + public: + FunctionType(std::unique_ptr result_type, std::vector arg_types) + : result_type_(std::move(result_type)), + arg_types_(std::move(arg_types)) {} + + void set_result_type(std::unique_ptr result_type) { + result_type_ = std::move(result_type); + } + + void set_arg_types(std::vector arg_types) { + arg_types_ = std::move(arg_types); + } + + const Type* result_type() const { return result_type_.get(); } + + Type& mutable_result_type() { + ABSL_ASSERT(result_type_.get() != nullptr); + return *result_type_; + } + + const std::vector& arg_types() const { return arg_types_; } + + std::vector& mutable_arg_types() { return arg_types_; } + + private: + // Result type of the function. + std::unique_ptr result_type_; + + // Argument types of the function. + std::vector arg_types_; +}; + +// Application defined abstract type. +// +// TODO(issues/5): decide on final naming for this. +class AbstractType { + AbstractType(std::string name, std::vector parameter_types) + : name_(std::move(name)), parameter_types_(std::move(parameter_types)) {} + + void set_name(std::string name) { name_ = std::move(name); } + + void set_parameter_types(std::vector parameter_types) { + parameter_types_ = std::move(parameter_types); + } + + const std::string& name() const { return name_; } + + const std::vector& parameter_types() const { return parameter_types_; } + + std::vector& mutable_parameter_types() { return parameter_types_; } + + private: + // The fully qualified name of this abstract type. + std::string name_; + + // Parameter types for this abstract type. + std::vector parameter_types_; +}; + +// Wrapper of a primitive type, e.g. `google.protobuf.Int64Value`. +class PrimitiveTypeWrapper { + public: + explicit PrimitiveTypeWrapper(PrimitiveType type) : type_(std::move(type)) {} + + void set_type(PrimitiveType type) { type_ = std::move(type); } + + const PrimitiveType& type() const { return type_; } + + PrimitiveType& type() { return type_; } + + private: + PrimitiveType type_; +}; + +// Protocol buffer message type. +// +// The `message_type` string specifies the qualified message type name. For +// example, `google.plus.Profile`. +class MessageType { + public: + explicit MessageType(std::string type) : type_(std::move(type)) {} + + void set_type(std::string type) { type_ = std::move(type); } + + const std::string& type() { return type_; } + + private: + std::string type_; +}; + +// Type param type. +// +// The `type_param` string specifies the type parameter name, e.g. `list` +// would be a `list_type` whose element type was a `type_param` type +// named `E`. +class ParamType { + public: + explicit ParamType(std::string type) : type_(std::move(type)) {} + + void set_type(std::string type) { type_ = std::move(type); } + + const std::string& type() { return type_; } + + private: + std::string type_; +}; + +// Error type. +// +// During type-checking if an expression is an error, its type is propagated +// as the `ERROR` type. This permits the type-checker to discover other +// errors present in the expression. +enum class ErrorType { kErrorTypeValue = 0 }; + +using DynamicType = absl::monostate; + +using TypeKind = + absl::variant, ErrorType, AbstractType>; + +// Analogous to google::api::expr::v1alpha1::Type. +// Represents a CEL type. +// +// TODO(issues/5): align with value.proto +class Type { + public: + explicit Type(TypeKind type_kind) : type_kind_(std::move(type_kind)) {} + + Type(Type&& rhs) = default; + Type& operator=(Type&& rhs) = default; + + void set_type_kind(TypeKind type_kind) { type_kind_ = std::move(type_kind); } + + const TypeKind& type_kind() const { return type_kind_; } + + TypeKind& mutable_type_kind() { return type_kind_; } + + private: + TypeKind type_kind_; +}; + +// Describes a resolved reference to a declaration. +class Reference { + public: + Reference(std::string name, std::vector overload_id, + Constant value) + : name_(std::move(name)), + overload_id_(std::move(overload_id)), + value_(std::move(value)) {} + + void set_name(std::string name) { name_ = std::move(name); } + + void set_overload_id(std::vector overload_id) { + overload_id_ = std::move(overload_id); + } + + void set_value(Constant value) { value_ = std::move(value); } + + const std::string& name() const { return name_; } + + const std::vector& overload_id() const { return overload_id_; } + + const Constant& value() const { return value_; } + + std::vector& mutable_overload_id() { return overload_id_; } + + Constant& mutable_value() { return value_; } + + private: + // The fully qualified name of the declaration. + std::string name_; + // For references to functions, this is a list of `Overload.overload_id` + // values which match according to typing rules. + // + // If the list has more than one element, overload resolution among the + // presented candidates must happen at runtime because of dynamic types. The + // type checker attempts to narrow down this list as much as possible. + // + // Empty if this is not a reference to a [Decl.FunctionDecl][]. + std::vector overload_id_; + // For references to constants, this may contain the value of the + // constant if known at compile time. + Constant value_; +}; + +// Analogous to google::api::expr::v1alpha1::CheckedExpr +// A CEL expression which has been successfully type checked. +// Move-only type. +class CheckedExpr { + public: + CheckedExpr() {} + CheckedExpr(absl::flat_hash_map reference_map, + absl::flat_hash_map type_map, + SourceInfo source_info, std::string expr_version, Expr expr) + : reference_map_(std::move(reference_map)), + type_map_(std::move(type_map)), + source_info_(std::move(source_info)), + expr_version_(std::move(expr_version)), + expr_(std::move(expr)) {} + + CheckedExpr(CheckedExpr&& rhs) = default; + CheckedExpr& operator=(CheckedExpr&& rhs) = default; + + void set_reference_map( + absl::flat_hash_map reference_map) { + reference_map_ = std::move(reference_map); + } + + void set_type_map(absl::flat_hash_map type_map) { + type_map_ = std::move(type_map); + } + + void set_source_info(SourceInfo source_info) { + source_info_ = std::move(source_info); + } + + void set_expr_version(std::string expr_version) { + expr_version_ = std::move(expr_version); + } + + void set_expr(Expr expr) { expr_ = std::move(expr); } + + const absl::flat_hash_map& reference_map() const { + return reference_map_; + } + + absl::flat_hash_map& mutable_reference_map() { + return reference_map_; + } + + const absl::flat_hash_map& type_map() const { + return type_map_; + } + + absl::flat_hash_map& mutable_type_map() { return type_map_; } + + const SourceInfo& source_info() const { return source_info_; } + + SourceInfo& mutable_source_info() { return source_info_; } + + const std::string& expr_version() const { return expr_version_; } + + const Expr& expr() const { return expr_; } + + Expr& mutable_expr() { return expr_; } + + private: + // A map from expression ids to resolved references. + // + // The following entries are in this table: + // + // - An Ident or Select expression is represented here if it resolves to a + // declaration. For instance, if `a.b.c` is represented by + // `select(select(id(a), b), c)`, and `a.b` resolves to a declaration, + // while `c` is a field selection, then the reference is attached to the + // nested select expression (but not to the id or or the outer select). + // In turn, if `a` resolves to a declaration and `b.c` are field selections, + // the reference is attached to the ident expression. + // - Every Call expression has an entry here, identifying the function being + // called. + // - Every CreateStruct expression for a message has an entry, identifying + // the message. + absl::flat_hash_map reference_map_; + // A map from expression ids to types. + // + // Every expression node which has a type different than DYN has a mapping + // here. If an expression has type DYN, it is omitted from this map to save + // space. + absl::flat_hash_map type_map_; + // The source info derived from input that generated the parsed `expr` and + // any optimizations made during the type-checking pass. + SourceInfo source_info_; + // The expr version indicates the major / minor version number of the `expr` + // representation. + // + // The most common reason for a version change will be to indicate to the CEL + // runtimes that transformations have been performed on the expr during static + // analysis. In some cases, this will save the runtime the work of applying + // the same or similar transformations prior to evaluation. + std::string expr_version_; + // The checked expression. Semantically equivalent to the parsed `expr`, but + // may have structural differences. + Expr expr_; +}; + +} // namespace cel::ast::internal + +#endif // THIRD_PARTY_CEL_CPP_BASE_AST_H_ diff --git a/base/ast_test.cc b/base/ast_test.cc new file mode 100644 index 000000000..987ef4b8a --- /dev/null +++ b/base/ast_test.cc @@ -0,0 +1,179 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/ast.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/variant.h" +#include "internal/testing.h" + +namespace cel { +namespace ast { +namespace internal { +namespace { +TEST(AstTest, ExprConstructionConstant) { + Expr expr(1, true); + ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); + const auto& constant = absl::get(expr.expr_kind()); + ASSERT_TRUE(absl::holds_alternative(constant)); + ASSERT_TRUE(absl::get(constant)); +} + +TEST(AstTest, ExprConstructionIdent) { + Expr expr(1, Ident("var")); + ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); + ASSERT_EQ(absl::get(expr.expr_kind()).name(), "var"); +} + +TEST(AstTest, ExprConstructionSelect) { + Expr expr(1, Select(std::make_unique(2, Ident("var")), "field")); + ASSERT_TRUE(absl::holds_alternative(expr.expr_kind()); + ASSERT_TRUE(absl::holds_alternative(select.operand()->expr_kind())); + ASSERT_EQ(absl::get(select.operand()->expr_kind()).name(), "var"); + ASSERT_EQ(select.field(), "field"); +} + +TEST(AstTest, ExprConstructionCall) { + Expr expr(1, Call(std::make_unique(2, Ident("var")), "function", {})); + ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); + const auto& call = absl::get(expr.expr_kind()); + ASSERT_TRUE(absl::holds_alternative(call.target()->expr_kind())); + ASSERT_EQ(absl::get(call.target()->expr_kind()).name(), "var"); + ASSERT_EQ(call.function(), "function"); + ASSERT_TRUE(call.args().empty()); +} + +TEST(AstTest, ExprConstructionCreateList) { + CreateList create_list; + create_list.mutable_elements().emplace_back(Expr(2, Ident("var1"))); + create_list.mutable_elements().emplace_back(Expr(3, Ident("var2"))); + create_list.mutable_elements().emplace_back(Expr(4, Ident("var3"))); + Expr expr(1, std::move(create_list)); + ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); + const auto& elements = absl::get(expr.expr_kind()).elements(); + ASSERT_EQ(absl::get(elements[0].expr_kind()).name(), "var1"); + ASSERT_EQ(absl::get(elements[1].expr_kind()).name(), "var2"); + ASSERT_EQ(absl::get(elements[2].expr_kind()).name(), "var3"); +} + +TEST(AstTest, ExprConstructionCreateStruct) { + CreateStruct create_struct; + create_struct.set_message_name("name"); + create_struct.mutable_entries().emplace_back(CreateStruct::Entry( + 1, "key1", std::make_unique(2, Ident("value1")))); + create_struct.mutable_entries().emplace_back(CreateStruct::Entry( + 3, "key2", std::make_unique(4, Ident("value2")))); + create_struct.mutable_entries().emplace_back( + CreateStruct::Entry(5, std::make_unique(6, Ident("key3")), + std::make_unique(6, Ident("value3")))); + Expr expr(1, std::move(create_struct)); + ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); + const auto& entries = absl::get(expr.expr_kind()).entries(); + ASSERT_EQ(absl::get(entries[0].key_kind()), "key1"); + ASSERT_EQ(absl::get(entries[0].value()->expr_kind()).name(), "value1"); + ASSERT_EQ(absl::get(entries[1].key_kind()), "key2"); + ASSERT_EQ(absl::get(entries[1].value()->expr_kind()).name(), "value2"); + ASSERT_EQ( + absl::get( + absl::get>(entries[2].key_kind())->expr_kind()) + .name(), + "key3"); + ASSERT_EQ(absl::get(entries[2].value()->expr_kind()).name(), "value3"); +} + +TEST(AstTest, ExprConstructionComprehension) { + Comprehension comprehension; + comprehension.set_iter_var("iter_var"); + comprehension.set_iter_range(std::make_unique(1, Ident("range"))); + comprehension.set_accu_var("accu_var"); + comprehension.set_accu_init(std::make_unique(2, Ident("init"))); + comprehension.set_loop_condition(std::make_unique(3, Ident("cond"))); + comprehension.set_loop_step(std::make_unique(4, Ident("step"))); + comprehension.set_result(std::make_unique(5, Ident("result"))); + Expr expr(6, std::move(comprehension)); + ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); + auto& created_expr = absl::get(expr.expr_kind()); + ASSERT_EQ(created_expr.iter_var(), "iter_var"); + ASSERT_EQ(absl::get(created_expr.iter_range()->expr_kind()).name(), + "range"); + ASSERT_EQ(created_expr.accu_var(), "accu_var"); + ASSERT_EQ(absl::get(created_expr.accu_init()->expr_kind()).name(), + "init"); + ASSERT_EQ(absl::get(created_expr.loop_condition()->expr_kind()).name(), + "cond"); + ASSERT_EQ(absl::get(created_expr.loop_step()->expr_kind()).name(), + "step"); + ASSERT_EQ(absl::get(created_expr.result()->expr_kind()).name(), + "result"); +} + +TEST(AstTest, ExprMoveTest) { + Expr expr(1, Ident("var")); + ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); + ASSERT_EQ(absl::get(expr.expr_kind()).name(), "var"); + Expr new_expr = std::move(expr); + ASSERT_TRUE(absl::holds_alternative(new_expr.expr_kind())); + ASSERT_EQ(absl::get(new_expr.expr_kind()).name(), "var"); +} + +TEST(AstTest, ParsedExpr) { + ParsedExpr parsed_expr; + parsed_expr.set_expr(Expr(1, Ident("name"))); + auto& source_info = parsed_expr.mutable_source_info(); + source_info.set_syntax_version("syntax_version"); + source_info.set_location("location"); + source_info.set_line_offsets({1, 2, 3}); + source_info.set_positions({{1, 1}, {2, 2}}); + ASSERT_TRUE(absl::holds_alternative(parsed_expr.expr().expr_kind())); + ASSERT_EQ(absl::get(parsed_expr.expr().expr_kind()).name(), "name"); + ASSERT_EQ(parsed_expr.source_info().syntax_version(), "syntax_version"); + ASSERT_EQ(parsed_expr.source_info().location(), "location"); + EXPECT_THAT(parsed_expr.source_info().line_offsets(), + testing::UnorderedElementsAre(1, 2, 3)); + EXPECT_THAT( + parsed_expr.source_info().positions(), + testing::UnorderedElementsAre(testing::Pair(1, 1), testing::Pair(2, 2))); +} + +TEST(AstTest, CheckedExpr) { + CheckedExpr checked_expr; + checked_expr.set_expr(Expr(1, Ident("name"))); + auto& source_info = checked_expr.mutable_source_info(); + source_info.set_syntax_version("syntax_version"); + source_info.set_location("location"); + source_info.set_line_offsets({1, 2, 3}); + source_info.set_positions({{1, 1}, {2, 2}}); + checked_expr.set_expr_version("expr_version"); + checked_expr.mutable_type_map().insert( + {1, Type(PrimitiveType(PrimitiveType::kBool))}); + ASSERT_TRUE(absl::holds_alternative(checked_expr.expr().expr_kind())); + ASSERT_EQ(absl::get(checked_expr.expr().expr_kind()).name(), "name"); + ASSERT_EQ(checked_expr.source_info().syntax_version(), "syntax_version"); + ASSERT_EQ(checked_expr.source_info().location(), "location"); + EXPECT_THAT(checked_expr.source_info().line_offsets(), + testing::UnorderedElementsAre(1, 2, 3)); + EXPECT_THAT( + checked_expr.source_info().positions(), + testing::UnorderedElementsAre(testing::Pair(1, 1), testing::Pair(2, 2))); + EXPECT_EQ(checked_expr.expr_version(), "expr_version"); +} + +} // namespace +} // namespace internal +} // namespace ast +} // namespace cel From 0f80c890df5a8d0282d9840860af012e0985397f Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 3 May 2022 14:55:27 +0000 Subject: [PATCH 089/111] Internal change. PiperOrigin-RevId: 446193337 --- base/BUILD | 4 + base/internal/memory_manager.pre.h | 2 + base/memory_manager.cc | 223 ++++++++++++++++++++++++++++- base/memory_manager.h | 6 + base/memory_manager_test.cc | 13 ++ 5 files changed, 244 insertions(+), 4 deletions(-) diff --git a/base/BUILD b/base/BUILD index 8e7c62c83..5a19418fd 100644 --- a/base/BUILD +++ b/base/BUILD @@ -54,9 +54,13 @@ cc_library( deps = [ "//base/internal:memory_manager", "//internal:no_destructor", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", ], ) diff --git a/base/internal/memory_manager.pre.h b/base/internal/memory_manager.pre.h index 28ac19541..741142b75 100644 --- a/base/internal/memory_manager.pre.h +++ b/base/internal/memory_manager.pre.h @@ -28,6 +28,8 @@ class MemoryManager; namespace base_internal { +size_t GetPageSize(); + class Resource; template diff --git a/base/memory_manager.cc b/base/memory_manager.cc index db2484646..0f0d40522 100644 --- a/base/memory_manager.cc +++ b/base/memory_manager.cc @@ -1,5 +1,3 @@ -#include "base/memory_manager.h" - // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,18 +12,44 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "base/memory_manager.h" + +#ifndef _WIN32 +#include +#include + +#include +#else +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN 1 +#endif +#ifndef NOMINMAX +#define NOMINMAX 1 +#endif +#include +#endif + +#include #include #include #include #include #include +#include #include +#include #include +#include +#include #include "absl/base/attributes.h" +#include "absl/base/call_once.h" #include "absl/base/config.h" +#include "absl/base/dynamic_annotations.h" #include "absl/base/macros.h" +#include "absl/base/thread_annotations.h" #include "absl/numeric/bits.h" +#include "absl/synchronization/mutex.h" #include "internal/no_destructor.h" namespace cel { @@ -81,16 +105,23 @@ struct ControlBlock final { } }; -size_t AlignUp(size_t size, size_t align) { +uintptr_t AlignUp(uintptr_t size, size_t align) { ABSL_ASSERT(size != 0); ABSL_ASSERT(absl::has_single_bit(align)); // Assert aligned to power of 2. #if ABSL_HAVE_BUILTIN(__builtin_align_up) return __builtin_align_up(size, align); #else - return (size + align - size_t{1}) & ~(align - size_t{1}); + return (size + static_cast(align) - uintptr_t{1}) & + ~(static_cast(align) - uintptr_t{1}); #endif } +template +T* AlignUp(T* pointer, size_t align) { + return reinterpret_cast( + AlignUp(reinterpret_cast(pointer), align)); +} + inline constexpr size_t kControlBlockSize = sizeof(ControlBlock); inline constexpr size_t kControlBlockAlign = alignof(ControlBlock); @@ -179,8 +210,188 @@ size_t AdjustAllocationSize(size_t size, size_t align) { return size + kControlBlockSize; } +struct ArenaBlock final { + // The base pointer of the virtual memory, always points to the start of a + // page. + uint8_t* begin; + // The end pointer of the virtual memory, it's 1 past the last byte of the + // page(s). + uint8_t* end; + // The pointer to the first byte that we have not yet allocated. + uint8_t* current; + + size_t remaining() const { return static_cast(end - current); } + + // Aligns the current pointer to `align`. + ArenaBlock& Align(size_t align) { + current = std::min(end, AlignUp(current, align)); + return *this; + } + + // Allocate `size` bytes from this block. This causes the current pointer to + // advance `size` bytes. + uint8_t* Allocate(size_t size) { + uint8_t* pointer = current; + current += size; + ABSL_ASSERT(current <= end); + return pointer; + } + + size_t capacity() const { return static_cast(end - begin); } +}; + +// Allocate a block of virtual memory from the kernel. `size` must be a multiple +// of `GetArenaPageSize()`. `hint` is a suggestion to the kernel of where we +// would like the virtual memory to be placed. +std::optional ArenaBlockAllocate(size_t size, + void* hint = nullptr) { + void* pointer; +#ifndef _WIN32 + pointer = mmap(hint, size, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + if (ABSL_PREDICT_FALSE(pointer == MAP_FAILED)) { + return std::nullopt; + } +#else + pointer = VirtualAlloc(hint, size, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE); + if (ABSL_PREDICT_FALSE(pointer == nullptr)) { + if (hint == nullptr) { + return absl::nullopt; + } + // Try again, without the hint. + pointer = + VirtualAlloc(nullptr, size, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE); + if (pointer == nullptr) { + return absl::nullopt; + } + } +#endif + ANNOTATE_MEMORY_IS_UNINITIALIZED(pointer, size); + return ArenaBlock{static_cast(pointer), + static_cast(pointer) + size, + static_cast(pointer)}; +} + +// Free the block of virtual memory with the kernel. +void ArenaBlockFree(void* pointer, size_t size) { +#ifndef _WIN32 + if (ABSL_PREDICT_FALSE(munmap(pointer, size))) { + // If this happens its likely a bug and its probably corruption. Just bail. + std::perror("cel: failed to unmap pages from memory"); + std::fflush(stderr); + std::abort(); + } +#else + static_cast(size); + if (ABSL_PREDICT_FALSE(!VirtualFree(pointer, 0, MEM_RELEASE))) { + // TODO(issues/5): print the error + std::abort(); + } +#endif +} + +class DefaultArenaMemoryManager final : public ArenaMemoryManager { + public: + ~DefaultArenaMemoryManager() override { + absl::MutexLock lock(&mutex_); + for (const auto& owned : owned_) { + (*owned.second)(owned.first); + } + for (auto& block : blocks_) { + ArenaBlockFree(block.begin, block.capacity()); + } + } + + private: + AllocationResult Allocate(size_t size, size_t align) override { + auto page_size = base_internal::GetPageSize(); + if (align > page_size) { + // Just, no. We refuse anything that requests alignment over the system + // page size. + return AllocationResult{nullptr}; + } + absl::MutexLock lock(&mutex_); + bool bridge_gap = false; + if (ABSL_PREDICT_FALSE(blocks_.empty() || + blocks_.back().Align(align).remaining() == 0)) { + // Currently no allocated blocks or the allocation alignment is large + // enough that we cannot use any of the last block. Just allocate a block + // large enough. + auto maybe_block = ArenaBlockAllocate(AlignUp(size, page_size)); + if (!maybe_block.has_value()) { + return AllocationResult{nullptr}; + } + blocks_.push_back(std::move(maybe_block).value()); + } else { + // blocks_.back() was aligned above. + auto& last_block = blocks_.back(); + size_t remaining = last_block.remaining(); + if (ABSL_PREDICT_FALSE(remaining < size)) { + auto maybe_block = + ArenaBlockAllocate(AlignUp(size, page_size), last_block.end); + if (!maybe_block.has_value()) { + return AllocationResult{nullptr}; + } + bridge_gap = last_block.end == maybe_block.value().begin; + blocks_.push_back(std::move(maybe_block).value()); + } + } + if (ABSL_PREDICT_FALSE(bridge_gap)) { + // The last block did not have enough to fit the requested size, so we had + // to allocate a new block. However the alignment was low enough and the + // kernel gave us the page immediately after the last. Therefore we can + // span the allocation across both blocks. + auto& second_last_block = blocks_[blocks_.size() - 2]; + size_t remaining = second_last_block.remaining(); + void* pointer = second_last_block.Allocate(remaining); + blocks_.back().Allocate(size - remaining); + return AllocationResult{pointer}; + } + return AllocationResult{blocks_.back().Allocate(size)}; + } + + void OwnDestructor(void* pointer, void (*destruct)(void*)) override { + absl::MutexLock lock(&mutex_); + owned_.emplace_back(pointer, destruct); + } + + absl::Mutex mutex_; + std::vector blocks_ ABSL_GUARDED_BY(mutex_); + std::vector> owned_ ABSL_GUARDED_BY(mutex_); + // TODO(issues/5): we could use a priority queue to keep track of any + // unallocated space at the end blocks. +}; + } // namespace +namespace base_internal { + +// Returns the platforms page size. When requesting vitual memory from the +// kernel, typically the size requested must be a multiple of the page size. +size_t GetPageSize() { + static const size_t page_size = []() -> size_t { +#ifndef _WIN32 + auto value = sysconf(_SC_PAGESIZE); + if (ABSL_PREDICT_FALSE(value == -1)) { + // This should not happen, if it does bail. There is no other way to + // determine the page size. + std::perror("cel: failed to determine system page size"); + std::fflush(stderr); + std::abort(); + } + return static_cast(value); +#else + SYSTEM_INFO system_info; + SecureZeroMemory(&system_info, sizeof(system_info)); + GetSystemInfo(&system_info); + return static_cast(system_info.dwPageSize); +#endif + }(); + return page_size; +} + +} // namespace base_internal + MemoryManager& MemoryManager::Global() { static internal::NoDestructor instance; return *instance; @@ -267,4 +478,8 @@ void ArenaMemoryManager::Deallocate(void* pointer, size_t size, size_t align) { std::abort(); } +std::unique_ptr ArenaMemoryManager::Default() { + return std::make_unique(); +} + } // namespace cel diff --git a/base/memory_manager.h b/base/memory_manager.h index d53cbf074..e333fe18b 100644 --- a/base/memory_manager.h +++ b/base/memory_manager.h @@ -285,6 +285,12 @@ class ProtoMemoryManager; // Base class for all arena-based memory managers. class ArenaMemoryManager : public MemoryManager { + public: + // Returns the default implementation of an arena-based memory manager. In + // most cases it should be good enough, however you should not rely on its + // performance characteristics. + static std::unique_ptr Default(); + protected: ArenaMemoryManager() : ArenaMemoryManager(true) {} diff --git a/base/memory_manager_test.cc b/base/memory_manager_test.cc index 854c5c49b..fe20fb02b 100644 --- a/base/memory_manager_test.cc +++ b/base/memory_manager_test.cc @@ -49,5 +49,18 @@ TEST(ManagedMemory, Null) { EXPECT_EQ(nullptr, ManagedMemory()); } +struct LargeStruct { + char padding[4096 - alignof(char)]; +}; + +TEST(DefaultArenaMemoryManager, OddSizes) { + auto memory_manager = ArenaMemoryManager::Default(); + size_t page_size = base_internal::GetPageSize(); + for (size_t allocated = 0; allocated <= page_size; + allocated += sizeof(LargeStruct)) { + static_cast(memory_manager->New()); + } +} + } // namespace } // namespace cel From 3c8ef7a1c38c1a58d7062539b13ee4e2b4a9d9e3 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 3 May 2022 16:35:37 +0000 Subject: [PATCH 090/111] Internal change. PiperOrigin-RevId: 446214501 --- base/BUILD | 2 + base/internal/BUILD | 10 + base/internal/memory_manager_testing.cc | 30 + base/internal/memory_manager_testing.h | 49 + base/type_test.cc | 211 ++-- base/value_test.cc | 1455 ++++++++++++----------- 6 files changed, 1000 insertions(+), 757 deletions(-) create mode 100644 base/internal/memory_manager_testing.cc create mode 100644 base/internal/memory_manager_testing.h diff --git a/base/BUILD b/base/BUILD index 5a19418fd..1eb2ef747 100644 --- a/base/BUILD +++ b/base/BUILD @@ -143,6 +143,7 @@ cc_test( ":memory_manager", ":type", ":value", + "//base/internal:memory_manager_testing", "//internal:testing", "@com_google_absl//absl/hash", "@com_google_absl//absl/hash:hash_testing", @@ -198,6 +199,7 @@ cc_test( ":memory_manager", ":type", ":value", + "//base/internal:memory_manager_testing", "//internal:strings", "//internal:testing", "//internal:time", diff --git a/base/internal/BUILD b/base/internal/BUILD index 33ebe7ea3..f264a4a5f 100644 --- a/base/internal/BUILD +++ b/base/internal/BUILD @@ -38,6 +38,16 @@ cc_library( ], ) +cc_library( + name = "memory_manager_testing", + testonly = True, + srcs = ["memory_manager_testing.cc"], + hdrs = ["memory_manager_testing.h"], + deps = [ + "//internal:testing", + ], +) + cc_library( name = "operators", hdrs = ["operators.h"], diff --git a/base/internal/memory_manager_testing.cc b/base/internal/memory_manager_testing.cc new file mode 100644 index 000000000..5b403e3c1 --- /dev/null +++ b/base/internal/memory_manager_testing.cc @@ -0,0 +1,30 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/memory_manager_testing.h" + +#include + +namespace cel::base_internal { + +std::string MemoryManagerTestModeToString(MemoryManagerTestMode mode) { + switch (mode) { + case MemoryManagerTestMode::kGlobal: + return "Global"; + case MemoryManagerTestMode::kArena: + return "Arena"; + } +} + +} // namespace cel::base_internal diff --git a/base/internal/memory_manager_testing.h b/base/internal/memory_manager_testing.h new file mode 100644 index 000000000..e62e11853 --- /dev/null +++ b/base/internal/memory_manager_testing.h @@ -0,0 +1,49 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_TESTING_H_ +#define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_TESTING_H_ + +#include +#include + +#include "internal/testing.h" + +namespace cel::base_internal { + +enum class MemoryManagerTestMode { + kGlobal = 0, + kArena, +}; + +std::string MemoryManagerTestModeToString(MemoryManagerTestMode mode); + +inline auto MemoryManagerTestModeAll() { + return testing::Values(MemoryManagerTestMode::kGlobal, + MemoryManagerTestMode::kArena); +} + +inline std::string MemoryManagerTestModeName( + const testing::TestParamInfo& info) { + return MemoryManagerTestModeToString(info.param); +} + +inline std::string MemoryManagerTestModeTupleName( + const testing::TestParamInfo>& info) { + return MemoryManagerTestModeToString(std::get<0>(info.param)); +} + +} // namespace cel::base_internal + +#endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_TESTING_H_ diff --git a/base/type_test.cc b/base/type_test.cc index a1d2cc6b4..5a4e844e6 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -21,6 +21,7 @@ #include "absl/hash/hash_testing.h" #include "absl/status/status.h" #include "base/handle.h" +#include "base/internal/memory_manager_testing.h" #include "base/memory_manager.h" #include "base/type_factory.h" #include "base/type_manager.h" @@ -145,6 +146,34 @@ Persistent Must(absl::StatusOr> status_or_handle) { template constexpr void IS_INITIALIZED(T&) {} +class TypeTest + : public testing::TestWithParam { + protected: + void SetUp() override { + if (GetParam() == base_internal::MemoryManagerTestMode::kArena) { + memory_manager_ = ArenaMemoryManager::Default(); + } + } + + void TearDown() override { + if (GetParam() == base_internal::MemoryManagerTestMode::kArena) { + memory_manager_.reset(); + } + } + + MemoryManager& memory_manager() const { + switch (GetParam()) { + case base_internal::MemoryManagerTestMode::kGlobal: + return MemoryManager::Global(); + case base_internal::MemoryManagerTestMode::kArena: + return *memory_manager_; + } + } + + private: + std::unique_ptr memory_manager_; +}; + TEST(Type, TransientHandleTypeTraits) { EXPECT_TRUE(std::is_default_constructible_v>); EXPECT_TRUE(std::is_copy_constructible_v>); @@ -175,14 +204,14 @@ TEST(Type, PersistentHandleTypeTraits) { EXPECT_TRUE(std::is_swappable_v>); } -TEST(Type, CopyConstructor) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, CopyConstructor) { + TypeFactory type_factory(memory_manager()); Transient type(type_factory.GetIntType()); EXPECT_EQ(type, type_factory.GetIntType()); } -TEST(Type, MoveConstructor) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, MoveConstructor) { + TypeFactory type_factory(memory_manager()); Transient from(type_factory.GetIntType()); Transient to(std::move(from)); IS_INITIALIZED(from); @@ -190,15 +219,15 @@ TEST(Type, MoveConstructor) { EXPECT_EQ(to, type_factory.GetIntType()); } -TEST(Type, CopyAssignment) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, CopyAssignment) { + TypeFactory type_factory(memory_manager()); Transient type(type_factory.GetNullType()); type = type_factory.GetIntType(); EXPECT_EQ(type, type_factory.GetIntType()); } -TEST(Type, MoveAssignment) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, MoveAssignment) { + TypeFactory type_factory(memory_manager()); Transient from(type_factory.GetIntType()); Transient to(type_factory.GetNullType()); to = std::move(from); @@ -207,8 +236,8 @@ TEST(Type, MoveAssignment) { EXPECT_EQ(to, type_factory.GetIntType()); } -TEST(Type, Swap) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Swap) { + TypeFactory type_factory(memory_manager()); Transient lhs = type_factory.GetIntType(); Transient rhs = type_factory.GetUintType(); std::swap(lhs, rhs); @@ -220,8 +249,8 @@ TEST(Type, Swap) { // extension for struct member initiation by name for it to be worth it. That // feature is not available in C++17. -TEST(Type, Null) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Null) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetNullType()->kind(), Kind::kNullType); EXPECT_EQ(type_factory.GetNullType()->name(), "null_type"); EXPECT_THAT(type_factory.GetNullType()->parameters(), SizeIs(0)); @@ -242,8 +271,8 @@ TEST(Type, Null) { EXPECT_FALSE(type_factory.GetNullType().Is()); } -TEST(Type, Error) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Error) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetErrorType()->kind(), Kind::kError); EXPECT_EQ(type_factory.GetErrorType()->name(), "*error*"); EXPECT_THAT(type_factory.GetErrorType()->parameters(), SizeIs(0)); @@ -264,8 +293,8 @@ TEST(Type, Error) { EXPECT_FALSE(type_factory.GetErrorType().Is()); } -TEST(Type, Dyn) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Dyn) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetDynType()->kind(), Kind::kDyn); EXPECT_EQ(type_factory.GetDynType()->name(), "dyn"); EXPECT_THAT(type_factory.GetDynType()->parameters(), SizeIs(0)); @@ -286,8 +315,8 @@ TEST(Type, Dyn) { EXPECT_FALSE(type_factory.GetDynType().Is()); } -TEST(Type, Any) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Any) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetAnyType()->kind(), Kind::kAny); EXPECT_EQ(type_factory.GetAnyType()->name(), "google.protobuf.Any"); EXPECT_THAT(type_factory.GetAnyType()->parameters(), SizeIs(0)); @@ -308,8 +337,8 @@ TEST(Type, Any) { EXPECT_FALSE(type_factory.GetAnyType().Is()); } -TEST(Type, Bool) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Bool) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetBoolType()->kind(), Kind::kBool); EXPECT_EQ(type_factory.GetBoolType()->name(), "bool"); EXPECT_THAT(type_factory.GetBoolType()->parameters(), SizeIs(0)); @@ -330,8 +359,8 @@ TEST(Type, Bool) { EXPECT_FALSE(type_factory.GetBoolType().Is()); } -TEST(Type, Int) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Int) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetIntType()->kind(), Kind::kInt); EXPECT_EQ(type_factory.GetIntType()->name(), "int"); EXPECT_THAT(type_factory.GetIntType()->parameters(), SizeIs(0)); @@ -352,8 +381,8 @@ TEST(Type, Int) { EXPECT_FALSE(type_factory.GetIntType().Is()); } -TEST(Type, Uint) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Uint) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetUintType()->kind(), Kind::kUint); EXPECT_EQ(type_factory.GetUintType()->name(), "uint"); EXPECT_THAT(type_factory.GetUintType()->parameters(), SizeIs(0)); @@ -374,8 +403,8 @@ TEST(Type, Uint) { EXPECT_FALSE(type_factory.GetUintType().Is()); } -TEST(Type, Double) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Double) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetDoubleType()->kind(), Kind::kDouble); EXPECT_EQ(type_factory.GetDoubleType()->name(), "double"); EXPECT_THAT(type_factory.GetDoubleType()->parameters(), SizeIs(0)); @@ -396,8 +425,8 @@ TEST(Type, Double) { EXPECT_FALSE(type_factory.GetDoubleType().Is()); } -TEST(Type, String) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, String) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetStringType()->kind(), Kind::kString); EXPECT_EQ(type_factory.GetStringType()->name(), "string"); EXPECT_THAT(type_factory.GetStringType()->parameters(), SizeIs(0)); @@ -418,8 +447,8 @@ TEST(Type, String) { EXPECT_FALSE(type_factory.GetStringType().Is()); } -TEST(Type, Bytes) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Bytes) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetBytesType()->kind(), Kind::kBytes); EXPECT_EQ(type_factory.GetBytesType()->name(), "bytes"); EXPECT_THAT(type_factory.GetBytesType()->parameters(), SizeIs(0)); @@ -440,8 +469,8 @@ TEST(Type, Bytes) { EXPECT_FALSE(type_factory.GetBytesType().Is()); } -TEST(Type, Duration) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Duration) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetDurationType()->kind(), Kind::kDuration); EXPECT_EQ(type_factory.GetDurationType()->name(), "google.protobuf.Duration"); EXPECT_THAT(type_factory.GetDurationType()->parameters(), SizeIs(0)); @@ -462,8 +491,8 @@ TEST(Type, Duration) { EXPECT_FALSE(type_factory.GetDurationType().Is()); } -TEST(Type, Timestamp) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Timestamp) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetTimestampType()->kind(), Kind::kTimestamp); EXPECT_EQ(type_factory.GetTimestampType()->name(), "google.protobuf.Timestamp"); @@ -485,8 +514,8 @@ TEST(Type, Timestamp) { EXPECT_FALSE(type_factory.GetTimestampType().Is()); } -TEST(Type, Enum) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Enum) { + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); EXPECT_EQ(enum_type->kind(), Kind::kEnum); @@ -510,8 +539,8 @@ TEST(Type, Enum) { EXPECT_FALSE(enum_type.Is()); } -TEST(Type, Struct) { - TypeManager type_manager(MemoryManager::Global()); +TEST_P(TypeTest, Struct) { + TypeManager type_manager(memory_manager()); ASSERT_OK_AND_ASSIGN(auto enum_type, type_manager.CreateStructType()); EXPECT_EQ(enum_type->kind(), Kind::kStruct); @@ -535,8 +564,8 @@ TEST(Type, Struct) { EXPECT_FALSE(enum_type.Is()); } -TEST(Type, List) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, List) { + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto list_type, type_factory.CreateListType(type_factory.GetBoolType())); EXPECT_EQ(list_type, @@ -562,8 +591,8 @@ TEST(Type, List) { EXPECT_FALSE(list_type.Is()); } -TEST(Type, Map) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(TypeTest, Map) { + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto map_type, type_factory.CreateMapType(type_factory.GetStringType(), type_factory.GetBoolType())); @@ -595,8 +624,10 @@ TEST(Type, Map) { EXPECT_TRUE(map_type.Is()); } -TEST(EnumType, FindConstant) { - TypeFactory type_factory(MemoryManager::Global()); +using EnumTypeTest = TypeTest; + +TEST_P(EnumTypeTest, FindConstant) { + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); @@ -626,8 +657,14 @@ TEST(EnumType, FindConstant) { StatusIs(absl::StatusCode::kNotFound)); } -TEST(StructType, FindField) { - TypeManager type_manager(MemoryManager::Global()); +INSTANTIATE_TEST_SUITE_P(EnumTypeTest, EnumTypeTest, + base_internal::MemoryManagerTestModeAll(), + base_internal::MemoryManagerTestModeName); + +class StructTypeTest : public TypeTest {}; + +TEST_P(StructTypeTest, FindField) { + TypeManager type_manager(memory_manager()); ASSERT_OK_AND_ASSIGN(auto struct_type, type_manager.CreateStructType()); @@ -690,99 +727,109 @@ TEST(StructType, FindField) { StatusIs(absl::StatusCode::kNotFound)); } -TEST(NullType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +INSTANTIATE_TEST_SUITE_P(StructTypeTest, StructTypeTest, + base_internal::MemoryManagerTestModeAll(), + base_internal::MemoryManagerTestModeName); + +class DebugStringTest : public TypeTest {}; + +TEST_P(DebugStringTest, NullType) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetNullType()->DebugString(), "null_type"); } -TEST(ErrorType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, ErrorType) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetErrorType()->DebugString(), "*error*"); } -TEST(DynType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, DynType) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetDynType()->DebugString(), "dyn"); } -TEST(AnyType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, AnyType) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetAnyType()->DebugString(), "google.protobuf.Any"); } -TEST(BoolType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, BoolType) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetBoolType()->DebugString(), "bool"); } -TEST(IntType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, IntType) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetIntType()->DebugString(), "int"); } -TEST(UintType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, UintType) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetUintType()->DebugString(), "uint"); } -TEST(DoubleType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, DoubleType) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetDoubleType()->DebugString(), "double"); } -TEST(StringType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, StringType) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetStringType()->DebugString(), "string"); } -TEST(BytesType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, BytesType) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetBytesType()->DebugString(), "bytes"); } -TEST(DurationType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, DurationType) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetDurationType()->DebugString(), "google.protobuf.Duration"); } -TEST(TimestampType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, TimestampType) { + TypeFactory type_factory(memory_manager()); EXPECT_EQ(type_factory.GetTimestampType()->DebugString(), "google.protobuf.Timestamp"); } -TEST(EnumType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, EnumType) { + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); EXPECT_EQ(enum_type->DebugString(), "test_enum.TestEnum"); } -TEST(StructType, DebugString) { - TypeManager type_manager(MemoryManager::Global()); +TEST_P(DebugStringTest, StructType) { + TypeManager type_manager(memory_manager()); ASSERT_OK_AND_ASSIGN(auto struct_type, type_manager.CreateStructType()); EXPECT_EQ(struct_type->DebugString(), "test_struct.TestStruct"); } -TEST(ListType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, ListType) { + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto list_type, type_factory.CreateListType(type_factory.GetBoolType())); EXPECT_EQ(list_type->DebugString(), "list(bool)"); } -TEST(MapType, DebugString) { - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, MapType) { + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto map_type, type_factory.CreateMapType(type_factory.GetStringType(), type_factory.GetBoolType())); EXPECT_EQ(map_type->DebugString(), "map(string, bool)"); } -TEST(Type, SupportsAbslHash) { - TypeFactory type_factory(MemoryManager::Global()); +INSTANTIATE_TEST_SUITE_P(DebugStringTest, DebugStringTest, + base_internal::MemoryManagerTestModeAll(), + base_internal::MemoryManagerTestModeName); + +TEST_P(TypeTest, SupportsAbslHash) { + TypeFactory type_factory(memory_manager()); EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ Persistent(type_factory.GetNullType()), Persistent(type_factory.GetErrorType()), @@ -806,5 +853,9 @@ TEST(Type, SupportsAbslHash) { })); } +INSTANTIATE_TEST_SUITE_P(TypeTest, TypeTest, + base_internal::MemoryManagerTestModeAll(), + base_internal::MemoryManagerTestModeName); + } // namespace } // namespace cel diff --git a/base/value_test.cc b/base/value_test.cc index 0a90009d8..1e4ccd3c1 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -30,6 +30,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/time/time.h" +#include "base/internal/memory_manager_testing.h" #include "base/memory_manager.h" #include "base/type.h" #include "base/type_factory.h" @@ -40,6 +41,7 @@ #include "internal/time.h" namespace cel { + namespace { using testing::Eq; @@ -467,6 +469,45 @@ Transient Must(absl::StatusOr> status_or_handle) { template constexpr void IS_INITIALIZED(T&) {} +template +class BaseValueTest + : public testing::TestWithParam< + std::tuple> { + using Base = testing::TestWithParam< + std::tuple>; + + protected: + void SetUp() override { + if (std::get<0>(Base::GetParam()) == + base_internal::MemoryManagerTestMode::kArena) { + memory_manager_ = ArenaMemoryManager::Default(); + } + } + + void TearDown() override { + if (std::get<0>(Base::GetParam()) == + base_internal::MemoryManagerTestMode::kArena) { + memory_manager_.reset(); + } + } + + MemoryManager& memory_manager() const { + switch (std::get<0>(Base::GetParam())) { + case base_internal::MemoryManagerTestMode::kGlobal: + return MemoryManager::Global(); + case base_internal::MemoryManagerTestMode::kArena: + return *memory_manager_; + } + } + + const auto& test_case() const { return std::get<1>(Base::GetParam()); } + + private: + std::unique_ptr memory_manager_; +}; + +using ValueTest = BaseValueTest<>; + TEST(Value, HandleSize) { // Advisory test to ensure we attempt to keep the size of Value handles under // 32 bytes. As of the time of writing they are 24 bytes. @@ -503,8 +544,8 @@ TEST(Value, PersistentHandleTypeTraits) { EXPECT_TRUE(std::is_swappable_v>); } -TEST(Value, DefaultConstructor) { - ValueFactory value_factory(MemoryManager::Global()); +TEST_P(ValueTest, DefaultConstructor) { + ValueFactory value_factory(memory_manager()); Transient value; EXPECT_EQ(value, value_factory.GetNullValue()); } @@ -516,144 +557,148 @@ struct ConstructionAssignmentTestCase final { }; using ConstructionAssignmentTest = - testing::TestWithParam; + BaseValueTest; TEST_P(ConstructionAssignmentTest, CopyConstructor) { - const auto& test_case = GetParam(); - TypeFactory type_factory(MemoryManager::Global()); - ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(memory_manager()); + ValueFactory value_factory(memory_manager()); Persistent from( - test_case.default_value(type_factory, value_factory)); + test_case().default_value(type_factory, value_factory)); Persistent to(from); IS_INITIALIZED(to); - EXPECT_EQ(to, test_case.default_value(type_factory, value_factory)); + EXPECT_EQ(to, test_case().default_value(type_factory, value_factory)); } TEST_P(ConstructionAssignmentTest, MoveConstructor) { - const auto& test_case = GetParam(); - TypeFactory type_factory(MemoryManager::Global()); - ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(memory_manager()); + ValueFactory value_factory(memory_manager()); Persistent from( - test_case.default_value(type_factory, value_factory)); + test_case().default_value(type_factory, value_factory)); Persistent to(std::move(from)); IS_INITIALIZED(from); EXPECT_EQ(from, value_factory.GetNullValue()); - EXPECT_EQ(to, test_case.default_value(type_factory, value_factory)); + EXPECT_EQ(to, test_case().default_value(type_factory, value_factory)); } TEST_P(ConstructionAssignmentTest, CopyAssignment) { - const auto& test_case = GetParam(); - TypeFactory type_factory(MemoryManager::Global()); - ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(memory_manager()); + ValueFactory value_factory(memory_manager()); Persistent from( - test_case.default_value(type_factory, value_factory)); + test_case().default_value(type_factory, value_factory)); Persistent to; to = from; EXPECT_EQ(to, from); } TEST_P(ConstructionAssignmentTest, MoveAssignment) { - const auto& test_case = GetParam(); - TypeFactory type_factory(MemoryManager::Global()); - ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(memory_manager()); + ValueFactory value_factory(memory_manager()); Persistent from( - test_case.default_value(type_factory, value_factory)); + test_case().default_value(type_factory, value_factory)); Persistent to; to = std::move(from); IS_INITIALIZED(from); EXPECT_EQ(from, value_factory.GetNullValue()); - EXPECT_EQ(to, test_case.default_value(type_factory, value_factory)); + EXPECT_EQ(to, test_case().default_value(type_factory, value_factory)); } INSTANTIATE_TEST_SUITE_P( ConstructionAssignmentTest, ConstructionAssignmentTest, - testing::ValuesIn({ - {"Null", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return value_factory.GetNullValue(); - }}, - {"Bool", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return value_factory.CreateBoolValue(false); - }}, - {"Int", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return value_factory.CreateIntValue(0); - }}, - {"Uint", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return value_factory.CreateUintValue(0); - }}, - {"Double", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return value_factory.CreateDoubleValue(0.0); - }}, - {"Duration", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return Must(value_factory.CreateDurationValue(absl::ZeroDuration())); - }}, - {"Timestamp", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return Must(value_factory.CreateTimestampValue(absl::UnixEpoch())); - }}, - {"Error", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return value_factory.CreateErrorValue(absl::CancelledError()); - }}, - {"Bytes", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return Must(value_factory.CreateBytesValue("")); - }}, - {"String", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return Must(value_factory.CreateStringValue("")); - }}, - {"Enum", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return Must( - EnumValue::New(Must(type_factory.CreateEnumType()), - value_factory, EnumType::ConstantId("VALUE1"))); - }}, - {"Struct", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return Must(StructValue::New( - Must(type_factory.CreateStructType()), - value_factory)); - }}, - {"List", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return Must(value_factory.CreateListValue( - Must(type_factory.CreateListType(type_factory.GetIntType())), - std::vector{})); - }}, - {"Map", - [](TypeFactory& type_factory, - ValueFactory& value_factory) -> Persistent { - return Must(value_factory.CreateMapValue( - Must(type_factory.CreateMapType(type_factory.GetStringType(), - type_factory.GetIntType())), - std::map{})); - }}, - }), - [](const testing::TestParamInfo& info) { - return info.param.name; + testing::Combine( + base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {"Null", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return value_factory.GetNullValue(); + }}, + {"Bool", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return value_factory.CreateBoolValue(false); + }}, + {"Int", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return value_factory.CreateIntValue(0); + }}, + {"Uint", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return value_factory.CreateUintValue(0); + }}, + {"Double", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return value_factory.CreateDoubleValue(0.0); + }}, + {"Duration", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must( + value_factory.CreateDurationValue(absl::ZeroDuration())); + }}, + {"Timestamp", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must( + value_factory.CreateTimestampValue(absl::UnixEpoch())); + }}, + {"Error", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return value_factory.CreateErrorValue(absl::CancelledError()); + }}, + {"Bytes", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must(value_factory.CreateBytesValue("")); + }}, + {"String", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must(value_factory.CreateStringValue("")); + }}, + {"Enum", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must(EnumValue::New( + Must(type_factory.CreateEnumType()), + value_factory, EnumType::ConstantId("VALUE1"))); + }}, + {"Struct", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must(StructValue::New( + Must(type_factory.CreateStructType()), + value_factory)); + }}, + {"List", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must(value_factory.CreateListValue( + Must(type_factory.CreateListType(type_factory.GetIntType())), + std::vector{})); + }}, + {"Map", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return Must(value_factory.CreateMapValue( + Must(type_factory.CreateMapType(type_factory.GetStringType(), + type_factory.GetIntType())), + std::map{})); + }}, + })), + [](const testing::TestParamInfo< + std::tuple>& info) { + return absl::StrCat( + base_internal::MemoryManagerTestModeToString(std::get<0>(info.param)), + "_", std::get<1>(info.param).name); }); -TEST(Value, Swap) { - ValueFactory value_factory(MemoryManager::Global()); +TEST_P(ValueTest, Swap) { + ValueFactory value_factory(memory_manager()); Persistent lhs = value_factory.CreateIntValue(0); Persistent rhs = value_factory.CreateUintValue(0); std::swap(lhs, rhs); @@ -661,19 +706,21 @@ TEST(Value, Swap) { EXPECT_EQ(rhs, value_factory.CreateIntValue(0)); } -TEST(NullValue, DebugString) { - ValueFactory value_factory(MemoryManager::Global()); +using DebugStringTest = ValueTest; + +TEST_P(DebugStringTest, NullValue) { + ValueFactory value_factory(memory_manager()); EXPECT_EQ(value_factory.GetNullValue()->DebugString(), "null"); } -TEST(BoolValue, DebugString) { - ValueFactory value_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, BoolValue) { + ValueFactory value_factory(memory_manager()); EXPECT_EQ(value_factory.CreateBoolValue(false)->DebugString(), "false"); EXPECT_EQ(value_factory.CreateBoolValue(true)->DebugString(), "true"); } -TEST(IntValue, DebugString) { - ValueFactory value_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, IntValue) { + ValueFactory value_factory(memory_manager()); EXPECT_EQ(value_factory.CreateIntValue(-1)->DebugString(), "-1"); EXPECT_EQ(value_factory.CreateIntValue(0)->DebugString(), "0"); EXPECT_EQ(value_factory.CreateIntValue(1)->DebugString(), "1"); @@ -685,8 +732,8 @@ TEST(IntValue, DebugString) { "9223372036854775807"); } -TEST(UintValue, DebugString) { - ValueFactory value_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, UintValue) { + ValueFactory value_factory(memory_manager()); EXPECT_EQ(value_factory.CreateUintValue(0)->DebugString(), "0u"); EXPECT_EQ(value_factory.CreateUintValue(1)->DebugString(), "1u"); EXPECT_EQ(value_factory.CreateUintValue(std::numeric_limits::max()) @@ -694,8 +741,8 @@ TEST(UintValue, DebugString) { "18446744073709551615u"); } -TEST(DoubleValue, DebugString) { - ValueFactory value_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, DoubleValue) { + ValueFactory value_factory(memory_manager()); EXPECT_EQ(value_factory.CreateDoubleValue(-1.0)->DebugString(), "-1.0"); EXPECT_EQ(value_factory.CreateDoubleValue(0.0)->DebugString(), "0.0"); EXPECT_EQ(value_factory.CreateDoubleValue(1.0)->DebugString(), "1.0"); @@ -727,25 +774,29 @@ TEST(DoubleValue, DebugString) { "-infinity"); } -TEST(DurationValue, DebugString) { - ValueFactory value_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, DurationValue) { + ValueFactory value_factory(memory_manager()); EXPECT_EQ(DurationValue::Zero(value_factory)->DebugString(), internal::FormatDuration(absl::ZeroDuration()).value()); } -TEST(TimestampValue, DebugString) { - ValueFactory value_factory(MemoryManager::Global()); +TEST_P(DebugStringTest, TimestampValue) { + ValueFactory value_factory(memory_manager()); EXPECT_EQ(TimestampValue::UnixEpoch(value_factory)->DebugString(), internal::FormatTimestamp(absl::UnixEpoch()).value()); } +INSTANTIATE_TEST_SUITE_P(DebugStringTest, DebugStringTest, + base_internal::MemoryManagerTestModeAll(), + base_internal::MemoryManagerTestModeTupleName); + // The below tests could be made parameterized but doing so requires the // extension for struct member initiation by name for it to be worth it. That // feature is not available in C++17. -TEST(Value, Error) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, Error) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto error_value = value_factory.CreateErrorValue(absl::CancelledError()); EXPECT_TRUE(error_value.Is()); EXPECT_FALSE(error_value.Is()); @@ -755,9 +806,9 @@ TEST(Value, Error) { EXPECT_EQ(error_value->value(), absl::CancelledError()); } -TEST(Value, Bool) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, Bool) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto false_value = BoolValue::False(value_factory); EXPECT_TRUE(false_value.Is()); EXPECT_FALSE(false_value.Is()); @@ -780,9 +831,9 @@ TEST(Value, Bool) { EXPECT_NE(true_value, false_value); } -TEST(Value, Int) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, Int) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = value_factory.CreateIntValue(0); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -805,9 +856,9 @@ TEST(Value, Int) { EXPECT_NE(one_value, zero_value); } -TEST(Value, Uint) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, Uint) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = value_factory.CreateUintValue(0); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -830,9 +881,9 @@ TEST(Value, Uint) { EXPECT_NE(one_value, zero_value); } -TEST(Value, Double) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, Double) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = value_factory.CreateDoubleValue(0.0); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -855,9 +906,9 @@ TEST(Value, Double) { EXPECT_NE(one_value, zero_value); } -TEST(Value, Duration) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, Duration) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = Must(value_factory.CreateDurationValue(absl::ZeroDuration())); EXPECT_TRUE(zero_value.Is()); @@ -885,9 +936,9 @@ TEST(Value, Duration) { StatusIs(absl::StatusCode::kInvalidArgument)); } -TEST(Value, Timestamp) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, Timestamp) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = Must(value_factory.CreateTimestampValue(absl::UnixEpoch())); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -914,9 +965,9 @@ TEST(Value, Timestamp) { StatusIs(absl::StatusCode::kInvalidArgument)); } -TEST(Value, BytesFromString) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, BytesFromString) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = Must(value_factory.CreateBytesValue(std::string("0"))); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -939,9 +990,9 @@ TEST(Value, BytesFromString) { EXPECT_NE(one_value, zero_value); } -TEST(Value, BytesFromStringView) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, BytesFromStringView) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = Must(value_factory.CreateBytesValue(absl::string_view("0"))); EXPECT_TRUE(zero_value.Is()); @@ -967,9 +1018,9 @@ TEST(Value, BytesFromStringView) { EXPECT_NE(one_value, zero_value); } -TEST(Value, BytesFromCord) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, BytesFromCord) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = Must(value_factory.CreateBytesValue(absl::Cord("0"))); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -992,9 +1043,9 @@ TEST(Value, BytesFromCord) { EXPECT_NE(one_value, zero_value); } -TEST(Value, BytesFromLiteral) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, BytesFromLiteral) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = Must(value_factory.CreateBytesValue("0")); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1017,9 +1068,9 @@ TEST(Value, BytesFromLiteral) { EXPECT_NE(one_value, zero_value); } -TEST(Value, BytesFromExternal) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, BytesFromExternal) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = Must(value_factory.CreateBytesValue("0", []() {})); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1042,9 +1093,9 @@ TEST(Value, BytesFromExternal) { EXPECT_NE(one_value, zero_value); } -TEST(Value, StringFromString) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, StringFromString) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = Must(value_factory.CreateStringValue(std::string("0"))); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1068,9 +1119,9 @@ TEST(Value, StringFromString) { EXPECT_NE(one_value, zero_value); } -TEST(Value, StringFromStringView) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, StringFromStringView) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = Must(value_factory.CreateStringValue(absl::string_view("0"))); EXPECT_TRUE(zero_value.Is()); @@ -1097,9 +1148,9 @@ TEST(Value, StringFromStringView) { EXPECT_NE(one_value, zero_value); } -TEST(Value, StringFromCord) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, StringFromCord) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = Must(value_factory.CreateStringValue(absl::Cord("0"))); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1122,9 +1173,9 @@ TEST(Value, StringFromCord) { EXPECT_NE(one_value, zero_value); } -TEST(Value, StringFromLiteral) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, StringFromLiteral) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = Must(value_factory.CreateStringValue("0")); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1147,9 +1198,9 @@ TEST(Value, StringFromLiteral) { EXPECT_NE(one_value, zero_value); } -TEST(Value, StringFromExternal) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ValueTest, StringFromExternal) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); auto zero_value = Must(value_factory.CreateStringValue("0", []() {})); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1192,122 +1243,125 @@ struct BytesConcatTestCase final { std::string rhs; }; -using BytesConcatTest = testing::TestWithParam; +using BytesConcatTest = BaseValueTest; TEST_P(BytesConcatTest, Concat) { - const BytesConcatTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); + ValueFactory value_factory(memory_manager()); EXPECT_TRUE( Must(BytesValue::Concat(value_factory, - MakeStringBytes(value_factory, test_case.lhs), - MakeStringBytes(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + MakeStringBytes(value_factory, test_case().lhs), + MakeStringBytes(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(BytesValue::Concat(value_factory, - MakeStringBytes(value_factory, test_case.lhs), - MakeCordBytes(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + MakeStringBytes(value_factory, test_case().lhs), + MakeCordBytes(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( - Must(BytesValue::Concat(value_factory, - MakeStringBytes(value_factory, test_case.lhs), - MakeExternalBytes(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + Must(BytesValue::Concat( + value_factory, MakeStringBytes(value_factory, test_case().lhs), + MakeExternalBytes(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(BytesValue::Concat(value_factory, - MakeCordBytes(value_factory, test_case.lhs), - MakeStringBytes(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + MakeCordBytes(value_factory, test_case().lhs), + MakeStringBytes(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(BytesValue::Concat(value_factory, - MakeCordBytes(value_factory, test_case.lhs), - MakeCordBytes(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + MakeCordBytes(value_factory, test_case().lhs), + MakeCordBytes(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( - Must(BytesValue::Concat(value_factory, - MakeCordBytes(value_factory, test_case.lhs), - MakeExternalBytes(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + Must(BytesValue::Concat( + value_factory, MakeCordBytes(value_factory, test_case().lhs), + MakeExternalBytes(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(BytesValue::Concat(value_factory, - MakeExternalBytes(value_factory, test_case.lhs), - MakeStringBytes(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + MakeExternalBytes(value_factory, test_case().lhs), + MakeStringBytes(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(BytesValue::Concat(value_factory, - MakeExternalBytes(value_factory, test_case.lhs), - MakeCordBytes(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + MakeExternalBytes(value_factory, test_case().lhs), + MakeCordBytes(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( - Must(BytesValue::Concat(value_factory, - MakeExternalBytes(value_factory, test_case.lhs), - MakeExternalBytes(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); -} - -INSTANTIATE_TEST_SUITE_P(BytesConcatTest, BytesConcatTest, - testing::ValuesIn({ - {"", ""}, - {"", std::string("\0", 1)}, - {std::string("\0", 1), ""}, - {std::string("\0", 1), std::string("\0", 1)}, - {"", "foo"}, - {"foo", ""}, - {"foo", "foo"}, - {"bar", "foo"}, - {"foo", "bar"}, - {"bar", "bar"}, - })); + Must(BytesValue::Concat( + value_factory, MakeExternalBytes(value_factory, test_case().lhs), + MakeExternalBytes(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); +} + +INSTANTIATE_TEST_SUITE_P( + BytesConcatTest, BytesConcatTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {"", ""}, + {"", std::string("\0", 1)}, + {std::string("\0", 1), ""}, + {std::string("\0", 1), std::string("\0", 1)}, + {"", "foo"}, + {"foo", ""}, + {"foo", "foo"}, + {"bar", "foo"}, + {"foo", "bar"}, + {"bar", "bar"}, + }))); struct BytesSizeTestCase final { std::string data; size_t size; }; -using BytesSizeTest = testing::TestWithParam; +using BytesSizeTest = BaseValueTest; TEST_P(BytesSizeTest, Size) { - const BytesSizeTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->size(), - test_case.size); - EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->size(), - test_case.size); - EXPECT_EQ(MakeExternalBytes(value_factory, test_case.data)->size(), - test_case.size); -} - -INSTANTIATE_TEST_SUITE_P(BytesSizeTest, BytesSizeTest, - testing::ValuesIn({ - {"", 0}, - {"1", 1}, - {"foo", 3}, - {"\xef\xbf\xbd", 3}, - })); + ValueFactory value_factory(memory_manager()); + EXPECT_EQ(MakeStringBytes(value_factory, test_case().data)->size(), + test_case().size); + EXPECT_EQ(MakeCordBytes(value_factory, test_case().data)->size(), + test_case().size); + EXPECT_EQ(MakeExternalBytes(value_factory, test_case().data)->size(), + test_case().size); +} + +INSTANTIATE_TEST_SUITE_P( + BytesSizeTest, BytesSizeTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {"", 0}, + {"1", 1}, + {"foo", 3}, + {"\xef\xbf\xbd", 3}, + }))); struct BytesEmptyTestCase final { std::string data; bool empty; }; -using BytesEmptyTest = testing::TestWithParam; +using BytesEmptyTest = BaseValueTest; TEST_P(BytesEmptyTest, Empty) { - const BytesEmptyTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->empty(), - test_case.empty); - EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->empty(), - test_case.empty); - EXPECT_EQ(MakeExternalBytes(value_factory, test_case.data)->empty(), - test_case.empty); -} - -INSTANTIATE_TEST_SUITE_P(BytesEmptyTest, BytesEmptyTest, - testing::ValuesIn({ - {"", true}, - {std::string("\0", 1), false}, - {"1", false}, - })); + ValueFactory value_factory(memory_manager()); + EXPECT_EQ(MakeStringBytes(value_factory, test_case().data)->empty(), + test_case().empty); + EXPECT_EQ(MakeCordBytes(value_factory, test_case().data)->empty(), + test_case().empty); + EXPECT_EQ(MakeExternalBytes(value_factory, test_case().data)->empty(), + test_case().empty); +} + +INSTANTIATE_TEST_SUITE_P( + BytesEmptyTest, BytesEmptyTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {"", true}, + {std::string("\0", 1), false}, + {"1", false}, + }))); struct BytesEqualsTestCase final { std::string lhs; @@ -1315,53 +1369,54 @@ struct BytesEqualsTestCase final { bool equals; }; -using BytesEqualsTest = testing::TestWithParam; +using BytesEqualsTest = BaseValueTest; TEST_P(BytesEqualsTest, Equals) { - const BytesEqualsTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(MakeStringBytes(value_factory, test_case.lhs) - ->Equals(MakeStringBytes(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeStringBytes(value_factory, test_case.lhs) - ->Equals(MakeCordBytes(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeStringBytes(value_factory, test_case.lhs) - ->Equals(MakeExternalBytes(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeCordBytes(value_factory, test_case.lhs) - ->Equals(MakeStringBytes(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeCordBytes(value_factory, test_case.lhs) - ->Equals(MakeCordBytes(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeCordBytes(value_factory, test_case.lhs) - ->Equals(MakeExternalBytes(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeExternalBytes(value_factory, test_case.lhs) - ->Equals(MakeStringBytes(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeExternalBytes(value_factory, test_case.lhs) - ->Equals(MakeCordBytes(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeExternalBytes(value_factory, test_case.lhs) - ->Equals(MakeExternalBytes(value_factory, test_case.rhs)), - test_case.equals); -} - -INSTANTIATE_TEST_SUITE_P(BytesEqualsTest, BytesEqualsTest, - testing::ValuesIn({ - {"", "", true}, - {"", std::string("\0", 1), false}, - {std::string("\0", 1), "", false}, - {std::string("\0", 1), std::string("\0", 1), true}, - {"", "foo", false}, - {"foo", "", false}, - {"foo", "foo", true}, - {"bar", "foo", false}, - {"foo", "bar", false}, - {"bar", "bar", true}, - })); + ValueFactory value_factory(memory_manager()); + EXPECT_EQ(MakeStringBytes(value_factory, test_case().lhs) + ->Equals(MakeStringBytes(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeStringBytes(value_factory, test_case().lhs) + ->Equals(MakeCordBytes(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeStringBytes(value_factory, test_case().lhs) + ->Equals(MakeExternalBytes(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeCordBytes(value_factory, test_case().lhs) + ->Equals(MakeStringBytes(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeCordBytes(value_factory, test_case().lhs) + ->Equals(MakeCordBytes(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeCordBytes(value_factory, test_case().lhs) + ->Equals(MakeExternalBytes(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeExternalBytes(value_factory, test_case().lhs) + ->Equals(MakeStringBytes(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeExternalBytes(value_factory, test_case().lhs) + ->Equals(MakeCordBytes(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeExternalBytes(value_factory, test_case().lhs) + ->Equals(MakeExternalBytes(value_factory, test_case().rhs)), + test_case().equals); +} + +INSTANTIATE_TEST_SUITE_P( + BytesEqualsTest, BytesEqualsTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {"", "", true}, + {"", std::string("\0", 1), false}, + {std::string("\0", 1), "", false}, + {std::string("\0", 1), std::string("\0", 1), true}, + {"", "foo", false}, + {"foo", "", false}, + {"foo", "foo", true}, + {"bar", "foo", false}, + {"foo", "bar", false}, + {"bar", "bar", true}, + }))); struct BytesCompareTestCase final { std::string lhs; @@ -1369,139 +1424,145 @@ struct BytesCompareTestCase final { int compare; }; -using BytesCompareTest = testing::TestWithParam; +using BytesCompareTest = BaseValueTest; int NormalizeCompareResult(int compare) { return std::clamp(compare, -1, 1); } TEST_P(BytesCompareTest, Equals) { - const BytesCompareTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); + ValueFactory value_factory(memory_manager()); EXPECT_EQ(NormalizeCompareResult( - MakeStringBytes(value_factory, test_case.lhs) - ->Compare(MakeStringBytes(value_factory, test_case.rhs))), - test_case.compare); + MakeStringBytes(value_factory, test_case().lhs) + ->Compare(MakeStringBytes(value_factory, test_case().rhs))), + test_case().compare); EXPECT_EQ(NormalizeCompareResult( - MakeStringBytes(value_factory, test_case.lhs) - ->Compare(MakeCordBytes(value_factory, test_case.rhs))), - test_case.compare); - EXPECT_EQ(NormalizeCompareResult( - MakeStringBytes(value_factory, test_case.lhs) - ->Compare(MakeExternalBytes(value_factory, test_case.rhs))), - test_case.compare); - EXPECT_EQ(NormalizeCompareResult( - MakeCordBytes(value_factory, test_case.lhs) - ->Compare(MakeStringBytes(value_factory, test_case.rhs))), - test_case.compare); - EXPECT_EQ(NormalizeCompareResult( - MakeCordBytes(value_factory, test_case.lhs) - ->Compare(MakeCordBytes(value_factory, test_case.rhs))), - test_case.compare); + MakeStringBytes(value_factory, test_case().lhs) + ->Compare(MakeCordBytes(value_factory, test_case().rhs))), + test_case().compare); + EXPECT_EQ( + NormalizeCompareResult( + MakeStringBytes(value_factory, test_case().lhs) + ->Compare(MakeExternalBytes(value_factory, test_case().rhs))), + test_case().compare); EXPECT_EQ(NormalizeCompareResult( - MakeCordBytes(value_factory, test_case.lhs) - ->Compare(MakeExternalBytes(value_factory, test_case.rhs))), - test_case.compare); + MakeCordBytes(value_factory, test_case().lhs) + ->Compare(MakeStringBytes(value_factory, test_case().rhs))), + test_case().compare); EXPECT_EQ(NormalizeCompareResult( - MakeExternalBytes(value_factory, test_case.lhs) - ->Compare(MakeStringBytes(value_factory, test_case.rhs))), - test_case.compare); + MakeCordBytes(value_factory, test_case().lhs) + ->Compare(MakeCordBytes(value_factory, test_case().rhs))), + test_case().compare); + EXPECT_EQ(NormalizeCompareResult(MakeCordBytes(value_factory, test_case().lhs) + ->Compare(MakeExternalBytes( + value_factory, test_case().rhs))), + test_case().compare); EXPECT_EQ(NormalizeCompareResult( - MakeExternalBytes(value_factory, test_case.lhs) - ->Compare(MakeCordBytes(value_factory, test_case.rhs))), - test_case.compare); + MakeExternalBytes(value_factory, test_case().lhs) + ->Compare(MakeStringBytes(value_factory, test_case().rhs))), + test_case().compare); EXPECT_EQ(NormalizeCompareResult( - MakeExternalBytes(value_factory, test_case.lhs) - ->Compare(MakeExternalBytes(value_factory, test_case.rhs))), - test_case.compare); -} - -INSTANTIATE_TEST_SUITE_P(BytesCompareTest, BytesCompareTest, - testing::ValuesIn({ - {"", "", 0}, - {"", std::string("\0", 1), -1}, - {std::string("\0", 1), "", 1}, - {std::string("\0", 1), std::string("\0", 1), 0}, - {"", "foo", -1}, - {"foo", "", 1}, - {"foo", "foo", 0}, - {"bar", "foo", -1}, - {"foo", "bar", 1}, - {"bar", "bar", 0}, - })); + MakeExternalBytes(value_factory, test_case().lhs) + ->Compare(MakeCordBytes(value_factory, test_case().rhs))), + test_case().compare); + EXPECT_EQ( + NormalizeCompareResult( + MakeExternalBytes(value_factory, test_case().lhs) + ->Compare(MakeExternalBytes(value_factory, test_case().rhs))), + test_case().compare); +} + +INSTANTIATE_TEST_SUITE_P( + BytesCompareTest, BytesCompareTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {"", "", 0}, + {"", std::string("\0", 1), -1}, + {std::string("\0", 1), "", 1}, + {std::string("\0", 1), std::string("\0", 1), 0}, + {"", "foo", -1}, + {"foo", "", 1}, + {"foo", "foo", 0}, + {"bar", "foo", -1}, + {"foo", "bar", 1}, + {"bar", "bar", 0}, + }))); struct BytesDebugStringTestCase final { std::string data; }; -using BytesDebugStringTest = testing::TestWithParam; +using BytesDebugStringTest = BaseValueTest; TEST_P(BytesDebugStringTest, ToCord) { - const BytesDebugStringTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->DebugString(), - internal::FormatBytesLiteral(test_case.data)); - EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->DebugString(), - internal::FormatBytesLiteral(test_case.data)); - EXPECT_EQ(MakeExternalBytes(value_factory, test_case.data)->DebugString(), - internal::FormatBytesLiteral(test_case.data)); -} - -INSTANTIATE_TEST_SUITE_P(BytesDebugStringTest, BytesDebugStringTest, - testing::ValuesIn({ - {""}, - {"1"}, - {"foo"}, - {"\xef\xbf\xbd"}, - })); + ValueFactory value_factory(memory_manager()); + EXPECT_EQ(MakeStringBytes(value_factory, test_case().data)->DebugString(), + internal::FormatBytesLiteral(test_case().data)); + EXPECT_EQ(MakeCordBytes(value_factory, test_case().data)->DebugString(), + internal::FormatBytesLiteral(test_case().data)); + EXPECT_EQ(MakeExternalBytes(value_factory, test_case().data)->DebugString(), + internal::FormatBytesLiteral(test_case().data)); +} + +INSTANTIATE_TEST_SUITE_P( + BytesDebugStringTest, BytesDebugStringTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {""}, + {"1"}, + {"foo"}, + {"\xef\xbf\xbd"}, + }))); struct BytesToStringTestCase final { std::string data; }; -using BytesToStringTest = testing::TestWithParam; +using BytesToStringTest = BaseValueTest; TEST_P(BytesToStringTest, ToString) { - const BytesToStringTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->ToString(), - test_case.data); - EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->ToString(), - test_case.data); - EXPECT_EQ(MakeExternalBytes(value_factory, test_case.data)->ToString(), - test_case.data); -} - -INSTANTIATE_TEST_SUITE_P(BytesToStringTest, BytesToStringTest, - testing::ValuesIn({ - {""}, - {"1"}, - {"foo"}, - {"\xef\xbf\xbd"}, - })); + ValueFactory value_factory(memory_manager()); + EXPECT_EQ(MakeStringBytes(value_factory, test_case().data)->ToString(), + test_case().data); + EXPECT_EQ(MakeCordBytes(value_factory, test_case().data)->ToString(), + test_case().data); + EXPECT_EQ(MakeExternalBytes(value_factory, test_case().data)->ToString(), + test_case().data); +} + +INSTANTIATE_TEST_SUITE_P( + BytesToStringTest, BytesToStringTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {""}, + {"1"}, + {"foo"}, + {"\xef\xbf\xbd"}, + }))); struct BytesToCordTestCase final { std::string data; }; -using BytesToCordTest = testing::TestWithParam; +using BytesToCordTest = BaseValueTest; TEST_P(BytesToCordTest, ToCord) { - const BytesToCordTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(MakeStringBytes(value_factory, test_case.data)->ToCord(), - test_case.data); - EXPECT_EQ(MakeCordBytes(value_factory, test_case.data)->ToCord(), - test_case.data); - EXPECT_EQ(MakeExternalBytes(value_factory, test_case.data)->ToCord(), - test_case.data); -} - -INSTANTIATE_TEST_SUITE_P(BytesToCordTest, BytesToCordTest, - testing::ValuesIn({ - {""}, - {"1"}, - {"foo"}, - {"\xef\xbf\xbd"}, - })); + ValueFactory value_factory(memory_manager()); + EXPECT_EQ(MakeStringBytes(value_factory, test_case().data)->ToCord(), + test_case().data); + EXPECT_EQ(MakeCordBytes(value_factory, test_case().data)->ToCord(), + test_case().data); + EXPECT_EQ(MakeExternalBytes(value_factory, test_case().data)->ToCord(), + test_case().data); +} + +INSTANTIATE_TEST_SUITE_P( + BytesToCordTest, BytesToCordTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {""}, + {"1"}, + {"foo"}, + {"\xef\xbf\xbd"}, + }))); Persistent MakeStringString(ValueFactory& value_factory, absl::string_view value) { @@ -1523,122 +1584,125 @@ struct StringConcatTestCase final { std::string rhs; }; -using StringConcatTest = testing::TestWithParam; +using StringConcatTest = BaseValueTest; TEST_P(StringConcatTest, Concat) { - const StringConcatTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_TRUE( - Must(StringValue::Concat(value_factory, - MakeStringString(value_factory, test_case.lhs), - MakeStringString(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); - EXPECT_TRUE( - Must(StringValue::Concat(value_factory, - MakeStringString(value_factory, test_case.lhs), - MakeCordString(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + ValueFactory value_factory(memory_manager()); EXPECT_TRUE( Must(StringValue::Concat( - value_factory, MakeStringString(value_factory, test_case.lhs), - MakeExternalString(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + value_factory, MakeStringString(value_factory, test_case().lhs), + MakeStringString(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(StringValue::Concat(value_factory, - MakeCordString(value_factory, test_case.lhs), - MakeStringString(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); - EXPECT_TRUE( - Must(StringValue::Concat(value_factory, - MakeCordString(value_factory, test_case.lhs), - MakeCordString(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + MakeStringString(value_factory, test_case().lhs), + MakeCordString(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(StringValue::Concat( - value_factory, MakeCordString(value_factory, test_case.lhs), - MakeExternalString(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + value_factory, MakeStringString(value_factory, test_case().lhs), + MakeExternalString(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( - Must(StringValue::Concat(value_factory, - MakeExternalString(value_factory, test_case.lhs), - MakeStringString(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + Must(StringValue::Concat( + value_factory, MakeCordString(value_factory, test_case().lhs), + MakeStringString(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(StringValue::Concat(value_factory, - MakeExternalString(value_factory, test_case.lhs), - MakeCordString(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); + MakeCordString(value_factory, test_case().lhs), + MakeCordString(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); EXPECT_TRUE( Must(StringValue::Concat( - value_factory, MakeExternalString(value_factory, test_case.lhs), - MakeExternalString(value_factory, test_case.rhs))) - ->Equals(test_case.lhs + test_case.rhs)); -} - -INSTANTIATE_TEST_SUITE_P(StringConcatTest, StringConcatTest, - testing::ValuesIn({ - {"", ""}, - {"", std::string("\0", 1)}, - {std::string("\0", 1), ""}, - {std::string("\0", 1), std::string("\0", 1)}, - {"", "foo"}, - {"foo", ""}, - {"foo", "foo"}, - {"bar", "foo"}, - {"foo", "bar"}, - {"bar", "bar"}, - })); + value_factory, MakeCordString(value_factory, test_case().lhs), + MakeExternalString(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); + EXPECT_TRUE(Must(StringValue::Concat( + value_factory, + MakeExternalString(value_factory, test_case().lhs), + MakeStringString(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); + EXPECT_TRUE(Must(StringValue::Concat( + value_factory, + MakeExternalString(value_factory, test_case().lhs), + MakeCordString(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); + EXPECT_TRUE(Must(StringValue::Concat( + value_factory, + MakeExternalString(value_factory, test_case().lhs), + MakeExternalString(value_factory, test_case().rhs))) + ->Equals(test_case().lhs + test_case().rhs)); +} + +INSTANTIATE_TEST_SUITE_P( + StringConcatTest, StringConcatTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {"", ""}, + {"", std::string("\0", 1)}, + {std::string("\0", 1), ""}, + {std::string("\0", 1), std::string("\0", 1)}, + {"", "foo"}, + {"foo", ""}, + {"foo", "foo"}, + {"bar", "foo"}, + {"foo", "bar"}, + {"bar", "bar"}, + }))); struct StringSizeTestCase final { std::string data; size_t size; }; -using StringSizeTest = testing::TestWithParam; +using StringSizeTest = BaseValueTest; TEST_P(StringSizeTest, Size) { - const StringSizeTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(MakeStringString(value_factory, test_case.data)->size(), - test_case.size); - EXPECT_EQ(MakeCordString(value_factory, test_case.data)->size(), - test_case.size); - EXPECT_EQ(MakeExternalString(value_factory, test_case.data)->size(), - test_case.size); -} - -INSTANTIATE_TEST_SUITE_P(StringSizeTest, StringSizeTest, - testing::ValuesIn({ - {"", 0}, - {"1", 1}, - {"foo", 3}, - {"\xef\xbf\xbd", 1}, - })); + ValueFactory value_factory(memory_manager()); + EXPECT_EQ(MakeStringString(value_factory, test_case().data)->size(), + test_case().size); + EXPECT_EQ(MakeCordString(value_factory, test_case().data)->size(), + test_case().size); + EXPECT_EQ(MakeExternalString(value_factory, test_case().data)->size(), + test_case().size); +} + +INSTANTIATE_TEST_SUITE_P( + StringSizeTest, StringSizeTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {"", 0}, + {"1", 1}, + {"foo", 3}, + {"\xef\xbf\xbd", 1}, + }))); struct StringEmptyTestCase final { std::string data; bool empty; }; -using StringEmptyTest = testing::TestWithParam; +using StringEmptyTest = BaseValueTest; TEST_P(StringEmptyTest, Empty) { - const StringEmptyTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(MakeStringString(value_factory, test_case.data)->empty(), - test_case.empty); - EXPECT_EQ(MakeCordString(value_factory, test_case.data)->empty(), - test_case.empty); - EXPECT_EQ(MakeExternalString(value_factory, test_case.data)->empty(), - test_case.empty); -} - -INSTANTIATE_TEST_SUITE_P(StringEmptyTest, StringEmptyTest, - testing::ValuesIn({ - {"", true}, - {std::string("\0", 1), false}, - {"1", false}, - })); + ValueFactory value_factory(memory_manager()); + EXPECT_EQ(MakeStringString(value_factory, test_case().data)->empty(), + test_case().empty); + EXPECT_EQ(MakeCordString(value_factory, test_case().data)->empty(), + test_case().empty); + EXPECT_EQ(MakeExternalString(value_factory, test_case().data)->empty(), + test_case().empty); +} + +INSTANTIATE_TEST_SUITE_P( + StringEmptyTest, StringEmptyTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {"", true}, + {std::string("\0", 1), false}, + {"1", false}, + }))); struct StringEqualsTestCase final { std::string lhs; @@ -1646,53 +1710,54 @@ struct StringEqualsTestCase final { bool equals; }; -using StringEqualsTest = testing::TestWithParam; +using StringEqualsTest = BaseValueTest; TEST_P(StringEqualsTest, Equals) { - const StringEqualsTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(MakeStringString(value_factory, test_case.lhs) - ->Equals(MakeStringString(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeStringString(value_factory, test_case.lhs) - ->Equals(MakeCordString(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeStringString(value_factory, test_case.lhs) - ->Equals(MakeExternalString(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeCordString(value_factory, test_case.lhs) - ->Equals(MakeStringString(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeCordString(value_factory, test_case.lhs) - ->Equals(MakeCordString(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeCordString(value_factory, test_case.lhs) - ->Equals(MakeExternalString(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeExternalString(value_factory, test_case.lhs) - ->Equals(MakeStringString(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeExternalString(value_factory, test_case.lhs) - ->Equals(MakeCordString(value_factory, test_case.rhs)), - test_case.equals); - EXPECT_EQ(MakeExternalString(value_factory, test_case.lhs) - ->Equals(MakeExternalString(value_factory, test_case.rhs)), - test_case.equals); -} - -INSTANTIATE_TEST_SUITE_P(StringEqualsTest, StringEqualsTest, - testing::ValuesIn({ - {"", "", true}, - {"", std::string("\0", 1), false}, - {std::string("\0", 1), "", false}, - {std::string("\0", 1), std::string("\0", 1), true}, - {"", "foo", false}, - {"foo", "", false}, - {"foo", "foo", true}, - {"bar", "foo", false}, - {"foo", "bar", false}, - {"bar", "bar", true}, - })); + ValueFactory value_factory(memory_manager()); + EXPECT_EQ(MakeStringString(value_factory, test_case().lhs) + ->Equals(MakeStringString(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeStringString(value_factory, test_case().lhs) + ->Equals(MakeCordString(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeStringString(value_factory, test_case().lhs) + ->Equals(MakeExternalString(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeCordString(value_factory, test_case().lhs) + ->Equals(MakeStringString(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeCordString(value_factory, test_case().lhs) + ->Equals(MakeCordString(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeCordString(value_factory, test_case().lhs) + ->Equals(MakeExternalString(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeExternalString(value_factory, test_case().lhs) + ->Equals(MakeStringString(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeExternalString(value_factory, test_case().lhs) + ->Equals(MakeCordString(value_factory, test_case().rhs)), + test_case().equals); + EXPECT_EQ(MakeExternalString(value_factory, test_case().lhs) + ->Equals(MakeExternalString(value_factory, test_case().rhs)), + test_case().equals); +} + +INSTANTIATE_TEST_SUITE_P( + StringEqualsTest, StringEqualsTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {"", "", true}, + {"", std::string("\0", 1), false}, + {std::string("\0", 1), "", false}, + {std::string("\0", 1), std::string("\0", 1), true}, + {"", "foo", false}, + {"foo", "", false}, + {"foo", "foo", true}, + {"bar", "foo", false}, + {"foo", "bar", false}, + {"bar", "bar", true}, + }))); struct StringCompareTestCase final { std::string lhs; @@ -1700,143 +1765,151 @@ struct StringCompareTestCase final { int compare; }; -using StringCompareTest = testing::TestWithParam; +using StringCompareTest = BaseValueTest; TEST_P(StringCompareTest, Equals) { - const StringCompareTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(NormalizeCompareResult( - MakeStringString(value_factory, test_case.lhs) - ->Compare(MakeStringString(value_factory, test_case.rhs))), - test_case.compare); - EXPECT_EQ(NormalizeCompareResult( - MakeStringString(value_factory, test_case.lhs) - ->Compare(MakeCordString(value_factory, test_case.rhs))), - test_case.compare); + ValueFactory value_factory(memory_manager()); EXPECT_EQ( NormalizeCompareResult( - MakeStringString(value_factory, test_case.lhs) - ->Compare(MakeExternalString(value_factory, test_case.rhs))), - test_case.compare); - EXPECT_EQ(NormalizeCompareResult( - MakeCordString(value_factory, test_case.lhs) - ->Compare(MakeStringString(value_factory, test_case.rhs))), - test_case.compare); + MakeStringString(value_factory, test_case().lhs) + ->Compare(MakeStringString(value_factory, test_case().rhs))), + test_case().compare); EXPECT_EQ(NormalizeCompareResult( - MakeCordString(value_factory, test_case.lhs) - ->Compare(MakeCordString(value_factory, test_case.rhs))), - test_case.compare); - EXPECT_EQ(NormalizeCompareResult(MakeCordString(value_factory, test_case.lhs) - ->Compare(MakeExternalString( - value_factory, test_case.rhs))), - test_case.compare); + MakeStringString(value_factory, test_case().lhs) + ->Compare(MakeCordString(value_factory, test_case().rhs))), + test_case().compare); + EXPECT_EQ( + NormalizeCompareResult( + MakeStringString(value_factory, test_case().lhs) + ->Compare(MakeExternalString(value_factory, test_case().rhs))), + test_case().compare); + EXPECT_EQ( + NormalizeCompareResult( + MakeCordString(value_factory, test_case().lhs) + ->Compare(MakeStringString(value_factory, test_case().rhs))), + test_case().compare); EXPECT_EQ(NormalizeCompareResult( - MakeExternalString(value_factory, test_case.lhs) - ->Compare(MakeStringString(value_factory, test_case.rhs))), - test_case.compare); + MakeCordString(value_factory, test_case().lhs) + ->Compare(MakeCordString(value_factory, test_case().rhs))), + test_case().compare); + EXPECT_EQ( + NormalizeCompareResult( + MakeCordString(value_factory, test_case().lhs) + ->Compare(MakeExternalString(value_factory, test_case().rhs))), + test_case().compare); + EXPECT_EQ( + NormalizeCompareResult( + MakeExternalString(value_factory, test_case().lhs) + ->Compare(MakeStringString(value_factory, test_case().rhs))), + test_case().compare); EXPECT_EQ(NormalizeCompareResult( - MakeExternalString(value_factory, test_case.lhs) - ->Compare(MakeCordString(value_factory, test_case.rhs))), - test_case.compare); + MakeExternalString(value_factory, test_case().lhs) + ->Compare(MakeCordString(value_factory, test_case().rhs))), + test_case().compare); EXPECT_EQ( NormalizeCompareResult( - MakeExternalString(value_factory, test_case.lhs) - ->Compare(MakeExternalString(value_factory, test_case.rhs))), - test_case.compare); -} - -INSTANTIATE_TEST_SUITE_P(StringCompareTest, StringCompareTest, - testing::ValuesIn({ - {"", "", 0}, - {"", std::string("\0", 1), -1}, - {std::string("\0", 1), "", 1}, - {std::string("\0", 1), std::string("\0", 1), 0}, - {"", "foo", -1}, - {"foo", "", 1}, - {"foo", "foo", 0}, - {"bar", "foo", -1}, - {"foo", "bar", 1}, - {"bar", "bar", 0}, - })); + MakeExternalString(value_factory, test_case().lhs) + ->Compare(MakeExternalString(value_factory, test_case().rhs))), + test_case().compare); +} + +INSTANTIATE_TEST_SUITE_P( + StringCompareTest, StringCompareTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {"", "", 0}, + {"", std::string("\0", 1), -1}, + {std::string("\0", 1), "", 1}, + {std::string("\0", 1), std::string("\0", 1), 0}, + {"", "foo", -1}, + {"foo", "", 1}, + {"foo", "foo", 0}, + {"bar", "foo", -1}, + {"foo", "bar", 1}, + {"bar", "bar", 0}, + }))); struct StringDebugStringTestCase final { std::string data; }; -using StringDebugStringTest = testing::TestWithParam; +using StringDebugStringTest = BaseValueTest; TEST_P(StringDebugStringTest, ToCord) { - const StringDebugStringTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(MakeStringString(value_factory, test_case.data)->DebugString(), - internal::FormatStringLiteral(test_case.data)); - EXPECT_EQ(MakeCordString(value_factory, test_case.data)->DebugString(), - internal::FormatStringLiteral(test_case.data)); - EXPECT_EQ(MakeExternalString(value_factory, test_case.data)->DebugString(), - internal::FormatStringLiteral(test_case.data)); -} - -INSTANTIATE_TEST_SUITE_P(StringDebugStringTest, StringDebugStringTest, - testing::ValuesIn({ - {""}, - {"1"}, - {"foo"}, - {"\xef\xbf\xbd"}, - })); + ValueFactory value_factory(memory_manager()); + EXPECT_EQ(MakeStringString(value_factory, test_case().data)->DebugString(), + internal::FormatStringLiteral(test_case().data)); + EXPECT_EQ(MakeCordString(value_factory, test_case().data)->DebugString(), + internal::FormatStringLiteral(test_case().data)); + EXPECT_EQ(MakeExternalString(value_factory, test_case().data)->DebugString(), + internal::FormatStringLiteral(test_case().data)); +} + +INSTANTIATE_TEST_SUITE_P( + StringDebugStringTest, StringDebugStringTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {""}, + {"1"}, + {"foo"}, + {"\xef\xbf\xbd"}, + }))); struct StringToStringTestCase final { std::string data; }; -using StringToStringTest = testing::TestWithParam; +using StringToStringTest = BaseValueTest; TEST_P(StringToStringTest, ToString) { - const StringToStringTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(MakeStringString(value_factory, test_case.data)->ToString(), - test_case.data); - EXPECT_EQ(MakeCordString(value_factory, test_case.data)->ToString(), - test_case.data); - EXPECT_EQ(MakeExternalString(value_factory, test_case.data)->ToString(), - test_case.data); -} - -INSTANTIATE_TEST_SUITE_P(StringToStringTest, StringToStringTest, - testing::ValuesIn({ - {""}, - {"1"}, - {"foo"}, - {"\xef\xbf\xbd"}, - })); + ValueFactory value_factory(memory_manager()); + EXPECT_EQ(MakeStringString(value_factory, test_case().data)->ToString(), + test_case().data); + EXPECT_EQ(MakeCordString(value_factory, test_case().data)->ToString(), + test_case().data); + EXPECT_EQ(MakeExternalString(value_factory, test_case().data)->ToString(), + test_case().data); +} + +INSTANTIATE_TEST_SUITE_P( + StringToStringTest, StringToStringTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {""}, + {"1"}, + {"foo"}, + {"\xef\xbf\xbd"}, + }))); struct StringToCordTestCase final { std::string data; }; -using StringToCordTest = testing::TestWithParam; +using StringToCordTest = BaseValueTest; TEST_P(StringToCordTest, ToCord) { - const StringToCordTestCase& test_case = GetParam(); - ValueFactory value_factory(MemoryManager::Global()); - EXPECT_EQ(MakeStringString(value_factory, test_case.data)->ToCord(), - test_case.data); - EXPECT_EQ(MakeCordString(value_factory, test_case.data)->ToCord(), - test_case.data); - EXPECT_EQ(MakeExternalString(value_factory, test_case.data)->ToCord(), - test_case.data); -} - -INSTANTIATE_TEST_SUITE_P(StringToCordTest, StringToCordTest, - testing::ValuesIn({ - {""}, - {"1"}, - {"foo"}, - {"\xef\xbf\xbd"}, - })); - -TEST(Value, Enum) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); + ValueFactory value_factory(memory_manager()); + EXPECT_EQ(MakeStringString(value_factory, test_case().data)->ToCord(), + test_case().data); + EXPECT_EQ(MakeCordString(value_factory, test_case().data)->ToCord(), + test_case().data); + EXPECT_EQ(MakeExternalString(value_factory, test_case().data)->ToCord(), + test_case().data); +} + +INSTANTIATE_TEST_SUITE_P( + StringToCordTest, StringToCordTest, + testing::Combine(base_internal::MemoryManagerTestModeAll(), + testing::ValuesIn({ + {""}, + {"1"}, + {"foo"}, + {"\xef\xbf\xbd"}, + }))); + +TEST_P(ValueTest, Enum) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); ASSERT_OK_AND_ASSIGN( @@ -1869,9 +1942,11 @@ TEST(Value, Enum) { EXPECT_NE(two_value, one_value); } -TEST(EnumType, NewInstance) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +using EnumTypeTest = ValueTest; + +TEST_P(EnumTypeTest, NewInstance) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); ASSERT_OK_AND_ASSIGN( @@ -1896,9 +1971,13 @@ TEST(EnumType, NewInstance) { StatusIs(absl::StatusCode::kNotFound)); } -TEST(Value, Struct) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +INSTANTIATE_TEST_SUITE_P(EnumTypeTest, EnumTypeTest, + base_internal::MemoryManagerTestModeAll(), + base_internal::MemoryManagerTestModeTupleName); + +TEST_P(ValueTest, Struct) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto struct_type, type_factory.CreateStructType()); ASSERT_OK_AND_ASSIGN(auto zero_value, @@ -1935,9 +2014,11 @@ TEST(Value, Struct) { EXPECT_NE(one_value, zero_value); } -TEST(StructValue, SetField) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +using StructValueTest = ValueTest; + +TEST_P(StructValueTest, SetField) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto struct_type, type_factory.CreateStructType()); ASSERT_OK_AND_ASSIGN(auto struct_value, @@ -2012,9 +2093,9 @@ TEST(StructValue, SetField) { StatusIs(absl::StatusCode::kNotFound)); } -TEST(StructValue, GetField) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(StructValueTest, GetField) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto struct_type, type_factory.CreateStructType()); ASSERT_OK_AND_ASSIGN(auto struct_value, @@ -2046,9 +2127,9 @@ TEST(StructValue, GetField) { StatusIs(absl::StatusCode::kNotFound)); } -TEST(StructValue, HasField) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(StructValueTest, HasField) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto struct_type, type_factory.CreateStructType()); ASSERT_OK_AND_ASSIGN(auto struct_value, @@ -2075,9 +2156,13 @@ TEST(StructValue, HasField) { StatusIs(absl::StatusCode::kNotFound)); } -TEST(Value, List) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +INSTANTIATE_TEST_SUITE_P(StructValueTest, StructValueTest, + base_internal::MemoryManagerTestModeAll(), + base_internal::MemoryManagerTestModeTupleName); + +TEST_P(ValueTest, List) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto list_type, type_factory.CreateListType(type_factory.GetIntType())); ASSERT_OK_AND_ASSIGN(auto zero_value, @@ -2108,9 +2193,11 @@ TEST(Value, List) { EXPECT_NE(one_value, zero_value); } -TEST(ListValue, DebugString) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +using ListValueTest = ValueTest; + +TEST_P(ListValueTest, DebugString) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto list_type, type_factory.CreateListType(type_factory.GetIntType())); ASSERT_OK_AND_ASSIGN(auto list_value, @@ -2123,9 +2210,9 @@ TEST(ListValue, DebugString) { EXPECT_EQ(list_value->DebugString(), "[0, 1, 2, 3, 4, 5]"); } -TEST(ListValue, Get) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(ListValueTest, Get) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto list_type, type_factory.CreateListType(type_factory.GetIntType())); ASSERT_OK_AND_ASSIGN(auto list_value, @@ -2149,9 +2236,13 @@ TEST(ListValue, Get) { StatusIs(absl::StatusCode::kOutOfRange)); } -TEST(Value, Map) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +INSTANTIATE_TEST_SUITE_P(ListValueTest, ListValueTest, + base_internal::MemoryManagerTestModeAll(), + base_internal::MemoryManagerTestModeTupleName); + +TEST_P(ValueTest, Map) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto map_type, type_factory.CreateMapType(type_factory.GetStringType(), type_factory.GetIntType())); @@ -2186,9 +2277,11 @@ TEST(Value, Map) { EXPECT_NE(one_value, zero_value); } -TEST(MapValue, DebugString) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +using MapValueTest = ValueTest; + +TEST_P(MapValueTest, DebugString) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto map_type, type_factory.CreateMapType(type_factory.GetStringType(), type_factory.GetIntType())); @@ -2203,9 +2296,9 @@ TEST(MapValue, DebugString) { EXPECT_EQ(map_value->DebugString(), "{\"bar\": 2, \"baz\": 3, \"foo\": 1}"); } -TEST(MapValue, GetAndHas) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +TEST_P(MapValueTest, GetAndHas) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto map_type, type_factory.CreateMapType(type_factory.GetStringType(), type_factory.GetIntType())); @@ -2245,9 +2338,13 @@ TEST(MapValue, GetAndHas) { IsOkAndHolds(false)); } -TEST(Value, SupportsAbslHash) { - ValueFactory value_factory(MemoryManager::Global()); - TypeFactory type_factory(MemoryManager::Global()); +INSTANTIATE_TEST_SUITE_P(MapValueTest, MapValueTest, + base_internal::MemoryManagerTestModeAll(), + base_internal::MemoryManagerTestModeTupleName); + +TEST_P(ValueTest, SupportsAbslHash) { + ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); ASSERT_OK_AND_ASSIGN(auto struct_type, @@ -2295,5 +2392,9 @@ TEST(Value, SupportsAbslHash) { })); } +INSTANTIATE_TEST_SUITE_P(ValueTest, ValueTest, + base_internal::MemoryManagerTestModeAll(), + base_internal::MemoryManagerTestModeTupleName); + } // namespace } // namespace cel From a5774cc804b3a6e1aab1284e0cb4563e38063d7d Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 3 May 2022 22:16:35 +0000 Subject: [PATCH 091/111] Update portable expr builder test to build with lite proto, using an example type provider and legacy type adapters. PiperOrigin-RevId: 446302462 --- eval/eval/BUILD | 2 - eval/eval/const_value_step.cc | 1 - eval/eval/select_step.cc | 3 +- eval/public/BUILD | 17 +- .../portable_cel_expr_builder_factory.cc | 4 +- .../portable_cel_expr_builder_factory_test.cc | 562 +++++++++++++++++- 6 files changed, 569 insertions(+), 20 deletions(-) diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 74d387b61..6a0c7659b 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -87,7 +87,6 @@ cc_library( ":evaluator_core", ":expression_step_base", "//eval/public:cel_value", - "//eval/public/structs:cel_proto_wrapper", "//internal:proto_time_encoding", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", @@ -226,7 +225,6 @@ cc_library( ":expression_step_base", "//eval/public:cel_value", "//eval/public/containers:container_backed_map_impl", - "//eval/public/structs:cel_proto_wrapper", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/eval/eval/const_value_step.cc b/eval/eval/const_value_step.cc index 33bac528b..067ac6054 100644 --- a/eval/eval/const_value_step.cc +++ b/eval/eval/const_value_step.cc @@ -6,7 +6,6 @@ #include "google/protobuf/timestamp.pb.h" #include "absl/status/statusor.h" #include "eval/eval/expression_step_base.h" -#include "eval/public/structs/cel_proto_wrapper.h" #include "internal/proto_time_encoding.h" namespace google::api::expr::runtime { diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index cbb5d2751..8a7e95dd8 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -180,7 +180,8 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { break; } case CelValue::Type::kMessage: { - if (arg.MessageOrDie() == nullptr) { + if (CelValue::MessageWrapper w; + arg.GetValue(&w) && w.message_ptr() == nullptr) { frame->value_stack().PopAndPush( CreateErrorValue(frame->memory_manager(), "Message is NULL"), result_trail); diff --git a/eval/public/BUILD b/eval/public/BUILD index 80e0c4bef..2118320c9 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -942,11 +942,22 @@ cc_test( name = "portable_cel_expr_builder_factory_test", srcs = ["portable_cel_expr_builder_factory_test.cc"], deps = [ - ":builtin_func_registrar", + ":activation", + ":cel_options", + ":cel_value", ":portable_cel_expr_builder_factory", - "//eval/public/structs:cel_proto_descriptor_pool_builder", - "//eval/public/structs:protobuf_descriptor_type_provider", + "//eval/public/structs:legacy_type_adapter", + "//eval/public/structs:legacy_type_info_apis", + "//eval/public/structs:legacy_type_provider", + "//eval/testutil:test_message_cc_proto", + "//internal:casts", + "//internal:proto_time_encoding", "//internal:testing", + "//parser", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc index 30320b48b..268bd1b35 100644 --- a/eval/public/portable_cel_expr_builder_factory.cc +++ b/eval/public/portable_cel_expr_builder_factory.cc @@ -31,10 +31,10 @@ std::unique_ptr CreatePortableExprBuilder( const InterpreterOptions& options) { if (type_provider == nullptr) { GOOGLE_LOG(ERROR) << "Cannot pass nullptr as type_provider to " - "CreateProtoLiteExprBuilder"; + "CreatePortableExprBuilder"; return nullptr; } - auto builder = absl::make_unique(); + auto builder = std::make_unique(); builder->GetTypeRegistry()->RegisterTypeProvider(std::move(type_provider)); // LINT.IfChange builder->set_shortcircuiting(options.short_circuiting); diff --git a/eval/public/portable_cel_expr_builder_factory_test.cc b/eval/public/portable_cel_expr_builder_factory_test.cc index 5382647f1..a2b7e54ba 100644 --- a/eval/public/portable_cel_expr_builder_factory_test.cc +++ b/eval/public/portable_cel_expr_builder_factory_test.cc @@ -14,18 +14,468 @@ #include "eval/public/portable_cel_expr_builder_factory.h" +#include +#include +#include #include +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/dynamic_message.h" -#include "eval/public/builtin_func_registrar.h" -#include "eval/public/structs/cel_proto_descriptor_pool_builder.h" -#include "eval/public/structs/protobuf_descriptor_type_provider.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "eval/public/activation.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" +#include "eval/public/structs/legacy_type_provider.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/casts.h" +#include "internal/proto_time_encoding.h" #include "internal/testing.h" +#include "parser/parser.h" namespace google::api::expr::runtime { namespace { +using ::google::protobuf::Int64Value; + +// Helpers for c++ / proto to cel value conversions. +std::optional Unwrap(const CelValue::MessageWrapper& wrapper) { + if (wrapper.message_ptr()->GetTypeName() == "google.protobuf.Duration") { + const auto* duration = + cel::internal::down_cast( + wrapper.message_ptr()); + return CelValue::CreateDuration(cel::internal::DecodeDuration(*duration)); + } else if (wrapper.message_ptr()->GetTypeName() == + "google.protobuf.Timestamp") { + const auto* timestamp = + cel::internal::down_cast( + wrapper.message_ptr()); + return CelValue::CreateTimestamp(cel::internal::DecodeTime(*timestamp)); + } + return std::nullopt; +} + +struct NativeToCelValue { + template + std::optional Convert(T arg) const { + return std::nullopt; + } + + std::optional Convert(int64_t v) const { + return CelValue::CreateInt64(v); + } + + std::optional Convert(const std::string& str) const { + return CelValue::CreateString(&str); + } + + std::optional Convert(double v) const { + return CelValue::CreateDouble(v); + } + + std::optional Convert(bool v) const { + return CelValue::CreateBool(v); + } + + std::optional Convert(const Int64Value& v) const { + return CelValue::CreateInt64(v.value()); + } +}; + +template +class FieldImpl; + +template +class ProtoField { + public: + template + using FieldImpl = FieldImpl; + + virtual ~ProtoField() = default; + virtual absl::Status Set(MessageT* m, CelValue v) const = 0; + virtual absl::StatusOr Get(const MessageT* m) const = 0; + virtual bool Has(const MessageT* m) const = 0; +}; + +// template helpers for wrapping member accessors generically. +template +struct ScalarApiWrap { + using GetFn = FieldT (MessageT::*)() const; + using HasFn = bool (MessageT::*)() const; + using SetFn = void (MessageT::*)(FieldT); + + ScalarApiWrap(GetFn get_fn, HasFn has_fn, SetFn set_fn) + : get_fn(get_fn), has_fn(has_fn), set_fn(set_fn) {} + + FieldT InvokeGet(const MessageT* msg) const { + return std::invoke(get_fn, msg); + } + bool InvokeHas(const MessageT* msg) const { + if (has_fn == nullptr) return true; + return std::invoke(has_fn, msg); + } + void InvokeSet(MessageT* msg, FieldT arg) const { + if (set_fn != nullptr) { + std::invoke(set_fn, msg, arg); + } + } + + GetFn get_fn; + HasFn has_fn; + SetFn set_fn; +}; + +template +struct ComplexTypeApiWrap { + public: + using GetFn = const FieldT& (MessageT::*)() const; + using HasFn = bool (MessageT::*)() const; + using SetAllocatedFn = void (MessageT::*)(FieldT*); + + ComplexTypeApiWrap(GetFn get_fn, HasFn has_fn, + SetAllocatedFn set_allocated_fn) + : get_fn(get_fn), has_fn(has_fn), set_allocated_fn(set_allocated_fn) {} + + const FieldT& InvokeGet(const MessageT* msg) const { + return std::invoke(get_fn, msg); + } + bool InvokeHas(const MessageT* msg) const { + if (has_fn == nullptr) return true; + return std::invoke(has_fn, msg); + } + + void InvokeSetAllocated(MessageT* msg, FieldT* arg) const { + if (set_allocated_fn != nullptr) { + std::invoke(set_allocated_fn, msg, arg); + } + } + + GetFn get_fn; + HasFn has_fn; + SetAllocatedFn set_allocated_fn; +}; + +template +class FieldImpl : public ProtoField { + private: + using ApiWrap = ScalarApiWrap; + + public: + FieldImpl(typename ApiWrap::GetFn get_fn, typename ApiWrap::HasFn has_fn, + typename ApiWrap::SetFn set_fn) + : api_wrapper_(get_fn, has_fn, set_fn) {} + absl::Status Set(TestMessage* m, CelValue v) const override { + FieldT arg; + if (!v.GetValue(&arg)) { + return absl::InvalidArgumentError("wrong type for set"); + } + api_wrapper_.InvokeSet(m, arg); + return absl::OkStatus(); + } + + absl::StatusOr Get(const TestMessage* m) const override { + FieldT result = api_wrapper_.InvokeGet(m); + auto converted = NativeToCelValue().Convert(result); + if (converted.has_value()) { + return *converted; + } + return absl::UnimplementedError("not implemented for type"); + } + + bool Has(const TestMessage* m) const override { + return api_wrapper_.InvokeHas(m); + } + + private: + ApiWrap api_wrapper_; +}; + +template +class FieldImpl : public ProtoField { + using ApiWrap = ComplexTypeApiWrap; + + public: + FieldImpl(typename ApiWrap::GetFn get_fn, typename ApiWrap::HasFn has_fn, + typename ApiWrap::SetAllocatedFn set_fn) + : api_wrapper_(get_fn, has_fn, set_fn) {} + absl::Status Set(TestMessage* m, CelValue v) const override { + int64_t arg; + if (!v.GetValue(&arg)) { + return absl::InvalidArgumentError("wrong type for set"); + } + Int64Value* proto_value = new Int64Value(); + proto_value->set_value(arg); + api_wrapper_.InvokeSetAllocated(m, proto_value); + return absl::OkStatus(); + } + + absl::StatusOr Get(const TestMessage* m) const override { + if (!api_wrapper_.InvokeHas(m)) { + return CelValue::CreateNull(); + } + Int64Value result = api_wrapper_.InvokeGet(m); + auto converted = NativeToCelValue().Convert(std::move(result)); + if (converted.has_value()) { + return *converted; + } + return absl::UnimplementedError("not implemented for type"); + } + + bool Has(const TestMessage* m) const override { + return api_wrapper_.InvokeHas(m); + } + + private: + ApiWrap api_wrapper_; +}; + +// Simple type system for Testing. +class DemoTypeProvider; + +class DemoTimestamp : public LegacyTypeMutationApis { + public: + DemoTimestamp() {} + bool DefinesField(absl::string_view field_name) const override { + return field_name == "seconds" || field_name == "nanos"; + } + + absl::StatusOr NewInstance( + cel::MemoryManager& memory_manager) const override; + + absl::StatusOr AdaptFromWellKnownType( + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper instance) const override; + + absl::Status SetField(absl::string_view field_name, const CelValue& value, + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper& instance) const override; + + private: + absl::Status Validate(const CelValue::MessageWrapper& wrapped_message) const { + if (wrapped_message.message_ptr()->GetTypeName() != + "google.protobuf.Timestamp") { + return absl::InvalidArgumentError("not a timestamp"); + } + return absl::OkStatus(); + } +}; + +class DemoTypeInfo : public LegacyTypeInfoApis { + public: + explicit DemoTypeInfo(const DemoTypeProvider* owning_provider) + : owning_provider_(*owning_provider) {} + std::string DebugString( + const internal::MessageWrapper& wrapped_message) const override; + + const std::string& GetTypename( + const internal::MessageWrapper& wrapped_message) const override; + + const LegacyTypeAccessApis* GetAccessApis( + const internal::MessageWrapper& wrapped_message) const override; + + private: + const DemoTypeProvider& owning_provider_; +}; + +class DemoTestMessage : public LegacyTypeMutationApis, + public LegacyTypeAccessApis { + public: + explicit DemoTestMessage(const DemoTypeProvider* owning_provider); + + bool DefinesField(absl::string_view field_name) const override { + return fields_.contains(field_name); + } + + absl::StatusOr NewInstance( + cel::MemoryManager& memory_manager) const override; + + absl::StatusOr AdaptFromWellKnownType( + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper instance) const override; + + absl::Status SetField(absl::string_view field_name, const CelValue& value, + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper& instance) const override; + + absl::StatusOr HasField( + absl::string_view field_name, + const CelValue::MessageWrapper& value) const override; + + absl::StatusOr GetField( + absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManager& memory_manager) const override; + + private: + using Field = ProtoField; + const DemoTypeProvider& owning_provider_; + absl::flat_hash_map> fields_; +}; + +class DemoTypeProvider : public LegacyTypeProvider { + public: + DemoTypeProvider() : timestamp_type_(), test_message_(this), info_(this) {} + const LegacyTypeInfoApis* GetTypeInfoInstance() const { return &info_; } + + std::optional ProvideLegacyType( + absl::string_view name) const override { + if (name == "google.protobuf.Timestamp") { + return LegacyTypeAdapter(nullptr, ×tamp_type_); + } else if (name == "google.api.expr.runtime.TestMessage") { + return LegacyTypeAdapter(&test_message_, &test_message_); + } + return std::nullopt; + } + + const std::string& GetStableType( + const google::protobuf::MessageLite* wrapped_message) const { + std::string name = wrapped_message->GetTypeName(); + auto [iter, inserted] = stable_types_.insert(name); + return *iter; + } + + CelValue WrapValue(const google::protobuf::MessageLite* message) const { + return CelValue::CreateMessageWrapper( + CelValue::MessageWrapper(message, GetTypeInfoInstance())); + } + + private: + DemoTimestamp timestamp_type_; + DemoTestMessage test_message_; + DemoTypeInfo info_; + mutable absl::node_hash_set stable_types_; // thread hostile +}; + +std::string DemoTypeInfo::DebugString( + const internal::MessageWrapper& wrapped_message) const { + return wrapped_message.message_ptr()->GetTypeName(); +} + +const std::string& DemoTypeInfo::GetTypename( + const internal::MessageWrapper& wrapped_message) const { + return owning_provider_.GetStableType(wrapped_message.message_ptr()); +} + +const LegacyTypeAccessApis* DemoTypeInfo::GetAccessApis( + const internal::MessageWrapper& wrapped_message) const { + auto adapter = owning_provider_.ProvideLegacyType( + wrapped_message.message_ptr()->GetTypeName()); + if (adapter.has_value()) { + return adapter->access_apis(); + } + return nullptr; // not implemented yet. +} + +absl::StatusOr DemoTimestamp::NewInstance( + cel::MemoryManager& memory_manager) const { + auto ts = memory_manager.New(); + return CelValue::MessageWrapper(ts.release(), nullptr); +} +absl::StatusOr DemoTimestamp::AdaptFromWellKnownType( + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper instance) const { + return *Unwrap(instance); +} + +absl::Status DemoTimestamp::SetField(absl::string_view field_name, + const CelValue& value, + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper& instance) const { + ABSL_ASSERT(Validate(instance).ok()); + const auto* const_ts = + cel::internal::down_cast( + instance.message_ptr()); + auto* mutable_ts = const_cast(const_ts); + if (field_name == "seconds" && value.IsInt64()) { + mutable_ts->set_seconds(value.Int64OrDie()); + } else if (field_name == "nanos" && value.IsInt64()) { + mutable_ts->set_nanos(value.Int64OrDie()); + } else { + return absl::UnknownError("no such field"); + } + return absl::OkStatus(); +} + +DemoTestMessage::DemoTestMessage(const DemoTypeProvider* owning_provider) + : owning_provider_(*owning_provider) { + // Note: has for non-optional scalars on proto3 messages would be implemented + // as msg.value() != MessageType::default_instance.value(), but omited for + // brevity. + fields_["int64_value"] = std::make_unique>( + &TestMessage::int64_value, + /*has_fn=*/nullptr, &TestMessage::set_int64_value); + fields_["double_value"] = std::make_unique>( + &TestMessage::double_value, + /*has_fn=*/nullptr, &TestMessage::set_double_value); + fields_["bool_value"] = std::make_unique>( + &TestMessage::bool_value, + /*has_fn=*/nullptr, &TestMessage::set_bool_value); + fields_["int64_wrapper_value"] = + std::make_unique>( + &TestMessage::int64_wrapper_value, + &TestMessage::has_int64_wrapper_value, + &TestMessage::set_allocated_int64_wrapper_value); +} + +absl::StatusOr DemoTestMessage::NewInstance( + cel::MemoryManager& memory_manager) const { + auto ts = memory_manager.New(); + return CelValue::MessageWrapper(ts.release(), + owning_provider_.GetTypeInfoInstance()); +} + +absl::Status DemoTestMessage::SetField( + absl::string_view field_name, const CelValue& value, + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper& instance) const { + auto iter = fields_.find(field_name); + if (iter == fields_.end()) { + return absl::UnknownError("no such field"); + } + auto* test_msg = + cel::internal::down_cast(instance.message_ptr()); + auto* mutable_test_msg = const_cast(test_msg); + return iter->second->Set(mutable_test_msg, value); +} + +absl::StatusOr DemoTestMessage::AdaptFromWellKnownType( + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper instance) const { + return CelValue::CreateMessageWrapper(instance); +} + +absl::StatusOr DemoTestMessage::HasField( + absl::string_view field_name, const CelValue::MessageWrapper& value) const { + auto iter = fields_.find(field_name); + if (iter == fields_.end()) { + return absl::UnknownError("no such field"); + } + auto* test_msg = + cel::internal::down_cast(value.message_ptr()); + return iter->second->Has(test_msg); +} + +// Access field on instance. +absl::StatusOr DemoTestMessage::GetField( + absl::string_view field_name, const CelValue::MessageWrapper& instance, + ProtoWrapperTypeOptions unboxing_option, + cel::MemoryManager& memory_manager) const { + auto iter = fields_.find(field_name); + if (iter == fields_.end()) { + return absl::UnknownError("no such field"); + } + auto* test_msg = + cel::internal::down_cast(instance.message_ptr()); + return iter->second->Get(test_msg); +} + TEST(PortableCelExprBuilderFactoryTest, CreateNullOnMissingTypeProvider) { std::unique_ptr builder = CreatePortableExprBuilder(nullptr); @@ -33,17 +483,107 @@ TEST(PortableCelExprBuilderFactoryTest, CreateNullOnMissingTypeProvider) { } TEST(PortableCelExprBuilderFactoryTest, CreateSuccess) { - google::protobuf::DescriptorPool descriptor_pool; google::protobuf::Arena arena; - // Setup descriptor pool and builder - ASSERT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); - google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); - auto type_provider = std::make_unique( - &descriptor_pool, &message_factory); + InterpreterOptions opts; + Activation activation; + std::unique_ptr builder = + CreatePortableExprBuilder(std::make_unique(), opts); + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + parser::Parse("google.protobuf.Timestamp{seconds: 3000, nanos: 20}")); + // TODO(issues/5): make builtin functions portable + // ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + + absl::Time result_time; + ASSERT_TRUE(result.GetValue(&result_time)); + EXPECT_EQ(result_time, + absl::UnixEpoch() + absl::Minutes(50) + absl::Nanoseconds(20)); +} + +TEST(PortableCelExprBuilderFactoryTest, CreateCustomMessage) { + google::protobuf::Arena arena; + + InterpreterOptions opts; + Activation activation; + std::unique_ptr builder = + CreatePortableExprBuilder(std::make_unique(), opts); + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + parser::Parse("google.api.expr.runtime.TestMessage{int64_value: 20, " + "double_value: 3.5}.double_value")); + + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + + double result_double; + ASSERT_TRUE(result.GetValue(&result_double)) << result.DebugString(); + EXPECT_EQ(result_double, 3.5); +} + +TEST(PortableCelExprBuilderFactoryTest, ActivationAndCreate) { + google::protobuf::Arena arena; + + InterpreterOptions opts; + Activation activation; + auto provider = std::make_unique(); + auto* provider_view = provider.get(); std::unique_ptr builder = - CreatePortableExprBuilder(std::move(type_provider)); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + CreatePortableExprBuilder(std::move(provider), opts); + builder->set_container("google.api.expr.runtime"); + ASSERT_OK_AND_ASSIGN( + ParsedExpr expr, + parser::Parse("TestMessage{int64_value: 20, bool_value: " + "false}.bool_value || my_var.bool_value ? 1 : 2")); + + ASSERT_OK_AND_ASSIGN( + auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); + TestMessage my_var; + my_var.set_bool_value(true); + activation.InsertValue("my_var", provider_view->WrapValue(&my_var)); + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + + int64_t result_int64; + ASSERT_TRUE(result.GetValue(&result_int64)) << result.DebugString(); + EXPECT_EQ(result_int64, 1); +} + +TEST(PortableCelExprBuilderFactoryTest, WrapperTypes) { + google::protobuf::Arena arena; + InterpreterOptions opts; + Activation activation; + auto provider = std::make_unique(); + const auto* provider_view = provider.get(); + std::unique_ptr builder = + CreatePortableExprBuilder(std::move(provider), opts); + builder->set_container("google.api.expr.runtime"); + ASSERT_OK_AND_ASSIGN(ParsedExpr null_expr, + parser::Parse("my_var.int64_wrapper_value")); + + TestMessage my_var; + my_var.set_bool_value(true); + activation.InsertValue("my_var", provider_view->WrapValue(&my_var)); + + ASSERT_OK_AND_ASSIGN( + auto plan, + builder->CreateExpression(&null_expr.expr(), &null_expr.source_info())); + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + + EXPECT_TRUE(result.IsNull()) << result.DebugString(); + + my_var.mutable_int64_wrapper_value()->set_value(30); + + ASSERT_OK_AND_ASSIGN(result, plan->Evaluate(activation, &arena)); + int64_t result_int64; + ASSERT_TRUE(result.GetValue(&result_int64)) << result.DebugString(); + EXPECT_EQ(result_int64, 30); } } // namespace From f382a37a76bb9c7df0383e80252bca1390f49450 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 3 May 2022 22:26:19 +0000 Subject: [PATCH 092/111] Update type registry to lookup CelType values using the registered type adapters instead of directly consulting a proto DescriptorPool. PiperOrigin-RevId: 446304673 --- eval/compiler/BUILD | 2 -- eval/compiler/flat_expr_builder.cc | 2 +- eval/compiler/flat_expr_builder.h | 3 --- eval/compiler/flat_expr_builder_test.cc | 12 ++++----- eval/compiler/resolver_test.cc | 13 +++++++-- eval/public/BUILD | 2 ++ eval/public/cel_expr_builder_factory.cc | 2 +- eval/public/cel_expression.h | 5 ---- eval/public/cel_type_registry.cc | 35 +++++++++---------------- eval/public/cel_type_registry.h | 12 +++------ eval/public/cel_type_registry_test.cc | 16 +++++------ 11 files changed, 45 insertions(+), 59 deletions(-) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 827d82e03..e7ee05866 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -38,14 +38,12 @@ cc_library( "//eval/public:cel_expression", "//eval/public:cel_function_registry", "//eval/public:source_position", - "//internal:status_macros", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", - "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 9f9450f9f..999d03ad8 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -19,11 +19,11 @@ #include #include #include +#include #include #include #include "google/api/expr/v1alpha1/checked.pb.h" -#include "stack" #include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index dee1cc189..471ddec2d 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -30,9 +30,6 @@ class FlatExprBuilder : public CelExpressionBuilder { public: FlatExprBuilder() : CelExpressionBuilder() {} - explicit FlatExprBuilder(const google::protobuf::DescriptorPool* descriptor_pool) - : CelExpressionBuilder(descriptor_pool) {} - // set_enable_unknowns controls support for unknowns in expressions created. void set_enable_unknowns(bool enabled) { enable_unknowns_ = enabled; } diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index c2cbd4218..c6aae9715 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1808,7 +1808,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { // This time, the message is unknown. We only have the proto as data, we did // not link the generated message, so it's not included in the generated pool. - FlatExprBuilder builder(google::protobuf::DescriptorPool::generated_pool()); + FlatExprBuilder builder; builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -1831,7 +1831,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { // This time, the message is *known*. We are using a custom descriptor pool // that has been primed with the relevant message. - FlatExprBuilder builder2(&desc_pool); + FlatExprBuilder builder2; builder2.GetTypeRegistry()->RegisterTypeProvider( std::make_unique(&desc_pool, &message_factory)); @@ -1871,9 +1871,9 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForSelect) { const google::protobuf::FieldDescriptor* field = desc->FindFieldByName("int64_value"); refl->SetInt64(message, field, 123); - // This time, the message is *known*. We are using a custom descriptor pool - // that has been primed with the relevant message. - FlatExprBuilder builder(&desc_pool); + // The since this is access only, the evaluator will work with message duck + // typing. + FlatExprBuilder builder; ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -1923,7 +1923,7 @@ TEST_P(CustomDescriptorPoolTest, TestType) { ASSERT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("m")); - FlatExprBuilder builder(&descriptor_pool); + FlatExprBuilder builder; builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique(&descriptor_pool, &message_factory)); diff --git a/eval/compiler/resolver_test.cc b/eval/compiler/resolver_test.cc index 8ecfab760..b3346d436 100644 --- a/eval/compiler/resolver_test.cc +++ b/eval/compiler/resolver_test.cc @@ -98,14 +98,19 @@ TEST(ResolverTest, TestFindConstantUnqualifiedType) { } TEST(ResolverTest, TestFindConstantFullyQualifiedType) { + google::protobuf::LinkMessageReflection(); CelFunctionRegistry func_registry; CelTypeRegistry type_registry; + type_registry.RegisterTypeProvider( + std::make_unique( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); Resolver resolver("cel", &func_registry, &type_registry); auto type_value = resolver.FindConstant(".google.api.expr.runtime.TestMessage", -1); - EXPECT_TRUE(type_value.has_value()); - EXPECT_TRUE(type_value->IsCelType()); + ASSERT_TRUE(type_value.has_value()); + ASSERT_TRUE(type_value->IsCelType()); EXPECT_THAT(type_value->CelTypeOrDie().value(), Eq("google.api.expr.runtime.TestMessage")); } @@ -113,6 +118,10 @@ TEST(ResolverTest, TestFindConstantFullyQualifiedType) { TEST(ResolverTest, TestFindConstantQualifiedTypeDisabled) { CelFunctionRegistry func_registry; CelTypeRegistry type_registry; + type_registry.RegisterTypeProvider( + std::make_unique( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); Resolver resolver("", &func_registry, &type_registry, false); auto type_value = resolver.FindConstant(".google.api.expr.runtime.TestMessage", -1); diff --git a/eval/public/BUILD b/eval/public/BUILD index 2118320c9..e40f25ac8 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -649,12 +649,14 @@ cc_library( ":cel_value", "//eval/public/structs:legacy_type_provider", "//internal:no_destructor", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 3c517ba14..17775e5aa 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -45,7 +45,7 @@ std::unique_ptr CreateCelExpressionBuilder( GOOGLE_LOG(WARNING) << "Failed to validate standard message types: " << s; return nullptr; } - auto builder = absl::make_unique(descriptor_pool); + auto builder = std::make_unique(); builder->GetTypeRegistry()->RegisterTypeProvider( std::make_unique(descriptor_pool, message_factory)); diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 5dc894a9f..95b4f5bdc 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -80,11 +80,6 @@ class CelExpressionBuilder { type_registry_(absl::make_unique()), container_("") {} - explicit CelExpressionBuilder(const google::protobuf::DescriptorPool* descriptor_pool) - : func_registry_(absl::make_unique()), - type_registry_(absl::make_unique(descriptor_pool)), - container_("") {} - virtual ~CelExpressionBuilder() {} // Creates CelExpression object from AST tree. diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc index e7a688ed3..f2925b09d 100644 --- a/eval/public/cel_type_registry.cc +++ b/eval/public/cel_type_registry.cc @@ -9,6 +9,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" #include "absl/status/status.h" +#include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "eval/public/cel_value.h" #include "internal/no_destructor.h" @@ -93,21 +94,13 @@ const absl::flat_hash_set& GetCoreEnums } // namespace CelTypeRegistry::CelTypeRegistry() - : descriptor_pool_(google::protobuf::DescriptorPool::generated_pool()), - types_(GetCoreTypes()), - enums_(GetCoreEnums()) { - EnumAdder().AddEnum(enums_map_); -} - -CelTypeRegistry::CelTypeRegistry(const google::protobuf::DescriptorPool* descriptor_pool) - : descriptor_pool_(descriptor_pool), - types_(GetCoreTypes()), - enums_(GetCoreEnums()) { + : types_(GetCoreTypes()), enums_(GetCoreEnums()) { EnumAdder().AddEnum(enums_map_); } void CelTypeRegistry::Register(std::string fully_qualified_type_name) { // Registers the fully qualified type name as a CEL type. + absl::MutexLock lock(&mutex_); types_.insert(std::move(fully_qualified_type_name)); } @@ -124,13 +117,6 @@ CelTypeRegistry::GetFirstTypeProvider() const { return type_providers_[0]; } -const google::protobuf::Descriptor* CelTypeRegistry::FindDescriptor( - absl::string_view fully_qualified_type_name) const { - // Public protobuf interface only accepts const string&. - return descriptor_pool_->FindMessageTypeByName( - std::string(fully_qualified_type_name)); -} - // Find a type's CelValue instance by its fully qualified name. absl::optional CelTypeRegistry::FindTypeAdapter( absl::string_view fully_qualified_type_name) const { @@ -146,6 +132,7 @@ absl::optional CelTypeRegistry::FindTypeAdapter( absl::optional CelTypeRegistry::FindType( absl::string_view fully_qualified_type_name) const { + absl::MutexLock lock(&mutex_); // Searches through explicitly registered type names first. auto type = types_.find(fully_qualified_type_name); // The CelValue returned by this call will remain valid as long as the @@ -154,12 +141,14 @@ absl::optional CelTypeRegistry::FindType( return CelValue::CreateCelTypeView(*type); } - // By default falls back to looking at whether the protobuf descriptor is - // linked into the binary. In the future, this functionality may be disabled, - // but this is most consistent with the current CEL C++ behavior. - auto desc = FindDescriptor(fully_qualified_type_name); - if (desc != nullptr) { - return CelValue::CreateCelTypeView(desc->full_name()); + // By default falls back to looking at whether the type is provided by one + // of the registered providers (generally, one backed by the generated + // DescriptorPool). + auto adapter = FindTypeAdapter(fully_qualified_type_name); + if (adapter.has_value()) { + auto [iter, inserted] = + types_.insert(std::string(fully_qualified_type_name)); + return CelValue::CreateCelTypeView(*iter); } return absl::nullopt; } diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index b716ea448..91294adfb 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -6,11 +6,13 @@ #include #include "google/protobuf/descriptor.h" +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "eval/public/cel_value.h" #include "eval/public/structs/legacy_type_provider.h" @@ -37,7 +39,6 @@ class CelTypeRegistry { }; CelTypeRegistry(); - explicit CelTypeRegistry(const google::protobuf::DescriptorPool* descriptor_pool); ~CelTypeRegistry() {} @@ -90,15 +91,10 @@ class CelTypeRegistry { } private: - // Find a protobuf Descriptor given a fully qualified protobuf type name. - const google::protobuf::Descriptor* FindDescriptor( - absl::string_view fully_qualified_type_name) const; - - const google::protobuf::DescriptorPool* descriptor_pool_; // externally owned - + mutable absl::Mutex mutex_; // node_hash_set provides pointer-stability, which is required for the // strings backing CelType objects. - absl::node_hash_set types_; + mutable absl::node_hash_set types_ ABSL_GUARDED_BY(mutex_); // Set of registered enums. absl::flat_hash_set enums_; // Internal representation for enums. diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc index afbce4301..2f6b09619 100644 --- a/eval/public/cel_type_registry_test.cc +++ b/eval/public/cel_type_registry_test.cc @@ -197,16 +197,16 @@ TEST(CelTypeRegistryTest, TestFindTypeCoreTypeFound) { EXPECT_THAT(type->CelTypeOrDie().value(), Eq("int")); } -TEST(CelTypeRegistryTest, TestFindTypeProtobufTypeFound) { +TEST(CelTypeRegistryTest, TestFindTypeAdapterTypeFound) { CelTypeRegistry registry; + registry.RegisterTypeProvider(std::make_unique( + std::vector{"google.protobuf.Int64"})); + registry.RegisterTypeProvider(std::make_unique( + std::vector{"google.protobuf.Any"})); auto type = registry.FindType("google.protobuf.Any"); - if constexpr (std::is_base_of_v) { - ASSERT_TRUE(type.has_value()); - EXPECT_TRUE(type->IsCelType()); - EXPECT_THAT(type->CelTypeOrDie().value(), Eq("google.protobuf.Any")); - } else { - EXPECT_FALSE(type.has_value()); - } + ASSERT_TRUE(type.has_value()); + EXPECT_TRUE(type->IsCelType()); + EXPECT_THAT(type->CelTypeOrDie().value(), Eq("google.protobuf.Any")); } TEST(CelTypeRegistryTest, TestFindTypeNotRegisteredTypeNotFound) { From 17eba2dec83c9442062537c1dd52ccea7508a949 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 3 May 2022 22:26:51 +0000 Subject: [PATCH 093/111] Expose MessageWrapper type in a publicly visible rule. PiperOrigin-RevId: 446304791 --- eval/public/BUILD | 28 ++++++- eval/public/cel_value.cc | 2 +- eval/public/cel_value.h | 12 +-- eval/public/cel_value_internal.h | 54 +----------- eval/public/comparison_functions.cc | 2 +- eval/public/comparison_functions_test.cc | 8 +- eval/public/message_wrapper.h | 82 +++++++++++++++++++ eval/public/message_wrapper_test.cc | 54 ++++++++++++ .../portable_cel_expr_builder_factory_test.cc | 13 ++- eval/public/structs/BUILD | 15 ++-- .../structs/cel_proto_wrap_util_test.cc | 2 +- eval/public/structs/cel_proto_wrapper.cc | 4 +- .../structs/legacy_type_adapter_test.cc | 6 +- eval/public/structs/legacy_type_info_apis.h | 8 +- .../structs/proto_message_type_adapter.cc | 14 ++-- .../proto_message_type_adapter_test.cc | 50 +++++------ .../public/structs/trivial_legacy_type_info.h | 10 +-- .../structs/trivial_legacy_type_info_test.cc | 8 +- 18 files changed, 236 insertions(+), 136 deletions(-) create mode 100644 eval/public/message_wrapper.h create mode 100644 eval/public/message_wrapper_test.cc diff --git a/eval/public/BUILD b/eval/public/BUILD index e40f25ac8..d01cca09e 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -18,12 +18,35 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +cc_library( + name = "message_wrapper", + hdrs = [ + "message_wrapper.h", + ], + deps = ["@com_google_protobuf//:protobuf"], +) + +cc_test( + name = "message_wrapper_test", + srcs = [ + "message_wrapper_test.cc", + ], + deps = [ + ":message_wrapper", + "//eval/public/structs:trivial_legacy_type_info", + "//eval/testutil:test_message_cc_proto", + "//internal:testing", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "cel_value_internal", hdrs = [ "cel_value_internal.h", ], deps = [ + ":message_wrapper", "//internal:casts", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/numeric:bits", @@ -42,6 +65,7 @@ cc_library( ], deps = [ ":cel_value_internal", + ":message_wrapper", "//base:memory_manager", "//eval/public/structs:legacy_type_info_apis", "//extensions/protobuf:memory_manager", @@ -267,7 +291,7 @@ cc_library( ":cel_number", ":cel_options", ":cel_value", - ":cel_value_internal", + ":message_wrapper", "//eval/eval:mutable_list_impl", "//eval/public/structs:legacy_type_adapter", "//eval/public/structs:legacy_type_info_apis", @@ -299,8 +323,8 @@ cc_test( ":cel_function_registry", ":cel_options", ":cel_value", - ":cel_value_internal", ":comparison_functions", + ":message_wrapper", ":set_util", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", diff --git a/eval/public/cel_value.cc b/eval/public/cel_value.cc index 5b12a7362..4dc5bcc77 100644 --- a/eval/public/cel_value.cc +++ b/eval/public/cel_value.cc @@ -73,7 +73,7 @@ struct DebugStringVisitor { return absl::StrFormat("%s", arg.value()); } - std::string operator()(const internal::MessageWrapper& arg) { + std::string operator()(const MessageWrapper& arg) { return arg.message_ptr() == nullptr ? "NULL" : arg.legacy_type_info()->DebugString(arg); diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index d0ba11dbd..effa2603a 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -34,6 +34,7 @@ #include "absl/types/variant.h" #include "base/memory_manager.h" #include "eval/public/cel_value_internal.h" +#include "eval/public/message_wrapper.h" #include "internal/casts.h" #include "internal/status_macros.h" #include "internal/utf8.h" @@ -115,16 +116,7 @@ class CelValue { // absl::variant. using NullType = absl::monostate; - // MessageWrapper wraps a tagged MessageLite with the accessors used to - // get field values. - // - // message_ptr(): get the MessageLite pointer of the wrapped message. - // - // legacy_type_info(): get type information about the wrapped message. see - // LegacyTypeInfoApis. - // - // HasFullProto(): returns whether it's safe to downcast to google::protobuf::Message. - using MessageWrapper = internal::MessageWrapper; + using MessageWrapper = MessageWrapper; private: // CelError MUST BE the last in the declaration - it is a ceiling for Type diff --git a/eval/public/cel_value_internal.h b/eval/public/cel_value_internal.h index 1281635ee..af1a5d949 100644 --- a/eval/public/cel_value_internal.h +++ b/eval/public/cel_value_internal.h @@ -25,14 +25,10 @@ #include "absl/base/macros.h" #include "absl/numeric/bits.h" #include "absl/types/variant.h" +#include "eval/public/message_wrapper.h" #include "internal/casts.h" -namespace google::api::expr::runtime { - -// Forward declare to resolve circular dependency. -class LegacyTypeInfoApis; - -namespace internal { +namespace google::api::expr::runtime::internal { // Helper classes needed for IndexOf metafunction implementation. template @@ -88,49 +84,6 @@ class ValueHolder { absl::variant value_; }; -class MessageWrapper { - public: - static_assert(alignof(google::protobuf::MessageLite) >= 2, - "Assume that valid MessageLite ptrs have a free low-order bit"); - MessageWrapper() : message_ptr_(0), legacy_type_info_(nullptr) {} - - MessageWrapper(const google::protobuf::MessageLite* message, - const LegacyTypeInfoApis* legacy_type_info) - : message_ptr_(reinterpret_cast(message)), - legacy_type_info_(legacy_type_info) { - ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); - } - - MessageWrapper(const google::protobuf::Message* message, - const LegacyTypeInfoApis* legacy_type_info) - : message_ptr_(reinterpret_cast(message) | kTagMask), - legacy_type_info_(legacy_type_info) { - ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); - } - - bool HasFullProto() const { return (message_ptr_ & kTagMask) == kTagMask; } - - const google::protobuf::MessageLite* message_ptr() const { - return reinterpret_cast(message_ptr_ & - kPtrMask); - } - - const LegacyTypeInfoApis* legacy_type_info() const { - return legacy_type_info_; - } - - private: - static constexpr uintptr_t kTagMask = 1 << 0; - static constexpr uintptr_t kPtrMask = ~kTagMask; - uintptr_t message_ptr_; - const LegacyTypeInfoApis* legacy_type_info_; - // TODO(issues/5): add LegacyTypeAccessApis to expose generic accessors for - // MessageLite. -}; - -static_assert(sizeof(MessageWrapper) <= 2 * sizeof(uintptr_t), - "MessageWrapper must not increase CelValue size."); - // Adapter for visitor clients that depend on google::protobuf::Message as a variant type. template struct MessageVisitAdapter { @@ -151,7 +104,6 @@ struct MessageVisitAdapter { Op op; }; -} // namespace internal -} // namespace google::api::expr::runtime +} // namespace google::api::expr::runtime::internal #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_INTERNAL_H_ diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index a68c4e221..77c5e7069 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -37,7 +37,7 @@ #include "eval/public/cel_number.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" -#include "eval/public/cel_value_internal.h" +#include "eval/public/message_wrapper.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "internal/casts.h" diff --git a/eval/public/comparison_functions_test.cc b/eval/public/comparison_functions_test.cc index e26c025e3..597574d88 100644 --- a/eval/public/comparison_functions_test.cc +++ b/eval/public/comparison_functions_test.cc @@ -44,10 +44,10 @@ #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" -#include "eval/public/cel_value_internal.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/containers/field_backed_list_impl.h" +#include "eval/public/message_wrapper.h" #include "eval/public/set_util.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/structs/trivial_legacy_type_info.h" @@ -413,7 +413,7 @@ TEST(CelValueEqualImplTest, ProtoEqualityDifferingTypenameInequal) { CelValue lhs = CelProtoWrapper::CreateMessage(&example, &arena); CelValue rhs = CelValue::CreateMessageWrapper( - internal::MessageWrapper(&example, TrivialTypeInfo::GetInstance())); + MessageWrapper(&example, TrivialTypeInfo::GetInstance())); EXPECT_THAT(CelValueEqualImpl(lhs, rhs), Optional(false)); } @@ -430,9 +430,9 @@ TEST(CelValueEqualImplTest, ProtoEqualityNoAccessorInequal) { &example)); CelValue lhs = CelValue::CreateMessageWrapper( - internal::MessageWrapper(&example, TrivialTypeInfo::GetInstance())); + MessageWrapper(&example, TrivialTypeInfo::GetInstance())); CelValue rhs = CelValue::CreateMessageWrapper( - internal::MessageWrapper(&example, TrivialTypeInfo::GetInstance())); + MessageWrapper(&example, TrivialTypeInfo::GetInstance())); EXPECT_THAT(CelValueEqualImpl(lhs, rhs), Optional(false)); } diff --git a/eval/public/message_wrapper.h b/eval/public/message_wrapper.h new file mode 100644 index 000000000..b4e1f00fa --- /dev/null +++ b/eval/public/message_wrapper.h @@ -0,0 +1,82 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_MESSAGE_WRAPPER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_MESSAGE_WRAPPER_H_ + +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" + +namespace google::api::expr::runtime { + +// Forward declare to resolve cycle. +class LegacyTypeInfoApis; + +// Wrapper type for protobuf messages. This is used to limit internal usages of +// proto APIs and to support working with the proto lite runtime. +// +// Provides operations for checking if down-casting to Message is safe. +class MessageWrapper { + public: + static_assert(alignof(google::protobuf::MessageLite) >= 2, + "Assume that valid MessageLite ptrs have a free low-order bit"); + MessageWrapper() : message_ptr_(0), legacy_type_info_(nullptr) {} + + MessageWrapper(const google::protobuf::MessageLite* message, + const LegacyTypeInfoApis* legacy_type_info) + : message_ptr_(reinterpret_cast(message)), + legacy_type_info_(legacy_type_info) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); + } + + MessageWrapper(const google::protobuf::Message* message, + const LegacyTypeInfoApis* legacy_type_info) + : message_ptr_(reinterpret_cast(message) | kTagMask), + legacy_type_info_(legacy_type_info) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); + } + + // If true, the message is using the full proto runtime and downcasting to + // message should be safe. + bool HasFullProto() const { return (message_ptr_ & kTagMask) == kTagMask; } + + // Returns the underlying message. + // + // Clients must check HasFullProto before downcasting to Message. + const google::protobuf::MessageLite* message_ptr() const { + return reinterpret_cast(message_ptr_ & + kPtrMask); + } + + // Type information associated with this message. + const LegacyTypeInfoApis* legacy_type_info() const { + return legacy_type_info_; + } + + private: + MessageWrapper(uintptr_t message_ptr, + const LegacyTypeInfoApis* legacy_type_info) + : message_ptr_(message_ptr), legacy_type_info_(legacy_type_info) {} + + static constexpr uintptr_t kTagMask = 1 << 0; + static constexpr uintptr_t kPtrMask = ~kTagMask; + uintptr_t message_ptr_; + const LegacyTypeInfoApis* legacy_type_info_; +}; + +static_assert(sizeof(MessageWrapper) <= 2 * sizeof(uintptr_t), + "MessageWrapper must not increase CelValue size."); + +} // namespace google::api::expr::runtime +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_MESSAGE_WRAPPER_H_ diff --git a/eval/public/message_wrapper_test.cc b/eval/public/message_wrapper_test.cc new file mode 100644 index 000000000..e3fb2d3f5 --- /dev/null +++ b/eval/public/message_wrapper_test.cc @@ -0,0 +1,54 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/message_wrapper.h" + +#include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" +#include "eval/public/structs/trivial_legacy_type_info.h" +#include "eval/testutil/test_message.pb.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +TEST(MessageWrapper, Size) { + static_assert(sizeof(MessageWrapper) <= 2 * sizeof(uintptr_t), + "MessageWrapper must not increase CelValue size."); +} + +TEST(MessageWrapper, WrapsMessage) { + TestMessage test_message; + + test_message.set_int64_value(20); + test_message.set_double_value(12.3); + + MessageWrapper wrapped_message(&test_message, TrivialTypeInfo::GetInstance()); + + constexpr bool is_full_proto_runtime = + std::is_base_of_v; + + EXPECT_EQ(wrapped_message.message_ptr(), + static_cast(&test_message)); + ASSERT_EQ(wrapped_message.HasFullProto(), is_full_proto_runtime); +} + +TEST(MessageWrapper, DefaultNull) { + MessageWrapper wrapper; + EXPECT_EQ(wrapper.message_ptr(), nullptr); + EXPECT_EQ(wrapper.legacy_type_info(), nullptr); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/public/portable_cel_expr_builder_factory_test.cc b/eval/public/portable_cel_expr_builder_factory_test.cc index a2b7e54ba..79d9dbaaf 100644 --- a/eval/public/portable_cel_expr_builder_factory_test.cc +++ b/eval/public/portable_cel_expr_builder_factory_test.cc @@ -271,14 +271,13 @@ class DemoTypeInfo : public LegacyTypeInfoApis { public: explicit DemoTypeInfo(const DemoTypeProvider* owning_provider) : owning_provider_(*owning_provider) {} - std::string DebugString( - const internal::MessageWrapper& wrapped_message) const override; + std::string DebugString(const MessageWrapper& wrapped_message) const override; const std::string& GetTypename( - const internal::MessageWrapper& wrapped_message) const override; + const MessageWrapper& wrapped_message) const override; const LegacyTypeAccessApis* GetAccessApis( - const internal::MessageWrapper& wrapped_message) const override; + const MessageWrapper& wrapped_message) const override; private: const DemoTypeProvider& owning_provider_; @@ -354,17 +353,17 @@ class DemoTypeProvider : public LegacyTypeProvider { }; std::string DemoTypeInfo::DebugString( - const internal::MessageWrapper& wrapped_message) const { + const MessageWrapper& wrapped_message) const { return wrapped_message.message_ptr()->GetTypeName(); } const std::string& DemoTypeInfo::GetTypename( - const internal::MessageWrapper& wrapped_message) const { + const MessageWrapper& wrapped_message) const { return owning_provider_.GetStableType(wrapped_message.message_ptr()); } const LegacyTypeAccessApis* DemoTypeInfo::GetAccessApis( - const internal::MessageWrapper& wrapped_message) const { + const MessageWrapper& wrapped_message) const { auto adapter = owning_provider_.ProvideLegacyType( wrapped_message.message_ptr()->GetTypeName()); if (adapter.has_value()) { diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 23ee01efc..e56af26ca 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -28,7 +28,7 @@ cc_library( ":cel_proto_wrap_util", ":proto_message_type_adapter", "//eval/public:cel_value", - "//eval/public:cel_value_internal", + "//eval/public:message_wrapper", "//internal:proto_time_encoding", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", @@ -77,11 +77,10 @@ cc_test( ], deps = [ ":cel_proto_wrap_util", - ":legacy_type_info_apis", ":protobuf_value_factory", ":trivial_legacy_type_info", "//eval/public:cel_value", - "//eval/public:cel_value_internal", + "//eval/public:message_wrapper", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/testutil:test_message_cc_proto", @@ -236,7 +235,7 @@ cc_library( "//base:memory_manager", "//eval/public:cel_options", "//eval/public:cel_value", - "//eval/public:cel_value_internal", + "//eval/public:message_wrapper", "//eval/public/containers:internal_field_backed_list_impl", "//eval/public/containers:internal_field_backed_map_impl", "//extensions/protobuf:memory_manager", @@ -258,7 +257,7 @@ cc_test( ":legacy_type_info_apis", ":proto_message_type_adapter", "//eval/public:cel_value", - "//eval/public:cel_value_internal", + "//eval/public:message_wrapper", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/containers:field_access", @@ -305,7 +304,7 @@ cc_test( cc_library( name = "legacy_type_info_apis", hdrs = ["legacy_type_info_apis.h"], - deps = ["//eval/public:cel_value_internal"], + deps = ["//eval/public:message_wrapper"], ) cc_library( @@ -314,7 +313,7 @@ cc_library( hdrs = ["trivial_legacy_type_info.h"], deps = [ ":legacy_type_info_apis", - "//eval/public:cel_value_internal", + "//eval/public:message_wrapper", "//internal:no_destructor", ], ) @@ -324,7 +323,7 @@ cc_test( srcs = ["trivial_legacy_type_info_test.cc"], deps = [ ":trivial_legacy_type_info", - "//eval/public:cel_value_internal", + "//eval/public:message_wrapper", "//internal:testing", ], ) diff --git a/eval/public/structs/cel_proto_wrap_util_test.cc b/eval/public/structs/cel_proto_wrap_util_test.cc index 1a9311a97..8611ef254 100644 --- a/eval/public/structs/cel_proto_wrap_util_test.cc +++ b/eval/public/structs/cel_proto_wrap_util_test.cc @@ -30,9 +30,9 @@ #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "eval/public/cel_value.h" -#include "eval/public/cel_value_internal.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/message_wrapper.h" #include "eval/public/structs/protobuf_value_factory.h" #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/testutil/test_message.pb.h" diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index 07fb68945..f5c82969a 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -17,7 +17,7 @@ #include "google/protobuf/message.h" #include "absl/types/optional.h" #include "eval/public/cel_value.h" -#include "eval/public/cel_value_internal.h" +#include "eval/public/message_wrapper.h" #include "eval/public/structs/cel_proto_wrap_util.h" #include "eval/public/structs/proto_message_type_adapter.h" @@ -33,7 +33,7 @@ using ::google::protobuf::Message; CelValue CelProtoWrapper::InternalWrapMessage(const Message* message) { return CelValue::CreateMessageWrapper( - internal::MessageWrapper(message, &GetGenericProtoTypeInfoInstance())); + MessageWrapper(message, &GetGenericProtoTypeInfoInstance())); } // CreateMessage creates CelValue from google::protobuf::Message. diff --git a/eval/public/structs/legacy_type_adapter_test.cc b/eval/public/structs/legacy_type_adapter_test.cc index f7632e032..1402387fa 100644 --- a/eval/public/structs/legacy_type_adapter_test.cc +++ b/eval/public/structs/legacy_type_adapter_test.cc @@ -65,7 +65,7 @@ class TestAccessApiImpl : public LegacyTypeAccessApis { TEST(LegacyTypeAdapterMutationApis, DefaultNoopAdapt) { TestMessage message; - internal::MessageWrapper wrapper(&message, TrivialTypeInfo::GetInstance()); + MessageWrapper wrapper(&message, TrivialTypeInfo::GetInstance()); google::protobuf::Arena arena; cel::extensions::ProtoMemoryManager manager(&arena); @@ -80,8 +80,8 @@ TEST(LegacyTypeAdapterMutationApis, DefaultNoopAdapt) { TEST(LegacyTypeAdapterAccessApis, DefaultAlwaysInequal) { TestMessage message; - internal::MessageWrapper wrapper(&message, nullptr); - internal::MessageWrapper wrapper2(&message, nullptr); + MessageWrapper wrapper(&message, nullptr); + MessageWrapper wrapper2(&message, nullptr); google::protobuf::Arena arena; cel::extensions::ProtoMemoryManager manager(&arena); diff --git a/eval/public/structs/legacy_type_info_apis.h b/eval/public/structs/legacy_type_info_apis.h index 939dc8a94..49ce036af 100644 --- a/eval/public/structs/legacy_type_info_apis.h +++ b/eval/public/structs/legacy_type_info_apis.h @@ -17,7 +17,7 @@ #include -#include "eval/public/cel_value_internal.h" +#include "eval/public/message_wrapper.h" namespace google::api::expr::runtime { @@ -40,13 +40,13 @@ class LegacyTypeInfoApis { // Return a debug representation of the wrapped message. virtual std::string DebugString( - const internal::MessageWrapper& wrapped_message) const = 0; + const MessageWrapper& wrapped_message) const = 0; // Return a const-reference to the typename for the wrapped message's type. // The CEL interpreter assumes that the typename is owned externally and will // outlive any CelValues created by the interpreter. virtual const std::string& GetTypename( - const internal::MessageWrapper& wrapped_message) const = 0; + const MessageWrapper& wrapped_message) const = 0; // Return a pointer to the wrapped message's access api implementation. // @@ -57,7 +57,7 @@ class LegacyTypeInfoApis { // access, the interpreter will treat this the same as accessing a field that // is not defined for the type. virtual const LegacyTypeAccessApis* GetAccessApis( - const internal::MessageWrapper& wrapped_message) const = 0; + const MessageWrapper& wrapped_message) const = 0; }; } // namespace google::api::expr::runtime diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index 1a089b235..58bdb17bf 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -24,9 +24,9 @@ #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "eval/public/cel_value.h" -#include "eval/public/cel_value_internal.h" #include "eval/public/containers/internal_field_backed_list_impl.h" #include "eval/public/containers/internal_field_backed_map_impl.h" +#include "eval/public/message_wrapper.h" #include "eval/public/structs/cel_proto_wrap_util.h" #include "eval/public/structs/field_access_impl.h" #include "eval/public/structs/legacy_type_adapter.h" @@ -175,7 +175,7 @@ class DucktypedMessageAdapter : public LegacyTypeAccessApis, // Implement TypeInfo Apis const std::string& GetTypename( - const internal::MessageWrapper& wrapped_message) const override { + const MessageWrapper& wrapped_message) const override { if (!wrapped_message.HasFullProto() || wrapped_message.message_ptr() == nullptr) { return UnsupportedTypeName(); @@ -186,7 +186,7 @@ class DucktypedMessageAdapter : public LegacyTypeAccessApis, } std::string DebugString( - const internal::MessageWrapper& wrapped_message) const override { + const MessageWrapper& wrapped_message) const override { if (!wrapped_message.HasFullProto() || wrapped_message.message_ptr() == nullptr) { return UnsupportedTypeName(); @@ -197,7 +197,7 @@ class DucktypedMessageAdapter : public LegacyTypeAccessApis, } const LegacyTypeAccessApis* GetAccessApis( - const internal::MessageWrapper& wrapped_message) const override { + const MessageWrapper& wrapped_message) const override { return this; } @@ -208,8 +208,8 @@ class DucktypedMessageAdapter : public LegacyTypeAccessApis, }; CelValue MessageCelValueFactory(const google::protobuf::Message* message) { - return CelValue::CreateMessageWrapper(internal::MessageWrapper( - message, &DucktypedMessageAdapter::GetSingleton())); + return CelValue::CreateMessageWrapper( + MessageWrapper(message, &DucktypedMessageAdapter::GetSingleton())); } } // namespace @@ -236,7 +236,7 @@ absl::StatusOr ProtoMessageTypeAdapter::NewInstance( return absl::InvalidArgumentError( absl::StrCat("Failed to create message ", descriptor_->name())); } - return internal::MessageWrapper(msg, &GetGenericProtoTypeInfoInstance()); + return MessageWrapper(msg, &GetGenericProtoTypeInfoInstance()); } bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 001ff82ec..f31ec9bb9 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -21,10 +21,10 @@ #include "google/protobuf/message_lite.h" #include "absl/status/status.h" #include "eval/public/cel_value.h" -#include "eval/public/cel_value_internal.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/containers/field_access.h" +#include "eval/public/message_wrapper.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" @@ -77,7 +77,7 @@ TEST_P(ProtoMessageTypeAccessorTest, HasFieldSingular) { const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; - internal::MessageWrapper value(&example, nullptr); + MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.HasField("int64_value", value), IsOkAndHolds(false)); example.set_int64_value(10); @@ -90,7 +90,7 @@ TEST_P(ProtoMessageTypeAccessorTest, HasFieldRepeated) { TestMessage example; - internal::MessageWrapper value(&example, nullptr); + MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.HasField("int64_list", value), IsOkAndHolds(false)); example.add_int64_list(10); @@ -104,7 +104,7 @@ TEST_P(ProtoMessageTypeAccessorTest, HasFieldMap) { TestMessage example; example.set_int64_value(10); - internal::MessageWrapper value(&example, nullptr); + MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.HasField("int64_int32_map", value), IsOkAndHolds(false)); (*example.mutable_int64_int32_map())[2] = 3; @@ -118,7 +118,7 @@ TEST_P(ProtoMessageTypeAccessorTest, HasFieldUnknownField) { TestMessage example; example.set_int64_value(10); - internal::MessageWrapper value(&example, nullptr); + MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.HasField("unknown_field", value), StatusIs(absl::StatusCode::kNotFound)); @@ -128,8 +128,8 @@ TEST_P(ProtoMessageTypeAccessorTest, HasFieldNonMessageType) { google::protobuf::Arena arena; const LegacyTypeAccessApis& accessor = GetAccessApis(); - internal::MessageWrapper value( - static_cast(nullptr), nullptr); + MessageWrapper value(static_cast(nullptr), + nullptr); EXPECT_THAT(accessor.HasField("unknown_field", value), StatusIs(absl::StatusCode::kInternal)); @@ -144,7 +144,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldSingular) { TestMessage example; example.set_int64_value(10); - internal::MessageWrapper value(&example, nullptr); + MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.GetField("int64_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), @@ -160,7 +160,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldNoSuchField) { TestMessage example; example.set_int64_value(10); - internal::MessageWrapper value(&example, nullptr); + MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.GetField("unknown_field", value, ProtoWrapperTypeOptions::kUnsetNull, manager), @@ -174,8 +174,8 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldNotAMessage) { ProtoMemoryManager manager(&arena); - internal::MessageWrapper value( - static_cast(nullptr), nullptr); + MessageWrapper value(static_cast(nullptr), + nullptr); EXPECT_THAT(accessor.GetField("int64_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), @@ -192,7 +192,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldRepeated) { example.add_int64_list(10); example.add_int64_list(20); - internal::MessageWrapper value(&example, nullptr); + MessageWrapper value(&example, nullptr); ASSERT_OK_AND_ASSIGN( CelValue result, @@ -216,7 +216,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldMap) { TestMessage example; (*example.mutable_int64_int32_map())[10] = 20; - internal::MessageWrapper value(&example, nullptr); + MessageWrapper value(&example, nullptr); ASSERT_OK_AND_ASSIGN( CelValue result, @@ -240,7 +240,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldWrapperType) { TestMessage example; example.mutable_int64_wrapper_value()->set_value(10); - internal::MessageWrapper value(&example, nullptr); + MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), @@ -255,7 +255,7 @@ TEST_P(ProtoMessageTypeAccessorTest, GetFieldWrapperTypeUnsetNullUnbox) { TestMessage example; - internal::MessageWrapper value(&example, nullptr); + MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), @@ -277,7 +277,7 @@ TEST_P(ProtoMessageTypeAccessorTest, TestMessage example; - internal::MessageWrapper value(&example, nullptr); + MessageWrapper value(&example, nullptr); EXPECT_THAT( accessor.GetField("int64_wrapper_value", value, @@ -305,8 +305,8 @@ TEST_P(ProtoMessageTypeAccessorTest, IsEqualTo) { TestMessage example2; example2.mutable_int64_wrapper_value()->set_value(10); - internal::MessageWrapper value(&example, nullptr); - internal::MessageWrapper value2(&example2, nullptr); + MessageWrapper value(&example, nullptr); + MessageWrapper value2(&example2, nullptr); EXPECT_TRUE(accessor.IsEqualTo(value, value2)); EXPECT_TRUE(accessor.IsEqualTo(value2, value)); @@ -323,8 +323,8 @@ TEST_P(ProtoMessageTypeAccessorTest, IsEqualToSameTypeInequal) { TestMessage example2; example2.mutable_int64_wrapper_value()->set_value(12); - internal::MessageWrapper value(&example, nullptr); - internal::MessageWrapper value2(&example2, nullptr); + MessageWrapper value(&example, nullptr); + MessageWrapper value2(&example2, nullptr); EXPECT_FALSE(accessor.IsEqualTo(value, value2)); EXPECT_FALSE(accessor.IsEqualTo(value2, value)); @@ -341,8 +341,8 @@ TEST_P(ProtoMessageTypeAccessorTest, IsEqualToDifferentTypeInequal) { Int64Value example2; example2.set_value(10); - internal::MessageWrapper value(&example, nullptr); - internal::MessageWrapper value2(&example2, nullptr); + MessageWrapper value(&example, nullptr); + MessageWrapper value2(&example2, nullptr); EXPECT_FALSE(accessor.IsEqualTo(value, value2)); EXPECT_FALSE(accessor.IsEqualTo(value2, value)); @@ -359,10 +359,10 @@ TEST_P(ProtoMessageTypeAccessorTest, IsEqualToNonMessageInequal) { TestMessage example2; example2.mutable_int64_wrapper_value()->set_value(10); - internal::MessageWrapper value(&example, nullptr); + MessageWrapper value(&example, nullptr); // Upcast to message lite to prevent unwrapping to message. - internal::MessageWrapper value2( - static_cast(&example2), nullptr); + MessageWrapper value2(static_cast(&example2), + nullptr); EXPECT_FALSE(accessor.IsEqualTo(value, value2)); EXPECT_FALSE(accessor.IsEqualTo(value2, value)); diff --git a/eval/public/structs/trivial_legacy_type_info.h b/eval/public/structs/trivial_legacy_type_info.h index eabff8858..988a43d9c 100644 --- a/eval/public/structs/trivial_legacy_type_info.h +++ b/eval/public/structs/trivial_legacy_type_info.h @@ -17,7 +17,7 @@ #include -#include "eval/public/cel_value_internal.h" +#include "eval/public/message_wrapper.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "internal/no_destructor.h" @@ -27,19 +27,17 @@ namespace google::api::expr::runtime { // operations need to be supported. class TrivialTypeInfo : public LegacyTypeInfoApis { public: - const std::string& GetTypename( - const internal::MessageWrapper& wrapper) const override { + const std::string& GetTypename(const MessageWrapper& wrapper) const override { static cel::internal::NoDestructor kTypename("opaque type"); return *kTypename; } - std::string DebugString( - const internal::MessageWrapper& wrapper) const override { + std::string DebugString(const MessageWrapper& wrapper) const override { return "opaque"; } const LegacyTypeAccessApis* GetAccessApis( - const internal::MessageWrapper& wrapper) const override { + const MessageWrapper& wrapper) const override { // Accessors unsupported -- caller should treat this as an opaque type (no // fields defined, field access always results in a CEL error). return nullptr; diff --git a/eval/public/structs/trivial_legacy_type_info_test.cc b/eval/public/structs/trivial_legacy_type_info_test.cc index 36832e888..eb54c0fcd 100644 --- a/eval/public/structs/trivial_legacy_type_info_test.cc +++ b/eval/public/structs/trivial_legacy_type_info_test.cc @@ -14,7 +14,7 @@ #include "eval/public/structs/trivial_legacy_type_info.h" -#include "eval/public/cel_value_internal.h" +#include "eval/public/message_wrapper.h" #include "internal/testing.h" namespace google::api::expr::runtime { @@ -22,7 +22,7 @@ namespace { TEST(TrivialTypeInfo, GetTypename) { TrivialTypeInfo info; - internal::MessageWrapper wrapper; + MessageWrapper wrapper; EXPECT_EQ(info.GetTypename(wrapper), "opaque type"); EXPECT_EQ(TrivialTypeInfo::GetInstance()->GetTypename(wrapper), @@ -31,7 +31,7 @@ TEST(TrivialTypeInfo, GetTypename) { TEST(TrivialTypeInfo, DebugString) { TrivialTypeInfo info; - internal::MessageWrapper wrapper; + MessageWrapper wrapper; EXPECT_EQ(info.DebugString(wrapper), "opaque"); EXPECT_EQ(TrivialTypeInfo::GetInstance()->DebugString(wrapper), "opaque"); @@ -39,7 +39,7 @@ TEST(TrivialTypeInfo, DebugString) { TEST(TrivialTypeInfo, GetAccessApis) { TrivialTypeInfo info; - internal::MessageWrapper wrapper; + MessageWrapper wrapper; EXPECT_EQ(info.GetAccessApis(wrapper), nullptr); EXPECT_EQ(TrivialTypeInfo::GetInstance()->GetAccessApis(wrapper), nullptr); From 00435fcb81a6399fd6503ce7cd044b646d640814 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 3 May 2022 22:27:19 +0000 Subject: [PATCH 094/111] Introduce MessageWrapper::Builder abstraction to manage message creation steps. PiperOrigin-RevId: 446304898 --- eval/eval/create_struct_step.cc | 2 +- eval/public/BUILD | 1 + eval/public/message_wrapper.h | 28 +++++++ eval/public/message_wrapper_test.cc | 30 +++++++ .../portable_cel_expr_builder_factory_test.cc | 83 +++++++++---------- eval/public/structs/legacy_type_adapter.h | 15 ++-- .../structs/legacy_type_adapter_test.cc | 35 -------- .../structs/proto_message_type_adapter.cc | 28 ++++--- .../structs/proto_message_type_adapter.h | 12 +-- .../proto_message_type_adapter_test.cc | 28 +++---- .../protobuf_descriptor_type_provider_test.cc | 2 +- 11 files changed, 144 insertions(+), 120 deletions(-) diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 03caf078d..b4db5e61b 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -71,7 +71,7 @@ absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, } } - CEL_ASSIGN_OR_RETURN(CelValue::MessageWrapper instance, + CEL_ASSIGN_OR_RETURN(MessageWrapper::Builder instance, type_adapter_->NewInstance(frame->memory_manager())); int index = 0; diff --git a/eval/public/BUILD b/eval/public/BUILD index d01cca09e..06d47afb1 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -35,6 +35,7 @@ cc_test( ":message_wrapper", "//eval/public/structs:trivial_legacy_type_info", "//eval/testutil:test_message_cc_proto", + "//internal:casts", "//internal:testing", "@com_google_protobuf//:protobuf", ], diff --git a/eval/public/message_wrapper.h b/eval/public/message_wrapper.h index b4e1f00fa..8b5d17b49 100644 --- a/eval/public/message_wrapper.h +++ b/eval/public/message_wrapper.h @@ -29,6 +29,34 @@ class LegacyTypeInfoApis; // Provides operations for checking if down-casting to Message is safe. class MessageWrapper { public: + // Simple builder class. + // + // Wraps a tagged mutable message lite ptr. + class Builder { + public: + explicit Builder(google::protobuf::MessageLite* message) + : message_ptr_(reinterpret_cast(message)) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); + } + explicit Builder(google::protobuf::Message* message) + : message_ptr_(reinterpret_cast(message) | kTagMask) { + ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= 1); + } + + google::protobuf::MessageLite* message_ptr() const { + return reinterpret_cast(message_ptr_ & kPtrMask); + } + + bool HasFullProto() const { return (message_ptr_ & kTagMask) == kTagMask; } + + MessageWrapper Build(const LegacyTypeInfoApis* type_info) { + return MessageWrapper(message_ptr_, type_info); + } + + private: + uintptr_t message_ptr_; + }; + static_assert(alignof(google::protobuf::MessageLite) >= 2, "Assume that valid MessageLite ptrs have a free low-order bit"); MessageWrapper() : message_ptr_(0), legacy_type_info_(nullptr) {} diff --git a/eval/public/message_wrapper_test.cc b/eval/public/message_wrapper_test.cc index e3fb2d3f5..244248add 100644 --- a/eval/public/message_wrapper_test.cc +++ b/eval/public/message_wrapper_test.cc @@ -18,6 +18,7 @@ #include "google/protobuf/message_lite.h" #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/testutil/test_message.pb.h" +#include "internal/casts.h" #include "internal/testing.h" namespace google::api::expr::runtime { @@ -44,6 +45,35 @@ TEST(MessageWrapper, WrapsMessage) { ASSERT_EQ(wrapped_message.HasFullProto(), is_full_proto_runtime); } +TEST(MessageWrapperBuilder, Builder) { + TestMessage test_message; + + MessageWrapper::Builder builder(&test_message); + constexpr bool is_full_proto_runtime = + std::is_base_of_v; + + ASSERT_EQ(builder.HasFullProto(), is_full_proto_runtime); + + ASSERT_EQ(builder.message_ptr(), + static_cast(&test_message)); + + auto mutable_message = + cel::internal::down_cast(builder.message_ptr()); + mutable_message->set_int64_value(20); + mutable_message->set_double_value(12.3); + + MessageWrapper wrapped_message = + builder.Build(TrivialTypeInfo::GetInstance()); + + ASSERT_EQ(wrapped_message.message_ptr(), + static_cast(&test_message)); + ASSERT_EQ(wrapped_message.HasFullProto(), is_full_proto_runtime); + EXPECT_EQ(wrapped_message.message_ptr(), + static_cast(&test_message)); + EXPECT_EQ(test_message.int64_value(), 20); + EXPECT_EQ(test_message.double_value(), 12.3); +} + TEST(MessageWrapper, DefaultNull) { MessageWrapper wrapper; EXPECT_EQ(wrapper.message_ptr(), nullptr); diff --git a/eval/public/portable_cel_expr_builder_factory_test.cc b/eval/public/portable_cel_expr_builder_factory_test.cc index 79d9dbaaf..329d57741 100644 --- a/eval/public/portable_cel_expr_builder_factory_test.cc +++ b/eval/public/portable_cel_expr_builder_factory_test.cc @@ -46,17 +46,14 @@ namespace { using ::google::protobuf::Int64Value; // Helpers for c++ / proto to cel value conversions. -std::optional Unwrap(const CelValue::MessageWrapper& wrapper) { - if (wrapper.message_ptr()->GetTypeName() == "google.protobuf.Duration") { +std::optional Unwrap(const google::protobuf::MessageLite* wrapper) { + if (wrapper->GetTypeName() == "google.protobuf.Duration") { const auto* duration = - cel::internal::down_cast( - wrapper.message_ptr()); + cel::internal::down_cast(wrapper); return CelValue::CreateDuration(cel::internal::DecodeDuration(*duration)); - } else if (wrapper.message_ptr()->GetTypeName() == - "google.protobuf.Timestamp") { + } else if (wrapper->GetTypeName() == "google.protobuf.Timestamp") { const auto* timestamp = - cel::internal::down_cast( - wrapper.message_ptr()); + cel::internal::down_cast(wrapper); return CelValue::CreateTimestamp(cel::internal::DecodeTime(*timestamp)); } return std::nullopt; @@ -246,21 +243,21 @@ class DemoTimestamp : public LegacyTypeMutationApis { return field_name == "seconds" || field_name == "nanos"; } - absl::StatusOr NewInstance( + absl::StatusOr NewInstance( cel::MemoryManager& memory_manager) const override; absl::StatusOr AdaptFromWellKnownType( cel::MemoryManager& memory_manager, - CelValue::MessageWrapper instance) const override; + CelValue::MessageWrapper::Builder instance) const override; - absl::Status SetField(absl::string_view field_name, const CelValue& value, - cel::MemoryManager& memory_manager, - CelValue::MessageWrapper& instance) const override; + absl::Status SetField( + absl::string_view field_name, const CelValue& value, + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper::Builder& instance) const override; private: - absl::Status Validate(const CelValue::MessageWrapper& wrapped_message) const { - if (wrapped_message.message_ptr()->GetTypeName() != - "google.protobuf.Timestamp") { + absl::Status Validate(const google::protobuf::MessageLite* wrapped_message) const { + if (wrapped_message->GetTypeName() != "google.protobuf.Timestamp") { return absl::InvalidArgumentError("not a timestamp"); } return absl::OkStatus(); @@ -292,16 +289,17 @@ class DemoTestMessage : public LegacyTypeMutationApis, return fields_.contains(field_name); } - absl::StatusOr NewInstance( + absl::StatusOr NewInstance( cel::MemoryManager& memory_manager) const override; absl::StatusOr AdaptFromWellKnownType( cel::MemoryManager& memory_manager, - CelValue::MessageWrapper instance) const override; + CelValue::MessageWrapper::Builder instance) const override; - absl::Status SetField(absl::string_view field_name, const CelValue& value, - cel::MemoryManager& memory_manager, - CelValue::MessageWrapper& instance) const override; + absl::Status SetField( + absl::string_view field_name, const CelValue& value, + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper::Builder& instance) const override; absl::StatusOr HasField( absl::string_view field_name, @@ -372,26 +370,26 @@ const LegacyTypeAccessApis* DemoTypeInfo::GetAccessApis( return nullptr; // not implemented yet. } -absl::StatusOr DemoTimestamp::NewInstance( +absl::StatusOr DemoTimestamp::NewInstance( cel::MemoryManager& memory_manager) const { auto ts = memory_manager.New(); - return CelValue::MessageWrapper(ts.release(), nullptr); + return CelValue::MessageWrapper::Builder(ts.release()); } absl::StatusOr DemoTimestamp::AdaptFromWellKnownType( cel::MemoryManager& memory_manager, - CelValue::MessageWrapper instance) const { - return *Unwrap(instance); + CelValue::MessageWrapper::Builder instance) const { + auto value = Unwrap(instance.message_ptr()); + ABSL_ASSERT(value.has_value()); + return *value; } -absl::Status DemoTimestamp::SetField(absl::string_view field_name, - const CelValue& value, - cel::MemoryManager& memory_manager, - CelValue::MessageWrapper& instance) const { - ABSL_ASSERT(Validate(instance).ok()); - const auto* const_ts = - cel::internal::down_cast( - instance.message_ptr()); - auto* mutable_ts = const_cast(const_ts); +absl::Status DemoTimestamp::SetField( + absl::string_view field_name, const CelValue& value, + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper::Builder& instance) const { + ABSL_ASSERT(Validate(instance.message_ptr()).ok()); + auto* mutable_ts = cel::internal::down_cast( + instance.message_ptr()); if (field_name == "seconds" && value.IsInt64()) { mutable_ts->set_seconds(value.Int64OrDie()); } else if (field_name == "nanos" && value.IsInt64()) { @@ -423,31 +421,30 @@ DemoTestMessage::DemoTestMessage(const DemoTypeProvider* owning_provider) &TestMessage::set_allocated_int64_wrapper_value); } -absl::StatusOr DemoTestMessage::NewInstance( +absl::StatusOr DemoTestMessage::NewInstance( cel::MemoryManager& memory_manager) const { auto ts = memory_manager.New(); - return CelValue::MessageWrapper(ts.release(), - owning_provider_.GetTypeInfoInstance()); + return CelValue::MessageWrapper::Builder(ts.release()); } absl::Status DemoTestMessage::SetField( absl::string_view field_name, const CelValue& value, cel::MemoryManager& memory_manager, - CelValue::MessageWrapper& instance) const { + CelValue::MessageWrapper::Builder& instance) const { auto iter = fields_.find(field_name); if (iter == fields_.end()) { return absl::UnknownError("no such field"); } - auto* test_msg = - cel::internal::down_cast(instance.message_ptr()); - auto* mutable_test_msg = const_cast(test_msg); + auto* mutable_test_msg = + cel::internal::down_cast(instance.message_ptr()); return iter->second->Set(mutable_test_msg, value); } absl::StatusOr DemoTestMessage::AdaptFromWellKnownType( cel::MemoryManager& memory_manager, - CelValue::MessageWrapper instance) const { - return CelValue::CreateMessageWrapper(instance); + CelValue::MessageWrapper::Builder instance) const { + return CelValue::CreateMessageWrapper( + instance.Build(owning_provider_.GetTypeInfoInstance())); } absl::StatusOr DemoTestMessage::HasField( diff --git a/eval/public/structs/legacy_type_adapter.h b/eval/public/structs/legacy_type_adapter.h index 5250f1b70..a7659a7bb 100644 --- a/eval/public/structs/legacy_type_adapter.h +++ b/eval/public/structs/legacy_type_adapter.h @@ -39,26 +39,23 @@ class LegacyTypeMutationApis { // Create a new empty instance of the type. // May return a status if the type is not possible to create. - virtual absl::StatusOr NewInstance( + virtual absl::StatusOr NewInstance( cel::MemoryManager& memory_manager) const = 0; // Normalize special types to a native CEL value after building. - // The default implementation is a no-op. // The interpreter guarantees that instance is uniquely owned by the // interpreter, and can be safely mutated. virtual absl::StatusOr AdaptFromWellKnownType( cel::MemoryManager& memory_manager, - CelValue::MessageWrapper instance) const { - return CelValue::CreateMessageWrapper(instance); - } + CelValue::MessageWrapper::Builder instance) const = 0; // Set field on instance to value. // The interpreter guarantees that instance is uniquely owned by the // interpreter, and can be safely mutated. - virtual absl::Status SetField(absl::string_view field_name, - const CelValue& value, - cel::MemoryManager& memory_manager, - CelValue::MessageWrapper& instance) const = 0; + virtual absl::Status SetField( + absl::string_view field_name, const CelValue& value, + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper::Builder& instance) const = 0; }; // Interface for access apis. diff --git a/eval/public/structs/legacy_type_adapter_test.cc b/eval/public/structs/legacy_type_adapter_test.cc index 1402387fa..726a32342 100644 --- a/eval/public/structs/legacy_type_adapter_test.cc +++ b/eval/public/structs/legacy_type_adapter_test.cc @@ -25,26 +25,6 @@ namespace google::api::expr::runtime { namespace { -using testing::EqualsProto; - -class TestMutationApiImpl : public LegacyTypeMutationApis { - public: - TestMutationApiImpl() {} - bool DefinesField(absl::string_view field_name) const override { - return false; - } - - absl::StatusOr NewInstance( - cel::MemoryManager& memory_manager) const override { - return absl::UnimplementedError("Not implemented"); - } - - absl::Status SetField(absl::string_view field_name, const CelValue& value, - cel::MemoryManager& memory_manager, - CelValue::MessageWrapper& instance) const override { - return absl::UnimplementedError("Not implemented"); - } -}; class TestAccessApiImpl : public LegacyTypeAccessApis { public: @@ -63,21 +43,6 @@ class TestAccessApiImpl : public LegacyTypeAccessApis { } }; -TEST(LegacyTypeAdapterMutationApis, DefaultNoopAdapt) { - TestMessage message; - MessageWrapper wrapper(&message, TrivialTypeInfo::GetInstance()); - google::protobuf::Arena arena; - cel::extensions::ProtoMemoryManager manager(&arena); - - TestMutationApiImpl impl; - - ASSERT_OK_AND_ASSIGN(CelValue v, - impl.AdaptFromWellKnownType(manager, wrapper)); - - EXPECT_THAT(v, - test::IsCelMessage(EqualsProto(TestMessage::default_instance()))); -} - TEST(LegacyTypeAdapterAccessApis, DefaultAlwaysInequal) { TestMessage message; MessageWrapper wrapper(&message, nullptr); diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index 58bdb17bf..1a0eda8f2 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -53,7 +53,7 @@ const std::string& UnsupportedTypeName() { CelValue MessageCelValueFactory(const google::protobuf::Message* message); inline absl::StatusOr UnwrapMessage( - const CelValue::MessageWrapper& value, absl::string_view op) { + const MessageWrapper& value, absl::string_view op) { if (!value.HasFullProto() || value.message_ptr() == nullptr) { return absl::InternalError( absl::StrCat(op, " called on non-message type.")); @@ -61,6 +61,15 @@ inline absl::StatusOr UnwrapMessage( return cel::internal::down_cast(value.message_ptr()); } +inline absl::StatusOr UnwrapMessage( + const MessageWrapper::Builder& value, absl::string_view op) { + if (!value.HasFullProto() || value.message_ptr() == nullptr) { + return absl::InternalError( + absl::StrCat(op, " called on non-message type.")); + } + return cel::internal::down_cast(value.message_ptr()); +} + bool ProtoEquals(const google::protobuf::Message& m1, const google::protobuf::Message& m2) { // Equality behavior is undefined for message differencer if input messages // have different descriptors. For CEL just return false. @@ -224,8 +233,8 @@ absl::Status ProtoMessageTypeAdapter::ValidateSetFieldOp( return absl::OkStatus(); } -absl::StatusOr ProtoMessageTypeAdapter::NewInstance( - cel::MemoryManager& memory_manager) const { +absl::StatusOr +ProtoMessageTypeAdapter::NewInstance(cel::MemoryManager& memory_manager) const { // This implementation requires arena-backed memory manager. google::protobuf::Arena* arena = ProtoMemoryManager::CastToProtoArena(memory_manager); const Message* prototype = message_factory_->GetPrototype(descriptor_); @@ -236,7 +245,7 @@ absl::StatusOr ProtoMessageTypeAdapter::NewInstance( return absl::InvalidArgumentError( absl::StrCat("Failed to create message ", descriptor_->name())); } - return MessageWrapper(msg, &GetGenericProtoTypeInfoInstance()); + return MessageWrapper::Builder(msg); } bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { @@ -264,17 +273,14 @@ absl::StatusOr ProtoMessageTypeAdapter::GetField( absl::Status ProtoMessageTypeAdapter::SetField( absl::string_view field_name, const CelValue& value, cel::MemoryManager& memory_manager, - CelValue::MessageWrapper& instance) const { + CelValue::MessageWrapper::Builder& instance) const { // Assume proto arena implementation if this provider is used. google::protobuf::Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager); - CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + CEL_ASSIGN_OR_RETURN(google::protobuf::Message * mutable_message, UnwrapMessage(instance, "SetField")); - // Interpreter guarantees this is the top-level instance. - google::protobuf::Message* mutable_message = const_cast(message); - const google::protobuf::FieldDescriptor* field_descriptor = descriptor_->FindFieldByName(field_name.data()); CEL_RETURN_IF_ERROR( @@ -340,11 +346,11 @@ absl::Status ProtoMessageTypeAdapter::SetField( absl::StatusOr ProtoMessageTypeAdapter::AdaptFromWellKnownType( cel::MemoryManager& memory_manager, - CelValue::MessageWrapper instance) const { + CelValue::MessageWrapper::Builder instance) const { // Assume proto arena implementation if this provider is used. google::protobuf::Arena* arena = cel::extensions::ProtoMemoryManager::CastToProtoArena(memory_manager); - CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, + CEL_ASSIGN_OR_RETURN(google::protobuf::Message * message, UnwrapMessage(instance, "AdaptFromWellKnownType")); return internal::UnwrapMessageToValue(message, &MessageCelValueFactory, arena); diff --git a/eval/public/structs/proto_message_type_adapter.h b/eval/public/structs/proto_message_type_adapter.h index 5282a6119..d56540e3e 100644 --- a/eval/public/structs/proto_message_type_adapter.h +++ b/eval/public/structs/proto_message_type_adapter.h @@ -36,19 +36,19 @@ class ProtoMessageTypeAdapter : public LegacyTypeAccessApis, ~ProtoMessageTypeAdapter() override = default; - absl::StatusOr NewInstance( + absl::StatusOr NewInstance( cel::MemoryManager& memory_manager) const override; bool DefinesField(absl::string_view field_name) const override; - absl::Status SetField(absl::string_view field_name, const CelValue& value, - - cel::MemoryManager& memory_manager, - CelValue::MessageWrapper& instance) const override; + absl::Status SetField( + absl::string_view field_name, const CelValue& value, + cel::MemoryManager& memory_manager, + CelValue::MessageWrapper::Builder& instance) const override; absl::StatusOr AdaptFromWellKnownType( cel::MemoryManager& memory_manager, - CelValue::MessageWrapper instance) const override; + CelValue::MessageWrapper::Builder instance) const override; absl::StatusOr GetField( absl::string_view field_name, const CelValue::MessageWrapper& instance, diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index f31ec9bb9..b53406dfd 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -437,7 +437,7 @@ TEST(ProtoMessageTypeAdapter, NewInstance) { google::protobuf::MessageFactory::generated_factory()); ProtoMemoryManager manager(&arena); - ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper result, + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder result, adapter.NewInstance(manager)); EXPECT_THAT(result.message_ptr(), EqualsProto(TestMessage::default_instance())); @@ -485,7 +485,7 @@ TEST(ProtoMessageTypeAdapter, SetFieldSingular) { google::protobuf::MessageFactory::generated_factory()); ProtoMemoryManager manager(&arena); - ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper value, + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder value, adapter.NewInstance(manager)); ASSERT_OK(adapter.SetField("int64_value", CelValue::CreateInt64(10), manager, @@ -513,7 +513,7 @@ TEST(ProtoMessageTypeAdapter, SetFieldMap) { CelValue value_to_set = CelValue::CreateMap(&builder); - ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, adapter.NewInstance(manager)); ASSERT_OK( @@ -536,7 +536,7 @@ TEST(ProtoMessageTypeAdapter, SetFieldRepeated) { ContainerBackedListImpl list( {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); CelValue value_to_set = CelValue::CreateList(&list); - ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, adapter.NewInstance(manager)); ASSERT_OK(adapter.SetField("int64_list", value_to_set, manager, instance)); @@ -555,7 +555,7 @@ TEST(ProtoMessageTypeAdapter, SetFieldNotAField) { google::protobuf::MessageFactory::generated_factory()); ProtoMemoryManager manager(&arena); - ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, adapter.NewInstance(manager)); ASSERT_THAT(adapter.SetField("not_a_field", CelValue::CreateInt64(10), @@ -584,7 +584,7 @@ TEST(ProtoMesssageTypeAdapter, SetFieldWrongType) { CelValue int_value = CelValue::CreateInt64(42); - ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, adapter.NewInstance(manager)); EXPECT_THAT(adapter.SetField("int64_value", map_value, manager, instance), @@ -613,8 +613,8 @@ TEST(ProtoMesssageTypeAdapter, SetFieldNotAMessage) { ProtoMemoryManager manager(&arena); CelValue int_value = CelValue::CreateInt64(42); - CelValue::MessageWrapper instance( - static_cast(nullptr), nullptr); + CelValue::MessageWrapper::Builder instance( + static_cast(nullptr)); EXPECT_THAT(adapter.SetField("int64_value", int_value, manager, instance), StatusIs(absl::StatusCode::kInternal)); @@ -629,8 +629,8 @@ TEST(ProtoMesssageTypeAdapter, SetFieldNullMessage) { ProtoMemoryManager manager(&arena); CelValue int_value = CelValue::CreateInt64(42); - CelValue::MessageWrapper instance( - static_cast(nullptr), nullptr); + CelValue::MessageWrapper::Builder instance( + static_cast(nullptr)); EXPECT_THAT(adapter.SetField("int64_value", int_value, manager, instance), StatusIs(absl::StatusCode::kInternal)); @@ -644,7 +644,7 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownType) { google::protobuf::MessageFactory::generated_factory()); ProtoMemoryManager manager(&arena); - ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, adapter.NewInstance(manager)); ASSERT_OK( adapter.SetField("value", CelValue::CreateInt64(42), manager, instance)); @@ -663,7 +663,7 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeUnspecial) { google::protobuf::MessageFactory::generated_factory()); ProtoMemoryManager manager(&arena); - ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper instance, + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, adapter.NewInstance(manager)); ASSERT_OK(adapter.SetField("int64_value", CelValue::CreateInt64(42), manager, @@ -683,8 +683,8 @@ TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeNotAMessageError) { google::protobuf::MessageFactory::generated_factory()); ProtoMemoryManager manager(&arena); - CelValue::MessageWrapper instance( - static_cast(nullptr), nullptr); + CelValue::MessageWrapper::Builder instance( + static_cast(nullptr)); // Interpreter guaranteed to call this with a message type, otherwise, // something has broken. diff --git a/eval/public/structs/protobuf_descriptor_type_provider_test.cc b/eval/public/structs/protobuf_descriptor_type_provider_test.cc index 39d153026..00c5e09e3 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider_test.cc +++ b/eval/public/structs/protobuf_descriptor_type_provider_test.cc @@ -35,7 +35,7 @@ TEST(ProtobufDescriptorProvider, Basic) { ASSERT_TRUE(type_adapter->mutation_apis() != nullptr); ASSERT_TRUE(type_adapter->mutation_apis()->DefinesField("value")); - ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper value, + ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder value, type_adapter->mutation_apis()->NewInstance(manager)); ASSERT_OK(type_adapter->mutation_apis()->SetField( From 14df906a60954e42f37b1caefe36b486fb1317d7 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 3 May 2022 22:27:57 +0000 Subject: [PATCH 095/111] Add portable version of the function adapter helper. This version doesn't include any reflection utilities or assume that values are implementing the full message interface. PiperOrigin-RevId: 446305046 --- eval/public/BUILD | 31 ++++ eval/public/cel_function_adapter.h | 9 +- eval/public/cel_function_adapter_impl.h | 4 +- eval/public/portable_cel_function_adapter.h | 37 +++++ .../portable_cel_function_adapter_test.cc | 150 ++++++++++++++++++ 5 files changed, 222 insertions(+), 9 deletions(-) create mode 100644 eval/public/portable_cel_function_adapter.h create mode 100644 eval/public/portable_cel_function_adapter_test.cc diff --git a/eval/public/BUILD b/eval/public/BUILD index 06d47afb1..713de9e39 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -221,6 +221,37 @@ cc_library( ], ) +cc_library( + name = "portable_cel_function_adapter", + hdrs = [ + "portable_cel_function_adapter.h", + ], + deps = [ + ":cel_function", + ":cel_function_adapter_impl", + ":cel_function_registry", + ":cel_value", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "portable_cel_function_adapter_test", + size = "small", + srcs = [ + "portable_cel_function_adapter_test.cc", + ], + deps = [ + ":portable_cel_function_adapter", + "//internal:status_macros", + "//internal:testing", + ], +) + cc_library( name = "cel_function_provider", srcs = [ diff --git a/eval/public/cel_function_adapter.h b/eval/public/cel_function_adapter.h index 9c5bdb18e..2df1229dc 100644 --- a/eval/public/cel_function_adapter.h +++ b/eval/public/cel_function_adapter.h @@ -7,14 +7,9 @@ #include "google/protobuf/message.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "eval/public/cel_function.h" #include "eval/public/cel_function_adapter_impl.h" -#include "eval/public/cel_function_registry.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "internal/status_macros.h" namespace google::api::expr::runtime { @@ -23,12 +18,12 @@ namespace internal { // A type code matcher that adds support for google::protobuf::Message. struct ProtoAdapterTypeCodeMatcher { template - constexpr absl::optional type_code() { + constexpr std::optional type_code() { return internal::TypeCodeMatcher().type_code(); } template <> - constexpr absl::optional type_code() { + constexpr std::optional type_code() { return CelValue::Type::kMessage; } }; diff --git a/eval/public/cel_function_adapter_impl.h b/eval/public/cel_function_adapter_impl.h index 9e669a21a..59b5872a5 100644 --- a/eval/public/cel_function_adapter_impl.h +++ b/eval/public/cel_function_adapter_impl.h @@ -34,7 +34,7 @@ namespace internal { // Used for CEL type deduction based on C++ native type. struct TypeCodeMatcher { template - constexpr absl::optional type_code() { + constexpr std::optional type_code() { int index = CelValue::IndexOf::value; if (index < 0) return {}; CelValue::Type arg_type = static_cast(index); @@ -47,7 +47,7 @@ struct TypeCodeMatcher { // A bit of a trick - to pass Any kind of value, we use generic CelValue // parameters. template <> - constexpr absl::optional type_code() { + constexpr std::optional type_code() { return CelValue::Type::kAny; } }; diff --git a/eval/public/portable_cel_function_adapter.h b/eval/public/portable_cel_function_adapter.h new file mode 100644 index 000000000..840fb86de --- /dev/null +++ b/eval/public/portable_cel_function_adapter.h @@ -0,0 +1,37 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_FUNCTION_ADAPTER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_FUNCTION_ADAPTER_H_ + +#include "eval/public/cel_function_adapter_impl.h" + +namespace google::api::expr::runtime { + +// Portable version of the FunctionAdapter template utility. +// +// The PortableFunctionAdapter variation provides the same interface, +// but doesn't support unwrapping google::protobuf::Message values. See documentation on +// Function adapter for example usage. +// +// Most users should prefer using the standard FunctionAdapter. +template +using PortableFunctionAdapter = + internal::FunctionAdapter; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_FUNCTION_ADAPTER_H_ diff --git a/eval/public/portable_cel_function_adapter_test.cc b/eval/public/portable_cel_function_adapter_test.cc new file mode 100644 index 000000000..ebe69157b --- /dev/null +++ b/eval/public/portable_cel_function_adapter_test.cc @@ -0,0 +1,150 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/portable_cel_function_adapter.h" + +#include +#include +#include + +#include "internal/status_macros.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { + +namespace { + +TEST(PortableCelFunctionAdapterTest, TestAdapterNoArg) { + auto func = [](google::protobuf::Arena*) -> int64_t { return 100; }; + ASSERT_OK_AND_ASSIGN(auto cel_func, (PortableFunctionAdapter::Create( + "const", false, func))); + + absl::Span args; + CelValue result = CelValue::CreateNull(); + google::protobuf::Arena arena; + ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); + // Obvious failure, for educational purposes only. + ASSERT_TRUE(result.IsInt64()); +} + +TEST(PortableCelFunctionAdapterTest, TestAdapterOneArg) { + std::function func = + [](google::protobuf::Arena* arena, int64_t i) -> int64_t { return i + 1; }; + ASSERT_OK_AND_ASSIGN( + auto cel_func, + (PortableFunctionAdapter::Create("_++_", false, func))); + + std::vector args_vec; + args_vec.push_back(CelValue::CreateInt64(99)); + + CelValue result = CelValue::CreateNull(); + google::protobuf::Arena arena; + + absl::Span args(&args_vec[0], args_vec.size()); + ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_EQ(result.Int64OrDie(), 100); +} + +TEST(PortableCelFunctionAdapterTest, TestAdapterTwoArgs) { + auto func = [](google::protobuf::Arena* arena, int64_t i, int64_t j) -> int64_t { + return i + j; + }; + ASSERT_OK_AND_ASSIGN(auto cel_func, + (PortableFunctionAdapter::Create( + "_++_", false, func))); + + std::vector args_vec; + args_vec.push_back(CelValue::CreateInt64(20)); + args_vec.push_back(CelValue::CreateInt64(22)); + + CelValue result = CelValue::CreateNull(); + google::protobuf::Arena arena; + + absl::Span args(&args_vec[0], args_vec.size()); + ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_EQ(result.Int64OrDie(), 42); +} + +using StringHolder = CelValue::StringHolder; + +TEST(PortableCelFunctionAdapterTest, TestAdapterThreeArgs) { + auto func = [](google::protobuf::Arena* arena, StringHolder s1, StringHolder s2, + StringHolder s3) -> StringHolder { + std::string value = absl::StrCat(s1.value(), s2.value(), s3.value()); + + return StringHolder( + google::protobuf::Arena::Create(arena, std::move(value))); + }; + ASSERT_OK_AND_ASSIGN( + auto cel_func, + (PortableFunctionAdapter::Create("concat", false, func))); + + std::string test1 = "1"; + std::string test2 = "2"; + std::string test3 = "3"; + + std::vector args_vec; + args_vec.push_back(CelValue::CreateString(&test1)); + args_vec.push_back(CelValue::CreateString(&test2)); + args_vec.push_back(CelValue::CreateString(&test3)); + + CelValue result = CelValue::CreateNull(); + google::protobuf::Arena arena; + + absl::Span args(&args_vec[0], args_vec.size()); + ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); + ASSERT_TRUE(result.IsString()); + EXPECT_EQ(result.StringOrDie().value(), "123"); +} + +TEST(PortableCelFunctionAdapterTest, TestTypeDeductionForCelValueBasicTypes) { + auto func = [](google::protobuf::Arena* arena, bool, int64_t, uint64_t, double, + CelValue::StringHolder, CelValue::BytesHolder, + CelValue::MessageWrapper, absl::Duration, absl::Time, + const CelList*, const CelMap*, + const CelError*) -> bool { return false; }; + ASSERT_OK_AND_ASSIGN( + auto cel_func, + (PortableFunctionAdapter::Create("dummy_func", false, + func))); + auto descriptor = cel_func->descriptor(); + + EXPECT_EQ(descriptor.receiver_style(), false); + EXPECT_EQ(descriptor.name(), "dummy_func"); + + int pos = 0; + ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kBool); + ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kInt64); + ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kUint64); + ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kDouble); + ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kString); + ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kBytes); + ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kMessage); + ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kDuration); + ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kTimestamp); + ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kList); + ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kMap); + ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kError); +} + +} // namespace + +} // namespace google::api::expr::runtime From 119cb4342587081f8987e72fde7ae2ba29d5355d Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 4 May 2022 16:58:02 +0000 Subject: [PATCH 096/111] Consolidate flat expr builder setup into the portable implementation. PiperOrigin-RevId: 446484239 --- eval/public/BUILD | 1 + eval/public/cel_expr_builder_factory.cc | 47 +++---------------- .../portable_cel_expr_builder_factory.cc | 2 - .../portable_cel_expr_builder_factory.h | 7 ++- 4 files changed, 13 insertions(+), 44 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index 713de9e39..c79a1f6f0 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -487,6 +487,7 @@ cc_library( deps = [ ":cel_expression", ":cel_options", + ":portable_cel_expr_builder_factory", "//eval/compiler:flat_expr_builder", "//eval/public/structs:proto_message_type_adapter", "//eval/public/structs:protobuf_descriptor_type_provider", diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 17775e5aa..dbb232689 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -23,9 +23,11 @@ #include "absl/status/status.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/public/cel_options.h" +#include "eval/public/portable_cel_expr_builder_factory.h" #include "eval/public/structs/proto_message_type_adapter.h" #include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "internal/proto_util.h" + namespace google::api::expr::runtime { namespace { @@ -45,48 +47,11 @@ std::unique_ptr CreateCelExpressionBuilder( GOOGLE_LOG(WARNING) << "Failed to validate standard message types: " << s; return nullptr; } - auto builder = std::make_unique(); - builder->GetTypeRegistry()->RegisterTypeProvider( - std::make_unique(descriptor_pool, - message_factory)); - // LINT.IfChange - builder->set_shortcircuiting(options.short_circuiting); - builder->set_constant_folding(options.constant_folding, - options.constant_arena); - builder->set_enable_comprehension(options.enable_comprehension); - builder->set_enable_comprehension_list_append( - options.enable_comprehension_list_append); - builder->set_comprehension_max_iterations( - options.comprehension_max_iterations); - builder->set_fail_on_warnings(options.fail_on_warnings); - builder->set_enable_qualified_type_identifiers( - options.enable_qualified_type_identifiers); - builder->set_enable_comprehension_vulnerability_check( - options.enable_comprehension_vulnerability_check); - builder->set_enable_null_coercion(options.enable_null_to_message_coercion); - builder->set_enable_wrapper_type_null_unboxing( - options.enable_empty_wrapper_null_unboxing); - builder->set_enable_heterogeneous_equality( - options.enable_heterogeneous_equality); - builder->set_enable_qualified_identifier_rewrites( - options.enable_qualified_identifier_rewrites); - - switch (options.unknown_processing) { - case UnknownProcessingOptions::kAttributeAndFunction: - builder->set_enable_unknown_function_results(true); - builder->set_enable_unknowns(true); - break; - case UnknownProcessingOptions::kAttributeOnly: - builder->set_enable_unknowns(true); - break; - case UnknownProcessingOptions::kDisabled: - break; - } - - builder->set_enable_missing_attribute_errors( - options.enable_missing_attribute_errors); - // LINT.ThenChange(//depot/google3/eval/public/portable_cel_expr_builder_factory.cc) + auto builder = + CreatePortableExprBuilder(std::make_unique( + descriptor_pool, message_factory), + options); return builder; } diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc index 268bd1b35..025982ff9 100644 --- a/eval/public/portable_cel_expr_builder_factory.cc +++ b/eval/public/portable_cel_expr_builder_factory.cc @@ -36,7 +36,6 @@ std::unique_ptr CreatePortableExprBuilder( } auto builder = std::make_unique(); builder->GetTypeRegistry()->RegisterTypeProvider(std::move(type_provider)); - // LINT.IfChange builder->set_shortcircuiting(options.short_circuiting); builder->set_constant_folding(options.constant_folding, options.constant_arena); @@ -72,7 +71,6 @@ std::unique_ptr CreatePortableExprBuilder( builder->set_enable_missing_attribute_errors( options.enable_missing_attribute_errors); - // LINT.ThenChange(//depot/google3/eval/public/cel_expr_builder_factory.cc) return builder; } diff --git a/eval/public/portable_cel_expr_builder_factory.h b/eval/public/portable_cel_expr_builder_factory.h index 84cd86d82..b31b51ccf 100644 --- a/eval/public/portable_cel_expr_builder_factory.h +++ b/eval/public/portable_cel_expr_builder_factory.h @@ -26,7 +26,12 @@ namespace api { namespace expr { namespace runtime { -// Factory creates CelExpressionBuilder implementation for public use. +// Factory for initializing a CelExpressionBuilder implementation for public +// use. +// +// This version does not include any message type information, instead deferring +// to the type_provider argument. type_provider is guaranteed to be the first +// type provider in the type registry. std::unique_ptr CreatePortableExprBuilder( std::unique_ptr type_provider, const InterpreterOptions& options = InterpreterOptions()); From 201937b8cef52074fc99cba22f7b4906a21f293a Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 4 May 2022 18:52:16 +0000 Subject: [PATCH 097/111] Update builtin function registrar to use portable version of cel function adapter. PiperOrigin-RevId: 446514900 --- eval/public/BUILD | 6 +- eval/public/builtin_func_registrar.cc | 667 ++++++++++-------- eval/public/comparison_functions.cc | 61 +- eval/public/comparison_functions.h | 3 +- .../portable_cel_expr_builder_factory_test.cc | 53 +- 5 files changed, 434 insertions(+), 356 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index c79a1f6f0..4f28b8f7b 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -285,12 +285,12 @@ cc_library( deps = [ ":cel_builtins", ":cel_function", - ":cel_function_adapter", ":cel_function_registry", ":cel_number", ":cel_options", ":cel_value", ":comparison_functions", + ":portable_cel_function_adapter", "//eval/eval:mutable_list_impl", "//eval/public/containers:container_backed_list_impl", "//internal:casts", @@ -318,12 +318,12 @@ cc_library( ], deps = [ ":cel_builtins", - ":cel_function_adapter", ":cel_function_registry", ":cel_number", ":cel_options", ":cel_value", ":message_wrapper", + ":portable_cel_function_adapter", "//eval/eval:mutable_list_impl", "//eval/public/structs:legacy_type_adapter", "//eval/public/structs:legacy_type_info_apis", @@ -336,7 +336,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", - "@com_google_protobuf//:protobuf", "@com_googlesource_code_re2//:re2", ], ) @@ -1002,6 +1001,7 @@ cc_test( srcs = ["portable_cel_expr_builder_factory_test.cc"], deps = [ ":activation", + ":builtin_func_registrar", ":cel_options", ":cel_value", ":portable_cel_expr_builder_factory", diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 75600d889..613522a4d 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -21,7 +21,6 @@ #include #include -#include "google/protobuf/map_field.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" @@ -32,13 +31,13 @@ #include "absl/types/optional.h" #include "eval/eval/mutable_list_impl.h" #include "eval/public/cel_builtins.h" -#include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_number.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/comparison_functions.h" #include "eval/public/containers/container_backed_list_impl.h" +#include "eval/public/portable_cel_function_adapter.h" #include "internal/casts.h" #include "internal/overflow.h" #include "internal/proto_time_encoding.h" @@ -201,19 +200,19 @@ CelValue Modulo(Arena* arena, uint64_t v0, uint64_t v1) { template absl::Status RegisterArithmeticFunctionsForType(CelFunctionRegistry* registry) { absl::Status status = - FunctionAdapter::CreateAndRegister( + PortableFunctionAdapter::CreateAndRegister( builtin::kAdd, false, Add, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kSubtract, false, Sub, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kMultiply, false, Mul, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kDivide, false, Div, registry); return status; } @@ -526,30 +525,34 @@ absl::Status RegisterSetMembershipFunctions(CelFunctionRegistry* registry, for (absl::string_view op : in_operators) { if (options.enable_heterogeneous_equality) { CEL_RETURN_IF_ERROR( - (FunctionAdapter:: + (PortableFunctionAdapter:: CreateAndRegister(op, false, &HeterogeneousEqualityIn, registry))); } else { CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - op, false, In, registry))); + (PortableFunctionAdapter:: + CreateAndRegister(op, false, In, registry))); CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - op, false, In, registry))); + (PortableFunctionAdapter:: + CreateAndRegister(op, false, In, registry))); CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - op, false, In, registry))); + (PortableFunctionAdapter:: + CreateAndRegister(op, false, In, registry))); CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( - op, false, In, registry))); + (PortableFunctionAdapter:: + CreateAndRegister(op, false, In, registry))); CEL_RETURN_IF_ERROR( - (FunctionAdapter:: - CreateAndRegister(op, false, In, - registry))); + (PortableFunctionAdapter< + bool, CelValue::StringHolder, + const CelList*>::CreateAndRegister(op, false, + In, + registry))); CEL_RETURN_IF_ERROR( - (FunctionAdapter:: - CreateAndRegister(op, false, In, - registry))); + (PortableFunctionAdapter< + bool, CelValue::BytesHolder, + const CelList*>::CreateAndRegister(op, false, + In, + registry))); } } } @@ -647,31 +650,37 @@ absl::Status RegisterSetMembershipFunctions(CelFunctionRegistry* registry, }; for (auto op : in_operators) { - auto status = - FunctionAdapter::CreateAndRegister(op, false, - stringKeyInSet, - registry); + auto status = PortableFunctionAdapter< + CelValue, CelValue::StringHolder, + const CelMap*>::CreateAndRegister(op, false, stringKeyInSet, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - op, false, boolKeyInSet, registry); + status = + PortableFunctionAdapter::CreateAndRegister(op, false, + boolKeyInSet, + registry); if (!status.ok()) return status; status = - FunctionAdapter::CreateAndRegister( - op, false, intKeyInSet, registry); + PortableFunctionAdapter::CreateAndRegister(op, false, + intKeyInSet, + registry); if (!status.ok()) return status; status = - FunctionAdapter::CreateAndRegister( - op, false, uintKeyInSet, registry); + PortableFunctionAdapter::CreateAndRegister(op, false, + uintKeyInSet, + registry); if (!status.ok()) return status; if (options.enable_heterogeneous_equality) { - status = - FunctionAdapter::CreateAndRegister( - op, false, doubleKeyInSet, registry); + status = PortableFunctionAdapter< + CelValue, double, const CelMap*>::CreateAndRegister(op, false, + doubleKeyInSet, + registry); if (!status.ok()) return status; } } @@ -680,52 +689,58 @@ absl::Status RegisterSetMembershipFunctions(CelFunctionRegistry* registry, absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { - auto status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringContains, false, StringContains, - registry); + auto status = PortableFunctionAdapter< + bool, CelValue::StringHolder, + CelValue::StringHolder>::CreateAndRegister(builtin::kStringContains, + false, StringContains, + registry); if (!status.ok()) return status; - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringContains, true, StringContains, - registry); + status = PortableFunctionAdapter< + bool, CelValue::StringHolder, + CelValue::StringHolder>::CreateAndRegister(builtin::kStringContains, true, + StringContains, registry); if (!status.ok()) return status; - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringEndsWith, false, StringEndsWith, - registry); + status = PortableFunctionAdapter< + bool, CelValue::StringHolder, + CelValue::StringHolder>::CreateAndRegister(builtin::kStringEndsWith, + false, StringEndsWith, + registry); if (!status.ok()) return status; - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringEndsWith, true, StringEndsWith, - registry); + status = PortableFunctionAdapter< + bool, CelValue::StringHolder, + CelValue::StringHolder>::CreateAndRegister(builtin::kStringEndsWith, true, + StringEndsWith, registry); if (!status.ok()) return status; - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringStartsWith, false, StringStartsWith, - registry); + status = PortableFunctionAdapter< + bool, CelValue::StringHolder, + CelValue::StringHolder>::CreateAndRegister(builtin::kStringStartsWith, + false, StringStartsWith, + registry); if (!status.ok()) return status; - return FunctionAdapter:: - CreateAndRegister(builtin::kStringStartsWith, true, StringStartsWith, - registry); + return PortableFunctionAdapter< + bool, CelValue::StringHolder, + CelValue::StringHolder>::CreateAndRegister(builtin::kStringStartsWith, + true, StringStartsWith, + registry); } absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { - auto status = FunctionAdapter:: - CreateAndRegister( - builtin::kFullYear, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetFullYear(arena, ts, tz.value()); }, - registry); + auto status = + PortableFunctionAdapter:: + CreateAndRegister( + builtin::kFullYear, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetFullYear(arena, ts, tz.value()); }, + registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kFullYear, true, [](Arena* arena, absl::Time ts) -> CelValue { return GetFullYear(arena, ts, ""); @@ -733,15 +748,16 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister( - builtin::kMonth, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetMonth(arena, ts, tz.value()); }, - registry); + status = + PortableFunctionAdapter:: + CreateAndRegister( + builtin::kMonth, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetMonth(arena, ts, tz.value()); }, + registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kMonth, true, [](Arena* arena, absl::Time ts) -> CelValue { return GetMonth(arena, ts, ""); @@ -749,15 +765,16 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDayOfYear, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDayOfYear(arena, ts, tz.value()); }, - registry); + status = + PortableFunctionAdapter:: + CreateAndRegister( + builtin::kDayOfYear, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDayOfYear(arena, ts, tz.value()); }, + registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kDayOfYear, true, [](Arena* arena, absl::Time ts) -> CelValue { return GetDayOfYear(arena, ts, ""); @@ -765,15 +782,16 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDayOfMonth, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDayOfMonth(arena, ts, tz.value()); }, - registry); + status = + PortableFunctionAdapter:: + CreateAndRegister( + builtin::kDayOfMonth, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDayOfMonth(arena, ts, tz.value()); }, + registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kDayOfMonth, true, [](Arena* arena, absl::Time ts) -> CelValue { return GetDayOfMonth(arena, ts, ""); @@ -781,15 +799,16 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDate, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDate(arena, ts, tz.value()); }, - registry); + status = + PortableFunctionAdapter:: + CreateAndRegister( + builtin::kDate, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDate(arena, ts, tz.value()); }, + registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kDate, true, [](Arena* arena, absl::Time ts) -> CelValue { return GetDate(arena, ts, ""); @@ -797,15 +816,16 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDayOfWeek, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDayOfWeek(arena, ts, tz.value()); }, - registry); + status = + PortableFunctionAdapter:: + CreateAndRegister( + builtin::kDayOfWeek, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDayOfWeek(arena, ts, tz.value()); }, + registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kDayOfWeek, true, [](Arena* arena, absl::Time ts) -> CelValue { return GetDayOfWeek(arena, ts, ""); @@ -813,15 +833,16 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister( - builtin::kHours, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetHours(arena, ts, tz.value()); }, - registry); + status = + PortableFunctionAdapter:: + CreateAndRegister( + builtin::kHours, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetHours(arena, ts, tz.value()); }, + registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kHours, true, [](Arena* arena, absl::Time ts) -> CelValue { return GetHours(arena, ts, ""); @@ -829,15 +850,16 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister( - builtin::kMinutes, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetMinutes(arena, ts, tz.value()); }, - registry); + status = + PortableFunctionAdapter:: + CreateAndRegister( + builtin::kMinutes, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetMinutes(arena, ts, tz.value()); }, + registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kMinutes, true, [](Arena* arena, absl::Time ts) -> CelValue { return GetMinutes(arena, ts, ""); @@ -845,15 +867,16 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister( - builtin::kSeconds, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetSeconds(arena, ts, tz.value()); }, - registry); + status = + PortableFunctionAdapter:: + CreateAndRegister( + builtin::kSeconds, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetSeconds(arena, ts, tz.value()); }, + registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kSeconds, true, [](Arena* arena, absl::Time ts) -> CelValue { return GetSeconds(arena, ts, ""); @@ -861,15 +884,18 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister( - builtin::kMilliseconds, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetMilliseconds(arena, ts, tz.value()); }, - registry); + status = + PortableFunctionAdapter:: + CreateAndRegister( + builtin::kMilliseconds, true, + [](Arena* arena, absl::Time ts, + CelValue::StringHolder tz) -> CelValue { + return GetMilliseconds(arena, ts, tz.value()); + }, + registry); if (!status.ok()) return status; - return FunctionAdapter::CreateAndRegister( + return PortableFunctionAdapter::CreateAndRegister( builtin::kMilliseconds, true, [](Arena* arena, absl::Time ts) -> CelValue { return GetMilliseconds(arena, ts, ""); @@ -880,54 +906,57 @@ absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, absl::Status RegisterBytesConversionFunctions(CelFunctionRegistry* registry, const InterpreterOptions&) { // bytes -> bytes - auto status = FunctionAdapter:: + auto status = + PortableFunctionAdapter:: + CreateAndRegister( + builtin::kBytes, false, + [](Arena*, CelValue::BytesHolder value) -> CelValue::BytesHolder { + return value; + }, + registry); + if (!status.ok()) return status; + + // string -> bytes + return PortableFunctionAdapter:: CreateAndRegister( builtin::kBytes, false, - [](Arena*, CelValue::BytesHolder value) -> CelValue::BytesHolder { - return value; + [](Arena* arena, CelValue::StringHolder value) -> CelValue { + return CelValue::CreateBytesView(value.value()); }, registry); - if (!status.ok()) return status; - - // string -> bytes - return FunctionAdapter::CreateAndRegister( - builtin::kBytes, false, - [](Arena* arena, CelValue::StringHolder value) -> CelValue { - return CelValue::CreateBytesView(value.value()); - }, - registry); } absl::Status RegisterDoubleConversionFunctions(CelFunctionRegistry* registry, const InterpreterOptions&) { // double -> double - auto status = FunctionAdapter::CreateAndRegister( + auto status = PortableFunctionAdapter::CreateAndRegister( builtin::kDouble, false, [](Arena*, double v) { return v; }, registry); if (!status.ok()) return status; // int -> double - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kDouble, false, [](Arena*, int64_t v) { return static_cast(v); }, registry); if (!status.ok()) return status; // string -> double - status = FunctionAdapter::CreateAndRegister( - builtin::kDouble, false, - [](Arena* arena, CelValue::StringHolder s) { - double result; - if (absl::SimpleAtod(s.value(), &result)) { - return CelValue::CreateDouble(result); - } else { - return CreateErrorValue(arena, "cannot convert string to double", - absl::StatusCode::kInvalidArgument); - } - }, - registry); + status = PortableFunctionAdapter:: + CreateAndRegister( + builtin::kDouble, false, + [](Arena* arena, CelValue::StringHolder s) { + double result; + if (absl::SimpleAtod(s.value(), &result)) { + return CelValue::CreateDouble(result); + } else { + return CreateErrorValue(arena, "cannot convert string to double", + absl::StatusCode::kInvalidArgument); + } + }, + registry); if (!status.ok()) return status; // uint -> double - return FunctionAdapter::CreateAndRegister( + return PortableFunctionAdapter::CreateAndRegister( builtin::kDouble, false, [](Arena*, uint64_t v) { return static_cast(v); }, registry); } @@ -935,13 +964,13 @@ absl::Status RegisterDoubleConversionFunctions(CelFunctionRegistry* registry, absl::Status RegisterIntConversionFunctions(CelFunctionRegistry* registry, const InterpreterOptions&) { // bool -> int - auto status = FunctionAdapter::CreateAndRegister( + auto status = PortableFunctionAdapter::CreateAndRegister( builtin::kInt, false, [](Arena*, bool v) { return static_cast(v); }, registry); if (!status.ok()) return status; // double -> int - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kInt, false, [](Arena* arena, double v) { auto conv = cel::internal::CheckedDoubleToInt64(v); @@ -954,32 +983,33 @@ absl::Status RegisterIntConversionFunctions(CelFunctionRegistry* registry, if (!status.ok()) return status; // int -> int - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kInt, false, [](Arena*, int64_t v) { return v; }, registry); if (!status.ok()) return status; // string -> int - status = FunctionAdapter::CreateAndRegister( - builtin::kInt, false, - [](Arena* arena, CelValue::StringHolder s) { - int64_t result; - if (!absl::SimpleAtoi(s.value(), &result)) { - return CreateErrorValue(arena, "cannot convert string to int", - absl::StatusCode::kInvalidArgument); - } - return CelValue::CreateInt64(result); - }, - registry); + status = PortableFunctionAdapter:: + CreateAndRegister( + builtin::kInt, false, + [](Arena* arena, CelValue::StringHolder s) { + int64_t result; + if (!absl::SimpleAtoi(s.value(), &result)) { + return CreateErrorValue(arena, "cannot convert string to int", + absl::StatusCode::kInvalidArgument); + } + return CelValue::CreateInt64(result); + }, + registry); if (!status.ok()) return status; // time -> int - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kInt, false, [](Arena*, absl::Time t) { return absl::ToUnixSeconds(t); }, registry); if (!status.ok()) return status; // uint -> int - return FunctionAdapter::CreateAndRegister( + return PortableFunctionAdapter::CreateAndRegister( builtin::kInt, false, [](Arena* arena, uint64_t v) { auto conv = cel::internal::CheckedUint64ToInt64(v); @@ -998,8 +1028,8 @@ absl::Status RegisterStringConversionFunctions( return absl::OkStatus(); } - auto status = - FunctionAdapter::CreateAndRegister( + auto status = PortableFunctionAdapter:: + CreateAndRegister( builtin::kString, false, [](Arena* arena, CelValue::BytesHolder value) -> CelValue { if (::cel::internal::Utf8IsValid(value.value())) { @@ -1012,47 +1042,50 @@ absl::Status RegisterStringConversionFunctions( if (!status.ok()) return status; // double -> string - status = FunctionAdapter::CreateAndRegister( - builtin::kString, false, - [](Arena* arena, double value) -> CelValue::StringHolder { - return CelValue::StringHolder( - Arena::Create(arena, absl::StrCat(value))); - }, - registry); + status = PortableFunctionAdapter:: + CreateAndRegister( + builtin::kString, false, + [](Arena* arena, double value) -> CelValue::StringHolder { + return CelValue::StringHolder( + Arena::Create(arena, absl::StrCat(value))); + }, + registry); if (!status.ok()) return status; // int -> string - status = FunctionAdapter::CreateAndRegister( - builtin::kString, false, - [](Arena* arena, int64_t value) -> CelValue::StringHolder { - return CelValue::StringHolder( - Arena::Create(arena, absl::StrCat(value))); - }, - registry); - if (!status.ok()) return status; - - // string -> string - status = FunctionAdapter:: + status = PortableFunctionAdapter:: CreateAndRegister( builtin::kString, false, - [](Arena*, CelValue::StringHolder value) -> CelValue::StringHolder { - return value; + [](Arena* arena, int64_t value) -> CelValue::StringHolder { + return CelValue::StringHolder( + Arena::Create(arena, absl::StrCat(value))); }, registry); if (!status.ok()) return status; + // string -> string + status = + PortableFunctionAdapter:: + CreateAndRegister( + builtin::kString, false, + [](Arena*, CelValue::StringHolder value) + -> CelValue::StringHolder { return value; }, + registry); + if (!status.ok()) return status; + // uint -> string - status = FunctionAdapter::CreateAndRegister( - builtin::kString, false, - [](Arena* arena, uint64_t value) -> CelValue::StringHolder { - return CelValue::StringHolder( - Arena::Create(arena, absl::StrCat(value))); - }, - registry); + status = PortableFunctionAdapter:: + CreateAndRegister( + builtin::kString, false, + [](Arena* arena, uint64_t value) -> CelValue::StringHolder { + return CelValue::StringHolder( + Arena::Create(arena, absl::StrCat(value))); + }, + registry); if (!status.ok()) return status; // duration -> string - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kString, false, [](Arena* arena, absl::Duration value) -> CelValue { auto encode = EncodeDurationToString(value); @@ -1066,7 +1099,7 @@ absl::Status RegisterStringConversionFunctions( if (!status.ok()) return status; // timestamp -> string - return FunctionAdapter::CreateAndRegister( + return PortableFunctionAdapter::CreateAndRegister( builtin::kString, false, [](Arena* arena, absl::Time value) -> CelValue { auto encode = EncodeTimeToString(value); @@ -1082,7 +1115,7 @@ absl::Status RegisterStringConversionFunctions( absl::Status RegisterUintConversionFunctions(CelFunctionRegistry* registry, const InterpreterOptions&) { // double -> uint - auto status = FunctionAdapter::CreateAndRegister( + auto status = PortableFunctionAdapter::CreateAndRegister( builtin::kUint, false, [](Arena* arena, double v) { auto conv = cel::internal::CheckedDoubleToUint64(v); @@ -1095,7 +1128,7 @@ absl::Status RegisterUintConversionFunctions(CelFunctionRegistry* registry, if (!status.ok()) return status; // int -> uint - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kUint, false, [](Arena* arena, int64_t v) { auto conv = cel::internal::CheckedInt64ToUint64(v); @@ -1108,21 +1141,22 @@ absl::Status RegisterUintConversionFunctions(CelFunctionRegistry* registry, if (!status.ok()) return status; // string -> uint - status = FunctionAdapter::CreateAndRegister( - builtin::kUint, false, - [](Arena* arena, CelValue::StringHolder s) { - uint64_t result; - if (!absl::SimpleAtoi(s.value(), &result)) { - return CreateErrorValue(arena, "doesn't convert to a string", - absl::StatusCode::kInvalidArgument); - } - return CelValue::CreateUint64(result); - }, - registry); + status = PortableFunctionAdapter:: + CreateAndRegister( + builtin::kUint, false, + [](Arena* arena, CelValue::StringHolder s) { + uint64_t result; + if (!absl::SimpleAtoi(s.value(), &result)) { + return CreateErrorValue(arena, "doesn't convert to a string", + absl::StatusCode::kInvalidArgument); + } + return CelValue::CreateUint64(result); + }, + registry); if (!status.ok()) return status; // uint -> uint - return FunctionAdapter::CreateAndRegister( + return PortableFunctionAdapter::CreateAndRegister( builtin::kUint, false, [](Arena*, uint64_t v) { return v; }, registry); } @@ -1135,13 +1169,14 @@ absl::Status RegisterConversionFunctions(CelFunctionRegistry* registry, if (!status.ok()) return status; // duration() conversion from string. - status = FunctionAdapter::CreateAndRegister( - builtin::kDuration, false, CreateDurationFromString, registry); + status = PortableFunctionAdapter:: + CreateAndRegister(builtin::kDuration, false, CreateDurationFromString, + registry); if (!status.ok()) return status; // dyn() identity function. // TODO(issues/102): strip dyn() function references at type-check time. - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kDyn, false, [](Arena*, CelValue value) -> CelValue { return value; }, registry); @@ -1152,7 +1187,7 @@ absl::Status RegisterConversionFunctions(CelFunctionRegistry* registry, if (!status.ok()) return status; // timestamp conversion from int. - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kTimestamp, false, [](Arena*, int64_t epoch_seconds) -> CelValue { return CelValue::CreateTimestamp(absl::FromUnixSeconds(epoch_seconds)); @@ -1162,25 +1197,26 @@ absl::Status RegisterConversionFunctions(CelFunctionRegistry* registry, // timestamp() conversion from string. bool enable_timestamp_duration_overflow_errors = options.enable_timestamp_duration_overflow_errors; - status = FunctionAdapter::CreateAndRegister( - builtin::kTimestamp, false, - [=](Arena* arena, CelValue::StringHolder time_str) -> CelValue { - absl::Time ts; - if (!absl::ParseTime(absl::RFC3339_full, time_str.value(), &ts, - nullptr)) { - return CreateErrorValue(arena, - "String to Timestamp conversion failed", - absl::StatusCode::kInvalidArgument); - } - if (enable_timestamp_duration_overflow_errors) { - if (ts < absl::UniversalEpoch() || ts > kMaxTime) { - return CreateErrorValue(arena, "timestamp overflow", - absl::StatusCode::kOutOfRange); - } - } - return CelValue::CreateTimestamp(ts); - }, - registry); + status = PortableFunctionAdapter:: + CreateAndRegister( + builtin::kTimestamp, false, + [=](Arena* arena, CelValue::StringHolder time_str) -> CelValue { + absl::Time ts; + if (!absl::ParseTime(absl::RFC3339_full, time_str.value(), &ts, + nullptr)) { + return CreateErrorValue(arena, + "String to Timestamp conversion failed", + absl::StatusCode::kInvalidArgument); + } + if (enable_timestamp_duration_overflow_errors) { + if (ts < absl::UniversalEpoch() || ts > kMaxTime) { + return CreateErrorValue(arena, "timestamp overflow", + absl::StatusCode::kOutOfRange); + } + } + return CelValue::CreateTimestamp(ts); + }, + registry); if (!status.ok()) return status; return RegisterUintConversionFunctions(registry, options); @@ -1191,13 +1227,13 @@ absl::Status RegisterConversionFunctions(CelFunctionRegistry* registry, absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { // logical NOT - absl::Status status = FunctionAdapter::CreateAndRegister( + absl::Status status = PortableFunctionAdapter::CreateAndRegister( builtin::kNot, false, [](Arena*, bool value) -> bool { return !value; }, registry); if (!status.ok()) return status; // Negation group - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kNeg, false, [](Arena* arena, int64_t value) -> CelValue { auto inv = cel::internal::CheckedNegation(value); @@ -1209,7 +1245,7 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kNeg, false, [](Arena*, double value) -> double { return -value; }, registry); if (!status.ok()) return status; @@ -1220,27 +1256,27 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, if (!status.ok()) return status; // Strictness - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kNotStrictlyFalse, false, [](Arena*, bool value) -> bool { return value; }, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kNotStrictlyFalse, false, [](Arena*, const CelError*) -> bool { return true; }, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kNotStrictlyFalse, false, [](Arena*, const UnknownSet*) -> bool { return true; }, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kNotStrictlyFalseDeprecated, false, [](Arena*, bool value) -> bool { return value; }, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kNotStrictlyFalseDeprecated, false, [](Arena*, const CelError*) -> bool { return true; }, registry); if (!status.ok()) return status; @@ -1257,11 +1293,14 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, }; // receiver style = true/false // Support global and receiver style size() operations on strings. - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, true, size_func, registry); + status = PortableFunctionAdapter< + CelValue, CelValue::StringHolder>::CreateAndRegister(builtin::kSize, true, + size_func, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, false, size_func, registry); + status = PortableFunctionAdapter< + CelValue, CelValue::StringHolder>::CreateAndRegister(builtin::kSize, + false, size_func, + registry); if (!status.ok()) return status; // Bytes size @@ -1270,11 +1309,15 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, }; // receiver style = true/false // Support global and receiver style size() operations on bytes. - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, true, bytes_size_func, registry); + status = PortableFunctionAdapter< + int64_t, CelValue::BytesHolder>::CreateAndRegister(builtin::kSize, true, + bytes_size_func, + registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kSize, false, bytes_size_func, registry); + status = PortableFunctionAdapter< + int64_t, CelValue::BytesHolder>::CreateAndRegister(builtin::kSize, false, + bytes_size_func, + registry); if (!status.ok()) return status; // List size @@ -1283,10 +1326,10 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, }; // receiver style = true/false // Support both the global and receiver style size() for lists. - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kSize, true, list_size_func, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kSize, false, list_size_func, registry); if (!status.ok()) return status; @@ -1295,10 +1338,10 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, return (*cel_map).size(); }; // receiver style = true/false - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kSize, true, map_size_func, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kSize, false, map_size_func, registry); if (!status.ok()) return status; @@ -1319,8 +1362,8 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, bool enable_timestamp_duration_overflow_errors = options.enable_timestamp_duration_overflow_errors; // Special arithmetic operators for Timestamp and Duration - status = - FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter:: + CreateAndRegister( builtin::kAdd, false, [=](Arena* arena, absl::Time t1, absl::Duration d2) -> CelValue { if (enable_timestamp_duration_overflow_errors) { @@ -1335,8 +1378,8 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = - FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter:: + CreateAndRegister( builtin::kAdd, false, [=](Arena* arena, absl::Duration d2, absl::Time t1) -> CelValue { if (enable_timestamp_duration_overflow_errors) { @@ -1351,7 +1394,7 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter:: + status = PortableFunctionAdapter:: CreateAndRegister( builtin::kAdd, false, [=](Arena* arena, absl::Duration d1, absl::Duration d2) -> CelValue { @@ -1367,8 +1410,8 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = - FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter:: + CreateAndRegister( builtin::kSubtract, false, [=](Arena* arena, absl::Time t1, absl::Duration d2) -> CelValue { if (enable_timestamp_duration_overflow_errors) { @@ -1383,22 +1426,23 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kSubtract, false, - [=](Arena* arena, absl::Time t1, absl::Time t2) -> CelValue { - if (enable_timestamp_duration_overflow_errors) { - auto diff = cel::internal::CheckedSub(t1, t2); - if (!diff.ok()) { - return CreateErrorValue(arena, diff.status()); - } - return CelValue::CreateDuration(*diff); - } - return CelValue::CreateDuration(t1 - t2); - }, - registry); + status = PortableFunctionAdapter:: + CreateAndRegister( + builtin::kSubtract, false, + [=](Arena* arena, absl::Time t1, absl::Time t2) -> CelValue { + if (enable_timestamp_duration_overflow_errors) { + auto diff = cel::internal::CheckedSub(t1, t2); + if (!diff.ok()) { + return CreateErrorValue(arena, diff.status()); + } + return CelValue::CreateDuration(*diff); + } + return CelValue::CreateDuration(t1 - t2); + }, + registry); if (!status.ok()) return status; - status = FunctionAdapter:: + status = PortableFunctionAdapter:: CreateAndRegister( builtin::kSubtract, false, [=](Arena* arena, absl::Duration d1, absl::Duration d2) -> CelValue { @@ -1416,27 +1460,24 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, // Concat group if (options.enable_string_concat) { - status = FunctionAdapter< + status = PortableFunctionAdapter< CelValue::StringHolder, CelValue::StringHolder, CelValue::StringHolder>::CreateAndRegister(builtin::kAdd, false, ConcatString, registry); if (!status.ok()) return status; - status = - FunctionAdapter::CreateAndRegister(builtin::kAdd, - false, - ConcatBytes, - registry); + status = PortableFunctionAdapter< + CelValue::BytesHolder, CelValue::BytesHolder, + CelValue::BytesHolder>::CreateAndRegister(builtin::kAdd, false, + ConcatBytes, registry); if (!status.ok()) return status; } if (options.enable_list_concat) { - status = - FunctionAdapter::CreateAndRegister(builtin::kAdd, false, - ConcatList, - registry); + status = PortableFunctionAdapter< + const CelList*, const CelList*, + const CelList*>::CreateAndRegister(builtin::kAdd, false, ConcatList, + registry); if (!status.ok()) return status; } @@ -1457,42 +1498,45 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, return CelValue::CreateBool(RE2::PartialMatch(re2::StringPiece(target.value().data(), target.value().size()), re2)); }; - status = FunctionAdapter< + status = PortableFunctionAdapter< CelValue, CelValue::StringHolder, CelValue::StringHolder>::CreateAndRegister(builtin::kRegexMatch, false, regex_matches, registry); if (!status.ok()) return status; // Receiver-style matches function. - status = FunctionAdapter< + status = PortableFunctionAdapter< CelValue, CelValue::StringHolder, CelValue::StringHolder>::CreateAndRegister(builtin::kRegexMatch, true, regex_matches, registry); if (!status.ok()) return status; } - status = FunctionAdapter:: - CreateAndRegister(builtin::kRuntimeListAppend, false, AppendList, - registry); + status = + PortableFunctionAdapter:: + CreateAndRegister(builtin::kRuntimeListAppend, false, AppendList, + registry); if (!status.ok()) return status; status = RegisterStringFunctions(registry, options); if (!status.ok()) return status; // Modulo - status = FunctionAdapter::CreateAndRegister( - builtin::kModulo, false, Modulo, registry); + status = + PortableFunctionAdapter::CreateAndRegister( + builtin::kModulo, false, Modulo, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kModulo, false, Modulo, registry); + status = + PortableFunctionAdapter::CreateAndRegister( + builtin::kModulo, false, Modulo, registry); if (!status.ok()) return status; status = RegisterTimestampFunctions(registry, options); if (!status.ok()) return status; // duration functions - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kHours, true, [](Arena* arena, absl::Duration d) -> CelValue { return GetHours(arena, d); @@ -1500,7 +1544,7 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kMinutes, true, [](Arena* arena, absl::Duration d) -> CelValue { return GetMinutes(arena, d); @@ -1508,7 +1552,7 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kSeconds, true, [](Arena* arena, absl::Duration d) -> CelValue { return GetSeconds(arena, d); @@ -1516,7 +1560,7 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kMilliseconds, true, [](Arena* arena, absl::Duration d) -> CelValue { return GetMilliseconds(arena, d); @@ -1524,12 +1568,13 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; - return FunctionAdapter::CreateAndRegister( - builtin::kType, false, - [](Arena*, CelValue value) -> CelValue::CelTypeHolder { - return value.ObtainCelType().CelTypeOrDie(); - }, - registry); + return PortableFunctionAdapter:: + CreateAndRegister( + builtin::kType, false, + [](Arena*, CelValue value) -> CelValue::CelTypeHolder { + return value.ObtainCelType().CelTypeOrDie(); + }, + registry); } } // namespace google::api::expr::runtime diff --git a/eval/public/comparison_functions.cc b/eval/public/comparison_functions.cc index 77c5e7069..c6ce86e00 100644 --- a/eval/public/comparison_functions.cc +++ b/eval/public/comparison_functions.cc @@ -32,12 +32,12 @@ #include "absl/types/optional.h" #include "eval/eval/mutable_list_impl.h" #include "eval/public/cel_builtins.h" -#include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_number.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/message_wrapper.h" +#include "eval/public/portable_cel_function_adapter.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "internal/casts.h" @@ -177,14 +177,12 @@ bool CrossNumericGreaterOrEqualTo(Arena* arena, T t, U u) { return CelNumber(t) >= CelNumber(u); } -bool MessageNullEqual(Arena* arena, const google::protobuf::Message* t1, - CelValue::NullType) { +bool MessageNullEqual(Arena* arena, MessageWrapper t1, CelValue::NullType) { // messages should never be null. return false; } -bool MessageNullInequal(Arena* arena, const google::protobuf::Message* t1, - CelValue::NullType) { +bool MessageNullInequal(Arena* arena, MessageWrapper t1, CelValue::NullType) { // messages should never be null. return true; } @@ -380,13 +378,13 @@ template absl::Status RegisterEqualityFunctionsForType(CelFunctionRegistry* registry) { // Inequality absl::Status status = - FunctionAdapter::CreateAndRegister( + PortableFunctionAdapter::CreateAndRegister( builtin::kInequal, false, WrapComparison(&Inequal), registry); if (!status.ok()) return status; // Equality - status = FunctionAdapter::CreateAndRegister( + status = PortableFunctionAdapter::CreateAndRegister( builtin::kEqual, false, WrapComparison(&Equal), registry); return status; } @@ -395,11 +393,11 @@ template absl::Status RegisterSymmetricFunction( absl::string_view name, std::function fn, CelFunctionRegistry* registry) { - CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( + CEL_RETURN_IF_ERROR((PortableFunctionAdapter::CreateAndRegister( name, false, fn, registry))); // the symmetric version - CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( + CEL_RETURN_IF_ERROR((PortableFunctionAdapter::CreateAndRegister( name, false, [fn](google::protobuf::Arena* arena, U u, T t) { return fn(arena, t, u); }, registry))); @@ -411,20 +409,25 @@ template absl::Status RegisterOrderingFunctionsForType(CelFunctionRegistry* registry) { // Less than // Extra paranthesis needed for Macros with multiple template arguments. - CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( - builtin::kLess, false, LessThan, registry))); + CEL_RETURN_IF_ERROR( + (PortableFunctionAdapter::CreateAndRegister( + builtin::kLess, false, LessThan, registry))); // Less than or Equal - CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( - builtin::kLessOrEqual, false, LessThanOrEqual, registry))); + CEL_RETURN_IF_ERROR( + (PortableFunctionAdapter::CreateAndRegister( + builtin::kLessOrEqual, false, LessThanOrEqual, registry))); // Greater than - CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( - builtin::kGreater, false, GreaterThan, registry))); + CEL_RETURN_IF_ERROR( + (PortableFunctionAdapter::CreateAndRegister( + builtin::kGreater, false, GreaterThan, registry))); // Greater than or Equal - CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( - builtin::kGreaterOrEqual, false, GreaterThanOrEqual, registry))); + CEL_RETURN_IF_ERROR( + (PortableFunctionAdapter::CreateAndRegister( + builtin::kGreaterOrEqual, false, GreaterThanOrEqual, + registry))); return absl::OkStatus(); } @@ -479,17 +482,17 @@ absl::Status RegisterHomogenousComparisonFunctions( absl::Status RegisterNullMessageEqualityFunctions( CelFunctionRegistry* registry) { CEL_RETURN_IF_ERROR( - (RegisterSymmetricFunction( + (RegisterSymmetricFunction( builtin::kEqual, MessageNullEqual, registry))); CEL_RETURN_IF_ERROR( - (RegisterSymmetricFunction( + (RegisterSymmetricFunction( builtin::kInequal, MessageNullInequal, registry))); return absl::OkStatus(); } -// Wrapper around CelValueEqualImpl to work with the FunctionAdapter template. -// Implements CEL ==, +// Wrapper around CelValueEqualImpl to work with the PortableFunctionAdapter +// template. Implements CEL ==, CelValue GeneralizedEqual(Arena* arena, CelValue t1, CelValue t2) { std::optional result = CelValueEqualImpl(t1, t2); if (result.has_value()) { @@ -500,8 +503,8 @@ CelValue GeneralizedEqual(Arena* arena, CelValue t1, CelValue t2) { return CreateNoMatchingOverloadError(arena, builtin::kEqual); } -// Wrapper around CelValueEqualImpl to work with the FunctionAdapter template. -// Implements CEL !=. +// Wrapper around CelValueEqualImpl to work with the PortableFunctionAdapter +// template. Implements CEL !=. CelValue GeneralizedInequal(Arena* arena, CelValue t1, CelValue t2) { std::optional result = CelValueEqualImpl(t1, t2); if (result.has_value()) { @@ -512,16 +515,16 @@ CelValue GeneralizedInequal(Arena* arena, CelValue t1, CelValue t2) { template absl::Status RegisterCrossNumericComparisons(CelFunctionRegistry* registry) { - CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( + CEL_RETURN_IF_ERROR((PortableFunctionAdapter::CreateAndRegister( builtin::kLess, /*receiver_style=*/false, &CrossNumericLessThan, registry))); - CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( + CEL_RETURN_IF_ERROR((PortableFunctionAdapter::CreateAndRegister( builtin::kGreater, /*receiver_style=*/false, &CrossNumericGreaterThan, registry))); - CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( + CEL_RETURN_IF_ERROR((PortableFunctionAdapter::CreateAndRegister( builtin::kGreaterOrEqual, /*receiver_style=*/false, &CrossNumericGreaterOrEqualTo, registry))); - CEL_RETURN_IF_ERROR((FunctionAdapter::CreateAndRegister( + CEL_RETURN_IF_ERROR((PortableFunctionAdapter::CreateAndRegister( builtin::kLessOrEqual, /*receiver_style=*/false, &CrossNumericLessOrEqualTo, registry))); return absl::OkStatus(); @@ -530,11 +533,11 @@ absl::Status RegisterCrossNumericComparisons(CelFunctionRegistry* registry) { absl::Status RegisterHeterogeneousComparisonFunctions( CelFunctionRegistry* registry) { CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( + (PortableFunctionAdapter::CreateAndRegister( builtin::kEqual, /*receiver_style=*/false, &GeneralizedEqual, registry))); CEL_RETURN_IF_ERROR( - (FunctionAdapter::CreateAndRegister( + (PortableFunctionAdapter::CreateAndRegister( builtin::kInequal, /*receiver_style=*/false, &GeneralizedInequal, registry))); diff --git a/eval/public/comparison_functions.h b/eval/public/comparison_functions.h index 96563e11e..b9300b099 100644 --- a/eval/public/comparison_functions.h +++ b/eval/public/comparison_functions.h @@ -15,7 +15,6 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_COMPARISON_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_COMPARISON_FUNCTIONS_H_ -#include "google/protobuf/arena.h" #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" @@ -27,7 +26,7 @@ namespace google::api::expr::runtime { // // Returns nullopt if the comparison is undefined between differently typed // values. -absl::optional CelValueEqualImpl(const CelValue& v1, const CelValue& v2); +std::optional CelValueEqualImpl(const CelValue& v1, const CelValue& v2); // Register built in comparison functions (==, !=, <, <=, >, >=). // diff --git a/eval/public/portable_cel_expr_builder_factory_test.cc b/eval/public/portable_cel_expr_builder_factory_test.cc index 329d57741..68c56d44a 100644 --- a/eval/public/portable_cel_expr_builder_factory_test.cc +++ b/eval/public/portable_cel_expr_builder_factory_test.cc @@ -22,13 +22,10 @@ #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" -#include "google/protobuf/descriptor.h" -#include "google/protobuf/dynamic_message.h" -#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/types/optional.h" #include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/structs/legacy_type_adapter.h" @@ -488,8 +485,7 @@ TEST(PortableCelExprBuilderFactoryTest, CreateSuccess) { ASSERT_OK_AND_ASSIGN( ParsedExpr expr, parser::Parse("google.protobuf.Timestamp{seconds: 3000, nanos: 20}")); - // TODO(issues/5): make builtin functions portable - // ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); @@ -513,6 +509,7 @@ TEST(PortableCelExprBuilderFactoryTest, CreateCustomMessage) { ParsedExpr expr, parser::Parse("google.api.expr.runtime.TestMessage{int64_value: 20, " "double_value: 3.5}.double_value")); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), opts)); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); @@ -538,6 +535,7 @@ TEST(PortableCelExprBuilderFactoryTest, ActivationAndCreate) { ParsedExpr expr, parser::Parse("TestMessage{int64_value: 20, bool_value: " "false}.bool_value || my_var.bool_value ? 1 : 2")); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), opts)); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); @@ -554,6 +552,7 @@ TEST(PortableCelExprBuilderFactoryTest, ActivationAndCreate) { TEST(PortableCelExprBuilderFactoryTest, WrapperTypes) { google::protobuf::Arena arena; InterpreterOptions opts; + opts.enable_heterogeneous_equality = true; Activation activation; auto provider = std::make_unique(); const auto* provider_view = provider.get(); @@ -561,8 +560,9 @@ TEST(PortableCelExprBuilderFactoryTest, WrapperTypes) { CreatePortableExprBuilder(std::move(provider), opts); builder->set_container("google.api.expr.runtime"); ASSERT_OK_AND_ASSIGN(ParsedExpr null_expr, - parser::Parse("my_var.int64_wrapper_value")); - + parser::Parse("my_var.int64_wrapper_value != null ? " + "my_var.int64_wrapper_value > 29 : null")); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), opts)); TestMessage my_var; my_var.set_bool_value(true); activation.InsertValue("my_var", provider_view->WrapValue(&my_var)); @@ -577,9 +577,40 @@ TEST(PortableCelExprBuilderFactoryTest, WrapperTypes) { my_var.mutable_int64_wrapper_value()->set_value(30); ASSERT_OK_AND_ASSIGN(result, plan->Evaluate(activation, &arena)); - int64_t result_int64; - ASSERT_TRUE(result.GetValue(&result_int64)) << result.DebugString(); - EXPECT_EQ(result_int64, 30); + bool result_bool; + ASSERT_TRUE(result.GetValue(&result_bool)) << result.DebugString(); + EXPECT_TRUE(result_bool); +} + +TEST(PortableCelExprBuilderFactoryTest, SimpleBuiltinFunctions) { + google::protobuf::Arena arena; + InterpreterOptions opts; + opts.enable_heterogeneous_equality = true; + Activation activation; + auto provider = std::make_unique(); + std::unique_ptr builder = + CreatePortableExprBuilder(std::move(provider), opts); + builder->set_container("google.api.expr.runtime"); + + // Fairly complicated but silly expression to cover a mix of builtins + // (comparisons, arithmetic, datetime). + ASSERT_OK_AND_ASSIGN( + ParsedExpr ternary_expr, + parser::Parse( + "TestMessage{int64_value: 2}.int64_value + 1 < " + " TestMessage{double_value: 3.5}.double_value - 0.1 ? " + " (google.protobuf.Timestamp{seconds: 300} - timestamp(240) " + " >= duration('1m') ? 'yes' : 'no') :" + " null")); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), opts)); + + ASSERT_OK_AND_ASSIGN(auto plan, + builder->CreateExpression(&ternary_expr.expr(), + &ternary_expr.source_info())); + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + + ASSERT_TRUE(result.IsString()) << result.DebugString(); + EXPECT_EQ(result.StringOrDie().value(), "yes"); } } // namespace From b078a781891e41936b871f32b6ad6791ae417c1b Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 6 May 2022 22:57:39 +0000 Subject: [PATCH 098/111] Remove `CreateUnknownValueError` and `IsUnknownValueError` PiperOrigin-RevId: 447089041 --- eval/public/cel_value.h | 7 ------- 1 file changed, 7 deletions(-) diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index effa2603a..fd170f5a5 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -649,13 +649,6 @@ CelValue CreateUnknownFunctionResultError(google::protobuf::Arena* arena, // into. bool IsUnknownFunctionResult(const CelValue& value); -ABSL_DEPRECATED("This type of error is no longer used by the evaluator.") -CelValue CreateUnknownValueError(google::protobuf::Arena* arena, - absl::string_view unknown_path); - -ABSL_DEPRECATED("This type of error is no longer used by the evaluator.") -bool IsUnknownValueError(const CelValue& value); - } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_H_ From 61815be75171b8dd376238d840fdccdd85a95213 Mon Sep 17 00:00:00 2001 From: jcking Date: Wed, 11 May 2022 17:35:46 +0000 Subject: [PATCH 099/111] Remove more deprecated and unreferenced functions PiperOrigin-RevId: 448029301 --- eval/public/base_activation.h | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/eval/public/base_activation.h b/eval/public/base_activation.h index a63f8c9b7..6b33681ee 100644 --- a/eval/public/base_activation.h +++ b/eval/public/base_activation.h @@ -31,16 +31,6 @@ class BaseActivation { virtual absl::optional FindValue(absl::string_view, google::protobuf::Arena*) const = 0; - ABSL_DEPRECATED( - "No longer supported in the activation. See " - "google::api::expr::runtime::AttributeUtility.") - virtual bool IsPathUnknown(absl::string_view) const { return false; } - - ABSL_DEPRECATED("Use missing_attribute_patterns() instead.") - virtual const google::protobuf::FieldMask& unknown_paths() const { - return google::protobuf::FieldMask::default_instance(); - } - // Return the collection of attribute patterns that determine missing // attributes. virtual const std::vector& missing_attribute_patterns() From bb8475cd73fd4f076427e9d966f1fce43d102c7e Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 12 May 2022 18:50:03 +0000 Subject: [PATCH 100/111] Remove references to `CelValue` from `CelAttributeQualifier` PiperOrigin-RevId: 448302268 --- eval/public/cel_attribute.cc | 143 +++++++++++++++++++----------- eval/public/cel_attribute.h | 79 ++++++++++------- eval/public/cel_attribute_test.cc | 8 ++ 3 files changed, 148 insertions(+), 82 deletions(-) diff --git a/eval/public/cel_attribute.cc b/eval/public/cel_attribute.cc index 917413022..c7c26c95a 100644 --- a/eval/public/cel_attribute.cc +++ b/eval/public/cel_attribute.cc @@ -2,10 +2,10 @@ #include #include +#include #include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "absl/types/variant.h" #include "eval/public/cel_value.h" namespace google::api::expr::runtime { @@ -45,6 +45,13 @@ class CelAttributeStringPrinter { explicit CelAttributeStringPrinter(std::string* output, CelValue::Type type) : output_(*output), type_(type) {} + absl::Status operator()(const CelValue::Type& ignored) const { + // Attributes are represented as a variant, with illegal attribute + // qualifiers represented with their type as the first alternative. + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported attribute qualifier ", CelValue::TypeName(type_))); + } + absl::Status operator()(int64_t index) { absl::StrAppend(&output_, "[", index, "]"); return absl::OkStatus(); @@ -60,77 +67,102 @@ class CelAttributeStringPrinter { return absl::OkStatus(); } - absl::Status operator()(const CelValue::StringHolder& field) { - absl::StrAppend(&output_, ".", field.value()); + absl::Status operator()(const std::string& field) { + absl::StrAppend(&output_, ".", field); return absl::OkStatus(); } - template - absl::Status operator()(const T&) { - // Attributes are represented as generic CelValues, but remaining kinds are - // not legal attribute qualifiers. - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported attribute qualifier ", CelValue::TypeName(type_))); - } - private: std::string& output_; CelValue::Type type_; }; -// Helper class, used to implement CelAttributeQualifier::operator==. -class EqualVisitor { - public: - template - class NestedEqualVisitor { - public: - explicit NestedEqualVisitor(const T& arg) : arg_(arg) {} +struct CelAttributeQualifierTypeVisitor final { + CelValue::Type operator()(const CelValue::Type& type) const { return type; } - template - bool operator()(const U&) const { - return false; - } + CelValue::Type operator()(int64_t ignored) const { + static_cast(ignored); + return CelValue::Type::kInt64; + } - bool operator()(const T& other) const { return other == arg_; } - - private: - const T& arg_; - }; - // Message wrapper is unsupported. Add specialization to make visitor - // compile. - template <> - class NestedEqualVisitor { - public: - explicit NestedEqualVisitor( - const CelValue::MessageWrapper&) {} - template - bool operator()(const U&) const { - return false; - } - }; + CelValue::Type operator()(uint64_t ignored) const { + static_cast(ignored); + return CelValue::Type::kUint64; + } - explicit EqualVisitor(const CelValue& other) : other_(other) {} + CelValue::Type operator()(const std::string& ignored) const { + static_cast(ignored); + return CelValue::Type::kString; + } - template - bool operator()(const Type& arg) { - return other_.template InternalVisit(NestedEqualVisitor(arg)); + CelValue::Type operator()(bool ignored) const { + static_cast(ignored); + return CelValue::Type::kBool; } +}; - private: - const CelValue& other_; +struct CelAttributeQualifierIsMatchVisitor final { + const CelValue& value; + + bool operator()(const CelValue::Type& ignored) const { + static_cast(ignored); + return false; + } + + bool operator()(int64_t other) const { + int64_t value_value; + return value.GetValue(&value_value) && value_value == other; + } + + bool operator()(uint64_t other) const { + uint64_t value_value; + return value.GetValue(&value_value) && value_value == other; + } + + bool operator()(const std::string& other) const { + CelValue::StringHolder value_value; + return value.GetValue(&value_value) && value_value.value() == other; + } + + bool operator()(bool other) const { + bool value_value; + return value.GetValue(&value_value) && value_value == other; + } }; } // namespace +CelValue::Type CelAttributeQualifier::type() const { + return std::visit(CelAttributeQualifierTypeVisitor{}, value_); +} + +CelAttributeQualifier CelAttributeQualifier::Create(CelValue value) { + switch (value.type()) { + case CelValue::Type::kInt64: + return CelAttributeQualifier(std::in_place_type, + value.Int64OrDie()); + case CelValue::Type::kUint64: + return CelAttributeQualifier(std::in_place_type, + value.Uint64OrDie()); + case CelValue::Type::kString: + return CelAttributeQualifier(std::in_place_type, + std::string(value.StringOrDie().value())); + case CelValue::Type::kBool: + return CelAttributeQualifier(std::in_place_type, value.BoolOrDie()); + default: + return CelAttributeQualifier(); + } +} + CelAttributePattern CreateCelAttributePattern( absl::string_view variable, - std::initializer_list> + std::initializer_list> path_spec) { std::vector path; path.reserve(path_spec.size()); for (const auto& spec_elem : path_spec) { - path.emplace_back(absl::visit(QualifierVisitor(), spec_elem)); + path.emplace_back(std::visit(QualifierVisitor(), spec_elem)); } return CelAttributePattern(std::string(variable), std::move(path)); } @@ -167,15 +199,24 @@ const absl::StatusOr CelAttribute::AsString() const { std::string result = variable_.ident_expr().name(); for (const auto& qualifier : qualifier_path_) { - CEL_RETURN_IF_ERROR(qualifier.Visit( - CelAttributeStringPrinter(&result, qualifier.type()))); + CEL_RETURN_IF_ERROR( + std::visit(CelAttributeStringPrinter(&result, qualifier.type()), + qualifier.value_)); } return result; } bool CelAttributeQualifier::IsMatch(const CelValue& cel_value) const { - return value_.template InternalVisit(EqualVisitor(cel_value)); + return std::visit(CelAttributeQualifierIsMatchVisitor{cel_value}, value_); +} + +bool CelAttributeQualifier::IsMatch(const CelAttributeQualifier& other) const { + if (std::holds_alternative(value_) || + std::holds_alternative(other.value_)) { + return false; + } + return value_ == other.value_; } } // namespace google::api::expr::runtime diff --git a/eval/public/cel_attribute.h b/eval/public/cel_attribute.h index 0e5523e0a..afe8fab87 100644 --- a/eval/public/cel_attribute.h +++ b/eval/public/cel_attribute.h @@ -6,6 +6,10 @@ #include #include #include +#include +#include +#include +#include #include "google/api/expr/v1alpha1/syntax.pb.h" #include "absl/status/status.h" @@ -27,56 +31,69 @@ namespace google::api::expr::runtime { class CelAttributeQualifier { public: // Factory method. - static CelAttributeQualifier Create(CelValue value) { - return CelAttributeQualifier(value); - } + static CelAttributeQualifier Create(CelValue value); + + CelAttributeQualifier(const CelAttributeQualifier&) = default; + CelAttributeQualifier(CelAttributeQualifier&&) = default; + + CelAttributeQualifier& operator=(const CelAttributeQualifier&) = default; + CelAttributeQualifier& operator=(CelAttributeQualifier&&) = default; - CelValue::Type type() const { return value_.type(); } + CelValue::Type type() const; // Family of Get... methods. Return values if requested type matches the // stored one. - absl::optional GetInt64Key() const { - return (value_.IsInt64()) ? absl::optional(value_.Int64OrDie()) - : absl::nullopt; + std::optional GetInt64Key() const { + return std::holds_alternative(value_) + ? std::optional(std::get<1>(value_)) + : std::nullopt; } - absl::optional GetUint64Key() const { - return (value_.IsUint64()) ? absl::optional(value_.Uint64OrDie()) - : absl::nullopt; + std::optional GetUint64Key() const { + return std::holds_alternative(value_) + ? std::optional(std::get<2>(value_)) + : std::nullopt; } - absl::optional GetStringKey() const { - return (value_.IsString()) - ? absl::optional(value_.StringOrDie().value()) - : absl::nullopt; + std::optional GetStringKey() const { + return std::holds_alternative(value_) + ? std::optional(std::get<3>(value_)) + : std::nullopt; } - absl::optional GetBoolKey() const { - return (value_.IsBool()) ? absl::optional(value_.BoolOrDie()) - : absl::nullopt; + std::optional GetBoolKey() const { + return std::holds_alternative(value_) + ? std::optional(std::get<4>(value_)) + : std::nullopt; } bool operator==(const CelAttributeQualifier& other) const { - return IsMatch(other.value_); + return IsMatch(other); } bool IsMatch(const CelValue& cel_value) const; bool IsMatch(absl::string_view other_key) const { - absl::optional key = GetStringKey(); + std::optional key = GetStringKey(); return (key.has_value() && key.value() == other_key); } private: friend class CelAttribute; - explicit CelAttributeQualifier(CelValue value) : value_(value) {} - template - T Visit(Op&& operation) const { - return value_.InternalVisit(operation); - } + CelAttributeQualifier() = default; + + template + CelAttributeQualifier(std::in_place_type_t in_place_type, T&& value) + : value_(in_place_type, std::forward(value)) {} + + bool IsMatch(const CelAttributeQualifier& other) const; - CelValue value_; + // The previous implementation of CelAttribute preserved all CelValue + // instances, regardless of whether they are supported in this context or not. + // We represented unsupported types by using the first alternative and thus + // preserve backwards compatibility with the result of `type()` above. + std::variant value_; }; // CelAttributeQualifierPattern matches a segment in @@ -85,11 +102,11 @@ class CelAttributeQualifier { class CelAttributeQualifierPattern { private: // Qualifier value. If not set, treated as wildcard. - absl::optional value_; + std::optional value_; explicit CelAttributeQualifierPattern( - absl::optional value) - : value_(value) {} + std::optional value) + : value_(std::move(value)) {} public: // Factory method. @@ -98,7 +115,7 @@ class CelAttributeQualifierPattern { } static CelAttributeQualifierPattern CreateWildcard() { - return CelAttributeQualifierPattern(absl::nullopt); + return CelAttributeQualifierPattern(std::nullopt); } bool IsWildcard() const { return !value_.has_value(); } @@ -211,8 +228,8 @@ class CelAttributePattern { // must outlive the returned pattern. CelAttributePattern CreateCelAttributePattern( absl::string_view variable, - std::initializer_list> + std::initializer_list> path_spec = {}); } // namespace google::api::expr::runtime diff --git a/eval/public/cel_attribute_test.cc b/eval/public/cel_attribute_test.cc index 8b013c4fb..7bd09c640 100644 --- a/eval/public/cel_attribute_test.cc +++ b/eval/public/cel_attribute_test.cc @@ -361,6 +361,10 @@ TEST(CelAttribute, InvalidQualifiers) { CelAttributeQualifier::Create( CelProtoWrapper::CreateMessage(&expr, &arena)), }); + CelAttribute attr3( + expr, { + CelAttributeQualifier::Create(CelValue::CreateBool(false)), + }); // Implementation detail: Messages as attribute qualifiers are unsupported, // so the implementation treats them inequal to any other. This is included @@ -368,6 +372,10 @@ TEST(CelAttribute, InvalidQualifiers) { EXPECT_FALSE(attr1 == attr2); EXPECT_FALSE(attr2 == attr1); EXPECT_FALSE(attr2 == attr2); + EXPECT_FALSE(attr1 == attr3); + EXPECT_FALSE(attr3 == attr1); + EXPECT_FALSE(attr2 == attr3); + EXPECT_FALSE(attr3 == attr2); // If the attribute includes an unsupported qualifier, return invalid argument // error. From 99b54d4822e0943add41a424a7f8ed42c1e94b2c Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 12 May 2022 19:18:22 +0000 Subject: [PATCH 101/111] Remove references to `CelValue` from `CelAttributeQualifier` PiperOrigin-RevId: 448308815 --- eval/eval/attribute_utility.h | 8 +- eval/eval/function_step_test.cc | 25 - eval/public/unknown_function_result_set.cc | 17 - eval/public/unknown_function_result_set.h | 11 +- .../unknown_function_result_set_test.cc | 443 +----------------- eval/public/unknown_set_test.cc | 39 +- eval/tests/unknowns_end_to_end_test.cc | 12 +- 7 files changed, 23 insertions(+), 532 deletions(-) diff --git a/eval/eval/attribute_utility.h b/eval/eval/attribute_utility.h index 79f069215..6d4925f0e 100644 --- a/eval/eval/attribute_utility.h +++ b/eval/eval/attribute_utility.h @@ -79,11 +79,9 @@ class AttributeUtility { const UnknownSet* CreateUnknownSet(const CelFunctionDescriptor& fn_descriptor, int64_t expr_id, absl::Span args) const { - auto* fn = memory_manager_ - .New( - fn_descriptor, expr_id, - std::vector(args.begin(), args.end())) - .release(); + auto* fn = + memory_manager_.New(fn_descriptor, expr_id) + .release(); return memory_manager_.New(UnknownFunctionResultSet(fn)) .release(); } diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index 690ce82cd..223d6eb83 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -598,16 +598,6 @@ INSTANTIATE_TEST_SUITE_P( UnknownProcessingOptions::kAttributeAndFunction), &TestNameFn); -MATCHER_P2(IsAdd, a, b, "") { - const UnknownFunctionResult* result = arg; - return result->arguments().size() == 2 && - result->arguments().at(0).IsInt64() && - result->arguments().at(1).IsInt64() && - result->arguments().at(0).Int64OrDie() == a && - result->arguments().at(1).Int64OrDie() == b && - result->descriptor().name() == "_+_"; -} - TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { ExecutionPath path; CelFunctionRegistry registry; @@ -641,11 +631,6 @@ TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()); - // Arguments captured. - EXPECT_THAT(value.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), - ElementsAre(IsAdd(2, 3))); } TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { @@ -691,11 +676,6 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()); - // Arguments captured. - EXPECT_THAT(value.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), - ElementsAre(IsAdd(2, 3))); } TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { @@ -741,11 +721,6 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()) << *(value.ErrorOrDie()); - // Arguments captured. - EXPECT_THAT(value.UnknownSetOrDie() - ->unknown_function_results() - .unknown_function_results(), - UnorderedElementsAre(IsAdd(2, 3), IsAdd(3, 2))); } TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { diff --git a/eval/public/unknown_function_result_set.cc b/eval/public/unknown_function_result_set.cc index b2ef5b84d..75361c263 100644 --- a/eval/public/unknown_function_result_set.cc +++ b/eval/public/unknown_function_result_set.cc @@ -50,23 +50,6 @@ bool UnknownFunctionResultLessThan(const UnknownFunctionResult& lhs, return false; } - if (lhs.arguments().size() < rhs.arguments().size()) { - return true; - } - - if (lhs.arguments().size() > rhs.arguments().size()) { - return false; - } - - for (size_t i = 0; i < lhs.arguments().size(); i++) { - if (CelValueLessThan(lhs.arguments()[i], rhs.arguments()[i])) { - return true; - } - if (CelValueLessThan(rhs.arguments()[i], lhs.arguments()[i])) { - return false; - } - } - // equal return false; } diff --git a/eval/public/unknown_function_result_set.h b/eval/public/unknown_function_result_set.h index 891b3713f..ed13c3985 100644 --- a/eval/public/unknown_function_result_set.h +++ b/eval/public/unknown_function_result_set.h @@ -17,9 +17,8 @@ namespace runtime { // allows for lazy evaluation of expensive functions. class UnknownFunctionResult { public: - UnknownFunctionResult(const CelFunctionDescriptor& descriptor, int64_t expr_id, - const std::vector& arguments) - : descriptor_(descriptor), expr_id_(expr_id), arguments_(arguments) {} + UnknownFunctionResult(const CelFunctionDescriptor& descriptor, int64_t expr_id) + : descriptor_(descriptor), expr_id_(expr_id) {} // The descriptor of the called function that return Unknown. const CelFunctionDescriptor& descriptor() const { return descriptor_; } @@ -29,18 +28,16 @@ class UnknownFunctionResult { // they will be treated as the same unknown function result. int64_t call_expr_id() const { return expr_id_; } - // The arguments of the function call that generated the unknown. - const std::vector& arguments() const { return arguments_; } - // Equality operator provided for testing. Compatible with set less-than // comparator. // Compares descriptor then arguments elementwise. bool IsEqualTo(const UnknownFunctionResult& other) const; + // TODO(issues/5): re-implement argument capture + private: CelFunctionDescriptor descriptor_; int64_t expr_id_; - std::vector arguments_; }; // Comparator for set semantics. diff --git a/eval/public/unknown_function_result_set_test.cc b/eval/public/unknown_function_result_set_test.cc index 8d89ddc2f..a4005a54c 100644 --- a/eval/public/unknown_function_result_set_test.cc +++ b/eval/public/unknown_function_result_set_test.cc @@ -19,6 +19,7 @@ #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "internal/testing.h" + namespace google { namespace api { namespace expr { @@ -42,466 +43,40 @@ bool IsLessThan(const UnknownFunctionResult& lhs, return UnknownFunctionComparator()(&lhs, &rhs); } -TEST(UnknownFunctionResult, ArgumentCapture) { - UnknownFunctionResult call1( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - EXPECT_THAT(call1.arguments(), SizeIs(2)); - EXPECT_THAT(call1.arguments().at(0).Int64OrDie(), Eq(1)); -} - TEST(UnknownFunctionResult, Equals) { - UnknownFunctionResult call1( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + UnknownFunctionResult call1(kTwoInt, /*expr_id=*/0); - UnknownFunctionResult call2( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + UnknownFunctionResult call2(kTwoInt, /*expr_id=*/0); EXPECT_TRUE(call1.IsEqualTo(call2)); EXPECT_FALSE(IsLessThan(call1, call2)); EXPECT_FALSE(IsLessThan(call2, call1)); - UnknownFunctionResult call3(kOneInt, /*expr_id=*/0, - {CelValue::CreateInt64(1)}); + UnknownFunctionResult call3(kOneInt, /*expr_id=*/0); - UnknownFunctionResult call4(kOneInt, /*expr_id=*/0, - {CelValue::CreateInt64(1)}); + UnknownFunctionResult call4(kOneInt, /*expr_id=*/0); EXPECT_TRUE(call3.IsEqualTo(call4)); } TEST(UnknownFunctionResult, InequalDescriptor) { - UnknownFunctionResult call1( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + UnknownFunctionResult call1(kTwoInt, /*expr_id=*/0); - UnknownFunctionResult call2(kOneInt, /*expr_id=*/0, - {CelValue::CreateInt64(1)}); + UnknownFunctionResult call2(kOneInt, /*expr_id=*/0); EXPECT_FALSE(call1.IsEqualTo(call2)); EXPECT_TRUE(IsLessThan(call2, call1)); CelFunctionDescriptor one_uint("OneInt", false, {CelValue::Type::kUint64}); - UnknownFunctionResult call3(kOneInt, /*expr_id=*/0, - {CelValue::CreateInt64(1)}); + UnknownFunctionResult call3(kOneInt, /*expr_id=*/0); - UnknownFunctionResult call4(one_uint, /*expr_id=*/0, - {CelValue::CreateUint64(1)}); + UnknownFunctionResult call4(one_uint, /*expr_id=*/0); EXPECT_FALSE(call3.IsEqualTo(call4)); EXPECT_TRUE(IsLessThan(call3, call4)); } -TEST(UnknownFunctionResult, InequalArgs) { - UnknownFunctionResult call1( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - UnknownFunctionResult call2( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(3)}); - - EXPECT_FALSE(call1.IsEqualTo(call2)); - EXPECT_TRUE(IsLessThan(call1, call2)); - - UnknownFunctionResult call3( - kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - UnknownFunctionResult call4(kTwoInt, /*expr_id=*/0, - {CelValue::CreateInt64(1)}); - - EXPECT_FALSE(call3.IsEqualTo(call4)); - EXPECT_TRUE(IsLessThan(call4, call3)); -} - -TEST(UnknownFunctionResult, ListsEqual) { - ContainerBackedListImpl cel_list_1(std::vector{ - CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - ContainerBackedListImpl cel_list_2(std::vector{ - CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); - - UnknownFunctionResult call1(desc, /*expr_id=*/0, - {CelValue::CreateList(&cel_list_1)}); - UnknownFunctionResult call2(desc, /*expr_id=*/0, - {CelValue::CreateList(&cel_list_2)}); - - // [1, 2] == [1, 2] - EXPECT_TRUE(call1.IsEqualTo(call2)); -} - -TEST(UnknownFunctionResult, ListsDifferentSizes) { - ContainerBackedListImpl cel_list_1(std::vector{ - CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - ContainerBackedListImpl cel_list_2(std::vector{ - CelValue::CreateInt64(1), - CelValue::CreateInt64(2), - CelValue::CreateInt64(3), - }); - - CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); - - UnknownFunctionResult call1(desc, /*expr_id=*/0, - {CelValue::CreateList(&cel_list_1)}); - UnknownFunctionResult call2(desc, /*expr_id=*/0, - {CelValue::CreateList(&cel_list_2)}); - - // [1, 2] == [1, 2, 3] - EXPECT_FALSE(call1.IsEqualTo(call2)); - EXPECT_TRUE(IsLessThan(call1, call2)); -} - -TEST(UnknownFunctionResult, ListsDifferentMembers) { - ContainerBackedListImpl cel_list_1(std::vector{ - CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); - - ContainerBackedListImpl cel_list_2(std::vector{ - CelValue::CreateInt64(2), CelValue::CreateInt64(2)}); - - CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); - - UnknownFunctionResult call1(desc, /*expr_id=*/0, - {CelValue::CreateList(&cel_list_1)}); - UnknownFunctionResult call2(desc, /*expr_id=*/0, - {CelValue::CreateList(&cel_list_2)}); - - // [1, 2] == [2, 2] - EXPECT_FALSE(call1.IsEqualTo(call2)); - EXPECT_TRUE(IsLessThan(call1, call2)); -} - -TEST(UnknownFunctionResult, MapsEqual) { - std::vector> values{ - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, - {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; - - auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)).value(); - auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values)).value(); - - CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); - - UnknownFunctionResult call1(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_1.get())}); - UnknownFunctionResult call2(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_2.get())}); - - // {1: 2, 2: 4} == {1: 2, 2: 4} - EXPECT_TRUE(call1.IsEqualTo(call2)); -} - -TEST(UnknownFunctionResult, MapsDifferentSizes) { - std::vector> values{ - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, - {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; - - auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)).value(); - - std::vector> values2{ - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, - {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, - {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; - - auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)).value(); - - CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); - - UnknownFunctionResult call1(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_1.get())}); - UnknownFunctionResult call2(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_2.get())}); - - // {1: 2, 2: 4} == {1: 2, 2: 4, 3: 6} - EXPECT_FALSE(call1.IsEqualTo(call2)); - EXPECT_TRUE(IsLessThan(call1, call2)); -} - -TEST(UnknownFunctionResult, MapsDifferentElements) { - std::vector> values{ - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, - {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, - {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; - - auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)).value(); - - std::vector> values2{ - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, - {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, - {CelValue::CreateInt64(4), CelValue::CreateInt64(8)}}; - - auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)).value(); - - std::vector> values3{ - {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, - {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, - {CelValue::CreateInt64(3), CelValue::CreateInt64(5)}}; - - auto cel_map_3 = CreateContainerBackedMap(absl::MakeSpan(values3)).value(); - - CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); - - UnknownFunctionResult call1(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_1.get())}); - UnknownFunctionResult call2(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_2.get())}); - UnknownFunctionResult call3(desc, /*expr_id=*/0, - {CelValue::CreateMap(cel_map_3.get())}); - - // {1: 2, 2: 4, 3: 6} == {1: 2, 2: 4, 4: 8} - EXPECT_FALSE(call1.IsEqualTo(call2)); - EXPECT_TRUE(IsLessThan(call1, call2)); - // {1: 2, 2: 4, 3: 6} == {1: 2, 2: 4, 3: 5} - EXPECT_FALSE(call1.IsEqualTo(call3)); - EXPECT_TRUE(IsLessThan(call3, call1)); -} - -TEST(UnknownFunctionResult, Messages) { - protobuf::Empty message1; - protobuf::Empty message2; - google::protobuf::Arena arena; - - CelFunctionDescriptor desc("OneMessage", false, {CelValue::Type::kMessage}); - - UnknownFunctionResult call1( - desc, /*expr_id=*/0, {CelProtoWrapper::CreateMessage(&message1, &arena)}); - UnknownFunctionResult call2( - desc, /*expr_id=*/0, {CelProtoWrapper::CreateMessage(&message2, &arena)}); - UnknownFunctionResult call3( - desc, /*expr_id=*/0, {CelProtoWrapper::CreateMessage(&message1, &arena)}); - - // &message1 == &message2 - EXPECT_FALSE(call1.IsEqualTo(call2)); - - // &message1 == &message1 - EXPECT_TRUE(call1.IsEqualTo(call3)); -} - -TEST(UnknownFunctionResult, AnyDescriptor) { - CelFunctionDescriptor anyDesc("OneAny", false, {CelValue::Type::kAny}); - - UnknownFunctionResult callAnyInt1(anyDesc, /*expr_id=*/0, - {CelValue::CreateInt64(2)}); - UnknownFunctionResult callInt(kOneInt, /*expr_id=*/0, - {CelValue::CreateInt64(2)}); - - UnknownFunctionResult callAnyInt2(anyDesc, /*expr_id=*/0, - {CelValue::CreateInt64(2)}); - UnknownFunctionResult callAnyUint(anyDesc, /*expr_id=*/0, - {CelValue::CreateUint64(2)}); - - EXPECT_FALSE(callAnyInt1.IsEqualTo(callInt)); - EXPECT_TRUE(IsLessThan(callAnyInt1, callInt)); - EXPECT_FALSE(callAnyInt1.IsEqualTo(callAnyUint)); - EXPECT_TRUE(IsLessThan(callAnyInt1, callAnyUint)); - EXPECT_TRUE(callAnyInt1.IsEqualTo(callAnyInt2)); -} - -TEST(UnknownFunctionResult, Strings) { - CelFunctionDescriptor desc("OneString", false, {CelValue::Type::kString}); - - UnknownFunctionResult callStringSmile(desc, /*expr_id=*/0, - {CelValue::CreateStringView("😁")}); - UnknownFunctionResult callStringFrown(desc, /*expr_id=*/0, - {CelValue::CreateStringView("🙁")}); - UnknownFunctionResult callStringSmile2(desc, /*expr_id=*/0, - {CelValue::CreateStringView("😁")}); - - EXPECT_TRUE(callStringSmile.IsEqualTo(callStringSmile2)); - EXPECT_FALSE(callStringSmile.IsEqualTo(callStringFrown)); -} - -TEST(UnknownFunctionResult, DurationHandling) { - google::protobuf::Arena arena; - absl::Duration duration1 = absl::Seconds(5); - protobuf::Duration duration2; - duration2.set_seconds(5); - - CelFunctionDescriptor durationDesc("OneDuration", false, - {CelValue::Type::kDuration}); - - UnknownFunctionResult callDuration1(durationDesc, /*expr_id=*/0, - {CelValue::CreateDuration(duration1)}); - UnknownFunctionResult callDuration2( - durationDesc, /*expr_id=*/0, - {CelProtoWrapper::CreateMessage(&duration2, &arena)}); - UnknownFunctionResult callDuration3( - durationDesc, /*expr_id=*/0, - {CelProtoWrapper::CreateDuration(&duration2)}); - - EXPECT_TRUE(callDuration1.IsEqualTo(callDuration2)); - EXPECT_TRUE(callDuration1.IsEqualTo(callDuration3)); -} - -TEST(UnknownFunctionResult, TimestampHandling) { - google::protobuf::Arena arena; - absl::Time ts1 = absl::FromUnixMillis(1000); - protobuf::Timestamp ts2; - ts2.set_seconds(1); - - CelFunctionDescriptor timestampDesc("OneTimestamp", false, - {CelValue::Type::kTimestamp}); - - UnknownFunctionResult callTimestamp1(timestampDesc, /*expr_id=*/0, - {CelValue::CreateTimestamp(ts1)}); - UnknownFunctionResult callTimestamp2( - timestampDesc, /*expr_id=*/0, - {CelProtoWrapper::CreateMessage(&ts2, &arena)}); - UnknownFunctionResult callTimestamp3( - timestampDesc, /*expr_id=*/0, {CelProtoWrapper::CreateTimestamp(&ts2)}); - - EXPECT_TRUE(callTimestamp1.IsEqualTo(callTimestamp2)); - EXPECT_TRUE(callTimestamp1.IsEqualTo(callTimestamp3)); -} - -// This tests that the conversion and different map backing implementations are -// compatible with the equality tests. -TEST(UnknownFunctionResult, ProtoStructTreatedAsMap) { - Arena arena; - - const std::vector kFields = {"field1", "field2", "field3"}; - - Struct value_struct; - - auto& value1 = (*value_struct.mutable_fields())[kFields[0]]; - value1.set_bool_value(true); - - auto& value2 = (*value_struct.mutable_fields())[kFields[1]]; - value2.set_number_value(1.0); - - auto& value3 = (*value_struct.mutable_fields())[kFields[2]]; - value3.set_string_value("test"); - - CelValue proto_struct = CelProtoWrapper::CreateMessage(&value_struct, &arena); - ASSERT_TRUE(proto_struct.IsMap()); - - std::vector> values{ - {CelValue::CreateStringView(kFields[2]), - CelValue::CreateStringView("test")}, - {CelValue::CreateStringView(kFields[1]), CelValue::CreateDouble(1.0)}, - {CelValue::CreateStringView(kFields[0]), CelValue::CreateBool(true)}}; - - auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)).value(); - - CelValue cel_map = CelValue::CreateMap(backing_map.get()); - - CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); - - UnknownFunctionResult proto_struct_result(desc, /*expr_id=*/0, - {proto_struct}); - UnknownFunctionResult cel_map_result(desc, /*expr_id=*/0, {cel_map}); - - EXPECT_TRUE(proto_struct_result.IsEqualTo(cel_map_result)); -} - -// This tests that the conversion and different map backing implementations are -// compatible with the equality tests. -TEST(UnknownFunctionResult, ProtoListTreatedAsList) { - Arena arena; - - ListValue list_value; - - list_value.add_values()->set_bool_value(true); - list_value.add_values()->set_number_value(1.0); - list_value.add_values()->set_string_value("test"); - - CelValue proto_list = CelProtoWrapper::CreateMessage(&list_value, &arena); - ASSERT_TRUE(proto_list.IsList()); - - std::vector list_values{CelValue::CreateBool(true), - CelValue::CreateDouble(1.0), - CelValue::CreateStringView("test")}; - - ContainerBackedListImpl list_backing(list_values); - - CelValue cel_list = CelValue::CreateList(&list_backing); - - CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); - - UnknownFunctionResult proto_list_result(desc, /*expr_id=*/0, {proto_list}); - UnknownFunctionResult cel_list_result(desc, /*expr_id=*/0, {cel_list}); - - EXPECT_TRUE(cel_list_result.IsEqualTo(proto_list_result)); -} - -TEST(UnknownFunctionResult, NestedProtoTypes) { - Arena arena; - - ListValue list_value; - - list_value.add_values()->set_bool_value(true); - list_value.add_values()->set_number_value(1.0); - list_value.add_values()->set_string_value("test"); - - std::vector list_values{CelValue::CreateBool(true), - CelValue::CreateDouble(1.0), - CelValue::CreateStringView("test")}; - - ContainerBackedListImpl list_backing(list_values); - - CelValue cel_list = CelValue::CreateList(&list_backing); - - Struct value_struct; - - *(value_struct.mutable_fields()->operator[]("field").mutable_list_value()) = - list_value; - - std::vector> values{ - {CelValue::CreateStringView("field"), cel_list}}; - - auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)).value(); - - CelValue cel_map = CelValue::CreateMap(backing_map.get()); - CelValue proto_map = CelProtoWrapper::CreateMessage(&value_struct, &arena); - - CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); - - UnknownFunctionResult cel_map_result(desc, /*expr_id=*/0, {cel_map}); - UnknownFunctionResult proto_struct_result(desc, /*expr_id=*/0, {proto_map}); - - EXPECT_TRUE(proto_struct_result.IsEqualTo(cel_map_result)); -} - -UnknownFunctionResult MakeUnknown(int64_t i) { - return UnknownFunctionResult(kOneInt, /*expr_id=*/0, - {CelValue::CreateInt64(i)}); -} - -testing::Matcher UnknownMatches( - const UnknownFunctionResult& obj) { - return testing::Truly([&](const UnknownFunctionResult* to_match) { - return obj.IsEqualTo(*to_match); - }); -} - -TEST(UnknownFunctionResultSet, Merge) { - UnknownFunctionResult a = MakeUnknown(1); - UnknownFunctionResult b = MakeUnknown(2); - UnknownFunctionResult c = MakeUnknown(3); - UnknownFunctionResult d = MakeUnknown(1); - - UnknownFunctionResultSet a1(&a); - UnknownFunctionResultSet b1(&b); - UnknownFunctionResultSet c1(&c); - UnknownFunctionResultSet d1(&d); - - UnknownFunctionResultSet ab(a1, b1); - UnknownFunctionResultSet cd(c1, d1); - - UnknownFunctionResultSet merged(ab, cd); - - EXPECT_THAT(merged.unknown_function_results(), SizeIs(3)); - EXPECT_THAT(merged.unknown_function_results(), - testing::UnorderedElementsAre( - UnknownMatches(a), UnknownMatches(b), UnknownMatches(c))); -} - } // namespace } // namespace runtime } // namespace expr diff --git a/eval/public/unknown_set_test.cc b/eval/public/unknown_set_test.cc index 6333a5826..0a9cafdf6 100644 --- a/eval/public/unknown_set_test.cc +++ b/eval/public/unknown_set_test.cc @@ -19,9 +19,8 @@ using testing::UnorderedElementsAre; UnknownFunctionResultSet MakeFunctionResult(Arena* arena, int64_t id) { CelFunctionDescriptor desc("OneInt", false, {CelValue::Type::kInt64}); - std::vector call_args{CelValue::CreateInt64(id)}; - const auto* function_result = Arena::Create( - arena, desc, /*expr_id=*/0, call_args); + const auto* function_result = + Arena::Create(arena, desc, /*expr_id=*/0); return UnknownFunctionResultSet(function_result); } @@ -48,17 +47,6 @@ MATCHER_P(UnknownAttributeIs, id, "") { return maybe_qualifier.value() == id; } -MATCHER_P(UnknownFunctionResultIs, id, "") { - const UnknownFunctionResult* result = arg; - if (result->arguments().size() != 1) { - return false; - } - if (!result->arguments()[0].IsInt64()) { - return false; - } - return result->arguments()[0].Int64OrDie() == id; -} - TEST(UnknownSet, AttributesMerge) { Arena arena; UnknownSet a(MakeAttribute(&arena, 1)); @@ -75,23 +63,6 @@ TEST(UnknownSet, AttributesMerge) { UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); } -TEST(UnknownSet, FunctionsMerge) { - Arena arena; - - UnknownSet a(MakeFunctionResult(&arena, 1)); - UnknownSet b(MakeFunctionResult(&arena, 2)); - UnknownSet c(MakeFunctionResult(&arena, 2)); - UnknownSet d(a, b); - UnknownSet e(c, d); - - EXPECT_THAT(d.unknown_function_results().unknown_function_results(), - UnorderedElementsAre(UnknownFunctionResultIs(1), - UnknownFunctionResultIs(2))); - EXPECT_THAT(e.unknown_function_results().unknown_function_results(), - UnorderedElementsAre(UnknownFunctionResultIs(1), - UnknownFunctionResultIs(2))); -} - TEST(UnknownSet, DefaultEmpty) { UnknownSet empty_set; EXPECT_THAT(empty_set.unknown_attributes().attributes(), IsEmpty()); @@ -110,15 +81,9 @@ TEST(UnknownSet, MixedMerges) { EXPECT_THAT(d.unknown_attributes().attributes(), UnorderedElementsAre(UnknownAttributeIs(1))); - EXPECT_THAT(d.unknown_function_results().unknown_function_results(), - UnorderedElementsAre(UnknownFunctionResultIs(1), - UnknownFunctionResultIs(2))); EXPECT_THAT( e.unknown_attributes().attributes(), UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); - EXPECT_THAT(e.unknown_function_results().unknown_function_results(), - UnorderedElementsAre(UnknownFunctionResultIs(1), - UnknownFunctionResultIs(2))); } } // namespace diff --git a/eval/tests/unknowns_end_to_end_test.cc b/eval/tests/unknowns_end_to_end_test.cc index 1d9a04fdd..cd873ea51 100644 --- a/eval/tests/unknowns_end_to_end_test.cc +++ b/eval/tests/unknowns_end_to_end_test.cc @@ -161,11 +161,9 @@ class UnknownsTest : public testing::Test { google::api::expr::v1alpha1::Expr expr_; }; -MATCHER_P2(FunctionCallIs, fn_name, fn_arg, "") { +MATCHER_P(FunctionCallIs, fn_name, "") { const UnknownFunctionResult* result = arg; - return result->arguments().size() == 1 && result->arguments()[0].IsString() && - result->arguments()[0].StringOrDie().value() == fn_arg && - result->descriptor().name() == fn_name; + return result->descriptor().name() == fn_name; } MATCHER_P(AttributeIs, attr, "") { @@ -280,7 +278,7 @@ TEST_F(UnknownsTest, UnknownFunctions) { EXPECT_THAT(response.UnknownSetOrDie() ->unknown_function_results() .unknown_function_results(), - ElementsAre(FunctionCallIs("F1", "arg1"))); + ElementsAre(FunctionCallIs("F1"))); } TEST_F(UnknownsTest, UnknownsMerge) { @@ -305,7 +303,7 @@ TEST_F(UnknownsTest, UnknownsMerge) { EXPECT_THAT(response.UnknownSetOrDie() ->unknown_function_results() .unknown_function_results(), - ElementsAre(FunctionCallIs("F1", "arg1"))); + ElementsAre(FunctionCallIs("F1"))); EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes().attributes(), ElementsAre(AttributeIs("var2"))); } @@ -457,7 +455,7 @@ TEST_F(UnknownsCompTest, UnknownsMerge) { EXPECT_THAT(response.UnknownSetOrDie() ->unknown_function_results() .unknown_function_results(), - testing::SizeIs(10)); + testing::SizeIs(1)); } constexpr char kListCompCondExpr[] = R"pb( From 813c83f1d8479a200c31658b5cf9826a540849b8 Mon Sep 17 00:00:00 2001 From: jcking Date: Thu, 19 May 2022 20:23:03 +0000 Subject: [PATCH 102/111] Internal change PiperOrigin-RevId: 449819425 --- base/internal/type.post.h | 1 + base/internal/value.post.h | 1 + base/type.cc | 6 +++ base/type.h | 29 +++++++++++++ base/type_factory.cc | 4 ++ base/type_factory.h | 2 + base/type_test.cc | 85 +++++++++++++++++++++++++++++--------- base/value.cc | 26 ++++++++++++ base/value.h | 39 +++++++++++++++++ base/value_factory.cc | 5 +++ base/value_factory.h | 3 ++ base/value_test.cc | 34 +++++++++++++++ 12 files changed, 215 insertions(+), 20 deletions(-) diff --git a/base/internal/type.post.h b/base/internal/type.post.h index 35111acc9..782de403a 100644 --- a/base/internal/type.post.h +++ b/base/internal/type.post.h @@ -277,6 +277,7 @@ CEL_INTERNAL_TYPE_DECL(EnumType); CEL_INTERNAL_TYPE_DECL(StructType); CEL_INTERNAL_TYPE_DECL(ListType); CEL_INTERNAL_TYPE_DECL(MapType); +CEL_INTERNAL_TYPE_DECL(TypeType); #undef CEL_INTERNAL_TYPE_DECL } // namespace cel diff --git a/base/internal/value.post.h b/base/internal/value.post.h index cbef6bf19..fafc3da87 100644 --- a/base/internal/value.post.h +++ b/base/internal/value.post.h @@ -679,6 +679,7 @@ CEL_INTERNAL_VALUE_DECL(EnumValue); CEL_INTERNAL_VALUE_DECL(StructValue); CEL_INTERNAL_VALUE_DECL(ListValue); CEL_INTERNAL_VALUE_DECL(MapValue); +CEL_INTERNAL_VALUE_DECL(TypeValue); #undef CEL_INTERNAL_VALUE_DECL } // namespace cel diff --git a/base/type.cc b/base/type.cc index dbaa8cada..82da70757 100644 --- a/base/type.cc +++ b/base/type.cc @@ -49,6 +49,7 @@ CEL_INTERNAL_TYPE_IMPL(EnumType); CEL_INTERNAL_TYPE_IMPL(StructType); CEL_INTERNAL_TYPE_IMPL(ListType); CEL_INTERNAL_TYPE_IMPL(MapType); +CEL_INTERNAL_TYPE_IMPL(TypeType); #undef CEL_INTERNAL_TYPE_IMPL absl::Span> Type::parameters() const { return {}; } @@ -198,4 +199,9 @@ void MapType::HashValue(absl::HashState state) const { Type::HashValue(absl::HashState::combine(std::move(state), key(), value())); } +const TypeType& TypeType::Get() { + static const internal::NoDestructor instance; + return *instance; +} + } // namespace cel diff --git a/base/type.h b/base/type.h index 2e9314278..5e8b5beb0 100644 --- a/base/type.h +++ b/base/type.h @@ -50,6 +50,7 @@ class TimestampType; class EnumType; class ListType; class MapType; +class TypeType; class TypeFactory; class TypeProvider; class TypeManager; @@ -66,6 +67,7 @@ class DurationValue; class TimestampValue; class EnumValue; class StructValue; +class TypeValue; class ValueFactory; class TypedEnumValueFactory; class TypedStructValueFactory; @@ -107,6 +109,7 @@ class Type : public base_internal::Resource { friend class StructType; friend class ListType; friend class MapType; + friend class TypeType; friend class base_internal::TypeHandleBase; Type() = default; @@ -715,6 +718,32 @@ class MapType : public Type { void HashValue(absl::HashState state) const final; }; +// TypeType represents the type of a type. +class TypeType final : public Type { + public: + Kind kind() const override { return Kind::kType; } + + absl::string_view name() const override { return "type"; } + + private: + friend class TypeValue; + friend class TypeFactory; + template + friend class internal::NoDestructor; + friend class base_internal::TypeHandleBase; + + // Called by base_internal::TypeHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Type& type) { return type.kind() == Kind::kType; } + + ABSL_ATTRIBUTE_PURE_FUNCTION static const TypeType& Get(); + + TypeType() = default; + + TypeType(const TypeType&) = delete; + TypeType(TypeType&&) = delete; +}; + } // namespace cel // type.pre.h forward declares types so they can be friended above. The types diff --git a/base/type_factory.cc b/base/type_factory.cc index b29a9ae30..b3f3d2aaa 100644 --- a/base/type_factory.cc +++ b/base/type_factory.cc @@ -116,6 +116,10 @@ Persistent TypeFactory::GetTimestampType() { return WrapSingletonType(); } +Persistent TypeFactory::GetTypeType() { + return WrapSingletonType(); +} + absl::StatusOr> TypeFactory::CreateListType( const Persistent& element) { absl::MutexLock lock(&list_types_mutex_); diff --git a/base/type_factory.h b/base/type_factory.h index 0ceab92cb..914cc5bdd 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -96,6 +96,8 @@ class TypeFactory { const Persistent& key, const Persistent& value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + Persistent GetTypeType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + private: template static Persistent WrapSingletonType() { diff --git a/base/type_test.cc b/base/type_test.cc index 5a4e844e6..6f3fe1de7 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -269,6 +269,7 @@ TEST_P(TypeTest, Null) { EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); EXPECT_FALSE(type_factory.GetNullType().Is()); + EXPECT_FALSE(type_factory.GetNullType().Is()); } TEST_P(TypeTest, Error) { @@ -291,6 +292,7 @@ TEST_P(TypeTest, Error) { EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); EXPECT_FALSE(type_factory.GetErrorType().Is()); + EXPECT_FALSE(type_factory.GetErrorType().Is()); } TEST_P(TypeTest, Dyn) { @@ -313,6 +315,7 @@ TEST_P(TypeTest, Dyn) { EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); EXPECT_FALSE(type_factory.GetDynType().Is()); + EXPECT_FALSE(type_factory.GetDynType().Is()); } TEST_P(TypeTest, Any) { @@ -335,6 +338,7 @@ TEST_P(TypeTest, Any) { EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); EXPECT_FALSE(type_factory.GetAnyType().Is()); + EXPECT_FALSE(type_factory.GetAnyType().Is()); } TEST_P(TypeTest, Bool) { @@ -357,6 +361,7 @@ TEST_P(TypeTest, Bool) { EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); EXPECT_FALSE(type_factory.GetBoolType().Is()); + EXPECT_FALSE(type_factory.GetBoolType().Is()); } TEST_P(TypeTest, Int) { @@ -379,6 +384,7 @@ TEST_P(TypeTest, Int) { EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); EXPECT_FALSE(type_factory.GetIntType().Is()); + EXPECT_FALSE(type_factory.GetIntType().Is()); } TEST_P(TypeTest, Uint) { @@ -401,6 +407,7 @@ TEST_P(TypeTest, Uint) { EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); EXPECT_FALSE(type_factory.GetUintType().Is()); + EXPECT_FALSE(type_factory.GetUintType().Is()); } TEST_P(TypeTest, Double) { @@ -423,6 +430,7 @@ TEST_P(TypeTest, Double) { EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); EXPECT_FALSE(type_factory.GetDoubleType().Is()); + EXPECT_FALSE(type_factory.GetDoubleType().Is()); } TEST_P(TypeTest, String) { @@ -445,6 +453,7 @@ TEST_P(TypeTest, String) { EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); EXPECT_FALSE(type_factory.GetStringType().Is()); + EXPECT_FALSE(type_factory.GetStringType().Is()); } TEST_P(TypeTest, Bytes) { @@ -467,6 +476,7 @@ TEST_P(TypeTest, Bytes) { EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); EXPECT_FALSE(type_factory.GetBytesType().Is()); + EXPECT_FALSE(type_factory.GetBytesType().Is()); } TEST_P(TypeTest, Duration) { @@ -489,6 +499,7 @@ TEST_P(TypeTest, Duration) { EXPECT_FALSE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); EXPECT_FALSE(type_factory.GetDurationType().Is()); + EXPECT_FALSE(type_factory.GetDurationType().Is()); } TEST_P(TypeTest, Timestamp) { @@ -512,6 +523,7 @@ TEST_P(TypeTest, Timestamp) { EXPECT_FALSE(type_factory.GetTimestampType().Is()); EXPECT_FALSE(type_factory.GetTimestampType().Is()); EXPECT_FALSE(type_factory.GetTimestampType().Is()); + EXPECT_FALSE(type_factory.GetTimestampType().Is()); } TEST_P(TypeTest, Enum) { @@ -537,31 +549,33 @@ TEST_P(TypeTest, Enum) { EXPECT_FALSE(enum_type.Is()); EXPECT_FALSE(enum_type.Is()); EXPECT_FALSE(enum_type.Is()); + EXPECT_FALSE(enum_type.Is()); } TEST_P(TypeTest, Struct) { TypeManager type_manager(memory_manager()); - ASSERT_OK_AND_ASSIGN(auto enum_type, + ASSERT_OK_AND_ASSIGN(auto struct_type, type_manager.CreateStructType()); - EXPECT_EQ(enum_type->kind(), Kind::kStruct); - EXPECT_EQ(enum_type->name(), "test_struct.TestStruct"); - EXPECT_THAT(enum_type->parameters(), SizeIs(0)); - EXPECT_FALSE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); - EXPECT_TRUE(enum_type.Is()); - EXPECT_TRUE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); - EXPECT_FALSE(enum_type.Is()); + EXPECT_EQ(struct_type->kind(), Kind::kStruct); + EXPECT_EQ(struct_type->name(), "test_struct.TestStruct"); + EXPECT_THAT(struct_type->parameters(), SizeIs(0)); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_TRUE(struct_type.Is()); + EXPECT_TRUE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); + EXPECT_FALSE(struct_type.Is()); } TEST_P(TypeTest, List) { @@ -589,6 +603,7 @@ TEST_P(TypeTest, List) { EXPECT_FALSE(list_type.Is()); EXPECT_TRUE(list_type.Is()); EXPECT_FALSE(list_type.Is()); + EXPECT_FALSE(list_type.Is()); } TEST_P(TypeTest, Map) { @@ -622,6 +637,30 @@ TEST_P(TypeTest, Map) { EXPECT_FALSE(map_type.Is()); EXPECT_FALSE(map_type.Is()); EXPECT_TRUE(map_type.Is()); + EXPECT_FALSE(map_type.Is()); +} + +TEST_P(TypeTest, TypeType) { + TypeFactory type_factory(memory_manager()); + EXPECT_EQ(type_factory.GetTypeType()->kind(), Kind::kType); + EXPECT_EQ(type_factory.GetTypeType()->name(), "type"); + EXPECT_THAT(type_factory.GetTypeType()->parameters(), SizeIs(0)); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_FALSE(type_factory.GetTypeType().Is()); + EXPECT_TRUE(type_factory.GetTypeType().Is()); } using EnumTypeTest = TypeTest; @@ -824,6 +863,11 @@ TEST_P(DebugStringTest, MapType) { EXPECT_EQ(map_type->DebugString(), "map(string, bool)"); } +TEST_P(DebugStringTest, TypeType) { + TypeFactory type_factory(memory_manager()); + EXPECT_EQ(type_factory.GetTypeType()->DebugString(), "type"); +} + INSTANTIATE_TEST_SUITE_P(DebugStringTest, DebugStringTest, base_internal::MemoryManagerTestModeAll(), base_internal::MemoryManagerTestModeName); @@ -850,6 +894,7 @@ TEST_P(TypeTest, SupportsAbslHash) { Must(type_factory.CreateListType(type_factory.GetBoolType()))), Persistent(Must(type_factory.CreateMapType( type_factory.GetStringType(), type_factory.GetBoolType()))), + Persistent(type_factory.GetTypeType()), })); } diff --git a/base/value.cc b/base/value.cc index ed9c7b017..e5165ddb6 100644 --- a/base/value.cc +++ b/base/value.cc @@ -65,6 +65,8 @@ CEL_INTERNAL_VALUE_IMPL(TimestampValue); CEL_INTERNAL_VALUE_IMPL(EnumValue); CEL_INTERNAL_VALUE_IMPL(StructValue); CEL_INTERNAL_VALUE_IMPL(ListValue); +CEL_INTERNAL_VALUE_IMPL(MapValue); +CEL_INTERNAL_VALUE_IMPL(TypeValue); #undef CEL_INTERNAL_VALUE_IMPL namespace { @@ -876,6 +878,30 @@ absl::StatusOr StructValue::HasField(FieldId field) const { return absl::visit(HasFieldVisitor{*this}, field.data_); } +Transient TypeValue::type() const { + return TransientHandleFactory::MakeUnmanaged( + TypeType::Get()); +} + +std::string TypeValue::DebugString() const { return value()->DebugString(); } + +void TypeValue::CopyTo(Value& address) const { + CEL_COPY_TO_IMPL(TypeValue, *this, address); +} + +void TypeValue::MoveTo(Value& address) { + CEL_MOVE_TO_IMPL(TypeValue, *this, address); +} + +bool TypeValue::Equals(const Value& other) const { + return kind() == other.kind() && + value() == internal::down_cast(other).value(); +} + +void TypeValue::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), type(), value()); +} + namespace base_internal { absl::Cord InlinedCordBytesValue::ToCord(bool reference_counted) const { diff --git a/base/value.h b/base/value.h index 2cb47a93d..e383df2be 100644 --- a/base/value.h +++ b/base/value.h @@ -55,6 +55,7 @@ class EnumValue; class StructValue; class ListValue; class MapValue; +class TypeValue; class ValueFactory; namespace internal { @@ -91,6 +92,7 @@ class Value : public base_internal::Resource { friend class StructValue; friend class ListValue; friend class MapValue; + friend class TypeValue; friend class base_internal::ValueHandleBase; friend class base_internal::StringBytesValue; friend class base_internal::ExternalDataBytesValue; @@ -889,6 +891,43 @@ class MapValue : public Value { #define CEL_IMPLEMENT_MAP_VALUE(map_value) \ CEL_INTERNAL_IMPLEMENT_VALUE(Map, map_value) +// TypeValue represents an instance of cel::Type. +class TypeValue final : public Value, base_internal::ResourceInlined { + public: + Transient type() const override; + + Kind kind() const override { return Kind::kType; } + + std::string DebugString() const override; + + Transient value() const { return value_; } + + private: + template + friend class base_internal::ValueHandle; + friend class base_internal::ValueHandleBase; + + // Called by base_internal::ValueHandleBase to implement Is for Transient and + // Persistent. + static bool Is(const Value& value) { return value.kind() == Kind::kType; } + + // Called by `base_internal::ValueHandle` to construct value inline. + explicit TypeValue(Persistent type) : value_(std::move(type)) {} + + TypeValue() = delete; + + TypeValue(const TypeValue&) = default; + TypeValue(TypeValue&&) = default; + + // See comments for respective member functions on `Value`. + void CopyTo(Value& address) const override; + void MoveTo(Value& address) override; + bool Equals(const Value& other) const override; + void HashValue(absl::HashState state) const override; + + Persistent value_; +}; + } // namespace cel // value.pre.h forward declares types so they can be friended above. The types diff --git a/base/value_factory.cc b/base/value_factory.cc index d6831f9eb..5364e8f42 100644 --- a/base/value_factory.cc +++ b/base/value_factory.cc @@ -137,6 +137,11 @@ ValueFactory::CreateTimestampValue(absl::Time value) { value); } +Persistent ValueFactory::CreateTypeValue( + const Persistent& value) { + return PersistentHandleFactory::Make(value); +} + Persistent ValueFactory::GetEmptyBytesValue() { return PersistentHandleFactory::Make< InlinedStringViewBytesValue>(absl::string_view()); diff --git a/base/value_factory.h b/base/value_factory.h index 20829e2fb..f795adf61 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -174,6 +174,9 @@ class ValueFactory final { std::forward(args)...); } + Persistent CreateTypeValue( + const Persistent& value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + private: friend class BytesValue; friend class StringValue; diff --git a/base/value_test.cc b/base/value_test.cc index 1e4ccd3c1..0f5c0b63d 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -688,6 +688,11 @@ INSTANTIATE_TEST_SUITE_P( type_factory.GetIntType())), std::map{})); }}, + {"Type", + [](TypeFactory& type_factory, + ValueFactory& value_factory) -> Persistent { + return value_factory.CreateTypeValue(type_factory.GetNullType()); + }}, })), [](const testing::TestParamInfo< std::tuple()); + EXPECT_FALSE(null_value.Is()); + EXPECT_EQ(null_value, null_value); + EXPECT_EQ(null_value, + value_factory.CreateTypeValue(type_factory.GetNullType())); + EXPECT_EQ(null_value->kind(), Kind::kType); + EXPECT_EQ(null_value->type(), type_factory.GetTypeType()); + EXPECT_EQ(null_value->value(), type_factory.GetNullType()); + + auto int_value = value_factory.CreateTypeValue(type_factory.GetIntType()); + EXPECT_TRUE(int_value.Is()); + EXPECT_FALSE(int_value.Is()); + EXPECT_EQ(int_value, int_value); + EXPECT_EQ(int_value, + value_factory.CreateTypeValue(type_factory.GetIntType())); + EXPECT_EQ(int_value->kind(), Kind::kType); + EXPECT_EQ(int_value->type(), type_factory.GetTypeType()); + EXPECT_EQ(int_value->value(), type_factory.GetIntType()); + + EXPECT_NE(null_value, int_value); + EXPECT_NE(int_value, null_value); +} + Persistent MakeStringBytes(ValueFactory& value_factory, absl::string_view value) { return Must(value_factory.CreateBytesValue(value)); @@ -2389,6 +2421,8 @@ TEST_P(ValueTest, SupportsAbslHash) { Persistent(struct_value), Persistent(list_value), Persistent(map_value), + Persistent( + value_factory.CreateTypeValue(type_factory.GetNullType())), })); } From 345f12ff0de5f63c07203daf92fe4d07b6811e41 Mon Sep 17 00:00:00 2001 From: jcking Date: Fri, 20 May 2022 21:27:23 +0000 Subject: [PATCH 103/111] Internal change PiperOrigin-RevId: 450063543 --- base/BUILD | 3 + base/type_manager.cc | 57 ++++++++++++++++ base/type_manager.h | 20 +++++- base/type_provider.cc | 147 ++++++++++++++++++++++++++++++++++++++++++ base/type_provider.h | 4 ++ 5 files changed, 230 insertions(+), 1 deletion(-) create mode 100644 base/type_manager.cc create mode 100644 base/type_provider.cc diff --git a/base/BUILD b/base/BUILD index 1eb2ef747..cfab9e597 100644 --- a/base/BUILD +++ b/base/BUILD @@ -104,6 +104,8 @@ cc_library( srcs = [ "type.cc", "type_factory.cc", + "type_manager.cc", + "type_provider.cc", ], hdrs = [ "type.h", @@ -120,6 +122,7 @@ cc_library( "//internal:casts", "//internal:no_destructor", "//internal:rtti", + "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", diff --git a/base/type_manager.cc b/base/type_manager.cc new file mode 100644 index 000000000..ed706a361 --- /dev/null +++ b/base/type_manager.cc @@ -0,0 +1,57 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/type_manager.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "internal/status_macros.h" + +namespace cel { + +absl::StatusOr> TypeManager::ProvideType( + absl::string_view name) { + // The const_cast is safe because TypeFactory can never call back into + // TypeRegistry or TypeManager methods we expose or override in TypeManager. + // Thus for state defined by TypeManager, we are effectively const. + return ProvideType(const_cast(*this), name); +} + +absl::StatusOr> TypeManager::ProvideType( + TypeFactory& type_factory, absl::string_view name) const { + { + // Check for builtin types first. + CEL_ASSIGN_OR_RETURN( + auto type, TypeProvider::Builtin().ProvideType(type_factory, name)); + if (type) { + return type; + } + } + // Check with the type registry. + absl::MutexLock lock(&mutex_); + auto existing = types_.find(name); + if (existing == types_.end()) { + // Delegate to TypeRegistry implementation. + CEL_ASSIGN_OR_RETURN(auto type, + TypeRegistry::ProvideType(type_factory, name)); + ABSL_ASSERT(!type || type->name() == name); + existing = types_.insert({std::string(name), std::move(type)}).first; + } + return existing->second; +} + +} // namespace cel diff --git a/base/type_manager.h b/base/type_manager.h index 28353e6b7..a50710eeb 100644 --- a/base/type_manager.h +++ b/base/type_manager.h @@ -15,7 +15,14 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_MANAGER_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPE_MANAGER_H_ +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" #include "base/type_factory.h" +#include "base/type_provider.h" #include "base/type_registry.h" namespace cel { @@ -25,9 +32,20 @@ namespace cel { // and registering type implementations. // // TODO(issues/5): more comments after solidifying role -class TypeManager : public TypeFactory, public TypeRegistry { +class TypeManager final : public TypeFactory, public TypeRegistry { public: using TypeFactory::TypeFactory; + + absl::StatusOr> ProvideType(absl::string_view name); + + absl::StatusOr> ProvideType( + TypeFactory& type_factory, absl::string_view name) const override; + + private: + mutable absl::Mutex mutex_; + // std::string as the key because we also cache types which do not exist. + mutable absl::flat_hash_map> types_ + ABSL_GUARDED_BY(mutex_); }; } // namespace cel diff --git a/base/type_provider.cc b/base/type_provider.cc new file mode 100644 index 000000000..6e42993d2 --- /dev/null +++ b/base/type_provider.cc @@ -0,0 +1,147 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/type_provider.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "base/type_factory.h" +#include "internal/no_destructor.h" + +namespace cel { + +namespace { + +class BuiltinTypeProvider final : public TypeProvider { + public: + using BuiltinType = + std::pair> (*)(TypeFactory&)>; + + BuiltinTypeProvider() + : types_{{ + {"null_type", GetNullType}, + {"bool", GetBoolType}, + {"int", GetIntType}, + {"uint", GetUintType}, + {"double", GetDoubleType}, + {"bytes", GetBytesType}, + {"string", GetStringType}, + {"google.protobuf.Duration", GetDurationType}, + {"google.protobuf.Timestamp", GetTimestampType}, + {"list", GetListType}, + {"map", GetMapType}, + {"type", GetTypeType}, + }} { + std::stable_sort( + types_.begin(), types_.end(), + [](const BuiltinType& lhs, const BuiltinType& rhs) -> bool { + return lhs.first < rhs.first; + }); + } + + absl::StatusOr> ProvideType( + TypeFactory& type_factory, absl::string_view name) const override { + auto existing = std::lower_bound( + types_.begin(), types_.end(), name, + [](const BuiltinType& lhs, absl::string_view rhs) -> bool { + return lhs.first < rhs; + }); + if (existing == types_.end() || existing->first != name) { + return Persistent(); + } + return (existing->second)(type_factory); + } + + private: + static absl::StatusOr> GetNullType( + TypeFactory& type_factory) { + return type_factory.GetNullType(); + } + + static absl::StatusOr> GetBoolType( + TypeFactory& type_factory) { + return type_factory.GetBoolType(); + } + + static absl::StatusOr> GetIntType( + TypeFactory& type_factory) { + return type_factory.GetIntType(); + } + + static absl::StatusOr> GetUintType( + TypeFactory& type_factory) { + return type_factory.GetUintType(); + } + + static absl::StatusOr> GetDoubleType( + TypeFactory& type_factory) { + return type_factory.GetDoubleType(); + } + + static absl::StatusOr> GetBytesType( + TypeFactory& type_factory) { + return type_factory.GetBytesType(); + } + + static absl::StatusOr> GetStringType( + TypeFactory& type_factory) { + return type_factory.GetStringType(); + } + + static absl::StatusOr> GetDurationType( + TypeFactory& type_factory) { + return type_factory.GetDurationType(); + } + + static absl::StatusOr> GetTimestampType( + TypeFactory& type_factory) { + return type_factory.GetTimestampType(); + } + + static absl::StatusOr> GetListType( + TypeFactory& type_factory) { + // The element type does not matter. + return type_factory.CreateListType(type_factory.GetDynType()); + } + + static absl::StatusOr> GetMapType( + TypeFactory& type_factory) { + // The key and value types do not matter. + return type_factory.CreateMapType(type_factory.GetDynType(), + type_factory.GetDynType()); + } + + static absl::StatusOr> GetTypeType( + TypeFactory& type_factory) { + return type_factory.GetTypeType(); + } + + std::array types_; +}; + +} // namespace + +const TypeProvider& TypeProvider::Builtin() { + static const internal::NoDestructor instance; + return *instance; +} + +} // namespace cel diff --git a/base/type_provider.h b/base/type_provider.h index 8a481801c..1db6ae7e9 100644 --- a/base/type_provider.h +++ b/base/type_provider.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ +#include "absl/base/attributes.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -37,6 +38,9 @@ class TypeFactory; // implementations. class TypeProvider { public: + // Returns a TypeProvider which provides all of CEL's builtin types. + ABSL_ATTRIBUTE_PURE_FUNCTION static const TypeProvider& Builtin(); + virtual ~TypeProvider() = default; // Return a persistent handle to a Type for the fully qualified type name, if From 2929d2b11a2e5bbfeedac86f093a9108bb2faa05 Mon Sep 17 00:00:00 2001 From: jcking Date: Tue, 24 May 2022 21:58:45 +0000 Subject: [PATCH 104/111] Internal change PiperOrigin-RevId: 450776472 --- base/internal/value.pre.h | 9 +++++++++ base/value.h | 17 +++++++++++++---- base/value_factory.cc | 12 ++++++++++++ base/value_factory.h | 19 +++++++++++++++++++ 4 files changed, 53 insertions(+), 4 deletions(-) diff --git a/base/internal/value.pre.h b/base/internal/value.pre.h index 4441bc7d9..ebd7c1685 100644 --- a/base/internal/value.pre.h +++ b/base/internal/value.pre.h @@ -18,9 +18,13 @@ #define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_VALUE_PRE_H_ #include +#include #include #include +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" #include "base/handle.h" #include "internal/rtti.h" @@ -173,6 +177,11 @@ struct ExternalData final { std::unique_ptr releaser; }; +using StringValueRep = + absl::variant>; +using BytesValueRep = + absl::variant>; + } // namespace base_internal } // namespace cel diff --git a/base/value.h b/base/value.h index e383df2be..72e1cd626 100644 --- a/base/value.h +++ b/base/value.h @@ -63,6 +63,13 @@ template class NoDestructor; } +namespace interop_internal { +base_internal::StringValueRep GetStringValueRep( + const Transient& value); +base_internal::BytesValueRep GetBytesValueRep( + const Transient& value); +} // namespace interop_internal + // A representation of a CEL value that enables reflection and introspection of // values. class Value : public base_internal::Resource { @@ -357,8 +364,7 @@ class DoubleValue final : public Value, public base_internal::ResourceInlined { class BytesValue : public Value { protected: - using Rep = absl::variant>; + using Rep = base_internal::BytesValueRep; public: static Persistent Empty(ValueFactory& value_factory); @@ -403,6 +409,8 @@ class BytesValue : public Value { friend class base_internal::InlinedStringViewBytesValue; friend class base_internal::StringBytesValue; friend class base_internal::ExternalDataBytesValue; + friend base_internal::BytesValueRep interop_internal::GetBytesValueRep( + const Transient& value); // Called by base_internal::ValueHandleBase to implement Is for Transient and // Persistent. @@ -428,8 +436,7 @@ class BytesValue : public Value { class StringValue : public Value { protected: - using Rep = absl::variant>; + using Rep = base_internal::StringValueRep; public: static Persistent Empty(ValueFactory& value_factory); @@ -471,6 +478,8 @@ class StringValue : public Value { friend class base_internal::InlinedStringViewStringValue; friend class base_internal::StringStringValue; friend class base_internal::ExternalDataStringValue; + friend base_internal::StringValueRep interop_internal::GetStringValueRep( + const Transient& value); // Called by base_internal::ValueHandleBase to implement Is for Transient and // Persistent. diff --git a/base/value_factory.cc b/base/value_factory.cc index 5364e8f42..e410ee1c9 100644 --- a/base/value_factory.cc +++ b/base/value_factory.cc @@ -142,6 +142,12 @@ Persistent ValueFactory::CreateTypeValue( return PersistentHandleFactory::Make(value); } +absl::StatusOr> +ValueFactory::CreateBytesValueFromView(absl::string_view value) { + return PersistentHandleFactory::Make< + InlinedStringViewBytesValue>(value); +} + Persistent ValueFactory::GetEmptyBytesValue() { return PersistentHandleFactory::Make< InlinedStringViewBytesValue>(absl::string_view()); @@ -173,4 +179,10 @@ absl::StatusOr> ValueFactory::CreateStringValue( ExternalDataStringValue>(memory_manager(), std::move(value)); } +absl::StatusOr> +ValueFactory::CreateStringValueFromView(absl::string_view value) { + return PersistentHandleFactory::Make< + InlinedStringViewStringValue>(value); +} + } // namespace cel diff --git a/base/value_factory.h b/base/value_factory.h index f795adf61..281c6e6eb 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -33,6 +33,13 @@ namespace cel { +namespace interop_internal { +absl::StatusOr> CreateStringValueFromView( + cel::ValueFactory& value_factory, absl::string_view input); +absl::StatusOr> CreateBytesValueFromView( + cel::ValueFactory& value_factory, absl::string_view input); +} // namespace interop_internal + class ValueFactory final { private: template @@ -180,6 +187,12 @@ class ValueFactory final { private: friend class BytesValue; friend class StringValue; + friend absl::StatusOr> + interop_internal::CreateStringValueFromView(cel::ValueFactory& value_factory, + absl::string_view input); + friend absl::StatusOr> + interop_internal::CreateBytesValueFromView(cel::ValueFactory& value_factory, + absl::string_view input); MemoryManager& memory_manager() const { return memory_manager_; } @@ -189,6 +202,9 @@ class ValueFactory final { absl::StatusOr> CreateBytesValue( base_internal::ExternalData value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::StatusOr> CreateBytesValueFromView( + absl::string_view value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + Persistent GetEmptyStringValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; @@ -198,6 +214,9 @@ class ValueFactory final { absl::StatusOr> CreateStringValue( base_internal::ExternalData value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + absl::StatusOr> CreateStringValueFromView( + absl::string_view value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + MemoryManager& memory_manager_; }; From 95d4b2deb787aa5eae33b33ef9eac0e7a197a113 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 26 May 2022 18:52:46 +0000 Subject: [PATCH 105/111] Bring native types closer to protos. Do not assert on mutable_*() calls when underlying std::unique_ptr<*> is unset. PiperOrigin-RevId: 451211992 --- base/ast.h | 56 ++++++++++++++++++++++++++++-------- base/ast_test.cc | 75 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 12 deletions(-) diff --git a/base/ast.h b/base/ast.h index 7eb0dfc25..5cba4267d 100644 --- a/base/ast.h +++ b/base/ast.h @@ -74,6 +74,7 @@ class Ident { // A field selection expression. e.g. `request.auth`. class Select { public: + Select() {} Select(std::unique_ptr operand, std::string field, bool test_only = false) : operand_(std::move(operand)), @@ -91,7 +92,9 @@ class Select { const Expr* operand() const { return operand_.get(); } Expr& mutable_operand() { - ABSL_ASSERT(operand_ != nullptr); + if (operand_ == nullptr) { + operand_ = std::make_unique(); + } return *operand_; } @@ -122,6 +125,7 @@ class Select { // (-- TODO(issues/5): Convert built-in globals to instance methods --) class Call { public: + Call() {} Call(std::unique_ptr target, std::string function, std::vector args) : target_(std::move(target)), @@ -137,7 +141,9 @@ class Call { const Expr* target() const { return target_.get(); } Expr& mutable_target() { - ABSL_ASSERT(target_ != nullptr); + if (target_ == nullptr) { + target_ = std::make_unique(); + } return *target_; } @@ -194,6 +200,7 @@ class CreateStruct { // Represents an entry. class Entry { public: + Entry() {} Entry(int64_t id, absl::variant> key_kind, std::unique_ptr value) @@ -221,7 +228,9 @@ class CreateStruct { const Expr* value() const { return value_.get(); } Expr& mutable_value() { - ABSL_ASSERT(value_ != nullptr); + if (value_ == nullptr) { + value_ = std::make_unique(); + } return *value_; } @@ -347,7 +356,9 @@ class Comprehension { const Expr* iter_range() const { return iter_range_.get(); } Expr& mutable_iter_range() { - ABSL_ASSERT(iter_range_ != nullptr); + if (iter_range_ == nullptr) { + iter_range_ = std::make_unique(); + } return *iter_range_; } @@ -356,28 +367,36 @@ class Comprehension { const Expr* accu_init() const { return accu_init_.get(); } Expr& mutable_accu_init() { - ABSL_ASSERT(accu_init_ != nullptr); + if (accu_init_ == nullptr) { + accu_init_ = std::make_unique(); + } return *accu_init_; } const Expr* loop_condition() const { return loop_condition_.get(); } Expr& mutable_loop_condition() { - ABSL_ASSERT(loop_condition_ != nullptr); + if (loop_condition_ == nullptr) { + loop_condition_ = std::make_unique(); + } return *loop_condition_; } const Expr* loop_step() const { return loop_step_.get(); } Expr& mutable_loop_step() { - ABSL_ASSERT(loop_step_ != nullptr); + if (loop_step_ == nullptr) { + loop_step_ = std::make_unique(); + } return *loop_step_; } const Expr* result() const { return result_.get(); } Expr& mutable_result() { - ABSL_ASSERT(result_ != nullptr); + if (result_ == nullptr) { + result_ = std::make_unique(); + } return *result_; } @@ -631,6 +650,8 @@ class Type; // List type with typed elements, e.g. `list`. class ListType { + public: + ListType() {} explicit ListType(std::unique_ptr elem_type) : elem_type_(std::move(elem_type)) {} @@ -641,7 +662,9 @@ class ListType { const Type* elem_type() const { return elem_type_.get(); } Type& mutable_elem_type() { - ABSL_ASSERT(elem_type_ != nullptr); + if (elem_type_ == nullptr) { + elem_type_ = std::make_unique(); + } return *elem_type_; } @@ -652,6 +675,7 @@ class ListType { // Map type with parameterized key and value types, e.g. `map`. class MapType { public: + MapType() {} MapType(std::unique_ptr key_type, std::unique_ptr value_type) : key_type_(std::move(key_type)), value_type_(std::move(value_type)) {} @@ -668,12 +692,16 @@ class MapType { const Type* value_type() const { return value_type_.get(); } Type& mutable_key_type() { - ABSL_ASSERT(key_type_ != nullptr); + if (key_type_ == nullptr) { + key_type_ = std::make_unique(); + } return *key_type_; } Type& mutable_value_type() { - ABSL_ASSERT(value_type_ != nullptr); + if (value_type_ == nullptr) { + value_type_ = std::make_unique(); + } return *value_type_; } @@ -693,6 +721,7 @@ class MapType { // --) class FunctionType { public: + FunctionType() {} FunctionType(std::unique_ptr result_type, std::vector arg_types) : result_type_(std::move(result_type)), arg_types_(std::move(arg_types)) {} @@ -708,7 +737,9 @@ class FunctionType { const Type* result_type() const { return result_type_.get(); } Type& mutable_result_type() { - ABSL_ASSERT(result_type_.get() != nullptr); + if (result_type_ == nullptr) { + result_type_ = std::make_unique(); + } return *result_type_; } @@ -819,6 +850,7 @@ using TypeKind = // TODO(issues/5): align with value.proto class Type { public: + Type() {} explicit Type(TypeKind type_kind) : type_kind_(std::move(type_kind)) {} Type(Type&& rhs) = default; diff --git a/base/ast_test.cc b/base/ast_test.cc index 987ef4b8a..8f1bf3bd7 100644 --- a/base/ast_test.cc +++ b/base/ast_test.cc @@ -48,6 +48,13 @@ TEST(AstTest, ExprConstructionSelect) { ASSERT_EQ(select.field(), "field"); } +TEST(AstTest, SelectMutableOperand) { + Select select; + select.mutable_operand().set_expr_kind(Ident("var")); + ASSERT_TRUE(absl::holds_alternative(select.operand()->expr_kind())); + ASSERT_EQ(absl::get(select.operand()->expr_kind()).name(), "var"); +} + TEST(AstTest, ExprConstructionCall) { Expr expr(1, Call(std::make_unique(2, Ident("var")), "function", {})); ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); @@ -58,6 +65,13 @@ TEST(AstTest, ExprConstructionCall) { ASSERT_TRUE(call.args().empty()); } +TEST(AstTest, CallMutableTarget) { + Call call; + call.mutable_target().set_expr_kind(Ident("var")); + ASSERT_TRUE(absl::holds_alternative(call.target()->expr_kind())); + ASSERT_EQ(absl::get(call.target()->expr_kind()).name(), "var"); +} + TEST(AstTest, ExprConstructionCreateList) { CreateList create_list; create_list.mutable_elements().emplace_back(Expr(2, Ident("var1"))); @@ -96,6 +110,13 @@ TEST(AstTest, ExprConstructionCreateStruct) { ASSERT_EQ(absl::get(entries[2].value()->expr_kind()).name(), "value3"); } +TEST(AstTest, CreateStructEntryMutableValue) { + CreateStruct::Entry entry; + entry.mutable_value().set_expr_kind(Ident("var")); + ASSERT_TRUE(absl::holds_alternative(entry.value()->expr_kind())); + ASSERT_EQ(absl::get(entry.value()->expr_kind()).name(), "var"); +} + TEST(AstTest, ExprConstructionComprehension) { Comprehension comprehension; comprehension.set_iter_var("iter_var"); @@ -122,6 +143,36 @@ TEST(AstTest, ExprConstructionComprehension) { "result"); } +TEST(AstTest, ComprehensionMutableConstruction) { + Comprehension comprehension; + comprehension.mutable_iter_range().set_expr_kind(Ident("var")); + ASSERT_TRUE( + absl::holds_alternative(comprehension.iter_range()->expr_kind())); + ASSERT_EQ(absl::get(comprehension.iter_range()->expr_kind()).name(), + "var"); + comprehension.mutable_accu_init().set_expr_kind(Ident("var")); + ASSERT_TRUE( + absl::holds_alternative(comprehension.accu_init()->expr_kind())); + ASSERT_EQ(absl::get(comprehension.accu_init()->expr_kind()).name(), + "var"); + comprehension.mutable_loop_condition().set_expr_kind(Ident("var")); + ASSERT_TRUE(absl::holds_alternative( + comprehension.loop_condition()->expr_kind())); + ASSERT_EQ( + absl::get(comprehension.loop_condition()->expr_kind()).name(), + "var"); + comprehension.mutable_loop_step().set_expr_kind(Ident("var")); + ASSERT_TRUE( + absl::holds_alternative(comprehension.loop_step()->expr_kind())); + ASSERT_EQ(absl::get(comprehension.loop_step()->expr_kind()).name(), + "var"); + comprehension.mutable_result().set_expr_kind(Ident("var")); + ASSERT_TRUE( + absl::holds_alternative(comprehension.result()->expr_kind())); + ASSERT_EQ(absl::get(comprehension.result()->expr_kind()).name(), + "var"); +} + TEST(AstTest, ExprMoveTest) { Expr expr(1, Ident("var")); ASSERT_TRUE(absl::holds_alternative(expr.expr_kind())); @@ -150,6 +201,30 @@ TEST(AstTest, ParsedExpr) { testing::UnorderedElementsAre(testing::Pair(1, 1), testing::Pair(2, 2))); } +TEST(AstTest, ListTypeMutableConstruction) { + ListType type; + type.mutable_elem_type() = Type(PrimitiveType::kBool); + EXPECT_EQ(absl::get(type.elem_type()->type_kind()), + PrimitiveType::kBool); +} + +TEST(AstTest, MapTypeMutableConstruction) { + MapType type; + type.mutable_key_type() = Type(PrimitiveType::kBool); + type.mutable_value_type() = Type(PrimitiveType::kBool); + EXPECT_EQ(absl::get(type.key_type()->type_kind()), + PrimitiveType::kBool); + EXPECT_EQ(absl::get(type.value_type()->type_kind()), + PrimitiveType::kBool); +} + +TEST(AstTest, FunctionTypeMutableConstruction) { + FunctionType type; + type.mutable_result_type() = Type(PrimitiveType::kBool); + EXPECT_EQ(absl::get(type.result_type()->type_kind()), + PrimitiveType::kBool); +} + TEST(AstTest, CheckedExpr) { CheckedExpr checked_expr; checked_expr.set_expr(Expr(1, Ident("name"))); From 21db34fc700f81f7d0f30306039511aa45bad3fa Mon Sep 17 00:00:00 2001 From: jcking Date: Sat, 28 May 2022 00:53:20 +0000 Subject: [PATCH 106/111] Internal change. PiperOrigin-RevId: 451524495 --- base/type_factory.h | 8 +- base/type_manager.cc | 14 +-- base/type_manager.h | 21 +++- base/type_provider.cc | 4 +- base/type_provider.h | 5 +- base/type_test.cc | 60 +++++----- base/value_factory.h | 20 +++- base/value_factory_test.cc | 8 +- base/value_test.cc | 225 ++++++++++++++++++++++++++----------- 9 files changed, 239 insertions(+), 126 deletions(-) diff --git a/base/type_factory.h b/base/type_factory.h index 914cc5bdd..2c2300e78 100644 --- a/base/type_factory.h +++ b/base/type_factory.h @@ -33,7 +33,7 @@ namespace cel { // // While TypeFactory is not final and has a virtual destructor, inheriting it is // forbidden outside of the CEL codebase. -class TypeFactory { +class TypeFactory final { private: template using EnableIfBaseOfT = @@ -44,8 +44,6 @@ class TypeFactory { MemoryManager& memory_manager ABSL_ATTRIBUTE_LIFETIME_BOUND) : memory_manager_(memory_manager) {} - virtual ~TypeFactory() = default; - TypeFactory(const TypeFactory&) = delete; TypeFactory& operator=(const TypeFactory&) = delete; @@ -98,6 +96,8 @@ class TypeFactory { Persistent GetTypeType() ABSL_ATTRIBUTE_LIFETIME_BOUND; + MemoryManager& memory_manager() const { return memory_manager_; } + private: template static Persistent WrapSingletonType() { @@ -109,8 +109,6 @@ class TypeFactory { const T>(T::Get())); } - MemoryManager& memory_manager() const { return memory_manager_; } - MemoryManager& memory_manager_; absl::Mutex list_types_mutex_; diff --git a/base/type_manager.cc b/base/type_manager.cc index ed706a361..796d38694 100644 --- a/base/type_manager.cc +++ b/base/type_manager.cc @@ -23,20 +23,12 @@ namespace cel { -absl::StatusOr> TypeManager::ProvideType( +absl::StatusOr> TypeManager::ResolveType( absl::string_view name) { - // The const_cast is safe because TypeFactory can never call back into - // TypeRegistry or TypeManager methods we expose or override in TypeManager. - // Thus for state defined by TypeManager, we are effectively const. - return ProvideType(const_cast(*this), name); -} - -absl::StatusOr> TypeManager::ProvideType( - TypeFactory& type_factory, absl::string_view name) const { { // Check for builtin types first. CEL_ASSIGN_OR_RETURN( - auto type, TypeProvider::Builtin().ProvideType(type_factory, name)); + auto type, TypeProvider::Builtin().ProvideType(type_factory(), name)); if (type) { return type; } @@ -47,7 +39,7 @@ absl::StatusOr> TypeManager::ProvideType( if (existing == types_.end()) { // Delegate to TypeRegistry implementation. CEL_ASSIGN_OR_RETURN(auto type, - TypeRegistry::ProvideType(type_factory, name)); + type_provider().ProvideType(type_factory(), name)); ABSL_ASSERT(!type || type->name() == name); existing = types_.insert({std::string(name), std::move(type)}).first; } diff --git a/base/type_manager.h b/base/type_manager.h index a50710eeb..bbeea1b3e 100644 --- a/base/type_manager.h +++ b/base/type_manager.h @@ -17,6 +17,7 @@ #include +#include "absl/base/attributes.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" @@ -32,16 +33,26 @@ namespace cel { // and registering type implementations. // // TODO(issues/5): more comments after solidifying role -class TypeManager final : public TypeFactory, public TypeRegistry { +class TypeManager final { public: - using TypeFactory::TypeFactory; + TypeManager(TypeFactory& type_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, + TypeProvider& type_provider ABSL_ATTRIBUTE_LIFETIME_BOUND) + : type_factory_(type_factory), type_provider_(type_provider) {} - absl::StatusOr> ProvideType(absl::string_view name); + MemoryManager& memory_manager() const { + return type_factory().memory_manager(); + } - absl::StatusOr> ProvideType( - TypeFactory& type_factory, absl::string_view name) const override; + TypeFactory& type_factory() const { return type_factory_; } + + TypeProvider& type_provider() const { return type_provider_; } + + absl::StatusOr> ResolveType(absl::string_view name); private: + TypeFactory& type_factory_; + TypeProvider& type_provider_; + mutable absl::Mutex mutex_; // std::string as the key because we also cache types which do not exist. mutable absl::flat_hash_map> types_ diff --git a/base/type_provider.cc b/base/type_provider.cc index 6e42993d2..c3bc38f2b 100644 --- a/base/type_provider.cc +++ b/base/type_provider.cc @@ -139,8 +139,8 @@ class BuiltinTypeProvider final : public TypeProvider { } // namespace -const TypeProvider& TypeProvider::Builtin() { - static const internal::NoDestructor instance; +TypeProvider& TypeProvider::Builtin() { + static internal::NoDestructor instance; return *instance; } diff --git a/base/type_provider.h b/base/type_provider.h index 1db6ae7e9..cde5befa8 100644 --- a/base/type_provider.h +++ b/base/type_provider.h @@ -38,8 +38,9 @@ class TypeFactory; // implementations. class TypeProvider { public: - // Returns a TypeProvider which provides all of CEL's builtin types. - ABSL_ATTRIBUTE_PURE_FUNCTION static const TypeProvider& Builtin(); + // Returns a TypeProvider which provides all of CEL's builtin types. It is + // thread safe. + ABSL_ATTRIBUTE_PURE_FUNCTION static TypeProvider& Builtin(); virtual ~TypeProvider() = default; diff --git a/base/type_test.cc b/base/type_test.cc index 6f3fe1de7..e9df8905c 100644 --- a/base/type_test.cc +++ b/base/type_test.cc @@ -105,13 +105,14 @@ class TestStructType final : public StructType { absl::StatusOr FindFieldByName(TypeManager& type_manager, absl::string_view name) const override { if (name == "bool_field") { - return Field("bool_field", 0, type_manager.GetBoolType()); + return Field("bool_field", 0, type_manager.type_factory().GetBoolType()); } else if (name == "int_field") { - return Field("int_field", 1, type_manager.GetIntType()); + return Field("int_field", 1, type_manager.type_factory().GetIntType()); } else if (name == "uint_field") { - return Field("uint_field", 2, type_manager.GetUintType()); + return Field("uint_field", 2, type_manager.type_factory().GetUintType()); } else if (name == "double_field") { - return Field("double_field", 3, type_manager.GetDoubleType()); + return Field("double_field", 3, + type_manager.type_factory().GetDoubleType()); } return absl::NotFoundError(""); } @@ -120,13 +121,16 @@ class TestStructType final : public StructType { int64_t number) const override { switch (number) { case 0: - return Field("bool_field", 0, type_manager.GetBoolType()); + return Field("bool_field", 0, + type_manager.type_factory().GetBoolType()); case 1: - return Field("int_field", 1, type_manager.GetIntType()); + return Field("int_field", 1, type_manager.type_factory().GetIntType()); case 2: - return Field("uint_field", 2, type_manager.GetUintType()); + return Field("uint_field", 2, + type_manager.type_factory().GetUintType()); case 3: - return Field("double_field", 3, type_manager.GetDoubleType()); + return Field("double_field", 3, + type_manager.type_factory().GetDoubleType()); default: return absl::NotFoundError(""); } @@ -553,9 +557,11 @@ TEST_P(TypeTest, Enum) { } TEST_P(TypeTest, Struct) { - TypeManager type_manager(memory_manager()); - ASSERT_OK_AND_ASSIGN(auto struct_type, - type_manager.CreateStructType()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ASSERT_OK_AND_ASSIGN( + auto struct_type, + type_manager.type_factory().CreateStructType()); EXPECT_EQ(struct_type->kind(), Kind::kStruct); EXPECT_EQ(struct_type->name(), "test_struct.TestStruct"); EXPECT_THAT(struct_type->parameters(), SizeIs(0)); @@ -703,61 +709,63 @@ INSTANTIATE_TEST_SUITE_P(EnumTypeTest, EnumTypeTest, class StructTypeTest : public TypeTest {}; TEST_P(StructTypeTest, FindField) { - TypeManager type_manager(memory_manager()); - ASSERT_OK_AND_ASSIGN(auto struct_type, - type_manager.CreateStructType()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ASSERT_OK_AND_ASSIGN( + auto struct_type, + type_manager.type_factory().CreateStructType()); ASSERT_OK_AND_ASSIGN( auto field1, struct_type->FindField(type_manager, StructType::FieldId("bool_field"))); EXPECT_EQ(field1.name, "bool_field"); EXPECT_EQ(field1.number, 0); - EXPECT_EQ(field1.type, type_manager.GetBoolType()); + EXPECT_EQ(field1.type, type_manager.type_factory().GetBoolType()); ASSERT_OK_AND_ASSIGN( field1, struct_type->FindField(type_manager, StructType::FieldId(0))); EXPECT_EQ(field1.name, "bool_field"); EXPECT_EQ(field1.number, 0); - EXPECT_EQ(field1.type, type_manager.GetBoolType()); + EXPECT_EQ(field1.type, type_manager.type_factory().GetBoolType()); ASSERT_OK_AND_ASSIGN( auto field2, struct_type->FindField(type_manager, StructType::FieldId("int_field"))); EXPECT_EQ(field2.name, "int_field"); EXPECT_EQ(field2.number, 1); - EXPECT_EQ(field2.type, type_manager.GetIntType()); + EXPECT_EQ(field2.type, type_manager.type_factory().GetIntType()); ASSERT_OK_AND_ASSIGN( field2, struct_type->FindField(type_manager, StructType::FieldId(1))); EXPECT_EQ(field2.name, "int_field"); EXPECT_EQ(field2.number, 1); - EXPECT_EQ(field2.type, type_manager.GetIntType()); + EXPECT_EQ(field2.type, type_manager.type_factory().GetIntType()); ASSERT_OK_AND_ASSIGN( auto field3, struct_type->FindField(type_manager, StructType::FieldId("uint_field"))); EXPECT_EQ(field3.name, "uint_field"); EXPECT_EQ(field3.number, 2); - EXPECT_EQ(field3.type, type_manager.GetUintType()); + EXPECT_EQ(field3.type, type_manager.type_factory().GetUintType()); ASSERT_OK_AND_ASSIGN( field3, struct_type->FindField(type_manager, StructType::FieldId(2))); EXPECT_EQ(field3.name, "uint_field"); EXPECT_EQ(field3.number, 2); - EXPECT_EQ(field3.type, type_manager.GetUintType()); + EXPECT_EQ(field3.type, type_manager.type_factory().GetUintType()); ASSERT_OK_AND_ASSIGN( auto field4, struct_type->FindField(type_manager, StructType::FieldId("double_field"))); EXPECT_EQ(field4.name, "double_field"); EXPECT_EQ(field4.number, 3); - EXPECT_EQ(field4.type, type_manager.GetDoubleType()); + EXPECT_EQ(field4.type, type_manager.type_factory().GetDoubleType()); ASSERT_OK_AND_ASSIGN( field4, struct_type->FindField(type_manager, StructType::FieldId(3))); EXPECT_EQ(field4.name, "double_field"); EXPECT_EQ(field4.number, 3); - EXPECT_EQ(field4.type, type_manager.GetDoubleType()); + EXPECT_EQ(field4.type, type_manager.type_factory().GetDoubleType()); EXPECT_THAT(struct_type->FindField(type_manager, StructType::FieldId("missing_field")), @@ -842,9 +850,11 @@ TEST_P(DebugStringTest, EnumType) { } TEST_P(DebugStringTest, StructType) { - TypeManager type_manager(memory_manager()); - ASSERT_OK_AND_ASSIGN(auto struct_type, - type_manager.CreateStructType()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ASSERT_OK_AND_ASSIGN( + auto struct_type, + type_manager.type_factory().CreateStructType()); EXPECT_EQ(struct_type->DebugString(), "test_struct.TestStruct"); } diff --git a/base/value_factory.h b/base/value_factory.h index 281c6e6eb..ad13b750b 100644 --- a/base/value_factory.h +++ b/base/value_factory.h @@ -29,6 +29,7 @@ #include "absl/time/time.h" #include "base/handle.h" #include "base/memory_manager.h" +#include "base/type_manager.h" #include "base/value.h" namespace cel { @@ -47,13 +48,18 @@ class ValueFactory final { std::enable_if_t>, V>; public: - explicit ValueFactory( - MemoryManager& memory_manager ABSL_ATTRIBUTE_LIFETIME_BOUND) - : memory_manager_(memory_manager) {} + explicit ValueFactory(TypeManager& type_manager ABSL_ATTRIBUTE_LIFETIME_BOUND) + : type_manager_(type_manager) {} ValueFactory(const ValueFactory&) = delete; ValueFactory& operator=(const ValueFactory&) = delete; + TypeFactory& type_factory() const { return type_manager().type_factory(); } + + TypeProvider& type_provider() const { return type_manager().type_provider(); } + + TypeManager& type_manager() const { return type_manager_; } + Persistent GetNullValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; Persistent CreateErrorValue(absl::Status status) @@ -184,6 +190,10 @@ class ValueFactory final { Persistent CreateTypeValue( const Persistent& value) ABSL_ATTRIBUTE_LIFETIME_BOUND; + MemoryManager& memory_manager() const { + return type_manager().memory_manager(); + } + private: friend class BytesValue; friend class StringValue; @@ -194,8 +204,6 @@ class ValueFactory final { interop_internal::CreateBytesValueFromView(cel::ValueFactory& value_factory, absl::string_view input); - MemoryManager& memory_manager() const { return memory_manager_; } - Persistent GetEmptyBytesValue() ABSL_ATTRIBUTE_LIFETIME_BOUND; @@ -217,7 +225,7 @@ class ValueFactory final { absl::StatusOr> CreateStringValueFromView( absl::string_view value) ABSL_ATTRIBUTE_LIFETIME_BOUND; - MemoryManager& memory_manager_; + TypeManager& type_manager_; }; // TypedEnumValueFactory creates EnumValue scoped to a specific EnumType. Used diff --git a/base/value_factory_test.cc b/base/value_factory_test.cc index 171f0f360..36d7ac285 100644 --- a/base/value_factory_test.cc +++ b/base/value_factory_test.cc @@ -24,13 +24,17 @@ namespace { using cel::internal::StatusIs; TEST(ValueFactory, CreateErrorValueReplacesOk) { - ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_THAT(value_factory.CreateErrorValue(absl::OkStatus())->value(), StatusIs(absl::StatusCode::kUnknown)); } TEST(ValueFactory, CreateStringValueIllegalByteSequence) { - ValueFactory value_factory(MemoryManager::Global()); + TypeFactory type_factory(MemoryManager::Global()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_THAT(value_factory.CreateStringValue("\xff"), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(value_factory.CreateStringValue(absl::Cord("\xff")), diff --git a/base/value_test.cc b/base/value_test.cc index 0f5c0b63d..9f67cf262 100644 --- a/base/value_test.cc +++ b/base/value_test.cc @@ -318,13 +318,14 @@ class TestStructType final : public StructType { absl::StatusOr FindFieldByName(TypeManager& type_manager, absl::string_view name) const override { if (name == "bool_field") { - return Field("bool_field", 0, type_manager.GetBoolType()); + return Field("bool_field", 0, type_manager.type_factory().GetBoolType()); } else if (name == "int_field") { - return Field("int_field", 1, type_manager.GetIntType()); + return Field("int_field", 1, type_manager.type_factory().GetIntType()); } else if (name == "uint_field") { - return Field("uint_field", 2, type_manager.GetUintType()); + return Field("uint_field", 2, type_manager.type_factory().GetUintType()); } else if (name == "double_field") { - return Field("double_field", 3, type_manager.GetDoubleType()); + return Field("double_field", 3, + type_manager.type_factory().GetDoubleType()); } return absl::NotFoundError(""); } @@ -333,13 +334,16 @@ class TestStructType final : public StructType { int64_t number) const override { switch (number) { case 0: - return Field("bool_field", 0, type_manager.GetBoolType()); + return Field("bool_field", 0, + type_manager.type_factory().GetBoolType()); case 1: - return Field("int_field", 1, type_manager.GetIntType()); + return Field("int_field", 1, type_manager.type_factory().GetIntType()); case 2: - return Field("uint_field", 2, type_manager.GetUintType()); + return Field("uint_field", 2, + type_manager.type_factory().GetUintType()); case 3: - return Field("double_field", 3, type_manager.GetDoubleType()); + return Field("double_field", 3, + type_manager.type_factory().GetDoubleType()); default: return absl::NotFoundError(""); } @@ -545,7 +549,9 @@ TEST(Value, PersistentHandleTypeTraits) { } TEST_P(ValueTest, DefaultConstructor) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); Transient value; EXPECT_EQ(value, value_factory.GetNullValue()); } @@ -561,7 +567,8 @@ using ConstructionAssignmentTest = TEST_P(ConstructionAssignmentTest, CopyConstructor) { TypeFactory type_factory(memory_manager()); - ValueFactory value_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); Persistent from( test_case().default_value(type_factory, value_factory)); Persistent to(from); @@ -571,7 +578,8 @@ TEST_P(ConstructionAssignmentTest, CopyConstructor) { TEST_P(ConstructionAssignmentTest, MoveConstructor) { TypeFactory type_factory(memory_manager()); - ValueFactory value_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); Persistent from( test_case().default_value(type_factory, value_factory)); Persistent to(std::move(from)); @@ -582,7 +590,8 @@ TEST_P(ConstructionAssignmentTest, MoveConstructor) { TEST_P(ConstructionAssignmentTest, CopyAssignment) { TypeFactory type_factory(memory_manager()); - ValueFactory value_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); Persistent from( test_case().default_value(type_factory, value_factory)); Persistent to; @@ -592,7 +601,8 @@ TEST_P(ConstructionAssignmentTest, CopyAssignment) { TEST_P(ConstructionAssignmentTest, MoveAssignment) { TypeFactory type_factory(memory_manager()); - ValueFactory value_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); Persistent from( test_case().default_value(type_factory, value_factory)); Persistent to; @@ -703,7 +713,9 @@ INSTANTIATE_TEST_SUITE_P( }); TEST_P(ValueTest, Swap) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); Persistent lhs = value_factory.CreateIntValue(0); Persistent rhs = value_factory.CreateUintValue(0); std::swap(lhs, rhs); @@ -714,18 +726,24 @@ TEST_P(ValueTest, Swap) { using DebugStringTest = ValueTest; TEST_P(DebugStringTest, NullValue) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(value_factory.GetNullValue()->DebugString(), "null"); } TEST_P(DebugStringTest, BoolValue) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(value_factory.CreateBoolValue(false)->DebugString(), "false"); EXPECT_EQ(value_factory.CreateBoolValue(true)->DebugString(), "true"); } TEST_P(DebugStringTest, IntValue) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(value_factory.CreateIntValue(-1)->DebugString(), "-1"); EXPECT_EQ(value_factory.CreateIntValue(0)->DebugString(), "0"); EXPECT_EQ(value_factory.CreateIntValue(1)->DebugString(), "1"); @@ -738,7 +756,9 @@ TEST_P(DebugStringTest, IntValue) { } TEST_P(DebugStringTest, UintValue) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(value_factory.CreateUintValue(0)->DebugString(), "0u"); EXPECT_EQ(value_factory.CreateUintValue(1)->DebugString(), "1u"); EXPECT_EQ(value_factory.CreateUintValue(std::numeric_limits::max()) @@ -747,7 +767,9 @@ TEST_P(DebugStringTest, UintValue) { } TEST_P(DebugStringTest, DoubleValue) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(value_factory.CreateDoubleValue(-1.0)->DebugString(), "-1.0"); EXPECT_EQ(value_factory.CreateDoubleValue(0.0)->DebugString(), "0.0"); EXPECT_EQ(value_factory.CreateDoubleValue(1.0)->DebugString(), "1.0"); @@ -780,13 +802,17 @@ TEST_P(DebugStringTest, DoubleValue) { } TEST_P(DebugStringTest, DurationValue) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(DurationValue::Zero(value_factory)->DebugString(), internal::FormatDuration(absl::ZeroDuration()).value()); } TEST_P(DebugStringTest, TimestampValue) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(TimestampValue::UnixEpoch(value_factory)->DebugString(), internal::FormatTimestamp(absl::UnixEpoch()).value()); } @@ -800,8 +826,9 @@ INSTANTIATE_TEST_SUITE_P(DebugStringTest, DebugStringTest, // feature is not available in C++17. TEST_P(ValueTest, Error) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto error_value = value_factory.CreateErrorValue(absl::CancelledError()); EXPECT_TRUE(error_value.Is()); EXPECT_FALSE(error_value.Is()); @@ -812,8 +839,9 @@ TEST_P(ValueTest, Error) { } TEST_P(ValueTest, Bool) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto false_value = BoolValue::False(value_factory); EXPECT_TRUE(false_value.Is()); EXPECT_FALSE(false_value.Is()); @@ -837,8 +865,9 @@ TEST_P(ValueTest, Bool) { } TEST_P(ValueTest, Int) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = value_factory.CreateIntValue(0); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -862,8 +891,9 @@ TEST_P(ValueTest, Int) { } TEST_P(ValueTest, Uint) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = value_factory.CreateUintValue(0); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -887,8 +917,9 @@ TEST_P(ValueTest, Uint) { } TEST_P(ValueTest, Double) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = value_factory.CreateDoubleValue(0.0); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -912,8 +943,9 @@ TEST_P(ValueTest, Double) { } TEST_P(ValueTest, Duration) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = Must(value_factory.CreateDurationValue(absl::ZeroDuration())); EXPECT_TRUE(zero_value.Is()); @@ -942,8 +974,9 @@ TEST_P(ValueTest, Duration) { } TEST_P(ValueTest, Timestamp) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = Must(value_factory.CreateTimestampValue(absl::UnixEpoch())); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -971,8 +1004,9 @@ TEST_P(ValueTest, Timestamp) { } TEST_P(ValueTest, BytesFromString) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = Must(value_factory.CreateBytesValue(std::string("0"))); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -996,8 +1030,9 @@ TEST_P(ValueTest, BytesFromString) { } TEST_P(ValueTest, BytesFromStringView) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = Must(value_factory.CreateBytesValue(absl::string_view("0"))); EXPECT_TRUE(zero_value.Is()); @@ -1024,8 +1059,9 @@ TEST_P(ValueTest, BytesFromStringView) { } TEST_P(ValueTest, BytesFromCord) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = Must(value_factory.CreateBytesValue(absl::Cord("0"))); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1049,8 +1085,9 @@ TEST_P(ValueTest, BytesFromCord) { } TEST_P(ValueTest, BytesFromLiteral) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = Must(value_factory.CreateBytesValue("0")); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1074,8 +1111,9 @@ TEST_P(ValueTest, BytesFromLiteral) { } TEST_P(ValueTest, BytesFromExternal) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = Must(value_factory.CreateBytesValue("0", []() {})); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1099,8 +1137,9 @@ TEST_P(ValueTest, BytesFromExternal) { } TEST_P(ValueTest, StringFromString) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = Must(value_factory.CreateStringValue(std::string("0"))); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1125,8 +1164,9 @@ TEST_P(ValueTest, StringFromString) { } TEST_P(ValueTest, StringFromStringView) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = Must(value_factory.CreateStringValue(absl::string_view("0"))); EXPECT_TRUE(zero_value.Is()); @@ -1154,8 +1194,9 @@ TEST_P(ValueTest, StringFromStringView) { } TEST_P(ValueTest, StringFromCord) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = Must(value_factory.CreateStringValue(absl::Cord("0"))); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1179,8 +1220,9 @@ TEST_P(ValueTest, StringFromCord) { } TEST_P(ValueTest, StringFromLiteral) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = Must(value_factory.CreateStringValue("0")); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1204,8 +1246,9 @@ TEST_P(ValueTest, StringFromLiteral) { } TEST_P(ValueTest, StringFromExternal) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto zero_value = Must(value_factory.CreateStringValue("0", []() {})); EXPECT_TRUE(zero_value.Is()); EXPECT_FALSE(zero_value.Is()); @@ -1229,8 +1272,9 @@ TEST_P(ValueTest, StringFromExternal) { } TEST_P(ValueTest, Type) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); auto null_value = value_factory.CreateTypeValue(type_factory.GetNullType()); EXPECT_TRUE(null_value.Is()); EXPECT_FALSE(null_value.Is()); @@ -1278,7 +1322,9 @@ struct BytesConcatTestCase final { using BytesConcatTest = BaseValueTest; TEST_P(BytesConcatTest, Concat) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_TRUE( Must(BytesValue::Concat(value_factory, MakeStringBytes(value_factory, test_case().lhs), @@ -1350,7 +1396,9 @@ struct BytesSizeTestCase final { using BytesSizeTest = BaseValueTest; TEST_P(BytesSizeTest, Size) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringBytes(value_factory, test_case().data)->size(), test_case().size); EXPECT_EQ(MakeCordBytes(value_factory, test_case().data)->size(), @@ -1377,7 +1425,9 @@ struct BytesEmptyTestCase final { using BytesEmptyTest = BaseValueTest; TEST_P(BytesEmptyTest, Empty) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringBytes(value_factory, test_case().data)->empty(), test_case().empty); EXPECT_EQ(MakeCordBytes(value_factory, test_case().data)->empty(), @@ -1404,7 +1454,9 @@ struct BytesEqualsTestCase final { using BytesEqualsTest = BaseValueTest; TEST_P(BytesEqualsTest, Equals) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringBytes(value_factory, test_case().lhs) ->Equals(MakeStringBytes(value_factory, test_case().rhs)), test_case().equals); @@ -1461,7 +1513,9 @@ using BytesCompareTest = BaseValueTest; int NormalizeCompareResult(int compare) { return std::clamp(compare, -1, 1); } TEST_P(BytesCompareTest, Equals) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(NormalizeCompareResult( MakeStringBytes(value_factory, test_case().lhs) ->Compare(MakeStringBytes(value_factory, test_case().rhs))), @@ -1525,7 +1579,9 @@ struct BytesDebugStringTestCase final { using BytesDebugStringTest = BaseValueTest; TEST_P(BytesDebugStringTest, ToCord) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringBytes(value_factory, test_case().data)->DebugString(), internal::FormatBytesLiteral(test_case().data)); EXPECT_EQ(MakeCordBytes(value_factory, test_case().data)->DebugString(), @@ -1551,7 +1607,9 @@ struct BytesToStringTestCase final { using BytesToStringTest = BaseValueTest; TEST_P(BytesToStringTest, ToString) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringBytes(value_factory, test_case().data)->ToString(), test_case().data); EXPECT_EQ(MakeCordBytes(value_factory, test_case().data)->ToString(), @@ -1577,7 +1635,9 @@ struct BytesToCordTestCase final { using BytesToCordTest = BaseValueTest; TEST_P(BytesToCordTest, ToCord) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringBytes(value_factory, test_case().data)->ToCord(), test_case().data); EXPECT_EQ(MakeCordBytes(value_factory, test_case().data)->ToCord(), @@ -1619,7 +1679,9 @@ struct StringConcatTestCase final { using StringConcatTest = BaseValueTest; TEST_P(StringConcatTest, Concat) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_TRUE( Must(StringValue::Concat( value_factory, MakeStringString(value_factory, test_case().lhs), @@ -1691,7 +1753,9 @@ struct StringSizeTestCase final { using StringSizeTest = BaseValueTest; TEST_P(StringSizeTest, Size) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringString(value_factory, test_case().data)->size(), test_case().size); EXPECT_EQ(MakeCordString(value_factory, test_case().data)->size(), @@ -1718,7 +1782,9 @@ struct StringEmptyTestCase final { using StringEmptyTest = BaseValueTest; TEST_P(StringEmptyTest, Empty) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringString(value_factory, test_case().data)->empty(), test_case().empty); EXPECT_EQ(MakeCordString(value_factory, test_case().data)->empty(), @@ -1745,7 +1811,9 @@ struct StringEqualsTestCase final { using StringEqualsTest = BaseValueTest; TEST_P(StringEqualsTest, Equals) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringString(value_factory, test_case().lhs) ->Equals(MakeStringString(value_factory, test_case().rhs)), test_case().equals); @@ -1800,7 +1868,9 @@ struct StringCompareTestCase final { using StringCompareTest = BaseValueTest; TEST_P(StringCompareTest, Equals) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ( NormalizeCompareResult( MakeStringString(value_factory, test_case().lhs) @@ -1868,7 +1938,9 @@ struct StringDebugStringTestCase final { using StringDebugStringTest = BaseValueTest; TEST_P(StringDebugStringTest, ToCord) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringString(value_factory, test_case().data)->DebugString(), internal::FormatStringLiteral(test_case().data)); EXPECT_EQ(MakeCordString(value_factory, test_case().data)->DebugString(), @@ -1894,7 +1966,9 @@ struct StringToStringTestCase final { using StringToStringTest = BaseValueTest; TEST_P(StringToStringTest, ToString) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringString(value_factory, test_case().data)->ToString(), test_case().data); EXPECT_EQ(MakeCordString(value_factory, test_case().data)->ToString(), @@ -1920,7 +1994,9 @@ struct StringToCordTestCase final { using StringToCordTest = BaseValueTest; TEST_P(StringToCordTest, ToCord) { - ValueFactory value_factory(memory_manager()); + TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); EXPECT_EQ(MakeStringString(value_factory, test_case().data)->ToCord(), test_case().data); EXPECT_EQ(MakeCordString(value_factory, test_case().data)->ToCord(), @@ -1940,8 +2016,9 @@ INSTANTIATE_TEST_SUITE_P( }))); TEST_P(ValueTest, Enum) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); ASSERT_OK_AND_ASSIGN( @@ -1977,8 +2054,9 @@ TEST_P(ValueTest, Enum) { using EnumTypeTest = ValueTest; TEST_P(EnumTypeTest, NewInstance) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); ASSERT_OK_AND_ASSIGN( @@ -2008,8 +2086,9 @@ INSTANTIATE_TEST_SUITE_P(EnumTypeTest, EnumTypeTest, base_internal::MemoryManagerTestModeTupleName); TEST_P(ValueTest, Struct) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto struct_type, type_factory.CreateStructType()); ASSERT_OK_AND_ASSIGN(auto zero_value, @@ -2049,8 +2128,9 @@ TEST_P(ValueTest, Struct) { using StructValueTest = ValueTest; TEST_P(StructValueTest, SetField) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto struct_type, type_factory.CreateStructType()); ASSERT_OK_AND_ASSIGN(auto struct_value, @@ -2126,8 +2206,9 @@ TEST_P(StructValueTest, SetField) { } TEST_P(StructValueTest, GetField) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto struct_type, type_factory.CreateStructType()); ASSERT_OK_AND_ASSIGN(auto struct_value, @@ -2160,8 +2241,9 @@ TEST_P(StructValueTest, GetField) { } TEST_P(StructValueTest, HasField) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto struct_type, type_factory.CreateStructType()); ASSERT_OK_AND_ASSIGN(auto struct_value, @@ -2193,8 +2275,9 @@ INSTANTIATE_TEST_SUITE_P(StructValueTest, StructValueTest, base_internal::MemoryManagerTestModeTupleName); TEST_P(ValueTest, List) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto list_type, type_factory.CreateListType(type_factory.GetIntType())); ASSERT_OK_AND_ASSIGN(auto zero_value, @@ -2228,8 +2311,9 @@ TEST_P(ValueTest, List) { using ListValueTest = ValueTest; TEST_P(ListValueTest, DebugString) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto list_type, type_factory.CreateListType(type_factory.GetIntType())); ASSERT_OK_AND_ASSIGN(auto list_value, @@ -2243,8 +2327,9 @@ TEST_P(ListValueTest, DebugString) { } TEST_P(ListValueTest, Get) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto list_type, type_factory.CreateListType(type_factory.GetIntType())); ASSERT_OK_AND_ASSIGN(auto list_value, @@ -2273,8 +2358,9 @@ INSTANTIATE_TEST_SUITE_P(ListValueTest, ListValueTest, base_internal::MemoryManagerTestModeTupleName); TEST_P(ValueTest, Map) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto map_type, type_factory.CreateMapType(type_factory.GetStringType(), type_factory.GetIntType())); @@ -2312,8 +2398,9 @@ TEST_P(ValueTest, Map) { using MapValueTest = ValueTest; TEST_P(MapValueTest, DebugString) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto map_type, type_factory.CreateMapType(type_factory.GetStringType(), type_factory.GetIntType())); @@ -2329,8 +2416,9 @@ TEST_P(MapValueTest, DebugString) { } TEST_P(MapValueTest, GetAndHas) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto map_type, type_factory.CreateMapType(type_factory.GetStringType(), type_factory.GetIntType())); @@ -2375,8 +2463,9 @@ INSTANTIATE_TEST_SUITE_P(MapValueTest, MapValueTest, base_internal::MemoryManagerTestModeTupleName); TEST_P(ValueTest, SupportsAbslHash) { - ValueFactory value_factory(memory_manager()); TypeFactory type_factory(memory_manager()); + TypeManager type_manager(type_factory, TypeProvider::Builtin()); + ValueFactory value_factory(type_manager); ASSERT_OK_AND_ASSIGN(auto enum_type, type_factory.CreateEnumType()); ASSERT_OK_AND_ASSIGN(auto struct_type, From 23cd804130d42ab31677782286800e417412ebc9 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 31 May 2022 15:00:38 +0000 Subject: [PATCH 107/111] Add conversion utilities going from the proto type representations to native type representations of the AST. PiperOrigin-RevId: 452047564 --- base/BUILD | 35 ++ base/ast.h | 27 +- base/ast_utility.cc | 506 +++++++++++++++++++++++ base/ast_utility.h | 44 ++ base/ast_utility_test.cc | 848 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 1443 insertions(+), 17 deletions(-) create mode 100644 base/ast_utility.cc create mode 100644 base/ast_utility.h create mode 100644 base/ast_utility_test.cc diff --git a/base/BUILD b/base/BUILD index cfab9e597..7a547dd68 100644 --- a/base/BUILD +++ b/base/BUILD @@ -240,3 +240,38 @@ cc_test( "@com_google_absl//absl/types:variant", ], ) + +cc_library( + name = "ast_utility", + srcs = ["ast_utility.cc"], + hdrs = ["ast_utility.h"], + deps = [ + ":ast", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "ast_utility_test", + srcs = [ + "ast_utility_test.cc", + ], + deps = [ + ":ast", + ":ast_utility", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/base/ast.h b/base/ast.h index 5cba4267d..a4fcc34ac 100644 --- a/base/ast.h +++ b/base/ast.h @@ -200,30 +200,22 @@ class CreateStruct { // Represents an entry. class Entry { public: + using KeyKind = absl::variant>; Entry() {} - Entry(int64_t id, - absl::variant> key_kind, - std::unique_ptr value) + Entry(int64_t id, KeyKind key_kind, std::unique_ptr value) : id_(id), key_kind_(std::move(key_kind)), value_(std::move(value)) {} void set_id(int64_t id) { id_ = id; } - void set_key_kind( - absl::variant> key_kind) { - key_kind_ = std::move(key_kind); - } + void set_key_kind(KeyKind key_kind) { key_kind_ = std::move(key_kind); } void set_value(std::unique_ptr value) { value_ = std::move(value); } int64_t id() const { return id_; } - const absl::variant>& key_kind() const { - return key_kind_; - } + const KeyKind& key_kind() const { return key_kind_; } - absl::variant>& mutable_key_kind() { - return key_kind_; - } + KeyKind& mutable_key_kind() { return key_kind_; } const Expr* value() const { return value_.get(); } @@ -240,7 +232,7 @@ class CreateStruct { // information and other attributes to the node. int64_t id_; // The `Entry` key kinds. - absl::variant> key_kind_; + KeyKind key_kind_; // Required. The value assigned to the key. std::unique_ptr value_; }; @@ -759,6 +751,7 @@ class FunctionType { // // TODO(issues/5): decide on final naming for this. class AbstractType { + public: AbstractType(std::string name, std::vector parameter_types) : name_(std::move(name)), parameter_types_(std::move(parameter_types)) {} @@ -791,7 +784,7 @@ class PrimitiveTypeWrapper { const PrimitiveType& type() const { return type_; } - PrimitiveType& type() { return type_; } + PrimitiveType& mutable_type() { return type_; } private: PrimitiveType type_; @@ -807,7 +800,7 @@ class MessageType { void set_type(std::string type) { type_ = std::move(type); } - const std::string& type() { return type_; } + const std::string& type() const { return type_; } private: std::string type_; @@ -824,7 +817,7 @@ class ParamType { void set_type(std::string type) { type_ = std::move(type); } - const std::string& type() { return type_; } + const std::string& type() const { return type_; } private: std::string type_; diff --git a/base/ast_utility.cc b/base/ast_utility.cc new file mode 100644 index 000000000..812470d8b --- /dev/null +++ b/base/ast_utility.cc @@ -0,0 +1,506 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/ast_utility.h" + +#include +#include +#include +#include +#include + +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "base/ast.h" + +namespace cel::ast::internal { + +absl::StatusOr ToNative(const google::api::expr::v1alpha1::Constant& constant) { + switch (constant.constant_kind_case()) { + case google::api::expr::v1alpha1::Constant::kNullValue: + return NullValue::kNullValue; + case google::api::expr::v1alpha1::Constant::kBoolValue: + return constant.bool_value(); + case google::api::expr::v1alpha1::Constant::kInt64Value: + return constant.int64_value(); + case google::api::expr::v1alpha1::Constant::kUint64Value: + return constant.uint64_value(); + case google::api::expr::v1alpha1::Constant::kDoubleValue: + return constant.double_value(); + case google::api::expr::v1alpha1::Constant::kStringValue: + return constant.string_value(); + case google::api::expr::v1alpha1::Constant::kBytesValue: + return constant.bytes_value(); + case google::api::expr::v1alpha1::Constant::kDurationValue: + return absl::Seconds(constant.duration_value().seconds()) + + absl::Nanoseconds(constant.duration_value().nanos()); + case google::api::expr::v1alpha1::Constant::kTimestampValue: + return absl::FromUnixSeconds(constant.timestamp_value().seconds()) + + absl::Nanoseconds(constant.timestamp_value().nanos()); + default: + return absl::InvalidArgumentError( + "Illegal type supplied for google::api::expr::v1alpha1::Constant."); + } +} + +Ident ToNative(const google::api::expr::v1alpha1::Expr::Ident& ident) { + return Ident(ident.name()); +} + +absl::StatusOr(native_expr->expr_kind())); + auto& native_select = absl::get