Skip to content

Commit

Permalink
Allow TreeTraversal to run individual queries in non-batch mode (#917)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Arndt <[email protected]>
  • Loading branch information
masterleinad and Daniel Arndt authored Sep 20, 2023
1 parent 3f8fbc9 commit 2d3c687
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 5 deletions.
26 changes: 26 additions & 0 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
namespace ArborX
{

namespace Experimental
{
struct PerThread
{};
} // namespace Experimental

namespace Details
{
struct HappyTreeFriends;
Expand Down Expand Up @@ -101,6 +107,18 @@ class BasicBoundingVolumeHierarchy
std::forward<View>(view), std::forward<Args>(args)...);
}

template <typename Predicate, typename Callback>
KOKKOS_FUNCTION void query(Experimental::PerThread,
Predicate const &predicate,
Callback const &callback) const
{
ArborX::Details::TreeTraversal<BasicBoundingVolumeHierarchy,
/* Predicates Dummy */ std::true_type,
Callback, typename Predicate::Tag>
tree_traversal(*this, callback);
tree_traversal(predicate);
}

private:
friend struct Details::HappyTreeFriends;

Expand Down Expand Up @@ -161,6 +179,14 @@ class BoundingVolumeHierarchy
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
}

template <typename Predicate, typename Callback>
KOKKOS_FUNCTION void query(Experimental::PerThread tag,
Predicate const &predicate,
Callback const &callback) const
{
base_type::query(tag, predicate, callback);
}
};

template <typename MemorySpace>
Expand Down
37 changes: 32 additions & 5 deletions src/details/ArborX_DetailsTreeTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,23 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
else
{
Kokkos::parallel_for("ArborX::TreeTraversal::spatial",
Kokkos::RangePolicy<ExecutionSpace>(
Kokkos::RangePolicy<ExecutionSpace, FullTree>(
space, 0, Access::size(predicates)),
*this);
}
}

KOKKOS_FUNCTION TreeTraversal(BVH const &bvh, Callback const &callback)
: _bvh{bvh}
, _callback{callback}
{}

struct OneLeafTree
{};

struct FullTree
{};

KOKKOS_FUNCTION void operator()(OneLeafTree, int queryIndex) const
{
auto const &predicate = Access::get(_predicates, queryIndex);
Expand All @@ -84,10 +92,15 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
}
}

KOKKOS_FUNCTION void operator()(int queryIndex) const
KOKKOS_FUNCTION void operator()(FullTree, int queryIndex) const
{
auto const &predicate = Access::get(_predicates, queryIndex);
operator()(predicate);
}

template <typename Predicate>
KOKKOS_FUNCTION void operator()(Predicate const &predicate) const
{
int node = HappyTreeFriends::getRoot(_bvh); // start with root
do
{
Expand Down Expand Up @@ -415,15 +428,23 @@ struct TreeTraversal<BVH, Predicates, Callback,
{
Kokkos::parallel_for(
"ArborX::Experimental::TreeTraversal::OrderedSpatialPredicate",
Kokkos::RangePolicy<ExecutionSpace>(space, 0,
Access::size(predicates)),
Kokkos::RangePolicy<ExecutionSpace, FullTree>(
space, 0, Access::size(predicates)),
*this);
}
}

KOKKOS_FUNCTION TreeTraversal(BVH const &bvh, Callback const &callback)
: _bvh{bvh}
, _callback{callback}
{}

struct OneLeafTree
{};

struct FullTree
{};

KOKKOS_FUNCTION void operator()(OneLeafTree, int queryIndex) const
{
auto const &predicate = Access::get(_predicates, queryIndex);
Expand All @@ -440,9 +461,15 @@ struct TreeTraversal<BVH, Predicates, Callback,
}
}

KOKKOS_FUNCTION void operator()(int queryIndex) const
KOKKOS_FUNCTION void operator()(FullTree, int queryIndex) const
{
auto const &predicate = Access::get(_predicates, queryIndex);
operator()(predicate);
}

