Skip to content

Commit 2529c66

Browse files
authored
First version of MATX dense2sparse conversion (dispatch to cuSPARSE) (#868)
1 parent 9114af3 commit 2529c66

File tree

5 files changed

+467
-7
lines changed

5 files changed

+467
-7
lines changed

examples/sparse_tensor.cu

+17
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,22 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
149149
(X = solve(Acsr, Y)).run(exec);
150150
print(X);
151151

152+
//
153+
// A direct dense2sparse conversion. This is the correct way of
154+
// performing an efficient sparse operation. Note, however,
155+
// that assigning a right-hand-side value to a sparse tensor
156+
// (viz. the lval Acoo) is an experimental operation recently
157+
// added to MatX, and it is currently restricted to a direct
158+
// "dense2sparse" operation at the right-hand-side.
159+
//
160+
auto D = make_tensor<float, 2>({4, 8});
161+
D.SetVals({
162+
{0, 11, 0, 12, 0, 0, 0, 0},
163+
{0, 0, 13, 0, 0, 0, 0, 0},
164+
{0, 0, 0, 0, 0, 0, 0, 14},
165+
{0, 15, 0, 0, 16, 0, 17, 0}});
166+
(Acoo = dense2sparse(D)).run(exec);
167+
print(Acoo);
168+
152169
MATX_EXIT_HANDLER();
153170
}

include/matx/core/sparse_tensor.h

+63-7
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,33 @@
3636

3737
#include "matx/core/sparse_tensor_format.h"
3838
#include "matx/core/tensor_impl.h"
39+
#include "matx/operators/base_operator.h"
3940

