Skip to content

Commit

Permalink
Support mixed-precision for SpMV (#907)
Browse files Browse the repository at this point in the history
* Support mixed-precision for SpMV
  • Loading branch information
aartbik authored Mar 11, 2025
1 parent 6b20e04 commit a3fbd5d
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions include/matx/transforms/matmul/matvec_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class MatVecCUSPARSEHandle_t {
using TB = typename TensorTypeB::value_type;
using TC = typename TensorTypeC::value_type;

// Mixed-precision compute type.
using TCOMP = std::conditional_t<is_matx_half_v<TC>, float, TC>;

/**
* Construct a SpMV handle
*/
Expand All @@ -87,12 +90,12 @@ class MatVecCUSPARSEHandle_t {
params_ = GetSpMVParams(c, a, b, stream, alpha, beta);

// Properly typed alpha, beta.
if constexpr (std::is_same_v<TC, cuda::std::complex<float>> ||
std::is_same_v<TC, cuda::std::complex<double>>) {
if constexpr (std::is_same_v<TCOMP, cuda::std::complex<float>> ||
std::is_same_v<TCOMP, cuda::std::complex<double>>) {
salpha_ = {alpha, 0};
sbeta_ = {beta, 0};
} else if constexpr (std::is_same_v<TC, float> ||
std::is_same_v<TC, double>) {
} else if constexpr (std::is_same_v<TCOMP, float> ||
std::is_same_v<TCOMP, double>) {
salpha_ = alpha;
sbeta_ = beta;
} else {
Expand Down Expand Up @@ -139,7 +142,7 @@ class MatVecCUSPARSEHandle_t {

// Allocate a workspace for SpMV.
const cusparseSpMVAlg_t algo = CUSPARSE_SPMV_ALG_DEFAULT;
const cudaDataType comptp = dtc; // TODO: support separate comp type?!
const cudaDataType comptp = MatXTypeToCudaType<TCOMP>();
ret =
cusparseSpMV_bufferSize(handle_, params_.opA, &salpha_, matA_, vecB_,
&sbeta_, vecC_, comptp, algo, &workspaceSize_);
Expand Down Expand Up @@ -188,7 +191,7 @@ class MatVecCUSPARSEHandle_t {
[[maybe_unused]] const TensorTypeB &b) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL);
const cusparseSpMVAlg_t algo = CUSPARSE_SPMV_ALG_DEFAULT;
const cudaDataType comptp = MatXTypeToCudaType<TC>(); // TODO: see above
const cudaDataType comptp = MatXTypeToCudaType<TCOMP>();
[[maybe_unused]] cusparseStatus_t ret =
cusparseSpMV(handle_, params_.opA, &salpha_, matA_, vecB_, &sbeta_,
vecC_, comptp, algo, workspace_);
Expand All @@ -203,8 +206,8 @@ class MatVecCUSPARSEHandle_t {
size_t workspaceSize_ = 0;
void *workspace_ = nullptr;
detail::MatVecCUSPARSEParams_t params_;
TC salpha_;
TC sbeta_;
TCOMP salpha_;
TCOMP sbeta_;
};

/**
Expand Down Expand Up @@ -287,10 +290,12 @@ void sparse_matvec_impl(TensorTypeC &C, const TensorTypeA &a,
"tensors must have SpMV rank");
static_assert(std::is_same_v<TC, TA> && std::is_same_v<TC, TB>,
"tensors must have the same data type");
// TODO: allow MIXED-PRECISION computation!
static_assert(std::is_same_v<TC, float> || std::is_same_v<TC, double> ||
std::is_same_v<TC, cuda::std::complex<float>> ||
std::is_same_v<TC, cuda::std::complex<double>>,
static_assert(std::is_same_v<TC, matx::matxFp16> ||
std::is_same_v<TC, matx::matxBf16> ||
std::is_same_v<TC, float> ||
std::is_same_v<TC, double> ||
std::is_same_v<TC, cuda::std::complex<float>> ||
std::is_same_v<TC, cuda::std::complex<double>>,
"unsupported data type");
MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 1) &&
a.Size(RANKA - 2) == c.Size(RANKC - 1),
Expand Down

0 comments on commit a3fbd5d

Please sign in to comment.