diff --git a/include/Quadtree.h b/include/Quadtree.h index f4af30f..010b069 100644 --- a/include/Quadtree.h +++ b/include/Quadtree.h @@ -10,7 +10,27 @@ namespace quadtree { -template, typename Float = float> +namespace detail +{ + template + struct StdMakeUnique + { + template + std::unique_ptr operator() (Args&&... args) + { + return std::make_unique(std::forward(args)...); + } + }; +} + +template< + typename T, + typename GetBox, + typename Equal = std::equal_to, + typename Float = float, + template class Allocator = std::allocator, + template class MakeUnique = detail::StdMakeUnique +> class Quadtree { static_assert(std::is_convertible_v, Box>, @@ -20,9 +40,12 @@ class Quadtree static_assert(std::is_arithmetic_v); public: + template + using vector_type = std::vector< U, Allocator >; + Quadtree(const Box& box, const GetBox& getBox = GetBox(), const Equal& equal = Equal()) : - mBox(box), mRoot(std::make_unique()), mGetBox(getBox), mEqual(equal) + mBox(box), mMakeUnique(), mRoot(mMakeUnique()), mGetBox(getBox), mEqual(equal) { } @@ -37,16 +60,16 @@ class Quadtree remove(mRoot.get(), nullptr, mBox, value); } - std::vector query(const Box& box) const + vector_type query(const Box& box) const { - auto values = std::vector(); + auto values = vector_type(); query(mRoot.get(), mBox, box, values); return values; } - std::vector> findAllIntersections() const + vector_type> findAllIntersections() const { - auto intersections = std::vector>(); + auto intersections = vector_type>(); findAllIntersections(mRoot.get(), intersections); return intersections; } @@ -55,14 +78,22 @@ class Quadtree static constexpr auto Threshold = std::size_t(16); static constexpr auto MaxDepth = std::size_t(8); + struct Node; +#if __cplusplus < 201703L + typedef typename std::result_of()>::type UniqueNodePtr; +#else + typedef std::invoke_result_t> UniqueNodePtr; +#endif + struct Node { - std::array, 4> children; - std::vector values; + std::array children; + vector_type values; }; Box mBox; - std::unique_ptr mRoot; + MakeUnique mMakeUnique; + UniqueNodePtr mRoot; GetBox mGetBox; Equal mEqual; @@ -163,9 +194,9 @@ class Quadtree assert(isLeaf(node) && "Only leaves can be split"); // Create children for (auto& child : node->children) - child = std::make_unique(); + child = mMakeUnique(); // Assign values to children - auto newValues = std::vector(); // New values for this node + auto newValues = vector_type(); // New values for this node for (const auto& value : node->values) { auto i = getQuadrant(box, mGetBox(value)); @@ -238,7 +269,7 @@ class Quadtree } } - void query(Node* node, const Box& box, const Box& queryBox, std::vector& values) const + void query(Node* node, const Box& box, const Box& queryBox, vector_type& values) const { assert(node != nullptr); assert(queryBox.intersects(box)); @@ -258,7 +289,7 @@ class Quadtree } } - void findAllIntersections(Node* node, std::vector>& intersections) const + void findAllIntersections(Node* node, vector_type>& intersections) const { // Find intersections between values stored in this node // Make sure to not report the same intersection twice @@ -284,7 +315,7 @@ class Quadtree } } - void findIntersectionsInDescendants(Node* node, const T& value, std::vector>& intersections) const + void findIntersectionsInDescendants(Node* node, const T& value, vector_type>& intersections) const { // Test against the values stored in this node for (const auto& other : node->values)