Skip to content

Commit

Permalink
Enabled mixed precision tests for SpMM and SpMV (#910)
Browse files Browse the repository at this point in the history
  • Loading branch information
aartbik authored Mar 11, 2025
1 parent 207dfc7 commit 6d38754
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
2 changes: 1 addition & 1 deletion test/00_sparse/Matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ protected:

template <typename T> class MatmulSparseTestsAll : public MatmulSparseTest<T> { };

TYPED_TEST_SUITE(MatmulSparseTestsAll, MatXFloatNonHalfTypesCUDAExec);
TYPED_TEST_SUITE(MatmulSparseTestsAll, MatXFloatNonComplexHalfTypesCUDAExec);

TYPED_TEST(MatmulSparseTestsAll, MatmulCOO) {
MATX_ENTER_HANDLER();
Expand Down
2 changes: 1 addition & 1 deletion test/00_sparse/Matvec.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ protected:

template <typename T> class MatvecSparseTestsAll : public MatvecSparseTest<T> { };

TYPED_TEST_SUITE(MatvecSparseTestsAll, MatXFloatNonHalfTypesCUDAExec);
TYPED_TEST_SUITE(MatvecSparseTestsAll, MatXFloatNonComplexHalfTypesCUDAExec);

TYPED_TEST(MatvecSparseTestsAll, MatvecCOO) {
MATX_ENTER_HANDLER();
Expand Down
4 changes: 4 additions & 0 deletions test/include/test_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ using MatXFloatNonComplexNonHalfTuple = cuda::std::tuple<float, double>;
using MatXNumericNonHalfTuple = cuda::std::tuple<uint32_t, int32_t, uint64_t, int64_t, float, double,
cuda::std::complex<float>, cuda::std::complex<double>>;
using MatXFloatNonHalfTuple = cuda::std::tuple<float, double, cuda::std::complex<float>, cuda::std::complex<double>>;
using MatXFloatNonComplexHalfTuple = cuda::std::tuple<matx::matxFp16, matx::matxBf16, float, double,
cuda::std::complex<float>, cuda::std::complex<double>>;
using MatXComplexNonHalfTuple = cuda::std::tuple<cuda::std::complex<float>, cuda::std::complex<double>>;
using MatXNumericNonComplexTuple = cuda::std::tuple<uint32_t, int32_t, uint64_t, int64_t, float, double>;
using MatXComplexTuple = cuda::std::tuple<cuda::std::complex<float>, cuda::std::complex<double>,
Expand Down Expand Up @@ -138,6 +140,7 @@ using MatXDoubleOnlyTuple = cuda::std::tuple<double>;
using MatXAllTypesCUDAExec = TupleToTypes<TypedCartesianProduct<MatXAllTuple, ExecutorTypesCUDAOnly>::type>::type;
using MatXFloatTypesCUDAExec = TupleToTypes<TypedCartesianProduct<MatXFloatTuple, ExecutorTypesCUDAOnly>::type>::type;
using MatXFloatNonHalfTypesCUDAExec = TupleToTypes<TypedCartesianProduct<MatXFloatNonHalfTuple, ExecutorTypesCUDAOnly>::type>::type;
using MatXFloatNonComplexHalfTypesCUDAExec = TupleToTypes<TypedCartesianProduct<MatXFloatNonComplexHalfTuple, ExecutorTypesCUDAOnly>::type>::type;
using MatXFloatNonComplexTypesCUDAExec = TupleToTypes<TypedCartesianProduct<MatXFloatNonComplexTuple, ExecutorTypesCUDAOnly>::type>::type;
using MatXFloatHalfTypesCUDAExec = TupleToTypes<TypedCartesianProduct<MatXFloatHalfTuple, ExecutorTypesCUDAOnly>::type>::type;
using MatXNumericTypesCUDAExec = TupleToTypes<TypedCartesianProduct<MatXNumericTuple, ExecutorTypesCUDAOnly>::type>::type;
Expand All @@ -153,6 +156,7 @@ using MatXDoubleOnlyTypeCUDAExec = TupleToTypes<TypedCartesianProdu
// All executor types
using MatXNumericNonComplexTypesAllExecs = TupleToTypes<TypedCartesianProduct<MatXNumericNonComplexTuple, ExecutorTypesAll>::type>::type;
using MatXFloatNonHalfTypesAllExecs = TupleToTypes<TypedCartesianProduct<MatXFloatNonHalfTuple, ExecutorTypesAll>::type>::type;
using MatXFloatNonComplexHalfTypesAllExecs = TupleToTypes<TypedCartesianProduct<MatXFloatNonComplexHalfTuple, ExecutorTypesAll>::type>::type;
using MatXFloatNonComplexNonHalfTypesAllExecs = TupleToTypes<TypedCartesianProduct<MatXFloatNonComplexNonHalfTuple, ExecutorTypesAll>::type>::type;
using MatXNumericNoHalfTypesAllExecs = TupleToTypes<TypedCartesianProduct<MatXNumericNonHalfTuple, ExecutorTypesAll>::type>::type;
using MatXComplexNonHalfTypesAllExecs = TupleToTypes<TypedCartesianProduct<MatXComplexNonHalfTuple, ExecutorTypesAll>::type>::type;
Expand Down

0 comments on commit 6d38754

Please sign in to comment.