Skip to content

Commit 263e586

Browse files
Create linear_solver.py
1 parent c6c0dd8 commit 263e586

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed

linear_solver.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""Wrappers for linear solver packages to handle factorizing and solving linear systems."""
2+
3+
import time
4+
from typing import Hashable, Dict
5+
6+
import numpy as np
7+
import scipy as sp
8+
import scipy.sparse
9+
from pyMKL import pardisoSolver
10+
11+
12+
# This is an arbitrary limit on the maximum number of factorizations we will store before clearing memory.
13+
_MAX_FACTORIZATIONS: int = 20
14+
15+
16+
class MultipleSystemPardisoSolver(pardisoSolver):
17+
"""A specialized subclass of the pardisoSolver from pyMKL for supporting factorizations of multiple systems.
18+
Note that if the matrix sparsity pattern or size changes, then `clear()` MUST be called before trying to solve with
19+
the new matrix.
20+
Args:
21+
verbose: Whether or not to show verbose output. Defaults to True.
22+
maxfct: The maximum number of factorizations to store before resetting and clearing all memory.
23+
"""
24+
25+
def __init__(self, verbose=True, maxfct: int = _MAX_FACTORIZATIONS):
26+
self.maxfct = maxfct
27+
self._cache: Dict[Hashable, int] = {}
28+
self._next_mnum = 1
29+
self._verbose = verbose
30+
self._must_initialize_super = True
31+
32+
@staticmethod
33+
def _matrix_to_key(matrix: sp.sparse.csr_matrix) -> Hashable:
34+
# We implicitly assume here that all matrices with the same nonzero entries have the same sparsity structure!
35+
return matrix.data.tobytes()
36+
37+
def clear(self):
38+
"""Clear the memory for all matrices and reset the cache."""
39+
super().clear()
40+
self._next_mnum = 1
41+
self._cache = {}
42+
self._must_initialize_super = True
43+
44+
def _initialize_if_needed(self, matrix: sp.sparse.csr_matrix):
45+
if not self._must_initialize_super:
46+
return
47+
old = self.maxfct
48+
super().__init__(matrix, mtype=13) # Set to 13 (complex unsymmetric), which is correct for SC-PML.
49+
self.maxfct = old
50+
51+
self._cache[self._matrix_to_key(matrix)] = self._next_mnum
52+
53+
if self._verbose:
54+
print("Performing a brand-new symbolic factorization...")
55+
start = time.time()
56+
self.run_pardiso(phase=11)
57+
if self._verbose:
58+
print("(took %3.3f seconds)" % (time.time() - start))
59+
60+
self._must_initialize_super = False
61+
62+
def set_matrix(self, matrix: sp.sparse.csr_matrix):
63+
"""Set the matrix to `matrix`, which will perform a new factorization if `matrix` has not been seen before."""
64+
key = self._matrix_to_key(matrix)
65+
if key in self._cache:
66+
self.mnum = self._cache[key]
67+
return
68+
69+
# Otherwise, this is a brand new matrix.
70+
if self._next_mnum > self.maxfct:
71+
if self._verbose:
72+
print("Clearing all factorizations (reached max limit)")
73+
self.clear()
74+
75+
self._initialize_if_needed(matrix)
76+
self.mnum = self._next_mnum
77+
self._cache[key] = self.mnum
78+
self._next_mnum += 1
79+
self._set_pardiso_matrix_data(matrix)
80+
81+
def solve(self, matrix: sp.sparse.csr_matrix, rhs: np.ndarray, transpose: bool = False) -> np.ndarray:
82+
"""Return `matrix` inverse times `rhs`."""
83+
self.set_matrix(matrix)
84+
if transpose:
85+
self.iparm[11] = 2
86+
else:
87+
self.iparm[11] = 0
88+
if self._verbose:
89+
print("Performing a solve...", transpose)
90+
start = time.time()
91+
out = super().solve(rhs).reshape(rhs.shape)
92+
if self._verbose:
93+
print("(took %3.3f seconds)" % (time.time() - start))
94+
return out
95+
96+
def _set_pardiso_matrix_data(self, matrix: sp.sparse.csr_matrix):
97+
A = matrix
98+
# If A is symmetric, store only the upper triangular portion
99+
if self.mtype in [2, -2, 4, -4, 6]:
100+
A = sp.sparse.triu(A, format="csr")
101+
elif self.mtype in [11, 13]:
102+
A = A.tocsr()
103+
104+
if not A.has_sorted_indices:
105+
A.sort_indices()
106+
107+
self.a = A.data
108+
self._MKL_a = self.a.ctypes.data_as(self.ctypes_dtype)
109+
if self._verbose:
110+
print("Performing a brand-new numerical factorization...")
111+
start = time.time()
112+
self.run_pardiso(phase=22)
113+
if self._verbose:
114+
print("(took %3.3f seconds)" % (time.time() - start))
115+
116+
multiple_system_solver = MultipleSystemPardisoSolver(verbose=False)

0 commit comments

Comments
 (0)