Skip to content

Commit

Permalink
Merge pull request arborx#733 from masterleinad/extend_distributed_tr…
Browse files Browse the repository at this point in the history
…ee_callback

Add DistributedTree query only taking a callback
  • Loading branch information
aprokop authored Jul 1, 2024
2 parents 5bc2187 + 627eb10 commit 3e8d9ee
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 52 deletions.
7 changes: 7 additions & 0 deletions src/ArborX_DistributedTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ class DistributedTree<MemorySpace, Details::LegacyDefaultTemplateValue,
std::forward<OffsetView>(offset));
}

template <typename ExecutionSpace, typename UserPredicates, typename Callback>
void query(ExecutionSpace const &space, UserPredicates const &user_predicates,
Callback &&callback) const
{
base_type::query(space, user_predicates, std::forward<Callback>(callback));
}

template <typename ExecutionSpace, typename UserPredicates, typename Callback,
typename Indices, typename Offset>
void query(ExecutionSpace const &space, UserPredicates const &user_predicates,
Expand Down
7 changes: 7 additions & 0 deletions src/details/ArborX_DetailsDistributedTreeImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ struct DistributedTreeImpl
ExecutionSpace const &space, Predicates const &queries,
Callback const &callback, OutputView &out, OffsetView &offset);

template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename Callback>
static void queryDispatch(SpatialPredicateTag, DistributedTree const &tree,
ExecutionSpace const &space,
Predicates const &predicates,
Callback const &callback);

// nearest neighbors queries
template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename Callback, typename Indices,
Expand Down
40 changes: 38 additions & 2 deletions src/details/ArborX_DetailsDistributedTreeSpatial.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,52 @@ DistributedTreeImpl::queryDispatch(SpatialPredicateTag, Tree const &tree,

using MemorySpace = typename Tree::memory_space;

auto const &top_tree = tree._top_tree;

Kokkos::View<int *, MemorySpace> intersected_ranks(
"ArborX::DistributedTree::query::spatial::intersected_ranks", 0);
tree._top_tree.query(space, predicates, LegacyDefaultCallback{},
intersected_ranks, offset);
top_tree.query(space, predicates, intersected_ranks, offset);

DistributedTree::forwardQueriesAndCommunicateResults(
tree.getComm(), space, tree._bottom_tree, predicates, callback,
intersected_ranks, offset, values);
}

template <typename Tree, typename ExecutionSpace, typename Predicates,
typename Callback>
void DistributedTreeImpl::queryDispatch(SpatialPredicateTag, Tree const &tree,
ExecutionSpace const &space,
Predicates const &predicates,
Callback const &callback)
{
std::string prefix = "ArborX::DistributedTree::query::spatial(pure)";

Kokkos::Profiling::ScopedRegion guard(prefix);

if (tree.empty())
return;

using MemorySpace = typename Tree::memory_space;
using namespace DistributedTree;

auto const &top_tree = tree._top_tree;
auto const &bottom_tree = tree._bottom_tree;
auto comm = tree.getComm();

Kokkos::View<int *, MemorySpace> intersected_ranks(
prefix + "::intersected_ranks", 0);
Kokkos::View<int *, MemorySpace> offset(prefix + "::offset", 0);
top_tree.query(space, predicates, intersected_ranks, offset);

using Query = typename Predicates::value_type;
Kokkos::View<Query *, MemorySpace> fwd_predicates(prefix + "::fwd_predicates",
0);
forwardQueries(comm, space, predicates, intersected_ranks, offset,
fwd_predicates);

bottom_tree.query(space, fwd_predicates, callback);
}

