From b01dbebabd6fdbc29f81f24c5b43f054adc121e4 Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Tue, 4 Feb 2025 16:00:23 -0800 Subject: [PATCH 1/2] Cleanup to use pass-by-reference more consistently --- include/matx/operators/solve.h | 12 +++--- include/matx/operators/sparse2dense.h | 6 +-- .../convert/sparse2dense_cusparse.h | 19 ++------- .../matx/transforms/matmul/matmul_cusparse.h | 21 ++-------- include/matx/transforms/solve/solve_cudss.h | 41 ++++++------------- 5 files changed, 27 insertions(+), 72 deletions(-) diff --git a/include/matx/operators/solve.h b/include/matx/operators/solve.h index 07423809..3b5a2ef8 100644 --- a/include/matx/operators/solve.h +++ b/include/matx/operators/solve.h @@ -56,7 +56,6 @@ class SolveOp : public BaseOp> { using matxop = bool; using matx_transform_op = bool; using solve_xform_op = bool; - using value_type = typename OpA::value_type; __MATX_INLINE__ SolveOp(const OpA &a, const OpB &b) : a_(a), b_(b) { for (int r = 0, rank = Rank(); r < rank; r++) { @@ -102,10 +101,11 @@ class SolveOp : public BaseOp> { } template - __MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape, - [[maybe_unused]] Executor &&ex) const noexcept { + __MATX_INLINE__ void + InnerPreRun([[maybe_unused]] ShapeType &&shape, + [[maybe_unused]] Executor &&ex) const noexcept { static_assert(is_sparse_tensor_v, - "Direct solver currently only supports sparse system"); + "Direct solver currently only supports sparse system"); if constexpr (is_matx_op()) { b_.PreRun(std::forward(shape), std::forward(ex)); } @@ -122,9 +122,9 @@ class SolveOp : public BaseOp> { template __MATX_INLINE__ void PostRun([[maybe_unused]] ShapeType &&shape, - [[maybe_unused]]Executor &&ex) const noexcept { + [[maybe_unused]] Executor &&ex) const noexcept { static_assert(is_sparse_tensor_v, - "Direct solver currently only supports sparse system"); + "Direct solver currently only supports sparse system"); if constexpr (is_matx_op()) { b_.PostRun(std::forward(shape), std::forward(ex)); } diff --git a/include/matx/operators/sparse2dense.h b/include/matx/operators/sparse2dense.h index 55d3c751..4231e4d7 100644 --- a/include/matx/operators/sparse2dense.h +++ b/include/matx/operators/sparse2dense.h @@ -53,7 +53,6 @@ class Sparse2DenseOp : public BaseOp> { using matxop = bool; using matx_transform_op = bool; using sparse2dense_xform_op = bool; - using value_type = typename OpA::value_type; __MATX_INLINE__ Sparse2DenseOp(const OpA &a) : a_(a) { for (int r = 0; r < Rank(); r++) { @@ -86,13 +85,12 @@ class Sparse2DenseOp : public BaseOp> { template void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const { if constexpr (is_sparse_tensor_v) { - auto ref = cuda::std::get<0>(out); - using Rtype = decltype(ref); + using Rtype = decltype(cuda::std::get<0>(out)); if constexpr (is_sparse_tensor_v) { MATX_THROW(matxNotSupported, "Cannot use sparse2dense for sparse output"); } else { - sparse2dense_impl(ref, a_, ex); + sparse2dense_impl(cuda::std::get<0>(out), a_, ex); } } else { MATX_THROW(matxNotSupported, "Cannot use sparse2dense on dense input"); diff --git a/include/matx/transforms/convert/sparse2dense_cusparse.h b/include/matx/transforms/convert/sparse2dense_cusparse.h index 8373b0f1..24d95dbb 100644 --- a/include/matx/transforms/convert/sparse2dense_cusparse.h +++ b/include/matx/transforms/convert/sparse2dense_cusparse.h @@ -215,33 +215,20 @@ using sparse2dense_cache_t = } // end namespace detail -template -__MATX_INLINE__ auto getSparse2DenseSupportedTensor(const Op &in, - cudaStream_t stream) { - const auto support_func = [&in]() { return true; }; - return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream); -} - template -void sparse2dense_impl(OutputTensorType O, const InputTensorType A, +void sparse2dense_impl(OutputTensorType &o, const InputTensorType &a, const cudaExecutor &exec) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) const auto stream = exec.getStream(); - auto a = A; // always sparse - auto o = getSparse2DenseSupportedTensor(O, stream); - // TODO: some more checking, supported type? on device? etc. - using atype = decltype(a); - using otype = decltype(o); - // Get parameters required by these tensors (for caching). auto params = - detail::Sparse2DenseHandle_t::GetConvParams(o, a, stream); + detail::Sparse2DenseHandle_t::GetConvParams(o, a, stream); // Lookup and cache. - using cache_val_type = detail::Sparse2DenseHandle_t; + using cache_val_type = detail::Sparse2DenseHandle_t; detail::GetCache().LookupAndExec( detail::GetCacheIdFromType(), params, [&]() { return std::make_shared(o, a, stream); }, diff --git a/include/matx/transforms/matmul/matmul_cusparse.h b/include/matx/transforms/matmul/matmul_cusparse.h index 075204fe..5fa06fbf 100644 --- a/include/matx/transforms/matmul/matmul_cusparse.h +++ b/include/matx/transforms/matmul/matmul_cusparse.h @@ -254,37 +254,22 @@ using gemm_cusparse_cache_t = } // end namespace detail -template -__MATX_INLINE__ auto getCUSPARSESupportedTensor(const Op &in, - cudaStream_t stream) { - const auto support_func = [&in]() { return true; }; - return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream); -} - template -void sparse_matmul_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B, +void sparse_matmul_impl(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b, const cudaExecutor &exec, float alpha = 1.0, float beta = 0.0) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) const auto stream = exec.getStream(); - auto a = A; // always sparse - auto b = getCUSPARSESupportedTensor(B, stream); - auto c = getCUSPARSESupportedTensor(C, stream); - // TODO: some more checking, supported type? on device? etc. - using atype = decltype(a); - using btype = decltype(b); - using ctype = decltype(c); - // Get parameters required by these tensors (for caching). auto params = - detail::MatMulCUSPARSEHandle_t::GetGemmParams( + detail::MatMulCUSPARSEHandle_t::GetGemmParams( c, a, b, stream, alpha, beta); // Lookup and cache. - using cache_val_type = detail::MatMulCUSPARSEHandle_t; + using cache_val_type = detail::MatMulCUSPARSEHandle_t; detail::GetCache().LookupAndExec( detail::GetCacheIdFromType(), params, [&]() { diff --git a/include/matx/transforms/solve/solve_cudss.h b/include/matx/transforms/solve/solve_cudss.h index 8299b36e..423376e3 100644 --- a/include/matx/transforms/solve/solve_cudss.h +++ b/include/matx/transforms/solve/solve_cudss.h @@ -238,35 +238,20 @@ using gemm_cudss_cache_t = } // end namespace detail -template -__MATX_INLINE__ auto getCUDSSSupportedTensor(const Op &in, - cudaStream_t stream) { - const auto support_func = [&in]() { return true; }; - return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream); -} - template -void sparse_solve_impl_trans(TensorTypeC C, const TensorTypeA A, - const TensorTypeB B, const cudaExecutor &exec) { +void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a, + const TensorTypeB &b, const cudaExecutor &exec) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) const auto stream = exec.getStream(); - auto a = A; // always sparse - auto b = getCUDSSSupportedTensor(B, stream); - auto c = getCUDSSSupportedTensor(C, stream); - // TODO: some more checking, supported type? on device? etc. - using atype = decltype(a); - using btype = decltype(b); - using ctype = decltype(c); - // Get parameters required by these tensors (for caching). - auto params = detail::SolveCUDSSHandle_t::GetSolveParams( + auto params = detail::SolveCUDSSHandle_t::GetSolveParams( c, a, b, stream); // Lookup and cache. - using cache_val_type = detail::SolveCUDSSHandle_t; + using cache_val_type = detail::SolveCUDSSHandle_t; detail::GetCache().LookupAndExec( detail::GetCacheIdFromType(), params, [&]() { return std::make_shared(c, a, b, stream); }, @@ -282,8 +267,8 @@ void sparse_solve_impl_trans(TensorTypeC C, const TensorTypeA A, // supports MATX native row-major storage, which will clean up the copies from // and to memory. template -void sparse_solve_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B, - const cudaExecutor &exec) { +void sparse_solve_impl(TensorTypeC &c, const TensorTypeA &a, + const TensorTypeB &b, const cudaExecutor &exec) { const auto stream = exec.getStream(); // Some copying-in hacks, assumes rank 2. @@ -291,20 +276,20 @@ void sparse_solve_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B, using TC = typename TensorTypeB::value_type; TB *bptr; matxAlloc(reinterpret_cast(&bptr), - sizeof(TB) * B.Size(0) * B.Size(1), MATX_ASYNC_DEVICE_MEMORY, + sizeof(TB) * b.Size(0) * b.Size(1), MATX_ASYNC_DEVICE_MEMORY, stream); - auto bT = make_tensor(bptr, {B.Size(1), B.Size(0)}); - (bT = transpose(B)).run(exec); + auto bT = make_tensor(bptr, {b.Size(1), b.Size(0)}); + (bT = transpose(b)).run(exec); TC *cptr; matxAlloc(reinterpret_cast(&cptr), - sizeof(TC) * C.Size(0) * C.Size(1), MATX_ASYNC_DEVICE_MEMORY, + sizeof(TC) * c.Size(0) * c.Size(1), MATX_ASYNC_DEVICE_MEMORY, stream); - auto cT = make_tensor(cptr, {C.Size(1), C.Size(0)}); + auto cT = make_tensor(cptr, {c.Size(1), c.Size(0)}); - sparse_solve_impl_trans(cT, A, bT, exec); + sparse_solve_impl_trans(cT, a, bT, exec); // Some copying-back hacks. - (C = transpose(cT)).run(exec); + (c = transpose(cT)).run(exec); } } // end namespace matx From b2c424a39f43ccf107c47cbea695e422c96086d1 Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Tue, 4 Feb 2025 16:14:01 -0800 Subject: [PATCH 2/2] put value_type back --- include/matx/operators/solve.h | 1 + include/matx/operators/sparse2dense.h | 1 + 2 files changed, 2 insertions(+) diff --git a/include/matx/operators/solve.h b/include/matx/operators/solve.h index 3b5a2ef8..2fc62181 100644 --- a/include/matx/operators/solve.h +++ b/include/matx/operators/solve.h @@ -56,6 +56,7 @@ class SolveOp : public BaseOp> { using matxop = bool; using matx_transform_op = bool; using solve_xform_op = bool; + using value_type = typename OpA::value_type; __MATX_INLINE__ SolveOp(const OpA &a, const OpB &b) : a_(a), b_(b) { for (int r = 0, rank = Rank(); r < rank; r++) { diff --git a/include/matx/operators/sparse2dense.h b/include/matx/operators/sparse2dense.h index 4231e4d7..fe8acef2 100644 --- a/include/matx/operators/sparse2dense.h +++ b/include/matx/operators/sparse2dense.h @@ -53,6 +53,7 @@ class Sparse2DenseOp : public BaseOp> { using matxop = bool; using matx_transform_op = bool; using sparse2dense_xform_op = bool; + using value_type = typename OpA::value_type; __MATX_INLINE__ Sparse2DenseOp(const OpA &a) : a_(a) { for (int r = 0; r < Rank(); r++) {