Skip to content

Commit

Permalink
Review updates
Browse files Browse the repository at this point in the history
Co-authored-by: Yu-Hsiang Tsai <[email protected]>
Co-authored-by: Marcel Koch <[email protected]>
  • Loading branch information
3 people committed Oct 30, 2023
1 parent 6459e2f commit b653d3b
Show file tree
Hide file tree
Showing 14 changed files with 177 additions and 183 deletions.
2 changes: 1 addition & 1 deletion common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ __global__ __launch_bounds__(


template <typename Group, typename ValueType>
__device__ __forceinline__ void single_rhs_compute_dot(Group subgroup,
__device__ __forceinline__ void single_rhs_compute_conj_dot(Group subgroup,
const int num_rows,
const ValueType* x,
const ValueType* y,
Expand Down
3 changes: 2 additions & 1 deletion common/cuda_hip/preconditioner/batch_identity.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ public:

__device__ __forceinline__ void generate(
size_type,
const gko::batch::matrix::ell::batch_item<const ValueType, gko::int32>&,
const gko::batch::matrix::ell::batch_item<const ValueType,
const gko::int32>&,
ValueType*)
{}

Expand Down
26 changes: 15 additions & 11 deletions common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ __device__ __forceinline__ void initialize(
const ValueType* const x_global_entry, ValueType& rho_old, ValueType& omega,
ValueType& alpha, ValueType* const x_shared_entry,
ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry,
ValueType* const p_shared_entry, ValueType* const v_shared_entry,
ValueType* const p_shared_entry, ValueType* const p_hat_shared_entry,
ValueType* const v_shared_entry,
typename gko::remove_complex<ValueType>& rhs_norm,
typename gko::remove_complex<ValueType>& res_norm)
{
Expand Down Expand Up @@ -70,6 +71,7 @@ __device__ __forceinline__ void initialize(
for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) {
r_hat_shared_entry[iz] = r_shared_entry[iz];
p_shared_entry[iz] = zero<ValueType>();
p_hat_shared_entry[iz] = zero<ValueType>();
v_shared_entry[iz] = zero<ValueType>();
}
}
Expand All @@ -82,8 +84,8 @@ __device__ __forceinline__ void update_p(
const ValueType* const r_shared_entry,
const ValueType* const v_shared_entry, ValueType* const p_shared_entry)
{
const ValueType beta = (rho_new / rho_old) * (alpha / omega);
for (int r = threadIdx.x; r < num_rows; r += blockDim.x) {
const ValueType beta = (rho_new / rho_old) * (alpha / omega);
p_shared_entry[r] =
r_shared_entry[r] +
beta * (p_shared_entry[r] - omega * v_shared_entry[r]);
Expand All @@ -97,8 +99,8 @@ __device__ __forceinline__ void compute_alpha(
const ValueType* const v_shared_entry, ValueType& alpha)
{
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_dot(subgroup, num_rows, r_hat_shared_entry,
v_shared_entry, alpha);
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry,
v_shared_entry, alpha);
}
__syncthreads();
if (threadIdx.x == 0) {
Expand Down Expand Up @@ -126,11 +128,11 @@ __device__ __forceinline__ void compute_omega(
const ValueType* const s_shared_entry, ValueType& temp, ValueType& omega)
{
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_dot(subgroup, num_rows, t_shared_entry,
s_shared_entry, omega);
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
s_shared_entry, omega);
} else if (threadIdx.x / config::warp_size == 1) {
single_rhs_compute_dot(subgroup, num_rows, t_shared_entry,
t_shared_entry, temp);
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
t_shared_entry, temp);
}

__syncthreads();
Expand Down Expand Up @@ -278,10 +280,12 @@ __global__ void apply_kernel(
// compute residual norms
// r_hat = r
// p = 0
// p_hat = 0
// v = 0
initialize(subgroup, num_rows, mat_entry, b_entry_ptr, x_gl_entry_ptr,
rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh,
r_hat_sh, p_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0]);
r_hat_sh, p_sh, p_hat_sh, v_sh, norms_rhs_sh[0],
norms_res_sh[0]);
__syncthreads();

// stopping criterion object
Expand All @@ -296,8 +300,8 @@ __global__ void apply_kernel(

// rho_new = < r_hat , r > = (r_hat)' * (r)
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_dot(subgroup, num_rows, r_hat_sh, r_sh,
rho_new_sh[0]);
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh, r_sh,
rho_new_sh[0]);
}
__syncthreads();

Expand Down
18 changes: 9 additions & 9 deletions core/solver/batch_bicgstab_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ void set_gmem_stride_bytes(storage_config& sconf,
}
// align global memory chunks
sconf.gmem_stride_bytes =
gmem_stride > 0 ? ((gmem_stride - 1) / align_bytes + 1) * align_bytes
: 0;
gmem_stride > 0 ? ceildiv(gmem_stride, align_bytes) * align_bytes : 0;
}


