Skip to content

Commit

Permalink
Added uniform distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy committed May 23, 2024
1 parent 9d9540e commit 199f13b
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 49 deletions.
1 change: 1 addition & 0 deletions dpnp/backend/extensions/rng/device/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ set(python_module_name _rng_dev_impl)
pybind11_add_module(${python_module_name} MODULE
rng_py.cpp
gaussian.cpp
uniform.cpp
)

if (WIN32)
Expand Down
19 changes: 2 additions & 17 deletions dpnp/backend/extensions/rng/device/common_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,7 @@
#include <oneapi/mkl/rng/device.hpp>
#include <sycl/sycl.hpp>

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace rng
{
namespace device
{
namespace details
namespace dpnp::backend::ext::rng::device::details
{
namespace py = pybind11;

Expand Down Expand Up @@ -129,9 +119,4 @@ struct RngContigFunctor
}
}
};
} // namespace details
} // namespace device
} // namespace rng
} // namespace ext
} // namespace backend
} // namespace dpnp
} // namespace dpnp::backend::ext::rng::device::details
24 changes: 20 additions & 4 deletions dpnp/backend/extensions/rng/device/dispatch/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,35 @@ struct GaussianTypePairSupportFactory
TypePairDefinedEntry<T,
double,
M,
mkl_rng_dev::gaussian_method::by_default>,
mkl_rng_dev::gaussian_method::box_muller2>,
TypePairDefinedEntry<T,
double,
float,
M,
mkl_rng_dev::gaussian_method::box_muller2>,
// fall-through
dpctl_td_ns::NotDefinedEntry>::is_defined;
};

template <typename T, typename M>
struct UniformTypePairSupportFactory
{
static constexpr bool is_defined = std::disjunction<
TypePairDefinedEntry<T,
double,
M,
mkl_rng_dev::uniform_method::standard>,
TypePairDefinedEntry<T,
double,
M,
mkl_rng_dev::uniform_method::accurate>,
TypePairDefinedEntry<T,
float,
M,
mkl_rng_dev::gaussian_method::by_default>,
mkl_rng_dev::uniform_method::standard>,
TypePairDefinedEntry<T,
float,
M,
mkl_rng_dev::gaussian_method::box_muller2>,
mkl_rng_dev::uniform_method::accurate>,
// fall-through
dpctl_td_ns::NotDefinedEntry>::is_defined;
};
Expand Down
48 changes: 22 additions & 26 deletions dpnp/backend/extensions/rng/device/dispatch/table_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,37 +40,34 @@ template <typename funcPtrT,
class Dispatch3DTableBuilder
{
private:
template <typename E, typename T>
template <typename E, typename T, typename... Methods>
const std::vector<funcPtrT> row_per_method() const
{
std::vector<funcPtrT> per_method = {
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::by_default>{}
.get(),
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::box_muller2>{}
.get(),
factory<funcPtrT, E, T, Methods>{}.get()...,
};
assert(per_method.size() == _no_of_methods);
return per_method;
}

template <typename E>
template <typename E, typename... Methods>
auto table_per_type_and_method() const
{
std::vector<std::vector<funcPtrT>> table_by_type = {
row_per_method<E, bool>(),
row_per_method<E, int8_t>(),
row_per_method<E, uint8_t>(),
row_per_method<E, int16_t>(),
row_per_method<E, uint16_t>(),
row_per_method<E, int32_t>(),
row_per_method<E, uint32_t>(),
row_per_method<E, int64_t>(),
row_per_method<E, uint64_t>(),
row_per_method<E, sycl::half>(),
row_per_method<E, float>(),
row_per_method<E, double>(),
row_per_method<E, std::complex<float>>(),
row_per_method<E, std::complex<double>>()};
row_per_method<E, bool, Methods...>(),
row_per_method<E, int8_t, Methods...>(),
row_per_method<E, uint8_t, Methods...>(),
row_per_method<E, int16_t, Methods...>(),
row_per_method<E, uint16_t, Methods...>(),
row_per_method<E, int32_t, Methods...>(),
row_per_method<E, uint32_t, Methods...>(),
row_per_method<E, int64_t, Methods...>(),
row_per_method<E, uint64_t, Methods...>(),
row_per_method<E, sycl::half, Methods...>(),
row_per_method<E, float, Methods...>(),
row_per_method<E, double, Methods...>(),
row_per_method<E, std::complex<float>, Methods...>(),
row_per_method<E, std::complex<double>, Methods...>()};
assert(table_by_type.size() == _no_of_types);
return table_by_type;
}
Expand All @@ -79,16 +76,15 @@ class Dispatch3DTableBuilder
Dispatch3DTableBuilder() = default;
~Dispatch3DTableBuilder() = default;

template <std::uint8_t... VecSizes>
template <typename... Methods, std::uint8_t... VecSizes>
void populate(funcPtrT table[][_no_of_types][_no_of_methods],
std::integer_sequence<std::uint8_t, VecSizes...>) const
{
const auto map_by_engine = {
table_per_type_and_method<mkl_rng_dev::mrg32k3a<VecSizes>>()...,
table_per_type_and_method<
mkl_rng_dev::philox4x32x10<VecSizes>>()...,
table_per_type_and_method<mkl_rng_dev::mcg31m1<VecSizes>>()...,
table_per_type_and_method<mkl_rng_dev::mcg59<VecSizes>>()...};
table_per_type_and_method<mkl_rng_dev::mrg32k3a<VecSizes>, Methods...>()...,
table_per_type_and_method<mkl_rng_dev::philox4x32x10<VecSizes>, Methods...>()...,
table_per_type_and_method<mkl_rng_dev::mcg31m1<VecSizes>, Methods...>()...,
table_per_type_and_method<mkl_rng_dev::mcg59<VecSizes>, Methods...>()...};
assert(map_by_engine.size() == _no_of_engines);

std::uint16_t engine_id = 0;
Expand Down
4 changes: 2 additions & 2 deletions dpnp/backend/extensions/rng/device/gaussian.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ using dpctl_krn_ns::disabled_sg_loadstore_wrapper_krn;
using dpctl_krn_ns::is_aligned;
using dpctl_krn_ns::required_alignment;

constexpr auto no_of_methods = 2; // number of methods of gaussian distribution
constexpr auto no_of_methods = 1; // number of methods of gaussian distribution

constexpr auto seq_of_vec_sizes =
std::integer_sequence<std::uint8_t, 2, 4, 8, 16>{};
Expand Down Expand Up @@ -291,6 +291,6 @@ void init_gaussian_dispatch_3d_table(void)
GaussianContigFactory, no_of_engines,
dpctl_td_ns::num_types, no_of_methods>
contig;
contig.populate(gaussian_dispatch_table, seq_of_vec_sizes);
contig.populate<mkl_rng_dev::gaussian_method::box_muller2>(gaussian_dispatch_table, seq_of_vec_sizes);
}
} // namespace dpnp::backend::ext::rng::device
Loading

0 comments on commit 199f13b

Please sign in to comment.