@@ -75,23 +75,9 @@ class SolveCUDSSHandle_t {
75
75
using TB = typename TensorTypeB::value_type;
76
76
using TC = typename TensorTypeC::value_type;
77
77
78
- static constexpr int RANKA = TensorTypeC::Rank();
79
- static constexpr int RANKB = TensorTypeC::Rank();
80
- static constexpr int RANKC = TensorTypeC::Rank();
81
-
82
78
SolveCUDSSHandle_t (TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b,
83
79
cudaStream_t stream) {
84
80
MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_INTERNAL)
85
-
86
- static_assert (RANKA == 2 );
87
- static_assert (RANKB == 2 );
88
- static_assert (RANKC == 2 );
89
-
90
- // Note: B,C transposed!
91
- MATX_ASSERT (a.Size (RANKA - 1 ) == b.Size (RANKB - 1 ), matxInvalidSize);
92
- MATX_ASSERT (a.Size (RANKA - 2 ) == b.Size (RANKB - 1 ), matxInvalidSize);
93
- MATX_ASSERT (b.Size (RANKB - 2 ) == c.Size (RANKC - 2 ), matxInvalidSize);
94
-
95
81
params_ = GetSolveParams (c, a, b, stream);
96
82
97
83
[[maybe_unused]] cudssStatus_t ret = cudssCreate (&handle_);
@@ -100,7 +86,7 @@ class SolveCUDSSHandle_t {
100
86
// Create cuDSS handle for sparse matrix A.
101
87
static_assert (is_sparse_tensor_v<TensorTypeA>);
102
88
MATX_ASSERT (TypeToInt<typename TensorTypeA::pos_type> ==
103
- TypeToInt<typename TensorTypeA::crd_type>,
89
+ TypeToInt<typename TensorTypeA::crd_type>,
104
90
matxNotSupported);
105
91
cudaDataType itp = MatXTypeToCudaType<typename TensorTypeA::crd_type>();
106
92
cudaDataType dta = MatXTypeToCudaType<TA>();
@@ -244,7 +230,29 @@ void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a,
244
230
MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_API)
245
231
const auto stream = exec.getStream ();
246
232
247
- // TODO: some more checking, supported type? on device? etc.
233
+ using TA = typename TensorTypeA::value_type;
234
+ using TB = typename TensorTypeB::value_type;
235
+ using TC = typename TensorTypeC::value_type;
236
+
237
+ static constexpr int RANKA = TensorTypeA::Rank ();
238
+ static constexpr int RANKB = TensorTypeB::Rank ();
239
+ static constexpr int RANKC = TensorTypeC::Rank ();
240
+
241
+ // Restrictions.
242
+ static_assert (RANKA == 2 && RANKB == 2 && RANKC == 2 ,
243
+ " tensors must have rank-2" );
244
+ static_assert (std::is_same_v<TC, TA> &&
245
+ std::is_same_v<TC, TB>,
246
+ " tensors must have the same data type" );
247
+ static_assert (std::is_same_v<TC, float > ||
248
+ std::is_same_v<TC, double > ||
249
+ std::is_same_v<TC, cuda::std::complex<float >> ||
250
+ std::is_same_v<TC, cuda::std::complex<double >>,
251
+ " unsupported data type" );
252
+ MATX_ASSERT ( // Note: B,C transposed!
253
+ a.Size (RANKA - 1 ) == b.Size (RANKB - 1 ) &&
254
+ a.Size (RANKA - 2 ) == b.Size (RANKB - 1 ) &&
255
+ b.Size (RANKB - 2 ) == c.Size (RANKC - 2 ), matxInvalidSize);
248
256
249
257
// Get parameters required by these tensors (for caching).
250
258
auto params = detail::SolveCUDSSHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>::GetSolveParams (
@@ -266,12 +274,16 @@ void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a,
266
274
// convoluted way of performing the solve step must be removed once cuDSS
267
275
// supports MATX native row-major storage, which will clean up the copies from
268
276
// and to memory.
277
+ //
278
+ // TODO: remove this when cuDSS supports row-major storage
279
+ //
269
280
template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
270
281
void sparse_solve_impl (TensorTypeC &c, const TensorTypeA &a,
271
282
const TensorTypeB &b, const cudaExecutor &exec) {
272
283
const auto stream = exec.getStream ();
273
284
274
- // Some copying-in hacks, assumes rank 2.
285
+ // Some copying-in hacks.
286
+ static_assert (TensorTypeB::Rank () == 2 && TensorTypeC::Rank () == 2 );
275
287
using TB = typename TensorTypeB::value_type;
276
288
using TC = typename TensorTypeB::value_type;
277
289
TB *bptr;
0 commit comments