@@ -238,35 +238,20 @@ using gemm_cudss_cache_t =
238
238
239
239
} // end namespace detail
240
240
241
- template <typename Op>
242
- __MATX_INLINE__ auto getCUDSSSupportedTensor (const Op &in,
243
- cudaStream_t stream) {
244
- const auto support_func = [&in]() { return true ; };
245
- return GetSupportedTensor (in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream);
246
- }
247
-
248
241
template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
249
- void sparse_solve_impl_trans (TensorTypeC C , const TensorTypeA A ,
250
- const TensorTypeB B , const cudaExecutor &exec) {
242
+ void sparse_solve_impl_trans (TensorTypeC &c , const TensorTypeA &a ,
243
+ const TensorTypeB &b , const cudaExecutor &exec) {
251
244
MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_API)
252
245
const auto stream = exec.getStream ();
253
246
254
- auto a = A; // always sparse
255
- auto b = getCUDSSSupportedTensor (B, stream);
256
- auto c = getCUDSSSupportedTensor (C, stream);
257
-
258
247
// TODO: some more checking, supported type? on device? etc.
259
248
260
- using atype = decltype (a);
261
- using btype = decltype (b);
262
- using ctype = decltype (c);
263
-
264
249
// Get parameters required by these tensors (for caching).
265
- auto params = detail::SolveCUDSSHandle_t<ctype, atype, btype >::GetSolveParams (
250
+ auto params = detail::SolveCUDSSHandle_t<TensorTypeC, TensorTypeA, TensorTypeB >::GetSolveParams (
266
251
c, a, b, stream);
267
252
268
253
// Lookup and cache.
269
- using cache_val_type = detail::SolveCUDSSHandle_t<ctype, atype, btype >;
254
+ using cache_val_type = detail::SolveCUDSSHandle_t<TensorTypeC, TensorTypeA, TensorTypeB >;
270
255
detail::GetCache ().LookupAndExec <detail::gemm_cudss_cache_t >(
271
256
detail::GetCacheIdFromType<detail::gemm_cudss_cache_t >(), params,
272
257
[&]() { return std::make_shared<cache_val_type>(c, a, b, stream); },
@@ -282,29 +267,29 @@ void sparse_solve_impl_trans(TensorTypeC C, const TensorTypeA A,
282
267
// supports MATX native row-major storage, which will clean up the copies from
283
268
// and to memory.
284
269
template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
285
- void sparse_solve_impl (TensorTypeC C , const TensorTypeA A, const TensorTypeB B ,
286
- const cudaExecutor &exec) {
270
+ void sparse_solve_impl (TensorTypeC &c , const TensorTypeA &a ,
271
+ const TensorTypeB &b, const cudaExecutor &exec) {
287
272
const auto stream = exec.getStream ();
288
273
289
274
// Some copying-in hacks, assumes rank 2.
290
275
using TB = typename TensorTypeB::value_type;
291
276
using TC = typename TensorTypeB::value_type;
292
277
TB *bptr;
293
278
matxAlloc (reinterpret_cast <void **>(&bptr),
294
- sizeof (TB) * B .Size (0 ) * B .Size (1 ), MATX_ASYNC_DEVICE_MEMORY,
279
+ sizeof (TB) * b .Size (0 ) * b .Size (1 ), MATX_ASYNC_DEVICE_MEMORY,
295
280
stream);
296
- auto bT = make_tensor (bptr, {B .Size (1 ), B .Size (0 )});
297
- (bT = transpose (B )).run (exec);
281
+ auto bT = make_tensor (bptr, {b .Size (1 ), b .Size (0 )});
282
+ (bT = transpose (b )).run (exec);
298
283
TC *cptr;
299
284
matxAlloc (reinterpret_cast <void **>(&cptr),
300
- sizeof (TC) * C .Size (0 ) * C .Size (1 ), MATX_ASYNC_DEVICE_MEMORY,
285
+ sizeof (TC) * c .Size (0 ) * c .Size (1 ), MATX_ASYNC_DEVICE_MEMORY,
301
286
stream);
302
- auto cT = make_tensor (cptr, {C .Size (1 ), C .Size (0 )});
287
+ auto cT = make_tensor (cptr, {c .Size (1 ), c .Size (0 )});
303
288
304
- sparse_solve_impl_trans (cT, A , bT, exec);
289
+ sparse_solve_impl_trans (cT, a , bT, exec);
305
290
306
291
// Some copying-back hacks.
307
- (C = transpose (cT)).run (exec);
292
+ (c = transpose (cT)).run (exec);
308
293
}
309
294
310
295
} // end namespace matx
0 commit comments