@@ -207,16 +207,29 @@ using sparse2dense_cache_t =
207
207
std::unordered_map<Sparse2DenseParams_t, std::any,
208
208
Sparse2DenseParamsKeyHash, Sparse2DenseParamsKeyEq>;
209
209
210
+ template <typename Op>
211
+ __MATX_INLINE__ auto getS2DSupportedTensor (const Op &in, cudaStream_t stream) {
212
+ const auto func = [&]() {
213
+ if constexpr (is_tensor_view_v<Op>)
214
+ return in.Stride (Op::Rank () - 1 ) == 1 ;
215
+ return true ;
216
+ };
217
+ return GetSupportedTensor (in, func, MATX_ASYNC_DEVICE_MEMORY, stream);
218
+ }
219
+
210
220
} // end namespace detail
211
221
212
222
template <typename OutputTensorType, typename InputTensorType>
213
- void sparse2dense_impl (OutputTensorType &o , const InputTensorType &a,
223
+ void sparse2dense_impl (OutputTensorType &O , const InputTensorType &a,
214
224
const cudaExecutor &exec) {
215
225
MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_API)
216
226
const auto stream = exec.getStream ();
217
227
228
+ // Transform into supported form.
229
+ auto o = getS2DSupportedTensor (O, stream);
230
+
218
231
using atype = InputTensorType;
219
- using otype = OutputTensorType ;
232
+ using otype = decltype (o) ;
220
233
221
234
using TA = typename atype::value_type;
222
235
using TO = typename otype::value_type;
@@ -248,6 +261,11 @@ void sparse2dense_impl(OutputTensorType &o, const InputTensorType &a,
248
261
[&](std::shared_ptr<cache_val_type> cache_type) {
249
262
cache_type->Exec (o, a);
250
263
});
264
+
265
+ // Copy transformed output back.
266
+ if (!o.isSameView (O)) {
267
+ (O = o).run (stream);
268
+ }
251
269
}
252
270
253
271
} // end namespace matx
0 commit comments