Skip to content

Commit

Permalink
Merge pull request #300 from dalg24/access_traits
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop authored Jun 10, 2020
2 parents b7342ad + 251d784 commit 4b8fe6f
Show file tree
Hide file tree
Showing 17 changed files with 135 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,8 @@ struct RadiusSearches

namespace ArborX
{
namespace Traits
{
template <typename DeviceType>
struct Access<RadiusSearches<DeviceType>, PredicatesTag>
struct AccessTraits<RadiusSearches<DeviceType>, PredicatesTag>
{
using memory_space = typename DeviceType::memory_space;
static std::size_t size(RadiusSearches<DeviceType> const &pred)
Expand All @@ -203,7 +201,7 @@ struct Access<RadiusSearches<DeviceType>, PredicatesTag>
}
};
template <typename DeviceType>
struct Access<NearestNeighborsSearches<DeviceType>, PredicatesTag>
struct AccessTraits<NearestNeighborsSearches<DeviceType>, PredicatesTag>
{
using memory_space = typename DeviceType::memory_space;
static std::size_t size(NearestNeighborsSearches<DeviceType> const &pred)
Expand All @@ -216,7 +214,6 @@ struct Access<NearestNeighborsSearches<DeviceType>, PredicatesTag>
return nearest(pred.points(i), pred.k);
}
};
} // namespace Traits
} // namespace ArborX

namespace bpo = boost::program_options;
Expand Down
7 changes: 2 additions & 5 deletions examples/access_traits/example_cuda_access_traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@ struct Spheres

namespace ArborX
{
namespace Traits
{
template <>
struct Access<PointCloud, PrimitivesTag>
struct AccessTraits<PointCloud, PrimitivesTag>
{
static std::size_t size(PointCloud const &cloud) { return cloud.N; }
KOKKOS_FUNCTION static Point get(PointCloud const &cloud, std::size_t i)
Expand All @@ -50,7 +48,7 @@ struct Access<PointCloud, PrimitivesTag>
};

template <>
struct Access<Spheres, PredicatesTag>
struct AccessTraits<Spheres, PredicatesTag>
{
static std::size_t size(Spheres const &d) { return d.N; }
KOKKOS_FUNCTION static auto get(Spheres const &d, std::size_t i)
Expand All @@ -59,7 +57,6 @@ struct Access<Spheres, PredicatesTag>
}
using memory_space = Kokkos::CudaSpace;
};
} // namespace Traits
} // namespace ArborX

int main(int argc, char *argv[])
Expand Down
5 changes: 1 addition & 4 deletions examples/access_traits/example_host_access_traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@

namespace ArborX
{
namespace Traits
{
template <typename T, typename Tag>
struct Access<std::vector<T>, Tag>
struct AccessTraits<std::vector<T>, Tag>
{
static std::size_t size(std::vector<T> const &v) { return v.size(); }
KOKKOS_FUNCTION static T const &get(std::vector<T> const &v, std::size_t i)
Expand All @@ -31,7 +29,6 @@ struct Access<std::vector<T>, Tag>
}
using memory_space = Kokkos::HostSpace;
};
} // namespace Traits
} // namespace ArborX

int main(int argc, char *argv[])
Expand Down
7 changes: 2 additions & 5 deletions examples/callback/example_callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ struct NearestToOrigin

namespace ArborX
{
namespace Traits
{
template <>
struct Access<FirstOctant, PredicatesTag>
struct AccessTraits<FirstOctant, PredicatesTag>
{
KOKKOS_FUNCTION static std::size_t size(FirstOctant) { return 1; }
KOKKOS_FUNCTION static auto get(FirstOctant, std::size_t)
Expand All @@ -44,7 +42,7 @@ struct Access<FirstOctant, PredicatesTag>
using memory_space = MemorySpace;
};
template <>
struct Access<NearestToOrigin, PredicatesTag>
struct AccessTraits<NearestToOrigin, PredicatesTag>
{
KOKKOS_FUNCTION static std::size_t size(NearestToOrigin) { return 1; }
KOKKOS_FUNCTION static auto get(NearestToOrigin d, std::size_t)
Expand All @@ -53,7 +51,6 @@ struct Access<NearestToOrigin, PredicatesTag>
}
using memory_space = MemorySpace;
};
} // namespace Traits
} // namespace ArborX