template <typename Tree, typename ExecutionSpace, typename Predicates,
typename Values, typename Offset>
std::enable_if_t<Kokkos::is_view_v<Values> && Kokkos::is_view_v<Offset>>
Expand Down
111 changes: 61 additions & 50 deletions src/details/ArborX_DetailsDistributedTreeUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ void forwardQueries(MPI_Comm comm, ExecutionSpace const &space,
Offset const &offset, FwdQueries &fwd_queries,
FwdIds &fwd_ids, Ranks &fwd_ranks)
{
Kokkos::Profiling::ScopedRegion guard(
"ArborX::DistributedTree::forwardQueries");
std::string prefix = "ArborX::DistributedTree::query::forwardQueries";

Kokkos::Profiling::ScopedRegion guard(prefix);

using MemorySpace = typename Predicates::memory_space;
using Query = typename Predicates::value_type;
Expand All @@ -68,77 +69,87 @@ void forwardQueries(MPI_Comm comm, ExecutionSpace const &space,
int const n_exports = KokkosExt::lastElement(space, offset);
int const n_imports = distributor.createFromSends(space, indices);

static_assert(std::is_same_v<Query, typename Predicates::value_type>);

{
Kokkos::View<int *, MemorySpace> export_ranks(
Kokkos::view_alloc(
space, Kokkos::WithoutInitializing,
"ArborX::DistributedTree::query::forwardQueries::export_ranks"),
n_exports);
Kokkos::deep_copy(space, export_ranks, comm_rank);

Kokkos::View<int *, MemorySpace> import_ranks(
Kokkos::view_alloc(
space, Kokkos::WithoutInitializing,
"ArborX::DistributedTree::query::forwardQueries::import_ranks"),
n_imports);

distributor.doPostsAndWaits(space, export_ranks, import_ranks);
fwd_ranks = import_ranks;
}

{
Kokkos::View<Query *, MemorySpace> exports(
Kokkos::view_alloc(
space, Kokkos::WithoutInitializing,
"ArborX::DistributedTree::query::forwardQueries::exports"),
Kokkos::View<Query *, MemorySpace> export_queries(
Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
prefix + "::export_queries"),
n_exports);
Kokkos::parallel_for(
"ArborX::DistributedTree::query::forward_queries_fill_buffer",
prefix + "::forward_queries_fill_buffer",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_queries),
KOKKOS_LAMBDA(int q) {
for (int i = offset(q); i < offset(q + 1); ++i)
{
exports(i) = queries(q);
}
export_queries(i) = queries(q);
});
Kokkos::View<Query *, MemorySpace> imports(
Kokkos::view_alloc(
space, Kokkos::WithoutInitializing,
"ArborX::DistributedTree::query::forwardQueries::imports"),
n_imports);

distributor.doPostsAndWaits(space, exports, imports);
fwd_queries = imports;
KokkosExt::reallocWithoutInitializing(space, fwd_queries, n_imports);
distributor.doPostsAndWaits(space, export_queries, fwd_queries);
}

{
Kokkos::View<int *, MemorySpace> export_ranks(
Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
prefix + "::export_ranks"),
n_exports);
Kokkos::deep_copy(space, export_ranks, comm_rank);

KokkosExt::reallocWithoutInitializing(space, fwd_ranks, n_imports);
distributor.doPostsAndWaits(space, export_ranks, fwd_ranks);
}

{
Kokkos::View<int *, MemorySpace> export_ids(
Kokkos::view_alloc(
space, Kokkos::WithoutInitializing,
"ArborX::DistributedTree::query::forwardQueries::export_ids"),
Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
prefix + "::export_ids"),
n_exports);
Kokkos::parallel_for(
"ArborX::DistributedTree::query::forward_queries_fill_ids",
prefix + "::forward_queries_fill_ids",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_queries),
KOKKOS_LAMBDA(int q) {
for (int i = offset(q); i < offset(q + 1); ++i)
{
export_ids(i) = q;
}
});
Kokkos::View<int *, MemorySpace> import_ids(
Kokkos::view_alloc(
space, Kokkos::WithoutInitializing,
"ArborX::DistributedTree::query::forwardQueries::import_ids"),
n_imports);

distributor.doPostsAndWaits(space, export_ids, import_ids);
fwd_ids = import_ids;
KokkosExt::reallocWithoutInitializing(space, fwd_ids, n_imports);
distributor.doPostsAndWaits(space, export_ids, fwd_ids);
}
}

