Skip to content

Commit

Permalink
Update dpcpp kernels and fix for 2022-1
Browse files Browse the repository at this point in the history
Cannot use sycl::reduce_over_group for older DPCPP versions.
  • Loading branch information
pratikvn committed Jul 24, 2023
1 parent 76da3ac commit 19b40f2
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 91 deletions.
84 changes: 61 additions & 23 deletions dpcpp/base/batch_multi_vector_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "dpcpp/base/dim3.dp.hpp"
#include "dpcpp/base/dpct.hpp"
#include "dpcpp/base/helper.hpp"
#include "dpcpp/components/cooperative_groups.dp.hpp"
#include "dpcpp/components/intrinsics.dp.hpp"
#include "dpcpp/components/reduction.dp.hpp"
#include "dpcpp/components/thread_ids.dp.hpp"


namespace gko {
Expand Down Expand Up @@ -81,16 +85,31 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
const dim3 grid(num_batches);

// Launch a kernel that has nbatches blocks, each block has max group size
(exec->get_queue())->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto alpha_b = batch::batch_entry(alpha_ub, group_id);
const auto x_b = batch::batch_entry(x_ub, group_id);
scale_kernel(alpha_b, x_b, item_ct1);
});
});
if (alpha->get_common_size()[1] == 1) {
(exec->get_queue())->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto alpha_b = batch::batch_entry(alpha_ub, group_id);
const auto x_b = batch::batch_entry(x_ub, group_id);
scale_kernel(alpha_b, x_b, item_ct1,
[](int col) { return 0; });
});
});
} else {
(exec->get_queue())->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto alpha_b = batch::batch_entry(alpha_ub, group_id);
const auto x_b = batch::batch_entry(x_ub, group_id);
scale_kernel(alpha_b, x_b, item_ct1,
[](int col) { return col; });
});
});
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
Expand All @@ -116,17 +135,33 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
const auto alpha_ub = get_batch_struct(alpha);
const auto x_ub = get_batch_struct(x);
const auto y_ub = get_batch_struct(y);
(exec->get_queue())->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto alpha_b = batch::batch_entry(alpha_ub, group_id);
const auto x_b = batch::batch_entry(x_ub, group_id);
const auto y_b = batch::batch_entry(y_ub, group_id);
add_scaled_kernel(alpha_b, x_b, y_b, item_ct1);
});
});
if (alpha->get_common_size()[1] == 1) {
(exec->get_queue())->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto alpha_b = batch::batch_entry(alpha_ub, group_id);
const auto x_b = batch::batch_entry(x_ub, group_id);
const auto y_b = batch::batch_entry(y_ub, group_id);
add_scaled_kernel(alpha_b, x_b, y_b, item_ct1,
[](auto col) { return 0; });
});
});
} else {
(exec->get_queue())->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto alpha_b = batch::batch_entry(alpha_ub, group_id);
const auto x_b = batch::batch_entry(x_ub, group_id);
const auto y_b = batch::batch_entry(y_ub, group_id);
add_scaled_kernel(alpha_b, x_b, y_b, item_ct1,
[](auto col) { return col; });
});
});
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
Expand Down Expand Up @@ -159,7 +194,8 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
const auto x_b = batch::batch_entry(x_ub, group_id);
const auto y_b = batch::batch_entry(y_ub, group_id);
const auto res_b = batch::batch_entry(res_ub, group_id);
compute_dot_product_kernel(x_b, y_b, res_b, item_ct1);
compute_gen_dot_product_kernel(x_b, y_b, res_b, item_ct1,
[](auto val) { return val; });
});
});
}
Expand Down Expand Up @@ -194,7 +230,9 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
const auto x_b = batch::batch_entry(x_ub, group_id);
const auto y_b = batch::batch_entry(y_ub, group_id);
const auto res_b = batch::batch_entry(res_ub, group_id);
compute_conj_dot_product_kernel(x_b, y_b, res_b, item_ct1);
compute_gen_dot_product_kernel(
x_b, y_b, res_b, item_ct1,
[](auto val) { return conj(val); });
});
});
}
Expand Down
111 changes: 43 additions & 68 deletions dpcpp/base/batch_multi_vector_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -30,105 +30,72 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************<GINKGO LICENSE>*******************************/

template <typename ValueType>
template <typename ValueType, typename Mapping>
__dpct_inline__ void scale_kernel(
const gko::batch_multi_vector::batch_entry<const ValueType>& alpha,
const gko::batch_multi_vector::batch_entry<ValueType>& x,
sycl::nd_item<3>& item_ct1)
sycl::nd_item<3>& item_ct1, Mapping map)
{
const int max_li = x.num_rows * x.num_rhs;
for (int li = item_ct1.get_local_linear_id(); li < max_li;
li += item_ct1.get_local_range().size()) {
const int row = li / x.num_rhs;
const int col = li % x.num_rhs;

if (alpha.num_rhs == 1) {
x.values[row * x.stride + col] =
alpha.values[0] * x.values[row * x.stride + col];
} else {
x.values[row * x.stride + col] =
alpha.values[col] * x.values[row * x.stride + col];
}
x.values[row * x.stride + col] =
alpha.values[map(col)] * x.values[row * x.stride + col];
}
}


