diff --git a/include/matx/transforms/matmul/matvec_cusparse.h b/include/matx/transforms/matmul/matvec_cusparse.h index 2f264623..d3452d9e 100644 --- a/include/matx/transforms/matmul/matvec_cusparse.h +++ b/include/matx/transforms/matmul/matvec_cusparse.h @@ -77,6 +77,9 @@ class MatVecCUSPARSEHandle_t { using TB = typename TensorTypeB::value_type; using TC = typename TensorTypeC::value_type; + // Mixed-precision compute type. + using TCOMP = std::conditional_t, float, TC>; + /** * Construct a SpMV handle */ @@ -87,12 +90,12 @@ class MatVecCUSPARSEHandle_t { params_ = GetSpMVParams(c, a, b, stream, alpha, beta); // Properly typed alpha, beta. - if constexpr (std::is_same_v> || - std::is_same_v>) { + if constexpr (std::is_same_v> || + std::is_same_v>) { salpha_ = {alpha, 0}; sbeta_ = {beta, 0}; - } else if constexpr (std::is_same_v || - std::is_same_v) { + } else if constexpr (std::is_same_v || + std::is_same_v) { salpha_ = alpha; sbeta_ = beta; } else { @@ -139,7 +142,7 @@ class MatVecCUSPARSEHandle_t { // Allocate a workspace for SpMV. const cusparseSpMVAlg_t algo = CUSPARSE_SPMV_ALG_DEFAULT; - const cudaDataType comptp = dtc; // TODO: support separate comp type?! + const cudaDataType comptp = MatXTypeToCudaType(); ret = cusparseSpMV_bufferSize(handle_, params_.opA, &salpha_, matA_, vecB_, &sbeta_, vecC_, comptp, algo, &workspaceSize_); @@ -188,7 +191,7 @@ class MatVecCUSPARSEHandle_t { [[maybe_unused]] const TensorTypeB &b) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL); const cusparseSpMVAlg_t algo = CUSPARSE_SPMV_ALG_DEFAULT; - const cudaDataType comptp = MatXTypeToCudaType(); // TODO: see above + const cudaDataType comptp = MatXTypeToCudaType(); [[maybe_unused]] cusparseStatus_t ret = cusparseSpMV(handle_, params_.opA, &salpha_, matA_, vecB_, &sbeta_, vecC_, comptp, algo, workspace_); @@ -203,8 +206,8 @@ class MatVecCUSPARSEHandle_t { size_t workspaceSize_ = 0; void *workspace_ = nullptr; detail::MatVecCUSPARSEParams_t params_; - TC salpha_; - TC sbeta_; + TCOMP salpha_; + TCOMP sbeta_; }; /** @@ -287,10 +290,12 @@ void sparse_matvec_impl(TensorTypeC &C, const TensorTypeA &a, "tensors must have SpMV rank"); static_assert(std::is_same_v && std::is_same_v, "tensors must have the same data type"); - // TODO: allow MIXED-PRECISION computation! - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v> || - std::is_same_v>, + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v>, "unsupported data type"); MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 1) && a.Size(RANKA - 2) == c.Size(RANKC - 1),