Expand All @@ -143,8 +142,8 @@ void set_gmem_stride_bytes(storage_config& sconf,
* - rhs_norms
* - res_norms
*
* @param shared_mem_per_blk The amount of shared memory per block to use for
* keeping intermediate vectors. In case keeping the matrix in L1 cache etc.
* @param available_shared_mem The amount of shared memory per block to use
* for keeping intermediate vectors. In case keeping the matrix in L1 cache etc.
* should be prioritized, the cache configuration must be updated separately
* and the needed space should be subtracted before passing to this
* function.
Expand All @@ -154,7 +153,7 @@ void set_gmem_stride_bytes(storage_config& sconf,
* @return A struct containing allocation information specific to Bicgstab.
*/
template <typename Prectype, typename ValueType, int align_bytes = 32>
storage_config compute_shared_storage(const int shared_mem_per_blk,
storage_config compute_shared_storage(const int available_shared_mem,
const int num_rows, const int num_nz,
const int num_rhs)
{
Expand All @@ -163,10 +162,11 @@ storage_config compute_shared_storage(const int shared_mem_per_blk,
const int num_main_vecs = 9;
const int prec_storage =
Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType);
int rem_shared = shared_mem_per_blk;
// Set default values. All vecs are in global.
int rem_shared = available_shared_mem;
// Set default values. Initially all vecs are in global memory.
// {prec_shared, n_shared, n_global, gmem_stride_bytes, padded_vec_len}
storage_config sconf{false, 0, num_main_vecs, 0, num_rows};
// If available shared mem, is zero, set all vecs to global.
// If available shared mem is zero, set all vecs to global.
if (rem_shared <= 0) {
set_gmem_stride_bytes<align_bytes>(sconf, vec_size, prec_storage);
return sconf;
Expand All @@ -177,13 +177,13 @@ storage_config compute_shared_storage(const int shared_mem_per_blk,
const int num_vecs_shared = min(initial_vecs_available, num_main_vecs);
sconf.n_shared += num_vecs_shared;
sconf.n_global -= num_vecs_shared;
rem_shared -= num_vecs_shared * vec_size;
// Set the storage configuration with preconditioner workspace in global if
// there are any vectors in global memory.
if (sconf.n_global > 0) {
set_gmem_stride_bytes<align_bytes>(sconf, vec_size, prec_storage);
return sconf;
}
rem_shared -= num_vecs_shared * vec_size;
// If more shared memory space is available and preconditioner workspace is
// needed, enable preconditioner workspace to use shared memory.
if (rem_shared >= prec_storage && prec_storage > 0) {
Expand Down
3 changes: 2 additions & 1 deletion cuda/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ get_batch_struct(batch::matrix::Dense<ValueType>* const op)
* Generates an immutable uniform batch struct from a batch of ell matrices.
*/
template <typename ValueType, typename IndexType>
inline batch::matrix::ell::uniform_batch<const cuda_type<ValueType>, IndexType>
inline batch::matrix::ell::uniform_batch<const cuda_type<ValueType>,
const IndexType>
get_batch_struct(const batch::matrix::Ell<ValueType, IndexType>* const op)
{
return {as_cuda_type(op->get_const_values()),
Expand Down
5 changes: 1 addition & 4 deletions cuda/solver/batch_bicgstab_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@ int get_num_threads_per_block(std::shared_ptr<const DefaultExecutor> exec,
cudaDeviceGetAttribute(&max_regs_blk, cudaDevAttrMaxRegistersPerBlock,
exec->get_device_id());
const int max_threads_regs =
((max_regs_blk /
static_cast<int>((static_cast<double>(num_regs_used)))) /
warp_sz) *
warp_sz;
((max_regs_blk / static_cast<int>(num_regs_used)) / warp_sz) * warp_sz;
int max_threads = std::min(max_threads_regs, device_max_threads);
max_threads = max_threads <= 1024 ? max_threads : 1024;
return std::min(num_warps * warp_sz, max_threads);
Expand Down
137 changes: 65 additions & 72 deletions dpcpp/base/batch_multi_vector_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
long max_group_size =
device.get_info<sycl::info::device::max_work_group_size>();
int group_size =
std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
max_group_size);

const dim3 block(group_size);
Expand Down Expand Up @@ -141,7 +141,7 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
long max_group_size =
device.get_info<sycl::info::device::max_work_group_size>();
int group_size =
std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
max_group_size);

const dim3 block(group_size);
Expand Down Expand Up @@ -202,49 +202,45 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
long max_group_size =
device.get_info<sycl::info::device::max_work_group_size>();
int group_size =
std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
max_group_size);

const dim3 block(group_size);
const dim3 grid(num_batches);
if (x->get_common_size()[1] == 1) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b =
batch::extract_batch_item(x_ub, group_id);
const auto y_b =
batch::extract_batch_item(y_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
single_rhs_compute_dot_sg(x_b.num_rows, x_b.values,
y_b.values, res_b.values[0],
item_ct1);
});
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto y_b = batch::extract_batch_item(y_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
single_rhs_compute_conj_dot_sg(x_b.num_rows, x_b.values,
y_b.values, res_b.values[0],
item_ct1);
});
});
} else {
// TODO: Remove reqd_sub_group size and use sycl::reduce_over_group
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b =
batch::extract_batch_item(x_ub, group_id);
const auto y_b =
batch::extract_batch_item(y_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
compute_gen_dot_product_kernel(
x_b, y_b, res_b, item_ct1,
[](auto val) { return val; });
});
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto y_b = batch::extract_batch_item(y_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
compute_gen_dot_product_kernel(
x_b, y_b, res_b, item_ct1,
[](auto val) { return val; });
});
});
}
}
Expand All @@ -270,27 +266,26 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
long max_group_size =
device.get_info<sycl::info::device::max_work_group_size>();
int group_size =
std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
max_group_size);

