Skip to content

Commit

Permalink
Support mixed-precision for SpMM (#906)
Browse files Browse the repository at this point in the history
* Support mixed-precision for SpMM

Also fixes a few minor details related to zero-size allocation
and host-side modification of device memory.

* use type trait for half
  • Loading branch information
aartbik authored Mar 11, 2025
1 parent 245b036 commit 4445f72
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 21 deletions.
10 changes: 6 additions & 4 deletions include/matx/core/make_sparse_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ namespace experimental {
template <typename T>
__MATX_INLINE__ static auto
makeDefaultNonOwningZeroStorage(index_t sz, matxMemorySpace_t space) {
T *ptr;
T *ptr = nullptr;
assert(sz > 0);
matxAlloc((void **)&ptr, sz * sizeof(T), space, 0);
// TODO: introduce a more efficient matxCalloc or matxMemset?
for (index_t i = 0; i < sz; i++) {
ptr[i] = 0;
if (space == MATX_DEVICE_MEMORY || space == MATX_ASYNC_DEVICE_MEMORY) {
cudaMemset(ptr, 0, sz * sizeof(T));
} else {
memset(ptr, 0, sz * sizeof(T));
}
raw_pointer_buffer<T, matx_allocator<T>> buf{ptr, sz * sizeof(T),
/*owning=*/false};
Expand Down
15 changes: 10 additions & 5 deletions include/matx/transforms/convert/dense2sparse_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ template <typename T>
__MATX_INLINE__ static auto makeDefaultNonOwningStorage(size_t sz,
matxMemorySpace_t space,
cudaStream_t stream) {
T *ptr;
matxAlloc(reinterpret_cast<void **>(&ptr), sz * sizeof(T), space, stream);
raw_pointer_buffer<T, matx_allocator<T>> buf{ptr, sz * sizeof(T),
/*owning=*/false};
T *ptr = nullptr;
if (sz != 0) {
matxAlloc(reinterpret_cast<void **>(&ptr), sz * sizeof(T), space, stream);
}
raw_pointer_buffer<T, matx_allocator<T>> buf{ptr, sz * sizeof(T), /*owning=*/false};
return basic_storage<decltype(buf)>{std::move(buf)};
}

Expand Down Expand Up @@ -147,8 +148,12 @@ class Dense2SparseHandle_t {
// the nnz is updated explicitly here before allocating
// the new components of COO.
POS *pos = reinterpret_cast<POS *>(params_.ptrO1);
pos[1] = nnz;
matxMemorySpace_t space = GetPointerKind(pos);
if (space == MATX_DEVICE_MEMORY || space == MATX_ASYNC_DEVICE_MEMORY) {
cudaMemcpy(pos + 1, &nnz, sizeof(POS), cudaMemcpyHostToDevice);
} else {
pos[1] = nnz;
}
o.SetVal(makeDefaultNonOwningStorage<VAL>(nnz, space, stream));
o.SetCrd(0, makeDefaultNonOwningStorage<CRD>(nnz, space, stream));
o.SetCrd(1, makeDefaultNonOwningStorage<CRD>(nnz, space, stream));
Expand Down
29 changes: 17 additions & 12 deletions include/matx/transforms/matmul/matmul_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class MatMulCUSPARSEHandle_t {
using TB = typename TensorTypeB::value_type;
using TC = typename TensorTypeC::value_type;

// Mixed-precision compute type.
using TCOMP = std::conditional_t<is_matx_half_v<TC>, float, TC>;

/**
* Construct a sparse GEMM handle
* SpMM <- for now
Expand All @@ -92,12 +95,12 @@ class MatMulCUSPARSEHandle_t {
params_ = GetGemmParams(c, a, b, stream, alpha, beta);

// Properly typed alpha, beta.
if constexpr (std::is_same_v<TC, cuda::std::complex<float>> ||
std::is_same_v<TC, cuda::std::complex<double>>) {
if constexpr (std::is_same_v<TCOMP, cuda::std::complex<float>> ||
std::is_same_v<TCOMP, cuda::std::complex<double>>) {
salpha_ = {alpha, 0};
sbeta_ = {beta, 0};
} else if constexpr (std::is_same_v<TC, float> ||
std::is_same_v<TC, double>) {
} else if constexpr (std::is_same_v<TCOMP, float> ||
std::is_same_v<TCOMP, double>) {
salpha_ = alpha;
sbeta_ = beta;
} else {
Expand Down Expand Up @@ -147,7 +150,7 @@ class MatMulCUSPARSEHandle_t {

// Allocate a workspace for SpMM.
const cusparseSpMMAlg_t algo = CUSPARSE_SPMM_ALG_DEFAULT;
const cudaDataType comptp = dtc; // TODO: support separate comp type?!
const cudaDataType comptp = MatXTypeToCudaType<TCOMP>();
ret = cusparseSpMM_bufferSize(handle_, params_.opA, params_.opB, &salpha_,
matA_, matB_, &sbeta_, matC_, comptp, algo,
&workspaceSize_);
Expand Down Expand Up @@ -199,7 +202,7 @@ class MatMulCUSPARSEHandle_t {
[[maybe_unused]] const TensorTypeB &b) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL);
const cusparseSpMMAlg_t algo = CUSPARSE_SPMM_ALG_DEFAULT;
const cudaDataType comptp = MatXTypeToCudaType<TC>(); // TODO: see above
const cudaDataType comptp = MatXTypeToCudaType<TCOMP>();
[[maybe_unused]] cusparseStatus_t ret =
cusparseSpMM(handle_, params_.opA, params_.opB, &salpha_, matA_, matB_,
&sbeta_, matC_, comptp, algo, workspace_);
Expand All @@ -214,8 +217,8 @@ class MatMulCUSPARSEHandle_t {
size_t workspaceSize_ = 0;
void *workspace_ = nullptr;
detail::MatMulCUSPARSEParams_t params_;
TC salpha_;
TC sbeta_;
TCOMP salpha_;
TCOMP sbeta_;
};

/**
Expand Down Expand Up @@ -299,10 +302,12 @@ void sparse_matmul_impl(TensorTypeC &C, const TensorTypeA &a,
"tensors must have rank-2");
static_assert(std::is_same_v<TC, TA> && std::is_same_v<TC, TB>,
"tensors must have the same data type");
// TODO: allow MIXED-PRECISION computation!
static_assert(std::is_same_v<TC, float> || std::is_same_v<TC, double> ||
std::is_same_v<TC, cuda::std::complex<float>> ||
std::is_same_v<TC, cuda::std::complex<double>>,
static_assert(std::is_same_v<TC, matx::matxFp16> ||
std::is_same_v<TC, matx::matxBf16> ||
std::is_same_v<TC, float> ||
std::is_same_v<TC, double> ||
std::is_same_v<TC, cuda::std::complex<float>> ||
std::is_same_v<TC, cuda::std::complex<double>>,
"unsupported data type");
MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 2) &&
c.Size(RANKC - 1) == b.Size(RANKB - 1) &&
Expand Down

0 comments on commit 4445f72

Please sign in to comment.