Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup to use pass-by-reference more consistently #861

Merged
merged 2 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions include/matx/operators/solve.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ class SolveOp : public BaseOp<SolveOp<OpA, OpB>> {
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape,
[[maybe_unused]] Executor &&ex) const noexcept {
__MATX_INLINE__ void
InnerPreRun([[maybe_unused]] ShapeType &&shape,
[[maybe_unused]] Executor &&ex) const noexcept {
static_assert(is_sparse_tensor_v<OpA>,
"Direct solver currently only supports sparse system");
"Direct solver currently only supports sparse system");
if constexpr (is_matx_op<OpB>()) {
b_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
Expand All @@ -122,9 +123,9 @@ class SolveOp : public BaseOp<SolveOp<OpA, OpB>> {

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PostRun([[maybe_unused]] ShapeType &&shape,
[[maybe_unused]]Executor &&ex) const noexcept {
[[maybe_unused]] Executor &&ex) const noexcept {
static_assert(is_sparse_tensor_v<OpA>,
"Direct solver currently only supports sparse system");
"Direct solver currently only supports sparse system");
if constexpr (is_matx_op<OpB>()) {
b_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
Expand Down
5 changes: 2 additions & 3 deletions include/matx/operators/sparse2dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,12 @@ class Sparse2DenseOp : public BaseOp<Sparse2DenseOp<OpA>> {
template <typename Out, typename Executor>
void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const {
if constexpr (is_sparse_tensor_v<OpA>) {
auto ref = cuda::std::get<0>(out);
using Rtype = decltype(ref);
using Rtype = decltype(cuda::std::get<0>(out));
if constexpr (is_sparse_tensor_v<Rtype>) {
MATX_THROW(matxNotSupported,
"Cannot use sparse2dense for sparse output");
} else {
sparse2dense_impl(ref, a_, ex);
sparse2dense_impl(cuda::std::get<0>(out), a_, ex);
}
} else {
MATX_THROW(matxNotSupported, "Cannot use sparse2dense on dense input");
Expand Down
19 changes: 3 additions & 16 deletions include/matx/transforms/convert/sparse2dense_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,33 +215,20 @@ using sparse2dense_cache_t =

} // end namespace detail

template <typename Op>
__MATX_INLINE__ auto getSparse2DenseSupportedTensor(const Op &in,
cudaStream_t stream) {
const auto support_func = [&in]() { return true; };
return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream);
}

template <typename OutputTensorType, typename InputTensorType>
void sparse2dense_impl(OutputTensorType O, const InputTensorType A,
void sparse2dense_impl(OutputTensorType &o, const InputTensorType &a,
const cudaExecutor &exec) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
const auto stream = exec.getStream();

auto a = A; // always sparse
auto o = getSparse2DenseSupportedTensor(O, stream);

// TODO: some more checking, supported type? on device? etc.

using atype = decltype(a);
using otype = decltype(o);

// Get parameters required by these tensors (for caching).
auto params =
detail::Sparse2DenseHandle_t<otype, atype>::GetConvParams(o, a, stream);
detail::Sparse2DenseHandle_t<OutputTensorType, InputTensorType>::GetConvParams(o, a, stream);

// Lookup and cache.
using cache_val_type = detail::Sparse2DenseHandle_t<otype, atype>;
using cache_val_type = detail::Sparse2DenseHandle_t<OutputTensorType, InputTensorType>;
detail::GetCache().LookupAndExec<detail::sparse2dense_cache_t>(
detail::GetCacheIdFromType<detail::sparse2dense_cache_t>(), params,
[&]() { return std::make_shared<cache_val_type>(o, a, stream); },
Expand Down
21 changes: 3 additions & 18 deletions include/matx/transforms/matmul/matmul_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,37 +254,22 @@ using gemm_cusparse_cache_t =

} // end namespace detail

template <typename Op>
__MATX_INLINE__ auto getCUSPARSESupportedTensor(const Op &in,
cudaStream_t stream) {
const auto support_func = [&in]() { return true; };
return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream);
}

template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
void sparse_matmul_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B,
void sparse_matmul_impl(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b,
const cudaExecutor &exec, float alpha = 1.0,
float beta = 0.0) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
const auto stream = exec.getStream();

auto a = A; // always sparse
auto b = getCUSPARSESupportedTensor(B, stream);
auto c = getCUSPARSESupportedTensor(C, stream);

// TODO: some more checking, supported type? on device? etc.

using atype = decltype(a);
using btype = decltype(b);
using ctype = decltype(c);

// Get parameters required by these tensors (for caching).
auto params =
detail::MatMulCUSPARSEHandle_t<ctype, atype, btype>::GetGemmParams(
detail::MatMulCUSPARSEHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>::GetGemmParams(
c, a, b, stream, alpha, beta);

// Lookup and cache.
using cache_val_type = detail::MatMulCUSPARSEHandle_t<ctype, atype, btype>;
using cache_val_type = detail::MatMulCUSPARSEHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>;
detail::GetCache().LookupAndExec<detail::gemm_cusparse_cache_t>(
detail::GetCacheIdFromType<detail::gemm_cusparse_cache_t>(), params,
[&]() {
Expand Down
41 changes: 13 additions & 28 deletions include/matx/transforms/solve/solve_cudss.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,35 +238,20 @@ using gemm_cudss_cache_t =

} // end namespace detail

template <typename Op>
__MATX_INLINE__ auto getCUDSSSupportedTensor(const Op &in,
cudaStream_t stream) {
const auto support_func = [&in]() { return true; };
return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream);
}

template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
void sparse_solve_impl_trans(TensorTypeC C, const TensorTypeA A,
const TensorTypeB B, const cudaExecutor &exec) {
void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a,
const TensorTypeB &b, const cudaExecutor &exec) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
const auto stream = exec.getStream();

auto a = A; // always sparse
auto b = getCUDSSSupportedTensor(B, stream);
auto c = getCUDSSSupportedTensor(C, stream);

// TODO: some more checking, supported type? on device? etc.

using atype = decltype(a);
using btype = decltype(b);
using ctype = decltype(c);

// Get parameters required by these tensors (for caching).
auto params = detail::SolveCUDSSHandle_t<ctype, atype, btype>::GetSolveParams(
auto params = detail::SolveCUDSSHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>::GetSolveParams(
c, a, b, stream);

// Lookup and cache.
using cache_val_type = detail::SolveCUDSSHandle_t<ctype, atype, btype>;
using cache_val_type = detail::SolveCUDSSHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>;
detail::GetCache().LookupAndExec<detail::gemm_cudss_cache_t>(
detail::GetCacheIdFromType<detail::gemm_cudss_cache_t>(), params,
[&]() { return std::make_shared<cache_val_type>(c, a, b, stream); },
Expand All @@ -282,29 +267,29 @@ void sparse_solve_impl_trans(TensorTypeC C, const TensorTypeA A,
// supports MATX native row-major storage, which will clean up the copies from
// and to memory.
template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
void sparse_solve_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B,
const cudaExecutor &exec) {
void sparse_solve_impl(TensorTypeC &c, const TensorTypeA &a,
const TensorTypeB &b, const cudaExecutor &exec) {
const auto stream = exec.getStream();

// Some copying-in hacks, assumes rank 2.
using TB = typename TensorTypeB::value_type;
using TC = typename TensorTypeB::value_type;
TB *bptr;
matxAlloc(reinterpret_cast<void **>(&bptr),
sizeof(TB) * B.Size(0) * B.Size(1), MATX_ASYNC_DEVICE_MEMORY,
sizeof(TB) * b.Size(0) * b.Size(1), MATX_ASYNC_DEVICE_MEMORY,
stream);
auto bT = make_tensor(bptr, {B.Size(1), B.Size(0)});
(bT = transpose(B)).run(exec);
auto bT = make_tensor(bptr, {b.Size(1), b.Size(0)});
(bT = transpose(b)).run(exec);
TC *cptr;
matxAlloc(reinterpret_cast<void **>(&cptr),
sizeof(TC) * C.Size(0) * C.Size(1), MATX_ASYNC_DEVICE_MEMORY,
sizeof(TC) * c.Size(0) * c.Size(1), MATX_ASYNC_DEVICE_MEMORY,
stream);
auto cT = make_tensor(cptr, {C.Size(1), C.Size(0)});
auto cT = make_tensor(cptr, {c.Size(1), c.Size(0)});

sparse_solve_impl_trans(cT, A, bT, exec);
sparse_solve_impl_trans(cT, a, bT, exec);

// Some copying-back hacks.
(C = transpose(cT)).run(exec);
(c = transpose(cT)).run(exec);
}

} // end namespace matx