Skip to content

Commit 55dd664

Browse files
authored
First version of MATX Sparse-Direct-Solve (using dispatch to cuDSS) (#849)
* First version of MATX Sparse-Direct-Solve (using dispatch to cuDSS)
1 parent b3ca482 commit 55dd664

File tree

5 files changed

+513
-36
lines changed

5 files changed

+513
-36
lines changed

examples/sparse_tensor.cu

+38-36
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
#include "matx.h"
3434

35+
// Note that sparse tensor support in MatX is still experimental.
36+
3537
using namespace matx;
3638

3739
int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
@@ -42,7 +44,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
4244
cudaExecutor exec{stream};
4345

4446
//
45-
// Print some formats that are used for the versatile sparse tensor
47+
// Print some formats that are used for the universal sparse tensor
4648
// type. Note that common formats like COO and CSR have good library
4749
// support in e.g. cuSPARSE, but MatX provides a much more general
4850
// way to define the sparse tensor storage through a DSL (see doc).
@@ -68,25 +70,6 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
6870
// | 0, 0, 0, 0, 0, 0, 0, 0 |
6971
// | 0, 0, 3, 4, 0, 5, 0, 0 |
7072
//
71-
72-
constexpr index_t m = 4;
73-
constexpr index_t n = 8;
74-
constexpr index_t nse = 5;
75-
76-
tensor_t<float, 1> values{{nse}};
77-
tensor_t<int, 1> row_idx{{nse}};
78-
tensor_t<int, 1> col_idx{{nse}};
79-
80-
values.SetVals({ 1, 2, 3, 4, 5 });
81-
row_idx.SetVals({ 0, 0, 3, 3, 3 });
82-
col_idx.SetVals({ 0, 1, 2, 3, 5 });
83-
84-
// Note that sparse tensor support in MatX is still experimental.
85-
auto Acoo = experimental::make_tensor_coo(values, row_idx, col_idx, {m, n});
86-
87-
//
88-
// This shows:
89-
//
9073
// tensor_impl_2_f32: SparseTensor{float} Rank: 2, Sizes:[4, 8], Levels:[4, 8]
9174
// nse = 5
9275
// format = ( d0, d1 ) -> ( d0 : compressed(non-unique), d1 : singleton )
@@ -95,6 +78,13 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
9578
// values = ( 1.0000e+00 2.0000e+00 3.0000e+00 4.0000e+00 5.0000e+00 )
9679
// space = CUDA managed memory
9780
//
81+
auto vals = make_tensor<float>({5});
82+
auto idxi = make_tensor<int>({5});
83+
auto idxj = make_tensor<int>({5});
84+
vals.SetVals({1, 2, 3, 4, 5});
85+
idxi.SetVals({0, 0, 3, 3, 3});
86+
idxj.SetVals({0, 1, 2, 3, 5});
87+
auto Acoo = experimental::make_tensor_coo(vals, idxi, idxj, {4, 8});
9888
print(Acoo);
9989

10090
//
@@ -107,9 +97,9 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
10797
// use sparse operations that are tailored for the sparse data
10898
// structure (such as scanning by row for CSR).
10999
//
110-
tensor_t<float, 2> A{{m, n}};
111-
for (index_t i = 0; i < m; i++) {
112-
for (index_t j = 0; j < n; j++) {
100+
auto A = make_tensor<float>({4, 8});
101+
for (index_t i = 0; i < 4; i++) {
102+
for (index_t j = 0; j < 8; j++) {
113103
A(i, j) = Acoo(i, j);
114104
}
115105
}
@@ -119,24 +109,36 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
119109
// SpMM is implemented on COO through cuSPARSE. This is the
120110
// correct way of performing an efficient sparse operation.
121111
//
122-
tensor_t<float, 2> B{{8, 4}};
123-
tensor_t<float, 2> C{{4, 4}};
124-
B.SetVals({ { 0, 1, 2, 3 },
125-
{ 4, 5, 6, 7 },
126-
{ 8, 9, 10, 11 },
127-
{ 12, 13, 14, 15 },
128-
{ 16, 17, 18, 19 },
129-
{ 20, 21, 22, 23 },
130-
{ 24, 25, 26, 27 },
131-
{ 28, 29, 30, 31 } });
112+
auto B = make_tensor<float, 2>({8, 4});
113+
auto C = make_tensor<float>({4, 4});
114+
B.SetVals({
115+
{ 0, 1, 2, 3}, { 4, 5, 6, 7}, { 8, 9, 10, 11}, {12, 13, 14, 15},
116+
{16, 17, 18, 19}, {20, 21, 22, 23}, {24, 25, 26, 27}, {28, 29, 30, 31} });
132117
(C = matmul(Acoo, B)).run(exec);
133118
print(C);
134119

135120
//
136-
// Verify by computing the equivelent dense GEMM.
121+
// Creates a CSR matrix which is used to solve the following
122+
// system of equations AX=Y, where X is the unknown.
137123
//
138-
(C = matmul(A, B)).run(exec);
139-
print(C);
124+
// | 1 2 0 0 | | 1 5 | | 5 17 |
125+
// | 0 3 0 0 | x | 2 6 | = | 6 18 |
126+
// | 0 0 4 0 | | 3 7 | | 12 28 |
127+
// | 0 0 0 5 | | 4 8 | | 20 40 |
128+
//
129+
auto coeffs = make_tensor<float>({5});
130+
auto rowptr = make_tensor<int>({5});
131+
auto colidx = make_tensor<int>({5});
132+
coeffs.SetVals({1, 2, 3, 4, 5});
133+
rowptr.SetVals({0, 2, 3, 4, 5});
134+
colidx.SetVals({0, 1, 1, 2, 3});
135+
auto Acsr = experimental::make_tensor_csr(coeffs, rowptr, colidx, {4, 4});
136+
print(Acsr);
137+
auto X = make_tensor<float>({4, 2});
138+
auto Y = make_tensor<float>({4, 2});
139+
Y.SetVals({ {5, 17}, {6, 18}, {12, 28}, {20, 40} });
140+
(X = solve(Acsr, Y)).run(exec);
141+
print(X);
140142

141143
MATX_EXIT_HANDLER();
142144
}

include/matx/core/type_utils.h

+3
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,9 @@ template <typename T> constexpr cudaDataType_t MatXTypeToCudaType()
11261126
if constexpr (std::is_same_v<T, int8_t>) {
11271127
return CUDA_R_8I;
11281128
}
1129+
if constexpr (std::is_same_v<T, int>) {
1130+
return CUDA_R_32I;
1131+
}
11291132
if constexpr (std::is_same_v<T, float>) {
11301133
return CUDA_R_32F;
11311134
}

include/matx/operators/operators.h

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
#include "matx/operators/shift.h"
100100
#include "matx/operators/sign.h"
101101
#include "matx/operators/slice.h"
102+
#include "matx/operators/solve.h"
102103
#include "matx/operators/sort.h"
103104
#include "matx/operators/sph2cart.h"
104105
#include "matx/operators/stack.h"

include/matx/operators/solve.h

+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// BSD 3-Clause License
3+
//
4+
// Copyright (c) 2025, NVIDIA Corporation
5+
// All rights reserved.
6+
//
7+
// Redistribution and use in source and binary forms, with or without
8+
// modification, are permitted provided that the following conditions are met:
9+
//
10+
// 1. Redistributions of source code must retain the above copyright notice, this
11+
// list of conditions and the following disclaimer.
12+
//
13+
// 2. Redistributions in binary form must reproduce the above copyright notice,
14+
// this list of conditions and the following disclaimer in the documentation
15+
// and/or other materials provided with the distribution.
16+
//
17+
// 3. Neither the name of the copyright holder nor the names of its
18+
// contributors may be used to endorse or promote products derived from
19+
// this software without specific prior written permission.
20+
//
21+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
/////////////////////////////////////////////////////////////////////////////////
32+
33+
#pragma once
34+
35+
#include "matx/core/type_utils.h"
36+
#include "matx/operators/base_operator.h"
37+
#ifdef MATX_EN_CUDSS
38+
#include "matx/transforms/solve/solve_cudss.h"
39+
#endif
40+
41+
namespace matx {
42+
namespace detail {
43+
44+
template <typename OpA, typename OpB>
45+
class SolveOp : public BaseOp<SolveOp<OpA, OpB>> {
46+
private:
47+
typename detail::base_type_t<OpA> a_;
48+
typename detail::base_type_t<OpB> b_;
49+
50+
static constexpr int out_rank = OpB::Rank();
51+
cuda::std::array<index_t, out_rank> out_dims_;
52+
mutable detail::tensor_impl_t<typename OpA::value_type, out_rank> tmp_out_;
53+
mutable typename OpA::value_type *ptr = nullptr;
54+
55+
public:
56+
using matxop = bool;
57+
using matx_transform_op = bool;
58+
using solve_xform_op = bool;
59+
using value_type = typename OpA::value_type;
60+
61+
__MATX_INLINE__ SolveOp(const OpA &a, const OpB &b) : a_(a), b_(b) {
62+
for (int r = 0, rank = Rank(); r < rank; r++) {
63+
out_dims_[r] = b_.Size(r);
64+
}
65+
}
66+
67+
__MATX_INLINE__ std::string str() const {
68+
return "solve(" + get_type_str(a_) + "," + get_type_str(b_) + ")";
69+
}
70+
71+
__MATX_HOST__ __MATX_INLINE__ auto Data() const noexcept { return ptr; }
72+
73+
template <typename... Is>
74+
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto)
75+
operator()(Is... indices) const {
76+
return tmp_out_(indices...);
77+
}
78+
79+
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t
80+
Rank() {
81+
return remove_cvref_t<OpB>::Rank();
82+
}
83+
84+
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t
85+
Size(int dim) const {
86+
return out_dims_[dim];
87+
}
88+
89+
template <typename Out, typename Executor>
90+
void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const {
91+
static_assert(!is_sparse_tensor_v<OpB>, "sparse rhs not implemented");
92+
if constexpr (is_sparse_tensor_v<OpA>) {
93+
#ifdef MATX_EN_CUDSS
94+
sparse_solve_impl(cuda::std::get<0>(out), a_, b_, ex);
95+
#else
96+
MATX_THROW(matxNotSupported, "Sparse direct solver requires cuDSS");
97+
#endif
98+
} else {
99+
MATX_THROW(matxNotSupported,
100+
"Direct solver currently only supports sparse system");
101+
}
102+
}
103+
104+
template <typename ShapeType, typename Executor>
105+
__MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape,
106+
[[maybe_unused]] Executor &&ex) const noexcept {
107+
static_assert(is_sparse_tensor_v<OpA>,
108+
"Direct solver currently only supports sparse system");
109+
if constexpr (is_matx_op<OpB>()) {
110+
b_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
111+
}
112+
}
113+
114+
template <typename ShapeType, typename Executor>
115+
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape,
116+
[[maybe_unused]] Executor &&ex) const noexcept {
117+
InnerPreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
118+
detail::AllocateTempTensor(tmp_out_, std::forward<Executor>(ex), out_dims_,
119+
&ptr);
120+
Exec(cuda::std::make_tuple(tmp_out_), std::forward<Executor>(ex));
121+
}
122+
123+
template <typename ShapeType, typename Executor>
124+
__MATX_INLINE__ void PostRun([[maybe_unused]] ShapeType &&shape,
125+
[[maybe_unused]]Executor &&ex) const noexcept {
126+
static_assert(is_sparse_tensor_v<OpA>,
127+
"Direct solver currently only supports sparse system");
128+
if constexpr (is_matx_op<OpB>()) {
129+
b_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
130+
}
131+
matxFree(ptr);
132+
}
133+
};
134+
135+
} // end namespace detail
136+
137+
/**
138+
* Run a direct SOLVE (viz. X = solve(A, B) solves system AX=B for unknown X).
139+
*
140+
* Note that currently, this operation is only implemented for solving
141+
* a linear system with a very **sparse** matrix A.
142+
*
143+
* @tparam OpA
144+
* Data type of A tensor (sparse)
145+
* @tparam OpB
146+
* Data type of B tensor
147+
*
148+
* @param A
149+
* A Sparse tensor with system coefficients
150+
* @param B
151+
* B Dense tensor of known values
152+
*
153+
* @return
154+
* Operator that produces the output tensor X with the solution
155+
*/
156+
template <typename OpA, typename OpB>
157+
__MATX_INLINE__ auto solve(const OpA &A, const OpB &B) {
158+
return detail::SolveOp(A, B);
159+
}
160+
161+
} // end namespace matx

0 commit comments

Comments
 (0)