From e550f1d236d6f84bfb18abfe3526d9c30388b32a Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Wed, 20 Dec 2023 23:51:50 +0000 Subject: [PATCH] sparsity: add test --- test/unit/test_matrices.py | 40 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/unit/test_matrices.py b/test/unit/test_matrices.py index 306637ee7..4f8ab1d1e 100644 --- a/test/unit/test_matrices.py +++ b/test/unit/test_matrices.py @@ -39,6 +39,7 @@ from pyop2 import op2 from pyop2.exceptions import MapValueError, ModeValueError from pyop2.mpi import COMM_WORLD +from pyop2.datatypes import IntType from petsc4py.PETSc import ScalarType @@ -941,6 +942,45 @@ def test_assemble_mixed_rhs_vector(self, mset, mmap, mvdat): assert_allclose(dat[1].data_ro, exp, eps) +def test_matrices_sparsity_blockwise_specification(): + # + # 0 1 2 3 nodesetA + # x----x----x----x + # 0 1 2 setA + # + # 0 1 2 nodesetB + # x----x----x + # 0 1 setB + # + # 0 1 2 3 | 0 1 2 + # 0 x | + # 1 x | x x + # 2 x | x x x + # 3 x | x x sparsity + # ----------+------ + # 0 x x | x + # 1 x x x | x + # 2 x x | x + # + arity = 2 + setA = op2.Set(3) + nodesetA = op2.Set(4) + setB = op2.Set(2) + nodesetB = op2.Set(3) + nodesetAB = op2.MixedSet((nodesetA, nodesetB)) + datasetAB = nodesetAB ** 1 + mapA = op2.Map(setA, nodesetA, arity, values=[[0, 1], [1, 2], [2, 3]]) + mapB = op2.Map(setB, nodesetB, arity, values=[[0, 1], [1, 2]]) + mapBA = op2.Map(setB, setA, 1, values=[1, 2]) + mapAB = op2.Map(setA, setB, 1, values=[-1, 0, 1]) # "inverse" map + s = op2.Sparsity((datasetAB, datasetAB), {(1, 0): [(mapB, op2.ComposedMap(mapA, mapBA), None)], + (0, 1): [(mapA, op2.ComposedMap(mapB, mapAB), None)]}) + assert np.all(s._blocks[0][0].nnz == np.array([1, 1, 1, 1], dtype=IntType)) + assert np.all(s._blocks[0][1].nnz == np.array([0, 2, 3, 2], dtype=IntType)) + assert np.all(s._blocks[1][0].nnz == np.array([2, 3, 2], dtype=IntType)) + assert np.all(s._blocks[1][1].nnz == np.array([1, 1, 1], dtype=IntType)) + + if __name__ == '__main__': import os pytest.main(os.path.abspath(__file__))