4041
namespace matx {
42+
43+
namespace detail {
44+
45+
//
46+
// A sparse_set operation. Assigning to a sparse tensor is very different
47+
// from all other MatX assignments, because the underlying storage and
48+
// buffers may have to be resized to accomodate the output. Therefore,
49+
// for now, we provide a customized set operation that passes a direct
50+
// reference to the executor.
51+
//
52+
template <typename T, typename Op>
53+
class sparse_set : public BaseOp<sparse_set<T, Op>> {
54+
private:
55+
T &out_;
56+
mutable typename detail::base_type_t<Op> op_;
57+
public:
58+
inline sparse_set(T &out, const Op &op) : out_(out), op_(op) {}
59+
template <typename Ex> __MATX_INLINE__ void run(Ex &&ex) {
60+
op_.Exec(out_, std::forward<Ex>(ex));
61+
}
62+
};
63+
64+
} // end namespace detail
65+
4166
namespace experimental {
4267

4368
//
@@ -61,6 +86,7 @@ class sparse_tensor_t
6186
using crd_type = CRD;
6287
using pos_type = POS;
6388
using Format = TF;
89+
6490
static constexpr int DIM = TF::DIM;
6591
static constexpr int LVL = TF::LVL;
6692

@@ -84,13 +110,33 @@ class sparse_tensor_t
84110
: detail::tensor_impl_t<VAL, DIM, DimDesc,
85111
detail::SparseTensorData<VAL, CRD, POS, TF>>(
86112
shape) {
87-
// Initialize primary and secondary storage.
88113
values_ = std::move(vals);
89114
for (int l = 0; l < LVL; l++) {
90115
coordinates_[l] = std::move(crd[l]);
91116
positions_[l] = std::move(pos[l]);
92117
}
93-
// Set the sparse data in tensor_impl.
118+
SetSparseDataImpl();
119+
}
120+
121+
// Default destructor.
122+
__MATX_INLINE__ ~sparse_tensor_t() = default;
123+
124+
// Sets value storage.
125+
__MATX_INLINE__ void SetVal(StorageV &&val) { values_ = std::move(val); }
126+
127+
// Sets coordinates storage.
128+
__MATX_INLINE__ void SetCrd(int l, StorageC &&crd) {
129+
coordinates_[l] = std::move(crd);
130+
}
131+
132+
// Sets positions storage.
133+
__MATX_INLINE__ void SetPos(int l, StorageP &&pos) {
134+
positions_[l] = std::move(pos);
135+
}
136+
137+
// Sets sparse data in tensor_impl_t. This method must be called
138+
// every time changes are made to the underlying storage objects.
139+
void SetSparseDataImpl() {
94140
VAL *v = values_.data();
95141
CRD *c[LVL];
96142
POS *p[LVL];
@@ -104,13 +150,23 @@ class sparse_tensor_t
104150
this->SetSparseData(v, c, p);
105151
}
106152

107-
// Default destructor.
108-
__MATX_INLINE__ ~sparse_tensor_t() = default;
153+
// A direct sparse tensor assignment (viz. (Acoo = ...).exec();).
154+
template <typename T>
155+
[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ auto operator=(const T &op) {
156+
[[maybe_unused]] typename T::dense2sparse_xform_op valid = true;
157+
return detail::sparse_set(*this, op);
158+
}
109159

110160
// Size getters.
111-
index_t Nse() const { return static_cast<index_t>(values_.size() / sizeof(VAL)); }
112-
index_t crdSize(int l) const { return static_cast<index_t>(coordinates_[l].size() / sizeof(CRD)); }
113-
index_t posSize(int l) const { return static_cast<index_t>(positions_[l].size() / sizeof(POS)); }
161+
index_t Nse() const {
162+
return static_cast<index_t>(values_.size() / sizeof(VAL));
163+
}
164+
index_t crdSize(int l) const {
165+
return static_cast<index_t>(coordinates_[l].size() / sizeof(CRD));
166+
}
167+
index_t posSize(int l) const {
168+
return static_cast<index_t>(positions_[l].size() / sizeof(POS));
169+
}
114170

115171
private:
116172
// Primary storage of sparse tensor (explicitly stored element values).

include/matx/operators/dense2sparse.h

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// BSD 3-Clause License
3+
//
4+
// Copyright (c) 2025, NVIDIA Corporation
5+
// All rights reserved.
6+
//
7+
// Redistribution and use in source and binary forms, with or without
8+
// modification, are permitted provided that the following conditions are met:
9+
//
10+
// 1. Redistributions of source code must retain the above copyright notice, this
11+
// list of conditions and the following disclaimer.
12+
//
13+
// 2. Redistributions in binary form must reproduce the above copyright notice,
14+
// this list of conditions and the following disclaimer in the documentation
15+
// and/or other materials provided with the distribution.
16+
//
17+
// 3. Neither the name of the copyright holder nor the names of its
18+
// contributors may be used to endorse or promote products derived from
19+
// this software without specific prior written permission.
20+
//
21+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
/////////////////////////////////////////////////////////////////////////////////
32+
33+
#pragma once
34+
35+
#include "matx/core/type_utils.h"
36+
#include "matx/operators/base_operator.h"
37+
#include "matx/transforms/convert/dense2sparse_cusparse.h"
38+
39+
namespace matx {
40+
namespace detail {
41+
42+
template <typename OpA>
43+
class Dense2SparseOp : public BaseOp<Dense2SparseOp<OpA>> {
44+
private:
45+
typename detail::base_type_t<OpA> a_;
46+
47+
public:
48+
using matxop = bool;
49+
using matx_transform_op = bool;
50+
using dense2sparse_xform_op = bool;
51+
using value_type = typename OpA::value_type;
52+
53+
__MATX_INLINE__ Dense2SparseOp(const OpA &a) : a_(a) {}
54+
55+
__MATX_INLINE__ std::string str() const {
56+
return "dense2sparse(" + get_type_str(a_) + ")";
57+
}
58+
59+
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t
60+
Rank() {
61+
return remove_cvref_t<OpA>::Rank();
62+
}
63+
64+
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t
65+
Size(int dim) const {
66+
return a_.Size(dim);
67+
}
68+
69+
template <typename Out, typename Executor>
70+
void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const {
71+
if constexpr (is_sparse_tensor_v<OpA>) {
72+
MATX_THROW(matxNotSupported, "Cannot use dense2sparse on sparse input");
73+
} else {
74+
// NOTE: sparse assignment A = dense2sparse(B) takes direct reference!
75+
if constexpr (is_sparse_tensor_v<Out>) {
76+
dense2sparse_impl(out, a_, ex);
77+
} else {
78+
MATX_THROW(matxNotSupported,
79+
"Cannot use dense2sparse for dense output");
80+
}
81+
}
82+
}
83+
};
84+
85+
} // end namespace detail
86+
87+
/**
88+
* Convert a dense tensor into a sparse tensor.
89+
*
90+
* @tparam OpA
91+
* Data type of A tensor
92+
*
93+
* @param A
94+
* Dense input tensor
95+
*
96+
* @return
97+
* Sparse output tensor
98+
*/
99+
template <typename OpA> __MATX_INLINE__ auto dense2sparse(const OpA &A) {
100+
return detail::Dense2SparseOp(A);
101+
}
102+
103+
} // end namespace matx

include/matx/operators/operators.h

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "matx/operators/cov.h"
5353
#include "matx/operators/cross.h"
5454
#include "matx/operators/cumsum.h"
55+
#include "matx/operators/dense2sparse.h"
5556
#include "matx/operators/diag.h"
5657
#include "matx/operators/dct.h"
5758
#include "matx/operators/det.h"

0 commit comments

Comments
 (0)