Skip to content

Commit

Permalink
dpcpp kernel updates
Browse files Browse the repository at this point in the history
Co-authored-by: Phuong Nguyen<[email protected]>
  • Loading branch information
pratikvn committed Jul 21, 2023
1 parent 8d5acab commit ab81b76
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 72 deletions.
7 changes: 4 additions & 3 deletions core/base/batch_multi_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

namespace gko {
namespace batch_multi_vector {
namespace {


GKO_REGISTER_OPERATION(scale, batch_multi_vector::scale);
Expand All @@ -60,6 +61,7 @@ GKO_REGISTER_OPERATION(compute_norm2, batch_multi_vector::compute_norm2);
GKO_REGISTER_OPERATION(copy, batch_multi_vector::copy);


} // namespace
} // namespace batch_multi_vector


Expand Down Expand Up @@ -248,9 +250,8 @@ void BatchMultiVector<ValueType>::write(std::vector<mat_data32>& data) const
}


#define GKO_DECLARE_BATCH_MULTI_VECTOR_MATRIX(_type) \
class BatchMultiVector<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_MATRIX);
#define GKO_DECLARE_BATCH_MULTI_VECTOR(_type) class BatchMultiVector<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR);


} // namespace gko
9 changes: 7 additions & 2 deletions dpcpp/base/batch_multi_vector_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,21 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "core/base/batch_multi_vector_kernels.hpp"


#include <algorithm>
#include <dpcpp/matrix/batch_struct.hpp>
#include <CL/sycl.hpp>


#include <ginkgo/core/base/array.hpp>
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/range_accessors.hpp>


#include "core/components/prefix_sum_kernels.hpp"
#include "dpcpp/base/batch_struct.hpp"
#include "dpcpp/base/config.hpp"
#include "dpcpp/base/dim3.dp.hpp"
#include "dpcpp/base/dpct.hpp"
#include "dpcpp/base/helper.hpp"


