From 7cde545747bb029c02058e35989474a454faf305 Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Mon, 10 Mar 2025 18:13:42 -0700 Subject: [PATCH 1/2] Support mixed-precision for SpMV --- .../matx/transforms/matmul/matvec_cusparse.h | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/include/matx/transforms/matmul/matvec_cusparse.h b/include/matx/transforms/matmul/matvec_cusparse.h index 2f264623..cb5f5ad2 100644 --- a/include/matx/transforms/matmul/matvec_cusparse.h +++ b/include/matx/transforms/matmul/matvec_cusparse.h @@ -77,6 +77,11 @@ class MatVecCUSPARSEHandle_t { using TB = typename TensorTypeB::value_type; using TC = typename TensorTypeC::value_type; + // Mixed-precision compute type. + using TCOMP = std::conditional_t< + std::is_same_v || + std::is_same_v, float, TC>; + /** * Construct a SpMV handle */ @@ -87,12 +92,12 @@ class MatVecCUSPARSEHandle_t { params_ = GetSpMVParams(c, a, b, stream, alpha, beta); // Properly typed alpha, beta. - if constexpr (std::is_same_v> || - std::is_same_v>) { + if constexpr (std::is_same_v> || + std::is_same_v>) { salpha_ = {alpha, 0}; sbeta_ = {beta, 0}; - } else if constexpr (std::is_same_v || - std::is_same_v) { + } else if constexpr (std::is_same_v || + std::is_same_v) { salpha_ = alpha; sbeta_ = beta; } else { @@ -139,7 +144,7 @@ class MatVecCUSPARSEHandle_t { // Allocate a workspace for SpMV. const cusparseSpMVAlg_t algo = CUSPARSE_SPMV_ALG_DEFAULT; - const cudaDataType comptp = dtc; // TODO: support separate comp type?! + const cudaDataType comptp = MatXTypeToCudaType(); ret = cusparseSpMV_bufferSize(handle_, params_.opA, &salpha_, matA_, vecB_, &sbeta_, vecC_, comptp, algo, &workspaceSize_); @@ -188,7 +193,7 @@ class MatVecCUSPARSEHandle_t { [[maybe_unused]] const TensorTypeB &b) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL); const cusparseSpMVAlg_t algo = CUSPARSE_SPMV_ALG_DEFAULT; - const cudaDataType comptp = MatXTypeToCudaType(); // TODO: see above + const cudaDataType comptp = MatXTypeToCudaType(); [[maybe_unused]] cusparseStatus_t ret = cusparseSpMV(handle_, params_.opA, &salpha_, matA_, vecB_, &sbeta_, vecC_, comptp, algo, workspace_); @@ -203,8 +208,8 @@ class MatVecCUSPARSEHandle_t { size_t workspaceSize_ = 0; void *workspace_ = nullptr; detail::MatVecCUSPARSEParams_t params_; - TC salpha_; - TC sbeta_; + TCOMP salpha_; + TCOMP sbeta_; }; /** @@ -287,10 +292,12 @@ void sparse_matvec_impl(TensorTypeC &C, const TensorTypeA &a, "tensors must have SpMV rank"); static_assert(std::is_same_v && std::is_same_v, "tensors must have the same data type"); - // TODO: allow MIXED-PRECISION computation! - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v> || - std::is_same_v>, + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v>, "unsupported data type"); MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 1) && a.Size(RANKA - 2) == c.Size(RANKC - 1), From 86787b59aee434a106b580f4ce90c9ccb6fcdb71 Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Mon, 10 Mar 2025 20:17:22 -0700 Subject: [PATCH 2/2] use type trait for half --- include/matx/transforms/matmul/matvec_cusparse.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/include/matx/transforms/matmul/matvec_cusparse.h b/include/matx/transforms/matmul/matvec_cusparse.h index cb5f5ad2..d3452d9e 100644 --- a/include/matx/transforms/matmul/matvec_cusparse.h +++ b/include/matx/transforms/matmul/matvec_cusparse.h @@ -78,9 +78,7 @@ class MatVecCUSPARSEHandle_t { using TC = typename TensorTypeC::value_type; // Mixed-precision compute type. - using TCOMP = std::conditional_t< - std::is_same_v || - std::is_same_v, float, TC>; + using TCOMP = std::conditional_t, float, TC>; /** * Construct a SpMV handle