From 9e65c4468cb7544c1fc2adc34d39ee021fcf0cad Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Mon, 3 Feb 2025 17:05:08 -0800 Subject: [PATCH 1/3] First version of MATX dense2sparse conversion (using dispatch to cuSPARSE) --- examples/sparse_tensor.cu | 25 +- include/matx/core/type_utils.h | 18 ++ include/matx/operators/operators.h | 1 + include/matx/operators/sparse2dense.h | 144 ++++++++++ .../convert/sparse2dense_cusparse.h | 253 ++++++++++++++++++ .../matx/transforms/matmul/matmul_cusparse.h | 31 +-- include/matx/transforms/solve/solve_cudss.h | 2 +- 7 files changed, 438 insertions(+), 36 deletions(-) create mode 100644 include/matx/operators/sparse2dense.h create mode 100644 include/matx/transforms/convert/sparse2dense_cusparse.h diff --git a/examples/sparse_tensor.cu b/examples/sparse_tensor.cu index b2659991..96caaa6c 100644 --- a/examples/sparse_tensor.cu +++ b/examples/sparse_tensor.cu @@ -90,24 +90,33 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) // // A very naive way to convert the sparse matrix back to a dense // matrix. Note that one should **never** use the ()-operator in - // performance critical code, since sparse data structures do + // performance critical code, since sparse storage formats do // not provide O(1) random access to their elements (compressed // levels will use some form of search to determine if an element // is present). Instead, conversions (and other operations) should - // use sparse operations that are tailored for the sparse data - // structure (such as scanning by row for CSR). + // use sparse operations that are tailored for the sparse storage + // format (such as scanning by row for CSR). // - auto A = make_tensor({4, 8}); + auto A1 = make_tensor({4, 8}); for (index_t i = 0; i < 4; i++) { for (index_t j = 0; j < 8; j++) { - A(i, j) = Acoo(i, j); + A1(i, j) = Acoo(i, j); } } - print(A); + print(A1); // - // SpMM is implemented on COO through cuSPARSE. This is the - // correct way of performing an efficient sparse operation. + // A direct sparse2dense conversion. This is the correct way of + // performing the conversion, since the underlying implementation + // knows how to properly manipulate the sparse storage format. + // + auto A2 = make_tensor({4, 8}); + (A2 = sparse2dense(Acoo)).run(exec); + print(A2); + + // + // Perform a direct SpMM. This is also the correct way of performing + // an efficient sparse operation. // auto B = make_tensor({8, 4}); auto C = make_tensor({4, 4}); diff --git a/include/matx/core/type_utils.h b/include/matx/core/type_utils.h index cf164db6..c4022eff 100644 --- a/include/matx/core/type_utils.h +++ b/include/matx/core/type_utils.h @@ -38,6 +38,7 @@ #include #include #include +#include #include #include "cuda_fp16.h" @@ -1166,6 +1167,23 @@ template constexpr cublasComputeType_t MatXTypeToCudaComputeType() return CUBLAS_COMPUTE_32F; } + +template +constexpr cusparseIndexType_t MatXTypeToCuSparseIndexType() { + if constexpr (std::is_same_v) { + return CUSPARSE_INDEX_16U; + } + if constexpr (std::is_same_v) { + return CUSPARSE_INDEX_32I; + } + if constexpr (std::is_same_v) { + return CUSPARSE_INDEX_64I; + } + if constexpr (std::is_same_v) { + return CUSPARSE_INDEX_64I; + } +} + } // end namespace detail } // end namespace matx diff --git a/include/matx/operators/operators.h b/include/matx/operators/operators.h index 9d800ecc..a8ed9072 100644 --- a/include/matx/operators/operators.h +++ b/include/matx/operators/operators.h @@ -99,6 +99,7 @@ #include "matx/operators/shift.h" #include "matx/operators/sign.h" #include "matx/operators/slice.h" +#include "matx/operators/sparse2dense.h" #include "matx/operators/solve.h" #include "matx/operators/sort.h" #include "matx/operators/sph2cart.h" diff --git a/include/matx/operators/sparse2dense.h b/include/matx/operators/sparse2dense.h new file mode 100644 index 00000000..0a72d891 --- /dev/null +++ b/include/matx/operators/sparse2dense.h @@ -0,0 +1,144 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2025, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "matx/core/type_utils.h" +#include "matx/operators/base_operator.h" +#include "matx/transforms/convert/sparse2dense_cusparse.h" + +namespace matx { +namespace detail { + +template +class Sparse2DenseOp : public BaseOp> { +private: + typename detail::base_type_t a_; + + static constexpr int out_rank = OpA::Rank(); + cuda::std::array out_dims_; + mutable detail::tensor_impl_t tmp_out_; + mutable typename OpA::value_type *ptr = nullptr; + +public: + using matxop = bool; + using matx_transform_op = bool; + using sparse2dense_xform_op = bool; + using value_type = typename OpA::value_type; + + __MATX_INLINE__ Sparse2DenseOp(const OpA &a) : a_(a) { + for (int r = 0; r < Rank(); r++) { + out_dims_[r] = a_.Size(r); + } + } + + __MATX_INLINE__ std::string str() const { + return "sparse2dense(" + get_type_str(a_) + ")"; + } + + __MATX_HOST__ __MATX_INLINE__ auto Data() const noexcept { return ptr; } + + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) + operator()(Is... indices) const { + return tmp_out_(indices...); + } + + static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t + Rank() { + return remove_cvref_t::Rank(); + } + + constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t + Size(int dim) const { + return out_dims_[dim]; + } + + template + void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const { + if constexpr (is_sparse_tensor_v) { + if constexpr (is_sparse_tensor_v) { + MATX_THROW(matxNotSupported, + "Cannot use sparse2dense for sparse output"); + } else { + sparse2dense_impl(cuda::std::get<0>(out), a_, ex); + } + } else { + MATX_THROW(matxNotSupported, "Cannot use sparse2dense on dense input"); + } + } + + template + __MATX_INLINE__ void + InnerPreRun([[maybe_unused]] ShapeType &&shape, + [[maybe_unused]] Executor &&ex) const noexcept { + static_assert(is_sparse_tensor_v, + "Cannot use sparse2dense on dense input"); + } + + template + __MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, + [[maybe_unused]] Executor &&ex) const noexcept { + InnerPreRun(std::forward(shape), std::forward(ex)); + detail::AllocateTempTensor(tmp_out_, std::forward(ex), out_dims_, + &ptr); + Exec(cuda::std::make_tuple(tmp_out_), std::forward(ex)); + } + + template + __MATX_INLINE__ void PostRun([[maybe_unused]] ShapeType &&shape, + [[maybe_unused]] Executor &&ex) const noexcept { + static_assert(is_sparse_tensor_v, + "Cannot use sparse2dense on dense input"); + matxFree(ptr); + } +}; + +} // end namespace detail + +/** + * Convert a sparse tensor into a dense tensor. + * + * @tparam OpA + * Data type of A tensor + * + * @param A + * Sparse input tensor + * + * @return + * Dense output tensor + */ +template __MATX_INLINE__ auto sparse2dense(const OpA &A) { + return detail::Sparse2DenseOp(A); +} + +} // end namespace matx diff --git a/include/matx/transforms/convert/sparse2dense_cusparse.h b/include/matx/transforms/convert/sparse2dense_cusparse.h new file mode 100644 index 00000000..5ce44161 --- /dev/null +++ b/include/matx/transforms/convert/sparse2dense_cusparse.h @@ -0,0 +1,253 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2025, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include + +#include + +#include "matx/core/cache.h" +#include "matx/core/sparse_tensor.h" +#include "matx/core/tensor.h" + +namespace matx { + +namespace detail { + +/** + * Parameters needed to execute a cuSPARSE sparse2dense. + */ +struct Sparse2DenseParams_t { + MatXDataType_t dtype; + MatXDataType_t ptype; + MatXDataType_t ctype; + int rank; + cudaStream_t stream; + index_t nse; + index_t m; + index_t n; + // Matrix handles in cuSPARSE are data specific (unlike e.g. cuBLAS + // where the same plan can be shared between different data buffers). + void *ptrA0; + void *ptrA1; + void *ptrA2; + void *ptrA3; + void *ptrA4; + void *ptrO; +}; + +template +class Sparse2DenseHandle_t { +public: + using TA = typename TensorTypeA::value_type; + using TO = typename TensorTypeO::value_type; + + static constexpr int RANKA = TensorTypeA::Rank(); + static constexpr int RANKO = TensorTypeO::Rank(); + + /** + * Construct a sparse2dense handle. + */ + Sparse2DenseHandle_t(TensorTypeO &o, const TensorTypeA &a, + cudaStream_t stream) { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + static_assert(RANKA == RANKO); + + params_ = GetConvParams(o, a, stream); + + [[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_); + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxCudaError); + + // Create cuSPARSE handle for sparse matrix A. + static_assert(is_sparse_tensor_v); + cusparseIndexType_t pt = + MatXTypeToCuSparseIndexType(); + cusparseIndexType_t ct = + MatXTypeToCuSparseIndexType(); + cusparseIndexBase_t zb = CUSPARSE_INDEX_BASE_ZERO; + cudaDataType dta = MatXTypeToCudaType(); + if constexpr (TensorTypeA::Format::isCOO()) { + ret = cusparseCreateCoo(&matA_, params_.m, params_.n, params_.nse, + params_.ptrA3, params_.ptrA4, params_.ptrA0, ct, + zb, dta); + } else if constexpr (TensorTypeA::Format::isCSR()) { + ret = cusparseCreateCsr(&matA_, params_.m, params_.n, params_.nse, + params_.ptrA2, params_.ptrA4, params_.ptrA0, pt, + ct, zb, dta); + } else if constexpr (TensorTypeA::Format::isCSC()) { + ret = cusparseCreateCsc(&matA_, params_.m, params_.n, params_.nse, + params_.ptrA2, params_.ptrA4, params_.ptrA0, pt, + ct, zb, dta); + } else { + MATX_THROW(matxNotSupported, + "Sparse2Dense currently only supports COO/CSR/CSC"); + } + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxCudaError); + + // Create cuSPARSE handle for dense matrix O. + static_assert(is_tensor_view_v); + cudaDataType dto = MatXTypeToCudaType(); + const cusparseOrder_t order = CUSPARSE_ORDER_ROW; + ret = cusparseCreateDnMat(&matO_, params_.m, params_.n, /*ld=*/params_.n, + params_.ptrO, dto, order); + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxCudaError); + + // Allocate a workspace for sparse2dense. + const cusparseSparseToDenseAlg_t algo = CUSPARSE_SPARSETODENSE_ALG_DEFAULT; + ret = cusparseSparseToDense_bufferSize(handle_, matA_, matO_, algo, + &workspaceSize_); + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxCudaError); + if (workspaceSize_) { + matxAlloc((void **)&workspace_, workspaceSize_, MATX_DEVICE_MEMORY); + } + } + + ~Sparse2DenseHandle_t() { + if (workspaceSize_) { + matxFree(workspace_); + } + cusparseDestroy(handle_); + } + + static detail::Sparse2DenseParams_t + GetConvParams(TensorTypeO &o, const TensorTypeA &a, cudaStream_t stream) { + detail::Sparse2DenseParams_t params; + params.dtype = TypeToInt(); + params.ptype = TypeToInt(); + params.ctype = TypeToInt(); + params.rank = a.Rank(); + params.stream = stream; + // TODO: simple no-batch, row-wise, no-transpose for now + params.nse = a.Nse(); + params.m = a.Size(TensorTypeA::Rank() - 2); + params.n = a.Size(TensorTypeA::Rank() - 1); + // Matrix handles in cuSPARSE are data specific. Therefore, the pointers + // to the underlying buffers are part of the conversion parameters. + params.ptrA0 = a.Data(); + params.ptrA1 = a.POSData(0); + params.ptrA2 = a.POSData(1); + params.ptrA3 = a.CRDData(0); + params.ptrA4 = a.CRDData(1); + params.ptrO = o.Data(); + return params; + } + + __MATX_INLINE__ void Exec([[maybe_unused]] TensorTypeO &o, + [[maybe_unused]] const TensorTypeA &a) { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL); + const cusparseSparseToDenseAlg_t algo = CUSPARSE_SPARSETODENSE_ALG_DEFAULT; + [[maybe_unused]] cusparseStatus_t ret = + cusparseSparseToDense(handle_, matA_, matO_, algo, workspace_); + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxCudaError); + } + +private: + cusparseHandle_t handle_ = nullptr; // TODO: share handle globally? + cusparseSpMatDescr_t matA_ = nullptr; + cusparseDnMatDescr_t matO_ = nullptr; + size_t workspaceSize_ = 0; + void *workspace_ = nullptr; + detail::Sparse2DenseParams_t params_; +}; + +/** + * Crude hash on Sparse2Dense to get a reasonably good delta for collisions. + * This doesn't need to be perfect, but fast enough to not slow down lookups, + * and different enough so the common conversion parameters change. + */ +struct Sparse2DenseParamsKeyHash { + std::size_t operator()(const Sparse2DenseParams_t &k) const noexcept { + return std::hash()(reinterpret_cast(k.ptrA0)) + + std::hash()(reinterpret_cast(k.ptrO)) + + std::hash()(reinterpret_cast(k.stream)); + } +}; + +/** + * Test SOLVE parameters for equality. Unlike the hash, all parameters must + * match exactly to ensure the hashed kernel can be reused for the computation. + */ +struct Sparse2DenseParamsKeyEq { + bool operator()(const Sparse2DenseParams_t &l, + const Sparse2DenseParams_t &t) const noexcept { + return l.dtype == t.dtype && l.ptype == t.ptype && l.ctype == t.ctype && + l.rank == t.rank && l.stream == t.stream && l.nse == t.nse && + l.m == t.m && l.n == t.n && l.ptrA0 == t.ptrA0 && + l.ptrA1 == t.ptrA1 && l.ptrA2 == t.ptrA2 && l.ptrA3 == t.ptrA3 && + l.ptrA4 == t.ptrA4 && l.ptrO == t.ptrO; + } +}; + +using sparse2dense_cache_t = + std::unordered_map; + +} // end namespace detail + +template +__MATX_INLINE__ auto getSparse2DenseSupportedTensor(const Op &in, + cudaStream_t stream) { + const auto support_func = [&in]() { return true; }; + return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream); +} + +template +void sparse2dense_impl(OutputTensorType O, const InputTensorType A, + const cudaExecutor &exec) { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + const auto stream = exec.getStream(); + + auto a = A; // always sparse + auto o = getSparse2DenseSupportedTensor(O, stream); + + // TODO: some more checking, supported type? on device? etc. + + typedef decltype(o) otype; + typedef decltype(a) atype; + + // Get parameters required by these tensors (for caching). + auto params = + detail::Sparse2DenseHandle_t::GetConvParams(o, a, stream); + + // Lookup and cache. + using cache_val_type = detail::Sparse2DenseHandle_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), params, + [&]() { return std::make_shared(o, a, stream); }, + [&](std::shared_ptr cache_type) { + cache_type->Exec(o, a); + }); +} + +} // end namespace matx diff --git a/include/matx/transforms/matmul/matmul_cusparse.h b/include/matx/transforms/matmul/matmul_cusparse.h index e66c4f23..436ce21a 100644 --- a/include/matx/transforms/matmul/matmul_cusparse.h +++ b/include/matx/transforms/matmul/matmul_cusparse.h @@ -44,23 +44,6 @@ namespace matx { namespace detail { -// Translate MatXType for indices to cuSPARSE index type. -template -constexpr cusparseIndexType_t MatXTypeToCuSparseIndexType() { - if constexpr (std::is_same_v) { - return CUSPARSE_INDEX_16U; - } - if constexpr (std::is_same_v) { - return CUSPARSE_INDEX_32I; - } - if constexpr (std::is_same_v) { - return CUSPARSE_INDEX_64I; - } - if constexpr (std::is_same_v) { - return CUSPARSE_INDEX_64I; - } -} - /** * Parameters needed to execute a cuSPARSE GEMM. */ @@ -151,11 +134,11 @@ class MatMulCUSPARSEHandle_t { MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError); // Create cuSPARSE handle for dense matrices B and C. - static_assert(is_tensor_view_v); static_assert(is_tensor_view_v); + static_assert(is_tensor_view_v); cudaDataType dtb = MatXTypeToCudaType(); cudaDataType dtc = MatXTypeToCudaType(); - const cusparseOrder_t order = CUSPARSE_ORDER_ROW; // TODO: support col B,C? + const cusparseOrder_t order = CUSPARSE_ORDER_ROW; ret = cusparseCreateDnMat(&matB_, params_.k, params_.n, /*ld=*/params_.n, params_.ptrB, dtb, order); MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError); @@ -197,7 +180,7 @@ class MatMulCUSPARSEHandle_t { params.nse = a.Nse(); params.m = a.Size(TensorTypeA::Rank() - 2); params.n = b.Size(TensorTypeB::Rank() - 1); - params.k = a.Size(TensorTypeB::Rank() - 1); + params.k = a.Size(TensorTypeA::Rank() - 1); params.opA = CUSPARSE_OPERATION_NON_TRANSPOSE; params.opB = CUSPARSE_OPERATION_NON_TRANSPOSE; // Matrix handles in cuSPARSE are data specific. Therefore, the pointers @@ -274,13 +257,7 @@ using gemm_cusparse_cache_t = template __MATX_INLINE__ auto getCUSPARSESupportedTensor(const Op &in, cudaStream_t stream) { - const auto support_func = [&in]() { - if constexpr (is_tensor_view_v) { - return in.Stride(Op::Rank() - 1) == 1; // TODO: more than row-wise - } else { - return true; - } - }; + const auto support_func = [&in]() { return true; }; return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream); } diff --git a/include/matx/transforms/solve/solve_cudss.h b/include/matx/transforms/solve/solve_cudss.h index 6d303216..7e41a163 100644 --- a/include/matx/transforms/solve/solve_cudss.h +++ b/include/matx/transforms/solve/solve_cudss.h @@ -165,7 +165,7 @@ class SolveCUDSSHandle_t { params.n = c.Size(TensorTypeC::Rank() - 2); // Note: B,C transposed! params.k = a.Size(TensorTypeA::Rank() - 1); // Matrix handles in cuDSS are data specific. Therefore, the pointers - // to the underlying buffers are part of the GEMM parameters. + // to the underlying buffers are part of the SOLVE parameters. params.ptrA0 = a.Data(); params.ptrA1 = a.POSData(0); params.ptrA2 = a.POSData(1); From a25165dfc9938c1e356e58800147a827d3b3aef3 Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Mon, 3 Feb 2025 18:48:30 -0800 Subject: [PATCH 2/3] proper output type testing --- include/matx/operators/sparse2dense.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/include/matx/operators/sparse2dense.h b/include/matx/operators/sparse2dense.h index 0a72d891..009ab42c 100644 --- a/include/matx/operators/sparse2dense.h +++ b/include/matx/operators/sparse2dense.h @@ -86,11 +86,13 @@ class Sparse2DenseOp : public BaseOp> { template void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const { if constexpr (is_sparse_tensor_v) { - if constexpr (is_sparse_tensor_v) { + auto ref = cuda::std::get<0>(out); + typedef decltype(ref) Rtype; + if constexpr (is_sparse_tensor_v) { MATX_THROW(matxNotSupported, "Cannot use sparse2dense for sparse output"); } else { - sparse2dense_impl(cuda::std::get<0>(out), a_, ex); + sparse2dense_impl(ref, a_, ex); } } else { MATX_THROW(matxNotSupported, "Cannot use sparse2dense on dense input"); From d1ac7a3ddd9b472488be5b8ec8caa12f34bf9e6c Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Mon, 3 Feb 2025 22:45:12 -0800 Subject: [PATCH 3/3] addressed reviewer comments --- include/matx/operators/sparse2dense.h | 2 +- include/matx/transforms/convert/sparse2dense_cusparse.h | 4 ++-- include/matx/transforms/matmul/matmul_cusparse.h | 6 +++--- include/matx/transforms/solve/solve_cudss.h | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/matx/operators/sparse2dense.h b/include/matx/operators/sparse2dense.h index 009ab42c..55d3c751 100644 --- a/include/matx/operators/sparse2dense.h +++ b/include/matx/operators/sparse2dense.h @@ -87,7 +87,7 @@ class Sparse2DenseOp : public BaseOp> { void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const { if constexpr (is_sparse_tensor_v) { auto ref = cuda::std::get<0>(out); - typedef decltype(ref) Rtype; + using Rtype = decltype(ref); if constexpr (is_sparse_tensor_v) { MATX_THROW(matxNotSupported, "Cannot use sparse2dense for sparse output"); diff --git a/include/matx/transforms/convert/sparse2dense_cusparse.h b/include/matx/transforms/convert/sparse2dense_cusparse.h index 5ce44161..8373b0f1 100644 --- a/include/matx/transforms/convert/sparse2dense_cusparse.h +++ b/include/matx/transforms/convert/sparse2dense_cusparse.h @@ -233,8 +233,8 @@ void sparse2dense_impl(OutputTensorType O, const InputTensorType A, // TODO: some more checking, supported type? on device? etc. - typedef decltype(o) otype; - typedef decltype(a) atype; + using atype = decltype(a); + using otype = decltype(o); // Get parameters required by these tensors (for caching). auto params = diff --git a/include/matx/transforms/matmul/matmul_cusparse.h b/include/matx/transforms/matmul/matmul_cusparse.h index 436ce21a..075204fe 100644 --- a/include/matx/transforms/matmul/matmul_cusparse.h +++ b/include/matx/transforms/matmul/matmul_cusparse.h @@ -274,9 +274,9 @@ void sparse_matmul_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B, // TODO: some more checking, supported type? on device? etc. - typedef decltype(c) ctype; - typedef decltype(a) atype; - typedef decltype(b) btype; + using atype = decltype(a); + using btype = decltype(b); + using ctype = decltype(c); // Get parameters required by these tensors (for caching). auto params = diff --git a/include/matx/transforms/solve/solve_cudss.h b/include/matx/transforms/solve/solve_cudss.h index 7e41a163..8299b36e 100644 --- a/include/matx/transforms/solve/solve_cudss.h +++ b/include/matx/transforms/solve/solve_cudss.h @@ -257,9 +257,9 @@ void sparse_solve_impl_trans(TensorTypeC C, const TensorTypeA A, // TODO: some more checking, supported type? on device? etc. - typedef decltype(c) ctype; - typedef decltype(a) atype; - typedef decltype(b) btype; + using atype = decltype(a); + using btype = decltype(b); + using ctype = decltype(c); // Get parameters required by these tensors (for caching). auto params = detail::SolveCUDSSHandle_t::GetSolveParams(