namespace gko {
Expand Down
115 changes: 63 additions & 52 deletions dpcpp/base/batch_multi_vector_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

template <typename ValueType>
__dpct_inline__ void scale_kernel(
const gko::batch_dense::BatchEntry<const ValueType>& alpha,
const gko::batch_dense::BatchEntry<ValueType>& x,
const gko::batch_multi_vector::BatchEntry<const ValueType>& alpha,
const gko::batch_multi_vector::BatchEntry<ValueType>& x,
sycl::nd_item<3>& item_ct1)
{
const int max_li = x.num_rows * x.num_rhs;
Expand All @@ -53,84 +53,95 @@ __dpct_inline__ void scale_kernel(
}


/**
* Adds a scaled vector to another.
*
* @param num_rows Common length of both vectors.
* @param alpha Scaling factor.
* @param[in] x Vector to scale and add.
* @param[in,out] y Vector to add to.
*/
template <typename ValueType>
__dpct_inline__ void add_scaled_kernel(const int num_rows,
const ValueType alpha,
const ValueType* const __restrict__ x,
ValueType* const __restrict__ y,
sycl::nd_item<3> item_ct1)
__dpct_inline__ void add_scaled_kernel(
const gko::batch_multi_vector::BatchEntry<const ValueType>& alpha,
const gko::batch_multi_vector::BatchEntry<const ValueType>& x,
const gko::batch_multi_vector::BatchEntry<ValueType>& y,
sycl::nd_item<3>& item_ct1)
{
for (int li = item_ct1.get_local_linear_id(); li < num_rows;
li += item_ct1.get_local_range().size()) {
y[li] += alpha * x[li];
const int max_li = x.num_rows * x.num_rhs;
for (int li = item_ct1.get_local_id(2); li < max_li;
li += item_ct1.get_local_range(2)) {
const int row = li / x.num_rhs;
const int col = li % x.num_rhs;

if (alpha.num_rhs == 1) {
y.values[row * y.stride + col] +=
alpha.values[0] * x.values[row * x.stride + col];
} else {
y.values[row * y.stride + col] +=
alpha.values[col] * x.values[row * x.stride + col];
}
}
}


template <typename ValueType>
__dpct_inline__ void compute_dot_product_kernel(
const int num_rows, const ValueType* const __restrict__ x,
const ValueType* const __restrict__ y, ValueType& result,
sycl::nd_item<3> item_ct1)
const gko::batch_multi_vector::BatchEntry<const ValueType>& x,
const gko::batch_multi_vector::BatchEntry<const ValueType>& y,
const gko::batch_multi_vector::BatchEntry<ValueType>& result,
sycl::nd_item<3>& item_ct1)
{
const auto group = item_ct1.get_group();
const auto group_size = item_ct1.get_local_range().size();
const auto tid = item_ct1.get_local_linear_id();
const auto sg = item_ct1.get_sub_group();
const int sg_id = sg.get_group_id();
const int sg_size = sg.get_local_range().size();
const int num_sg = sg.get_group_range().size();

for (int rhs_index = sg_id; rhs_index < x.num_rhs; rhs_index += num_sg) {
ValueType val = zero<ValueType>();

for (int r = sg.get_local_id(); r < x.num_rows; r += sg_size) {
val += conj(x.values[r * x.stride + rhs_index]) *
y.values[r * y.stride + rhs_index];
}

ValueType val = zero<ValueType>();
val = sycl::reduce_over_group(sg, val, sycl::plus<>());

for (int r = tid; r < num_rows; r += group_size) {
val += conj(x[r]) * y[r];
if (sg.get_local_id() == 0) {
result.values[rhs_index] = val;
}
}
result = sycl::reduce_over_group(group, val, sycl::plus<>());
}


template <typename ValueType>
__dpct_inline__ void compute_norm2_kernel(
const int num_rows, const ValueType* const __restrict__ x,
gko::remove_complex<ValueType>& result, sycl::nd_item<3> item_ct1)
const gko::batch_multi_vector::BatchEntry<const ValueType>& x,
const gko::batch_multi_vector::BatchEntry<remove_complex<ValueType>>&
result,
sycl::nd_item<3>& item_ct1)
{
const auto group = item_ct1.get_group();
const auto group_size = item_ct1.get_local_range().size();
const auto tid = item_ct1.get_local_linear_id();
const auto sg = item_ct1.get_sub_group();
const int sg_id = sg.get_group_id();
const int sg_size = sg.get_local_range().size();
const int num_sg = sg.get_group_range().size();

using real_type = typename gko::remove_complex<ValueType>;
real_type val = zero<real_type>();
for (int rhs_index = sg_id; rhs_index < x.num_rhs; rhs_index += num_sg) {
real_type val = zero<real_type>();

for (int r = tid; r < num_rows; r += group_size) {
val += squared_norm(x[r]);
}
for (int r = sg.get_local_id(); r < x.num_rows; r += sg_size)
val += squared_norm(x.values[r * x.stride + rhs_index]);

val = sycl::reduce_over_group(group, val, sycl::plus<>());
val = sycl::reduce_over_group(sg, val, sycl::plus<>());

result = sqrt(val);
if (sg.get_local_id() == 0) result.values[rhs_index] = sqrt(val);
}
}


/**
* Copies the values of vector into another.
*
* @param num_rows Length of vector.
* @param in Vector to copy from.
* @param out Vector to copy into.
*/
template <typename ValueType>
__dpct_inline__ void copy_kernel(const int num_rows,
const ValueType* const __restrict__ in,
ValueType* const __restrict__ out,
sycl::nd_item<3> item_ct1)
__dpct_inline__ void copy_kernel(
const gko::batch_multi_vector::BatchEntry<const ValueType>& in,
const gko::batch_multi_vector::BatchEntry<ValueType>& out,
sycl::nd_item<3>& item_ct1)
{
for (int iz = item_ct1.get_local_linear_id(); iz < num_rows;
for (int iz = item_ct1.get_local_linear_id(); iz < in.num_rows * in.num_rhs;
iz += item_ct1.get_local_range().size()) {
out[iz] = in[iz];
const int i = iz / in.num_rhs;
const int j = iz % in.num_rhs;
out.values[i * out.stride + j] = in.values[i * in.stride + j];
}
}
2 changes: 1 addition & 1 deletion dpcpp/base/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define GKO_DPCPP_BASE_BATCH_STRUCT_HPP_


#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/matrix/batch_multi_vector.hpp>


#include "core/base/batch_struct.hpp"
Expand Down
29 changes: 15 additions & 14 deletions test/base/batch_multi_vector_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/batch_multi_vector.hpp>


#include <memory>
#include <random>


Expand All @@ -45,16 +46,16 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "core/base/batch_multi_vector_kernels.hpp"
#include "core/test/utils.hpp"
#include "core/test/utils/assertions.hpp"
#include "core/test/utils/batch_helpers.hpp"
#include "test/utils/executor.hpp"


class BatchMultiVector : public CommonTestFixture {
protected:
using vtype = double;
using Mtx = gko::BatchMultiVector<vtype>;
using NormVector = gko::BatchMultiVector<gko::remove_complex<vtype>>;
using ComplexMtx = gko::BatchMultiVector<std::complex<vtype>>;
using Mtx = gko::BatchMultiVector<value_type>;
using NormVector = gko::BatchMultiVector<gko::remove_complex<value_type>>;
using ComplexMtx = gko::BatchMultiVector<std::complex<value_type>>;

BatchMultiVector() : rand_engine(15) {}

Expand Down Expand Up @@ -148,7 +149,7 @@ TEST_F(BatchMultiVector, SingleVectorAddScaledIsEquivalentToRef)
x->add_scaled(alpha.get(), y.get());
dx->add_scaled(dalpha.get(), dy.get());

GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 1e-14);
GKO_ASSERT_BATCH_MTX_NEAR(dx, x, r<value_type>::value);
}


Expand All @@ -159,7 +160,7 @@ TEST_F(BatchMultiVector, MultipleVectorAddScaledIsEquivalentToRef)
x->add_scaled(alpha.get(), y.get());
dx->add_scaled(dalpha.get(), dy.get());

GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 1e-14);
GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 5 * r<value_type>::value);
}


