From 76d63e49fdd25a02f4a54a4e3c1147f66a4223a6 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 27 Mar 2024 12:52:15 +0100 Subject: [PATCH] Added base f/w for host API as an extension --- dpnp/CMakeLists.txt | 1 + .../extensions/rng/host/CMakeLists.txt | 77 ++++++++ .../extensions/rng/host/dispatch/matrix.hpp | 65 +++++++ .../rng/host/dispatch/table_builder.hpp | 106 ++++++++++ dpnp/backend/extensions/rng/host/gaussian.cpp | 181 ++++++++++++++++++ dpnp/backend/extensions/rng/host/gaussian.hpp | 44 +++++ dpnp/backend/extensions/rng/host/rng_py.cpp | 142 ++++++++++++++ 7 files changed, 616 insertions(+) create mode 100644 dpnp/backend/extensions/rng/host/CMakeLists.txt create mode 100644 dpnp/backend/extensions/rng/host/dispatch/matrix.hpp create mode 100644 dpnp/backend/extensions/rng/host/dispatch/table_builder.hpp create mode 100644 dpnp/backend/extensions/rng/host/gaussian.cpp create mode 100644 dpnp/backend/extensions/rng/host/gaussian.hpp create mode 100644 dpnp/backend/extensions/rng/host/rng_py.cpp diff --git a/dpnp/CMakeLists.txt b/dpnp/CMakeLists.txt index b4bdf13abbd..9808d57c825 100644 --- a/dpnp/CMakeLists.txt +++ b/dpnp/CMakeLists.txt @@ -59,6 +59,7 @@ add_subdirectory(backend) add_subdirectory(backend/extensions/blas) add_subdirectory(backend/extensions/lapack) add_subdirectory(backend/extensions/rng/device) +add_subdirectory(backend/extensions/rng/host) add_subdirectory(backend/extensions/vm) add_subdirectory(backend/extensions/sycl_ext) diff --git a/dpnp/backend/extensions/rng/host/CMakeLists.txt b/dpnp/backend/extensions/rng/host/CMakeLists.txt new file mode 100644 index 00000000000..f3db23c2a23 --- /dev/null +++ b/dpnp/backend/extensions/rng/host/CMakeLists.txt @@ -0,0 +1,77 @@ +# ***************************************************************************** +# Copyright (c) 2023, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + + +set(python_module_name _rng_host_impl) +pybind11_add_module(${python_module_name} MODULE + rng_py.cpp + gaussian.cpp +) + +if (WIN32) + if (${CMAKE_VERSION} VERSION_LESS "3.27") + # this is a work-around for target_link_options inserting option after -link option, cause + # linker to ignore it. + set(CMAKE_CXX_LINK_FLAGS "${CMAKE_CXX_LINK_FLAGS} -fsycl-device-code-split=per_kernel") + endif() +endif() + +set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON) + +target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) +target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/engine) +target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include) +target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src) + +target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS}) +target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR}) + +if (WIN32) + target_compile_options(${python_module_name} PRIVATE + /clang:-fno-approx-func + /clang:-fno-finite-math-only + ) +else() + target_compile_options(${python_module_name} PRIVATE + -fno-approx-func + -fno-finite-math-only + ) +endif() + +target_link_options(${python_module_name} PUBLIC -fsycl-device-code-split=per_kernel) +if (UNIX) + # this option is support on Linux only + target_link_options(${python_module_name} PUBLIC -fsycl-link-huge-device-code) +endif() + +if (DPNP_GENERATE_COVERAGE) + target_link_options(${python_module_name} PRIVATE -fprofile-instr-generate -fcoverage-mapping) +endif() + +target_link_libraries(${python_module_name} PUBLIC MKL::MKL_DPCPP) + +install(TARGETS ${python_module_name} + DESTINATION "dpnp/backend/extensions/rng/host" +) diff --git a/dpnp/backend/extensions/rng/host/dispatch/matrix.hpp b/dpnp/backend/extensions/rng/host/dispatch/matrix.hpp new file mode 100644 index 00000000000..eb7aa43450f --- /dev/null +++ b/dpnp/backend/extensions/rng/host/dispatch/matrix.hpp @@ -0,0 +1,65 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +#include "utils/type_dispatch.hpp" + +namespace dpnp::backend::ext::rng::host::dispatch +{ +namespace dpctl_td_ns = dpctl::tensor::type_dispatch; +namespace mkl_rng = oneapi::mkl::rng; + +template +struct TypePairDefinedEntry + : std::bool_constant && + std::is_same_v> +{ + static constexpr bool is_defined = true; +}; + +template +struct GaussianTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + TypePairDefinedEntry, + TypePairDefinedEntry, + TypePairDefinedEntry, + TypePairDefinedEntry, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; +} // namespace dpnp::backend::ext::rng::host::dispatch diff --git a/dpnp/backend/extensions/rng/host/dispatch/table_builder.hpp b/dpnp/backend/extensions/rng/host/dispatch/table_builder.hpp new file mode 100644 index 00000000000..772fa8fc9f3 --- /dev/null +++ b/dpnp/backend/extensions/rng/host/dispatch/table_builder.hpp @@ -0,0 +1,106 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +namespace dpnp::backend::ext::rng::host::dispatch +{ +namespace mkl_rng = oneapi::mkl::rng; + +template + typename factory, + int _no_of_engines, + int _no_of_types, + int _no_of_methods> +class Dispatch3DTableBuilder +{ +private: + template + const std::vector row_per_method() const + { + std::vector per_method = { + factory{} + .get(), + factory{} + .get(), + }; + assert(per_method.size() == _no_of_methods); + return per_method; + } + + template + auto table_per_type_and_method() const + { + std::vector> table_by_type = { + row_per_method(), + row_per_method(), + row_per_method(), + row_per_method(), + row_per_method(), + row_per_method(), + row_per_method(), + row_per_method(), + row_per_method(), + row_per_method(), + row_per_method(), + row_per_method(), + row_per_method>(), + row_per_method>()}; + assert(table_by_type.size() == _no_of_types); + return table_by_type; + } + +public: + Dispatch3DTableBuilder() = default; + ~Dispatch3DTableBuilder() = default; + + void populate(funcPtrT table[][_no_of_types][_no_of_methods]) const + { + const auto map_by_engine = { + table_per_type_and_method(), + table_per_type_and_method(), + table_per_type_and_method(), + table_per_type_and_method()}; + assert(map_by_engine.size() == _no_of_engines); + + std::uint16_t engine_id = 0; + for (auto &table_by_type : map_by_engine) { + std::uint16_t type_id = 0; + for (auto &row_by_method : table_by_type) { + std::uint16_t method_id = 0; + for (auto &fn_ptr : row_by_method) { + table[engine_id][type_id][method_id] = fn_ptr; + ++method_id; + } + ++type_id; + } + ++engine_id; + } + } +}; +} // namespace dpnp::backend::ext::rng::host::dispatch diff --git a/dpnp/backend/extensions/rng/host/gaussian.cpp b/dpnp/backend/extensions/rng/host/gaussian.cpp new file mode 100644 index 00000000000..d6b4a238702 --- /dev/null +++ b/dpnp/backend/extensions/rng/host/gaussian.cpp @@ -0,0 +1,181 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +#include + +// dpctl tensor headers +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +#include "gaussian.hpp" + +#include "dispatch/matrix.hpp" +#include "dispatch/table_builder.hpp" + +namespace dpnp::backend::ext::rng::host +{ +namespace dpctl_td_ns = dpctl::tensor::type_dispatch; +namespace dpctl_tu_ns = dpctl::tensor::type_utils; +namespace mkl_rng = oneapi::mkl::rng; +namespace py = pybind11; + +constexpr auto no_of_methods = 2; // number of methods of gaussian distribution +constexpr auto no_of_engines = device::engine::no_of_engines; + +typedef sycl::event (*gaussian_impl_fn_ptr_t)( + device::engine::EngineBase *engine, + const double, + const double, + const std::uint64_t, + char *, + const std::vector &); + +static gaussian_impl_fn_ptr_t gaussian_dispatch_table[no_of_engines] + [dpctl_td_ns::num_types] + [no_of_methods]; + +template +static sycl::event gaussian_impl(device::engine::EngineBase *engine, + const double mean_val, + const double stddev_val, + const std::uint64_t n, + char *out_ptr, + const std::vector &depends) +{ + auto &exec_q = engine->get_queue(); + dpctl_tu_ns::validate_type_for_device(exec_q); + + DataT *out = reinterpret_cast(out_ptr); + DataT mean = static_cast(mean_val); + DataT stddev = static_cast(stddev_val); + + auto seed_values = engine->get_seeds(); + auto no_of_seeds = seed_values.size(); + if (no_of_seeds > 1) { + throw std::runtime_error(""); + } + + mkl_rng::gaussian distribution(mean, stddev); + mkl_rng::mcg59 eng(exec_q, seed_values[0]); + + return mkl_rng::generate(distribution, eng, n, out, depends); +} + +std::pair + gaussian(device::engine::EngineBase *engine, + const std::uint8_t method_id, + const double mean, + const double stddev, + const std::uint64_t n, + dpctl::tensor::usm_ndarray res, + const std::vector &depends) +{ + auto &exec_q = engine->get_queue(); + + const int res_nd = res.get_ndim(); + const py::ssize_t *res_shape = res.get_shape_raw(); + + size_t res_nelems(1); + for (int i = 0; i < res_nd; ++i) { + res_nelems *= static_cast(res_shape[i]); + } + + if (res_nelems == 0) { + // nothing to do + return std::make_pair(sycl::event(), sycl::event()); + } + + // ensure that output is ample enough to accommodate all elements + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(res, res_nelems); + + if (!dpctl::utils::queues_are_compatible(exec_q, {res})) { + throw py::value_error( + "Execution queue is not compatible with the allocation queue"); + } + + bool is_res_c_contig = res.is_c_contiguous(); + if (!is_res_c_contig) { + throw std::runtime_error( + "Only population of contiguous array is supported."); + } + + auto enginge_id = engine->get_type().id(); + if (enginge_id >= device::engine::no_of_engines) { + throw std::runtime_error( + "Unknown engine type=" + std::to_string(enginge_id) + + " for gaussian distribution."); + } + + if (method_id >= no_of_methods) { + throw std::runtime_error("Unknown method=" + std::to_string(method_id) + + " for gaussian distribution."); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int res_type_id = array_types.typenum_to_lookup_id(res.get_typenum()); + + auto gaussian_fn = + gaussian_dispatch_table[enginge_id][res_type_id][method_id]; + if (gaussian_fn == nullptr) { + throw py::value_error( + "No gaussian implementation defined for a required type"); + } + + char *res_data = res.get_data(); + sycl::event gaussian_ev = + gaussian_fn(engine, mean, stddev, n, res_data, depends); + + sycl::event ht_ev = + dpctl::utils::keep_args_alive(exec_q, {res}, {gaussian_ev}); + return std::make_pair(ht_ev, gaussian_ev); +} + +template +struct GaussianContigFactory +{ + fnT get() + { + if constexpr (dispatch::GaussianTypePairSupportFactory::is_defined) { + return gaussian_impl; + } + else { + return nullptr; + } + } +}; + +void init_gaussian_dispatch_3d_table(void) +{ + dispatch::Dispatch3DTableBuilder + contig; + contig.populate(gaussian_dispatch_table); +} +} // namespace dpnp::backend::ext::rng::host diff --git a/dpnp/backend/extensions/rng/host/gaussian.hpp b/dpnp/backend/extensions/rng/host/gaussian.hpp new file mode 100644 index 00000000000..2ebf5f976e2 --- /dev/null +++ b/dpnp/backend/extensions/rng/host/gaussian.hpp @@ -0,0 +1,44 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +#include "../device/engine/base_engine.hpp" + +namespace dpnp::backend::ext::rng::host +{ +extern std::pair + gaussian(device::engine::EngineBase *engine, + const std::uint8_t method_id, + const double mean, + const double stddev, + const std::uint64_t n, + dpctl::tensor::usm_ndarray res, + const std::vector &depends = {}); + +extern void init_gaussian_dispatch_3d_table(void); +} // namespace dpnp::backend::ext::rng::host diff --git a/dpnp/backend/extensions/rng/host/rng_py.cpp b/dpnp/backend/extensions/rng/host/rng_py.cpp new file mode 100644 index 00000000000..d76b07242e7 --- /dev/null +++ b/dpnp/backend/extensions/rng/host/rng_py.cpp @@ -0,0 +1,142 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +// This file defines functions of dpnp.backend._rng_impl extensions +// +//***************************************************************************** + +#include +#include + +#include +#include + +#include "gaussian.hpp" + +// #include "../device/engine/mcg31m1_engine.hpp" +// #include "../device/engine/mcg59_engine.hpp" +// #include "../device/engine/mrg32k3a_engine.hpp" +// #include "../device/engine/philox4x32x10_engine.hpp" + +namespace mkl_rng = oneapi::mkl::rng; +namespace rng_host_ext = dpnp::backend::ext::rng::host; +// namespace rng_dev_engine = dpnp::backend::ext::rng::device::engine; +namespace py = pybind11; + +// populate dispatch 3-D tables +void init_dispatch_3d_tables(void) +{ + rng_host_ext::init_gaussian_dispatch_3d_table(); +} + +// class PyEngineBase : public rng_dev_engine::EngineBase +// { +// public: +// // inherit the constructor +// using EngineBase::EngineBase; + +// // trampoline (need one for each virtual function) +// // sycl::queue &get_queue() { +// // PYBIND11_OVERRIDE_PURE( +// // sycl::queue&, /* Return type */ +// // EngineBase, /* Parent class */ +// // get_queue, /* Name of function in C++ (must match Python +// name) +// // */ +// // ); +// // } +// }; + +PYBIND11_MODULE(_rng_host_impl, m) +{ + init_dispatch_3d_tables(); + + // py::class_( + // m, "EngineBase") + // .def(py::init<>()) + // .def("get_queue", &rng_dev_engine::EngineBase::get_queue); + + // py::class_(m, + // "MRG32k3a") + // .def(py::init(), + // py::arg("sycl_queue"), py::arg("seed"), py::arg("offset") = + // 0) + // .def(py::init &, + // std::uint64_t>(), + // py::arg("sycl_queue"), py::arg("seed"), py::arg("offset") = + // 0) + // .def(py::init &>(), + // py::arg("sycl_queue"), py::arg("seed"), + // py::arg("offset") = py::list()) + // .def(py::init &, + // std::vector &>(), + // py::arg("sycl_queue"), py::arg("seed"), + // py::arg("offset") = py::list()); + + // py::class_( + // m, "PHILOX4x32x10") + // .def(py::init(), + // py::arg("sycl_queue"), py::arg("seed"), py::arg("offset") = + // 0) + // .def(py::init &, + // std::uint64_t>(), + // py::arg("sycl_queue"), py::arg("seed"), py::arg("offset") = + // 0) + // .def(py::init &>(), + // py::arg("sycl_queue"), py::arg("seed"), + // py::arg("offset") = py::list()) + // .def(py::init &, + // std::vector &>(), + // py::arg("sycl_queue"), py::arg("seed"), + // py::arg("offset") = py::list()); + + // py::class_(m, + // "MCG31M1") + // .def(py::init(), + // py::arg("sycl_queue"), py::arg("seed"), py::arg("offset") = + // 0) + // .def(py::init &, + // std::uint64_t>(), + // py::arg("sycl_queue"), py::arg("seed"), py::arg("offset") = + // 0); + + // py::class_(m, + // "MCG59") + // .def(py::init(), + // py::arg("sycl_queue"), py::arg("seed"), py::arg("offset") = + // 0) + // .def(py::init &, + // std::uint64_t>(), + // py::arg("sycl_queue"), py::arg("seed"), py::arg("offset") = + // 0); + + m.def("_gaussian", &rng_host_ext::gaussian, "", py::arg("engine"), + py::arg("method_id"), py::arg("mean"), py::arg("stddev"), + py::arg("n"), py::arg("res"), py::arg("depends") = py::list()); +}