From 0fae98717bb1ea3afd8b87aa0531c5d0f056f1f2 Mon Sep 17 00:00:00 2001 From: Hudd Date: Wed, 12 Apr 2023 14:00:58 +0400 Subject: [PATCH] utility: Imporve paired() Changed paired() such that it doesn't access invalid memory when passed a rvalue --- utility/include/aw/utility/ranges/paired.h | 91 +++++++++++++--------- utility/tests/ranges_paired.c++ | 11 +++ 2 files changed, 66 insertions(+), 36 deletions(-) diff --git a/utility/include/aw/utility/ranges/paired.h b/utility/include/aw/utility/ranges/paired.h index d0fd24ce..99107272 100644 --- a/utility/include/aw/utility/ranges/paired.h +++ b/utility/include/aw/utility/ranges/paired.h @@ -9,25 +9,52 @@ */ #ifndef aw_utility_ranges_paired_h #define aw_utility_ranges_paired_h -#include +#include "iter_range.h" + #include +#include namespace aw { template struct pairs_sentinel : public std::pair { using std::pair::pair; }; -template +namespace adl { +using std::begin; +using std::end; +template +using begin_type = decltype( begin(std::declval()) ); +template +using end_type = decltype( end(std::declval()) ); +} // namespace adl + +template struct pairs_iterator { private: - using reference1 = typename std::iterator_traits::reference; - using reference2 = typename std::iterator_traits::reference; + using iterator1 = adl::begin_type; + using iterator2 = adl::begin_type; + + using sentinel1 = adl::end_type; + using sentinel2 = adl::end_type; + + using reference1 = typename std::iterator_traits::reference; + using reference2 = typename std::iterator_traits::reference; + + using iter_pair = std::pair; + + constexpr iter_pair make_pair() + { + using std::begin; + return { begin(ranges.first), begin(ranges.second) }; + } public: - using sentinel = pairs_sentinel; + using iterator = pairs_iterator; + using sentinel = pairs_sentinel; - constexpr pairs_iterator(Base1 iter1, Base2 iter2) - : iters{iter1, iter2} + constexpr pairs_iterator(Range1&& first, Range2&& second) + : ranges(std::forward(first), std::forward(second)) + , iters(make_pair()) {} constexpr std::pair operator*() @@ -49,49 +76,41 @@ struct pairs_iterator { iters.second != other.second; } - constexpr bool operator!=(Base1 const& it) + constexpr bool operator!=(iterator1 const& it) { return iters.first != it; } - std::pair iters; -}; - -namespace adl { -using std::begin; -using std::end; -template -using begin_type = decltype( begin(std::declval()) ); -template -using end_type = decltype( begin(std::declval()) ); -} // namespace adl - -template -struct pairs_adapter { -private: - using _iter1 = adl::begin_type; - using _iter2 = adl::begin_type; -public: - using iterator = pairs_iterator<_iter1, _iter2>; - using sentinel = pairs_sentinel<_iter1, _iter2>; - - constexpr pairs_adapter(Range1 range1, Range2 range2) - : ranges{range1, range2} - {} - constexpr iterator begin() { using std::begin; - return {begin(ranges.first), begin(ranges.second)}; + return *this; } constexpr sentinel end() { using std::end; - return {end(ranges.first), end(ranges.second)}; + return { end(ranges.first), end(ranges.second) }; + } + + constexpr iter_range first() + { + using std::end; + return { iters.first, end(ranges.first) }; } + constexpr iter_range second() + { + using std::end; + return { iters.second, end(ranges.second) }; + } + + // Safety thing for temporary ranges: + // If an rvalue is passed to paired() it gets stored here std::pair ranges; + + // Current state + std::pair iters; }; /*! @@ -102,7 +121,7 @@ struct pairs_adapter { template constexpr auto paired(Ranges&&... ranges) { - return pairs_adapter(std::forward(ranges)...); + return pairs_iterator(std::forward(ranges)...); } } // namespace aw #endif//aw_utility_ranges_paired_h diff --git a/utility/tests/ranges_paired.c++ b/utility/tests/ranges_paired.c++ index a4464df1..69f4dfa6 100644 --- a/utility/tests/ranges_paired.c++ +++ b/utility/tests/ranges_paired.c++ @@ -32,4 +32,15 @@ Test(paired_ref) { const std::vector expected({ 1, 3, 3, 5, 5, 7, 7 }); TestEqual(vec2, expected); } + +Test(paired_rvalue) { + std::vector vec2(10, 0); + + for (auto [v1, v2] : paired(std::vector{ 1, 2, 3, 4, 5, 6, 7 }, vec2)) { + v2 = static_cast(v1) | 0x1; + } + + const std::vector expected({ 1, 3, 3, 5, 5, 7, 7, 0, 0, 0 }); + TestEqual(vec2, expected); +} } // namespace aw