template <typename ExecutionSpace, typename Predicates, typename Indices,
typename Offset, typename FwdQueries>
void forwardQueries(MPI_Comm comm, ExecutionSpace const &space,
Predicates const &queries, Indices const &indices,
Offset const &offset, FwdQueries &fwd_queries)
{
std::string prefix =
"ArborX::DistributedTree::query::forwardQueries(partial)";

Kokkos::Profiling::ScopedRegion guard(prefix);

using MemorySpace = typename Predicates::memory_space;
using Query = typename Predicates::value_type;

Distributor<MemorySpace> distributor(comm);

int const n_exports = KokkosExt::lastElement(space, offset);
int const n_imports = distributor.createFromSends(space, indices);

Kokkos::View<Query *, MemorySpace> export_queries(
Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
prefix + "::export_queries"),
n_exports);
Kokkos::parallel_for(
prefix + "::forward_queries_fill_buffer",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, queries.size()),
KOKKOS_LAMBDA(int q) {
for (int i = offset(q); i < offset(q + 1); ++i)
export_queries(i) = queries(q);
});

KokkosExt::reallocWithoutInitializing(space, fwd_queries, n_imports);
distributor.doPostsAndWaits(space, export_queries, fwd_queries);
}

template <typename ExecutionSpace, typename OutputView, typename Offset,
typename Ranks, typename Ids>
void communicateResultsBack(MPI_Comm comm, ExecutionSpace const &space,
Expand Down
63 changes: 63 additions & 0 deletions test/tstDistributedTreeSpatial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ struct CustomPostCallbackWithAttachment
});
}
};

BOOST_AUTO_TEST_CASE_TEMPLATE(callback_with_attachment, DeviceType,
ARBORX_DEVICE_TYPES)
{
Expand Down Expand Up @@ -376,6 +377,68 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(callback_with_attachment, DeviceType,
}
}

template <typename DeviceType>
struct CustomPureInlineCallback
{
Kokkos::View<int *, DeviceType> counts;

template <typename Query>
KOKKOS_FUNCTION void operator()(Query const &, int index) const
{
Kokkos::atomic_inc(&counts(index));
}
};

BOOST_AUTO_TEST_CASE_TEMPLATE(pure_spatial_callback, DeviceType,
ARBORX_DEVICE_TYPES)
{
using ExecutionSpace = typename DeviceType::execution_space;

MPI_Comm comm = MPI_COMM_WORLD;
int comm_rank;
MPI_Comm_rank(comm, &comm_rank);
int comm_size;
MPI_Comm_size(comm, &comm_size);

// +----------0----------1----------2----------3
// | | | | |
// | | | | |
// | | | | |
// | | | | |
// 0----------1----------2----------3----------+
// [ rank 0 ]
// [ rank 1 ]
// [ rank 2 ]
// [ rank 3 ]
auto const tree = makeDistributedTree<DeviceType>(
comm, {{{{(float)comm_rank, 0., 0.}}, {{(float)comm_rank + 1, 1., 1.}}}});

// +--------0---------1----------2---------3
// | | | | |
// | | | | |
// | | | | |
// | | | | |
// 0--------1----x----2-----x----3----x----+ x
// ^ ^ ^ ^
// 0 1 2 3
Kokkos::View<decltype(ArborX::intersects(ArborX::Point{})) *, DeviceType>
queries("Testing::queries", 1);
auto queries_host = Kokkos::create_mirror_view(queries);
queries_host(0) = ArborX::intersects(ArborX::Point{1.5f + comm_rank, 0, 0});
deep_copy(queries, queries_host);

Kokkos::View<int *, DeviceType> counts("Testing::counts", queries.size());
tree.query(ExecutionSpace{}, queries,
CustomPureInlineCallback<DeviceType>{counts});
auto counts_host =
Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, counts);

std::vector<int> counts_ref;
counts_ref.push_back(comm_rank > 0 ? 1 : 0);

BOOST_TEST(counts_host == counts_ref, tt::per_element());
}

BOOST_AUTO_TEST_CASE_TEMPLATE(boost_comparison, DeviceType, ARBORX_DEVICE_TYPES)
{
using ExecutionSpace = typename DeviceType::execution_space;
Expand Down

0 comments on commit 3e8d9ee

Please sign in to comment.