36
36
37
37
#include " matx/core/sparse_tensor_format.h"
38
38
#include " matx/core/tensor_impl.h"
39
+ #include " matx/operators/base_operator.h"
39
40
40
41
namespace matx {
42
+
43
+ namespace detail {
44
+
45
+ //
46
+ // A sparse_set operation. Assigning to a sparse tensor is very different
47
+ // from all other MatX assignments, because the underlying storage and
48
+ // buffers may have to be resized to accomodate the output. Therefore,
49
+ // for now, we provide a customized set operation that passes a direct
50
+ // reference to the executor.
51
+ //
52
+ template <typename T, typename Op>
53
+ class sparse_set : public BaseOp <sparse_set<T, Op>> {
54
+ private:
55
+ T &out_;
56
+ mutable typename detail::base_type_t <Op> op_;
57
+ public:
58
+ inline sparse_set (T &out, const Op &op) : out_(out), op_(op) {}
59
+ template <typename Ex> __MATX_INLINE__ void run (Ex &&ex) {
60
+ op_.Exec (out_, std::forward<Ex>(ex));
61
+ }
62
+ };
63
+
64
+ } // end namespace detail
65
+
41
66
namespace experimental {
42
67
43
68
//
@@ -61,6 +86,7 @@ class sparse_tensor_t
61
86
using crd_type = CRD;
62
87
using pos_type = POS;
63
88
using Format = TF;
89
+
64
90
static constexpr int DIM = TF::DIM;
65
91
static constexpr int LVL = TF::LVL;
66
92
@@ -84,13 +110,33 @@ class sparse_tensor_t
84
110
: detail::tensor_impl_t<VAL, DIM, DimDesc,
85
111
detail::SparseTensorData<VAL, CRD, POS, TF>>(
86
112
shape) {
87
- // Initialize primary and secondary storage.
88
113
values_ = std::move (vals);
89
114
for (int l = 0 ; l < LVL; l++) {
90
115
coordinates_[l] = std::move (crd[l]);
91
116
positions_[l] = std::move (pos[l]);
92
117
}
93
- // Set the sparse data in tensor_impl.
118
+ SetSparseDataImpl ();
119
+ }
120
+
121
+ // Default destructor.
122
+ __MATX_INLINE__ ~sparse_tensor_t () = default ;
123
+
124
+ // Sets value storage.
125
+ __MATX_INLINE__ void SetVal (StorageV &&val) { values_ = std::move (val); }
126
+
127
+ // Sets coordinates storage.
128
+ __MATX_INLINE__ void SetCrd (int l, StorageC &&crd) {
129
+ coordinates_[l] = std::move (crd);
130
+ }
131
+
132
+ // Sets positions storage.
133
+ __MATX_INLINE__ void SetPos (int l, StorageP &&pos) {
134
+ positions_[l] = std::move (pos);
135
+ }
136
+
137
+ // Sets sparse data in tensor_impl_t. This method must be called
138
+ // every time changes are made to the underlying storage objects.
139
+ void SetSparseDataImpl () {
94
140
VAL *v = values_.data ();
95
141
CRD *c[LVL];
96
142
POS *p[LVL];
@@ -104,13 +150,23 @@ class sparse_tensor_t
104
150
this ->SetSparseData (v, c, p);
105
151
}
106
152
107
- // Default destructor.
108
- __MATX_INLINE__ ~sparse_tensor_t () = default ;
153
+ // A direct sparse tensor assignment (viz. (Acoo = ...).exec();).
154
+ template <typename T>
155
+ [[nodiscard]] __MATX_INLINE__ __MATX_HOST__ auto operator =(const T &op) {
156
+ [[maybe_unused]] typename T::dense2sparse_xform_op valid = true ;
157
+ return detail::sparse_set (*this , op);
158
+ }
109
159
110
160
// Size getters.
111
- index_t Nse () const { return static_cast <index_t >(values_.size () / sizeof (VAL)); }
112
- index_t crdSize (int l) const { return static_cast <index_t >(coordinates_[l].size () / sizeof (CRD)); }
113
- index_t posSize (int l) const { return static_cast <index_t >(positions_[l].size () / sizeof (POS)); }
161
+ index_t Nse () const {
162
+ return static_cast <index_t >(values_.size () / sizeof (VAL));
163
+ }
164
+ index_t crdSize (int l) const {
165
+ return static_cast <index_t >(coordinates_[l].size () / sizeof (CRD));
166
+ }
167
+ index_t posSize (int l) const {
168
+ return static_cast <index_t >(positions_[l].size () / sizeof (POS));
169
+ }
114
170
115
171
private:
116
172
// Primary storage of sparse tensor (explicitly stored element values).
0 commit comments