Skip to content

Commit

Permalink
Switch NearestBufferProvider to non-hardcoded float
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed Dec 3, 2024
1 parent 51b8677 commit 3eded6d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
6 changes: 4 additions & 2 deletions src/spatial/detail/ArborX_BruteForceImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ struct BruteForceImpl
int const n_indexables = values.size();
int const n_predicates = predicates.size();

NearestBufferProvider<MemorySpace> buffer_provider(space, predicates);
using Coordinate = decltype(predicates(0).distance(indexables(0)));
NearestBufferProvider<MemorySpace, Coordinate> buffer_provider(space,
predicates);

Kokkos::parallel_for(
"ArborX::BruteForce::query::nearest::"
Expand All @@ -168,7 +170,7 @@ struct BruteForceImpl
return;

using PairIndexDistance =
typename NearestBufferProvider<MemorySpace>::PairIndexDistance;
typename decltype(buffer_provider)::PairIndexDistance;
struct CompareDistance
{
KOKKOS_INLINE_FUNCTION bool
Expand Down
4 changes: 2 additions & 2 deletions src/spatial/detail/ArborX_NearestBufferProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
namespace ArborX::Details
{

template <typename MemorySpace>
template <typename MemorySpace, typename Coordinate>
struct NearestBufferProvider
{
static_assert(Kokkos::is_memory_space_v<MemorySpace>);

using PairIndexDistance = Kokkos::pair<int, float>;
using PairIndexDistance = Kokkos::pair<int, Coordinate>;

Kokkos::View<PairIndexDistance *, MemorySpace> _buffer;
Kokkos::View<int *, MemorySpace> _offset;
Expand Down
14 changes: 7 additions & 7 deletions src/spatial/detail/ArborX_TreeTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
Predicates _predicates;
Callback _callback;

NearestBufferProvider<MemorySpace> _buffer;
using Coordinate = decltype(std::declval<Predicates>()(0).distance(
HappyTreeFriends::getIndexable(_bvh, 0)));

NearestBufferProvider<MemorySpace, Coordinate> _buffer;

template <typename ExecutionSpace>
TreeTraversal(ExecutionSpace const &space, BVH const &bvh,
Expand All @@ -151,7 +154,8 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
}
else
{
_buffer = NearestBufferProvider<MemorySpace>(space, predicates);
_buffer =
NearestBufferProvider<MemorySpace, Coordinate>(space, predicates);

Kokkos::parallel_for("ArborX::TreeTraversal::nearest",
Kokkos::RangePolicy(space, 0, predicates.size()),
Expand Down Expand Up @@ -184,8 +188,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
if (k < 1)
return;

using PairIndexDistance =
typename NearestBufferProvider<MemorySpace>::PairIndexDistance;
using PairIndexDistance = typename decltype(_buffer)::PairIndexDistance;
struct CompareDistance
{
KOKKOS_INLINE_FUNCTION bool operator()(PairIndexDistance const &lhs,
Expand All @@ -212,9 +215,6 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
HappyTreeFriends::getInternalBoundingVolume(bvh, j));
};

using Coordinate =
decltype(predicate.distance(HappyTreeFriends::getIndexable(bvh, 0)));

constexpr int SENTINEL = -1;
int stack[64];
auto *stack_ptr = stack;
Expand Down

0 comments on commit 3eded6d

Please sign in to comment.