From 19b40f2f4276131cf9d0a8d8850111a08fe8962e Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Mon, 24 Jul 2023 14:14:45 +0200 Subject: [PATCH] Update dpcpp kernels and fix for 2022-1 Cannot use sycl::reduce_over_group for older DPCPP versions. --- dpcpp/base/batch_multi_vector_kernels.dp.cpp | 84 +++++++++---- dpcpp/base/batch_multi_vector_kernels.hpp.inc | 111 +++++++----------- 2 files changed, 104 insertions(+), 91 deletions(-) diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index 97f7469a6f6..1cd7061c161 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -48,6 +48,10 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "dpcpp/base/dim3.dp.hpp" #include "dpcpp/base/dpct.hpp" #include "dpcpp/base/helper.hpp" +#include "dpcpp/components/cooperative_groups.dp.hpp" +#include "dpcpp/components/intrinsics.dp.hpp" +#include "dpcpp/components/reduction.dp.hpp" +#include "dpcpp/components/thread_ids.dp.hpp" namespace gko { @@ -81,16 +85,31 @@ void scale(std::shared_ptr exec, const dim3 grid(num_batches); // Launch a kernel that has nbatches blocks, each block has max group size - (exec->get_queue())->submit([&](sycl::handler& cgh) { - cgh.parallel_for( - sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto alpha_b = batch::batch_entry(alpha_ub, group_id); - const auto x_b = batch::batch_entry(x_ub, group_id); - scale_kernel(alpha_b, x_b, item_ct1); - }); - }); + if (alpha->get_common_size()[1] == 1) { + (exec->get_queue())->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto alpha_b = batch::batch_entry(alpha_ub, group_id); + const auto x_b = batch::batch_entry(x_ub, group_id); + scale_kernel(alpha_b, x_b, item_ct1, + [](int col) { return 0; }); + }); + }); + } else { + (exec->get_queue())->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto alpha_b = batch::batch_entry(alpha_ub, group_id); + const auto x_b = batch::batch_entry(x_ub, group_id); + scale_kernel(alpha_b, x_b, item_ct1, + [](int col) { return col; }); + }); + }); + } } GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( @@ -116,17 +135,33 @@ void add_scaled(std::shared_ptr exec, const auto alpha_ub = get_batch_struct(alpha); const auto x_ub = get_batch_struct(x); const auto y_ub = get_batch_struct(y); - (exec->get_queue())->submit([&](sycl::handler& cgh) { - cgh.parallel_for( - sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto alpha_b = batch::batch_entry(alpha_ub, group_id); - const auto x_b = batch::batch_entry(x_ub, group_id); - const auto y_b = batch::batch_entry(y_ub, group_id); - add_scaled_kernel(alpha_b, x_b, y_b, item_ct1); - }); - }); + if (alpha->get_common_size()[1] == 1) { + (exec->get_queue())->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto alpha_b = batch::batch_entry(alpha_ub, group_id); + const auto x_b = batch::batch_entry(x_ub, group_id); + const auto y_b = batch::batch_entry(y_ub, group_id); + add_scaled_kernel(alpha_b, x_b, y_b, item_ct1, + [](auto col) { return 0; }); + }); + }); + } else { + (exec->get_queue())->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto alpha_b = batch::batch_entry(alpha_ub, group_id); + const auto x_b = batch::batch_entry(x_ub, group_id); + const auto y_b = batch::batch_entry(y_ub, group_id); + add_scaled_kernel(alpha_b, x_b, y_b, item_ct1, + [](auto col) { return col; }); + }); + }); + } } GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( @@ -159,7 +194,8 @@ void compute_dot(std::shared_ptr exec, const auto x_b = batch::batch_entry(x_ub, group_id); const auto y_b = batch::batch_entry(y_ub, group_id); const auto res_b = batch::batch_entry(res_ub, group_id); - compute_dot_product_kernel(x_b, y_b, res_b, item_ct1); + compute_gen_dot_product_kernel(x_b, y_b, res_b, item_ct1, + [](auto val) { return val; }); }); }); } @@ -194,7 +230,9 @@ void compute_conj_dot(std::shared_ptr exec, const auto x_b = batch::batch_entry(x_ub, group_id); const auto y_b = batch::batch_entry(y_ub, group_id); const auto res_b = batch::batch_entry(res_ub, group_id); - compute_conj_dot_product_kernel(x_b, y_b, res_b, item_ct1); + compute_gen_dot_product_kernel( + x_b, y_b, res_b, item_ct1, + [](auto val) { return conj(val); }); }); }); } diff --git a/dpcpp/base/batch_multi_vector_kernels.hpp.inc b/dpcpp/base/batch_multi_vector_kernels.hpp.inc index cb2ccd4ae50..6e22c5c078f 100644 --- a/dpcpp/base/batch_multi_vector_kernels.hpp.inc +++ b/dpcpp/base/batch_multi_vector_kernels.hpp.inc @@ -30,11 +30,11 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *************************************************************/ -template +template __dpct_inline__ void scale_kernel( const gko::batch_multi_vector::batch_entry& alpha, const gko::batch_multi_vector::batch_entry& x, - sycl::nd_item<3>& item_ct1) + sycl::nd_item<3>& item_ct1, Mapping map) { const int max_li = x.num_rows * x.num_rhs; for (int li = item_ct1.get_local_linear_id(); li < max_li; @@ -42,23 +42,18 @@ __dpct_inline__ void scale_kernel( const int row = li / x.num_rhs; const int col = li % x.num_rhs; - if (alpha.num_rhs == 1) { - x.values[row * x.stride + col] = - alpha.values[0] * x.values[row * x.stride + col]; - } else { - x.values[row * x.stride + col] = - alpha.values[col] * x.values[row * x.stride + col]; - } + x.values[row * x.stride + col] = + alpha.values[map(col)] * x.values[row * x.stride + col]; } } -template +template __dpct_inline__ void add_scaled_kernel( const gko::batch_multi_vector::batch_entry& alpha, const gko::batch_multi_vector::batch_entry& x, const gko::batch_multi_vector::batch_entry& y, - sycl::nd_item<3>& item_ct1) + sycl::nd_item<3>& item_ct1, Mapping map) { const int max_li = x.num_rows * x.num_rhs; for (int li = item_ct1.get_local_id(2); li < max_li; @@ -66,69 +61,41 @@ __dpct_inline__ void add_scaled_kernel( const int row = li / x.num_rhs; const int col = li % x.num_rhs; - if (alpha.num_rhs == 1) { - y.values[row * y.stride + col] += - alpha.values[0] * x.values[row * x.stride + col]; - } else { - y.values[row * y.stride + col] += - alpha.values[col] * x.values[row * x.stride + col]; - } + y.values[row * y.stride + col] += + alpha.values[map(col)] * x.values[row * x.stride + col]; } } -template -__dpct_inline__ void compute_dot_product_kernel( +template +__dpct_inline__ void compute_gen_dot_product_kernel( const gko::batch_multi_vector::batch_entry& x, const gko::batch_multi_vector::batch_entry& y, const gko::batch_multi_vector::batch_entry& result, - sycl::nd_item<3>& item_ct1) + sycl::nd_item<3>& item_ct1, Mapping conj_map) { - const auto sg = item_ct1.get_sub_group(); - const int sg_id = sg.get_group_id(); - const int sg_size = sg.get_local_range().size(); - const int num_sg = sg.get_group_range().size(); - - for (int rhs_index = sg_id; rhs_index < x.num_rhs; rhs_index += num_sg) { + constexpr auto tile_size = config::warp_size; + const auto subgroup = item_ct1.get_sub_group(); + const int subgroup_id = subgroup.get_group_id(); + const int subgroup_size = subgroup.get_local_range().size(); + const int num_subgroups = subgroup.get_group_range().size(); + auto subg = + group::tiled_partition(group::this_thread_block(item_ct1)); + + for (int rhs_index = subgroup_id; rhs_index < x.num_rhs; + rhs_index += num_subgroups) { ValueType val = zero(); - for (int r = sg.get_local_id(); r < x.num_rows; r += sg_size) { - val += x.values[r * x.stride + rhs_index] * + for (int r = subgroup.get_local_id(); r < x.num_rows; + r += subgroup_size) { + val += conj_map(x.values[r * x.stride + rhs_index]) * y.values[r * y.stride + rhs_index]; } - val = sycl::reduce_over_group(sg, val, sycl::plus<>()); + val = ::gko::kernels::dpcpp::reduce( + subg, val, [](ValueType a, ValueType b) { return a + b; }); - if (sg.get_local_id() == 0) { - result.values[rhs_index] = val; - } - } -} - - -template -__dpct_inline__ void compute_conj_dot_product_kernel( - const gko::batch_multi_vector::batch_entry& x, - const gko::batch_multi_vector::batch_entry& y, - const gko::batch_multi_vector::batch_entry& result, - sycl::nd_item<3>& item_ct1) -{ - const auto sg = item_ct1.get_sub_group(); - const int sg_id = sg.get_group_id(); - const int sg_size = sg.get_local_range().size(); - const int num_sg = sg.get_group_range().size(); - - for (int rhs_index = sg_id; rhs_index < x.num_rhs; rhs_index += num_sg) { - ValueType val = zero(); - - for (int r = sg.get_local_id(); r < x.num_rows; r += sg_size) { - val += conj(x.values[r * x.stride + rhs_index]) * - y.values[r * y.stride + rhs_index]; - } - - val = sycl::reduce_over_group(sg, val, sycl::plus<>()); - - if (sg.get_local_id() == 0) { + if (subgroup.get_local_id() == 0) { result.values[rhs_index] = val; } } @@ -142,21 +109,29 @@ __dpct_inline__ void compute_norm2_kernel( result, sycl::nd_item<3>& item_ct1) { - const auto sg = item_ct1.get_sub_group(); - const int sg_id = sg.get_group_id(); - const int sg_size = sg.get_local_range().size(); - const int num_sg = sg.get_group_range().size(); + constexpr auto tile_size = config::warp_size; + const auto subgroup = item_ct1.get_sub_group(); + const int subgroup_id = subgroup.get_group_id(); + const int subgroup_size = subgroup.get_local_range().size(); + const int num_subgroups = subgroup.get_group_range().size(); + auto subg = + group::tiled_partition(group::this_thread_block(item_ct1)); using real_type = typename gko::remove_complex; - for (int rhs_index = sg_id; rhs_index < x.num_rhs; rhs_index += num_sg) { + for (int rhs_index = subgroup_id; rhs_index < x.num_rhs; + rhs_index += num_subgroups) { real_type val = zero(); - for (int r = sg.get_local_id(); r < x.num_rows; r += sg_size) + for (int r = subgroup.get_local_id(); r < x.num_rows; + r += subgroup_size) val += squared_norm(x.values[r * x.stride + rhs_index]); - val = sycl::reduce_over_group(sg, val, sycl::plus<>()); + val = ::gko::kernels::dpcpp::reduce( + subg, val, [](real_type a, real_type b) { return a + b; }); - if (sg.get_local_id() == 0) result.values[rhs_index] = sqrt(val); + if (subgroup.get_local_id() == 0) { + result.values[rhs_index] = sqrt(val); + } } }