Skip to content

Commit

Permalink
[batch] add launch bounds and fix register check
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Oct 2, 2024
1 parent 85b80df commit 6cdfded
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 29 deletions.
19 changes: 13 additions & 6 deletions common/cuda_hip/solver/batch_bicgstab_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {


constexpr int max_bicgstab_threads = 1024;


namespace batch_single_kernels {


Expand Down Expand Up @@ -170,12 +175,14 @@ __device__ __forceinline__ void update_x_middle(
template <typename StopType, int n_shared, bool prec_shared_bool,
typename PrecType, typename LogType, typename BatchMatrixType,
typename ValueType>
__global__ void apply_kernel(
const gko::kernels::batch_bicgstab::storage_config sconf,
const int max_iter, const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared, const BatchMatrixType mat,
const ValueType* const __restrict__ b, ValueType* const __restrict__ x,
ValueType* const __restrict__ workspace = nullptr)
__global__ void __launch_bounds__(max_bicgstab_threads)
apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
const int max_iter, const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared,
const BatchMatrixType mat,
const ValueType* const __restrict__ b,
ValueType* const __restrict__ x,
ValueType* const __restrict__ workspace = nullptr)
{
using real_type = typename gko::remove_complex<ValueType>;
const auto num_batch_items = mat.num_batch_items;
Expand Down
21 changes: 13 additions & 8 deletions common/cuda_hip/solver/batch_cg_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {


constexpr int max_cg_threads = 1024;


namespace batch_single_kernels {


Expand Down Expand Up @@ -115,14 +120,14 @@ __device__ __forceinline__ void update_x_and_r(
template <typename StopType, const int n_shared, const bool prec_shared_bool,
typename PrecType, typename LogType, typename BatchMatrixType,
typename ValueType>
__global__ void apply_kernel(const gko::kernels::batch_cg::storage_config sconf,
const int max_iter,
const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared,
const BatchMatrixType mat,
const ValueType* const __restrict__ b,
ValueType* const __restrict__ x,
ValueType* const __restrict__ workspace = nullptr)
__global__ void __launch_bounds__(max_cg_threads)
apply_kernel(const gko::kernels::batch_cg::storage_config sconf,
const int max_iter, const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared,
const BatchMatrixType mat,
const ValueType* const __restrict__ b,
ValueType* const __restrict__ x,
ValueType* const __restrict__ workspace = nullptr)
{
using real_type = typename gko::remove_complex<ValueType>;
const auto num_batch_items = mat.num_batch_items;
Expand Down
25 changes: 17 additions & 8 deletions cuda/solver/batch_bicgstab_launch.instantiate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,29 @@ int get_num_threads_per_block(std::shared_ptr<const DefaultExecutor> exec,
constexpr int warp_sz = static_cast<int>(config::warp_size);
const int min_block_size = 2 * warp_sz;
const int device_max_threads =
((std::max(num_rows, min_block_size)) / warp_sz) * warp_sz;
cudaFuncAttributes funcattr;
cudaFuncGetAttributes(
&funcattr,
batch_single_kernels::apply_kernel<StopType, 9, true, PrecType, LogType,
BatchMatrixType, ValueType>);
const int num_regs_used = funcattr.numRegs;
(std::max(num_rows, min_block_size) / warp_sz) * warp_sz;
auto get_num_regs = [](const auto func) {
cudaFuncAttributes funcattr;
cudaFuncGetAttributes(&funcattr, func);
return funcattr.numRegs;
};
const int num_regs_used = std::max(
get_num_regs(
batch_single_kernels::apply_kernel<StopType, 9, true, PrecType,
LogType, BatchMatrixType,
ValueType>),
get_num_regs(
batch_single_kernels::apply_kernel<StopType, 0, false, PrecType,
LogType, BatchMatrixType,
ValueType>));
int max_regs_blk = 0;
cudaDeviceGetAttribute(&max_regs_blk, cudaDevAttrMaxRegistersPerBlock,
exec->get_device_id());
const int max_threads_regs =
((max_regs_blk / static_cast<int>(num_regs_used)) / warp_sz) * warp_sz;
int max_threads = std::min(max_threads_regs, device_max_threads);
max_threads = max_threads <= 1024 ? max_threads : 1024;
max_threads = max_threads <= max_bicgstab_threads ? max_threads
: max_bicgstab_threads;
return std::max(std::min(num_warps * warp_sz, max_threads), min_block_size);
}

Expand Down
22 changes: 15 additions & 7 deletions cuda/solver/batch_cg_launch.instantiate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,27 @@ int get_num_threads_per_block(std::shared_ptr<const DefaultExecutor> exec,
const int min_block_size = 2 * warp_sz;
const int device_max_threads =
(std::max(num_rows, min_block_size) / warp_sz) * warp_sz;
cudaFuncAttributes funcattr;
cudaFuncGetAttributes(
&funcattr,
batch_single_kernels::apply_kernel<StopType, 5, true, PrecType, LogType,
BatchMatrixType, ValueType>);
const int num_regs_used = funcattr.numRegs;
auto get_num_regs = [](const auto func) {
cudaFuncAttributes funcattr;
cudaFuncGetAttributes(&funcattr, func);
return funcattr.numRegs;
};
const int num_regs_used = std::max(
get_num_regs(
batch_single_kernels::apply_kernel<StopType, 5, true, PrecType,
LogType, BatchMatrixType,
ValueType>),
get_num_regs(
batch_single_kernels::apply_kernel<StopType, 0, false, PrecType,
LogType, BatchMatrixType,
ValueType>));
int max_regs_blk = 0;
cudaDeviceGetAttribute(&max_regs_blk, cudaDevAttrMaxRegistersPerBlock,
exec->get_device_id());
const int max_threads_regs =
((max_regs_blk / static_cast<int>(num_regs_used)) / warp_sz) * warp_sz;
int max_threads = std::min(max_threads_regs, device_max_threads);
max_threads = max_threads <= 1024 ? max_threads : 1024;
max_threads = max_threads <= max_cg_threads ? max_threads : max_cg_threads;
return std::max(std::min(num_warps * warp_sz, max_threads), min_block_size);
}

Expand Down

0 comments on commit 6cdfded

Please sign in to comment.