diff --git a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc index 9f77598ff5a..cb157d80fd5 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc +++ b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc @@ -103,6 +103,28 @@ __global__ __launch_bounds__( } +template +__device__ __forceinline__ void single_rhs_compute_conj_dot(Group subgroup, + const int num_rows, + const ValueType* x, + const ValueType* y, + ValueType& result) + +{ + ValueType val = zero(); + for (int r = subgroup.thread_rank(); r < num_rows; r += subgroup.size()) { + val += conj(x[r]) * y[r]; + } + + // subgroup level reduction + val = reduce(subgroup, val, thrust::plus{}); + + if (subgroup.thread_rank() == 0) { + result = val; + } +} + + template __device__ __forceinline__ void gen_one_dot( const gko::batch::multi_vector::batch_item& x, @@ -165,6 +187,27 @@ __launch_bounds__(default_block_size, sm_oversubscription) void compute_gen_dot_ } +template +__device__ __forceinline__ void single_rhs_compute_norm2( + Group subgroup, const int num_rows, const ValueType* x, + remove_complex& result) +{ + using real_type = typename gko::remove_complex; + real_type val = zero(); + + for (int r = subgroup.thread_rank(); r < num_rows; r += subgroup.size()) { + val += squared_norm(x[r]); + } + + // subgroup level reduction + val = reduce(subgroup, val, thrust::plus>{}); + + if (subgroup.thread_rank() == 0) { + result = sqrt(val); + } +} + + template __device__ __forceinline__ void one_norm2( const gko::batch::multi_vector::batch_item& x, @@ -238,6 +281,17 @@ __global__ __launch_bounds__( } +template +__device__ __forceinline__ void single_rhs_copy(const int num_rows, + const ValueType* in, + ValueType* out) +{ + for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) { + out[iz] = in[iz]; + } +} + + /** * Copies the values of one multi-vector into another. * diff --git a/common/cuda_hip/log/batch_logger.hpp.inc b/common/cuda_hip/log/batch_logger.hpp.inc index 7a4d59b67e9..e8cf77960ef 100644 --- a/common/cuda_hip/log/batch_logger.hpp.inc +++ b/common/cuda_hip/log/batch_logger.hpp.inc @@ -36,7 +36,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. template class SimpleFinalLogger final { public: - using real_type = remove_complex; + using real_type = RealType; SimpleFinalLogger(real_type* const batch_residuals, int* const batch_iters) : final_residuals_{batch_residuals}, final_iters_{batch_iters} diff --git a/common/cuda_hip/preconditioner/batch_identity.hpp.inc b/common/cuda_hip/preconditioner/batch_identity.hpp.inc index 1b1fb7b5482..923ed4ce946 100644 --- a/common/cuda_hip/preconditioner/batch_identity.hpp.inc +++ b/common/cuda_hip/preconditioner/batch_identity.hpp.inc @@ -45,16 +45,9 @@ public: return 0; } - __device__ __forceinline__ void generate( - size_type, - const gko::batch::matrix::ell::batch_item&, - ValueType*) - {} - - __device__ __forceinline__ void generate( - size_type, - const gko::batch::matrix::dense::batch_item&, - ValueType*) + template + __device__ __forceinline__ void generate(size_type, const batch_item_type&, + ValueType*) {} __device__ __forceinline__ void apply(const int num_rows, diff --git a/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc new file mode 100644 index 00000000000..faee2e069a7 --- /dev/null +++ b/common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc @@ -0,0 +1,382 @@ +/************************************************************* +Copyright (c) 2017-2023, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + + +template +__device__ __forceinline__ void initialize( + Group subgroup, const int num_rows, const BatchMatrixType_entry& mat_entry, + const ValueType* const b_global_entry, + const ValueType* const x_global_entry, ValueType& rho_old, ValueType& omega, + ValueType& alpha, ValueType* const x_shared_entry, + ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry, + ValueType* const p_shared_entry, ValueType* const p_hat_shared_entry, + ValueType* const v_shared_entry, + typename gko::remove_complex& rhs_norm, + typename gko::remove_complex& res_norm) +{ + rho_old = one(); + omega = one(); + alpha = one(); + + // copy x from global to shared memory + // r = b + for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) { + x_shared_entry[iz] = x_global_entry[iz]; + r_shared_entry[iz] = b_global_entry[iz]; + } + __syncthreads(); + + // r = b - A*x + advanced_apply(static_cast(-1.0), mat_entry, x_shared_entry, + static_cast(1.0), r_shared_entry); + __syncthreads(); + + if (threadIdx.x / config::warp_size == 0) { + single_rhs_compute_norm2(subgroup, num_rows, r_shared_entry, res_norm); + } else if (threadIdx.x / config::warp_size == 1) { + // Compute norms of rhs + single_rhs_compute_norm2(subgroup, num_rows, b_global_entry, rhs_norm); + } + __syncthreads(); + + for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) { + r_hat_shared_entry[iz] = r_shared_entry[iz]; + p_shared_entry[iz] = zero(); + p_hat_shared_entry[iz] = zero(); + v_shared_entry[iz] = zero(); + } +} + + +template +__device__ __forceinline__ void update_p( + const int num_rows, const ValueType& rho_new, const ValueType& rho_old, + const ValueType& alpha, const ValueType& omega, + const ValueType* const r_shared_entry, + const ValueType* const v_shared_entry, ValueType* const p_shared_entry) +{ + const ValueType beta = (rho_new / rho_old) * (alpha / omega); + for (int r = threadIdx.x; r < num_rows; r += blockDim.x) { + p_shared_entry[r] = + r_shared_entry[r] + + beta * (p_shared_entry[r] - omega * v_shared_entry[r]); + } +} + +template +__device__ __forceinline__ void compute_alpha( + Group subgroup, const int num_rows, const ValueType& rho_new, + const ValueType* const r_hat_shared_entry, + const ValueType* const v_shared_entry, ValueType& alpha) +{ + if (threadIdx.x / config::warp_size == 0) { + single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry, + v_shared_entry, alpha); + } + __syncthreads(); + if (threadIdx.x == 0) { + alpha = rho_new / alpha; + } +} + + +template +__device__ __forceinline__ void update_s(const int num_rows, + const ValueType* const r_shared_entry, + const ValueType& alpha, + const ValueType* const v_shared_entry, + ValueType* const s_shared_entry) +{ + for (int r = threadIdx.x; r < num_rows; r += blockDim.x) { + s_shared_entry[r] = r_shared_entry[r] - alpha * v_shared_entry[r]; + } +} + + +template +__device__ __forceinline__ void compute_omega( + Group subgroup, const int num_rows, const ValueType* const t_shared_entry, + const ValueType* const s_shared_entry, ValueType& temp, ValueType& omega) +{ + if (threadIdx.x / config::warp_size == 0) { + single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry, + s_shared_entry, omega); + } else if (threadIdx.x / config::warp_size == 1) { + single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry, + t_shared_entry, temp); + } + + __syncthreads(); + if (threadIdx.x == 0) { + omega /= temp; + } +} + +template +__device__ __forceinline__ void update_x_and_r( + const int num_rows, const ValueType* const p_hat_shared_entry, + const ValueType* const s_hat_shared_entry, const ValueType& alpha, + const ValueType& omega, const ValueType* const s_shared_entry, + const ValueType* const t_shared_entry, ValueType* const x_shared_entry, + ValueType* const r_shared_entry) +{ + for (int r = threadIdx.x; r < num_rows; r += blockDim.x) { + x_shared_entry[r] = x_shared_entry[r] + alpha * p_hat_shared_entry[r] + + omega * s_hat_shared_entry[r]; + r_shared_entry[r] = s_shared_entry[r] - omega * t_shared_entry[r]; + } +} + + +template +__device__ __forceinline__ void update_x_middle( + const int num_rows, const ValueType& alpha, + const ValueType* const p_hat_shared_entry, ValueType* const x_shared_entry) +{ + for (int r = threadIdx.x; r < num_rows; r += blockDim.x) { + x_shared_entry[r] = x_shared_entry[r] + alpha * p_hat_shared_entry[r]; + } +} + + +template +__global__ void apply_kernel( + const gko::kernels::batch_bicgstab::storage_config sconf, + const int max_iter, const gko::remove_complex tol, + LogType logger, PrecType prec_shared, const BatchMatrixType mat, + const ValueType* const __restrict__ b, ValueType* const __restrict__ x, + ValueType* const __restrict__ workspace = nullptr) +{ + using real_type = typename gko::remove_complex; + const auto num_batch_items = mat.num_batch_items; + const auto num_rows = mat.num_rows; + + constexpr auto tile_size = config::warp_size; + auto thread_block = group::this_thread_block(); + auto subgroup = group::tiled_partition(thread_block); + + for (int batch_id = blockIdx.x; batch_id < num_batch_items; + batch_id += gridDim.x) { + const int gmem_offset = + batch_id * sconf.gmem_stride_bytes / sizeof(ValueType); + extern __shared__ char local_mem_sh[]; + + ValueType* p_hat_sh; + ValueType* s_hat_sh; + ValueType* p_sh; + ValueType* s_sh; + ValueType* r_sh; + ValueType* r_hat_sh; + ValueType* v_sh; + ValueType* t_sh; + ValueType* x_sh; + ValueType* prec_work_sh; + + if (n_shared >= 1) { + p_hat_sh = reinterpret_cast(local_mem_sh); + } else { + p_hat_sh = workspace + gmem_offset; + } + if (n_shared == 1) { + s_hat_sh = workspace + gmem_offset; + } else { + s_hat_sh = p_hat_sh + sconf.padded_vec_len; + } + if (n_shared == 2) { + v_sh = workspace + gmem_offset; + } else { + v_sh = s_hat_sh + sconf.padded_vec_len; + } + if (n_shared == 3) { + t_sh = workspace + gmem_offset; + } else { + t_sh = v_sh + sconf.padded_vec_len; + } + if (n_shared == 4) { + p_sh = workspace + gmem_offset; + } else { + p_sh = t_sh + sconf.padded_vec_len; + } + if (n_shared == 5) { + s_sh = workspace + gmem_offset; + } else { + s_sh = p_sh + sconf.padded_vec_len; + } + if (n_shared == 6) { + r_sh = workspace + gmem_offset; + } else { + r_sh = s_sh + sconf.padded_vec_len; + } + if (n_shared == 7) { + r_hat_sh = workspace + gmem_offset; + } else { + r_hat_sh = r_sh + sconf.padded_vec_len; + } + if (n_shared == 8) { + x_sh = workspace + gmem_offset; + } else { + x_sh = r_hat_sh + sconf.padded_vec_len; + } + if (!prec_shared_bool && n_shared == 9) { + prec_work_sh = workspace + gmem_offset; + } else { + prec_work_sh = x_sh + sconf.padded_vec_len; + } + + __shared__ uninitialized_array rho_old_sh; + __shared__ uninitialized_array rho_new_sh; + __shared__ uninitialized_array omega_sh; + __shared__ uninitialized_array alpha_sh; + __shared__ uninitialized_array temp_sh; + __shared__ real_type norms_rhs_sh[1]; + __shared__ real_type norms_res_sh[1]; + + const auto mat_entry = + gko::batch::matrix::extract_batch_item(mat, batch_id); + const ValueType* const b_entry_ptr = + gko::batch::multi_vector::batch_item_ptr(b, 1, num_rows, batch_id); + ValueType* const x_gl_entry_ptr = + gko::batch::multi_vector::batch_item_ptr(x, 1, num_rows, batch_id); + + // generate preconditioner + prec_shared.generate(batch_id, mat_entry, prec_work_sh); + + // initialization + // rho_old = 1, omega = 1, alpha = 1 + // compute b norms + // copy x from global to shared memory + // r = b - A*x + // compute residual norms + // r_hat = r + // p = 0 + // p_hat = 0 + // v = 0 + initialize(subgroup, num_rows, mat_entry, b_entry_ptr, x_gl_entry_ptr, + rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh, + r_hat_sh, p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], + norms_res_sh[0]); + __syncthreads(); + + // stopping criterion object + StopType stop(tol, norms_rhs_sh); + + int iter = 0; + for (; iter < max_iter; iter++) { + if (stop.check_converged(norms_res_sh)) { + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + break; + } + + // rho_new = < r_hat , r > = (r_hat)' * (r) + if (threadIdx.x / config::warp_size == 0) { + single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh, r_sh, + rho_new_sh[0]); + } + __syncthreads(); + + // beta = (rho_new / rho_old)*(alpha / omega) + // p = r + beta*(p - omega * v) + update_p(num_rows, rho_new_sh[0], rho_old_sh[0], alpha_sh[0], + omega_sh[0], r_sh, v_sh, p_sh); + __syncthreads(); + + // p_hat = precond * p + prec_shared.apply(num_rows, p_sh, p_hat_sh); + __syncthreads(); + + // v = A * p_hat + simple_apply(mat_entry, p_hat_sh, v_sh); + __syncthreads(); + + // alpha = rho_new / < r_hat , v> + compute_alpha(subgroup, num_rows, rho_new_sh[0], r_hat_sh, v_sh, + alpha_sh[0]); + __syncthreads(); + + // s = r - alpha*v + update_s(num_rows, r_sh, alpha_sh[0], v_sh, s_sh); + __syncthreads(); + + // an estimate of residual norms + if (threadIdx.x / config::warp_size == 0) { + single_rhs_compute_norm2(subgroup, num_rows, s_sh, + norms_res_sh[0]); + } + __syncthreads(); + + // if (norms_res_sh[0] / norms_rhs_sh[0] < tol) { + if (stop.check_converged(norms_res_sh)) { + update_x_middle(num_rows, alpha_sh[0], p_hat_sh, x_sh); + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + break; + } + + // s_hat = precond * s + prec_shared.apply(num_rows, s_sh, s_hat_sh); + __syncthreads(); + + // t = A * s_hat + simple_apply(mat_entry, s_hat_sh, t_sh); + __syncthreads(); + + // omega = / + compute_omega(subgroup, num_rows, t_sh, s_sh, temp_sh[0], + omega_sh[0]); + __syncthreads(); + + // x = x + alpha*p_hat + omega *s_hat + // r = s - omega * t + update_x_and_r(num_rows, p_hat_sh, s_hat_sh, alpha_sh[0], + omega_sh[0], s_sh, t_sh, x_sh, r_sh); + __syncthreads(); + + if (threadIdx.x / config::warp_size == 0) { + single_rhs_compute_norm2(subgroup, num_rows, r_sh, + norms_res_sh[0]); + } + //__syncthreads(); + + if (threadIdx.x == blockDim.x - 1) { + rho_old_sh[0] = rho_new_sh[0]; + } + __syncthreads(); + } + + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + + // copy x back to global memory + single_rhs_copy(num_rows, x_sh, x_gl_entry_ptr); + __syncthreads(); + } +} diff --git a/contributors.txt b/contributors.txt index 1f1259bc082..aec120d93dd 100644 --- a/contributors.txt +++ b/contributors.txt @@ -20,6 +20,7 @@ Kashi Aditya Karlsruhe Institute of Technology Koch Marcel Karlsruhe Institute of Technology Maier Matthias Texas A&M University Nayak Pratik Karlsruhe Institute of Technology +Nguyen Phuong University of Tennessee, Knoxville Olenik Gregor HPSim Ribizel Tobias Karlsruhe Institute of Technology Riemer Lukas Karlsruhe Institute of Technology diff --git a/core/base/batch_struct.hpp b/core/base/batch_struct.hpp index 975671739eb..041630af66e 100644 --- a/core/base/batch_struct.hpp +++ b/core/base/batch_struct.hpp @@ -78,6 +78,15 @@ struct uniform_batch { }; +template +GKO_ATTRIBUTES GKO_INLINE ValueType* batch_item_ptr( + ValueType* const batch_start, const size_type stride, const int num_rows, + const size_type batch_idx) +{ + return batch_start + batch_idx * stride * num_rows; +} + + } // namespace multi_vector diff --git a/core/base/batch_utilities.hpp b/core/base/batch_utilities.hpp index f05a80322aa..cc92d294173 100644 --- a/core/base/batch_utilities.hpp +++ b/core/base/batch_utilities.hpp @@ -201,8 +201,9 @@ std::unique_ptr read( std::forward(create_args)...); for (size_type b = 0; b < num_batch_items; ++b) { - if (data.at(b).size != data.at(0).size) + if (data.at(b).size != data.at(0).size) { GKO_INVALID_STATE("Incorrect data passed in"); + } tmp->create_view_for_item(b)->read(data[b]); } diff --git a/core/matrix/batch_ell.cpp b/core/matrix/batch_ell.cpp index 19b2dcae5c3..88863a05dd4 100644 --- a/core/matrix/batch_ell.cpp +++ b/core/matrix/batch_ell.cpp @@ -134,7 +134,10 @@ Ell* Ell::apply( ptr_param> b, ptr_param> x) { - static_cast(this)->apply(b, x); + this->validate_application_parameters(b.get(), x.get()); + auto exec = this->get_executor(); + this->apply_impl(make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, x).get()); return this; } @@ -159,7 +162,13 @@ Ell* Ell::apply( ptr_param> beta, ptr_param> x) { - static_cast(this)->apply(alpha, b, beta, x); + this->validate_application_parameters(alpha.get(), b.get(), beta.get(), + x.get()); + auto exec = this->get_executor(); + this->apply_impl(make_temporary_clone(exec, alpha).get(), + make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, beta).get(), + make_temporary_clone(exec, x).get()); return this; } diff --git a/core/solver/batch_bicgstab_kernels.hpp b/core/solver/batch_bicgstab_kernels.hpp index 4689badeebd..32291562afd 100644 --- a/core/solver/batch_bicgstab_kernels.hpp +++ b/core/solver/batch_bicgstab_kernels.hpp @@ -92,6 +92,112 @@ inline int local_memory_requirement(const int num_rows, const int num_rhs) } +struct storage_config { + // preconditioner storage + bool prec_shared; + // total number of shared vectors + int n_shared; + // number of vectors in global memory + int n_global; + // global stride from one batch entry to the next + int gmem_stride_bytes; + // padded vector length + int padded_vec_len; +}; + + +template +void set_gmem_stride_bytes(storage_config& sconf, + const int multi_vector_size_bytes, + const int prec_storage_bytes) +{ + int gmem_stride = sconf.n_global * multi_vector_size_bytes; + if (!sconf.prec_shared) { + gmem_stride += prec_storage_bytes; + } + // align global memory chunks + sconf.gmem_stride_bytes = + gmem_stride > 0 ? ceildiv(gmem_stride, align_bytes) * align_bytes : 0; +} + + +/** + * Calculates the amount of in-solver storage needed by batch-Bicgstab and + * the split between shared and global memory. + * + * The calculation includes multivectors for + * - r + * - r_hat + * - p + * - p_hat + * - v + * - s + * - s_hat + * - t + * - x + * In addition, small arrays are needed for + * - rho_old + * - rho_new + * - omega + * - alpha + * - temp + * - rhs_norms + * - res_norms + * + * @param available_shared_mem The amount of shared memory per block to use + * for keeping intermediate vectors. In case keeping the matrix in L1 cache etc. + * should be prioritized, the cache configuration must be updated separately + * and the needed space should be subtracted before passing to this + * function. + * @param num_rows Size of the matrix. + * @param num_nz Number of nonzeros in the matrix + * @param num_rhs Number of right-hand-sides in the vectors. + * @return A struct containing allocation information specific to Bicgstab. + */ +template +storage_config compute_shared_storage(const int available_shared_mem, + const int num_rows, const int num_nz, + const int num_rhs) +{ + using real_type = remove_complex; + const int vec_size = num_rows * num_rhs * sizeof(ValueType); + const int num_main_vecs = 9; + const int prec_storage = + Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType); + int rem_shared = available_shared_mem; + // Set default values. Initially all vecs are in global memory. + // {prec_shared, n_shared, n_global, gmem_stride_bytes, padded_vec_len} + storage_config sconf{false, 0, num_main_vecs, 0, num_rows}; + // If available shared mem is zero, set all vecs to global. + if (rem_shared <= 0) { + set_gmem_stride_bytes(sconf, vec_size, prec_storage); + return sconf; + } + // Compute the number of vecs that can be stored in shared memory and assign + // the rest to global memory. + const int initial_vecs_available = rem_shared / vec_size; + const int num_vecs_shared = min(initial_vecs_available, num_main_vecs); + sconf.n_shared += num_vecs_shared; + sconf.n_global -= num_vecs_shared; + rem_shared -= num_vecs_shared * vec_size; + // Set the storage configuration with preconditioner workspace in global if + // there are any vectors in global memory. + if (sconf.n_global > 0) { + set_gmem_stride_bytes(sconf, vec_size, prec_storage); + return sconf; + } + // If more shared memory space is available and preconditioner workspace is + // needed, enable preconditioner workspace to use shared memory. + if (rem_shared >= prec_storage && prec_storage > 0) { + sconf.prec_shared = true; + rem_shared -= prec_storage; + } + // Set the global storage config and align to align_bytes bytes. + set_gmem_stride_bytes(sconf, vec_size, prec_storage); + return sconf; +} + + } // namespace batch_bicgstab diff --git a/core/test/utils/batch_helpers.hpp b/core/test/utils/batch_helpers.hpp index 77c2d397889..eee31050505 100644 --- a/core/test/utils/batch_helpers.hpp +++ b/core/test/utils/batch_helpers.hpp @@ -166,7 +166,7 @@ std::unique_ptr generate_diag_dominant_batch_matrix( static_cast(num_cols)}, {}}; auto engine = std::default_random_engine(42); - auto rand_diag_dist = std::normal_distribution(4.0, 12.0); + auto rand_diag_dist = std::normal_distribution(8.0, 1.0); for (int row = 0; row < num_rows; ++row) { std::uniform_int_distribution rand_nnz_dist{1, row + 1}; const auto k = rand_nnz_dist(engine); @@ -175,8 +175,8 @@ std::unique_ptr generate_diag_dominant_batch_matrix( } data.nonzeros.emplace_back( row, row, - static_cast( - detail::get_rand_value(rand_diag_dist, engine))); + std::abs(static_cast( + detail::get_rand_value(rand_diag_dist, engine)))); if (row < num_rows - 1) { data.nonzeros.emplace_back(row, k, value_type{-1.0}); data.nonzeros.emplace_back(row, row + 1, value_type{-1.0}); @@ -208,8 +208,15 @@ std::unique_ptr generate_diag_dominant_batch_matrix( auto rand_data = fill_random_matrix_data( num_rows, num_cols, row_idxs, col_idxs, rand_val_dist, engine); gko::utils::make_diag_dominant(rand_data); - batch_data.emplace_back(rand_data); GKO_ASSERT(rand_data.size == batch_data.at(0).size); + GKO_ASSERT(rand_data.nonzeros.size() == data.nonzeros.size()); + // Copy over the diagonal values + for (int i = 0; i < data.nonzeros.size(); ++i) { + if (data.nonzeros[i].row == data.nonzeros[i].column) { + rand_data.nonzeros[i] = data.nonzeros[i]; + } + } + batch_data.emplace_back(rand_data); } return gko::batch::read( exec, batch_data, std::forward(args)...); @@ -224,7 +231,7 @@ struct LinearSystem { std::shared_ptr matrix; std::shared_ptr rhs; - std::shared_ptr rhs_norm; + std::shared_ptr host_rhs_norm; std::shared_ptr exact_sol; }; @@ -250,8 +257,8 @@ LinearSystem generate_batch_linear_system( // A * x_{exact} = b sys.matrix->apply(sys.exact_sol, sys.rhs); const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs)); - sys.rhs_norm = real_vec::create(exec, norm_dim); - sys.rhs->compute_norm2(sys.rhs_norm.get()); + sys.host_rhs_norm = real_vec::create(exec->get_master(), norm_dim); + sys.rhs->compute_norm2(sys.host_rhs_norm.get()); return sys; } @@ -273,13 +280,13 @@ compute_residual_norms( const gko::batch_dim<2> norm_dim(num_batch_items, gko::dim<2>(1, num_rhs)); auto residual_vec = b->clone(); - auto res_norms = real_vec::create(exec, norm_dim); + auto res_norm = real_vec::create(exec->get_master(), norm_dim); auto alpha = gko::batch::initialize(num_batch_items, {-1.0}, exec); auto beta = gko::batch::initialize(num_batch_items, {1.0}, exec); mtx->apply(alpha, x, beta, residual_vec); - residual_vec->compute_norm2(res_norms); - return res_norms; + residual_vec->compute_norm2(res_norm); + return res_norm; } @@ -289,7 +296,7 @@ struct Result { using real_vec = batch::MultiVector>; std::shared_ptr x; - std::shared_ptr res_norm; + std::shared_ptr host_res_norm; }; @@ -323,7 +330,7 @@ Result solve_linear_system( result.x->fill(zero()); solver->apply(sys.rhs, result.x); - result.res_norm = + result.host_res_norm = compute_residual_norms(sys.matrix.get(), sys.rhs.get(), result.x.get()); return std::move(result); @@ -369,7 +376,7 @@ ResultWithLogData solve_linear_system( result.log_data->iter_counts = log_data->iter_counts; result.log_data->res_norms = log_data->res_norms; - result.res_norm = + result.host_res_norm = compute_residual_norms(sys.matrix.get(), sys.rhs.get(), result.x.get()); return std::move(result); diff --git a/cuda/base/kernel_config.hpp b/cuda/base/kernel_config.hpp new file mode 100644 index 00000000000..a4aecea1d55 --- /dev/null +++ b/cuda/base/kernel_config.hpp @@ -0,0 +1,88 @@ +/************************************************************* +Copyright (c) 2017-2023, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#ifndef GKO_CUDA_BASE_KERNEL_CONFIG_HPP_ +#define GKO_CUDA_BASE_KERNEL_CONFIG_HPP_ + + +#include + + +#include + + +namespace gko { +namespace kernels { +namespace cuda { +namespace detail { + + +template +class shared_memory_config_guard { +public: + using value_type = ValueType; + shared_memory_config_guard() : original_config_{} + { + GKO_ASSERT_NO_CUDA_ERRORS( + cudaDeviceGetSharedMemConfig(&original_config_)); + + if (sizeof(value_type) == 4) { + GKO_ASSERT_NO_CUDA_ERRORS( + cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeFourByte)); + } else if (sizeof(value_type) % 8 == 0) { + GKO_ASSERT_NO_CUDA_ERRORS( + cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte)); + } else { + GKO_ASSERT_NO_CUDA_ERRORS( + cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeDefault)); + } + } + + + ~shared_memory_config_guard() + { + // No need to exit or throw if we cant set the value back. + cudaDeviceSetSharedMemConfig(original_config_); + } + +private: + cudaSharedMemConfig original_config_; +}; + + +} // namespace detail +} // namespace cuda +} // namespace kernels +} // namespace gko + + +#endif // GKO_CUDA_BASE_KERNEL_CONFIG_HPP_ diff --git a/cuda/matrix/batch_struct.hpp b/cuda/matrix/batch_struct.hpp index 4a2a1835961..55a30c043e3 100644 --- a/cuda/matrix/batch_struct.hpp +++ b/cuda/matrix/batch_struct.hpp @@ -92,7 +92,8 @@ get_batch_struct(batch::matrix::Dense* const op) * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch, IndexType> +inline batch::matrix::ell::uniform_batch, + const IndexType> get_batch_struct(const batch::matrix::Ell* const op) { return {as_cuda_type(op->get_const_values()), diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index ee7d0948b99..1d80f206c1b 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -33,21 +33,39 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/solver/batch_bicgstab_kernels.hpp" +#include +#include + + #include #include +#include "core/base/batch_struct.hpp" +#include "core/matrix/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" +#include "cuda/base/batch_struct.hpp" #include "cuda/base/config.hpp" +#include "cuda/base/kernel_config.hpp" +#include "cuda/base/thrust.cuh" #include "cuda/base/types.hpp" #include "cuda/components/cooperative_groups.cuh" +#include "cuda/components/reduction.cuh" #include "cuda/components/thread_ids.cuh" +#include "cuda/components/uninitialized_array.hpp" #include "cuda/matrix/batch_struct.hpp" namespace gko { namespace kernels { namespace cuda { + + +// NOTE: this default block size is not used for the main solver kernel. +constexpr int default_block_size = 256; +constexpr int sm_oversubscription = 4; + + /** * @brief The batch Bicgstab solver namespace. * @@ -56,19 +74,219 @@ namespace cuda { namespace batch_bicgstab { +#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc" +#include "common/cuda_hip/components/uninitialized_array.hpp.inc" +#include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc" +#include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc" +#include "common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc" + + +template +int get_num_threads_per_block(std::shared_ptr exec, + const int num_rows) +{ + int num_warps = std::max(num_rows / 4, 2); + constexpr int warp_sz = static_cast(config::warp_size); + const int min_block_size = 2 * warp_sz; + const int device_max_threads = + ((std::max(num_rows, min_block_size)) / warp_sz) * warp_sz; + cudaFuncAttributes funcattr; + cudaFuncGetAttributes(&funcattr, + apply_kernel); + const int num_regs_used = funcattr.numRegs; + int max_regs_blk = 0; + cudaDeviceGetAttribute(&max_regs_blk, cudaDevAttrMaxRegistersPerBlock, + exec->get_device_id()); + const int max_threads_regs = + ((max_regs_blk / static_cast(num_regs_used)) / warp_sz) * warp_sz; + int max_threads = std::min(max_threads_regs, device_max_threads); + max_threads = max_threads <= 1024 ? max_threads : 1024; + return std::max(std::min(num_warps * warp_sz, max_threads), min_block_size); +} + + +template +int get_max_dynamic_shared_memory(std::shared_ptr exec) +{ + int shmem_per_sm = 0; + cudaDeviceGetAttribute(&shmem_per_sm, + cudaDevAttrMaxSharedMemoryPerMultiprocessor, + exec->get_device_id()); + GKO_ASSERT_NO_CUDA_ERRORS(cudaFuncSetAttribute( + apply_kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, 99 /*%*/)); + cudaFuncAttributes funcattr; + cudaFuncGetAttributes(&funcattr, + apply_kernel); + return funcattr.maxDynamicSharedSizeBytes; +} + + template using settings = gko::kernels::batch_bicgstab::settings; +template +class kernel_caller { +public: + using value_type = CuValueType; + + kernel_caller(std::shared_ptr exec, + const settings> settings) + : exec_{std::move(exec)}, settings_{settings} + {} + + template + void launch_apply_kernel( + const gko::kernels::batch_bicgstab::storage_config& sconf, + LogType& logger, PrecType& prec, const BatchMatrixType& mat, + const value_type* const __restrict__ b_values, + value_type* const __restrict__ x_values, + value_type* const __restrict__ workspace_data, const int& block_size, + const size_t& shared_size) const + { + apply_kernel + <<get_stream()>>>(sconf, settings_.max_iterations, + settings_.residual_tol, logger, prec, mat, + b_values, x_values, workspace_data); + } + + + template + void call_kernel( + LogType logger, const BatchMatrixType& mat, PrecType prec, + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const + { + using real_type = gko::remove_complex; + const size_type num_batch_items = mat.num_batch_items; + constexpr int align_multiple = 8; + const int padded_num_rows = + ceildiv(mat.num_rows, align_multiple) * align_multiple; + auto shem_guard = + gko::kernels::cuda::detail::shared_memory_config_guard< + value_type>(); + const int shmem_per_blk = + get_max_dynamic_shared_memory(exec_); + const int block_size = + get_num_threads_per_block( + exec_, mat.num_rows); + GKO_ASSERT(block_size >= 2 * config::warp_size); + + const size_t prec_size = + PrecType::dynamic_work_size(padded_num_rows, + mat.get_single_item_num_nnz()) * + sizeof(value_type); + const auto sconf = + gko::kernels::batch_bicgstab::compute_shared_storage( + shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), + b.num_rhs); + const size_t shared_size = + sconf.n_shared * padded_num_rows * sizeof(value_type) + + (sconf.prec_shared ? prec_size : 0); + auto workspace = gko::array( + exec_, + sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type)); + assert(sconf.gmem_stride_bytes % sizeof(value_type) == 0); + + value_type* const workspace_data = workspace.get_data(); + + // Template parameters launch_apply_kernel + if (sconf.prec_shared) { + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, workspace_data, + block_size, shared_size); + } else { + switch (sconf.n_shared) { + case 0: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 1: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 2: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 3: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 4: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 5: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 6: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 7: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 8: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 9: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + default: + GKO_NOT_IMPLEMENTED; + } + } + } + +private: + std::shared_ptr exec_; + const settings> settings_; +}; + + template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* const a, + const batch::BatchLinOp* const mat, const batch::BatchLinOp* const precon, const batch::MultiVector* const b, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) - GKO_NOT_IMPLEMENTED; +{ + using cu_value_type = cuda_type; + auto dispatcher = batch::solver::create_dispatcher( + kernel_caller(exec, settings), settings, mat, precon); + dispatcher.apply(b, x, logdata); +} GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index e0bc15fdc61..51665d26ff9 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -33,6 +33,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/base/batch_multi_vector_kernels.hpp" +#include + + #include @@ -77,10 +80,15 @@ void scale(std::shared_ptr exec, const auto alpha_ub = get_batch_struct(alpha); const auto x_ub = get_batch_struct(x); + const int num_rows = x->get_common_size()[0]; + constexpr int max_subgroup_size = config::warp_size; const auto num_batches = x_ub.num_batch_items; auto device = exec->get_queue()->get_device(); - auto group_size = + long max_group_size = device.get_info(); + int group_size = + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); const dim3 block(group_size); const dim3 grid(num_batches); @@ -125,13 +133,16 @@ void add_scaled(std::shared_ptr exec, const batch::MultiVector* const x, batch::MultiVector* const y) { - const size_type num_rows = x->get_common_size()[0]; - const size_type num_cols = x->get_common_size()[1]; - + constexpr int max_subgroup_size = config::warp_size; + const int num_rows = x->get_common_size()[0]; + const int num_cols = x->get_common_size()[1]; const auto num_batches = x->get_num_batch_items(); auto device = exec->get_queue()->get_device(); - auto group_size = + long max_group_size = device.get_info(); + int group_size = + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); const dim3 block(group_size); const dim3 grid(num_batches); @@ -183,29 +194,59 @@ void compute_dot(std::shared_ptr exec, const auto y_ub = get_batch_struct(y); const auto res_ub = get_batch_struct(result); + constexpr int max_subgroup_size = config::warp_size; const auto num_batches = x_ub.num_batch_items; + const int num_rows = x_ub.num_rows; auto device = exec->get_queue()->get_device(); - auto group_size = + + long max_group_size = device.get_info(); + int group_size = + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); const dim3 block(group_size); const dim3 grid(num_batches); - - // TODO: Remove reqd_sub_group size and use sycl::reduce_over_group - exec->get_queue()->submit([&](sycl::handler& cgh) { - cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( - config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto y_b = batch::extract_batch_item(y_ub, group_id); - const auto res_b = batch::extract_batch_item(res_ub, group_id); - compute_gen_dot_product_kernel(x_b, y_b, res_b, item_ct1, - [](auto val) { return val; }); - }); - }); + if (x->get_common_size()[1] == 1) { + exec->get_queue()->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = + batch::extract_batch_item(x_ub, group_id); + const auto y_b = + batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_conj_dot_sg( + x_b.num_rows, x_b.values, y_b.values, + res_b.values[0], item_ct1); + }); + }); + } else { + // TODO: Remove reqd_sub_group size and use sycl::reduce_over_group + exec->get_queue()->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = + batch::extract_batch_item(x_ub, group_id); + const auto y_b = + batch::extract_batch_item(y_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return val; }); + }); + }); + } } GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( @@ -222,10 +263,15 @@ void compute_conj_dot(std::shared_ptr exec, const auto y_ub = get_batch_struct(y); const auto res_ub = get_batch_struct(result); + constexpr int max_subgroup_size = config::warp_size; + const int num_rows = x->get_common_size()[0]; const auto num_batches = x_ub.num_batch_items; auto device = exec->get_queue()->get_device(); - auto group_size = + long max_group_size = device.get_info(); + int group_size = + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); const dim3 block(group_size); const dim3 grid(num_batches); @@ -234,7 +280,7 @@ void compute_conj_dot(std::shared_ptr exec, cgh.parallel_for( sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(config::warp_size)]] { + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { auto group = item_ct1.get_group(); auto group_id = group.get_group_linear_id(); const auto x_b = batch::extract_batch_item(x_ub, group_id); @@ -261,26 +307,50 @@ void compute_norm2(std::shared_ptr exec, const auto res_ub = get_batch_struct(result); const auto num_batches = x_ub.num_batch_items; + const int num_rows = x->get_common_size()[0]; auto device = exec->get_queue()->get_device(); - auto group_size = + + constexpr int max_subgroup_size = config::warp_size; + long max_group_size = device.get_info(); + int group_size = + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); const dim3 block(group_size); const dim3 grid(num_batches); - - exec->get_queue()->submit([&](sycl::handler& cgh) { - cgh.parallel_for(sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto x_b = - batch::extract_batch_item(x_ub, group_id); - const auto res_b = batch::extract_batch_item( - res_ub, group_id); - compute_norm2_kernel(x_b, res_b, item_ct1); - }); - }); + if (x->get_common_size()[1] == 1) { + exec->get_queue()->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = + batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values, + res_b.values[0], item_ct1); + }); + }); + } else { + exec->get_queue()->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(max_subgroup_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto x_b = + batch::extract_batch_item(x_ub, group_id); + const auto res_b = + batch::extract_batch_item(res_ub, group_id); + compute_norm2_kernel(x_b, res_b, item_ct1); + }); + }); + } } GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( @@ -296,9 +366,14 @@ void copy(std::shared_ptr exec, const auto result_ub = get_batch_struct(result); const auto num_batches = x_ub.num_batch_items; + const int num_rows = x->get_common_size()[0]; auto device = exec->get_queue()->get_device(); - auto group_size = + constexpr int max_subgroup_size = config::warp_size; + long max_group_size = device.get_info(); + int group_size = + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); const dim3 block(group_size); const dim3 grid(num_batches); diff --git a/dpcpp/base/batch_multi_vector_kernels.hpp.inc b/dpcpp/base/batch_multi_vector_kernels.hpp.inc index 22d00d780f9..be9d02aa88d 100644 --- a/dpcpp/base/batch_multi_vector_kernels.hpp.inc +++ b/dpcpp/base/batch_multi_vector_kernels.hpp.inc @@ -67,6 +67,53 @@ __dpct_inline__ void add_scaled_kernel( } +template +__dpct_inline__ void single_rhs_compute_conj_dot( + const int num_rows, const ValueType* const __restrict__ x, + const ValueType* const __restrict__ y, ValueType& result, + sycl::nd_item<3> item_ct1) +{ + const auto group = item_ct1.get_group(); + const auto group_size = item_ct1.get_local_range().size(); + const auto tid = item_ct1.get_local_linear_id(); + + ValueType val = zero(); + + for (int r = tid; r < num_rows; r += group_size) { + val += conj(x[r]) * y[r]; + } + result = sycl::reduce_over_group(group, val, sycl::plus<>()); +} + + +template +__dpct_inline__ void single_rhs_compute_conj_dot_sg( + const int num_rows, const ValueType* const __restrict__ x, + const ValueType* const __restrict__ y, ValueType& result, + sycl::nd_item<3> item_ct1) +{ + auto subg = + group::tiled_partition(group::this_thread_block(item_ct1)); + const auto subgroup = static_cast(subg); + const int subgroup_id = subgroup.get_group_id(); + const int subgroup_size = subgroup.get_local_range().size(); + const auto subgroup_tid = subgroup.get_local_id(); + + ValueType val = zero(); + + for (int r = subgroup_tid; r < num_rows; r += subgroup_size) { + val += conj(x[r]) * y[r]; + } + + val = ::gko::kernels::dpcpp::reduce( + subg, val, [](ValueType a, ValueType b) { return a + b; }); + + if (subgroup_tid == 0) { + result = val; + } +} + + template __dpct_inline__ void compute_gen_dot_product_kernel( const gko::batch::multi_vector::batch_item& x, @@ -102,6 +149,55 @@ __dpct_inline__ void compute_gen_dot_product_kernel( } +template +__dpct_inline__ void single_rhs_compute_norm2_sg( + const int num_rows, const ValueType* const __restrict__ x, + gko::remove_complex& result, sycl::nd_item<3> item_ct1) +{ + auto subg = + group::tiled_partition(group::this_thread_block(item_ct1)); + const auto subgroup = static_cast(subg); + const int subgroup_id = subgroup.get_group_id(); + const int subgroup_size = subgroup.get_local_range().size(); + + using real_type = typename gko::remove_complex; + real_type val = zero(); + + for (int r = subgroup.get_local_id(); r < num_rows; r += subgroup_size) { + val += squared_norm(x[r]); + } + + val = ::gko::kernels::dpcpp::reduce( + subg, val, [](real_type a, real_type b) { return a + b; }); + + if (subgroup.get_local_id() == 0) { + result = sqrt(val); + } +} + + +template +__dpct_inline__ void single_rhs_compute_norm2( + const int num_rows, const ValueType* const __restrict__ x, + gko::remove_complex& result, sycl::nd_item<3> item_ct1) +{ + const auto group = item_ct1.get_group(); + const auto group_size = item_ct1.get_local_range().size(); + const auto tid = item_ct1.get_local_linear_id(); + + using real_type = typename gko::remove_complex; + real_type val = zero(); + + for (int r = tid; r < num_rows; r += group_size) { + val += squared_norm(x[r]); + } + + val = sycl::reduce_over_group(group, val, sycl::plus<>()); + + result = sqrt(val); +} + + template __dpct_inline__ void compute_norm2_kernel( const gko::batch::multi_vector::batch_item& x, @@ -136,6 +232,17 @@ __dpct_inline__ void compute_norm2_kernel( } +template +__dpct_inline__ void copy_kernel(const int num_rows, const ValueType* in, + ValueType* out, sycl::nd_item<3>& item_ct1) +{ + for (int iz = item_ct1.get_local_linear_id(); iz < num_rows; + iz += item_ct1.get_local_range().size()) { + out[iz] = in[iz]; + } +} + + template __dpct_inline__ void copy_kernel( const gko::batch::multi_vector::batch_item& in, diff --git a/dpcpp/matrix/batch_dense_kernels.dp.cpp b/dpcpp/matrix/batch_dense_kernels.dp.cpp index a6fba2df8e3..d1320e79968 100644 --- a/dpcpp/matrix/batch_dense_kernels.dp.cpp +++ b/dpcpp/matrix/batch_dense_kernels.dp.cpp @@ -109,7 +109,8 @@ void simple_apply(std::shared_ptr exec, batch::matrix::extract_batch_item(mat_ub, group_id); const auto b_b = batch::extract_batch_item(b_ub, group_id); const auto x_b = batch::extract_batch_item(x_ub, group_id); - simple_apply_kernel(mat_b, b_b, x_b, item_ct1); + simple_apply_kernel(mat_b, b_b.values, x_b.values, + item_ct1); }); }); } @@ -160,7 +161,8 @@ void advanced_apply(std::shared_ptr exec, batch::extract_batch_item(alpha_ub, group_id); const auto beta_b = batch::extract_batch_item(beta_ub, group_id); - advanced_apply_kernel(alpha_b, mat_b, b_b, beta_b, x_b, + advanced_apply_kernel(alpha_b.values[0], mat_b, b_b.values, + beta_b.values[0], x_b.values, item_ct1); }); }); diff --git a/dpcpp/matrix/batch_dense_kernels.hpp.inc b/dpcpp/matrix/batch_dense_kernels.hpp.inc index 88ef5f54764..ba232ea02e4 100644 --- a/dpcpp/matrix/batch_dense_kernels.hpp.inc +++ b/dpcpp/matrix/batch_dense_kernels.hpp.inc @@ -33,9 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. template __dpct_inline__ void simple_apply_kernel( const gko::batch::matrix::dense::batch_item& mat, - const gko::batch::multi_vector::batch_item& b, - const gko::batch::multi_vector::batch_item& x, - sycl::nd_item<3>& item_ct1) + const ValueType* b, ValueType* x, sycl::nd_item<3>& item_ct1) { constexpr auto tile_size = config::warp_size; auto subg = @@ -50,14 +48,14 @@ __dpct_inline__ void simple_apply_kernel( for (int j = subgroup.get_local_id(); j < mat.num_cols; j += subgroup_size) { const ValueType val = mat.values[row * mat.stride + j]; - temp += val * b.values[j]; + temp += val * b[j]; } temp = ::gko::kernels::dpcpp::reduce( subg, temp, [](ValueType a, ValueType b) { return a + b; }); if (subgroup.get_local_id() == 0) { - x.values[row] = temp; + x[row] = temp; } } } @@ -65,11 +63,9 @@ __dpct_inline__ void simple_apply_kernel( template __dpct_inline__ void advanced_apply_kernel( - const gko::batch::multi_vector::batch_item& alpha, + const ValueType alpha, const gko::batch::matrix::dense::batch_item& mat, - const gko::batch::multi_vector::batch_item& b, - const gko::batch::multi_vector::batch_item& beta, - const gko::batch::multi_vector::batch_item& x, + const ValueType* b, const ValueType beta, ValueType* x, sycl::nd_item<3>& item_ct1) { constexpr auto tile_size = config::warp_size; @@ -85,14 +81,14 @@ __dpct_inline__ void advanced_apply_kernel( for (int j = subgroup.get_local_id(); j < mat.num_cols; j += subgroup_size) { const ValueType val = mat.values[row * mat.stride + j]; - temp += alpha.values[0] * val * b.values[j]; + temp += alpha * val * b[j]; } temp = ::gko::kernels::dpcpp::reduce( subg, temp, [](ValueType a, ValueType b) { return a + b; }); if (subgroup.get_local_id() == 0) { - x.values[row] = temp + beta.values[0] * x.values[row]; + x[row] = temp + beta * x[row]; } } } diff --git a/dpcpp/matrix/batch_ell_kernels.dp.cpp b/dpcpp/matrix/batch_ell_kernels.dp.cpp index 5a69bbd3d5d..f565f69f270 100644 --- a/dpcpp/matrix/batch_ell_kernels.dp.cpp +++ b/dpcpp/matrix/batch_ell_kernels.dp.cpp @@ -106,7 +106,8 @@ void simple_apply(std::shared_ptr exec, batch::matrix::extract_batch_item(mat_ub, group_id); const auto b_b = batch::extract_batch_item(b_ub, group_id); const auto x_b = batch::extract_batch_item(x_ub, group_id); - simple_apply_kernel(mat_b, b_b, x_b, item_ct1); + simple_apply_kernel(mat_b, b_b.values, x_b.values, + item_ct1); }); }); } @@ -158,7 +159,8 @@ void advanced_apply(std::shared_ptr exec, batch::extract_batch_item(alpha_ub, group_id); const auto beta_b = batch::extract_batch_item(beta_ub, group_id); - advanced_apply_kernel(alpha_b, mat_b, b_b, beta_b, x_b, + advanced_apply_kernel(alpha_b.values[0], mat_b, b_b.values, + beta_b.values[0], x_b.values, item_ct1); }); }); diff --git a/dpcpp/matrix/batch_ell_kernels.hpp.inc b/dpcpp/matrix/batch_ell_kernels.hpp.inc index 64d71710dbb..8c54d48db7d 100644 --- a/dpcpp/matrix/batch_ell_kernels.hpp.inc +++ b/dpcpp/matrix/batch_ell_kernels.hpp.inc @@ -33,9 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. template __dpct_inline__ void simple_apply_kernel( const gko::batch::matrix::ell::batch_item& mat, - const gko::batch::multi_vector::batch_item& b, - const gko::batch::multi_vector::batch_item& x, - sycl::nd_item<3>& item_ct1) + const ValueType* b, ValueType* x, sycl::nd_item<3>& item_ct1) { for (int tidx = item_ct1.get_local_linear_id(); tidx < mat.num_rows; tidx += item_ct1.get_local_range().size()) { @@ -45,22 +43,19 @@ __dpct_inline__ void simple_apply_kernel( if (col_idx == invalid_index()) { break; } else { - temp += mat.values[tidx + idx * mat.stride] * - b.values[col_idx * b.stride]; + temp += mat.values[tidx + idx * mat.stride] * b[col_idx]; } } - x.values[tidx * x.stride] = temp; + x[tidx] = temp; } } template __dpct_inline__ void advanced_apply_kernel( - const gko::batch::multi_vector::batch_item& alpha, + const ValueType alpha, const gko::batch::matrix::ell::batch_item& mat, - const gko::batch::multi_vector::batch_item& b, - const gko::batch::multi_vector::batch_item& beta, - const gko::batch::multi_vector::batch_item& x, + const ValueType* b, const ValueType beta, ValueType* x, sycl::nd_item<3>& item_ct1) { for (int tidx = item_ct1.get_local_linear_id(); tidx < mat.num_rows; @@ -71,11 +66,9 @@ __dpct_inline__ void advanced_apply_kernel( if (col_idx == invalid_index()) { break; } else { - temp += mat.values[tidx + idx * mat.stride] * - b.values[col_idx * b.stride]; + temp += mat.values[tidx + idx * mat.stride] * b[col_idx]; } } - x.values[tidx * x.stride] = - alpha.values[0] * temp + beta.values[0] * x.values[tidx * x.stride]; + x[tidx] = alpha * temp + beta * x[tidx]; } } diff --git a/dpcpp/matrix/batch_struct.hpp b/dpcpp/matrix/batch_struct.hpp index fe04407d82d..7f36378d8e1 100644 --- a/dpcpp/matrix/batch_struct.hpp +++ b/dpcpp/matrix/batch_struct.hpp @@ -91,7 +91,7 @@ inline batch::matrix::dense::uniform_batch get_batch_struct( * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch +inline batch::matrix::ell::uniform_batch get_batch_struct(const batch::matrix::Ell* const op) { return {op->get_const_values(), diff --git a/dpcpp/preconditioner/batch_identity.hpp.inc b/dpcpp/preconditioner/batch_identity.hpp.inc index e15a4d37399..792886f845d 100644 --- a/dpcpp/preconditioner/batch_identity.hpp.inc +++ b/dpcpp/preconditioner/batch_identity.hpp.inc @@ -42,15 +42,9 @@ public: static int dynamic_work_size(int, int) { return 0; } - void generate(size_type batch_id, - const gko::batch::matrix::ell::batch_item&, - ValueType* const, sycl::nd_item<3> item_ct1) - {} - - void generate(size_type batch_id, - const gko::batch::matrix::dense::batch_item&, - ValueType* const, sycl::nd_item<3> item_ct1) + template + void generate(size_type, const batch_item_type&, ValueType*, + sycl::nd_item<3> item_ct1) {} __dpct_inline__ void apply(const int num_rows, const ValueType* const r, diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 81519d8e2aa..9e353734f36 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -33,12 +33,26 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/solver/batch_bicgstab_kernels.hpp" -#include -#include +#include +#include +#include +#include + + +#include "core/base/batch_struct.hpp" +#include "core/matrix/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" +#include "dpcpp/base/batch_struct.hpp" #include "dpcpp/base/config.hpp" +#include "dpcpp/base/dim3.dp.hpp" +#include "dpcpp/base/dpct.hpp" +#include "dpcpp/base/helper.hpp" +#include "dpcpp/components/cooperative_groups.dp.hpp" +#include "dpcpp/components/intrinsics.dp.hpp" +#include "dpcpp/components/reduction.dp.hpp" +#include "dpcpp/components/thread_ids.dp.hpp" #include "dpcpp/matrix/batch_struct.hpp" @@ -53,19 +67,226 @@ namespace dpcpp { namespace batch_bicgstab { +#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc" +#include "dpcpp/matrix/batch_dense_kernels.hpp.inc" +#include "dpcpp/matrix/batch_ell_kernels.hpp.inc" +#include "dpcpp/solver/batch_bicgstab_kernels.hpp.inc" + + template using settings = gko::kernels::batch_bicgstab::settings; +__dpct_inline__ int get_group_size(int value, + int subgroup_size = config::warp_size) +{ + int num_sg = ceildiv(value, subgroup_size); + return num_sg * subgroup_size; +} + + +template +class KernelCaller { +public: + KernelCaller(std::shared_ptr exec, + const settings> settings) + : exec_{std::move(exec)}, settings_{settings} + {} + + template + __dpct_inline__ void launch_apply_kernel( + const gko::kernels::batch_bicgstab::storage_config& sconf, + LogType& logger, PrecType& prec, const BatchMatrixType mat, + const ValueType* const __restrict__ b_values, + ValueType* const __restrict__ x_values, + ValueType* const __restrict__ workspace, const int& group_size, + const int& shared_size) const + { + auto num_rows = mat.num_rows; + + const dim3 block(group_size); + const dim3 grid(mat.num_batch_items); + + auto max_iters = settings_.max_iterations; + auto res_tol = settings_.residual_tol; + + exec_->get_queue()->submit([&](sycl::handler& cgh) { + sycl::accessor + slm_values(sycl::range<1>(shared_size), cgh); + + cgh.parallel_for( + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( + subgroup_size)]] [[intel::kernel_args_restrict]] { + auto batch_id = item_ct1.get_group_linear_id(); + const auto mat_global_entry = + gko::batch::matrix::extract_batch_item(mat, batch_id); + const ValueType* const b_global_entry = + gko::batch::multi_vector::batch_item_ptr( + b_values, 1, num_rows, batch_id); + ValueType* const x_global_entry = + gko::batch::multi_vector::batch_item_ptr( + x_values, 1, num_rows, batch_id); + apply_kernel( + sconf, max_iters, res_tol, logger, prec, + mat_global_entry, b_global_entry, x_global_entry, + num_rows, mat.get_single_item_num_nnz(), + static_cast(slm_values.get_pointer()), + item_ct1, workspace); + }); + }); + } + + template + void call_kernel( + LogType logger, const BatchMatrixType& mat, PrecType prec, + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const + { + using real_type = gko::remove_complex; + const size_type num_batch_items = mat.num_batch_items; + const auto num_rows = mat.num_rows; + const auto num_rhs = b.num_rhs; + GKO_ASSERT(num_rhs == 1); + + auto device = exec_->get_queue()->get_device(); + auto max_group_size = + device.get_info(); + int group_size = + device.get_info(); + if (group_size > num_rows) { + group_size = get_group_size(num_rows); + }; + group_size = std::min( + std::max(group_size, static_cast(2 * config::warp_size)), + static_cast(max_group_size)); + + // reserve 5 for intermediate rho-s, norms, + // alpha, omega, temp and for reduce_over_group + // If the value available is negative, then set it to 0 + const int static_var_mem = + (group_size + 5) * sizeof(ValueType) + 2 * sizeof(real_type); + int shmem_per_blk = std::max( + static_cast( + device.get_info()) - + static_var_mem, + 0); + const int padded_num_rows = num_rows; + const size_type prec_size = PrecType::dynamic_work_size( + padded_num_rows, mat.get_single_item_num_nnz()); + const auto sconf = + gko::kernels::batch_bicgstab::compute_shared_storage( + shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), + b.num_rhs); + const size_t shared_size = sconf.n_shared * padded_num_rows + + (sconf.prec_shared ? prec_size : 0); + auto workspace = gko::array( + exec_, + sconf.gmem_stride_bytes * num_batch_items / sizeof(ValueType)); + GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(ValueType) == 0); + + ValueType* const workspace_data = workspace.get_data(); + int n_shared_total = sconf.n_shared + int(sconf.prec_shared); + + // template + // launch_apply_kernel + if (num_rows <= 32 && n_shared_total == 10) { + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, workspace_data, + group_size, shared_size); + } else if (num_rows <= 256 && n_shared_total == 10) { + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, workspace_data, + group_size, shared_size); + } else { + switch (n_shared_total) { + case 0: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 1: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 2: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 3: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 4: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 5: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 6: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 7: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 8: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 9: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + case 10: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, group_size, shared_size); + break; + default: + GKO_NOT_IMPLEMENTED; + } + } + } + +private: + std::shared_ptr exec_; + const settings> settings_; +}; + + template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* const a, - const batch::BatchLinOp* const precon, + const batch::BatchLinOp* const mat, + const batch::BatchLinOp* const precond, const batch::MultiVector* const b, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) - GKO_NOT_IMPLEMENTED; +{ + auto dispatcher = batch::solver::create_dispatcher( + KernelCaller(exec, settings), settings, mat, precond); + dispatcher.apply(b, x, logdata); +} + GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); diff --git a/dpcpp/solver/batch_bicgstab_kernels.hpp.inc b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc new file mode 100644 index 00000000000..03f8ea31165 --- /dev/null +++ b/dpcpp/solver/batch_bicgstab_kernels.hpp.inc @@ -0,0 +1,413 @@ +/************************************************************* +Copyright (c) 2017-2023, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +template +__dpct_inline__ void initialize( + const int num_rows, const BatchMatrixType_entry& mat_global_entry, + const ValueType* const b_global_entry, + const ValueType* const x_global_entry, ValueType& rho_old, ValueType& omega, + ValueType& alpha, ValueType* const x_shared_entry, + ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry, + ValueType* const p_shared_entry, ValueType* const v_shared_entry, + ValueType* const p_hat_shared_entry, + typename gko::remove_complex& rhs_norm, + typename gko::remove_complex& res_norm, + sycl::nd_item<3> item_ct1) +{ + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); + const auto group_size = item_ct1.get_local_range().size(); + const auto group = item_ct1.get_group(); + + rho_old = one(); + omega = one(); + alpha = one(); + + // copy x from global to shared memory + // r = b + for (int iz = tid; iz < num_rows; iz += group_size) { + x_shared_entry[iz] = x_global_entry[iz]; + r_shared_entry[iz] = b_global_entry[iz]; + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // r = b - A*x + advanced_apply_kernel(static_cast(-1.0), mat_global_entry, + x_shared_entry, static_cast(1.0), + r_shared_entry, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + if (sg_id == 0) { + single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm, + item_ct1); + } else if (sg_id == 1) { + single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm, + item_ct1); + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + + for (int iz = tid; iz < num_rows; iz += group_size) { + r_hat_shared_entry[iz] = r_shared_entry[iz]; + p_shared_entry[iz] = zero(); + p_hat_shared_entry[iz] = zero(); + v_shared_entry[iz] = zero(); + } +} + + +template +__dpct_inline__ void update_p(const int num_rows, const ValueType& rho_new, + const ValueType& rho_old, const ValueType& alpha, + const ValueType& omega, + const ValueType* const r_shared_entry, + const ValueType* const v_shared_entry, + ValueType* const p_shared_entry, + sycl::nd_item<3> item_ct1) +{ + const ValueType beta = (rho_new / rho_old) * (alpha / omega); + for (int r = item_ct1.get_local_linear_id(); r < num_rows; + r += item_ct1.get_local_range().size()) { + p_shared_entry[r] = + r_shared_entry[r] + + beta * (p_shared_entry[r] - omega * v_shared_entry[r]); + } +} + + +template +__dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new, + const ValueType* const r_hat_shared_entry, + const ValueType* const v_shared_entry, + ValueType& alpha, sycl::nd_item<3> item_ct1) +{ + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); + if (sg_id == 0) { + single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry, + v_shared_entry, alpha, item_ct1); + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + if (tid == 0) { + alpha = rho_new / alpha; + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); +} + + +template +__dpct_inline__ void update_s(const int num_rows, + const ValueType* const r_shared_entry, + const ValueType& alpha, + const ValueType* const v_shared_entry, + ValueType* const s_shared_entry, + sycl::nd_item<3> item_ct1) +{ + for (int r = item_ct1.get_local_linear_id(); r < num_rows; + r += item_ct1.get_local_range().size()) { + s_shared_entry[r] = r_shared_entry[r] - alpha * v_shared_entry[r]; + } +} + + +template +__dpct_inline__ void compute_omega(const int num_rows, + const ValueType* const t_shared_entry, + const ValueType* const s_shared_entry, + ValueType& temp, ValueType& omega, + sycl::nd_item<3> item_ct1) +{ + auto sg = item_ct1.get_sub_group(); + const auto sg_id = sg.get_group_id(); + const auto tid = item_ct1.get_local_linear_id(); + if (sg_id == 0) { + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, s_shared_entry, + omega, item_ct1); + } else if (sg_id == 1) { + single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, t_shared_entry, + temp, item_ct1); + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + if (tid == 0) { + omega /= temp; + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); +} + + +template +__dpct_inline__ void update_x_and_r( + const int num_rows, const ValueType* const p_hat_shared_entry, + const ValueType* const s_hat_shared_entry, const ValueType& alpha, + const ValueType& omega, const ValueType* const s_shared_entry, + const ValueType* const t_shared_entry, ValueType* const x_shared_entry, + ValueType* const r_shared_entry, sycl::nd_item<3> item_ct1) +{ + for (int r = item_ct1.get_local_linear_id(); r < num_rows; + r += item_ct1.get_local_range().size()) { + x_shared_entry[r] = x_shared_entry[r] + alpha * p_hat_shared_entry[r] + + omega * s_hat_shared_entry[r]; + r_shared_entry[r] = s_shared_entry[r] - omega * t_shared_entry[r]; + } +} + + +template +__dpct_inline__ void update_x_middle(const int num_rows, const ValueType& alpha, + const ValueType* const p_hat_shared_entry, + ValueType* const x_shared_entry, + sycl::nd_item<3> item_ct1) +{ + for (int r = item_ct1.get_local_linear_id(); r < num_rows; + r += item_ct1.get_local_range().size()) { + x_shared_entry[r] = x_shared_entry[r] + alpha * p_hat_shared_entry[r]; + } +} + + +template +void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf, + const int max_iter, const gko::remove_complex tol, + LogType logger, PrecType prec_shared, + const BatchMatrixType mat_global_entry, + const ValueType* const __restrict__ b_global_entry, + ValueType* const __restrict__ x_global_entry, + const size_type num_rows, const size_type nnz, + ValueType* const __restrict__ slm_values, + sycl::nd_item<3> item_ct1, + ValueType* const __restrict__ workspace = nullptr) +{ + using real_type = typename gko::remove_complex; + + const auto sg = item_ct1.get_sub_group(); + const int sg_id = sg.get_group_id(); + const int tid = item_ct1.get_local_linear_id(); + auto group = item_ct1.get_group(); + const int group_size = item_ct1.get_local_range().size(); + + const auto batch_id = item_ct1.get_group_linear_id(); + + ValueType* rho_old_sh; + ValueType* rho_new_sh; + ValueType* alpha_sh; + ValueType* omega_sh; + ValueType* temp_sh; + real_type* norms_rhs_sh; + real_type* norms_res_sh; + + using tile_value_t = ValueType[5]; + tile_value_t& values = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + group); + using tile_real_t = real_type[2]; + tile_real_t& reals = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + group); + rho_old_sh = &values[0]; + rho_new_sh = &values[1]; + alpha_sh = &values[2]; + omega_sh = &values[3]; + temp_sh = &values[4]; + norms_rhs_sh = &reals[0]; + norms_res_sh = &reals[1]; + const int gmem_offset = + batch_id * sconf.gmem_stride_bytes / sizeof(ValueType); + ValueType* p_hat_sh; + ValueType* s_hat_sh; + ValueType* s_sh; + ValueType* p_sh; + ValueType* r_sh; + ValueType* r_hat_sh; + ValueType* v_sh; + ValueType* t_sh; + ValueType* x_sh; + ValueType* prec_work_sh; + + if constexpr (n_shared_total >= 1) { + p_hat_sh = slm_values; + } else { + p_hat_sh = workspace + gmem_offset; + } + if constexpr (n_shared_total == 1) { + s_hat_sh = workspace + gmem_offset; + } else { + s_hat_sh = p_hat_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 2) { + v_sh = workspace + gmem_offset; + } else { + v_sh = s_hat_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 3) { + t_sh = workspace + gmem_offset; + } else { + t_sh = v_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 4) { + p_sh = workspace + gmem_offset; + } else { + p_sh = t_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 5) { + s_sh = workspace + gmem_offset; + } else { + s_sh = p_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 6) { + r_sh = workspace + gmem_offset; + } else { + r_sh = s_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 7) { + r_hat_sh = workspace + gmem_offset; + } else { + r_hat_sh = r_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 8) { + x_sh = workspace + gmem_offset; + } else { + x_sh = r_hat_sh + sconf.padded_vec_len; + } + if constexpr (n_shared_total == 9) { + prec_work_sh = workspace + gmem_offset; + } else { + prec_work_sh = x_sh + sconf.padded_vec_len; + } + + // generate preconditioner + prec_shared.generate(batch_id, mat_global_entry, prec_work_sh, item_ct1); + + // initialization + // rho_old = 1, omega = 1, alpha = 1 + // compute b norms + // copy x from global to shared memory + // r = b - A*x + // compute residual norms + // r_hat = r + // p = 0 + // p_hat = 0 + // v = 0 + initialize(num_rows, mat_global_entry, b_global_entry, x_global_entry, + rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh, r_hat_sh, + p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0], + item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // stopping criterion object + StopType stop(tol, norms_rhs_sh); + + int iter = 0; + for (; iter < max_iter; iter++) { + if (stop.check_converged(norms_res_sh)) { + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + break; + } + + // rho_new = < r_hat , r > = (r_hat)' * (r) + if (sg_id == 0) { + single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh, + rho_new_sh[0], item_ct1); + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // beta = (rho_new / rho_old)*(alpha / omega) + // p = r + beta*(p - omega * v) + update_p(num_rows, rho_new_sh[0], rho_old_sh[0], alpha_sh[0], + omega_sh[0], r_sh, v_sh, p_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // p_hat = precond * p + prec_shared.apply(num_rows, p_sh, p_hat_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // v = A * p_hat + simple_apply_kernel(mat_global_entry, p_hat_sh, v_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // alpha = rho_new / < r_hat , v> + compute_alpha(num_rows, rho_new_sh[0], r_hat_sh, v_sh, alpha_sh[0], + item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // s = r - alpha*v + update_s(num_rows, r_sh, alpha_sh[0], v_sh, s_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // an estimate of residual norms + if (sg_id == 0) { + single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0], + item_ct1); + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + if (stop.check_converged(norms_res_sh)) { + update_x_middle(num_rows, alpha_sh[0], p_hat_sh, x_sh, item_ct1); + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + break; + } + + // s_hat = precond * s + prec_shared.apply(num_rows, s_sh, s_hat_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // t = A * s_hat + simple_apply_kernel(mat_global_entry, s_hat_sh, t_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // omega = / + compute_omega(num_rows, t_sh, s_sh, temp_sh[0], omega_sh[0], item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + // x = x + alpha*p_hat + omega *s_hat + // r = s - omega * t + update_x_and_r(num_rows, p_hat_sh, s_hat_sh, alpha_sh[0], omega_sh[0], + s_sh, t_sh, x_sh, r_sh, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); + + if (sg_id == 0) + single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0], + item_ct1); + if (tid == group_size - 1) { + rho_old_sh[0] = rho_new_sh[0]; + } + item_ct1.barrier(sycl::access::fence_space::global_and_local); + } + + logger.log_iteration(batch_id, iter, norms_res_sh[0]); + + // copy x back to global memory + copy_kernel(num_rows, x_sh, x_global_entry, item_ct1); + item_ct1.barrier(sycl::access::fence_space::global_and_local); +} diff --git a/hip/matrix/batch_struct.hip.hpp b/hip/matrix/batch_struct.hip.hpp index e35f13f1249..ba75b1b634e 100644 --- a/hip/matrix/batch_struct.hip.hpp +++ b/hip/matrix/batch_struct.hip.hpp @@ -92,7 +92,8 @@ get_batch_struct(batch::matrix::Dense* const op) * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch, IndexType> +inline batch::matrix::ell::uniform_batch, + const IndexType> get_batch_struct(const batch::matrix::Ell* const op) { return {as_hip_type(op->get_const_values()), diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index 4ef8cd36c1b..217d314a5c9 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -34,21 +34,37 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include +#include #include #include +#include "core/base/batch_struct.hpp" +#include "core/matrix/batch_struct.hpp" #include "core/solver/batch_dispatch.hpp" #include "hip/base/batch_struct.hip.hpp" #include "hip/base/config.hip.hpp" +#include "hip/base/math.hip.hpp" +#include "hip/base/thrust.hip.hpp" +#include "hip/base/types.hip.hpp" +#include "hip/components/cooperative_groups.hip.hpp" +#include "hip/components/reduction.hip.hpp" +#include "hip/components/thread_ids.hip.hpp" +#include "hip/components/uninitialized_array.hip.hpp" #include "hip/matrix/batch_struct.hip.hpp" namespace gko { namespace kernels { namespace hip { + + +constexpr int default_block_size = 256; +constexpr int sm_oversubscription = 4; + /** * @brief The batch Bicgstab solver namespace. * @@ -57,19 +73,193 @@ namespace hip { namespace batch_bicgstab { +#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc" +#include "common/cuda_hip/components/uninitialized_array.hpp.inc" +#include "common/cuda_hip/matrix/batch_dense_kernels.hpp.inc" +#include "common/cuda_hip/matrix/batch_ell_kernels.hpp.inc" +#include "common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc" + + +template +int get_num_threads_per_block(std::shared_ptr exec, + const int num_rows) +{ + int num_warps = std::max(num_rows / 4, 2); + constexpr int warp_sz = static_cast(config::warp_size); + const int min_block_size = 2 * warp_sz; + const int device_max_threads = + ((std::max(num_rows, min_block_size)) / warp_sz) * warp_sz; + // This value has been taken from ROCm docs. This is the number of registers + // that maximizes the occupancy on an AMD GPU (MI200). HIP does not have an + // API to query the number of registers a function uses. + const int num_regs_used_per_thread = 64; + int max_regs_blk = 0; + GKO_ASSERT_NO_HIP_ERRORS(hipDeviceGetAttribute( + &max_regs_blk, hipDeviceAttributeMaxRegistersPerBlock, + exec->get_device_id())); + const int max_threads_regs = (max_regs_blk / num_regs_used_per_thread); + int max_threads = std::min(max_threads_regs, device_max_threads); + max_threads = max_threads <= 1024 ? max_threads : 1024; + return std::max(std::min(num_warps * warp_sz, max_threads), min_block_size); +} + + template using settings = gko::kernels::batch_bicgstab::settings; +template +class kernel_caller { +public: + using value_type = HipValueType; + + kernel_caller(std::shared_ptr exec, + const settings> settings) + : exec_{exec}, settings_{settings} + {} + + template + void launch_apply_kernel( + const gko::kernels::batch_bicgstab::storage_config& sconf, + LogType& logger, PrecType& prec, const BatchMatrixType& mat, + const value_type* const __restrict__ b_values, + value_type* const __restrict__ x_values, + value_type* const __restrict__ workspace_data, const int& block_size, + const size_t& shared_size) const + { + apply_kernel + <<get_stream()>>>(sconf, settings_.max_iterations, + settings_.residual_tol, logger, prec, mat, + b_values, x_values, workspace_data); + } + + + template + void call_kernel( + LogType logger, const BatchMatrixType& mat, PrecType prec, + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const + { + using real_type = gko::remove_complex; + const size_type num_batch_items = mat.num_batch_items; + constexpr int align_multiple = 8; + const int padded_num_rows = + ceildiv(mat.num_rows, align_multiple) * align_multiple; + int shmem_per_blk = 0; + GKO_ASSERT_NO_HIP_ERRORS(hipDeviceGetAttribute( + &shmem_per_blk, hipDeviceAttributeMaxSharedMemoryPerBlock, + exec_->get_device_id())); + const int block_size = + get_num_threads_per_block(exec_, mat.num_rows); + GKO_ASSERT(block_size >= 2 * config::warp_size); + + const size_t prec_size = + PrecType::dynamic_work_size(padded_num_rows, + mat.get_single_item_num_nnz()) * + sizeof(value_type); + const auto sconf = + gko::kernels::batch_bicgstab::compute_shared_storage( + shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), + b.num_rhs); + const size_t shared_size = + sconf.n_shared * padded_num_rows * sizeof(value_type) + + (sconf.prec_shared ? prec_size : 0); + auto workspace = gko::array( + exec_, + sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type)); + assert(sconf.gmem_stride_bytes % sizeof(value_type) == 0); + + value_type* const workspace_data = workspace.get_data(); + + // Template parameters launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, workspace_data, + block_size, shared_size); + } else { + switch (sconf.n_shared) { + case 0: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 1: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 2: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 3: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 4: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 5: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 6: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 7: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 8: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + case 9: + launch_apply_kernel( + sconf, logger, prec, mat, b.values, x.values, + workspace_data, block_size, shared_size); + break; + default: + GKO_NOT_IMPLEMENTED; + } + } + } + +private: + std::shared_ptr exec_; + const settings> settings_; +}; + + template void apply(std::shared_ptr exec, const settings>& settings, - const batch::BatchLinOp* const a, + const batch::BatchLinOp* const mat, const batch::BatchLinOp* const precon, const batch::MultiVector* const b, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) - GKO_NOT_IMPLEMENTED; +{ + using hip_value_type = hip_type; + auto dispatcher = batch::solver::create_dispatcher( + kernel_caller(exec, settings), settings, mat, precon); + dispatcher.apply(b, x, logdata); +} GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); diff --git a/include/ginkgo/core/solver/batch_solver_base.hpp b/include/ginkgo/core/solver/batch_solver_base.hpp index 3141812e259..cd4ae8d1590 100644 --- a/include/ginkgo/core/solver/batch_solver_base.hpp +++ b/include/ginkgo/core/solver/batch_solver_base.hpp @@ -177,30 +177,9 @@ class BatchSolver { }; -/** - * The parameter type shared between all preconditioned iterative solvers, - * excluding the parameters available in iterative_solver_factory_parameters. - * @see GKO_CREATE_FACTORY_PARAMETERS - */ -struct preconditioned_iterative_solver_factory_parameters { - /** - * The preconditioner to be used by the iterative solver. By default, no - * preconditioner is used. - */ - std::shared_ptr preconditioner{nullptr}; - - /** - * Already generated preconditioner. If one is provided, the factory - * `preconditioner` will be ignored. - */ - std::shared_ptr generated_preconditioner{nullptr}; -}; - - template struct enable_preconditioned_iterative_solver_factory_parameters - : enable_parameters_type, - preconditioned_iterative_solver_factory_parameters { + : enable_parameters_type { /** * Default maximum number iterations allowed. * @@ -225,40 +204,18 @@ struct enable_preconditioned_iterative_solver_factory_parameters tolerance_type, ::gko::batch::stop::tolerance_type::absolute); /** - * Provides a preconditioner factory to be used by the iterative solver in a - * fluent interface. - * @see preconditioned_iterative_solver_factory_parameters::preconditioner + * The preconditioner to be used by the iterative solver. By default, no + * preconditioner is used. */ - Parameters& with_preconditioner( - deferred_factory_parameter preconditioner) - { - this->preconditioner_generator = std::move(preconditioner); - this->deferred_factories["preconditioner"] = [](const auto& exec, - auto& params) { - if (!params.preconditioner_generator.is_empty()) { - params.preconditioner = - params.preconditioner_generator.on(exec); - } - }; - return *self(); - } + std::shared_ptr GKO_DEFERRED_FACTORY_PARAMETER( + preconditioner); /** - * Provides a concrete preconditioner to be used by the iterative solver in - * a fluent interface. - * @see preconditioned_iterative_solver_factory_parameters::preconditioner + * Already generated preconditioner. If one is provided, the factory + * `preconditioner` will be ignored. */ - Parameters& with_generated_preconditioner( - std::shared_ptr generated_preconditioner) - { - this->generated_preconditioner = std::move(generated_preconditioner); - return *self(); - } - -private: - GKO_ENABLE_SELF(Parameters); - - deferred_factory_parameter preconditioner_generator; + std::shared_ptr GKO_FACTORY_PARAMETER_SCALAR( + generated_preconditioner, nullptr); }; @@ -277,6 +234,7 @@ class EnableBatchSolver public EnableBatchLinOp { public: using real_type = remove_complex; + const ConcreteSolver* apply(ptr_param> b, ptr_param> x) const { @@ -305,7 +263,10 @@ class EnableBatchSolver ConcreteSolver* apply(ptr_param> b, ptr_param> x) { - static_cast(this)->apply(b, x); + this->validate_application_parameters(b.get(), x.get()); + auto exec = this->get_executor(); + this->apply_impl(make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, x).get()); return self(); } @@ -314,7 +275,13 @@ class EnableBatchSolver ptr_param> beta, ptr_param> x) { - static_cast(this)->apply(alpha, b, beta, x); + this->validate_application_parameters(alpha.get(), b.get(), beta.get(), + x.get()); + auto exec = this->get_executor(); + this->apply_impl(make_temporary_clone(exec, alpha).get(), + make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, beta).get(), + make_temporary_clone(exec, x).get()); return self(); } diff --git a/include/ginkgo/core/solver/solver_base.hpp b/include/ginkgo/core/solver/solver_base.hpp index cd0043c7b44..070cc4e6b4a 100644 --- a/include/ginkgo/core/solver/solver_base.hpp +++ b/include/ginkgo/core/solver/solver_base.hpp @@ -856,7 +856,6 @@ class EnablePreconditionedIterativeSolver template struct enable_iterative_solver_factory_parameters : enable_parameters_type { - using parameters_type = Parameters; /** * Stopping criteria to be used by the solver. */ @@ -868,8 +867,6 @@ struct enable_iterative_solver_factory_parameters template struct enable_preconditioned_iterative_solver_factory_parameters : enable_iterative_solver_factory_parameters { - using parameters_type = Parameters; - /** * The preconditioner to be used by the iterative solver. By default, no * preconditioner is used. diff --git a/include/ginkgo/core/stop/batch_stop_enum.hpp b/include/ginkgo/core/stop/batch_stop_enum.hpp index 1694dd164d9..3c463b8730c 100644 --- a/include/ginkgo/core/stop/batch_stop_enum.hpp +++ b/include/ginkgo/core/stop/batch_stop_enum.hpp @@ -48,7 +48,7 @@ namespace stop { * * With the `relative` tolerance type, the solver * convergence criteria checks against the relative residual norm - * ($||r|| \leq ||b|| \times \tau$, where $||b||$$ is the L2 norm of the rhs). + * ($||r|| \leq ||b|| \times \tau$, where $||b||$ is the L2 norm of the rhs). * * @note the computed residual norm, $||r||$ may be implicit or explicit * depending on the solver algorithm. diff --git a/reference/log/batch_logger.hpp b/reference/log/batch_logger.hpp index a70af0af51c..2598c23766f 100644 --- a/reference/log/batch_logger.hpp +++ b/reference/log/batch_logger.hpp @@ -51,8 +51,6 @@ namespace batch_log { template class SimpleFinalLogger final { public: - using real_type = remove_complex; - /** * Constructor * @@ -61,7 +59,7 @@ class SimpleFinalLogger final { * @param batch_iters final iteration counts for each * linear system in the batch. */ - SimpleFinalLogger(real_type* const batch_residuals, int* const batch_iters) + SimpleFinalLogger(RealType* const batch_residuals, int* const batch_iters) : final_residuals_{batch_residuals}, final_iters_{batch_iters} {} @@ -73,14 +71,14 @@ class SimpleFinalLogger final { * @param res_norm Norm of final residual norm */ void log_iteration(const size_type batch_idx, const int iter, - const real_type res_norm) + const RealType res_norm) { final_iters_[batch_idx] = iter; final_residuals_[batch_idx] = res_norm; } private: - real_type* const final_residuals_; + RealType* const final_residuals_; int* const final_iters_; }; diff --git a/reference/matrix/batch_struct.hpp b/reference/matrix/batch_struct.hpp index bb7680d1493..94beff5c2c2 100644 --- a/reference/matrix/batch_struct.hpp +++ b/reference/matrix/batch_struct.hpp @@ -95,7 +95,7 @@ inline batch::matrix::dense::uniform_batch get_batch_struct( * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::ell::uniform_batch +inline batch::matrix::ell::uniform_batch get_batch_struct(const batch::matrix::Ell* const op) { return {op->get_const_values(), diff --git a/reference/preconditioner/batch_identity.hpp b/reference/preconditioner/batch_identity.hpp index b0bf869c6be..6d6d462e660 100644 --- a/reference/preconditioner/batch_identity.hpp +++ b/reference/preconditioner/batch_identity.hpp @@ -34,6 +34,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define GKO_REFERENCE_PRECONDITIONER_BATCH_IDENTITY_HPP_ +#include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" diff --git a/reference/test/solver/batch_bicgstab_kernels.cpp b/reference/test/solver/batch_bicgstab_kernels.cpp index c47c80e64dc..211318e8a8f 100644 --- a/reference/test/solver/batch_bicgstab_kernels.cpp +++ b/reference/test/solver/batch_bicgstab_kernels.cpp @@ -87,7 +87,7 @@ class BatchBicgstab : public ::testing::Test { std::shared_ptr exec; const real_type eps = 1e-3; const gko::size_type num_batch_items = 2; - const int num_rows = 3; + const int num_rows = 15; const int num_rhs = 1; const Settings solver_settings{100, eps, gko::batch::stop::tolerance_type::relative}; @@ -108,8 +108,8 @@ TYPED_TEST(BatchBicgstab, SolvesStencilSystem) this->linear_system); for (size_t i = 0; i < this->num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i] / - this->linear_system.rhs_norm->get_const_values()[i], + ASSERT_LE(res.host_res_norm->get_const_values()[i] / + this->linear_system.host_rhs_norm->get_const_values()[i], this->solver_settings.residual_tol); } GKO_ASSERT_BATCH_MTX_NEAR(res.x, this->linear_system.exact_sol, @@ -130,9 +130,10 @@ TYPED_TEST(BatchBicgstab, StencilSystemLoggerLogsResidual) auto iter_array = res.log_data->iter_counts.get_const_data(); auto res_log_array = res.log_data->res_norms.get_const_data(); for (size_t i = 0; i < this->num_batch_items; i++) { - ASSERT_LE(res_log_array[i] / this->linear_system.rhs_norm->at(i, 0, 0), - this->solver_settings.residual_tol); - ASSERT_NEAR(res_log_array[i], res.res_norm->get_const_values()[i], + ASSERT_LE( + res_log_array[i] / this->linear_system.host_rhs_norm->at(i, 0, 0), + this->solver_settings.residual_tol); + ASSERT_NEAR(res_log_array[i], res.host_res_norm->get_const_values()[i], 10 * this->eps); } } @@ -186,8 +187,8 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i], + ASSERT_LE(res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i], tol); } } @@ -228,8 +229,8 @@ TYPED_TEST(BatchBicgstab, ApplyLogsResAndIters) auto res_norm = logger->get_residual_norm(); GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); for (size_t i = 0; i < num_batch_items; i++) { - auto rel_res_norm = res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i]; + auto rel_res_norm = res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i]; ASSERT_LE(iter_counts.get_const_data()[i], max_iters); EXPECT_LE(res_norm.get_const_data()[i], tol * 50); ASSERT_LE(rel_res_norm, tol * 50); @@ -266,8 +267,8 @@ TYPED_TEST(BatchBicgstab, CanSolveEllSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i], + ASSERT_LE(res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i], tol * 10); } } @@ -302,6 +303,6 @@ TYPED_TEST(BatchBicgstab, CanSolveDenseHpdSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 50); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i], tol * 50); + ASSERT_LE(res.host_res_norm->get_const_values()[i], tol * 50); } } diff --git a/test/base/batch_multi_vector_kernels.cpp b/test/base/batch_multi_vector_kernels.cpp index be625853656..6f4eb3d05a8 100644 --- a/test/base/batch_multi_vector_kernels.cpp +++ b/test/base/batch_multi_vector_kernels.cpp @@ -70,10 +70,9 @@ class MultiVector : public CommonTestFixture { std::normal_distribution<>(-1.0, 1.0), rand_engine, ref); } - void set_up_vector_data(gko::size_type num_vecs, + void set_up_vector_data(gko::size_type num_vecs, const int num_rows = 252, bool different_alpha = false) { - const int num_rows = 252; x = gen_mtx(batch_size, num_rows, num_vecs); y = gen_mtx(batch_size, num_rows, num_vecs); c_x = gen_mtx(batch_size, num_rows, num_vecs); @@ -143,7 +142,7 @@ TEST_F(MultiVector, MultipleVectorAddScaledIsEquivalentToRef) TEST_F(MultiVector, MultipleVectorAddScaledWithDifferentAlphaIsEquivalentToRef) { - set_up_vector_data(20, true); + set_up_vector_data(20, 252, true); x->add_scaled(alpha.get(), y.get()); dx->add_scaled(dalpha.get(), dy.get()); @@ -185,6 +184,21 @@ TEST_F(MultiVector, MultipleVectorScaleWithDifferentAlphaIsEquivalentToRef) } +TEST_F(MultiVector, ComputeNorm2SingleSmallIsEquivalentToRef) +{ + set_up_vector_data(1, 10); + auto norm_size = + gko::batch_dim<2>(batch_size, gko::dim<2>{1, x->get_common_size()[1]}); + auto norm_expected = NormVector::create(this->ref, norm_size); + auto dnorm = NormVector::create(this->exec, norm_size); + + x->compute_norm2(norm_expected.get()); + dx->compute_norm2(dnorm.get()); + + GKO_ASSERT_BATCH_MTX_NEAR(norm_expected, dnorm, 5 * r::value); +} + + TEST_F(MultiVector, ComputeNorm2SingleIsEquivalentToRef) { set_up_vector_data(1); @@ -250,6 +264,21 @@ TEST_F(MultiVector, ComputeDotSingleIsEquivalentToRef) } +TEST_F(MultiVector, ComputeDotSingleSmallIsEquivalentToRef) +{ + set_up_vector_data(1, 10); + auto dot_size = + gko::batch_dim<2>(batch_size, gko::dim<2>{1, x->get_common_size()[1]}); + auto dot_expected = Mtx::create(this->ref, dot_size); + auto ddot = Mtx::create(this->exec, dot_size); + + x->compute_dot(y.get(), dot_expected.get()); + dx->compute_dot(dy.get(), ddot.get()); + + GKO_ASSERT_BATCH_MTX_NEAR(dot_expected, ddot, 5 * r::value); +} + + TEST_F(MultiVector, ComputeConjDotIsEquivalentToRef) { set_up_vector_data(20); diff --git a/test/matrix/batch_dense_kernels.cpp b/test/matrix/batch_dense_kernels.cpp index a243d51f3c1..fa75a8f61e4 100644 --- a/test/matrix/batch_dense_kernels.cpp +++ b/test/matrix/batch_dense_kernels.cpp @@ -71,10 +71,9 @@ class Dense : public CommonTestFixture { std::normal_distribution<>(-1.0, 1.0), rand_engine, ref); } - void set_up_apply_data(gko::size_type num_vecs = 1) + void set_up_apply_data(gko::size_type num_rows, gko::size_type num_vecs = 1) { - const int num_rows = 252; - const int num_cols = 32; + const gko::size_type num_cols = 32; mat = gen_mtx(batch_size, num_rows, num_cols); y = gen_mtx(batch_size, num_cols, num_vecs); alpha = gen_mtx(batch_size, 1, 1); @@ -92,7 +91,7 @@ class Dense : public CommonTestFixture { std::default_random_engine rand_engine; - const size_t batch_size = 11; + const gko::size_type batch_size = 11; std::unique_ptr mat; std::unique_ptr y; std::unique_ptr alpha; @@ -106,9 +105,20 @@ class Dense : public CommonTestFixture { }; +TEST_F(Dense, SingleVectorApplyIsEquivalentToRefForSmallMatrices) +{ + set_up_apply_data(10); + + mat->apply(y.get(), expected.get()); + dmat->apply(dy.get(), dresult.get()); + + GKO_ASSERT_BATCH_MTX_NEAR(dresult, expected, r::value); +} + + TEST_F(Dense, SingleVectorApplyIsEquivalentToRef) { - set_up_apply_data(1); + set_up_apply_data(257); mat->apply(y.get(), expected.get()); dmat->apply(dy.get(), dresult.get()); @@ -119,7 +129,7 @@ TEST_F(Dense, SingleVectorApplyIsEquivalentToRef) TEST_F(Dense, SingleVectorAdvancedApplyIsEquivalentToRef) { - set_up_apply_data(1); + set_up_apply_data(257); mat->apply(alpha.get(), y.get(), beta.get(), expected.get()); dmat->apply(dalpha.get(), dy.get(), dbeta.get(), dresult.get()); diff --git a/test/matrix/batch_ell_kernels.cpp b/test/matrix/batch_ell_kernels.cpp index 572f47ba47d..7a4c6558c5d 100644 --- a/test/matrix/batch_ell_kernels.cpp +++ b/test/matrix/batch_ell_kernels.cpp @@ -87,8 +87,8 @@ class Ell : public CommonTestFixture { void set_up_apply_data(gko::size_type num_vecs = 1, int num_elems_per_row = 5) { - const int num_rows = 252; - const int num_cols = 32; + const gko::size_type num_rows = 252; + const gko::size_type num_cols = 32; GKO_ASSERT(num_elems_per_row <= num_cols); mat = gen_mtx(batch_size, num_rows, num_cols, num_elems_per_row); y = gen_mvec(batch_size, num_cols, num_vecs); @@ -107,7 +107,7 @@ class Ell : public CommonTestFixture { std::ranlux48 rand_engine; - const size_t batch_size = 11; + const gko::size_type batch_size = 11; std::unique_ptr mat; std::unique_ptr y; std::unique_ptr alpha; diff --git a/test/solver/CMakeLists.txt b/test/solver/CMakeLists.txt index 296a55b6271..00c78eb93a0 100644 --- a/test/solver/CMakeLists.txt +++ b/test/solver/CMakeLists.txt @@ -1,4 +1,4 @@ -ginkgo_create_common_test(batch_bicgstab_kernels DISABLE_EXECUTORS dpcpp cuda hip) +ginkgo_create_common_test(batch_bicgstab_kernels) ginkgo_create_common_test(bicg_kernels) ginkgo_create_common_test(bicgstab_kernels) ginkgo_create_common_test(cb_gmres_kernels) diff --git a/test/solver/batch_bicgstab_kernels.cpp b/test/solver/batch_bicgstab_kernels.cpp index adb68d92314..4bec19a165f 100644 --- a/test/solver/batch_bicgstab_kernels.cpp +++ b/test/solver/batch_bicgstab_kernels.cpp @@ -117,8 +117,8 @@ TEST_F(BatchBicgstab, SolvesStencilSystem) solver_settings, linear_system); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res.res_norm->get_const_values()[i] / - linear_system.rhs_norm->get_const_values()[i], + ASSERT_LE(res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i], solver_settings.residual_tol); } GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol); @@ -141,9 +141,9 @@ TEST_F(BatchBicgstab, StencilSystemLoggerLogsResidual) auto res_log_array = res.log_data->res_norms.get_const_data(); for (size_t i = 0; i < num_batch_items; i++) { - ASSERT_LE(res_log_array[i] / linear_system.rhs_norm->at(i, 0, 0), + ASSERT_LE(res_log_array[i] / linear_system.host_rhs_norm->at(i, 0, 0), solver_settings.residual_tol); - ASSERT_NEAR(res_log_array[i], res.res_norm->get_const_values()[i], + ASSERT_NEAR(res_log_array[i], res.host_res_norm->get_const_values()[i], 10 * tol); } } @@ -171,7 +171,7 @@ TEST_F(BatchBicgstab, StencilSystemLoggerLogsIterations) TEST_F(BatchBicgstab, CanSolve3ptStencilSystem) { - const int num_batch_items = 12; + const int num_batch_items = 8; const int num_rows = 100; const int num_rhs = 1; const real_type tol = 1e-5; @@ -185,35 +185,59 @@ TEST_F(BatchBicgstab, CanSolve3ptStencilSystem) GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 10); for (size_t i = 0; i < num_batch_items; i++) { - auto comp_res_norm = - exec->copy_val_to_host(res.res_norm->get_const_values() + i) / - exec->copy_val_to_host(linear_system.rhs_norm->get_const_values() + - i); + auto comp_res_norm = res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i]; ASSERT_LE(comp_res_norm, tol); } } -TEST_F(BatchBicgstab, CanSolveLargeHpdSystem) +TEST_F(BatchBicgstab, CanSolveLargeBatchSizeHpdSystem) { - const int num_batch_items = 3; + const int num_batch_items = 100; + const int num_rows = 102; + const int num_rhs = 1; + const real_type tol = 1e-5; + const int max_iters = num_rows * 2; + std::shared_ptr logger = Logger::create(); + auto mat = gko::share(gko::test::generate_diag_dominant_batch_matrix( + exec, num_batch_items, num_rows, true)); + auto linear_system = setup_linsys_and_solver(mat, num_rhs, tol, max_iters); + auto solver = gko::share(solver_factory->generate(linear_system.matrix)); + solver->add_logger(logger); + + auto res = gko::test::solve_linear_system(exec, linear_system, solver); + + solver->remove_logger(logger); + auto iter_counts = gko::make_temporary_clone(exec->get_master(), + &logger->get_num_iterations()); + auto res_norm = gko::make_temporary_clone(exec->get_master(), + &logger->get_residual_norm()); + GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 500); + for (size_t i = 0; i < num_batch_items; i++) { + auto comp_res_norm = res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i]; + ASSERT_LE(iter_counts->get_const_data()[i], max_iters); + EXPECT_LE(res_norm->get_const_data()[i] / + linear_system.host_rhs_norm->get_const_values()[i], + tol); + EXPECT_GT(res_norm->get_const_data()[i], real_type{0.0}); + ASSERT_LE(comp_res_norm, tol * 10); + } +} + + +TEST_F(BatchBicgstab, CanSolveLargeMatrixSizeHpdSystem) +{ + const int num_batch_items = 12; const int num_rows = 1025; const int num_rhs = 1; const real_type tol = 1e-5; - const int max_iters = 2000; - const real_type comp_tol = tol * 100; - auto solver_factory = - solver_type::build() - .with_max_iterations(max_iters) - .with_tolerance(tol) - .with_tolerance_type(gko::batch::stop::tolerance_type::absolute) - .on(exec); + const int max_iters = num_rows * 2; std::shared_ptr logger = Logger::create(); - auto diag_dom_mat = - gko::share(gko::test::generate_diag_dominant_batch_matrix( - exec, num_batch_items, num_rows, true)); - auto linear_system = - gko::test::generate_batch_linear_system(diag_dom_mat, num_rhs); + auto mat = gko::share(gko::test::generate_diag_dominant_batch_matrix( + exec, num_batch_items, num_rows, true)); + auto linear_system = setup_linsys_and_solver(mat, num_rhs, tol, max_iters); auto solver = gko::share(solver_factory->generate(linear_system.matrix)); solver->add_logger(logger); @@ -224,13 +248,15 @@ TEST_F(BatchBicgstab, CanSolveLargeHpdSystem) &logger->get_num_iterations()); auto res_norm = gko::make_temporary_clone(exec->get_master(), &logger->get_residual_norm()); - GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, comp_tol); + GKO_ASSERT_BATCH_MTX_NEAR(res.x, linear_system.exact_sol, tol * 500); for (size_t i = 0; i < num_batch_items; i++) { - auto comp_res_norm = - exec->copy_val_to_host(res.res_norm->get_const_values() + i); + auto comp_res_norm = res.host_res_norm->get_const_values()[i] / + linear_system.host_rhs_norm->get_const_values()[i]; ASSERT_LE(iter_counts->get_const_data()[i], max_iters); - EXPECT_LE(res_norm->get_const_data()[i], comp_tol); + EXPECT_LE(res_norm->get_const_data()[i] / + linear_system.host_rhs_norm->get_const_values()[i], + tol); EXPECT_GT(res_norm->get_const_data()[i], real_type{0.0}); - ASSERT_LE(comp_res_norm, comp_tol); + ASSERT_LE(comp_res_norm, tol * 10); } }