Skip to content

Commit

Permalink
[dpcpp] move to proper headers
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Aug 20, 2024
1 parent b35d079 commit ad5f7cd
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 63 deletions.
64 changes: 30 additions & 34 deletions dpcpp/base/batch_multi_vector_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "core/base/batch_struct.hpp"
#include "core/components/prefix_sum_kernels.hpp"
#include "dpcpp/base/batch_multi_vector_kernels.hpp"
#include "dpcpp/base/batch_struct.hpp"
#include "dpcpp/base/config.hpp"
#include "dpcpp/base/dim3.dp.hpp"
Expand All @@ -29,17 +30,9 @@
namespace gko {
namespace kernels {
namespace dpcpp {
/**
* @brief The MultiVector matrix format namespace.
* @ref MultiVector
* @ingroup batch_multi_vector
*/
namespace batch_multi_vector {


#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc"


template <typename ValueType>
void scale(std::shared_ptr<const DefaultExecutor> exec,
const batch::MultiVector<ValueType>* const alpha,
Expand Down Expand Up @@ -71,7 +64,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
const auto alpha_b =
batch::extract_batch_item(alpha_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
scale_kernel(
batch_single_kernels::scale_kernel(
alpha_b, x_b, item_ct1,
[](int row, int col, int stride) { return 0; });
});
Expand All @@ -85,10 +78,11 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
const auto alpha_b =
batch::extract_batch_item(alpha_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
scale_kernel(alpha_b, x_b, item_ct1,
[](int row, int col, int stride) {
return row * stride + col;
});
batch_single_kernels::scale_kernel(
alpha_b, x_b, item_ct1,
[](int row, int col, int stride) {
return row * stride + col;
});
});
});
} else {
Expand All @@ -100,7 +94,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
const auto alpha_b =
batch::extract_batch_item(alpha_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
scale_kernel(
batch_single_kernels::scale_kernel(
alpha_b, x_b, item_ct1,
[](int row, int col, int stride) { return col; });
});
Expand Down Expand Up @@ -144,8 +138,9 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
batch::extract_batch_item(alpha_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto y_b = batch::extract_batch_item(y_ub, group_id);
add_scaled_kernel(alpha_b, x_b, y_b, item_ct1,
[](auto col) { return 0; });
batch_single_kernels::add_scaled_kernel(
alpha_b, x_b, y_b, item_ct1,
[](auto col) { return 0; });
});
});
} else {
Expand All @@ -158,8 +153,9 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
batch::extract_batch_item(alpha_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto y_b = batch::extract_batch_item(y_ub, group_id);
add_scaled_kernel(alpha_b, x_b, y_b, item_ct1,
[](auto col) { return col; });
batch_single_kernels::add_scaled_kernel(
alpha_b, x_b, y_b, item_ct1,
[](auto col) { return col; });
});
});
}
Expand Down Expand Up @@ -206,7 +202,7 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
batch::extract_batch_item(y_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
single_rhs_compute_conj_dot_sg(
batch_single_kernels::single_rhs_compute_conj_dot_sg(
x_b.num_rows, x_b.values, y_b.values,
res_b.values[0], item_ct1);
});
Expand All @@ -226,7 +222,7 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
batch::extract_batch_item(y_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
compute_gen_dot_product_kernel(
batch_single_kernels::compute_gen_dot_product_kernel(
x_b, y_b, res_b, item_ct1,
[](auto val) { return val; });
});
Expand Down Expand Up @@ -272,7 +268,7 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
const auto y_b = batch::extract_batch_item(y_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
compute_gen_dot_product_kernel(
batch_single_kernels::compute_gen_dot_product_kernel(
x_b, y_b, res_b, item_ct1,
[](auto val) { return conj(val); });
});
Expand Down Expand Up @@ -308,17 +304,16 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b =
batch::extract_batch_item(x_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values,
res_b.values[0], item_ct1);
});
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
max_subgroup_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
batch_single_kernels::single_rhs_compute_norm2_sg(
x_b.num_rows, x_b.values, res_b.values[0], item_ct1);
});
});
} else {
exec->get_queue()->submit([&](sycl::handler& cgh) {
Expand All @@ -332,7 +327,8 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
batch::extract_batch_item(x_ub, group_id);
const auto res_b =
batch::extract_batch_item(res_ub, group_id);
compute_norm2_kernel(x_b, res_b, item_ct1);
batch_single_kernels::compute_norm2_kernel(x_b, res_b,
item_ct1);
});
});
}
Expand Down Expand Up @@ -371,7 +367,7 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto result_b =
batch::extract_batch_item(result_ub, group_id);
copy_kernel(x_b, result_b, item_ct1);
batch_single_kernels::copy_kernel(x_b, result_b, item_ct1);
});
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@
//
// SPDX-License-Identifier: BSD-3-Clause


#include <memory>

#include <CL/sycl.hpp>