template <typename Predicate>
KOKKOS_FUNCTION void operator()(Predicate const &predicate) const
{
using ArborX::Details::HappyTreeFriends;

using distance_type = decltype(predicate.distance(
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ foreach(_test Callbacks Degenerate ManufacturedSolution ComparisonWithBoost)
endforeach()
endforeach()
list(APPEND ARBORX_TEST_QUERY_TREE_SOURCES
tstQueryTreeCallbackQueryPerThread.cpp
tstQueryTreeRay.cpp
tstQueryTreeTraversalPolicy.cpp
tstQueryTreeIntersectsKDOP.cpp
Expand Down
113 changes: 113 additions & 0 deletions test/tstQueryTreeCallbackQueryPerThread.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/****************************************************************************
* Copyright (c) 2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
* distributed under a BSD 3-clause license. For the licensing terms see *
* the LICENSE file in the top-level directory. *
* *
* SPDX-License-Identifier: BSD-3-Clause *
****************************************************************************/

#include "ArborX_EnableDeviceTypes.hpp" // ARBORX_DEVICE_TYPES
#include <ArborX.hpp>

#include "BoostTest_CUDA_clang_workarounds.hpp"
#include <boost/test/unit_test.hpp>

#include <numeric>
#include <vector>

BOOST_AUTO_TEST_SUITE(PerThread)

struct IntersectionCallback
{
int query_index;
bool &success;

template <typename Query, typename Value>
KOKKOS_FUNCTION void operator()(Query const &, Value const &value) const
{
success = (query_index == value.index);
}
};

BOOST_AUTO_TEST_CASE_TEMPLATE(callback_intersects, DeviceType,
ARBORX_DEVICE_TYPES)
{
using MemorySpace = typename DeviceType::memory_space;
using ExecutionSpace = typename DeviceType::execution_space;
using Tree = ArborX::BVH<MemorySpace>;

int const n = 10;
Kokkos::View<ArborX::Point *, DeviceType> points(
Kokkos::view_alloc(Kokkos::WithoutInitializing, "points"), n);
Kokkos::parallel_for(
Kokkos::RangePolicy<ExecutionSpace>(0, n), KOKKOS_LAMBDA(int i) {
points(i) = {{(float)i, (float)i, (float)i}};
});

Tree const tree(ExecutionSpace{}, points);

bool success;
Kokkos::parallel_reduce(
Kokkos::RangePolicy<ExecutionSpace>(0, n),
KOKKOS_LAMBDA(int i, bool &update) {
float center = i;
ArborX::Box box{{center - .5f, center - .5f, center - .5f},
{center + .5f, center + .5f, center + .5f}};
tree.query(ArborX::Experimental::PerThread{}, ArborX::intersects(box),
IntersectionCallback{i, update});
},
Kokkos::LAnd<bool, Kokkos::HostSpace>(success));

BOOST_TEST(success);
}

struct OrderedIntersectionCallback
{
int query_index;
bool &success;

template <typename Query, typename Value>
KOKKOS_FUNCTION auto operator()(Query const &, Value const &value) const
{
success = (query_index == value.index);
return ArborX::CallbackTreeTraversalControl::early_exit;
}
};

BOOST_AUTO_TEST_CASE_TEMPLATE(callback_ordered_intersects, DeviceType,
ARBORX_DEVICE_TYPES)
{
using MemorySpace = typename DeviceType::memory_space;
using ExecutionSpace = typename DeviceType::execution_space;
using Tree = ArborX::BVH<MemorySpace>;

int const n = 10;
Kokkos::View<ArborX::Point *, DeviceType> points(
Kokkos::view_alloc(Kokkos::WithoutInitializing, "points"), n);
Kokkos::parallel_for(
Kokkos::RangePolicy<ExecutionSpace>(0, n), KOKKOS_LAMBDA(int i) {
points(i) = {{(float)i, (float)i, (float)i}};
});

Tree const tree(ExecutionSpace{}, points);

bool success;
Kokkos::parallel_reduce(
Kokkos::RangePolicy<ExecutionSpace>(0, n),
KOKKOS_LAMBDA(int i, bool &update) {
float center = i;
ArborX::Box box{{center - .5f, center - .5f, center - .5f},
{center + .5f, center + .5f, center + .5f}};
tree.query(ArborX::Experimental::PerThread{},
ArborX::Experimental::ordered_intersects(box),
OrderedIntersectionCallback{i, update});
},
Kokkos::LAnd<bool, Kokkos::HostSpace>(success));

BOOST_TEST(success);
}

BOOST_AUTO_TEST_SUITE_END()

0 comments on commit 2d3c687

Please sign in to comment.