Skip to content

Commit

Permalink
Applied pre-commit formatting rules
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy committed May 16, 2024
1 parent c8425db commit 1f89cc5
Show file tree
Hide file tree
Showing 22 changed files with 460 additions and 616 deletions.
1 change: 0 additions & 1 deletion dpnp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ build_dpnp_cython_ext_with_backend(dparray ${CMAKE_CURRENT_SOURCE_DIR}/dparray.p
add_subdirectory(backend)
add_subdirectory(backend/extensions/blas)
add_subdirectory(backend/extensions/lapack)
add_subdirectory(backend/extensions/rng)
add_subdirectory(backend/extensions/rng/device)
add_subdirectory(backend/extensions/vm)
add_subdirectory(backend/extensions/sycl_ext)
Expand Down
74 changes: 0 additions & 74 deletions dpnp/backend/extensions/rng/CMakeLists.txt

This file was deleted.

46 changes: 32 additions & 14 deletions dpnp/backend/extensions/rng/device/common_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

#include <pybind11/pybind11.h>

#include <sycl/sycl.hpp>
#include <oneapi/mkl/rng/device.hpp>
#include <sycl/sycl.hpp>

namespace dpnp
{
Expand Down Expand Up @@ -57,11 +57,14 @@ struct RngContigFunctor

EngineBuilderT engine_;
DistributorBuilderT distr_;
DataT * const res_ = nullptr;
DataT *const res_ = nullptr;
const std::size_t nelems_;

public:
RngContigFunctor(EngineBuilderT& engine, DistributorBuilderT& distr, DataT *res, const std::size_t n_elems)
RngContigFunctor(EngineBuilderT &engine,
DistributorBuilderT &distr,
DataT *res,
const std::size_t n_elems)
: engine_(engine), distr_(distr), res_(res), nelems_(n_elems)
{
}
Expand All @@ -82,31 +85,46 @@ struct RngContigFunctor
DistrT distr = distr_();

if constexpr (enable_sg_load) {
const std::size_t base = vi_per_wi * (nd_it.get_group(0) * nd_it.get_local_range(0) + sg.get_group_id()[0] * max_sg_size);
const std::size_t base =
vi_per_wi * (nd_it.get_group(0) * nd_it.get_local_range(0) +
sg.get_group_id()[0] * max_sg_size);

if ((sg_size == max_sg_size) && (base + vi_per_wi * sg_size < nelems_)) {
if ((sg_size == max_sg_size) &&
(base + vi_per_wi * sg_size < nelems_)) {
#pragma unroll
for (std::uint16_t it = 0; it < vi_per_wi; it += vec_sz) {
std::size_t offset = base + static_cast<std::size_t>(it) * static_cast<std::size_t>(sg_size);
auto out_multi_ptr = sycl::address_space_cast<sycl::access::address_space::global_space, sycl::access::decorated::yes>(&res_[offset]);

sycl::vec<DataT, vec_sz> rng_val_vec = mkl_rng_dev::generate<DistrT, EngineT>(distr, engine);
std::size_t offset =
base + static_cast<std::size_t>(it) *
static_cast<std::size_t>(sg_size);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&res_[offset]);

sycl::vec<DataT, vec_sz> rng_val_vec =
mkl_rng_dev::generate<DistrT, EngineT>(distr, engine);
sg.store<vec_sz>(out_multi_ptr, rng_val_vec);
}
}
else {
for (std::size_t offset = base + sg.get_local_id()[0]; offset < nelems_; offset += sg_size) {
res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(distr, engine);
for (std::size_t offset = base + sg.get_local_id()[0];
offset < nelems_; offset += sg_size)
{
res_[offset] =
mkl_rng_dev::generate_single<DistrT, EngineT>(distr,
engine);
}
}
}
else {
std::size_t base = nd_it.get_global_linear_id();

base = (base / sg_size) * sg_size * vi_per_wi + (base % sg_size);
for (std::size_t offset = base; offset < std::min(nelems_, base + sg_size * vi_per_wi); offset += sg_size)
for (std::size_t offset = base;
offset < std::min(nelems_, base + sg_size * vi_per_wi);
offset += sg_size)
{
res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(distr, engine);
res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(
distr, engine);
}
}
}
Expand All @@ -116,4 +134,4 @@ struct RngContigFunctor
} // namespace rng
} // namespace ext
} // namespace backend
} // namespace dpnp
} // namespace dpnp
28 changes: 20 additions & 8 deletions dpnp/backend/extensions/rng/device/dispatch/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@

#include "utils/type_dispatch.hpp"


