Skip to content

Commit

Permalink
Merge pull request #69 from haasad/sparray-tests
Browse files Browse the repository at this point in the history
Add minimal tests to ensure pypardiso works with scipy sparse arrays
  • Loading branch information
haasad authored Mar 5, 2024
2 parents 7f53e21 + 4cd6034 commit fed4dd5
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 3 deletions.
8 changes: 8 additions & 0 deletions tests/test_basic_solve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# coding: utf-8
import scipy.sparse as sp

from utils import create_test_A_b_rand, basic_solve
from pypardiso.scipy_aliases import pypardiso_solver

Expand All @@ -13,3 +15,9 @@ def test_bvector_smoketest():
def test_bmatrix_smoketest():
A, b = create_test_A_b_rand(matrix=True)
basic_solve(A, b)


def test_Aarray_smoketest():
A, b = create_test_A_b_rand()
A = sp.csr_array(A)
basic_solve(A, b)
2 changes: 1 addition & 1 deletion tests/test_input_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_input_b_sparse():
bsparse = sp.csr_matrix(b)
with pytest.warns(SparseEfficiencyWarning):
x = ps.solve(A, bsparse)
np.testing.assert_array_almost_equal(A*x, b)
np.testing.assert_array_almost_equal(A @ x, b)


def test_input_b_shape():
Expand Down
12 changes: 11 additions & 1 deletion tests/test_scipy_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,17 @@ def test_basic_spsolve_matrix():
np.testing.assert_array_almost_equal(xpp, xscipy)


@pytest.mark.filterwarnings("ignore:splu requires CSC matrix format")
def test_basic_spsolve_sparray():
ps.remove_stored_factorization()
ps.free_memory()
A, b = create_test_A_b_rand(matrix=True)
A = sp.csr_array(A)
xpp = spsolve(A, b)
xscipy = scipyspsolve(A, b)
np.testing.assert_array_almost_equal(xpp, xscipy)


@pytest.mark.filterwarnings("ignore:splu converted its input to CSC format")
def test_basic_factorized():
ps.remove_stored_factorization()
ps.free_memory()
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ def create_test_A_b_rand(n=1000, density=0.05, matrix=False):

def basic_solve(A, b):
x = ps.solve(A, b)
np.testing.assert_array_almost_equal(A*x, b)
np.testing.assert_array_almost_equal(A @ x, b)

0 comments on commit fed4dd5

Please sign in to comment.