|
| 1 | +"""Wrappers for linear solver packages to handle factorizing and solving linear systems.""" |
| 2 | + |
| 3 | +import time |
| 4 | +from typing import Hashable, Dict |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import scipy as sp |
| 8 | +import scipy.sparse |
| 9 | +from pyMKL import pardisoSolver |
| 10 | + |
| 11 | + |
| 12 | +# This is an arbitrary limit on the maximum number of factorizations we will store before clearing memory. |
| 13 | +_MAX_FACTORIZATIONS: int = 20 |
| 14 | + |
| 15 | + |
| 16 | +class MultipleSystemPardisoSolver(pardisoSolver): |
| 17 | + """A specialized subclass of the pardisoSolver from pyMKL for supporting factorizations of multiple systems. |
| 18 | + Note that if the matrix sparsity pattern or size changes, then `clear()` MUST be called before trying to solve with |
| 19 | + the new matrix. |
| 20 | + Args: |
| 21 | + verbose: Whether or not to show verbose output. Defaults to True. |
| 22 | + maxfct: The maximum number of factorizations to store before resetting and clearing all memory. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, verbose=True, maxfct: int = _MAX_FACTORIZATIONS): |
| 26 | + self.maxfct = maxfct |
| 27 | + self._cache: Dict[Hashable, int] = {} |
| 28 | + self._next_mnum = 1 |
| 29 | + self._verbose = verbose |
| 30 | + self._must_initialize_super = True |
| 31 | + |
| 32 | + @staticmethod |
| 33 | + def _matrix_to_key(matrix: sp.sparse.csr_matrix) -> Hashable: |
| 34 | + # We implicitly assume here that all matrices with the same nonzero entries have the same sparsity structure! |
| 35 | + return matrix.data.tobytes() |
| 36 | + |
| 37 | + def clear(self): |
| 38 | + """Clear the memory for all matrices and reset the cache.""" |
| 39 | + super().clear() |
| 40 | + self._next_mnum = 1 |
| 41 | + self._cache = {} |
| 42 | + self._must_initialize_super = True |
| 43 | + |
| 44 | + def _initialize_if_needed(self, matrix: sp.sparse.csr_matrix): |
| 45 | + if not self._must_initialize_super: |
| 46 | + return |
| 47 | + old = self.maxfct |
| 48 | + super().__init__(matrix, mtype=13) # Set to 13 (complex unsymmetric), which is correct for SC-PML. |
| 49 | + self.maxfct = old |
| 50 | + |
| 51 | + self._cache[self._matrix_to_key(matrix)] = self._next_mnum |
| 52 | + |
| 53 | + if self._verbose: |
| 54 | + print("Performing a brand-new symbolic factorization...") |
| 55 | + start = time.time() |
| 56 | + self.run_pardiso(phase=11) |
| 57 | + if self._verbose: |
| 58 | + print("(took %3.3f seconds)" % (time.time() - start)) |
| 59 | + |
| 60 | + self._must_initialize_super = False |
| 61 | + |
| 62 | + def set_matrix(self, matrix: sp.sparse.csr_matrix): |
| 63 | + """Set the matrix to `matrix`, which will perform a new factorization if `matrix` has not been seen before.""" |
| 64 | + key = self._matrix_to_key(matrix) |
| 65 | + if key in self._cache: |
| 66 | + self.mnum = self._cache[key] |
| 67 | + return |
| 68 | + |
| 69 | + # Otherwise, this is a brand new matrix. |
| 70 | + if self._next_mnum > self.maxfct: |
| 71 | + if self._verbose: |
| 72 | + print("Clearing all factorizations (reached max limit)") |
| 73 | + self.clear() |
| 74 | + |
| 75 | + self._initialize_if_needed(matrix) |
| 76 | + self.mnum = self._next_mnum |
| 77 | + self._cache[key] = self.mnum |
| 78 | + self._next_mnum += 1 |
| 79 | + self._set_pardiso_matrix_data(matrix) |
| 80 | + |
| 81 | + def solve(self, matrix: sp.sparse.csr_matrix, rhs: np.ndarray, transpose: bool = False) -> np.ndarray: |
| 82 | + """Return `matrix` inverse times `rhs`.""" |
| 83 | + self.set_matrix(matrix) |
| 84 | + if transpose: |
| 85 | + self.iparm[11] = 2 |
| 86 | + else: |
| 87 | + self.iparm[11] = 0 |
| 88 | + if self._verbose: |
| 89 | + print("Performing a solve...", transpose) |
| 90 | + start = time.time() |
| 91 | + out = super().solve(rhs).reshape(rhs.shape) |
| 92 | + if self._verbose: |
| 93 | + print("(took %3.3f seconds)" % (time.time() - start)) |
| 94 | + return out |
| 95 | + |
| 96 | + def _set_pardiso_matrix_data(self, matrix: sp.sparse.csr_matrix): |
| 97 | + A = matrix |
| 98 | + # If A is symmetric, store only the upper triangular portion |
| 99 | + if self.mtype in [2, -2, 4, -4, 6]: |
| 100 | + A = sp.sparse.triu(A, format="csr") |
| 101 | + elif self.mtype in [11, 13]: |
| 102 | + A = A.tocsr() |
| 103 | + |
| 104 | + if not A.has_sorted_indices: |
| 105 | + A.sort_indices() |
| 106 | + |
| 107 | + self.a = A.data |
| 108 | + self._MKL_a = self.a.ctypes.data_as(self.ctypes_dtype) |
| 109 | + if self._verbose: |
| 110 | + print("Performing a brand-new numerical factorization...") |
| 111 | + start = time.time() |
| 112 | + self.run_pardiso(phase=22) |
| 113 | + if self._verbose: |
| 114 | + print("(took %3.3f seconds)" % (time.time() - start)) |
| 115 | + |
| 116 | +multiple_system_solver = MultipleSystemPardisoSolver(verbose=False) |
0 commit comments