From 43e788f2e7ec23108b2cdc7adfd35745b0447552 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Mon, 19 Aug 2024 15:31:01 +0200 Subject: [PATCH] [dpcpp] move to proper headers --- dpcpp/base/batch_multi_vector_kernels.dp.cpp | 64 +++++++++---------- ...hpp.inc => batch_multi_vector_kernels.hpp} | 29 +++++++++ dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 2 +- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 43 ++++++++----- dpcpp/solver/batch_cg_kernels.dp.cpp | 2 +- dpcpp/solver/batch_cg_kernels.hpp.inc | 25 +++++--- 6 files changed, 102 insertions(+), 63 deletions(-) rename dpcpp/base/{batch_multi_vector_kernels.hpp.inc => batch_multi_vector_kernels.hpp} (92%) diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index 8f607725bc8..0d2662bdccd 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -15,6 +15,7 @@ #include "core/base/batch_struct.hpp" #include "core/components/prefix_sum_kernels.hpp" +#include "dpcpp/base/batch_multi_vector_kernels.hpp" #include "dpcpp/base/batch_struct.hpp" #include "dpcpp/base/config.hpp" #include "dpcpp/base/dim3.dp.hpp" @@ -29,17 +30,9 @@ namespace gko { namespace kernels { namespace dpcpp { -/** - * @brief The MultiVector matrix format namespace. - * @ref MultiVector - * @ingroup batch_multi_vector - */ namespace batch_multi_vector { -#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc" - - template void scale(std::shared_ptr exec, const batch::MultiVector* const alpha, @@ -71,7 +64,7 @@ void scale(std::shared_ptr exec, const auto alpha_b = batch::extract_batch_item(alpha_ub, group_id); const auto x_b = batch::extract_batch_item(x_ub, group_id); - scale_kernel( + batch_single_kernels::scale_kernel( alpha_b, x_b, item_ct1, [](int row, int col, int stride) { return 0; }); }); @@ -85,10 +78,11 @@ void scale(std::shared_ptr exec, const auto alpha_b = batch::extract_batch_item(alpha_ub, group_id); const auto x_b = batch::extract_batch_item(x_ub, group_id); - scale_kernel(alpha_b, x_b, item_ct1, - [](int row, int col, int stride) { - return row * stride + col; - }); + batch_single_kernels::scale_kernel( + alpha_b, x_b, item_ct1, + [](int row, int col, int stride) { + return row * stride + col; + }); }); }); } else { @@ -100,7 +94,7 @@ void scale(std::shared_ptr exec, const auto alpha_b = batch::extract_batch_item(alpha_ub, group_id); const auto x_b = batch::extract_batch_item(x_ub, group_id); - scale_kernel( + batch_single_kernels::scale_kernel( alpha_b, x_b, item_ct1, [](int row, int col, int stride) { return col; }); }); @@ -144,8 +138,9 @@ void add_scaled(std::shared_ptr exec, batch::extract_batch_item(alpha_ub, group_id); const auto x_b = batch::extract_batch_item(x_ub, group_id); const auto y_b = batch::extract_batch_item(y_ub, group_id); - add_scaled_kernel(alpha_b, x_b, y_b, item_ct1, - [](auto col) { return 0; }); + batch_single_kernels::add_scaled_kernel( + alpha_b, x_b, y_b, item_ct1, + [](auto col) { return 0; }); }); }); } else { @@ -158,8 +153,9 @@ void add_scaled(std::shared_ptr exec, batch::extract_batch_item(alpha_ub, group_id); const auto x_b = batch::extract_batch_item(x_ub, group_id); const auto y_b = batch::extract_batch_item(y_ub, group_id); - add_scaled_kernel(alpha_b, x_b, y_b, item_ct1, - [](auto col) { return col; }); + batch_single_kernels::add_scaled_kernel( + alpha_b, x_b, y_b, item_ct1, + [](auto col) { return col; }); }); }); } @@ -206,7 +202,7 @@ void compute_dot(std::shared_ptr exec, batch::extract_batch_item(y_ub, group_id); const auto res_b = batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_conj_dot_sg( + batch_single_kernels::single_rhs_compute_conj_dot_sg( x_b.num_rows, x_b.values, y_b.values, res_b.values[0], item_ct1); }); @@ -226,7 +222,7 @@ void compute_dot(std::shared_ptr exec, batch::extract_batch_item(y_ub, group_id); const auto res_b = batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( + batch_single_kernels::compute_gen_dot_product_kernel( x_b, y_b, res_b, item_ct1, [](auto val) { return val; }); }); @@ -272,7 +268,7 @@ void compute_conj_dot(std::shared_ptr exec, const auto y_b = batch::extract_batch_item(y_ub, group_id); const auto res_b = batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( + batch_single_kernels::compute_gen_dot_product_kernel( x_b, y_b, res_b, item_ct1, [](auto val) { return conj(val); }); }); @@ -308,17 +304,16 @@ void compute_norm2(std::shared_ptr exec, exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, - res_b.values[0], item_ct1); - }); + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + batch_single_kernels::single_rhs_compute_norm2_sg( + x_b.num_rows, x_b.values, res_b.values[0], item_ct1); + }); }); } else { exec->get_queue()->submit([&](sycl::handler& cgh) { @@ -332,7 +327,8 @@ void compute_norm2(std::shared_ptr exec, batch::extract_batch_item(x_ub, group_id); const auto res_b = batch::extract_batch_item(res_ub, group_id); - compute_norm2_kernel(x_b, res_b, item_ct1); + batch_single_kernels::compute_norm2_kernel(x_b, res_b, + item_ct1); }); }); } @@ -371,7 +367,7 @@ void copy(std::shared_ptr exec, const auto x_b = batch::extract_batch_item(x_ub, group_id); const auto result_b = batch::extract_batch_item(result_ub, group_id); - copy_kernel(x_b, result_b, item_ct1); + batch_single_kernels::copy_kernel(x_b, result_b, item_ct1); }); }); } diff --git a/dpcpp/base/batch_multi_vector_kernels.hpp.inc b/dpcpp/base/batch_multi_vector_kernels.hpp similarity index 92% rename from dpcpp/base/batch_multi_vector_kernels.hpp.inc rename to dpcpp/base/batch_multi_vector_kernels.hpp index c41eafd7efd..a16df237e34 100644 --- a/dpcpp/base/batch_multi_vector_kernels.hpp.inc +++ b/dpcpp/base/batch_multi_vector_kernels.hpp @@ -2,6 +2,29 @@ // // SPDX-License-Identifier: BSD-3-Clause + +#include + +#include + +#include "core/base/batch_struct.hpp" +#include "dpcpp/base/batch_struct.hpp" +#include "dpcpp/base/config.hpp" +#include "dpcpp/base/dim3.dp.hpp" +#include "dpcpp/base/dpct.hpp" +#include "dpcpp/base/helper.hpp" +#include "dpcpp/components/cooperative_groups.dp.hpp" +#include "dpcpp/components/intrinsics.dp.hpp" +#include "dpcpp/components/reduction.dp.hpp" +#include "dpcpp/components/thread_ids.dp.hpp" + + +namespace gko { +namespace kernels { +namespace GKO_DEVICE_NAMESPACE { +namespace batch_single_kernels { + + template __dpct_inline__ void scale_kernel( const gko::batch::multi_vector::batch_item& alpha, @@ -229,3 +252,9 @@ __dpct_inline__ void copy_kernel( out.values[i * out.stride + j] = in.values[i * in.stride + j]; } } + + +} // namespace batch_single_kernels +} // namespace GKO_DEVICE_NAMESPACE +} // namespace kernels +} // namespace gko diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index bb84283b49f..7dc8f3ec23b 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -13,6 +13,7 @@ #include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" +#include "dpcpp/base/batch_multi_vector_kernels.hpp" #include "dpcpp/base/batch_struct.hpp" #include "dpcpp/base/config.hpp" #include "dpcpp/base/dim3.dp.hpp" @@ -36,7 +37,6 @@ namespace dpcpp { namespace batch_bicgstab { -#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc" #include "dpcpp/matrix/batch_csr_kernels.hpp.inc" #include "dpcpp/matrix/batch_dense_kernels.hpp.inc" #include "dpcpp/matrix/batch_ell_kernels.hpp.inc" diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index ad7eaeff556..f5a88e9d59d 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -39,11 +39,13 @@ __dpct_inline__ void initialize( item_ct1.barrier(sycl::access::fence_space::global_and_local); if (sg_id == 0) { - single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm, - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm, + item_ct1); } else if (sg_id == 1) { - single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm, - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm, + item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); @@ -86,8 +88,9 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, const auto sg_id = sg.get_group_id(); const auto tid = item_ct1.get_local_linear_id(); if (sg_id == 0) { - single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, - v_shared_entry, alpha, item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, + v_shared_entry, alpha, item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); if (tid == 0) { @@ -123,11 +126,13 @@ __dpct_inline__ void compute_omega(const int num_rows, const auto sg_id = sg.get_group_id(); const auto tid = item_ct1.get_local_linear_id(); if (sg_id == 0) { - single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, s_shared_entry, - omega, item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, + s_shared_entry, omega, item_ct1); } else if (sg_id == 1) { - single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, t_shared_entry, - temp, item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, + t_shared_entry, temp, item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); if (tid == 0) { @@ -308,8 +313,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // rho_new = < r_hat , r > = (r_hat)' * (r) if (sg_id == 0) { - single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, - rho_new_sh[0], item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, + rho_new_sh[0], item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); @@ -338,8 +344,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // an estimate of residual norms if (sg_id == 0) { - single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], + item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); @@ -368,8 +375,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, item_ct1.barrier(sycl::access::fence_space::global_and_local); if (sg_id == 0) - single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0], - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0], + item_ct1); if (tid == group_size - 1) { rho_old_sh[0] = rho_new_sh[0]; } @@ -379,6 +387,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, logger.log_iteration(batch_id, iter, norms_res_sh[0]); // copy x back to global memory - copy_kernel(num_rows, x_sh, x_global_entry, item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel( + num_rows, x_sh, x_global_entry, item_ct1); item_ct1.barrier(sycl::access::fence_space::global_and_local); } diff --git a/dpcpp/solver/batch_cg_kernels.dp.cpp b/dpcpp/solver/batch_cg_kernels.dp.cpp index 61591f9efb6..f25d8266803 100644 --- a/dpcpp/solver/batch_cg_kernels.dp.cpp +++ b/dpcpp/solver/batch_cg_kernels.dp.cpp @@ -13,6 +13,7 @@ #include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" +#include "dpcpp/base/batch_multi_vector_kernels.hpp" #include "dpcpp/base/batch_struct.hpp" #include "dpcpp/base/config.hpp" #include "dpcpp/base/dim3.dp.hpp" @@ -36,7 +37,6 @@ namespace dpcpp { namespace batch_cg { -#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc" #include "dpcpp/matrix/batch_csr_kernels.hpp.inc" #include "dpcpp/matrix/batch_dense_kernels.hpp.inc" #include "dpcpp/matrix/batch_ell_kernels.hpp.inc" diff --git a/dpcpp/solver/batch_cg_kernels.hpp.inc b/dpcpp/solver/batch_cg_kernels.hpp.inc index cef6e620b64..7a91bcb2bbf 100644 --- a/dpcpp/solver/batch_cg_kernels.hpp.inc +++ b/dpcpp/solver/batch_cg_kernels.hpp.inc @@ -40,11 +40,13 @@ __dpct_inline__ void initialize( // Compute norms of rhs // and rho_old = r' * z if (sg_id == 0) { - single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norms, - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norms, + item_ct1); } else if (sg_id == 1) { - single_rhs_compute_conj_dot_sg(num_rows, r_shared_entry, z_shared_entry, - rho_old, item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, r_shared_entry, + z_shared_entry, rho_old, item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); @@ -80,9 +82,10 @@ __dpct_inline__ void update_x_and_r( auto sg = item_ct1.get_sub_group(); const auto tid = item_ct1.get_local_linear_id(); if (sg.get_group_id() == 0) { - single_rhs_compute_conj_dot_sg(num_rows, p_shared_entry, - Ap_shared_entry, alpha_shared_entry, - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, p_shared_entry, + Ap_shared_entry, alpha_shared_entry, + item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); if (tid == 0) { @@ -221,8 +224,9 @@ __dpct_inline__ void apply_kernel( // rho_new = (r)' * (z) if (sg_id == 0) { - single_rhs_compute_conj_dot_sg(num_rows, r_sh, z_sh, rho_new_sh[0], - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, r_sh, z_sh, + rho_new_sh[0], item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); @@ -239,6 +243,7 @@ __dpct_inline__ void apply_kernel( logger.log_iteration(batch_id, iter, norms_res_sh[0]); // copy x back to global memory - copy_kernel(num_rows, x_sh, x_global_entry, item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel( + num_rows, x_sh, x_global_entry, item_ct1); item_ct1.barrier(sycl::access::fence_space::global_and_local); }