const dim3 block(group_size);
const dim3 grid(num_batches);

exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto y_b = batch::extract_batch_item(y_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
compute_gen_dot_product_kernel(
x_b, y_b, res_b, item_ct1,
[](auto val) { return conj(val); });
});
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto y_b = batch::extract_batch_item(y_ub, group_id);
const auto res_b = batch::extract_batch_item(res_ub, group_id);
compute_gen_dot_product_kernel(
x_b, y_b, res_b, item_ct1,
[](auto val) { return conj(val); });
});
});
}

Expand All @@ -314,41 +309,39 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
long max_group_size =
device.get_info<sycl::info::device::max_work_group_size>();
int group_size =
std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
max_group_size);

const dim3 block(group_size);
const dim3 grid(num_batches);
if (x->get_common_size()[1] == 1) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b =
batch::extract_batch_item(x_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values,
res_b.values[0], item_ct1);
});
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values,
res_b.values[0], item_ct1);
});
});
} else {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b =
batch::extract_batch_item(x_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
compute_norm2_kernel(x_b, res_b, item_ct1);
});
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
compute_norm2_kernel(x_b, res_b, item_ct1);
});
});
}
}
Expand All @@ -372,7 +365,7 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
long max_group_size =
device.get_info<sycl::info::device::max_work_group_size>();
int group_size =
std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
max_group_size);

const dim3 block(group_size);
Expand Down
Loading

0 comments on commit b653d3b

Please sign in to comment.