namespace dpnp::backend::ext::rng::device::dispatch
{
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
namespace mkl_rng_dev = oneapi::mkl::rng::device;

template <typename Ty, typename ArgTy, typename Method, typename argMethod>
struct TypePairDefinedEntry : std::bool_constant<std::is_same_v<Ty, ArgTy> &&
std::is_same_v<Method, argMethod>>
struct TypePairDefinedEntry
: std::bool_constant<std::is_same_v<Ty, ArgTy> &&
std::is_same_v<Method, argMethod>>
{
static constexpr bool is_defined = true;
};
Expand All @@ -46,11 +46,23 @@ template <typename T, typename M>
struct GaussianTypePairSupportFactory
{
static constexpr bool is_defined = std::disjunction<
TypePairDefinedEntry<T, double, M, mkl_rng_dev::gaussian_method::by_default>,
TypePairDefinedEntry<T, double, M, mkl_rng_dev::gaussian_method::box_muller2>,
TypePairDefinedEntry<T, float, M, mkl_rng_dev::gaussian_method::by_default>,
TypePairDefinedEntry<T, float, M, mkl_rng_dev::gaussian_method::box_muller2>,
TypePairDefinedEntry<T,
double,
M,
mkl_rng_dev::gaussian_method::by_default>,
TypePairDefinedEntry<T,
double,
M,
mkl_rng_dev::gaussian_method::box_muller2>,
TypePairDefinedEntry<T,
float,
M,
mkl_rng_dev::gaussian_method::by_default>,
TypePairDefinedEntry<T,
float,
M,
mkl_rng_dev::gaussian_method::box_muller2>,
// fall-through
dpctl_td_ns::NotDefinedEntry>::is_defined;
};
} // dpnp::backend::ext::rng::device::dispatch
} // namespace dpnp::backend::ext::rng::device::dispatch
55 changes: 30 additions & 25 deletions dpnp/backend/extensions/rng/device/dispatch/table_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@

#include <oneapi/mkl/rng/device.hpp>


namespace dpnp::backend::ext::rng::device::dispatch
{
namespace mkl_rng_dev = oneapi::mkl::rng::device;

template <typename funcPtrT,
template <typename fnT, typename E, typename T, typename M> typename factory,
template <typename fnT, typename E, typename T, typename M>
typename factory,
int _no_of_engines,
int _no_of_types,
int _no_of_methods>
Expand All @@ -44,8 +44,10 @@ class Dispatch3DTableBuilder
const std::vector<funcPtrT> row_per_method() const
{
std::vector<funcPtrT> per_method = {
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::by_default>{}.get(),
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::box_muller2>{}.get(),
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::by_default>{}
.get(),
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::box_muller2>{}
.get(),
};
assert(per_method.size() == _no_of_methods);
return per_method;
Expand All @@ -54,21 +56,21 @@ class Dispatch3DTableBuilder
template <typename E>
auto table_per_type_and_method() const
{
std::vector<std::vector<funcPtrT>>
table_by_type = {row_per_method<E, bool>(),
row_per_method<E, int8_t>(),
row_per_method<E, uint8_t>(),
row_per_method<E, int16_t>(),
row_per_method<E, uint16_t>(),
row_per_method<E, int32_t>(),
row_per_method<E, uint32_t>(),
row_per_method<E, int64_t>(),
row_per_method<E, uint64_t>(),
row_per_method<E, sycl::half>(),
row_per_method<E, float>(),
row_per_method<E, double>(),
row_per_method<E, std::complex<float>>(),
row_per_method<E, std::complex<double>>()};
std::vector<std::vector<funcPtrT>> table_by_type = {
row_per_method<E, bool>(),
row_per_method<E, int8_t>(),
row_per_method<E, uint8_t>(),
row_per_method<E, int16_t>(),
row_per_method<E, uint16_t>(),
row_per_method<E, int32_t>(),
row_per_method<E, uint32_t>(),
row_per_method<E, int64_t>(),
row_per_method<E, uint64_t>(),
row_per_method<E, sycl::half>(),
row_per_method<E, float>(),
row_per_method<E, double>(),
row_per_method<E, std::complex<float>>(),
row_per_method<E, std::complex<double>>()};
assert(table_by_type.size() == _no_of_types);
return table_by_type;
}
Expand All @@ -78,12 +80,15 @@ class Dispatch3DTableBuilder
~Dispatch3DTableBuilder() = default;

template <std::uint8_t... VecSizes>
void populate(funcPtrT table[][_no_of_types][_no_of_methods], std::integer_sequence<std::uint8_t, VecSizes...>) const
void populate(funcPtrT table[][_no_of_types][_no_of_methods],
std::integer_sequence<std::uint8_t, VecSizes...>) const
{
const auto map_by_engine = {table_per_type_and_method<mkl_rng_dev::mrg32k3a<VecSizes>>()...,
table_per_type_and_method<mkl_rng_dev::philox4x32x10<VecSizes>>()...,
table_per_type_and_method<mkl_rng_dev::mcg31m1<VecSizes>>()...,
table_per_type_and_method<mkl_rng_dev::mcg59<VecSizes>>()...};
const auto map_by_engine = {
table_per_type_and_method<mkl_rng_dev::mrg32k3a<VecSizes>>()...,
table_per_type_and_method<
mkl_rng_dev::philox4x32x10<VecSizes>>()...,
table_per_type_and_method<mkl_rng_dev::mcg31m1<VecSizes>>()...,
table_per_type_and_method<mkl_rng_dev::mcg59<VecSizes>>()...};
assert(map_by_engine.size() == _no_of_engines);

std::uint16_t engine_id = 0;
Expand All @@ -101,4 +106,4 @@ class Dispatch3DTableBuilder
}
}
};
} // dpnp::backend::ext::rng::device::dispatch
} // namespace dpnp::backend::ext::rng::device::dispatch
Loading

0 comments on commit 1f89cc5

Please sign in to comment.