From b653d3b2df31eb060741137975593e0dffe68801 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Mon, 30 Oct 2023 12:13:02 +0100 Subject: [PATCH] Review updates Co-authored-by: Yu-Hsiang Tsai Co-authored-by: Marcel Koch --- .../base/batch_multi_vector_kernels.hpp.inc | 2 +- .../preconditioner/batch_identity.hpp.inc | 3 +- .../solver/batch_bicgstab_kernels.hpp.inc | 26 ++-- core/solver/batch_bicgstab_kernels.hpp | 18 +-- cuda/matrix/batch_struct.hpp | 3 +- cuda/solver/batch_bicgstab_kernels.cu | 5 +- dpcpp/base/batch_multi_vector_kernels.dp.cpp | 137 +++++++++--------- dpcpp/base/batch_multi_vector_kernels.hpp.inc | 14 +- dpcpp/matrix/batch_struct.hpp | 2 +- dpcpp/preconditioner/batch_identity.hpp.inc | 8 +- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 25 ++-- dpcpp/solver/batch_bicgstab_kernels.hpp.inc | 112 +++++++------- hip/matrix/batch_struct.hip.hpp | 3 +- reference/matrix/batch_struct.hpp | 2 +- 14 files changed, 177 insertions(+), 183 deletions(-) diff --git a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc index 72d58ecf5b3..1e0cb3bbcff 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc +++ b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc @@ -104,7 +104,7 @@ __global__ __launch_bounds__( template -__device__ __forceinline__ void single_rhs_compute_dot(Group subgroup, +__device__ __forceinline__ void single_rhs_compute_conj_dot(Group subgroup, const int num_rows, const ValueType* x, const ValueType* y, diff --git a/common/cuda_hip/preconditioner/batch_identity.hpp.inc b/common/cuda_hip/preconditioner/batch_identity.hpp.inc index 1b1fb7b5482..d3fa7fe737c 100644 --- a/common/cuda_hip/preconditioner/batch_identity.hpp.inc +++ b/common/cuda_hip/preconditioner/batch_identity.hpp.inc @@ -47,7 +47,8 @@ public: __device__ __forceinline__ void generate( size_type, - const gko::batch::matrix::ell::batch_item&, + const gko::batch::matrix::ell::batch_item&, ValueType*) {} diff --git a/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc index 0f666f205e8..faee2e069a7 100644 --- a/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc +++ b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc @@ -38,7 +38,8 @@ __device__ __forceinline__ void initialize( const ValueType* const x_global_entry, ValueType& rho_old, ValueType& omega, ValueType& alpha, ValueType* const x_shared_entry, ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry, - ValueType* const p_shared_entry, ValueType* const v_shared_entry, + ValueType* const p_shared_entry, ValueType* const p_hat_shared_entry, + ValueType* const v_shared_entry, typename gko::remove_complex& rhs_norm, typename gko::remove_complex& res_norm) { @@ -70,6 +71,7 @@ __device__ __forceinline__ void initialize( for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) { r_hat_shared_entry[iz] = r_shared_entry[iz]; p_shared_entry[iz] = zero(); + p_hat_shared_entry[iz] = zero(); v_shared_entry[iz] = zero(); } } @@ -82,8 +84,8 @@ __device__ __forceinline__ void update_p( const ValueType* const r_shared_entry, const ValueType* const v_shared_entry, ValueType* const p_shared_entry) { + const ValueType beta = (rho_new / rho_old) * (alpha / omega); for (int r = threadIdx.x; r < num_rows; r += blockDim.x) { - const ValueType beta = (rho_new / rho_old) * (alpha / omega); p_shared_entry[r] = r_shared_entry[r] + beta * (p_shared_entry[r] - omega * v_shared_entry[r]); @@ -97,8 +99,8 @@ __device__ __forceinline__ void compute_alpha( const ValueType* const v_shared_entry, ValueType& alpha) { if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_dot(subgroup, num_rows, r_hat_shared_entry, - v_shared_entry, alpha); + single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry, + v_shared_entry, alpha); } __syncthreads(); if (threadIdx.x == 0) { @@ -126,11 +128,11 @@ __device__ __forceinline__ void compute_omega( const ValueType* const s_shared_entry, ValueType& temp, ValueType& omega) { if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_dot(subgroup, num_rows, t_shared_entry, - s_shared_entry, omega); + single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry, + s_shared_entry, omega); } else if (threadIdx.x / config::warp_size == 1) { - single_rhs_compute_dot(subgroup, num_rows, t_shared_entry, - t_shared_entry, temp); + single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry, + t_shared_entry, temp); } __syncthreads(); @@ -278,10 +280,12 @@ __global__ void apply_kernel( // compute residual norms // r_hat = r // p = 0 + // p_hat = 0 // v = 0 initialize(subgroup, num_rows, mat_entry, b_entry_ptr, x_gl_entry_ptr, rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh, - r_hat_sh, p_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0]); + r_hat_sh, p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], + norms_res_sh[0]); __syncthreads(); // stopping criterion object @@ -296,8 +300,8 @@ __global__ void apply_kernel( // rho_new = < r_hat , r > = (r_hat)' * (r) if (threadIdx.x / config::warp_size == 0) { - single_rhs_compute_dot(subgroup, num_rows, r_hat_sh, r_sh, - rho_new_sh[0]); + single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh, r_sh, + rho_new_sh[0]); } __syncthreads(); diff --git a/core/solver/batch_bicgstab_kernels.hpp b/core/solver/batch_bicgstab_kernels.hpp index b3a5faf0a49..f43a7f2ddd5 100644 --- a/core/solver/batch_bicgstab_kernels.hpp +++ b/core/solver/batch_bicgstab_kernels.hpp @@ -115,8 +115,7 @@ void set_gmem_stride_bytes(storage_config& sconf, } // align global memory chunks sconf.gmem_stride_bytes = - gmem_stride > 0 ? ((gmem_stride - 1) / align_bytes + 1) * align_bytes - : 0; + gmem_stride > 0 ? ceildiv(gmem_stride, align_bytes) * align_bytes : 0; } @@ -143,8 +142,8 @@ void set_gmem_stride_bytes(storage_config& sconf, * - rhs_norms * - res_norms * - * @param shared_mem_per_blk The amount of shared memory per block to use for - * keeping intermediate vectors. In case keeping the matrix in L1 cache etc. + * @param available_shared_mem The amount of shared memory per block to use + * for keeping intermediate vectors. In case keeping the matrix in L1 cache etc. * should be prioritized, the cache configuration must be updated separately * and the needed space should be subtracted before passing to this * function. @@ -154,7 +153,7 @@ void set_gmem_stride_bytes(storage_config& sconf, * @return A struct containing allocation information specific to Bicgstab. */ template -storage_config compute_shared_storage(const int shared_mem_per_blk, +storage_config compute_shared_storage(const int available_shared_mem, const int num_rows, const int num_nz, const int num_rhs) { @@ -163,10 +162,11 @@ storage_config compute_shared_storage(const int shared_mem_per_blk, const int num_main_vecs = 9; const int prec_storage = Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType); - int rem_shared = shared_mem_per_blk; - // Set default values. All vecs are in global. + int rem_shared = available_shared_mem; + // Set default values. Initially all vecs are in global memory. + // {prec_shared, n_shared, n_global, gmem_stride_bytes, padded_vec_len} storage_config sconf{false, 0, num_main_vecs, 0, num_rows}; - // If available shared mem, is zero, set all vecs to global. + // If available shared mem is zero, set all vecs to global. if (rem_shared <= 0) { set_gmem_stride_bytes(sconf, vec_size, prec_storage); return sconf; @@ -177,13 +177,13 @@ storage_config compute_shared_storage(const int shared_mem_per_blk, const int num_vecs_shared = min(initial_vecs_available, num_main_vecs); sconf.n_shared += num_vecs_shared; sconf.n_global -= num_vecs_shared; + rem_shared -= num_vecs_shared * vec_size; // Set the storage configuration with preconditioner workspace in global if // there are any vectors in global memory. if (sconf.n_global > 0) { set_gmem_stride_bytes(sconf, vec_size, prec_storage); return sconf; } - rem_shared -= num_vecs_shared * vec_size; // If more shared memory space is available and preconditioner workspace is // needed, enable preconditioner workspace to use shared memory. if (rem_shared >= prec_storage && prec_storage > 0) { diff --git a/cuda/matrix/batch_struct.hpp b/cuda/matrix/batch_struct.hpp index 4a2a1835961..55a30c043e3 100644 --- a/cuda/matrix/batch_struct.hpp +++ b/cuda/matrix/batch_struct.hpp @@ -92,7 +92,8 @@ get_batch_struct(batch::matrix::Dense* const op) * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch, IndexType> +inline batch::matrix::ell::uniform_batch, + const IndexType> get_batch_struct(const batch::matrix::Ell* const op) { return {as_cuda_type(op->get_const_values()), diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index 9ecb27aecf2..e2762656abe 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -101,10 +101,7 @@ int get_num_threads_per_block(std::shared_ptr exec, cudaDeviceGetAttribute(&max_regs_blk, cudaDevAttrMaxRegistersPerBlock, exec->get_device_id()); const int max_threads_regs = - ((max_regs_blk / - static_cast((static_cast(num_regs_used)))) / - warp_sz) * - warp_sz; + ((max_regs_blk / static_cast(num_regs_used)) / warp_sz) * warp_sz; int max_threads = std::min(max_threads_regs, device_max_threads); max_threads = max_threads <= 1024 ? max_threads : 1024; return std::min(num_warps * warp_sz, max_threads); diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index 3068b654b75..c9809696889 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -87,7 +87,7 @@ void scale(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -141,7 +141,7 @@ void add_scaled(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -202,7 +202,7 @@ void compute_dot(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -210,41 +210,37 @@ void compute_dot(std::shared_ptr exec, if (x->get_common_size()[1] == 1) { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto y_b = - batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_dot_sg(x_b.num_rows, x_b.values, - y_b.values, res_b.values[0], - item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_conj_dot_sg(x_b.num_rows, x_b.values, + y_b.values, res_b.values[0], + item_ct1); + }); }); } else { // TODO: Remove reqd_sub_group size and use sycl::reduce_over_group exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto y_b = - batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( - x_b, y_b, res_b, item_ct1, - [](auto val) { return val; }); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return val; }); + }); }); } } @@ -270,7 +266,7 @@ void compute_conj_dot(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -278,19 +274,18 @@ void compute_conj_dot(std::shared_ptr exec, exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto y_b = batch::extract_batch_item(y_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel( - x_b, y_b, res_b, item_ct1, - [](auto val) { return conj(val); }); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto y_b = batch::extract_batch_item(y_ub, group_id); + const auto res_b = batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return conj(val); }); + }); }); } @@ -314,7 +309,7 @@ void compute_norm2(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); @@ -322,33 +317,31 @@ void compute_norm2(std::shared_ptr exec, if (x->get_common_size()[1] == 1) { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, - res_b.values[0], item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, + res_b.values[0], item_ct1); + }); }); } else { exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(max_subgroup_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto res_b = - batch::extract_batch_item(res_ub, group_id); - compute_norm2_kernel(x_b, res_b, item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_norm2_kernel(x_b, res_b, item_ct1); + }); }); } } @@ -372,7 +365,7 @@ void copy(std::shared_ptr exec, long max_group_size = device.get_info(); int group_size = - std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, max_group_size); const dim3 block(group_size); diff --git a/dpcpp/base/batch_multi_vector_kernels.hpp.inc b/dpcpp/base/batch_multi_vector_kernels.hpp.inc index 4db1dc5e1d7..6b503efddd2 100644 --- a/dpcpp/base/batch_multi_vector_kernels.hpp.inc +++ b/dpcpp/base/batch_multi_vector_kernels.hpp.inc @@ -68,13 +68,13 @@ __dpct_inline__ void add_scaled_kernel( template -__dpct_inline__ void single_rhs_compute_dot( +__dpct_inline__ void single_rhs_compute_conj_dot( const int num_rows, const ValueType* const __restrict__ x, const ValueType* const __restrict__ y, ValueType& result, sycl::nd_item<3> item_ct1) { - // auto grp = - // group::tiled_partition(group::this_thread_block(item_ct1)); + auto grp = + group::tiled_partition(group::this_thread_block(item_ct1)); // auto grp = group::this_thread_block(item_ct1); const auto group = item_ct1.get_group(); const auto group_size = item_ct1.get_local_range().size(); @@ -90,7 +90,7 @@ __dpct_inline__ void single_rhs_compute_dot( template -__dpct_inline__ void single_rhs_compute_dot_sg( +__dpct_inline__ void single_rhs_compute_conj_dot_sg( const int num_rows, const ValueType* const __restrict__ x, const ValueType* const __restrict__ y, ValueType& result, sycl::nd_item<3> item_ct1) @@ -178,13 +178,13 @@ __dpct_inline__ void single_rhs_compute_norm2_sg( } -template +template __dpct_inline__ void single_rhs_compute_norm2( const int num_rows, const ValueType* const __restrict__ x, gko::remove_complex& result, sycl::nd_item<3> item_ct1) { - // auto grp = - // group::tiled_partition(group::this_thread_block(item_ct1)); + auto grp = + group::tiled_partition(group::this_thread_block(item_ct1)); const auto group = item_ct1.get_group(); const auto group_size = item_ct1.get_local_range().size(); const auto tid = item_ct1.get_local_linear_id(); diff --git a/dpcpp/matrix/batch_struct.hpp b/dpcpp/matrix/batch_struct.hpp index fe04407d82d..7f36378d8e1 100644 --- a/dpcpp/matrix/batch_struct.hpp +++ b/dpcpp/matrix/batch_struct.hpp @@ -91,7 +91,7 @@ inline batch::matrix::dense::uniform_batch get_batch_struct( * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch +inline batch::matrix::ell::uniform_batch get_batch_struct(const batch::matrix::Ell* const op) { return {op->get_const_values(), diff --git a/dpcpp/preconditioner/batch_identity.hpp.inc b/dpcpp/preconditioner/batch_identity.hpp.inc index 404d987a3f4..e15a4d37399 100644 --- a/dpcpp/preconditioner/batch_identity.hpp.inc +++ b/dpcpp/preconditioner/batch_identity.hpp.inc @@ -42,10 +42,10 @@ public: static int dynamic_work_size(int, int) { return 0; } - void generate( - size_type batch_id, - const gko::batch::matrix::ell::batch_item&, - ValueType* const, sycl::nd_item<3> item_ct1) + void generate(size_type batch_id, + const gko::batch::matrix::ell::batch_item&, + ValueType* const, sycl::nd_item<3> item_ct1) {} void generate(size_type batch_id, diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 33749e91ae4..85ff5d442a1 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -60,9 +60,9 @@ namespace gko { namespace kernels { namespace dpcpp { /** - * @brief The batch Cg solver namespace. + * @brief The batch Bicgstab solver namespace. * - * @ingroup batch_cg + * @ingroup batch_bicgstab */ namespace batch_bicgstab { @@ -77,10 +77,10 @@ template using settings = gko::kernels::batch_bicgstab::settings; -__dpct_inline__ int get_group_size(int value, int simd_len = 32) +__dpct_inline__ int get_group_size(int value, int subgroup_size = 32) { - int num_sg = ceildiv(value, simd_len); - return num_sg * simd_len; + int num_sg = ceildiv(value, subgroup_size); + return num_sg * subgroup_size; } @@ -92,9 +92,9 @@ class KernelCaller { : exec_{std::move(exec)}, settings_{settings} {} - template + template __dpct_inline__ void launch_apply_kernel( const gko::kernels::batch_bicgstab::storage_config& sconf, LogType& logger, PrecType& prec, const BatchMatrixType mat, @@ -111,15 +111,16 @@ class KernelCaller { auto max_iters = settings_.max_iterations; auto res_tol = settings_.residual_tol; - (exec_->get_queue())->submit([&](sycl::handler& cgh) { + exec_->get_queue()->submit([&](sycl::handler& cgh) { sycl::accessor slm_values(sycl::range<1>(shared_size), cgh); cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( - simd_len)]] [[intel::kernel_args_restrict]] { + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + subgroup_size)]] [ + [intel::kernel_args_restrict]] { auto batch_id = item_ct1.get_group_linear_id(); const auto mat_global_entry = gko::batch::matrix::extract_batch_item(mat, batch_id); diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc index 67057f80e53..0b6f4511f02 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -39,6 +39,7 @@ __dpct_inline__ void initialize( ValueType& alpha, ValueType* const x_shared_entry, ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry, ValueType* const p_shared_entry, ValueType* const v_shared_entry, + ValueType* const p_hat_shared_entry, typename gko::remove_complex& rhs_norm, typename gko::remove_complex& res_norm, sycl::nd_item<3> item_ct1) @@ -85,6 +86,7 @@ __dpct_inline__ void initialize( for (int iz = tid; iz < num_rows; iz += group_size) { r_hat_shared_entry[iz] = r_shared_entry[iz]; p_shared_entry[iz] = zero(); + p_hat_shared_entry[iz] = zero(); v_shared_entry[iz] = zero(); } } @@ -115,23 +117,24 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, const ValueType* const v_shared_entry, ValueType& alpha, sycl::nd_item<3> item_ct1) { + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); if constexpr (sg_kernel_all) { - auto sg = item_ct1.get_sub_group(); - const auto sg_id = sg.get_group_id(); - const auto tid = item_ct1.get_local_linear_id(); - if (sg_id == 0) { - single_rhs_compute_dot_sg(num_rows, r_hat_shared_entry, - v_shared_entry, alpha, item_ct1); + single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, + v_shared_entry, alpha, item_ct1); } if (tid == 0) { alpha = rho_new / alpha; } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, r_hat_shared_entry, v_shared_entry, - alpha, item_ct1); - alpha = rho_new / alpha; + single_rhs_compute_conj_dot(num_rows, r_hat_shared_entry, + v_shared_entry, alpha, item_ct1); + if (tid == 0) { + alpha = rho_new / alpha; + } } } @@ -158,26 +161,30 @@ __dpct_inline__ void compute_omega(const int num_rows, ValueType& temp, ValueType& omega, sycl::nd_item<3> item_ct1) { + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); if constexpr (sg_kernel_all) { - auto sg = item_ct1.get_sub_group(); - const auto sg_id = sg.get_group_id(); - const auto tid = item_ct1.get_local_linear_id(); - - if (sg_id == 0) - single_rhs_compute_dot_sg(num_rows, t_shared_entry, s_shared_entry, - omega, item_ct1); - else if (sg_id == 1) - single_rhs_compute_dot_sg(num_rows, t_shared_entry, t_shared_entry, - temp, item_ct1); + if (sg_id == 0) { + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, + s_shared_entry, omega, item_ct1); + } else if (sg_id == 1) { + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, + t_shared_entry, temp, item_ct1); + } item_ct1.barrier(sycl::access::fence_space::local_space); - if (tid == 0) omega /= temp; + if (tid == 0) { + omega /= temp; + } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, t_shared_entry, s_shared_entry, omega, - item_ct1); - single_rhs_compute_dot(num_rows, t_shared_entry, t_shared_entry, temp, - item_ct1); - omega /= temp; + single_rhs_compute_conj_dot(num_rows, t_shared_entry, s_shared_entry, + omega, item_ct1); + single_rhs_compute_conj_dot(num_rows, t_shared_entry, t_shared_entry, + temp, item_ct1); + if (tid == 0) { + omega /= temp; + } } } @@ -244,33 +251,21 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, real_type* norms_rhs_sh; real_type* norms_res_sh; - if constexpr (sg_kernel_all) { - using tile_value_t = ValueType[5]; - tile_value_t& values = - *sycl::ext::oneapi::group_local_memory_for_overwrite( - group); - using tile_real_t = real_type[2]; - tile_real_t& reals = - *sycl::ext::oneapi::group_local_memory_for_overwrite( - group); - rho_old_sh = &values[0]; - rho_new_sh = &values[1]; - alpha_sh = &values[2]; - omega_sh = &values[3]; - temp_sh = &values[4]; - norms_rhs_sh = &reals[0]; - norms_res_sh = &reals[1]; - } else { - ValueType values[5]; - real_type reals[2]; - rho_old_sh = &values[0]; - rho_new_sh = &values[1]; - alpha_sh = &values[2]; - omega_sh = &values[3]; - temp_sh = &values[4]; - norms_rhs_sh = &reals[0]; - norms_res_sh = &reals[1]; - } + using tile_value_t = ValueType[5]; + tile_value_t& values = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + group); + using tile_real_t = real_type[2]; + tile_real_t& reals = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + group); + rho_old_sh = &values[0]; + rho_new_sh = &values[1]; + alpha_sh = &values[2]; + omega_sh = &values[3]; + temp_sh = &values[4]; + norms_rhs_sh = &reals[0]; + norms_res_sh = &reals[1]; const int gmem_offset = batch_id * sconf.gmem_stride_bytes / sizeof(ValueType); ValueType* p_hat_sh; @@ -346,11 +341,12 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // compute residual norms // r_hat = r // p = 0 + // p_hat = 0 // v = 0 initialize(num_rows, mat_global_entry, b_global_entry, x_global_entry, rho_old_sh[0], omega_sh[0], - alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, v_sh, - norms_rhs_sh[0], norms_res_sh[0], item_ct1); + alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, p_hat_sh, + v_sh, norms_rhs_sh[0], norms_res_sh[0], item_ct1); item_ct1.barrier(sycl::access::fence_space::local_space); // stopping criterion object @@ -366,13 +362,13 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, // rho_new = < r_hat , r > = (r_hat)' * (r) if constexpr (sg_kernel_all) { if (sg_id == 0) { - single_rhs_compute_dot_sg(num_rows, r_hat_sh, r_sh, - rho_new_sh[0], item_ct1); + single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, + rho_new_sh[0], item_ct1); } item_ct1.barrier(sycl::access::fence_space::local_space); } else { - single_rhs_compute_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], - item_ct1); + single_rhs_compute_conj_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0], + item_ct1); } // beta = (rho_new / rho_old)*(alpha / omega) diff --git a/hip/matrix/batch_struct.hip.hpp b/hip/matrix/batch_struct.hip.hpp index e35f13f1249..ba75b1b634e 100644 --- a/hip/matrix/batch_struct.hip.hpp +++ b/hip/matrix/batch_struct.hip.hpp @@ -92,7 +92,8 @@ get_batch_struct(batch::matrix::Dense* const op) * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch, IndexType> +inline batch::matrix::ell::uniform_batch, + const IndexType> get_batch_struct(const batch::matrix::Ell* const op) { return {as_hip_type(op->get_const_values()), diff --git a/reference/matrix/batch_struct.hpp b/reference/matrix/batch_struct.hpp index bb7680d1493..94beff5c2c2 100644 --- a/reference/matrix/batch_struct.hpp +++ b/reference/matrix/batch_struct.hpp @@ -95,7 +95,7 @@ inline batch::matrix::dense::uniform_batch get_batch_struct( * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch +inline batch::matrix::ell::uniform_batch get_batch_struct(const batch::matrix::Ell* const op) { return {op->get_const_values(),