Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify and simplify batch functionality: Multivector #1651

Merged
merged 6 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/cuda_hip/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include(${PROJECT_SOURCE_DIR}/cmake/template_instantiation.cmake)
set(CUDA_HIP_SOURCES
base/batch_multi_vector_kernels.cpp
base/device_matrix_data_kernels.cpp
base/index_set_kernels.cpp
components/prefix_sum_kernels.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,32 @@
//
// SPDX-License-Identifier: BSD-3-Clause

#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp"

#include <thrust/functional.h>
#include <thrust/transform.h>

#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/types.hpp>

#include "common/cuda_hip/base/config.hpp"
#include "common/cuda_hip/base/math.hpp"
#include "common/cuda_hip/base/runtime.hpp"
#include "core/base/batch_multi_vector_kernels.hpp"
#include "core/base/batch_struct.hpp"


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
namespace batch_multi_vector {


constexpr auto default_block_size = 256;


template <typename ValueType>
void scale(std::shared_ptr<const DefaultExecutor> exec,
const batch::MultiVector<ValueType>* const alpha,
Expand All @@ -11,16 +37,19 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
const auto alpha_ub = get_batch_struct(alpha);
const auto x_ub = get_batch_struct(x);
if (alpha->get_common_size()[1] == 1) {
scale_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
batch_single_kernels::scale_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
alpha_ub, x_ub,
[] __device__(int row, int col, int stride) { return 0; });
} else if (alpha->get_common_size() == x->get_common_size()) {
scale_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
batch_single_kernels::scale_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
alpha_ub, x_ub, [] __device__(int row, int col, int stride) {
return row * stride + col;
});
} else {
scale_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
batch_single_kernels::scale_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
alpha_ub, x_ub,
[] __device__(int row, int col, int stride) { return col; });
}
Expand All @@ -42,12 +71,12 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
const auto x_ub = get_batch_struct(x);
const auto y_ub = get_batch_struct(y);
if (alpha->get_common_size()[1] == 1) {
add_scaled_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
batch_single_kernels::add_scaled_kernel<<<
num_blocks, default_block_size, 0, exec->get_stream()>>>(
alpha_ub, x_ub, y_ub, [] __device__(int col) { return 0; });
} else {
add_scaled_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
batch_single_kernels::add_scaled_kernel<<<
num_blocks, default_block_size, 0, exec->get_stream()>>>(
alpha_ub, x_ub, y_ub, [] __device__(int col) { return col; });
}
}
Expand All @@ -67,8 +96,8 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
const auto x_ub = get_batch_struct(x);
const auto y_ub = get_batch_struct(y);
const auto res_ub = get_batch_struct(result);
compute_gen_dot_product_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
batch_single_kernels::compute_gen_dot_product_kernel<<<
num_blocks, default_block_size, 0, exec->get_stream()>>>(
x_ub, y_ub, res_ub, [] __device__(auto val) { return val; });
}

Expand All @@ -87,8 +116,8 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
const auto x_ub = get_batch_struct(x);
const auto y_ub = get_batch_struct(y);
const auto res_ub = get_batch_struct(result);
compute_gen_dot_product_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
batch_single_kernels::compute_gen_dot_product_kernel<<<
num_blocks, default_block_size, 0, exec->get_stream()>>>(
x_ub, y_ub, res_ub, [] __device__(auto val) { return conj(val); });
}

Expand All @@ -105,8 +134,9 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
const auto num_rhs = x->get_common_size()[1];
const auto x_ub = get_batch_struct(x);
const auto res_ub = get_batch_struct(result);
compute_norm2_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(x_ub, res_ub);
batch_single_kernels::compute_norm2_kernel<<<num_blocks, default_block_size,
0, exec->get_stream()>>>(
x_ub, res_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
Expand All @@ -121,8 +151,15 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
const auto num_blocks = x->get_num_batch_items();
const auto result_ub = get_batch_struct(result);
const auto x_ub = get_batch_struct(x);
copy_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
x_ub, result_ub);
batch_single_kernels::
copy_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
x_ub, result_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);


} // namespace batch_multi_vector
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,44 @@
//
// SPDX-License-Identifier: BSD-3-Clause

