Skip to content

Commit fcb6d3d

Browse files
authored
Enforce library restrictions on MatX transformations (#872)
1 parent 252dc03 commit fcb6d3d

File tree

4 files changed

+85
-45
lines changed

4 files changed

+85
-45
lines changed

include/matx/transforms/convert/dense2sparse_cusparse.h

+16-7
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,12 @@ class Dense2SparseHandle_t {
8484
using POS = typename TensorTypeO::pos_type;
8585
using CRD = typename TensorTypeO::crd_type;
8686

87-
static constexpr int RANKA = TensorTypeA::Rank();
88-
static constexpr int RANKO = TensorTypeO::Rank();
89-
9087
/**
9188
* Construct a dense2sparse handle.
9289
*/
9390
Dense2SparseHandle_t(TensorTypeO &o, const TensorTypeA &a,
9491
cudaStream_t stream) {
9592
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
96-
97-
static_assert(RANKA == RANKO);
98-
9993
params_ = GetConvParams(o, a, stream);
10094

10195
[[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_);
@@ -261,7 +255,22 @@ void dense2sparse_impl(OutputTensorType &o, const InputTensorType &a,
261255
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
262256
const auto stream = exec.getStream();
263257

264-
// TODO: some more checking, supported type? on device? etc.
258+
using TA = typename InputTensorType::value_type;
259+
using TO = typename OutputTensorType::value_type;
260+
261+
// Restrictions.
262+
static_assert(OutputTensorType::Rank() == InputTensorType::Rank(),
263+
"tensors must have same rank");
264+
static_assert(std::is_same_v<TA, TO>,
265+
"tensors must have the same data type");
266+
static_assert(std::is_same_v<TA, int8_t> ||
267+
std::is_same_v<TA, matx::matxFp16> ||
268+
std::is_same_v<TA, matx::matxBf16> ||
269+
std::is_same_v<TA, float> ||
270+
std::is_same_v<TA, double> ||
271+
std::is_same_v<TA, cuda::std::complex<float>> ||
272+
std::is_same_v<TA, cuda::std::complex<double>>,
273+
"unsupported data type");
265274

266275
// Get parameters required by these tensors (for caching).
267276
auto params =

include/matx/transforms/convert/sparse2dense_cusparse.h

+16-7
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,12 @@ class Sparse2DenseHandle_t {
7272
using TA = typename TensorTypeA::value_type;
7373
using TO = typename TensorTypeO::value_type;
7474

75-
static constexpr int RANKA = TensorTypeA::Rank();
76-
static constexpr int RANKO = TensorTypeO::Rank();
77-
7875
/**
7976
* Construct a sparse2dense handle.
8077
*/
8178
Sparse2DenseHandle_t(TensorTypeO &o, const TensorTypeA &a,
8279
cudaStream_t stream) {
8380
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
84-
85-
static_assert(RANKA == RANKO);
86-
8781
params_ = GetConvParams(o, a, stream);
8882

8983
[[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_);
@@ -221,7 +215,22 @@ void sparse2dense_impl(OutputTensorType &o, const InputTensorType &a,
221215
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
222216
const auto stream = exec.getStream();
223217

224-
// TODO: some more checking, supported type? on device? etc.
218+
using TA = typename InputTensorType::value_type;
219+
using TO = typename OutputTensorType::value_type;
220+
221+
// Restrictions.
222+
static_assert(OutputTensorType::Rank() == InputTensorType::Rank(),
223+
"tensors must have same rank");
224+
static_assert(std::is_same_v<TA, TO>,
225+
"tensors must have the same data type");
226+
static_assert(std::is_same_v<TA, int8_t> ||
227+
std::is_same_v<TA, matx::matxFp16> ||
228+
std::is_same_v<TA, matx::matxBf16> ||
229+
std::is_same_v<TA, float> ||
230+
std::is_same_v<TA, double> ||
231+
std::is_same_v<TA, cuda::std::complex<float>> ||
232+
std::is_same_v<TA, cuda::std::complex<double>>,
233+
"unsupported data type");
225234

226235
// Get parameters required by these tensors (for caching).
227236
auto params =

include/matx/transforms/matmul/matmul_cusparse.h

+24-14
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,6 @@ class MatMulCUSPARSEHandle_t {
7979
using TB = typename TensorTypeB::value_type;
8080
using TC = typename TensorTypeC::value_type;
8181

82-
static constexpr int RANKA = TensorTypeC::Rank();
83-
static constexpr int RANKB = TensorTypeC::Rank();
84-
static constexpr int RANKC = TensorTypeC::Rank();
85-
8682
/**
8783
* Construct a sparse GEMM handle
8884
* SpMV
@@ -94,15 +90,6 @@ class MatMulCUSPARSEHandle_t {
9490
const TensorTypeB &b, cudaStream_t stream, float alpha,
9591
float beta) {
9692
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
97-
98-
static_assert(RANKA == 2);
99-
static_assert(RANKB == 2);
100-
static_assert(RANKC == 2);
101-
102-
MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 2), matxInvalidSize);
103-
MATX_ASSERT(c.Size(RANKC - 1) == b.Size(RANKB - 1), matxInvalidSize);
104-
MATX_ASSERT(c.Size(RANKC - 2) == a.Size(RANKA - 2), matxInvalidSize);
105-
10693
params_ = GetGemmParams(c, a, b, stream, alpha, beta);
10794

10895
[[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_);
@@ -261,7 +248,30 @@ void sparse_matmul_impl(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB
261248
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
262249
const auto stream = exec.getStream();
263250

264-
// TODO: some more checking, supported type? on device? etc.
251+
using TA = typename TensorTypeA::value_type;
252+
using TB = typename TensorTypeB::value_type;
253+
using TC = typename TensorTypeC::value_type;
254+
255+
static constexpr int RANKA = TensorTypeA::Rank();
256+
static constexpr int RANKB = TensorTypeB::Rank();
257+
static constexpr int RANKC = TensorTypeC::Rank();
258+
259+
// Restrictions.
260+
static_assert(RANKA == 2 && RANKB == 2 && RANKC == 2,
261+
"tensors must have rank-2");
262+
static_assert(std::is_same_v<TC, TA> &&
263+
std::is_same_v<TC, TB>,
264+
"tensors must have the same data type");
265+
// TODO: allow MIXED-PRECISION computation!
266+
static_assert(std::is_same_v<TC, float> ||
267+
std::is_same_v<TC, double> ||
268+
std::is_same_v<TC, cuda::std::complex<float>> ||
269+
std::is_same_v<TC, cuda::std::complex<double>>,
270+
"unsupported data type");
271+
MATX_ASSERT(
272+
a.Size(RANKA - 1) == b.Size(RANKB - 2) &&
273+
c.Size(RANKC - 1) == b.Size(RANKB - 1) &&
274+
c.Size(RANKC - 2) == a.Size(RANKA - 2), matxInvalidSize);
265275

266276
// Get parameters required by these tensors (for caching).
267277
auto params =

include/matx/transforms/solve/solve_cudss.h

+29-17
Original file line numberDiff line numberDiff line change
@@ -75,23 +75,9 @@ class SolveCUDSSHandle_t {
7575
using TB = typename TensorTypeB::value_type;
7676
using TC = typename TensorTypeC::value_type;
7777

78-
static constexpr int RANKA = TensorTypeC::Rank();
79-
static constexpr int RANKB = TensorTypeC::Rank();
80-
static constexpr int RANKC = TensorTypeC::Rank();
81-
8278
SolveCUDSSHandle_t(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b,
8379
cudaStream_t stream) {
8480
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
85-
86-
static_assert(RANKA == 2);
87-
static_assert(RANKB == 2);
88-
static_assert(RANKC == 2);
89-
90-
// Note: B,C transposed!
91-
MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 1), matxInvalidSize);
92-
MATX_ASSERT(a.Size(RANKA - 2) == b.Size(RANKB - 1), matxInvalidSize);
93-
MATX_ASSERT(b.Size(RANKB - 2) == c.Size(RANKC - 2), matxInvalidSize);
94-
9581
params_ = GetSolveParams(c, a, b, stream);
9682

9783
[[maybe_unused]] cudssStatus_t ret = cudssCreate(&handle_);
@@ -100,7 +86,7 @@ class SolveCUDSSHandle_t {
10086
// Create cuDSS handle for sparse matrix A.
10187
static_assert(is_sparse_tensor_v<TensorTypeA>);
10288
MATX_ASSERT(TypeToInt<typename TensorTypeA::pos_type> ==
103-
TypeToInt<typename TensorTypeA::crd_type>,
89+
TypeToInt<typename TensorTypeA::crd_type>,
10490
matxNotSupported);
10591
cudaDataType itp = MatXTypeToCudaType<typename TensorTypeA::crd_type>();
10692
cudaDataType dta = MatXTypeToCudaType<TA>();
@@ -244,7 +230,29 @@ void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a,
244230
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
245231
const auto stream = exec.getStream();
246232

247-
// TODO: some more checking, supported type? on device? etc.
233+
using TA = typename TensorTypeA::value_type;
234+
using TB = typename TensorTypeB::value_type;
235+
using TC = typename TensorTypeC::value_type;
236+
237+
static constexpr int RANKA = TensorTypeA::Rank();
238+
static constexpr int RANKB = TensorTypeB::Rank();
239+
static constexpr int RANKC = TensorTypeC::Rank();
240+
241+
// Restrictions.
242+
static_assert(RANKA == 2 && RANKB == 2 && RANKC == 2,
243+
"tensors must have rank-2");
244+
static_assert(std::is_same_v<TC, TA> &&
245+
std::is_same_v<TC, TB>,
246+
"tensors must have the same data type");
247+
static_assert(std::is_same_v<TC, float> ||
248+
std::is_same_v<TC, double> ||
249+
std::is_same_v<TC, cuda::std::complex<float>> ||
250+
std::is_same_v<TC, cuda::std::complex<double>>,
251+
"unsupported data type");
252+
MATX_ASSERT( // Note: B,C transposed!
253+
a.Size(RANKA - 1) == b.Size(RANKB - 1) &&
254+
a.Size(RANKA - 2) == b.Size(RANKB - 1) &&
255+
b.Size(RANKB - 2) == c.Size(RANKC - 2), matxInvalidSize);
248256

249257
// Get parameters required by these tensors (for caching).
250258
auto params = detail::SolveCUDSSHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>::GetSolveParams(
@@ -266,12 +274,16 @@ void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a,
266274
// convoluted way of performing the solve step must be removed once cuDSS
267275
// supports MATX native row-major storage, which will clean up the copies from
268276
// and to memory.
277+
//
278+
// TODO: remove this when cuDSS supports row-major storage
279+
//
269280
template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
270281
void sparse_solve_impl(TensorTypeC &c, const TensorTypeA &a,
271282
const TensorTypeB &b, const cudaExecutor &exec) {
272283
const auto stream = exec.getStream();
273284

274-
// Some copying-in hacks, assumes rank 2.
285+
// Some copying-in hacks.
286+
static_assert(TensorTypeB::Rank() == 2 && TensorTypeC::Rank() == 2);
275287
using TB = typename TensorTypeB::value_type;
276288
using TC = typename TensorTypeB::value_type;
277289
TB *bptr;

0 commit comments

Comments
 (0)