Skip to content

Commit 22a71fd

Browse files
aartbikcliffburdick
authored andcommitted
allow transforming output for sparse2dense (#882)
1 parent 9d1eee3 commit 22a71fd

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

include/matx/transforms/convert/dense2sparse_cusparse.h

+1
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ void dense2sparse_impl(OutputTensorType &o, const InputTensorType &A,
270270
if (!is_matx_transform_op<InputTensorType>() && !a.isSameView(A)) {
271271
(a = A).run(stream);
272272
}
273+
273274
using atype = decltype(a);
274275
using otype = OutputTensorType;
275276

include/matx/transforms/convert/sparse2dense_cusparse.h

+20-2
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,29 @@ using sparse2dense_cache_t =
207207
std::unordered_map<Sparse2DenseParams_t, std::any,
208208
Sparse2DenseParamsKeyHash, Sparse2DenseParamsKeyEq>;
209209

210+
template <typename Op>
211+
__MATX_INLINE__ auto getS2DSupportedTensor(const Op &in, cudaStream_t stream) {
212+
const auto func = [&]() {
213+
if constexpr (is_tensor_view_v<Op>)
214+
return in.Stride(Op::Rank() - 1) == 1;
215+
return true;
216+
};
217+
return GetSupportedTensor(in, func, MATX_ASYNC_DEVICE_MEMORY, stream);
218+
}
219+
210220
} // end namespace detail
211221

212222
template <typename OutputTensorType, typename InputTensorType>
213-
void sparse2dense_impl(OutputTensorType &o, const InputTensorType &a,
223+
void sparse2dense_impl(OutputTensorType &O, const InputTensorType &a,
214224
const cudaExecutor &exec) {
215225
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
216226
const auto stream = exec.getStream();
217227

228+
// Transform into supported form.
229+
auto o = getS2DSupportedTensor(O, stream);
230+
218231
using atype = InputTensorType;
219-
using otype = OutputTensorType;
232+
using otype = decltype(o);
220233

221234
using TA = typename atype::value_type;
222235
using TO = typename otype::value_type;
@@ -248,6 +261,11 @@ void sparse2dense_impl(OutputTensorType &o, const InputTensorType &a,
248261
[&](std::shared_ptr<cache_val_type> cache_type) {
249262
cache_type->Exec(o, a);
250263
});
264+
265+
// Copy transformed output back.
266+
if (!o.isSameView(O)) {
267+
(O = o).run(stream);
268+
}
251269
}
252270

253271
} // end namespace matx

test/00_sparse/Convert.cu

+11
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,17 @@ TYPED_TEST(ConvertSparseTestsAll, ConvertCSR) {
150150
}
151151
}
152152

153+
// Allow transforming output.
154+
(transpose(O) = sparse2dense(S)).run(exec);
155+
156+
// Verify result.
157+
exec.sync();
158+
for (index_t i = 0; i < m; i++) {
159+
for (index_t j = 0; j < n; j++) {
160+
ASSERT_EQ(O(j, i), D(i, j));
161+
}
162+
}
163+
153164
MATX_EXIT_HANDLER();
154165
}
155166

0 commit comments

Comments
 (0)