diff --git a/common/cuda_hip/CMakeLists.txt b/common/cuda_hip/CMakeLists.txt index 463abfd9284..15d3a82419e 100644 --- a/common/cuda_hip/CMakeLists.txt +++ b/common/cuda_hip/CMakeLists.txt @@ -1,5 +1,6 @@ include(${PROJECT_SOURCE_DIR}/cmake/template_instantiation.cmake) set(CUDA_HIP_SOURCES + base/batch_multi_vector_kernels.cpp base/device_matrix_data_kernels.cpp base/index_set_kernels.cpp components/prefix_sum_kernels.cpp diff --git a/common/cuda_hip/base/batch_multi_vector_kernel_launcher.hpp.inc b/common/cuda_hip/base/batch_multi_vector_kernels.cpp similarity index 67% rename from common/cuda_hip/base/batch_multi_vector_kernel_launcher.hpp.inc rename to common/cuda_hip/base/batch_multi_vector_kernels.cpp index 19b5b74a547..17f65487464 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernel_launcher.hpp.inc +++ b/common/cuda_hip/base/batch_multi_vector_kernels.cpp @@ -2,6 +2,32 @@ // // SPDX-License-Identifier: BSD-3-Clause +#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp" + +#include +#include + +#include +#include +#include +#include + +#include "common/cuda_hip/base/config.hpp" +#include "common/cuda_hip/base/math.hpp" +#include "common/cuda_hip/base/runtime.hpp" +#include "core/base/batch_multi_vector_kernels.hpp" +#include "core/base/batch_struct.hpp" + + +namespace gko { +namespace kernels { +namespace GKO_DEVICE_NAMESPACE { +namespace batch_multi_vector { + + +constexpr auto default_block_size = 256; + + template void scale(std::shared_ptr exec, const batch::MultiVector* const alpha, @@ -11,16 +37,19 @@ void scale(std::shared_ptr exec, const auto alpha_ub = get_batch_struct(alpha); const auto x_ub = get_batch_struct(x); if (alpha->get_common_size()[1] == 1) { - scale_kernel<<get_stream()>>>( + batch_single_kernels::scale_kernel<<get_stream()>>>( alpha_ub, x_ub, [] __device__(int row, int col, int stride) { return 0; }); } else if (alpha->get_common_size() == x->get_common_size()) { - scale_kernel<<get_stream()>>>( + batch_single_kernels::scale_kernel<<get_stream()>>>( alpha_ub, x_ub, [] __device__(int row, int col, int stride) { return row * stride + col; }); } else { - scale_kernel<<get_stream()>>>( + batch_single_kernels::scale_kernel<<get_stream()>>>( alpha_ub, x_ub, [] __device__(int row, int col, int stride) { return col; }); } @@ -42,12 +71,12 @@ void add_scaled(std::shared_ptr exec, const auto x_ub = get_batch_struct(x); const auto y_ub = get_batch_struct(y); if (alpha->get_common_size()[1] == 1) { - add_scaled_kernel<<get_stream()>>>( + batch_single_kernels::add_scaled_kernel<<< + num_blocks, default_block_size, 0, exec->get_stream()>>>( alpha_ub, x_ub, y_ub, [] __device__(int col) { return 0; }); } else { - add_scaled_kernel<<get_stream()>>>( + batch_single_kernels::add_scaled_kernel<<< + num_blocks, default_block_size, 0, exec->get_stream()>>>( alpha_ub, x_ub, y_ub, [] __device__(int col) { return col; }); } } @@ -67,8 +96,8 @@ void compute_dot(std::shared_ptr exec, const auto x_ub = get_batch_struct(x); const auto y_ub = get_batch_struct(y); const auto res_ub = get_batch_struct(result); - compute_gen_dot_product_kernel<<get_stream()>>>( + batch_single_kernels::compute_gen_dot_product_kernel<<< + num_blocks, default_block_size, 0, exec->get_stream()>>>( x_ub, y_ub, res_ub, [] __device__(auto val) { return val; }); } @@ -87,8 +116,8 @@ void compute_conj_dot(std::shared_ptr exec, const auto x_ub = get_batch_struct(x); const auto y_ub = get_batch_struct(y); const auto res_ub = get_batch_struct(result); - compute_gen_dot_product_kernel<<get_stream()>>>( + batch_single_kernels::compute_gen_dot_product_kernel<<< + num_blocks, default_block_size, 0, exec->get_stream()>>>( x_ub, y_ub, res_ub, [] __device__(auto val) { return conj(val); }); } @@ -105,8 +134,9 @@ void compute_norm2(std::shared_ptr exec, const auto num_rhs = x->get_common_size()[1]; const auto x_ub = get_batch_struct(x); const auto res_ub = get_batch_struct(result); - compute_norm2_kernel<<get_stream()>>>(x_ub, res_ub); + batch_single_kernels::compute_norm2_kernel<<get_stream()>>>( + x_ub, res_ub); } GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( @@ -121,8 +151,15 @@ void copy(std::shared_ptr exec, const auto num_blocks = x->get_num_batch_items(); const auto result_ub = get_batch_struct(result); const auto x_ub = get_batch_struct(x); - copy_kernel<<get_stream()>>>( - x_ub, result_ub); + batch_single_kernels:: + copy_kernel<<get_stream()>>>( + x_ub, result_ub); } GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); + + +} // namespace batch_multi_vector +} // namespace GKO_DEVICE_NAMESPACE +} // namespace kernels +} // namespace gko diff --git a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc b/common/cuda_hip/base/batch_multi_vector_kernels.hpp similarity index 82% rename from common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc rename to common/cuda_hip/base/batch_multi_vector_kernels.hpp index 9b6301674be..bb3aac67b55 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc +++ b/common/cuda_hip/base/batch_multi_vector_kernels.hpp @@ -2,6 +2,44 @@ // // SPDX-License-Identifier: BSD-3-Clause +#include +#include + +#include +#include +#include +#include + +#include "common/cuda_hip/base/config.hpp" +#include "common/cuda_hip/base/math.hpp" +#include "common/cuda_hip/base/runtime.hpp" +#include "common/cuda_hip/base/thrust.hpp" +#include "common/cuda_hip/base/types.hpp" +#include "common/cuda_hip/components/cooperative_groups.hpp" +#include "common/cuda_hip/components/format_conversion.hpp" +#include "common/cuda_hip/components/reduction.hpp" +#include "common/cuda_hip/components/segment_scan.hpp" +#include "common/cuda_hip/components/thread_ids.hpp" +#include "common/cuda_hip/components/warp_blas.hpp" + +#if defined(GKO_COMPILING_CUDA) +#include "cuda/base/batch_struct.hpp" +#elif defined(GKO_COMPILING_HIP) +#include "hip/base/batch_struct.hip.hpp" +#else +#error "batch struct def missing" +#endif + + +namespace gko { +namespace kernels { +namespace GKO_DEVICE_NAMESPACE { +namespace batch_single_kernels { + + +constexpr auto default_block_size = 256; + + template __device__ __forceinline__ void scale( const gko::batch::multi_vector::batch_item& alpha, @@ -20,8 +58,7 @@ __device__ __forceinline__ void scale( template -__global__ -__launch_bounds__(default_block_size, sm_oversubscription) void scale_kernel( +__global__ __launch_bounds__(default_block_size) void scale_kernel( const gko::batch::multi_vector::uniform_batch alpha, const gko::batch::multi_vector::uniform_batch x, Mapping map) { @@ -52,20 +89,10 @@ __device__ __forceinline__ void add_scaled( template -__global__ __launch_bounds__( - default_block_size, - sm_oversubscription) void add_scaled_kernel(const gko::batch::multi_vector:: - uniform_batch< - const ValueType> - alpha, - const gko::batch::multi_vector:: - uniform_batch< - const ValueType> - x, - const gko::batch::multi_vector:: - uniform_batch - y, - Mapping map) +__global__ __launch_bounds__(default_block_size) void add_scaled_kernel( + const gko::batch::multi_vector::uniform_batch alpha, + const gko::batch::multi_vector::uniform_batch x, + const gko::batch::multi_vector::uniform_batch y, Mapping map) { for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_items; batch_id += gridDim.x) { @@ -145,7 +172,7 @@ __device__ __forceinline__ void compute_gen_dot_product( template __global__ -__launch_bounds__(default_block_size, sm_oversubscription) void compute_gen_dot_product_kernel( +__launch_bounds__(default_block_size) void compute_gen_dot_product_kernel( const gko::batch::multi_vector::uniform_batch x, const gko::batch::multi_vector::uniform_batch y, const gko::batch::multi_vector::uniform_batch result, @@ -232,19 +259,10 @@ __device__ __forceinline__ void compute_norm2( template -__global__ __launch_bounds__( - default_block_size, - sm_oversubscription) void compute_norm2_kernel(const gko::batch:: - multi_vector:: - uniform_batch< - const ValueType> - x, - const gko::batch:: - multi_vector:: - uniform_batch< - remove_complex< - ValueType>> - result) +__global__ __launch_bounds__(default_block_size) void compute_norm2_kernel( + const gko::batch::multi_vector::uniform_batch x, + const gko::batch::multi_vector::uniform_batch> + result) { for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_items; batch_id += gridDim.x) { @@ -287,8 +305,7 @@ __device__ __forceinline__ void copy( template -__global__ -__launch_bounds__(default_block_size, sm_oversubscription) void copy_kernel( +__global__ __launch_bounds__(default_block_size) void copy_kernel( const gko::batch::multi_vector::uniform_batch src, const gko::batch::multi_vector::uniform_batch dst) { @@ -299,3 +316,9 @@ __launch_bounds__(default_block_size, sm_oversubscription) void copy_kernel( copy(src_b, dst_b); } } + + +} // namespace batch_single_kernels +} // namespace GKO_DEVICE_NAMESPACE +} // namespace kernels +} // namespace gko diff --git a/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc index f71c8c40c3e..c2a53b2e518 100644 --- a/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc +++ b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc @@ -32,10 +32,14 @@ __device__ __forceinline__ void initialize( __syncthreads(); if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_norm2(subgroup, num_rows, r_shared_entry, res_norm); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2(subgroup, num_rows, r_shared_entry, + res_norm); } else if (threadIdx.x / config::warp_size == 1) { // Compute norms of rhs - single_rhs_compute_norm2(subgroup, num_rows, b_global_entry, rhs_norm); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2(subgroup, num_rows, b_global_entry, + rhs_norm); } __syncthreads(); @@ -70,8 +74,9 @@ __device__ __forceinline__ void compute_alpha( const ValueType* const v_shared_entry, ValueType& alpha) { if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry, - v_shared_entry, alpha); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry, + v_shared_entry, alpha); } __syncthreads(); if (threadIdx.x == 0) { @@ -99,11 +104,13 @@ __device__ __forceinline__ void compute_omega( const ValueType* const s_shared_entry, ValueType& temp, ValueType& omega) { if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry, - s_shared_entry, omega); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry, + s_shared_entry, omega); } else if (threadIdx.x / config::warp_size == 1) { - single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry, - t_shared_entry, temp); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry, + t_shared_entry, temp); } __syncthreads(); @@ -271,8 +278,9 @@ __global__ void apply_kernel( // rho_new = < r_hat , r > = (r_hat)' * (r) if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh, r_sh, - rho_new_sh[0]); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh, + r_sh, rho_new_sh[0]); } __syncthreads(); @@ -301,8 +309,9 @@ __global__ void apply_kernel( // an estimate of residual norms if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_norm2(subgroup, num_rows, s_sh, - norms_res_sh[0]); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2(subgroup, num_rows, s_sh, + norms_res_sh[0]); } __syncthreads(); @@ -333,8 +342,9 @@ __global__ void apply_kernel( __syncthreads(); if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_norm2(subgroup, num_rows, r_sh, - norms_res_sh[0]); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2(subgroup, num_rows, r_sh, + norms_res_sh[0]); } //__syncthreads(); @@ -347,7 +357,8 @@ __global__ void apply_kernel( logger.log_iteration(batch_id, iter, norms_res_sh[0]); // copy x back to global memory - single_rhs_copy(num_rows, x_sh, x_gl_entry_ptr); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_copy(num_rows, x_sh, x_gl_entry_ptr); __syncthreads(); } } diff --git a/common/cuda_hip/solver/batch_cg_kernels.hpp.inc b/common/cuda_hip/solver/batch_cg_kernels.hpp.inc index ffee501b58c..c95a6b1cf05 100644 --- a/common/cuda_hip/solver/batch_cg_kernels.hpp.inc +++ b/common/cuda_hip/solver/batch_cg_kernels.hpp.inc @@ -32,12 +32,14 @@ __device__ __forceinline__ void initialize( if (threadIdx.x / config::warp_size == 0) { // Compute norms of rhs - single_rhs_compute_norm2(subgroup, num_rows, b_global_entry, - rhs_norms_sh); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2(subgroup, num_rows, b_global_entry, + rhs_norms_sh); } else if (threadIdx.x / config::warp_size == 1) { // rho_old = r' * z - single_rhs_compute_conj_dot(subgroup, num_rows, r_shared_entry, - z_shared_entry, rho_old_shared_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot(subgroup, num_rows, r_shared_entry, + z_shared_entry, rho_old_shared_entry); } // p = z @@ -69,8 +71,9 @@ __device__ __forceinline__ void update_x_and_r( ValueType* const x_shared_entry, ValueType* const r_shared_entry) { if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_conj_dot(subgroup, num_rows, p_shared_entry, - Ap_shared_entry, alpha_shared_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot(subgroup, num_rows, p_shared_entry, + Ap_shared_entry, alpha_shared_entry); } __syncthreads(); @@ -202,8 +205,9 @@ __global__ void apply_kernel(const gko::kernels::batch_cg::storage_config sconf, if (threadIdx.x / config::warp_size == 0) { // rho_new = (r)' * (z) - single_rhs_compute_conj_dot(subgroup, num_rows, r_sh, z_sh, - rho_new_sh[0]); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot(subgroup, num_rows, r_sh, z_sh, + rho_new_sh[0]); } __syncthreads(); @@ -222,7 +226,8 @@ __global__ void apply_kernel(const gko::kernels::batch_cg::storage_config sconf, logger.log_iteration(batch_id, iter, norms_res_sh[0]); // copy x back to global memory - single_rhs_copy(num_rows, x_sh, x_global_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_copy(num_rows, x_sh, x_global_entry); __syncthreads(); } } diff --git a/cuda/CMakeLists.txt b/cuda/CMakeLists.txt index d4a94eda802..3631a65f48d 100644 --- a/cuda/CMakeLists.txt +++ b/cuda/CMakeLists.txt @@ -7,7 +7,6 @@ add_instantiation_files(${PROJECT_SOURCE_DIR}/common/cuda_hip matrix/fbcsr_kerne list(APPEND GKO_UNIFIED_COMMON_SOURCES ${PROJECT_SOURCE_DIR}/common/unified/matrix/dense_kernels.instantiate.cpp) target_sources(ginkgo_cuda PRIVATE - base/batch_multi_vector_kernels.cu base/device.cpp base/exception.cpp base/executor.cpp diff --git a/cuda/base/batch_multi_vector_kernels.cu b/cuda/base/batch_multi_vector_kernels.cu deleted file mode 100644 index 3dad5ba94f1..00000000000 --- a/cuda/base/batch_multi_vector_kernels.cu +++ /dev/null @@ -1,56 +0,0 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors -// -// SPDX-License-Identifier: BSD-3-Clause - -#include "core/base/batch_multi_vector_kernels.hpp" - -#include -#include - -#include -#include - -#include "common/cuda_hip/base/blas_bindings.hpp" -#include "common/cuda_hip/base/config.hpp" -#include "common/cuda_hip/base/pointer_mode_guard.hpp" -#include "common/cuda_hip/base/runtime.hpp" -#include "common/cuda_hip/base/thrust.hpp" -#include "common/cuda_hip/components/cooperative_groups.hpp" -#include "common/cuda_hip/components/reduction.hpp" -#include "common/cuda_hip/components/thread_ids.hpp" -#include "common/cuda_hip/components/warp_blas.hpp" -#include "core/base/batch_struct.hpp" -#include "cuda/base/batch_struct.hpp" - - -namespace gko { -namespace kernels { -namespace cuda { -/** - * @brief The MultiVector matrix format namespace. - * - * @ingroup batch_multi_vector - */ -namespace batch_multi_vector { - - -constexpr auto default_block_size = 256; -constexpr int sm_oversubscription = 4; - - -// clang-format off - -// NOTE: DO NOT CHANGE THE ORDERING OF THE INCLUDES - -#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc" - - -#include "common/cuda_hip/base/batch_multi_vector_kernel_launcher.hpp.inc" - -// clang-format on - - -} // namespace batch_multi_vector -} // namespace cuda -} // namespace kernels -} // namespace gko diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index 3c7fe50709c..4d3deb742fe 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -10,6 +10,7 @@ #include #include +#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp" #include "common/cuda_hip/base/config.hpp" #include "common/cuda_hip/base/runtime.hpp" #include "common/cuda_hip/base/thrust.hpp" @@ -43,7 +44,6 @@ constexpr int sm_oversubscription = 4; namespace batch_bicgstab { -#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_csr_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc" diff --git a/cuda/solver/batch_cg_kernels.cu b/cuda/solver/batch_cg_kernels.cu index b681bd13ce3..21c3e3d43c4 100644 --- a/cuda/solver/batch_cg_kernels.cu +++ b/cuda/solver/batch_cg_kernels.cu @@ -10,6 +10,7 @@ #include #include +#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp" #include "common/cuda_hip/base/config.hpp" #include "common/cuda_hip/base/thrust.hpp" #include "common/cuda_hip/base/types.hpp" @@ -42,7 +43,6 @@ constexpr int sm_oversubscription = 4; namespace batch_cg { -#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_csr_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc" diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index 8f607725bc8..0d2662bdccd 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -15,6 +15,7 @@ #include "core/base/batch_struct.hpp" #include "core/components/prefix_sum_kernels.hpp" +#include "dpcpp/base/batch_multi_vector_kernels.hpp" #include "dpcpp/base/batch_struct.hpp" #include "dpcpp/base/config.hpp" #include "dpcpp/base/dim3.dp.hpp" @@ -29,17 +30,9 @@ namespace gko { namespace kernels { namespace dpcpp { -/** - * @brief The MultiVector matrix format namespace. - * @ref MultiVector - * @ingroup batch_multi_vector - */ namespace batch_multi_vector { -#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc" - - template void scale(std::shared_ptr exec, const batch::MultiVector* const alpha, @@ -71,7 +64,7 @@ void scale(std::shared_ptr exec, const auto alpha_b = batch::extract_batch_item(alpha_ub, group_id); const auto x_b = batch::extract_batch_item(x_ub, group_id); - scale_kernel( + batch_single_kernels::scale_kernel( alpha_b, x_b, item_ct1, [](int row, int col, int stride) { return 0; }); }); @@ -85,10 +78,11 @@ void scale(std::shared_ptr exec, const auto alpha_b = batch::extract_batch_item(alpha_ub, group_id); const auto x_b = batch::extract_batch_item(x_ub, group_id); - scale_kernel(alpha_b, x_b, item_ct1, - [](int row, int col, int stride) { - return row * stride + col; - }); + batch_single_kernels::scale_kernel( + alpha_b, x_b, item_ct1, + [](int row, int col, int stride) { + return row * stride + col; + }); }); }); } else { @@ -100,7 +94,7 @@ void scale(std::shared_ptr exec, const auto alpha_b = batch::extract_batch_item(alpha_ub, group_id); const auto x_b = batch::extract_batch_item(x_ub, group_id); - scale_kernel( + batch_single_kernels::scale_kernel( alpha_b, x_b, item_ct1, [](int row, int col, int stride) { return col; }); }); @@ -144,8 +138,9 @@ void add_scaled(std::shared_ptr exec, batch::extract_batch_item(alpha_ub, group_id); const auto x_b = batch::extract_batch_item(x_ub, group_id); const auto y_b = batch::extract_batch_item(y_ub, group_id); - add_scaled_kernel(alpha_b, x_b, y_b, item_ct1, - [](auto col) { return 0; }); + batch_single_kernels::add_scaled_kernel( + alpha_b, x_b, y_b, item_ct1, + [](auto col) { return 0; }); }); }); } else { @@ -158,8 +153,9 @@ void add_scaled(std::shared_ptr exec, batch::extract_batch_item(alpha_ub, group_id); const auto x_b = batch::extract_batch_item(x_ub, group_id); const auto y_b = batch::extract_batch_item(y_ub, group_id); - add_scaled_kernel(alpha_b, x_b, y_b, item_ct1, - [](auto col) { return col; }); + batch_single_kernels::add_scaled_kernel( + alpha_b, x_b, y_b, item_ct1, + [](auto col) { return col; }); }); }); } @@ -206,7 +202,7 @@ void compute_dot(std::shared_ptr exec, batch::extract_batch_item(y_ub, group_id); const auto res_b = batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_conj_dot_sg( + batch_single_kernels::single_rhs_compute_conj_dot_sg( x_b.num_rows, x_b.values, y_b.values, res_b.values[0], item_ct1); }); @@ -226,7 +222,7 @@ void compute_dot(std::shared_ptr exec, batch::extract_batch_item(y_ub, group_id); const auto res_b = batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( + batch_single_kernels::compute_gen_dot_product_kernel( x_b, y_b, res_b, item_ct1, [](auto val) { return val; }); }); @@ -272,7 +268,7 @@ void compute_conj_dot(std::shared_ptr exec, const auto y_b = batch::extract_batch_item(y_ub, group_id); const auto res_b = batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( + batch_single_kernels::compute_gen_dot_product_kernel( x_b, y_b, res_b, item_ct1, [](auto val) { return conj(val); }); }); @@ -308,17 +304,16 @@ void compute_norm2(std::shared_ptr exec, exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, - res_b.values[0], item_ct1); - }); + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + batch_single_kernels::single_rhs_compute_norm2_sg( + x_b.num_rows, x_b.values, res_b.values[0], item_ct1); + }); }); } else { exec->get_queue()->submit([&](sycl::handler& cgh) { @@ -332,7 +327,8 @@ void compute_norm2(std::shared_ptr exec, batch::extract_batch_item(x_ub, group_id); const auto res_b = batch::extract_batch_item(res_ub, group_id); - compute_norm2_kernel(x_b, res_b, item_ct1); + batch_single_kernels::compute_norm2_kernel(x_b, res_b, + item_ct1); }); }); } @@ -371,7 +367,7 @@ void copy(std::shared_ptr exec, const auto x_b = batch::extract_batch_item(x_ub, group_id); const auto result_b = batch::extract_batch_item(result_ub, group_id); - copy_kernel(x_b, result_b, item_ct1); + batch_single_kernels::copy_kernel(x_b, result_b, item_ct1); }); }); } diff --git a/dpcpp/base/batch_multi_vector_kernels.hpp.inc b/dpcpp/base/batch_multi_vector_kernels.hpp similarity index 92% rename from dpcpp/base/batch_multi_vector_kernels.hpp.inc rename to dpcpp/base/batch_multi_vector_kernels.hpp index c41eafd7efd..bbcc540ae60 100644 --- a/dpcpp/base/batch_multi_vector_kernels.hpp.inc +++ b/dpcpp/base/batch_multi_vector_kernels.hpp @@ -2,6 +2,28 @@ // // SPDX-License-Identifier: BSD-3-Clause +#include + +#include + +#include "core/base/batch_struct.hpp" +#include "dpcpp/base/batch_struct.hpp" +#include "dpcpp/base/config.hpp" +#include "dpcpp/base/dim3.dp.hpp" +#include "dpcpp/base/dpct.hpp" +#include "dpcpp/base/helper.hpp" +#include "dpcpp/components/cooperative_groups.dp.hpp" +#include "dpcpp/components/intrinsics.dp.hpp" +#include "dpcpp/components/reduction.dp.hpp" +#include "dpcpp/components/thread_ids.dp.hpp" + + +namespace gko { +namespace kernels { +namespace GKO_DEVICE_NAMESPACE { +namespace batch_single_kernels { + + template __dpct_inline__ void scale_kernel( const gko::batch::multi_vector::batch_item& alpha, @@ -229,3 +251,9 @@ __dpct_inline__ void copy_kernel( out.values[i * out.stride + j] = in.values[i * in.stride + j]; } } + + +} // namespace batch_single_kernels +} // namespace GKO_DEVICE_NAMESPACE +} // namespace kernels +} // namespace gko diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index bb84283b49f..7dc8f3ec23b 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -13,6 +13,7 @@ #include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" +#include "dpcpp/base/batch_multi_vector_kernels.hpp" #include "dpcpp/base/batch_struct.hpp" #include "dpcpp/base/config.hpp" #include "dpcpp/base/dim3.dp.hpp" @@ -36,7 +37,6 @@ namespace dpcpp { namespace batch_bicgstab { -#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc" #include "dpcpp/matrix/batch_csr_kernels.hpp.inc" #include "dpcpp/matrix/batch_dense_kernels.hpp.inc" #include "dpcpp/matrix/batch_ell_kernels.hpp.inc" diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index ad7eaeff556..f5a88e9d59d 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -39,11 +39,13 @@ __dpct_inline__ void initialize( item_ct1.barrier(sycl::access::fence_space::global_and_local); if (sg_id == 0) { - single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm, - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm, + item_ct1); } else if (sg_id == 1) { - single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm, - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm, + item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); @@ -86,8 +88,9 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, const auto sg_id = sg.get_group_id(); const auto tid = item_ct1.get_local_linear_id(); if (sg_id == 0) { - single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, - v_shared_entry, alpha, item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, + v_shared_entry, alpha, item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); if (tid == 0) { @@ -123,11 +126,13 @@ __dpct_inline__ void compute_omega(const int num_rows, const auto sg_id = sg.get_group_id(); const auto tid = item_ct1.get_local_linear_id(); if (sg_id == 0) { - single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, s_shared_entry, - omega, item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, + s_shared_entry, omega, item_ct1); } else if (sg_id == 1) { - single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, t_shared_entry, - temp, item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, + t_shared_entry, temp, item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); if (tid == 0) { @@ -308,8 +313,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // rho_new = < r_hat , r > = (r_hat)' * (r) if (sg_id == 0) { - single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, - rho_new_sh[0], item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, + rho_new_sh[0], item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); @@ -338,8 +344,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // an estimate of residual norms if (sg_id == 0) { - single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], + item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); @@ -368,8 +375,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, item_ct1.barrier(sycl::access::fence_space::global_and_local); if (sg_id == 0) - single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0], - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0], + item_ct1); if (tid == group_size - 1) { rho_old_sh[0] = rho_new_sh[0]; } @@ -379,6 +387,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, logger.log_iteration(batch_id, iter, norms_res_sh[0]); // copy x back to global memory - copy_kernel(num_rows, x_sh, x_global_entry, item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel( + num_rows, x_sh, x_global_entry, item_ct1); item_ct1.barrier(sycl::access::fence_space::global_and_local); } diff --git a/dpcpp/solver/batch_cg_kernels.dp.cpp b/dpcpp/solver/batch_cg_kernels.dp.cpp index 61591f9efb6..f25d8266803 100644 --- a/dpcpp/solver/batch_cg_kernels.dp.cpp +++ b/dpcpp/solver/batch_cg_kernels.dp.cpp @@ -13,6 +13,7 @@ #include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" +#include "dpcpp/base/batch_multi_vector_kernels.hpp" #include "dpcpp/base/batch_struct.hpp" #include "dpcpp/base/config.hpp" #include "dpcpp/base/dim3.dp.hpp" @@ -36,7 +37,6 @@ namespace dpcpp { namespace batch_cg { -#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc" #include "dpcpp/matrix/batch_csr_kernels.hpp.inc" #include "dpcpp/matrix/batch_dense_kernels.hpp.inc" #include "dpcpp/matrix/batch_ell_kernels.hpp.inc" diff --git a/dpcpp/solver/batch_cg_kernels.hpp.inc b/dpcpp/solver/batch_cg_kernels.hpp.inc index cef6e620b64..7a91bcb2bbf 100644 --- a/dpcpp/solver/batch_cg_kernels.hpp.inc +++ b/dpcpp/solver/batch_cg_kernels.hpp.inc @@ -40,11 +40,13 @@ __dpct_inline__ void initialize( // Compute norms of rhs // and rho_old = r' * z if (sg_id == 0) { - single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norms, - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norms, + item_ct1); } else if (sg_id == 1) { - single_rhs_compute_conj_dot_sg(num_rows, r_shared_entry, z_shared_entry, - rho_old, item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, r_shared_entry, + z_shared_entry, rho_old, item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); @@ -80,9 +82,10 @@ __dpct_inline__ void update_x_and_r( auto sg = item_ct1.get_sub_group(); const auto tid = item_ct1.get_local_linear_id(); if (sg.get_group_id() == 0) { - single_rhs_compute_conj_dot_sg(num_rows, p_shared_entry, - Ap_shared_entry, alpha_shared_entry, - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, p_shared_entry, + Ap_shared_entry, alpha_shared_entry, + item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); if (tid == 0) { @@ -221,8 +224,9 @@ __dpct_inline__ void apply_kernel( // rho_new = (r)' * (z) if (sg_id == 0) { - single_rhs_compute_conj_dot_sg(num_rows, r_sh, z_sh, rho_new_sh[0], - item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + single_rhs_compute_conj_dot_sg(num_rows, r_sh, z_sh, + rho_new_sh[0], item_ct1); } item_ct1.barrier(sycl::access::fence_space::global_and_local); @@ -239,6 +243,7 @@ __dpct_inline__ void apply_kernel( logger.log_iteration(batch_id, iter, norms_res_sh[0]); // copy x back to global memory - copy_kernel(num_rows, x_sh, x_global_entry, item_ct1); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel( + num_rows, x_sh, x_global_entry, item_ct1); item_ct1.barrier(sycl::access::fence_space::global_and_local); } diff --git a/hip/CMakeLists.txt b/hip/CMakeLists.txt index 46b2d7bd19b..84bba295120 100644 --- a/hip/CMakeLists.txt +++ b/hip/CMakeLists.txt @@ -5,7 +5,6 @@ add_instantiation_files(${PROJECT_SOURCE_DIR}/common/cuda_hip matrix/fbcsr_kerne # we don't split up the dense kernels into distinct compilations list(APPEND GKO_UNIFIED_COMMON_SOURCES ${PROJECT_SOURCE_DIR}/common/unified/matrix/dense_kernels.instantiate.cpp) set(GINKGO_HIP_SOURCES - base/batch_multi_vector_kernels.hip.cpp base/device.hip.cpp base/exception.hip.cpp base/executor.hip.cpp diff --git a/hip/base/batch_multi_vector_kernels.hip.cpp b/hip/base/batch_multi_vector_kernels.hip.cpp deleted file mode 100644 index 701f4655a9a..00000000000 --- a/hip/base/batch_multi_vector_kernels.hip.cpp +++ /dev/null @@ -1,56 +0,0 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors -// -// SPDX-License-Identifier: BSD-3-Clause - -#include "core/base/batch_multi_vector_kernels.hpp" - -#include -#include - -#include -#include - -#include "common/cuda_hip/base/blas_bindings.hpp" -#include "common/cuda_hip/base/config.hpp" -#include "common/cuda_hip/base/pointer_mode_guard.hpp" -#include "common/cuda_hip/base/runtime.hpp" -#include "common/cuda_hip/base/thrust.hpp" -#include "common/cuda_hip/components/cooperative_groups.hpp" -#include "common/cuda_hip/components/reduction.hpp" -#include "common/cuda_hip/components/thread_ids.hpp" -#include "common/cuda_hip/components/uninitialized_array.hpp" -#include "core/base/batch_struct.hpp" -#include "hip/base/batch_struct.hip.hpp" - - -namespace gko { -namespace kernels { -namespace hip { -/** - * @brief The MultiVector matrix format namespace. - * - * @ingroup batch_multi_vector - */ -namespace batch_multi_vector { - - -constexpr auto default_block_size = 256; -constexpr int sm_oversubscription = 4; - - -// clang-format off - -// NOTE: DO NOT CHANGE THE ORDERING OF THE INCLUDES - -#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc" - - -#include "common/cuda_hip/base/batch_multi_vector_kernel_launcher.hpp.inc" - -// clang-format on - - -} // namespace batch_multi_vector -} // namespace hip -} // namespace kernels -} // namespace gko diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index ca49fa5eb9c..1c1be8b21f7 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -10,6 +10,7 @@ #include #include +#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp" #include "common/cuda_hip/base/config.hpp" #include "common/cuda_hip/base/math.hpp" #include "common/cuda_hip/base/runtime.hpp" @@ -42,7 +43,6 @@ constexpr int sm_oversubscription = 4; namespace batch_bicgstab { -#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_csr_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc" diff --git a/hip/solver/batch_cg_kernels.hip.cpp b/hip/solver/batch_cg_kernels.hip.cpp index 3a1642edfea..c860286c17c 100644 --- a/hip/solver/batch_cg_kernels.hip.cpp +++ b/hip/solver/batch_cg_kernels.hip.cpp @@ -10,6 +10,7 @@ #include #include +#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp" #include "common/cuda_hip/base/config.hpp" #include "common/cuda_hip/base/math.hpp" #include "common/cuda_hip/base/runtime.hpp" @@ -42,7 +43,6 @@ constexpr int sm_oversubscription = 4; namespace batch_cg { -#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_csr_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc" #include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc" diff --git a/omp/base/batch_multi_vector_kernels.cpp b/omp/base/batch_multi_vector_kernels.cpp index 395bf96cc7a..f740e3c32f0 100644 --- a/omp/base/batch_multi_vector_kernels.cpp +++ b/omp/base/batch_multi_vector_kernels.cpp @@ -10,24 +10,18 @@ #include #include +#include "common/unified/base/kernel_launch.hpp" #include "core/components/prefix_sum_kernels.hpp" +#include "reference/base/batch_multi_vector_kernels.hpp" #include "reference/base/batch_struct.hpp" namespace gko { namespace kernels { -namespace omp { -/** - * @brief The batch::MultiVector matrix format namespace. - * @ref batch::MultiVector - * @ingroup batch_multi_vector - */ +namespace GKO_DEVICE_NAMESPACE { namespace batch_multi_vector { -#include "reference/base/batch_multi_vector_kernels.hpp.inc" - - template void scale(std::shared_ptr exec, const batch::MultiVector* const alpha, @@ -39,7 +33,7 @@ void scale(std::shared_ptr exec, for (size_type batch = 0; batch < x->get_num_batch_items(); ++batch) { const auto alpha_b = gko::batch::extract_batch_item(alpha_ub, batch); const auto x_b = gko::batch::extract_batch_item(x_ub, batch); - scale_kernel(alpha_b, x_b); + batch_single_kernels::scale_kernel(alpha_b, x_b); } } @@ -61,7 +55,7 @@ void add_scaled(std::shared_ptr exec, const auto alpha_b = gko::batch::extract_batch_item(alpha_ub, batch); const auto x_b = gko::batch::extract_batch_item(x_ub, batch); const auto y_b = gko::batch::extract_batch_item(y_ub, batch); - add_scaled_kernel(alpha_b, x_b, y_b); + batch_single_kernels::add_scaled_kernel(alpha_b, x_b, y_b); } } @@ -83,7 +77,7 @@ void compute_dot(std::shared_ptr exec, const auto res_b = gko::batch::extract_batch_item(res_ub, batch); const auto x_b = gko::batch::extract_batch_item(x_ub, batch); const auto y_b = gko::batch::extract_batch_item(y_ub, batch); - compute_dot_product_kernel(x_b, y_b, res_b); + batch_single_kernels::compute_dot_product_kernel(x_b, y_b, res_b); } } @@ -105,7 +99,7 @@ void compute_conj_dot(std::shared_ptr exec, const auto res_b = gko::batch::extract_batch_item(res_ub, batch); const auto x_b = gko::batch::extract_batch_item(x_ub, batch); const auto y_b = gko::batch::extract_batch_item(y_ub, batch); - compute_conj_dot_product_kernel(x_b, y_b, res_b); + batch_single_kernels::compute_conj_dot_product_kernel(x_b, y_b, res_b); } } @@ -124,7 +118,7 @@ void compute_norm2(std::shared_ptr exec, for (size_type batch = 0; batch < result->get_num_batch_items(); ++batch) { const auto res_b = gko::batch::extract_batch_item(res_ub, batch); const auto x_b = gko::batch::extract_batch_item(x_ub, batch); - compute_norm2_kernel(x_b, res_b); + batch_single_kernels::compute_norm2_kernel(x_b, res_b); } } @@ -143,7 +137,7 @@ void copy(std::shared_ptr exec, for (size_type batch = 0; batch < x->get_num_batch_items(); ++batch) { const auto result_b = gko::batch::extract_batch_item(result_ub, batch); const auto x_b = gko::batch::extract_batch_item(x_ub, batch); - copy_kernel(x_b, result_b); + batch_single_kernels::copy_kernel(x_b, result_b); } } @@ -151,6 +145,6 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); } // namespace batch_multi_vector -} // namespace omp +} // namespace GKO_DEVICE_NAMESPACE } // namespace kernels } // namespace gko diff --git a/omp/solver/batch_bicgstab_kernels.cpp b/omp/solver/batch_bicgstab_kernels.cpp index 81df9c45e51..c245f284106 100644 --- a/omp/solver/batch_bicgstab_kernels.cpp +++ b/omp/solver/batch_bicgstab_kernels.cpp @@ -9,6 +9,7 @@ #include #include "core/solver/batch_dispatch.hpp" +#include "reference/base/batch_multi_vector_kernels.hpp" namespace gko { @@ -28,7 +29,6 @@ namespace { constexpr int max_num_rhs = 1; -#include "reference/base/batch_multi_vector_kernels.hpp.inc" #include "reference/matrix/batch_csr_kernels.hpp.inc" #include "reference/matrix/batch_dense_kernels.hpp.inc" #include "reference/matrix/batch_ell_kernels.hpp.inc" diff --git a/omp/solver/batch_cg_kernels.cpp b/omp/solver/batch_cg_kernels.cpp index 51c794ab597..55d6ee29321 100644 --- a/omp/solver/batch_cg_kernels.cpp +++ b/omp/solver/batch_cg_kernels.cpp @@ -9,6 +9,7 @@ #include #include "core/solver/batch_dispatch.hpp" +#include "reference/base/batch_multi_vector_kernels.hpp" namespace gko { @@ -28,7 +29,6 @@ namespace { constexpr int max_num_rhs = 1; -#include "reference/base/batch_multi_vector_kernels.hpp.inc" #include "reference/matrix/batch_csr_kernels.hpp.inc" #include "reference/matrix/batch_dense_kernels.hpp.inc" #include "reference/matrix/batch_ell_kernels.hpp.inc" diff --git a/reference/base/batch_multi_vector_kernels.cpp b/reference/base/batch_multi_vector_kernels.cpp index b0d20a6b826..f5e1c653054 100644 --- a/reference/base/batch_multi_vector_kernels.cpp +++ b/reference/base/batch_multi_vector_kernels.cpp @@ -10,24 +10,21 @@ #include #include + +#define GKO_DEVICE_NAMESPACE reference + + #include "core/base/batch_struct.hpp" +#include "reference/base/batch_multi_vector_kernels.hpp" #include "reference/base/batch_struct.hpp" namespace gko { namespace kernels { -namespace reference { -/** - * @brief The batch::MultiVector matrix format namespace. - * @ref batch::MultiVector - * @ingroup batch_multi_vector - */ +namespace GKO_DEVICE_NAMESPACE { namespace batch_multi_vector { -#include "reference/base/batch_multi_vector_kernels.hpp.inc" - - template void scale(std::shared_ptr exec, const batch::MultiVector* alpha, @@ -38,7 +35,7 @@ void scale(std::shared_ptr exec, for (size_type batch = 0; batch < x->get_num_batch_items(); ++batch) { const auto alpha_b = batch::extract_batch_item(alpha_ub, batch); const auto x_b = batch::extract_batch_item(x_ub, batch); - scale_kernel(alpha_b, x_b); + batch_single_kernels::scale_kernel(alpha_b, x_b); } } @@ -59,7 +56,7 @@ void add_scaled(std::shared_ptr exec, const auto alpha_b = batch::extract_batch_item(alpha_ub, batch); const auto x_b = batch::extract_batch_item(x_ub, batch); const auto y_b = batch::extract_batch_item(y_ub, batch); - add_scaled_kernel(alpha_b, x_b, y_b); + batch_single_kernels::add_scaled_kernel(alpha_b, x_b, y_b); } } @@ -80,7 +77,7 @@ void compute_dot(std::shared_ptr exec, const auto res_b = batch::extract_batch_item(res_ub, batch); const auto x_b = batch::extract_batch_item(x_ub, batch); const auto y_b = batch::extract_batch_item(y_ub, batch); - compute_dot_product_kernel(x_b, y_b, res_b); + batch_single_kernels::compute_dot_product_kernel(x_b, y_b, res_b); } } @@ -101,7 +98,7 @@ void compute_conj_dot(std::shared_ptr exec, const auto res_b = batch::extract_batch_item(res_ub, batch); const auto x_b = batch::extract_batch_item(x_ub, batch); const auto y_b = batch::extract_batch_item(y_ub, batch); - compute_conj_dot_product_kernel(x_b, y_b, res_b); + batch_single_kernels::compute_conj_dot_product_kernel(x_b, y_b, res_b); } } @@ -119,7 +116,7 @@ void compute_norm2(std::shared_ptr exec, for (size_type batch = 0; batch < result->get_num_batch_items(); ++batch) { const auto res_b = batch::extract_batch_item(res_ub, batch); const auto x_b = batch::extract_batch_item(x_ub, batch); - compute_norm2_kernel(x_b, res_b); + batch_single_kernels::compute_norm2_kernel(x_b, res_b); } } @@ -137,7 +134,7 @@ void copy(std::shared_ptr exec, for (size_type batch = 0; batch < x->get_num_batch_items(); ++batch) { const auto result_b = batch::extract_batch_item(result_ub, batch); const auto x_b = batch::extract_batch_item(x_ub, batch); - copy_kernel(x_b, result_b); + batch_single_kernels::copy_kernel(x_b, result_b); } } @@ -145,6 +142,6 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); } // namespace batch_multi_vector -} // namespace reference +} // namespace GKO_DEVICE_NAMESPACE } // namespace kernels } // namespace gko diff --git a/reference/base/batch_multi_vector_kernels.hpp.inc b/reference/base/batch_multi_vector_kernels.hpp similarity index 90% rename from reference/base/batch_multi_vector_kernels.hpp.inc rename to reference/base/batch_multi_vector_kernels.hpp index 24e59664b74..88f531f29cc 100644 --- a/reference/base/batch_multi_vector_kernels.hpp.inc +++ b/reference/base/batch_multi_vector_kernels.hpp @@ -2,6 +2,20 @@ // // SPDX-License-Identifier: BSD-3-Clause +#include +#include +#include +#include + +#include "reference/base/batch_struct.hpp" + + +namespace gko { +namespace kernels { +namespace GKO_DEVICE_NAMESPACE { +namespace batch_single_kernels { + + template inline void scale_kernel( const gko::batch::multi_vector::batch_item& alpha, @@ -129,3 +143,9 @@ inline void copy_kernel( out.values[i * out.stride + j] = in.values[i * in.stride + j]; } } + + +} // namespace batch_single_kernels +} // namespace GKO_DEVICE_NAMESPACE +} // namespace kernels +} // namespace gko diff --git a/reference/solver/batch_bicgstab_kernels.cpp b/reference/solver/batch_bicgstab_kernels.cpp index 97de157fb90..e68caffa936 100644 --- a/reference/solver/batch_bicgstab_kernels.cpp +++ b/reference/solver/batch_bicgstab_kernels.cpp @@ -5,6 +5,7 @@ #include "core/solver/batch_bicgstab_kernels.hpp" #include "core/solver/batch_dispatch.hpp" +#include "reference/base/batch_multi_vector_kernels.hpp" namespace gko { @@ -26,7 +27,6 @@ namespace { constexpr int max_num_rhs = 1; -#include "reference/base/batch_multi_vector_kernels.hpp.inc" #include "reference/matrix/batch_csr_kernels.hpp.inc" #include "reference/matrix/batch_dense_kernels.hpp.inc" #include "reference/matrix/batch_ell_kernels.hpp.inc" diff --git a/reference/solver/batch_bicgstab_kernels.hpp.inc b/reference/solver/batch_bicgstab_kernels.hpp.inc index b61db3669ef..1f8537ab66d 100644 --- a/reference/solver/batch_bicgstab_kernels.hpp.inc +++ b/reference/solver/batch_bicgstab_kernels.hpp.inc @@ -25,17 +25,20 @@ inline void initialize( alpha_entry.values[0] = one(); // Compute norms of rhs - compute_norm2_kernel(b_entry, rhs_norms_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + compute_norm2_kernel(b_entry, rhs_norms_entry); // r = b - copy_kernel(b_entry, r_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel( + b_entry, r_entry); // r = b - A*x advanced_apply_kernel(static_cast(-1.0), A_entry, gko::batch::to_const(x_entry), static_cast(1.0), r_entry); - compute_norm2_kernel(gko::batch::to_const(r_entry), - res_norms_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + compute_norm2_kernel(gko::batch::to_const(r_entry), + res_norms_entry); for (int r = 0; r < p_entry.num_rows; r++) { r_hat_entry.values[r * r_hat_entry.stride] = @@ -75,7 +78,9 @@ inline void compute_alpha( const gko::batch::multi_vector::batch_item& v_entry, const gko::batch::multi_vector::batch_item& alpha_entry) { - compute_dot_product_kernel(r_hat_entry, v_entry, alpha_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + compute_dot_product_kernel(r_hat_entry, v_entry, + alpha_entry); alpha_entry.values[0] = rho_new_entry.values[0] / alpha_entry.values[0]; } @@ -102,8 +107,10 @@ inline void compute_omega( const gko::batch::multi_vector::batch_item& temp_entry, const gko::batch::multi_vector::batch_item& omega_entry) { - compute_dot_product_kernel(t_entry, s_entry, omega_entry); - compute_dot_product_kernel(t_entry, t_entry, temp_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + compute_dot_product_kernel(t_entry, s_entry, omega_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + compute_dot_product_kernel(t_entry, t_entry, temp_entry); omega_entry.values[0] /= temp_entry.values[0]; } @@ -246,9 +253,10 @@ inline void batch_entry_bicgstab_impl( } // rho_new = < r_hat , r > = (r_hat)' * (r) - compute_dot_product_kernel(gko::batch::to_const(r_hat_entry), - gko::batch::to_const(r_entry), - rho_new_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + compute_dot_product_kernel( + gko::batch::to_const(r_hat_entry), + gko::batch::to_const(r_entry), rho_new_entry); // beta = (rho_new / rho_old)*(alpha / omega) // p = r + beta*(p - omega * v) @@ -277,8 +285,9 @@ inline void batch_entry_bicgstab_impl( gko::batch::to_const(v_entry), s_entry); // an estimate of residual norms - compute_norm2_kernel(gko::batch::to_const(s_entry), - res_norms_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + compute_norm2_kernel(gko::batch::to_const(s_entry), + res_norms_entry); if (stop.check_converged(res_norms_entry.values)) { // update x for the systems @@ -310,11 +319,13 @@ inline void batch_entry_bicgstab_impl( gko::batch::to_const(s_entry), gko::batch::to_const(t_entry), x_entry, r_entry); - compute_norm2_kernel(gko::batch::to_const(r_entry), - res_norms_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + compute_norm2_kernel(gko::batch::to_const(r_entry), + res_norms_entry); // rho_old = rho_new - copy_kernel(gko::batch::to_const(rho_new_entry), rho_old_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel( + gko::batch::to_const(rho_new_entry), rho_old_entry); } logger.log_iteration(batch_item_id, iter, res_norms_entry.values[0]); diff --git a/reference/solver/batch_cg_kernels.cpp b/reference/solver/batch_cg_kernels.cpp index 290fbc3718b..785a7a868a2 100644 --- a/reference/solver/batch_cg_kernels.cpp +++ b/reference/solver/batch_cg_kernels.cpp @@ -5,6 +5,7 @@ #include "core/solver/batch_cg_kernels.hpp" #include "core/solver/batch_dispatch.hpp" +#include "reference/base/batch_multi_vector_kernels.hpp" namespace gko { @@ -26,7 +27,6 @@ namespace { constexpr int max_num_rhs = 1; -#include "reference/base/batch_multi_vector_kernels.hpp.inc" #include "reference/matrix/batch_csr_kernels.hpp.inc" #include "reference/matrix/batch_dense_kernels.hpp.inc" #include "reference/matrix/batch_ell_kernels.hpp.inc" diff --git a/reference/solver/batch_cg_kernels.hpp.inc b/reference/solver/batch_cg_kernels.hpp.inc index b3df5ba97fd..ca88940cd69 100644 --- a/reference/solver/batch_cg_kernels.hpp.inc +++ b/reference/solver/batch_cg_kernels.hpp.inc @@ -26,10 +26,12 @@ inline void initialize( } // Compute norms of rhs - compute_norm2_kernel(b_entry, rhs_norms_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + compute_norm2_kernel(b_entry, rhs_norms_entry); // r = b - copy_kernel(b_entry, r_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel( + b_entry, r_entry); // r = b - A*x advanced_apply_kernel(static_cast(-1.0), A_entry, @@ -46,7 +48,8 @@ inline void update_p( const gko::batch::multi_vector::batch_item& p_entry) { if (rho_old_entry.values[0] == zero()) { - copy_kernel(z_entry, p_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel( + z_entry, p_entry); return; } const ValueType beta = rho_new_entry.values[0] / rho_old_entry.values[0]; @@ -67,7 +70,9 @@ inline void update_x_and_r( const gko::batch::multi_vector::batch_item& x_entry, const gko::batch::multi_vector::batch_item& r_entry) { - compute_conj_dot_product_kernel(p_entry, Ap_entry, alpha_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + compute_conj_dot_product_kernel(p_entry, Ap_entry, + alpha_entry); const ValueType temp = rho_old_entry.values[0] / alpha_entry.values[0]; for (int row = 0; row < r_entry.num_rows; row++) { @@ -154,9 +159,10 @@ inline void batch_entry_cg_impl( prec.apply(gko::batch::to_const(r_entry), z_entry); // rho_new = < r , z > = (r)' * (z) - compute_conj_dot_product_kernel( - gko::batch::to_const(r_entry), gko::batch::to_const(z_entry), - rho_new_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels:: + compute_conj_dot_product_kernel( + gko::batch::to_const(r_entry), gko::batch::to_const(z_entry), + rho_new_entry); ++iter; // use implicit residual norms res_norms_entry.values[0] = sqrt(abs(rho_new_entry.values[0])); @@ -185,7 +191,8 @@ inline void batch_entry_cg_impl( gko::batch::to_const(Ap_entry), alpha_entry, x_entry, r_entry); // rho_old = rho_new - copy_kernel(gko::batch::to_const(rho_new_entry), rho_old_entry); + gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel( + gko::batch::to_const(rho_new_entry), rho_old_entry); } logger.log_iteration(batch_item_id, iter, res_norms_entry.values[0]);