149
149
import math
150
150
151
151
import numpy as np
152
+ import numpy .typing as npt
152
153
from numpy .ma .core import logical_not
153
154
154
155
from .observables import ElementObservable
@@ -220,6 +221,7 @@ def _resp_fork(variables: VariableList, **kwargs):
220
221
221
222
class _SvdSolver (abc .ABC ):
222
223
"""SVD solver for response matrices."""
224
+ _response : npt .NDArray [float ] | None
223
225
224
226
def __init__ (self , nobs : int , nvar : int ):
225
227
self ._shape = (nobs , nvar )
@@ -284,15 +286,15 @@ def check_norm(self) -> tuple[np.ndarray, np.ndarray]:
284
286
return obs , var
285
287
286
288
@property
287
- def response (self ) -> np . ndarray :
289
+ def response (self ) -> npt . NDArray [ float ] :
288
290
"""Response matrix."""
289
291
resp = self ._response
290
292
if resp is None :
291
293
raise AtError ("No matrix yet: run build() or load() first" )
292
294
return resp
293
295
294
296
@response .setter
295
- def response (self , response : np . ndarray ) -> None :
297
+ def response (self , response : npt . NDArray [ float ] ) -> None :
296
298
l1 , c1 = self ._shape
297
299
l2 , c2 = response .shape
298
300
if l1 != l1 or c1 != c2 :
@@ -306,7 +308,7 @@ def weighted_response(self) -> np.ndarray:
306
308
"""Weighted response matrix."""
307
309
return self .response * (self .varweights / self .obsweights .reshape (- 1 , 1 ))
308
310
309
- def correction_matrix (self , nvals : int | None = None ) -> np . ndarray :
311
+ def correction_matrix (self , nvals : int | None = None ) -> npt . NDArray [ float ] :
310
312
"""Return the correction matrix (pseudo-inverse of the response matrix).
311
313
312
314
Args:
@@ -327,7 +329,7 @@ def correction_matrix(self, nvals: int | None = None) -> np.ndarray:
327
329
328
330
def get_correction (
329
331
self , observed : np .ndarray , nvals : int | None = None
330
- ) -> np . ndarray :
332
+ ) -> npt . NDArray [ float ] :
331
333
"""Compute the correction of the given observation.
332
334
333
335
Args:
0 commit comments