template <typename ValueType>
template <typename ValueType, typename Mapping>
__dpct_inline__ void add_scaled_kernel(
const gko::batch_multi_vector::batch_entry<const ValueType>& alpha,
const gko::batch_multi_vector::batch_entry<const ValueType>& x,
const gko::batch_multi_vector::batch_entry<ValueType>& y,
sycl::nd_item<3>& item_ct1)
sycl::nd_item<3>& item_ct1, Mapping map)
{
const int max_li = x.num_rows * x.num_rhs;
for (int li = item_ct1.get_local_id(2); li < max_li;
li += item_ct1.get_local_range(2)) {
const int row = li / x.num_rhs;
const int col = li % x.num_rhs;

if (alpha.num_rhs == 1) {
y.values[row * y.stride + col] +=
alpha.values[0] * x.values[row * x.stride + col];
} else {
y.values[row * y.stride + col] +=
alpha.values[col] * x.values[row * x.stride + col];
}
y.values[row * y.stride + col] +=
alpha.values[map(col)] * x.values[row * x.stride + col];
}
}


template <typename ValueType>
__dpct_inline__ void compute_dot_product_kernel(
template <typename ValueType, typename Mapping>
__dpct_inline__ void compute_gen_dot_product_kernel(
const gko::batch_multi_vector::batch_entry<const ValueType>& x,
const gko::batch_multi_vector::batch_entry<const ValueType>& y,
const gko::batch_multi_vector::batch_entry<ValueType>& result,
sycl::nd_item<3>& item_ct1)
sycl::nd_item<3>& item_ct1, Mapping conj_map)
{
const auto sg = item_ct1.get_sub_group();
const int sg_id = sg.get_group_id();
const int sg_size = sg.get_local_range().size();
const int num_sg = sg.get_group_range().size();

for (int rhs_index = sg_id; rhs_index < x.num_rhs; rhs_index += num_sg) {
constexpr auto tile_size = config::warp_size;
const auto subgroup = item_ct1.get_sub_group();
const int subgroup_id = subgroup.get_group_id();
const int subgroup_size = subgroup.get_local_range().size();
const int num_subgroups = subgroup.get_group_range().size();
auto subg =
group::tiled_partition<tile_size>(group::this_thread_block(item_ct1));

for (int rhs_index = subgroup_id; rhs_index < x.num_rhs;
rhs_index += num_subgroups) {
ValueType val = zero<ValueType>();

for (int r = sg.get_local_id(); r < x.num_rows; r += sg_size) {
val += x.values[r * x.stride + rhs_index] *
for (int r = subgroup.get_local_id(); r < x.num_rows;
r += subgroup_size) {
val += conj_map(x.values[r * x.stride + rhs_index]) *
y.values[r * y.stride + rhs_index];
}

val = sycl::reduce_over_group(sg, val, sycl::plus<>());
val = ::gko::kernels::dpcpp::reduce(
subg, val, [](ValueType a, ValueType b) { return a + b; });

if (sg.get_local_id() == 0) {
result.values[rhs_index] = val;
}
}
}


template <typename ValueType>
__dpct_inline__ void compute_conj_dot_product_kernel(
const gko::batch_multi_vector::batch_entry<const ValueType>& x,
const gko::batch_multi_vector::batch_entry<const ValueType>& y,
const gko::batch_multi_vector::batch_entry<ValueType>& result,
sycl::nd_item<3>& item_ct1)
{
const auto sg = item_ct1.get_sub_group();
const int sg_id = sg.get_group_id();
const int sg_size = sg.get_local_range().size();
const int num_sg = sg.get_group_range().size();

for (int rhs_index = sg_id; rhs_index < x.num_rhs; rhs_index += num_sg) {
ValueType val = zero<ValueType>();

for (int r = sg.get_local_id(); r < x.num_rows; r += sg_size) {
val += conj(x.values[r * x.stride + rhs_index]) *
y.values[r * y.stride + rhs_index];
}

val = sycl::reduce_over_group(sg, val, sycl::plus<>());

if (sg.get_local_id() == 0) {
if (subgroup.get_local_id() == 0) {
result.values[rhs_index] = val;
}
}
Expand All @@ -142,21 +109,29 @@ __dpct_inline__ void compute_norm2_kernel(
result,
sycl::nd_item<3>& item_ct1)
{
const auto sg = item_ct1.get_sub_group();
const int sg_id = sg.get_group_id();
const int sg_size = sg.get_local_range().size();
const int num_sg = sg.get_group_range().size();
constexpr auto tile_size = config::warp_size;
const auto subgroup = item_ct1.get_sub_group();
const int subgroup_id = subgroup.get_group_id();
const int subgroup_size = subgroup.get_local_range().size();
const int num_subgroups = subgroup.get_group_range().size();
auto subg =
group::tiled_partition<tile_size>(group::this_thread_block(item_ct1));

using real_type = typename gko::remove_complex<ValueType>;
for (int rhs_index = sg_id; rhs_index < x.num_rhs; rhs_index += num_sg) {
for (int rhs_index = subgroup_id; rhs_index < x.num_rhs;
rhs_index += num_subgroups) {
real_type val = zero<real_type>();

for (int r = sg.get_local_id(); r < x.num_rows; r += sg_size)
for (int r = subgroup.get_local_id(); r < x.num_rows;
r += subgroup_size)
val += squared_norm(x.values[r * x.stride + rhs_index]);

val = sycl::reduce_over_group(sg, val, sycl::plus<>());
val = ::gko::kernels::dpcpp::reduce(
subg, val, [](real_type a, real_type b) { return a + b; });

if (sg.get_local_id() == 0) result.values[rhs_index] = sqrt(val);
if (subgroup.get_local_id() == 0) {
result.values[rhs_index] = sqrt(val);
}
}
}

Expand Down

0 comments on commit 19b40f2

Please sign in to comment.