Skip to content

Commit

Permalink
Add SpMV support for matvec transformation (#904)
Browse files Browse the repository at this point in the history
* Add SpMV support for matvec transformation

with tests and doc

* typo
  • Loading branch information
aartbik authored Mar 11, 2025
1 parent 1e5c64a commit 245b036
Show file tree
Hide file tree
Showing 7 changed files with 545 additions and 8 deletions.
2 changes: 2 additions & 0 deletions docs_input/api/linalg/matvec/matvec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Matrix-vector multiplication

.. doxygenfunction:: matvec

For information on experimental sparse tensor support for Sparse-Matrix x Vector (SpMV), please see :ref:`sparse_tensor_api`.

Examples
~~~~~~~~

Expand Down
4 changes: 3 additions & 1 deletion docs_input/basics/sparse_tensor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ correct way of performing the conversion above is as follows::
(A = sparse2dense(Acoo)).run(exec);

The current experimental sparse support in MatX provides efficient
operations for sparse-to-dense, dense-to-sparse, matmul, and solve::
operations for sparse-to-dense, dense-to-sparse, matvec, matmul,
and solve::

(A = sparse2dense(Acoo)).run(exec);
(Acoo = dense2sparse(D)).run(exec);
(V = matvec(Acoo, W)).run(exec); // only Sparse-Matrix x Vector (SpMV)
(C = matmul(Acoo, B)).run(exec); // only Sparse-Matrix x Matrix (SpMM)
(X = solve(Acsr, Y)).run(exec); // only on CSR format

Expand Down
8 changes: 7 additions & 1 deletion include/matx/operators/matvec.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"
#include "matx/transforms/matvec.h"
#include "matx/transforms/matmul/matvec_cusparse.h"

namespace matx
{
Expand Down Expand Up @@ -91,7 +92,12 @@ namespace matx

template <typename Out, typename Executor>
void Exec(Out &&out, Executor &&ex) const{
matvec_impl(cuda::std::get<0>(out), a_, b_, ex, alpha_, beta_);
static_assert(!is_sparse_tensor_v<OpB>, "sparse rhs not implemented");
if constexpr (is_sparse_tensor_v<OpA>) {
sparse_matvec_impl(cuda::std::get<0>(out), a_, b_, ex, alpha_, beta_);
} else {
matvec_impl(cuda::std::get<0>(out), a_, b_, ex, alpha_, beta_);
}
}

template <typename ShapeType, typename Executor>
Expand Down
10 changes: 4 additions & 6 deletions include/matx/transforms/matmul/matmul_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,8 @@ class MatMulCUSPARSEHandle_t {

/**
* Construct a sparse GEMM handle
* SpMV
* SpMM <- for now
* SpGEMM
*
*/
MatMulCUSPARSEHandle_t(TensorTypeC &c, const TensorTypeA &a,
const TensorTypeB &b, cudaStream_t stream, float alpha,
Expand Down Expand Up @@ -256,8 +254,8 @@ using gemm_cusparse_cache_t =
MatMulCUSPARSEParamsKeyHash, MatMulCUSPARSEParamsKeyEq>;

template <typename Op>
__MATX_INLINE__ auto getCuSparseSupportedTensor(const Op &in,
cudaStream_t stream) {
__MATX_INLINE__ auto getCuSparseGemmSupportedTensor(const Op &in,
cudaStream_t stream) {
const auto func = [&]() {
if constexpr (is_tensor_view_v<Op>) {
return in.Stride(Op::Rank() - 1) == 1;
Expand All @@ -278,8 +276,8 @@ void sparse_matmul_impl(TensorTypeC &C, const TensorTypeA &a,
const auto stream = exec.getStream();

// Transform into supported form.
auto b = getCuSparseSupportedTensor(B, stream);
auto c = getCuSparseSupportedTensor(C, stream);
auto b = getCuSparseGemmSupportedTensor(B, stream);
auto c = getCuSparseGemmSupportedTensor(C, stream);
if (!is_matx_transform_op<TensorTypeB>() && !b.isSameView(B)) {
(b = B).run(stream);
}
Expand Down
Loading

0 comments on commit 245b036

Please sign in to comment.