Skip to content

Commit ddf1512

Browse files
committed
typing
1 parent 416befb commit ddf1512

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

docs/p/notebooks/response_matrices.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@
550550
"name": "stderr",
551551
"output_type": "stream",
552552
"text": [
553-
"/Users/laurent/dev/libraries/at/pyat/at/latticetools/response_matrix.py:572: AtWarning: No new excluded value\n",
553+
"/Users/laurent/dev/libraries/at/pyat/at/latticetools/response_matrix.py:580: AtWarning: No new excluded value\n",
554554
" warnings.warn(AtWarning(\"No new excluded value\"), stacklevel=1)\n"
555555
]
556556
}
@@ -591,7 +591,7 @@
591591
"name": "stderr",
592592
"output_type": "stream",
593593
"text": [
594-
"/Users/laurent/dev/libraries/at/pyat/at/latticetools/response_matrix.py:572: AtWarning: No new excluded value\n",
594+
"/Users/laurent/dev/libraries/at/pyat/at/latticetools/response_matrix.py:580: AtWarning: No new excluded value\n",
595595
" warnings.warn(AtWarning(\"No new excluded value\"), stacklevel=1)\n"
596596
]
597597
}

pyat/at/latticetools/response_matrix.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
import math
150150

151151
import numpy as np
152+
import numpy.typing as npt
152153
from numpy.ma.core import logical_not
153154

154155
from .observables import ElementObservable
@@ -220,6 +221,7 @@ def _resp_fork(variables: VariableList, **kwargs):
220221

221222
class _SvdSolver(abc.ABC):
222223
"""SVD solver for response matrices."""
224+
_response: npt.NDArray[float] | None
223225

224226
def __init__(self, nobs: int, nvar: int):
225227
self._shape = (nobs, nvar)
@@ -284,15 +286,15 @@ def check_norm(self) -> tuple[np.ndarray, np.ndarray]:
284286
return obs, var
285287

286288
@property
287-
def response(self) -> np.ndarray:
289+
def response(self) -> npt.NDArray[float]:
288290
"""Response matrix."""
289291
resp = self._response
290292
if resp is None:
291293
raise AtError("No matrix yet: run build() or load() first")
292294
return resp
293295

294296
@response.setter
295-
def response(self, response: np.ndarray) -> None:
297+
def response(self, response: npt.NDArray[float]) -> None:
296298
l1, c1 = self._shape
297299
l2, c2 = response.shape
298300
if l1 != l1 or c1 != c2:
@@ -306,7 +308,7 @@ def weighted_response(self) -> np.ndarray:
306308
"""Weighted response matrix."""
307309
return self.response * (self.varweights / self.obsweights.reshape(-1, 1))
308310

309-
def correction_matrix(self, nvals: int | None = None) -> np.ndarray:
311+
def correction_matrix(self, nvals: int | None = None) -> npt.NDArray[float]:
310312
"""Return the correction matrix (pseudo-inverse of the response matrix).
311313
312314
Args:
@@ -327,7 +329,7 @@ def correction_matrix(self, nvals: int | None = None) -> np.ndarray:
327329

328330
def get_correction(
329331
self, observed: np.ndarray, nvals: int | None = None
330-
) -> np.ndarray:
332+
) -> npt.NDArray[float]:
331333
"""Compute the correction of the given observation.
332334
333335
Args:

0 commit comments

Comments
 (0)