diff --git a/src/ArborX_LinearBVH.hpp b/src/ArborX_LinearBVH.hpp index 9279f333b..a7b915b63 100644 --- a/src/ArborX_LinearBVH.hpp +++ b/src/ArborX_LinearBVH.hpp @@ -36,6 +36,12 @@ namespace ArborX { +namespace Experimental +{ +struct PerThread +{}; +} // namespace Experimental + namespace Details { struct HappyTreeFriends; @@ -101,6 +107,18 @@ class BasicBoundingVolumeHierarchy std::forward(view), std::forward(args)...); } + template + KOKKOS_FUNCTION void query(Experimental::PerThread, + Predicate const &predicate, + Callback const &callback) const + { + ArborX::Details::TreeTraversal + tree_traversal(*this, callback); + tree_traversal(predicate); + } + private: friend struct Details::HappyTreeFriends; @@ -161,6 +179,14 @@ class BoundingVolumeHierarchy std::forward(callback_or_view), std::forward(view), std::forward(args)...); } + + template + KOKKOS_FUNCTION void query(Experimental::PerThread tag, + Predicate const &predicate, + Callback const &callback) const + { + base_type::query(tag, predicate, callback); + } }; template diff --git a/src/details/ArborX_DetailsTreeTraversal.hpp b/src/details/ArborX_DetailsTreeTraversal.hpp index c94c2c5cb..ce758c6aa 100644 --- a/src/details/ArborX_DetailsTreeTraversal.hpp +++ b/src/details/ArborX_DetailsTreeTraversal.hpp @@ -63,15 +63,23 @@ struct TreeTraversal else { Kokkos::parallel_for("ArborX::TreeTraversal::spatial", - Kokkos::RangePolicy( + Kokkos::RangePolicy( space, 0, Access::size(predicates)), *this); } } + KOKKOS_FUNCTION TreeTraversal(BVH const &bvh, Callback const &callback) + : _bvh{bvh} + , _callback{callback} + {} + struct OneLeafTree {}; + struct FullTree + {}; + KOKKOS_FUNCTION void operator()(OneLeafTree, int queryIndex) const { auto const &predicate = Access::get(_predicates, queryIndex); @@ -84,10 +92,15 @@ struct TreeTraversal } } - KOKKOS_FUNCTION void operator()(int queryIndex) const + KOKKOS_FUNCTION void operator()(FullTree, int queryIndex) const { auto const &predicate = Access::get(_predicates, queryIndex); + operator()(predicate); + } + template + KOKKOS_FUNCTION void operator()(Predicate const &predicate) const + { int node = HappyTreeFriends::getRoot(_bvh); // start with root do { @@ -415,15 +428,23 @@ struct TreeTraversal(space, 0, - Access::size(predicates)), + Kokkos::RangePolicy( + space, 0, Access::size(predicates)), *this); } } + KOKKOS_FUNCTION TreeTraversal(BVH const &bvh, Callback const &callback) + : _bvh{bvh} + , _callback{callback} + {} + struct OneLeafTree {}; + struct FullTree + {}; + KOKKOS_FUNCTION void operator()(OneLeafTree, int queryIndex) const { auto const &predicate = Access::get(_predicates, queryIndex); @@ -440,9 +461,15 @@ struct TreeTraversal + KOKKOS_FUNCTION void operator()(Predicate const &predicate) const + { using ArborX::Details::HappyTreeFriends; using distance_type = decltype(predicate.distance( diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 85dbc5dfa..7afac2d56 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -124,6 +124,7 @@ foreach(_test Callbacks Degenerate ManufacturedSolution ComparisonWithBoost) endforeach() endforeach() list(APPEND ARBORX_TEST_QUERY_TREE_SOURCES + tstQueryTreeCallbackQueryPerThread.cpp tstQueryTreeRay.cpp tstQueryTreeTraversalPolicy.cpp tstQueryTreeIntersectsKDOP.cpp diff --git a/test/tstQueryTreeCallbackQueryPerThread.cpp b/test/tstQueryTreeCallbackQueryPerThread.cpp new file mode 100644 index 000000000..2734c4bf8 --- /dev/null +++ b/test/tstQueryTreeCallbackQueryPerThread.cpp @@ -0,0 +1,113 @@ +/**************************************************************************** + * Copyright (c) 2023 by the ArborX authors * + * All rights reserved. * + * * + * This file is part of the ArborX library. ArborX is * + * distributed under a BSD 3-clause license. For the licensing terms see * + * the LICENSE file in the top-level directory. * + * * + * SPDX-License-Identifier: BSD-3-Clause * + ****************************************************************************/ + +#include "ArborX_EnableDeviceTypes.hpp" // ARBORX_DEVICE_TYPES +#include + +#include "BoostTest_CUDA_clang_workarounds.hpp" +#include + +#include +#include + +BOOST_AUTO_TEST_SUITE(PerThread) + +struct IntersectionCallback +{ + int query_index; + bool &success; + + template + KOKKOS_FUNCTION void operator()(Query const &, Value const &value) const + { + success = (query_index == value.index); + } +}; + +BOOST_AUTO_TEST_CASE_TEMPLATE(callback_intersects, DeviceType, + ARBORX_DEVICE_TYPES) +{ + using MemorySpace = typename DeviceType::memory_space; + using ExecutionSpace = typename DeviceType::execution_space; + using Tree = ArborX::BVH; + + int const n = 10; + Kokkos::View points( + Kokkos::view_alloc(Kokkos::WithoutInitializing, "points"), n); + Kokkos::parallel_for( + Kokkos::RangePolicy(0, n), KOKKOS_LAMBDA(int i) { + points(i) = {{(float)i, (float)i, (float)i}}; + }); + + Tree const tree(ExecutionSpace{}, points); + + bool success; + Kokkos::parallel_reduce( + Kokkos::RangePolicy(0, n), + KOKKOS_LAMBDA(int i, bool &update) { + float center = i; + ArborX::Box box{{center - .5f, center - .5f, center - .5f}, + {center + .5f, center + .5f, center + .5f}}; + tree.query(ArborX::Experimental::PerThread{}, ArborX::intersects(box), + IntersectionCallback{i, update}); + }, + Kokkos::LAnd(success)); + + BOOST_TEST(success); +} + +struct OrderedIntersectionCallback +{ + int query_index; + bool &success; + + template + KOKKOS_FUNCTION auto operator()(Query const &, Value const &value) const + { + success = (query_index == value.index); + return ArborX::CallbackTreeTraversalControl::early_exit; + } +}; + +BOOST_AUTO_TEST_CASE_TEMPLATE(callback_ordered_intersects, DeviceType, + ARBORX_DEVICE_TYPES) +{ + using MemorySpace = typename DeviceType::memory_space; + using ExecutionSpace = typename DeviceType::execution_space; + using Tree = ArborX::BVH; + + int const n = 10; + Kokkos::View points( + Kokkos::view_alloc(Kokkos::WithoutInitializing, "points"), n); + Kokkos::parallel_for( + Kokkos::RangePolicy(0, n), KOKKOS_LAMBDA(int i) { + points(i) = {{(float)i, (float)i, (float)i}}; + }); + + Tree const tree(ExecutionSpace{}, points); + + bool success; + Kokkos::parallel_reduce( + Kokkos::RangePolicy(0, n), + KOKKOS_LAMBDA(int i, bool &update) { + float center = i; + ArborX::Box box{{center - .5f, center - .5f, center - .5f}, + {center + .5f, center + .5f, center + .5f}}; + tree.query(ArborX::Experimental::PerThread{}, + ArborX::Experimental::ordered_intersects(box), + OrderedIntersectionCallback{i, update}); + }, + Kokkos::LAnd(success)); + + BOOST_TEST(success); +} + +BOOST_AUTO_TEST_SUITE_END()