struct PairIndexDistance
Expand Down
4 changes: 2 additions & 2 deletions src/ArborX_DistributedSearchTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ class DistributedSearchTree
Args &&... args) const
{
static_assert(Kokkos::is_execution_space<ExecutionSpace>::value, "");
using Access = Traits::Access<Predicates, Traits::PredicatesTag>;
using Tag = typename Traits::Helper<Access>::tag;
using Access = AccessTraits<Predicates, PredicatesTag>;
using Tag = typename Details::AccessTraitsHelper<Access>::tag;
using DeviceType = Kokkos::Device<ExecutionSpace, MemorySpace>;
Details::DistributedSearchTreeImpl<DeviceType>::queryDispatch(
Tag{}, *this, space, predicates, std::forward<Args>(args)...);
Expand Down
12 changes: 6 additions & 6 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ class BoundingVolumeHierarchy
CallbackOrView &&callback_or_view, View &&view,
Args &&... args) const
{
Details::check_valid_access_traits(Traits::PredicatesTag{}, predicates);
using Access = Traits::Access<Predicates, Traits::PredicatesTag>;
Details::check_valid_access_traits(PredicatesTag{}, predicates);
using Access = AccessTraits<Predicates, PredicatesTag>;
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");

Details::check_valid_callback_if_first_argument_is_not_a_view(
callback_or_view, predicates, view);

using Tag = typename Traits::Helper<Access>::tag;
using Tag = typename Details::AccessTraitsHelper<Access>::tag;

Details::BoundingVolumeHierarchyImpl::queryDispatch(
Tag{}, *this, space, predicates,
Expand Down Expand Up @@ -155,15 +155,15 @@ template <typename MemorySpace, typename Enable>
template <typename ExecutionSpace, typename Primitives>
BoundingVolumeHierarchy<MemorySpace, Enable>::BoundingVolumeHierarchy(
ExecutionSpace const &space, Primitives const &primitives)
: _size(Traits::Access<Primitives, Traits::PrimitivesTag>::size(primitives))
: _size(AccessTraits<Primitives, PrimitivesTag>::size(primitives))
, _internal_and_leaf_nodes(
Kokkos::ViewAllocateWithoutInitializing("internal_and_leaf_nodes"),
_size > 0 ? 2 * _size - 1 : 0)
{
Kokkos::Profiling::pushRegion("ArborX:BVH:construction");

Details::check_valid_access_traits(Traits::PrimitivesTag{}, primitives);
using Access = Traits::Access<Primitives, Traits::PrimitivesTag>;
Details::check_valid_access_traits(PrimitivesTag{}, primitives);
using Access = AccessTraits<Primitives, PrimitivesTag>;
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
ExecutionSpace>::value,
"Primitives must be accessible from the execution space");
Expand Down
102 changes: 54 additions & 48 deletions src/details/ArborX_AccessTraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

namespace ArborX
{
namespace Traits
{

struct PrimitivesTag
{
Expand All @@ -31,14 +29,19 @@ struct PredicatesTag
{
};

// Only a declaration so that existence of a specialization can be detected
template <typename T, typename Tag, typename Enable = void>
struct Access;
struct AccessTraits
{
using not_specialized = void; // tag to detect existence of a specialization
};

template <typename Traits>
using AccessTraitsNotSpecializedArchetypeAlias =
typename Traits::not_specialized;

template <typename View, typename Tag>
struct Access<View, Tag,
typename std::enable_if<Kokkos::is_view<View>::value &&
View::rank == 1>::type>
struct AccessTraits<
View, Tag, std::enable_if_t<Kokkos::is_view<View>{} && View::rank == 1>>
{
// Returns a const reference
KOKKOS_FUNCTION static typename View::const_value_type &get(View const &v,
Expand All @@ -53,9 +56,8 @@ struct Access<View, Tag,
};

template <typename View, typename Tag>
struct Access<View, Tag,
typename std::enable_if<Kokkos::is_view<View>::value &&
View::rank == 2>::type>
struct AccessTraits<
View, Tag, std::enable_if_t<Kokkos::is_view<View>{} && View::rank == 2>>
{
// Returns by value
KOKKOS_FUNCTION static Point get(View const &v, int i)
Expand All @@ -68,11 +70,6 @@ struct Access<View, Tag,
using memory_space = typename View::memory_space;
};

} // namespace Traits
} // namespace ArborX

namespace ArborX
{
namespace Details
{

Expand All @@ -90,94 +87,103 @@ template <typename Traits>
using AccessTraitsGetArchetypeExpression = decltype(
Traits::get(std::declval<first_template_parameter_t<Traits> const &>(), 0));

} // namespace Details

namespace Traits
{
template <typename Access,
typename = std::enable_if_t<Details::is_complete<Access>{}>>
struct Helper
template <typename Access>
struct AccessTraitsHelper
{
// Deduce return type of get()
using type = std::decay_t<
Details::detected_t<Details::AccessTraitsGetArchetypeExpression, Access>>;
using tag = typename Details::Tag<type>::type;
using type =
std::decay_t<detected_t<AccessTraitsGetArchetypeExpression, Access>>;
using tag = typename Tag<type>::type;
};
} // namespace Traits

namespace Details
{

template <typename Predicates>
void check_valid_access_traits(Traits::PredicatesTag, Predicates const &)
void check_valid_access_traits(PredicatesTag, Predicates const &)
{
using Access = Traits::Access<Predicates, Traits::PredicatesTag>;
using Access = AccessTraits<Predicates, PredicatesTag>;
static_assert(
is_complete<Access>{},
"Must specialize 'Traits::Access<Predicates,Traits::PredicatesTag>'");
!is_detected<AccessTraitsNotSpecializedArchetypeAlias, Access>{},
"Must specialize 'AccessTraits<Predicates,PredicatesTag>'");

static_assert(is_detected<AccessTraitsMemorySpaceArchetypeAlias, Access>{},
"Traits::Access<Predicates,Traits::PredicatesTag> must define "
"AccessTraits<Predicates,PredicatesTag> must define "
"'memory_space' member type");
static_assert(
Kokkos::is_memory_space<
detected_t<AccessTraitsMemorySpaceArchetypeAlias, Access>>{},
"'memory_space' member type must be a valid Kokkos memory space");

static_assert(is_detected<AccessTraitsSizeArchetypeExpression, Access>{},
"Traits::Access<Predicates,Traits::PredicatesTag> must define "
"AccessTraits<Predicates,PredicatesTag> must define "
"'size()' static member function");
static_assert(
std::is_integral<
detected_t<AccessTraitsSizeArchetypeExpression, Access>>{},
"size() static member function return type is not an integral type");

static_assert(is_detected<AccessTraitsGetArchetypeExpression, Access>{},
"Traits::Access<Predicates,Traits::PredicatesTag> must define "
"AccessTraits<Predicates,PredicatesTag> must define "
"'get()' static member function");

using Tag = typename Traits::Helper<Access>::tag;
using Tag = typename AccessTraitsHelper<Access>::tag;
static_assert(std::is_same<Tag, NearestPredicateTag>{} ||
std::is_same<Tag, SpatialPredicateTag>{},
"Invalid tag for the predicates");
}

template <typename Primitives>
void check_valid_access_traits(Traits::PrimitivesTag, Primitives const &)
void check_valid_access_traits(PrimitivesTag, Primitives const &)
{
using Access = Traits::Access<Primitives, Traits::PrimitivesTag>;
using Access = AccessTraits<Primitives, PrimitivesTag>;
static_assert(
is_complete<Access>{},
"Must specialize 'Traits::Access<Primitives,Traits::PrimitivesTag>'");
!is_detected<AccessTraitsNotSpecializedArchetypeAlias, Access>{},
"Must specialize 'AccessTraits<Primitives,PrimitivesTag>'");

static_assert(is_detected<AccessTraitsMemorySpaceArchetypeAlias, Access>{},
"Traits::Access<Primitives,Traits::PrimitivesTag> must define "
"AccessTraits<Primitives,PrimitivesTag> must define "
"'memory_space' member type");
static_assert(
Kokkos::is_memory_space<
detected_t<AccessTraitsMemorySpaceArchetypeAlias, Access>>{},
"'memory_space' member type must be a valid Kokkos memory space");

static_assert(is_detected<AccessTraitsSizeArchetypeExpression, Access>{},
"Traits::Access<Primitives,Traits::PrimitivesTag> must define "
"AccessTraits<Primitives,PrimitivesTag> must define "
"'size()' static member function");
static_assert(
std::is_integral<
detected_t<AccessTraitsSizeArchetypeExpression, Access>>{},
"size() static member function return type is not an integral type");

static_assert(is_detected<AccessTraitsGetArchetypeExpression, Access>{},
"Traits::Access<Primitives,Traits::PrimitivesTag> must define "
"AccessTraits<Primitives,PrimitivesTag> must define "
"'get()' static member function");
using T =
std::decay_t<detected_t<AccessTraitsGetArchetypeExpression, Access>>;
static_assert(
std::is_same<T, Point>{} || std::is_same<T, Box>{},
"Traits::Access<Primitives,Traits::PrimitivesTag>::get() return type "
"must decay to Point or to Box");
static_assert(std::is_same<T, Point>{} || std::is_same<T, Box>{},
"AccessTraits<Primitives,PrimitivesTag>::get() return type "
"must decay to Point or to Box");
}

} // namespace Details

namespace Traits
{
using ::ArborX::PredicatesTag;
using ::ArborX::PrimitivesTag;
template <typename T, typename Tag, typename Enable = void>
struct Access
{
using not_specialized = void;
};
} // namespace Traits
template <typename T, typename Tag>
struct AccessTraits<
T, Tag,
std::enable_if_t<!Details::is_detected<
AccessTraitsNotSpecializedArchetypeAlias, Traits::Access<T, Tag>>{}>>
: Traits::Access<T, Tag>
{
};
} // namespace ArborX

#endif
6 changes: 3 additions & 3 deletions src/details/ArborX_Callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ void check_valid_callback(Callback const &, Predicates const &,
"__host__ __device__ extended lambdas cannot be generic lambdas");
#endif

using Access = Traits::Access<Predicates, Traits::PredicatesTag>;
using PredicateTag = typename Traits::Helper<Access>::tag;
using Predicate = typename Traits::Helper<Access>::type;
using Access = AccessTraits<Predicates, PredicatesTag>;
using PredicateTag = typename AccessTraitsHelper<Access>::tag;
using Predicate = typename AccessTraitsHelper<Access>::type;

static_assert(
(std::is_same<PredicateTag, SpatialPredicateTag>{} &&
Expand Down
12 changes: 5 additions & 7 deletions src/details/ArborX_DetailsBatchedQueries.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct BatchedQueries
Box const &scene_bounding_box,
Predicates const &predicates)
{
using Access = Traits::Access<Predicates, Traits::PredicatesTag>;
using Access = AccessTraits<Predicates, PredicatesTag>;
auto const n_queries = Access::size(predicates);

Kokkos::View<unsigned int *, DeviceType> morton_codes(
Expand All @@ -79,13 +79,11 @@ struct BatchedQueries
applyPermutation(ExecutionSpace const &space,
Kokkos::View<unsigned int const *, DeviceType> permute,
Predicates const &v)
-> Kokkos::View<
std::decay_t<
decltype(Traits::Access<Predicates, Traits::PredicatesTag>::get(
std::declval<Predicates const &>(), std::declval<int>()))> *,
DeviceType>
-> Kokkos::View<typename AccessTraitsHelper<
AccessTraits<Predicates, PredicatesTag>>::type *,
DeviceType>
{
using Access = Traits::Access<Predicates, Traits::PredicatesTag>;
using Access = AccessTraits<Predicates, PredicatesTag>;
auto const n = Access::size(v);
ARBORX_ASSERT(permute.extent(0) == n);

Expand Down
Loading

0 comments on commit 4b8fe6f

Please sign in to comment.