Expand All @@ -171,7 +172,7 @@ TEST_F(BatchMultiVector,
x->add_scaled(alpha.get(), y.get());
dx->add_scaled(dalpha.get(), dy.get());

GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 1e-14);
GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 5 * r<value_type>::value);
}


Expand All @@ -182,7 +183,7 @@ TEST_F(BatchMultiVector, SingleVectorScaleIsEquivalentToRef)
x->scale(alpha.get());
dx->scale(dalpha.get());

GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 1e-14);
GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 5 * r<value_type>::value);
}


Expand All @@ -193,7 +194,7 @@ TEST_F(BatchMultiVector, MultipleVectorScaleIsEquivalentToRef)
x->scale(alpha.get());
dx->scale(dalpha.get());

GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 1e-14);
GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 5 * r<value_type>::value);
}


Expand All @@ -204,7 +205,7 @@ TEST_F(BatchMultiVector, MultipleVectorScaleWithDifferentAlphaIsEquivalentToRef)
x->scale(alpha.get());
dx->scale(dalpha.get());

GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 1e-14);
GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 5 * r<value_type>::value);
}


Expand All @@ -219,7 +220,7 @@ TEST_F(BatchMultiVector, ComputeNorm2SingleIsEquivalentToRef)
x->compute_norm2(norm_expected.get());
dx->compute_norm2(dnorm.get());

GKO_ASSERT_BATCH_MTX_NEAR(norm_expected, dnorm, 1e-14);
GKO_ASSERT_BATCH_MTX_NEAR(norm_expected, dnorm, 5 * r<value_type>::value);
}


Expand All @@ -234,7 +235,7 @@ TEST_F(BatchMultiVector, ComputeNorm2IsEquivalentToRef)
x->compute_norm2(norm_expected.get());
dx->compute_norm2(dnorm.get());

GKO_ASSERT_BATCH_MTX_NEAR(norm_expected, dnorm, 1e-14);
GKO_ASSERT_BATCH_MTX_NEAR(norm_expected, dnorm, 5 * r<value_type>::value);
}


Expand All @@ -249,7 +250,7 @@ TEST_F(BatchMultiVector, ComputeDotIsEquivalentToRef)
x->compute_dot(y.get(), dot_expected.get());
dx->compute_dot(dy.get(), ddot.get());

GKO_ASSERT_BATCH_MTX_NEAR(dot_expected, ddot, 1e-14);
GKO_ASSERT_BATCH_MTX_NEAR(dot_expected, ddot, 5 * r<value_type>::value);
}


Expand All @@ -264,7 +265,7 @@ TEST_F(BatchMultiVector, ComputeDotSingleIsEquivalentToRef)
x->compute_dot(y.get(), dot_expected.get());
dx->compute_dot(dy.get(), ddot.get());

GKO_ASSERT_BATCH_MTX_NEAR(dot_expected, ddot, 1e-14);
GKO_ASSERT_BATCH_MTX_NEAR(dot_expected, ddot, 5 * r<value_type>::value);
}


Expand Down

0 comments on commit ab81b76

Please sign in to comment.