Skip to content

Commit

Permalink
Emplace into optional to be able to visit variant containers containi… (
Browse files Browse the repository at this point in the history
  • Loading branch information
Urfoex authored Nov 8, 2024
1 parent 88cf229 commit 716de7b
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 149 deletions.
6 changes: 6 additions & 0 deletions include/rfl/TaggedUnion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ struct TaggedUnion {
VariantType variant_;
};

template <typename T>
concept TaggedUnionBased = requires(T t) {
[]<internal::StringLiteral _discriminator, typename... Args>(
TaggedUnion<_discriminator, Args...> const&) {}(t);
};

template <class T>
struct PossibleTags;

Expand Down
71 changes: 17 additions & 54 deletions include/rfl/Variant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#define RFL_VARIANT_HPP_

#include <array>
#include <bit>
#include <cstdint>
#include <limits>
#include <optional>
Expand All @@ -14,7 +13,6 @@
#include "internal/nth_element_t.hpp"
#include "internal/variant/find_max_size.hpp"
#include "internal/variant/is_alternative_type.hpp"
#include "internal/variant/is_convertible_to.hpp"
#include "internal/variant/result_t.hpp"

namespace rfl {
Expand Down Expand Up @@ -152,80 +150,40 @@ class Variant {
}

template <class F>
result_t<F> visit(F& _f) {
result_t<F> visit(F&& _f) {
using ResultType = result_t<F>;
if constexpr (std::is_same_v<ResultType, void>) {
bool visited = false;
do_visit_no_result(_f, &visited,
do_visit_no_result(std::forward<F>(_f), &visited,
std::make_integer_sequence<IndexType, size_>());
} else if constexpr (std::is_reference_v<ResultType>) {
std::remove_reference_t<ResultType>* res = nullptr;
do_visit_with_reference(_f, &res,
do_visit_with_reference(std::forward<F>(_f), &res,
std::make_integer_sequence<IndexType, size_>());
return *res;
} else {
auto res = std::optional<ResultType>();
do_visit_with_result(_f, &res,
do_visit_with_result(std::forward<F>(_f), &res,
std::make_integer_sequence<IndexType, size_>());
return std::move(*res);
}
}

template <class F>
result_t<F> visit(F& _f) const {
result_t<F> visit(F&& _f) const {
using ResultType = result_t<F>;
if constexpr (std::is_same_v<ResultType, void>) {
bool visited = false;
do_visit_no_result(_f, &visited,
do_visit_no_result(std::forward<F>(_f), &visited,
std::make_integer_sequence<IndexType, size_>());
} else if constexpr (std::is_reference_v<ResultType>) {
std::remove_reference_t<ResultType>* res = nullptr;
do_visit_with_reference(_f, &res,
do_visit_with_reference(std::forward<F>(_f), &res,
std::make_integer_sequence<IndexType, size_>());
return *res;
} else {
auto res = std::optional<ResultType>();
do_visit_with_result(_f, &res,
std::make_integer_sequence<IndexType, size_>());
return std::move(*res);
}
}

template <class F>
result_t<F> visit(const F& _f) {
using ResultType = std::remove_reference_t<result_t<F>>;
if constexpr (std::is_same_v<ResultType, void>) {
bool visited = false;
do_visit_no_result(_f, &visited,
std::make_integer_sequence<IndexType, size_>());
} else if constexpr (std::is_reference_v<ResultType>) {
std::remove_reference_t<ResultType>* res = nullptr;
do_visit_with_reference(_f, &res,
std::make_integer_sequence<IndexType, size_>());
return *res;
} else {
auto res = std::optional<ResultType>();
do_visit_with_result(_f, &res,
std::make_integer_sequence<IndexType, size_>());
return std::move(*res);
}
}

template <class F>
result_t<F> visit(const F& _f) const {
using ResultType = result_t<F>;
if constexpr (std::is_same_v<ResultType, void>) {
bool visited = false;
do_visit_no_result(_f, &visited,
std::make_integer_sequence<IndexType, size_>());
} else if constexpr (std::is_reference_v<ResultType>) {
std::remove_reference_t<ResultType>* res = nullptr;
do_visit_with_reference(_f, &res,
std::make_integer_sequence<IndexType, size_>());
return *res;
} else {
auto res = std::optional<ResultType>();
do_visit_with_result(_f, &res,
do_visit_with_result(std::forward<F>(_f), &res,
std::make_integer_sequence<IndexType, size_>());
return std::move(*res);
}
Expand Down Expand Up @@ -330,7 +288,7 @@ class Variant {
std::optional<ResultType>* _res,
Index<_i>) {
if (!*_res && index_ == _i) {
*_res = _f(get_alternative<_i>());
_res->emplace(_f(get_alternative<_i>()));
}
};
(visit_one(_f, _res, Index<_is>{}), ...);
Expand All @@ -343,7 +301,7 @@ class Variant {
std::optional<ResultType>* _res,
Index<_i>) {
if (!*_res && index_ == _i) {
*_res = _f(get_alternative<_i>());
_res->emplace(_f(get_alternative<_i>()));
}
};
(visit_one(_f, _res, Index<_is>{}), ...);
Expand All @@ -356,7 +314,7 @@ class Variant {
std::optional<ResultType>* _res,
Index<_i>) {
if (!*_res && index_ == _i) {
*_res = _f(get_alternative<_i>());
_res->emplace(_f(get_alternative<_i>()));
}
};
(visit_one(_f, _res, Index<_is>{}), ...);
Expand All @@ -369,7 +327,7 @@ class Variant {
std::optional<ResultType>* _res,
Index<_i>) {
if (!*_res && index_ == _i) {
*_res = _f(get_alternative<_i>());
_res->emplace(_f(get_alternative<_i>()));
}
};
(visit_one(_f, _res, Index<_is>{}), ...);
Expand Down Expand Up @@ -460,6 +418,11 @@ class Variant {
alignas(AlternativeTypes...) DataType data_;
};

template <typename V>
concept VariantBased = requires(std::decay_t<V> v) {
[]<typename... Args>(Variant<Args...> const&) {}(v);
};

template <class T, class... Types>
constexpr T* get_if(Variant<Types...>* _v) noexcept {
const auto get = [](auto& _v) -> T* {
Expand Down
3 changes: 1 addition & 2 deletions include/rfl/parsing/call_destructors_where_necessary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ void call_destructors_where_necessary(const std::array<bool, _size>& _set,
[&]<int... is>(std::integer_sequence<int, is...>) {
(call_destructor_on_one_if_necessary<ViewType, _size, is>(_set, _view),
...);
}
(std::make_integer_sequence<int, _size>());
}(std::make_integer_sequence<int, _size>());
}

} // namespace rfl::parsing
Expand Down
108 changes: 16 additions & 92 deletions include/rfl/visit.hpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#ifndef RFL_VISIT_HPP_
#define RFL_VISIT_HPP_

#include <variant>
#include <utility>

#include "Literal.hpp"
#include "TaggedUnion.hpp"
#include "internal/StringLiteral.hpp"
#include "internal/VisitTree.hpp"
#include "internal/VisitorWrapper.hpp"
#include "internal/variant/result_t.hpp"

namespace rfl {

Expand All @@ -22,104 +21,29 @@ inline auto visit(const Visitor& _visitor, const Literal<_fields...> _literal,
wrapper, _literal.value(), _args...);
}

template <class F, class... AlternativeTypes>
inline internal::variant::result_t<F, AlternativeTypes...> visit(
F& _f, Variant<AlternativeTypes...>& _v) {
return _v.visit(_f);
template <class F, VariantBased V>
inline auto visit(F&& _f, V&& _v)
-> decltype(std::declval<V&&>().visit(std::declval<F&&>())) {
return std::forward<V>(_v).visit(std::forward<F>(_f));
}

template <class F, class... AlternativeTypes>
inline internal::variant::result_t<F, AlternativeTypes...> visit(
F& _f, Variant<AlternativeTypes...>&& _v) {
return _v.visit(_f);
}

template <class F, class... AlternativeTypes>
inline internal::variant::result_t<F, AlternativeTypes...> visit(
F& _f, const Variant<AlternativeTypes...>& _v) {
return _v.visit(_f);
}

template <class F, class... AlternativeTypes>
inline internal::variant::result_t<F, AlternativeTypes...> visit(
const F& _f, Variant<AlternativeTypes...>& _v) {
return _v.visit(_f);
}

template <class F, class... AlternativeTypes>
inline internal::variant::result_t<F, AlternativeTypes...> visit(
const F& _f, Variant<AlternativeTypes...>&& _v) {
return _v.visit(_f);
}

template <class F, class... AlternativeTypes>
inline internal::variant::result_t<F, AlternativeTypes...> visit(
const F& _f, const Variant<AlternativeTypes...>& _v) {
return _v.visit(_f);
}

template <class F, internal::StringLiteral _discriminator,
class... AlternativeTypes>
inline internal::variant::result_t<F, AlternativeTypes...> visit(
F& _f, TaggedUnion<_discriminator, AlternativeTypes...>& _tagged_union) {
return _tagged_union.variant().visit(_f);
}

template <class F, internal::StringLiteral _discriminator,
class... AlternativeTypes>
inline internal::variant::result_t<F, AlternativeTypes...> visit(
F& _f, TaggedUnion<_discriminator, AlternativeTypes...>&& _tagged_union) {
return _tagged_union.variant().visit(_f);
}

template <class F, internal::StringLiteral _discriminator,
class... AlternativeTypes>
inline internal::variant::result_t<F, AlternativeTypes...> visit(
F& _f,
const TaggedUnion<_discriminator, AlternativeTypes...>& _tagged_union) {
return _tagged_union.variant().visit(_f);
}

template <class F, internal::StringLiteral _discriminator,
template <class F, TaggedUnionBased T, internal::StringLiteral _discriminator,
class... AlternativeTypes>
inline internal::variant::result_t<F, AlternativeTypes...> visit(
const F& _f,
TaggedUnion<_discriminator, AlternativeTypes...>& _tagged_union) {
return _tagged_union.variant().visit(_f);
}

template <class F, internal::StringLiteral _discriminator,
class... AlternativeTypes>
inline internal::variant::result_t<F, AlternativeTypes...> visit(
const F& _f,
TaggedUnion<_discriminator, AlternativeTypes...>&& _tagged_union) {
return _tagged_union.variant().visit(_f);
}

template <class F, internal::StringLiteral _discriminator,
class... AlternativeTypes>
inline internal::variant::result_t<F, AlternativeTypes...> visit(
const F& _f,
const TaggedUnion<_discriminator, AlternativeTypes...>& _tagged_union) {
return _tagged_union.variant().visit(_f);
}

template <class F, class Head, class... Tail>
inline auto visit(F& _f, Head& _head, Tail&... _tail) {
const auto f_outer = [&](auto& _h) {
const auto f_inner = [&](auto&... _t) { return _f(_h, _t...); };
return visit(f_inner, _tail...);
};
return _head.visit(f_outer);
inline auto visit(
F&& _f, TaggedUnion<_discriminator, AlternativeTypes...>&& _tagged_union)
-> decltype(std::declval<T&&>().variant().visit(std::declval<F&&>())) {
return std::forward<T>(_tagged_union).variant().visit(std::forward<F>(_f));
}

template <class F, class Head, class... Tail>
inline auto visit(const F& _f, Head& _head, Tail&... _tail) {
inline auto visit(F&& _f, Head&& _head, Tail&&... _tail) {
const auto f_outer = [&](auto& _h) {
const auto f_inner = [&](auto&... _t) { return _f(_h, _t...); };
return visit(f_inner, _tail...);
const auto f_inner = [&](auto&... _t) {
return std::forward<F>(_f)(_h, _t...);
};
return visit(f_inner, std::forward<Tail>(_tail)...);
};
return _head.visit(f_outer);
return std::forward<Head>(_head).visit(f_outer);
}

} // namespace rfl
Expand Down
25 changes: 24 additions & 1 deletion tests/json/test_rfl_variant_visit_move_only.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include <cassert>
#include <iostream>
#include <memory>
#include <rfl.hpp>
#include <rfl/json.hpp>
Expand All @@ -19,6 +18,10 @@ struct Square {
std::unique_ptr<double> width;
};

struct Circle {
double const radius;
};

using Shapes = rfl::Variant<Rectangle, Square>;

TEST(json, test_rfl_variant_visit_move_only) {
Expand All @@ -29,4 +32,24 @@ TEST(json, test_rfl_variant_visit_move_only) {
};
EXPECT_EQ(*rfl::visit(get_width, r), 5.0);
}

TEST(json, test_rfl_variant_visit_return_move_only_const) {
auto const circle = Circle{.radius = 10};
auto const variant = rfl::Variant<Circle>{circle};
const auto get_radius = [](const auto& object) -> const double& {
return object.radius;
};
auto const result = rfl::visit(get_radius, variant);
EXPECT_EQ(result, circle.radius);
}

TEST(json, test_rfl_variant_visit_return_move_only) {
auto circle = Circle{.radius = 10};
auto variant = rfl::Variant<Circle>{circle};
auto const get_radius = [](auto& object) -> double const& {
return object.radius;
};
auto result = rfl::visit(get_radius, variant);
EXPECT_EQ(result, circle.radius);
}
} // namespace test_rfl_variant_visit_move_only

0 comments on commit 716de7b

Please sign in to comment.