Skip to content

Commit efe8f43

Browse files
authored
Add sparse direct-solver tests (#874)
1 parent 63150a8 commit efe8f43

File tree

3 files changed

+117
-3
lines changed

3 files changed

+117
-3
lines changed

include/matx/transforms/solve/solve_cudss.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,6 @@ class SolveCUDSSHandle_t {
8585

8686
// Create cuDSS handle for sparse matrix A.
8787
static_assert(is_sparse_tensor_v<TensorTypeA>);
88-
MATX_ASSERT(TypeToInt<typename TensorTypeA::pos_type> ==
89-
TypeToInt<typename TensorTypeA::crd_type>,
90-
matxNotSupported);
9188
cudaDataType itp = MatXTypeToCudaType<typename TensorTypeA::crd_type>();
9289
cudaDataType dta = MatXTypeToCudaType<TA>();
9390
cudssMatrixType_t mtp = CUDSS_MTYPE_GENERAL;
@@ -253,6 +250,8 @@ void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a,
253250
a.Size(RANKA - 1) == b.Size(RANKB - 1) &&
254251
a.Size(RANKA - 2) == b.Size(RANKB - 1) &&
255252
b.Size(RANKB - 2) == c.Size(RANKC - 2), matxInvalidSize);
253+
static_assert(std::is_same_v<typename TensorTypeA::pos_type, int32_t> &&
254+
std::is_same_v<typename TensorTypeA::crd_type, int32_t>, "unsupported index type");
256255

257256
// Get parameters required by these tensors (for caching).
258257
auto params = detail::SolveCUDSSHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>::GetSolveParams(

test/00_sparse/Solve.cu

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// BSD 3-Clause License
3+
//
4+
// Copyright (c) 2025, NVIDIA Corporation
5+
// All rights reserved.
6+
//
7+
// Redistribution and use in source and binary forms, with or without
8+
// modification, are permitted provided that the following conditions are met:
9+
//
10+
// 1. Redistributions of source code must retain the above copyright notice,
11+
// this list of conditions and the following disclaimer.
12+
//
13+
// 2. Redistributions in binary form must reproduce the above copyright notice,
14+
// this list of conditions and the following disclaimer in the documentation
15+
// and/or other materials provided with the distribution.
16+
//
17+
// 3. Neither the name of the copyright holder nor the names of its
18+
// contributors may be used to endorse or promote products derived from
19+
// this software without specific prior written permission.
20+
//
21+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
/////////////////////////////////////////////////////////////////////////////////
32+
33+
#include "assert.h"
34+
#include "matx.h"
35+
#include "test_types.h"
36+
#include "utilities.h"
37+
#include "gtest/gtest.h"
38+
39+
using namespace matx;
40+
41+
template <typename T> class SolveSparseTest : public ::testing::Test {
42+
protected:
43+
float thresh = 0.001f;
44+
};
45+
46+
template <typename T> class SolveSparseTestsAll : public SolveSparseTest<T> { };
47+
48+
TYPED_TEST_SUITE(SolveSparseTestsAll, MatXFloatNonHalfTypesCUDAExec);
49+
50+
TYPED_TEST(SolveSparseTestsAll, SolveCSR) {
51+
MATX_ENTER_HANDLER();
52+
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
53+
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;
54+
55+
ExecType exec{};
56+
57+
//
58+
// Setup a system of equations AX=Y, where X is the unknown.
59+
// Solve for sparse A in CSR format.
60+
//
61+
// | 1 2 0 0 | | 1 5 | | 5 17 |
62+
// | 0 3 0 0 | x | 2 6 | = | 6 18 |
63+
// | 0 0 4 0 | | 3 7 | | 12 28 |
64+
// | 0 0 0 5 | | 4 8 | | 20 40 |
65+
//
66+
auto A = make_tensor<TestType>({4, 4});
67+
auto X = make_tensor<TestType>({4, 2});
68+
auto E = make_tensor<TestType>({4, 2});
69+
auto Y = make_tensor<TestType>({4, 2});
70+
// Coeffs.
71+
A(0, 0) = static_cast<TestType>(1); A(0, 1) = static_cast<TestType>(2);
72+
A(0, 2) = static_cast<TestType>(0); A(0, 3) = static_cast<TestType>(0);
73+
A(1, 0) = static_cast<TestType>(0); A(1, 1) = static_cast<TestType>(3);
74+
A(1, 2) = static_cast<TestType>(0); A(1, 3) = static_cast<TestType>(0);
75+
A(2, 0) = static_cast<TestType>(0); A(2, 1) = static_cast<TestType>(0);
76+
A(2, 2) = static_cast<TestType>(4); A(2, 3) = static_cast<TestType>(0);
77+
A(3, 0) = static_cast<TestType>(0); A(3, 1) = static_cast<TestType>(0);
78+
A(3, 2) = static_cast<TestType>(0); A(3, 3) = static_cast<TestType>(5);
79+
// Expected.
80+
E(0, 0) = static_cast<TestType>(1); E(0, 1) = static_cast<TestType>(5);
81+
E(1, 0) = static_cast<TestType>(2); E(1, 1) = static_cast<TestType>(6);
82+
E(2, 0) = static_cast<TestType>(3); E(2, 1) = static_cast<TestType>(7);
83+
E(3, 0) = static_cast<TestType>(4); E(3, 1) = static_cast<TestType>(8);
84+
// RHS.
85+
Y(0, 0) = static_cast<TestType>(5); Y(0, 1) = static_cast<TestType>(17);
86+
Y(1, 0) = static_cast<TestType>(6); Y(1, 1) = static_cast<TestType>(18);
87+
Y(2, 0) = static_cast<TestType>(12); Y(2, 1) = static_cast<TestType>(28);
88+
Y(3, 0) = static_cast<TestType>(20); Y(3, 1) = static_cast<TestType>(40);
89+
90+
// Convert dense A to sparse S in CSR format with int-32 indices.
91+
auto S = experimental::make_zero_tensor_csr<TestType, int32_t, int32_t>({4, 4});
92+
(S = dense2sparse(A)).run(exec);
93+
ASSERT_EQ(S.Nse(), 5);
94+
95+
// Solve the system.
96+
(X = solve(S, Y)).run(exec);
97+
98+
// Verify result.
99+
exec.sync();
100+
for (index_t i = 0; i < 4; i++) {
101+
for (index_t j = 0; j < 2; j++) {
102+
if constexpr (is_complex_v<TestType>) {
103+
ASSERT_NEAR(X(i, j).real(), E(i, j).real(), this->thresh);
104+
ASSERT_NEAR(X(i, j).imag(), E(i,j ).imag(), this->thresh);
105+
}
106+
else {
107+
ASSERT_NEAR(X(i, j), E(i, j), this->thresh);
108+
}
109+
110+
}
111+
}
112+
113+
MATX_EXIT_HANDLER();
114+
}

test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ set (test_sources
3939
01_radar/dct.cu
4040
00_sparse/Basic.cu
4141
00_sparse/Convert.cu
42+
00_sparse/Solve.cu
4243
main.cu
4344
)
4445

0 commit comments

Comments
 (0)