#include <thrust/functional.h>
#include <thrust/transform.h>

#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/types.hpp>

#include "common/cuda_hip/base/config.hpp"
#include "common/cuda_hip/base/math.hpp"
#include "common/cuda_hip/base/runtime.hpp"
#include "common/cuda_hip/base/thrust.hpp"
#include "common/cuda_hip/base/types.hpp"
#include "common/cuda_hip/components/cooperative_groups.hpp"
#include "common/cuda_hip/components/format_conversion.hpp"
#include "common/cuda_hip/components/reduction.hpp"
#include "common/cuda_hip/components/segment_scan.hpp"
#include "common/cuda_hip/components/thread_ids.hpp"
#include "common/cuda_hip/components/warp_blas.hpp"

#if defined(GKO_COMPILING_CUDA)
#include "cuda/base/batch_struct.hpp"
#elif defined(GKO_COMPILING_HIP)
#include "hip/base/batch_struct.hip.hpp"
#else
#error "batch struct def missing"
#endif


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
namespace batch_single_kernels {


constexpr auto default_block_size = 256;


template <typename ValueType, typename Mapping>
__device__ __forceinline__ void scale(
const gko::batch::multi_vector::batch_item<const ValueType>& alpha,
Expand All @@ -20,8 +58,7 @@ __device__ __forceinline__ void scale(


template <typename ValueType, typename Mapping>
__global__
__launch_bounds__(default_block_size, sm_oversubscription) void scale_kernel(
__global__ __launch_bounds__(default_block_size) void scale_kernel(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have used 4 for sm_oversubscription on both cuda/hip.
I assume the cuda is the correct and hip just uses it.
if you want to compute in more accurate mapping, hip should use (min_blocks_multiprocessor (4) * max_threads_per_block (256) )/32 = 32 for hip.
you will need to distinguish it by macro

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I havent fully benchmarked that yet. I agree that this will have to be a macro specialized for CUDA and HIP. But will done in a future PR. It has already been noted in #1376

const gko::batch::multi_vector::uniform_batch<const ValueType> alpha,
const gko::batch::multi_vector::uniform_batch<ValueType> x, Mapping map)
{
Expand Down Expand Up @@ -52,20 +89,10 @@ __device__ __forceinline__ void add_scaled(


template <typename ValueType, typename Mapping>
__global__ __launch_bounds__(
default_block_size,
sm_oversubscription) void add_scaled_kernel(const gko::batch::multi_vector::
uniform_batch<
const ValueType>
alpha,
const gko::batch::multi_vector::
uniform_batch<
const ValueType>
x,
const gko::batch::multi_vector::
uniform_batch<ValueType>
y,
Mapping map)
__global__ __launch_bounds__(default_block_size) void add_scaled_kernel(
const gko::batch::multi_vector::uniform_batch<const ValueType> alpha,
const gko::batch::multi_vector::uniform_batch<const ValueType> x,
const gko::batch::multi_vector::uniform_batch<ValueType> y, Mapping map)
{
for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_items;
batch_id += gridDim.x) {
Expand Down Expand Up @@ -145,7 +172,7 @@ __device__ __forceinline__ void compute_gen_dot_product(

template <typename ValueType, typename Mapping>
__global__
__launch_bounds__(default_block_size, sm_oversubscription) void compute_gen_dot_product_kernel(
__launch_bounds__(default_block_size) void compute_gen_dot_product_kernel(
const gko::batch::multi_vector::uniform_batch<const ValueType> x,
const gko::batch::multi_vector::uniform_batch<const ValueType> y,
const gko::batch::multi_vector::uniform_batch<ValueType> result,
Expand Down Expand Up @@ -232,19 +259,10 @@ __device__ __forceinline__ void compute_norm2(


template <typename ValueType>
__global__ __launch_bounds__(
default_block_size,
sm_oversubscription) void compute_norm2_kernel(const gko::batch::
multi_vector::
uniform_batch<
const ValueType>
x,
const gko::batch::
multi_vector::
uniform_batch<
remove_complex<
ValueType>>
result)
__global__ __launch_bounds__(default_block_size) void compute_norm2_kernel(
const gko::batch::multi_vector::uniform_batch<const ValueType> x,
const gko::batch::multi_vector::uniform_batch<remove_complex<ValueType>>
result)
{
for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_items;
batch_id += gridDim.x) {
Expand Down Expand Up @@ -287,8 +305,7 @@ __device__ __forceinline__ void copy(


template <typename ValueType>
__global__
__launch_bounds__(default_block_size, sm_oversubscription) void copy_kernel(
__global__ __launch_bounds__(default_block_size) void copy_kernel(
const gko::batch::multi_vector::uniform_batch<const ValueType> src,
const gko::batch::multi_vector::uniform_batch<ValueType> dst)
{
Expand All @@ -299,3 +316,9 @@ __launch_bounds__(default_block_size, sm_oversubscription) void copy_kernel(
copy(src_b, dst_b);
}
}


} // namespace batch_single_kernels
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
41 changes: 26 additions & 15 deletions common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@ __device__ __forceinline__ void initialize(
__syncthreads();

if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_norm2(subgroup, num_rows, r_shared_entry, res_norm);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2(subgroup, num_rows, r_shared_entry,
res_norm);
} else if (threadIdx.x / config::warp_size == 1) {
// Compute norms of rhs
single_rhs_compute_norm2(subgroup, num_rows, b_global_entry, rhs_norm);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2(subgroup, num_rows, b_global_entry,
rhs_norm);
}
__syncthreads();

Expand Down Expand Up @@ -70,8 +74,9 @@ __device__ __forceinline__ void compute_alpha(
const ValueType* const v_shared_entry, ValueType& alpha)
{
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry,
v_shared_entry, alpha);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry,
v_shared_entry, alpha);
}
__syncthreads();
if (threadIdx.x == 0) {
Expand Down Expand Up @@ -99,11 +104,13 @@ __device__ __forceinline__ void compute_omega(
const ValueType* const s_shared_entry, ValueType& temp, ValueType& omega)
{
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
s_shared_entry, omega);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
s_shared_entry, omega);
} else if (threadIdx.x / config::warp_size == 1) {
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
t_shared_entry, temp);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
t_shared_entry, temp);
}

__syncthreads();
Expand Down Expand Up @@ -271,8 +278,9 @@ __global__ void apply_kernel(

// rho_new = < r_hat , r > = (r_hat)' * (r)
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh, r_sh,
rho_new_sh[0]);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh,
r_sh, rho_new_sh[0]);
}
__syncthreads();

Expand Down Expand Up @@ -301,8 +309,9 @@ __global__ void apply_kernel(

// an estimate of residual norms
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_norm2(subgroup, num_rows, s_sh,
norms_res_sh[0]);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2(subgroup, num_rows, s_sh,
norms_res_sh[0]);
}
__syncthreads();

Expand Down Expand Up @@ -333,8 +342,9 @@ __global__ void apply_kernel(
__syncthreads();

if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_norm2(subgroup, num_rows, r_sh,
norms_res_sh[0]);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2(subgroup, num_rows, r_sh,
norms_res_sh[0]);
}
//__syncthreads();

Expand All @@ -347,7 +357,8 @@ __global__ void apply_kernel(
logger.log_iteration(batch_id, iter, norms_res_sh[0]);

// copy x back to global memory
single_rhs_copy(num_rows, x_sh, x_gl_entry_ptr);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_copy(num_rows, x_sh, x_gl_entry_ptr);
__syncthreads();
}
}
Loading