#include "core/base/batch_struct.hpp"
#include "dpcpp/base/batch_struct.hpp"
#include "dpcpp/base/config.hpp"
#include "dpcpp/base/dim3.dp.hpp"
#include "dpcpp/base/dpct.hpp"
#include "dpcpp/base/helper.hpp"
#include "dpcpp/components/cooperative_groups.dp.hpp"
#include "dpcpp/components/intrinsics.dp.hpp"
#include "dpcpp/components/reduction.dp.hpp"
#include "dpcpp/components/thread_ids.dp.hpp"


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
namespace batch_single_kernels {


template <typename ValueType, typename Mapping>
__dpct_inline__ void scale_kernel(
const gko::batch::multi_vector::batch_item<const ValueType>& alpha,
Expand Down Expand Up @@ -229,3 +252,9 @@ __dpct_inline__ void copy_kernel(
out.values[i * out.stride + j] = in.values[i * in.stride + j];
}
}


} // namespace batch_single_kernels
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
2 changes: 1 addition & 1 deletion dpcpp/solver/batch_bicgstab_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "core/base/batch_struct.hpp"
#include "core/matrix/batch_struct.hpp"
#include "core/solver/batch_dispatch.hpp"
#include "dpcpp/base/batch_multi_vector_kernels.hpp"
#include "dpcpp/base/batch_struct.hpp"
#include "dpcpp/base/config.hpp"
#include "dpcpp/base/dim3.dp.hpp"
Expand All @@ -36,7 +37,6 @@ namespace dpcpp {
namespace batch_bicgstab {


#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc"
#include "dpcpp/matrix/batch_csr_kernels.hpp.inc"
#include "dpcpp/matrix/batch_dense_kernels.hpp.inc"
#include "dpcpp/matrix/batch_ell_kernels.hpp.inc"
Expand Down
43 changes: 26 additions & 17 deletions dpcpp/solver/batch_bicgstab_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ __dpct_inline__ void initialize(
item_ct1.barrier(sycl::access::fence_space::global_and_local);

if (sg_id == 0) {
single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm,
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm,
item_ct1);
} else if (sg_id == 1) {
single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm,
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm,
item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

Expand Down Expand Up @@ -86,8 +88,9 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new,
const auto sg_id = sg.get_group_id();
const auto tid = item_ct1.get_local_linear_id();
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry,
v_shared_entry, alpha, item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry,
v_shared_entry, alpha, item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
if (tid == 0) {
Expand Down Expand Up @@ -123,11 +126,13 @@ __dpct_inline__ void compute_omega(const int num_rows,
const auto sg_id = sg.get_group_id();
const auto tid = item_ct1.get_local_linear_id();
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, s_shared_entry,
omega, item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry,
s_shared_entry, omega, item_ct1);
} else if (sg_id == 1) {
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, t_shared_entry,
temp, item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry,
t_shared_entry, temp, item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
if (tid == 0) {
Expand Down Expand Up @@ -308,8 +313,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,

// rho_new = < r_hat , r > = (r_hat)' * (r)
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh,
rho_new_sh[0], item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh,
rho_new_sh[0], item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

Expand Down Expand Up @@ -338,8 +344,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,

// an estimate of residual norms
if (sg_id == 0) {
single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0],
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0],
item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

Expand Down Expand Up @@ -368,8 +375,9 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
item_ct1.barrier(sycl::access::fence_space::global_and_local);

if (sg_id == 0)
single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0],
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0],
item_ct1);
if (tid == group_size - 1) {
rho_old_sh[0] = rho_new_sh[0];
}
Expand All @@ -379,6 +387,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
logger.log_iteration(batch_id, iter, norms_res_sh[0]);

// copy x back to global memory
copy_kernel(num_rows, x_sh, x_global_entry, item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel(
num_rows, x_sh, x_global_entry, item_ct1);
item_ct1.barrier(sycl::access::fence_space::global_and_local);
}
2 changes: 1 addition & 1 deletion dpcpp/solver/batch_cg_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "core/base/batch_struct.hpp"
#include "core/matrix/batch_struct.hpp"
#include "core/solver/batch_dispatch.hpp"
#include "dpcpp/base/batch_multi_vector_kernels.hpp"
#include "dpcpp/base/batch_struct.hpp"
#include "dpcpp/base/config.hpp"
#include "dpcpp/base/dim3.dp.hpp"
Expand All @@ -36,7 +37,6 @@ namespace dpcpp {
namespace batch_cg {


#include "dpcpp/base/batch_multi_vector_kernels.hpp.inc"
#include "dpcpp/matrix/batch_csr_kernels.hpp.inc"
#include "dpcpp/matrix/batch_dense_kernels.hpp.inc"
#include "dpcpp/matrix/batch_ell_kernels.hpp.inc"
Expand Down
25 changes: 15 additions & 10 deletions dpcpp/solver/batch_cg_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ __dpct_inline__ void initialize(
// Compute norms of rhs
// and rho_old = r' * z
if (sg_id == 0) {
single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norms,
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norms,
item_ct1);
} else if (sg_id == 1) {
single_rhs_compute_conj_dot_sg(num_rows, r_shared_entry, z_shared_entry,
rho_old, item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, r_shared_entry,
z_shared_entry, rho_old, item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

Expand Down Expand Up @@ -80,9 +82,10 @@ __dpct_inline__ void update_x_and_r(
auto sg = item_ct1.get_sub_group();
const auto tid = item_ct1.get_local_linear_id();
if (sg.get_group_id() == 0) {
single_rhs_compute_conj_dot_sg(num_rows, p_shared_entry,
Ap_shared_entry, alpha_shared_entry,
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, p_shared_entry,
Ap_shared_entry, alpha_shared_entry,
item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
if (tid == 0) {
Expand Down Expand Up @@ -221,8 +224,9 @@ __dpct_inline__ void apply_kernel(

// rho_new = (r)' * (z)
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, r_sh, z_sh, rho_new_sh[0],
item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot_sg(num_rows, r_sh, z_sh,
rho_new_sh[0], item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

Expand All @@ -239,6 +243,7 @@ __dpct_inline__ void apply_kernel(
logger.log_iteration(batch_id, iter, norms_res_sh[0]);

// copy x back to global memory
copy_kernel(num_rows, x_sh, x_global_entry, item_ct1);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::copy_kernel(
num_rows, x_sh, x_global_entry, item_ct1);
item_ct1.barrier(sycl::access::fence_space::global_and_local);
}

0 comments on commit ad5f7cd

Please sign in to comment.