Skip to content

Commit

Permalink
Review updates
Browse files Browse the repository at this point in the history
Co-authored-by: Yu-Hsiang Tsai<[email protected]>
  • Loading branch information
pratikvn committed Jul 27, 2023
1 parent 57539bc commit 62e84a2
Show file tree
Hide file tree
Showing 16 changed files with 103 additions and 147 deletions.
38 changes: 19 additions & 19 deletions common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************<GINKGO LICENSE>*******************************/


/**
* Scales the vectors in global or shared memory with a factor of alpha (alpha
* is in global memory or shared memory)
*/
template <typename ValueType, typename Mapping>
__device__ __forceinline__ void scale(
const gko::batch_multi_vector::batch_entry<const ValueType>& alpha,
Expand All @@ -52,9 +48,9 @@ __device__ __forceinline__ void scale(

template <typename ValueType, typename Mapping>
__global__
__launch_bounds__(default_block_size, sm_multiplier) void scale_kernel(
const gko::batch_multi_vector::uniform_batch<const ValueType> alpha,
const gko::batch_multi_vector::uniform_batch<ValueType> x, Mapping map)
__launch_bounds__(default_block_size, sm_multiplier) void scale_kernel(
const gko::batch_multi_vector::uniform_batch<const ValueType> alpha,
const gko::batch_multi_vector::uniform_batch<ValueType> x, Mapping map)
{
for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_entries;
batch_id += gridDim.x) {
Expand Down Expand Up @@ -83,10 +79,10 @@ __device__ __forceinline__ void add_scaled(

template <typename ValueType, typename Mapping>
__global__
__launch_bounds__(default_block_size, sm_multiplier) void add_scaled_kernel(
const gko::batch_multi_vector::uniform_batch<const ValueType> alpha,
const gko::batch_multi_vector::uniform_batch<const ValueType> x,
const gko::batch_multi_vector::uniform_batch<ValueType> y, Mapping map)
__launch_bounds__(default_block_size, sm_multiplier) void add_scaled_kernel(
const gko::batch_multi_vector::uniform_batch<const ValueType> alpha,
const gko::batch_multi_vector::uniform_batch<const ValueType> x,
const gko::batch_multi_vector::uniform_batch<ValueType> y, Mapping map)
{
for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_entries;
batch_id += gridDim.x) {
Expand Down Expand Up @@ -222,11 +218,15 @@ __device__ __forceinline__ void compute_norm2(


template <typename ValueType>
__global__
__launch_bounds__(default_block_size, sm_multiplier) void compute_norm2_kernel(
const gko::batch_multi_vector::uniform_batch<const ValueType> x,
const gko::batch_multi_vector::uniform_batch<remove_complex<ValueType>>
result)
__global__ __launch_bounds__(
default_block_size,
sm_multiplier) void compute_norm2_kernel(const gko::batch_multi_vector::
uniform_batch<const ValueType>
x,
const gko::batch_multi_vector::
uniform_batch<
remove_complex<ValueType>>
result)
{
for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_entries;
batch_id += gridDim.x) {
Expand Down Expand Up @@ -259,9 +259,9 @@ __device__ __forceinline__ void copy(

template <typename ValueType>
__global__
__launch_bounds__(default_block_size, sm_multiplier) void copy_kernel(
const gko::batch_multi_vector::uniform_batch<const ValueType> src,
const gko::batch_multi_vector::uniform_batch<ValueType> dst)
__launch_bounds__(default_block_size, sm_multiplier) void copy_kernel(
const gko::batch_multi_vector::uniform_batch<const ValueType> src,
const gko::batch_multi_vector::uniform_batch<ValueType> dst)
{
for (size_type batch_id = blockIdx.x; batch_id < src.num_batch_entries;
batch_id += gridDim.x) {
Expand Down
15 changes: 7 additions & 8 deletions core/base/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,20 @@ namespace batch_multi_vector {


/**
* Encapsulates one matrix from a batch of dense matrices (vectors).
* Encapsulates one matrix from a batch of multi-vectors.
*/
template <typename ValueType>
struct batch_entry {
using value_type = ValueType;
ValueType* values;
size_type stride;
int stride;
int num_rows;
int num_rhs;
};


/**
* A 'simple' structure to store a global uniform batch of dense matrices.
*
* It is uniform in the sense that all matrices in the batch have common sizes.
* A 'simple' structure to store a global uniform batch of multi-vectors.
*/
template <typename ValueType>
struct uniform_batch {
Expand All @@ -67,7 +66,7 @@ struct uniform_batch {

ValueType* values;
size_type num_batch_entries;
size_type stride;
int stride;
int num_rows;
int num_rhs;

Expand Down Expand Up @@ -122,8 +121,8 @@ batch_entry(const batch_multi_vector::uniform_batch<ValueType>& batch,

template <typename ValueType>
GKO_ATTRIBUTES GKO_INLINE batch_multi_vector::batch_entry<ValueType>
batch_entry(ValueType* const batch_values, const size_type stride,
const int num_rows, const int num_rhs, const size_type batch_idx)
batch_entry(ValueType* const batch_values, const int stride, const int num_rows,
const int num_rhs, const size_type batch_idx)
{
return {batch_values + batch_idx * stride * num_rows, stride, num_rows,
num_rhs};
Expand Down
51 changes: 25 additions & 26 deletions core/test/base/batch_multi_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class BatchMultiVector : public ::testing::Test {
{
ASSERT_EQ(m->get_num_batch_entries(), 0);
ASSERT_EQ(m->get_common_size(), gko::dim<2>{});
ASSERT_EQ(m->get_const_values(), nullptr);
}

std::shared_ptr<const gko::Executor> exec;
Expand All @@ -100,13 +101,6 @@ TYPED_TEST(BatchMultiVector, CanBeEmpty)
}


TYPED_TEST(BatchMultiVector, ReturnsNullValuesArrayWhenEmpty)
{
auto empty = gko::BatchMultiVector<TypeParam>::create(this->exec);
ASSERT_EQ(empty->get_const_values(), nullptr);
}


TYPED_TEST(BatchMultiVector, KnowsItsSizeAndValues)
{
ASSERT_NE(this->mtx->get_const_values(), nullptr);
Expand Down Expand Up @@ -165,10 +159,12 @@ TYPED_TEST(BatchMultiVector, CanBeConstructedFromExistingData)
using size_type = gko::size_type;
// clang-format off
value_type data[] = {
1.0, 2.0, -1.0,
3.0, 4.0, -1.0,
3.0, 5.0, 1.0,
5.0, 6.0, -3.0};
1.0, 2.0,
-1.0,3.0,
4.0, -1.0,
3.0, 5.0,
1.0, 5.0,
6.0, -3.0};
// clang-format on

auto m = gko::BatchMultiVector<TypeParam>::create(
Expand All @@ -192,11 +188,13 @@ TYPED_TEST(BatchMultiVector, CanBeConstructedFromExistingConstData)
using value_type = typename TestFixture::value_type;
using size_type = gko::size_type;
// clang-format off
const value_type data[] = {
1.0, 2.0, -1.0,
3.0, 4.0, -1.0,
3.0, 5.0, 1.0,
5.0, 6.0, -3.0};
value_type data[] = {
1.0, 2.0,
-1.0,3.0,
4.0, -1.0,
3.0, 5.0,
1.0, 5.0,
6.0, -3.0};
// clang-format on

auto m = gko::BatchMultiVector<TypeParam>::create_const(
Expand All @@ -215,7 +213,7 @@ TYPED_TEST(BatchMultiVector, CanBeConstructedFromExistingConstData)
}


TYPED_TEST(BatchMultiVector, CanBeConstructedFromBatchMultiVectorMatrices)
TYPED_TEST(BatchMultiVector, CanBeConstructedFromDenseMatrices)
{
using value_type = typename TestFixture::value_type;
using DenseMtx = typename TestFixture::DenseMtx;
Expand All @@ -227,12 +225,8 @@ TYPED_TEST(BatchMultiVector, CanBeConstructedFromBatchMultiVectorMatrices)

auto m = gko::BatchMultiVector<TypeParam>::create(
this->exec, std::vector<DenseMtx*>{mat1.get(), mat2.get()});
auto m_ref = gko::BatchMultiVector<TypeParam>::create(
this->exec, std::vector<DenseMtx*>{mat1.get(), mat2.get(), mat1.get(),
mat2.get(), mat1.get(), mat2.get()});
auto m2 = gko::BatchMultiVector<TypeParam>::create(this->exec, 3, m.get());

GKO_ASSERT_BATCH_MTX_NEAR(m2.get(), m_ref.get(), 1e-14);
this->assert_equal_to_original_mtx(m.get());
}


Expand All @@ -255,7 +249,7 @@ TYPED_TEST(BatchMultiVector, CanBeConstructedFromDenseMatricesByDuplication)
}


TYPED_TEST(BatchMultiVector, CanBeConstructedFromDenseMatrices)
TYPED_TEST(BatchMultiVector, CanBeConstructedFromBatchMultiVectorMatrices)
{
using value_type = typename TestFixture::value_type;
using DenseMtx = typename TestFixture::DenseMtx;
Expand All @@ -264,11 +258,15 @@ TYPED_TEST(BatchMultiVector, CanBeConstructedFromDenseMatrices)
this->exec);
auto mat2 = gko::initialize<DenseMtx>({{1.0, 2.5, 3.0}, {1.0, 2.0, 3.0}},
this->exec);

auto m = gko::BatchMultiVector<TypeParam>::create(
this->exec, std::vector<DenseMtx*>{mat1.get(), mat2.get()});
auto m_ref = gko::BatchMultiVector<TypeParam>::create(
this->exec, std::vector<DenseMtx*>{mat1.get(), mat2.get(), mat1.get(),
mat2.get(), mat1.get(), mat2.get()});

this->assert_equal_to_original_mtx(m.get());
auto m2 = gko::BatchMultiVector<TypeParam>::create(this->exec, 3, m.get());

GKO_ASSERT_BATCH_MTX_NEAR(m2.get(), m_ref.get(), 1e-14);
}


Expand Down Expand Up @@ -356,6 +354,7 @@ TYPED_TEST(BatchMultiVector, CanBeUnbatchedIntoDenseMatrices)

auto dense_mats = this->mtx->unbatch();

ASSERT_EQ(dense_mats.size(), 2);
GKO_ASSERT_MTX_NEAR(dense_mats[0].get(), mat1.get(), 0.);
GKO_ASSERT_MTX_NEAR(dense_mats[1].get(), mat2.get(), 0.);
}
Expand All @@ -380,8 +379,8 @@ TYPED_TEST(BatchMultiVector, CanBeReadFromMatrixData)

ASSERT_EQ(m->get_common_size(), gko::dim<2>(2, 2));
EXPECT_EQ(m->at(0, 0, 0), value_type{1.0});
EXPECT_EQ(m->at(0, 1, 0), value_type{0.0});
EXPECT_EQ(m->at(0, 0, 1), value_type{3.0});
EXPECT_EQ(m->at(0, 1, 0), value_type{0.0});
EXPECT_EQ(m->at(0, 1, 1), value_type{5.0});
EXPECT_EQ(m->at(1, 0, 0), value_type{-1.0});
EXPECT_EQ(m->at(1, 0, 1), value_type{0.5});
Expand Down
13 changes: 5 additions & 8 deletions core/test/utils/assertions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,21 +323,18 @@ ::testing::AssertionResult batch_matrices_near_impl(
const MatrixData2& second, double tolerance)
{
std::vector<double> err;
std::vector<bool> err_flag;
for (size_type b = 0; b < first.size(); ++b) {
auto num_rows = first[b].size[0];
auto num_cols = first[b].size[1];
if (num_rows != second[b].size[0] || num_cols != second[b].size[1]) {
if (first.size() != second.size()) {
return ::testing::AssertionFailure()
<< "Expected matrices of equal size\n\t" << first_expression
<< " is of size [" << num_rows << " x " << num_cols
<< "]\n\t" << second_expression << " is of size ["
<< second[b].size[0] << " x " << second[b].size[1] << "]"
<< " is of size [" << first[b].size[0] << " x "
<< first[b].size[1] << "]\n\t" << second_expression
<< " is of size [" << second[b].size[0] << " x "
<< second[b].size[1] << "]"
<< " for batch " << b;
}

err.push_back(detail::get_relative_error(first[b], second[b]));
err_flag.push_back(err.back() <= tolerance);
}

auto bat = std::find_if(err.begin(), err.end(),
Expand Down
6 changes: 4 additions & 2 deletions cuda/base/batch_multi_vector_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,16 @@ namespace batch_multi_vector {
constexpr auto default_block_size = 256;
constexpr int sm_multiplier = 4;

// clang-format off

// NOTE: DO NOT CHANGE THE ORDERING OF THE INCLUDES
// force-top: on

#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc"
// force-top: off


#include "common/cuda_hip/base/batch_multi_vector_kernel_launcher.hpp.inc"

// clang-format on

} // namespace batch_multi_vector
} // namespace cuda
Expand Down
4 changes: 2 additions & 2 deletions cuda/base/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ inline gko::batch_multi_vector::uniform_batch<const cuda_type<ValueType>>
get_batch_struct(const BatchMultiVector<ValueType>* const op)
{
return {as_cuda_type(op->get_const_values()), op->get_num_batch_entries(),
op->get_common_size()[1],
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1])};
}
Expand All @@ -79,7 +79,7 @@ inline gko::batch_multi_vector::uniform_batch<cuda_type<ValueType>>
get_batch_struct(BatchMultiVector<ValueType>* const op)
{
return {as_cuda_type(op->get_values()), op->get_num_batch_entries(),
op->get_common_size()[1],
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1])};
}
Expand Down
16 changes: 8 additions & 8 deletions dpcpp/base/batch_multi_vector_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,

// Launch a kernel that has nbatches blocks, each block has max group size
if (alpha->get_common_size()[1] == 1) {
(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) {
auto group = item_ct1.get_group();
Expand All @@ -98,7 +98,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
});
});
} else {
(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) {
auto group = item_ct1.get_group();
Expand Down Expand Up @@ -136,7 +136,7 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
const auto x_ub = get_batch_struct(x);
const auto y_ub = get_batch_struct(y);
if (alpha->get_common_size()[1] == 1) {
(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) {
auto group = item_ct1.get_group();
Expand All @@ -149,7 +149,7 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
});
});
} else {
(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) {
auto group = item_ct1.get_group();
Expand Down Expand Up @@ -187,7 +187,7 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
const dim3 grid(num_batches);

// TODO: Remove reqd_sub_group size and use sycl::reduce_over_group
(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
Expand Down Expand Up @@ -225,7 +225,7 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
const dim3 block(group_size);
const dim3 grid(num_batches);

(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
Expand Down Expand Up @@ -262,7 +262,7 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
const dim3 block(group_size);
const dim3 grid(num_batches);

(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
Expand Down Expand Up @@ -296,7 +296,7 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
const dim3 block(group_size);
const dim3 grid(num_batches);

(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) {
auto group = item_ct1.get_group();
Expand Down
4 changes: 2 additions & 2 deletions dpcpp/base/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ inline gko::batch_multi_vector::uniform_batch<const ValueType> get_batch_struct(
const BatchMultiVector<ValueType>* const op)
{
return {op->get_const_values(), op->get_num_batch_entries(),
op->get_common_size()[1],
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1])};
}
Expand All @@ -79,7 +79,7 @@ inline gko::batch_multi_vector::uniform_batch<ValueType> get_batch_struct(
BatchMultiVector<ValueType>* const op)
{
return {op->get_values(), op->get_num_batch_entries(),
op->get_common_size()[1],
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1])};
}
Expand Down
Loading

0 comments on commit 62e84a2

Please sign in to comment.