Skip to content

Commit

Permalink
unify cuda/hip batch_mvec
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Aug 19, 2024
1 parent 9f1c41b commit c960038
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 161 deletions.
1 change: 1 addition & 0 deletions common/cuda_hip/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include(${PROJECT_SOURCE_DIR}/cmake/template_instantiation.cmake)
set(CUDA_HIP_SOURCES
base/batch_multi_vector_kernels.cpp
base/device_matrix_data_kernels.cpp
base/index_set_kernels.cpp
components/prefix_sum_kernels.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,32 @@
//
// SPDX-License-Identifier: BSD-3-Clause

#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp"

#include <thrust/functional.h>
#include <thrust/transform.h>

#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/types.hpp>

#include "common/cuda_hip/base/config.hpp"
#include "common/cuda_hip/base/math.hpp"
#include "common/cuda_hip/base/runtime.hpp"
#include "core/base/batch_multi_vector_kernels.hpp"
#include "core/base/batch_struct.hpp"


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
namespace batch_multi_vector {


constexpr auto default_block_size = 256;


template <typename ValueType>
void scale(std::shared_ptr<const DefaultExecutor> exec,
const batch::MultiVector<ValueType>* const alpha,
Expand All @@ -11,16 +37,19 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
const auto alpha_ub = get_batch_struct(alpha);
const auto x_ub = get_batch_struct(x);
if (alpha->get_common_size()[1] == 1) {
scale_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
batch_single_kernels::scale_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
alpha_ub, x_ub,
[] __device__(int row, int col, int stride) { return 0; });
} else if (alpha->get_common_size() == x->get_common_size()) {
scale_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
batch_single_kernels::scale_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
alpha_ub, x_ub, [] __device__(int row, int col, int stride) {
return row * stride + col;
});
} else {
scale_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
batch_single_kernels::scale_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
alpha_ub, x_ub,
[] __device__(int row, int col, int stride) { return col; });
}
Expand All @@ -42,12 +71,12 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
const auto x_ub = get_batch_struct(x);
const auto y_ub = get_batch_struct(y);
if (alpha->get_common_size()[1] == 1) {
add_scaled_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
batch_single_kernels::add_scaled_kernel<<<
num_blocks, default_block_size, 0, exec->get_stream()>>>(
alpha_ub, x_ub, y_ub, [] __device__(int col) { return 0; });
} else {
add_scaled_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
batch_single_kernels::add_scaled_kernel<<<
num_blocks, default_block_size, 0, exec->get_stream()>>>(
alpha_ub, x_ub, y_ub, [] __device__(int col) { return col; });
}
}
Expand All @@ -67,8 +96,8 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
const auto x_ub = get_batch_struct(x);
const auto y_ub = get_batch_struct(y);
const auto res_ub = get_batch_struct(result);
compute_gen_dot_product_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
batch_single_kernels::compute_gen_dot_product_kernel<<<
num_blocks, default_block_size, 0, exec->get_stream()>>>(
x_ub, y_ub, res_ub, [] __device__(auto val) { return val; });
}

Expand All @@ -87,8 +116,8 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
const auto x_ub = get_batch_struct(x);
const auto y_ub = get_batch_struct(y);
const auto res_ub = get_batch_struct(result);
compute_gen_dot_product_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
batch_single_kernels::compute_gen_dot_product_kernel<<<
num_blocks, default_block_size, 0, exec->get_stream()>>>(
x_ub, y_ub, res_ub, [] __device__(auto val) { return conj(val); });
}

Expand All @@ -105,8 +134,9 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
const auto num_rhs = x->get_common_size()[1];
const auto x_ub = get_batch_struct(x);
const auto res_ub = get_batch_struct(result);
compute_norm2_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(x_ub, res_ub);
batch_single_kernels::compute_norm2_kernel<<<num_blocks, default_block_size,
0, exec->get_stream()>>>(
x_ub, res_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
Expand All @@ -121,8 +151,15 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
const auto num_blocks = x->get_num_batch_items();
const auto result_ub = get_batch_struct(result);
const auto x_ub = get_batch_struct(x);
copy_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
x_ub, result_ub);
batch_single_kernels::
copy_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
x_ub, result_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);


} // namespace batch_multi_vector
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
Loading

0 comments on commit c960038

Please sign in to comment.