Skip to content

Commit 52e055f

Browse files
committed
bug fix
1 parent ddf1512 commit 52e055f

File tree

4 files changed

+76
-202
lines changed

4 files changed

+76
-202
lines changed

docs/p/notebooks/response_matrices.ipynb

+30-161
Large diffs are not rendered by default.

pyat/at/lattice/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151

5252
__all__ = ['All', 'End', 'AtError', 'AtWarning', 'axis_descr',
53+
'BoolRefpts', 'Uint32Refpts',
5354
'check_radiation', 'check_6d',
5455
'set_radiation', 'set_6d',
5556
'make_copy', 'uint32_refpts', 'bool_refpts',

pyat/at/latticetools/response_matrix.py

+44-40
Original file line numberDiff line numberDiff line change
@@ -144,23 +144,25 @@
144144
import abc
145145
import warnings
146146
from collections.abc import Sequence, Generator
147+
from typing import Any, ClassVar
147148
from itertools import chain
148149
from functools import partial
149150
import math
150151

151152
import numpy as np
152153
import numpy.typing as npt
153-
from numpy.ma.core import logical_not
154154

155155
from .observables import ElementObservable
156156
from .observables import TrajectoryObservable, OrbitObservable, LatticeObservable
157157
from .observables import LocalOpticsObservable, GlobalOpticsObservable
158158
from .observablelist import ObservableList
159-
from ..lattice import AtError, AtWarning, Lattice, Refpts, AxisDef, plane_, All
160-
from ..lattice import Monitor, checkattr
159+
from ..lattice import AtError, AtWarning, Refpts, Uint32Refpts, All
160+
from ..lattice import AxisDef, plane_, Lattice, Monitor, checkattr
161161
from ..lattice.lattice_variables import RefptsVariable
162162
from ..lattice.variables import VariableList
163163

164+
FloatArray = npt.NDArray[np.float64]
165+
164166
_orbit_correctors = checkattr("KickAngle")
165167

166168
_globring: Lattice | None = None
@@ -221,16 +223,19 @@ def _resp_fork(variables: VariableList, **kwargs):
221223

222224
class _SvdSolver(abc.ABC):
223225
"""SVD solver for response matrices."""
224-
_response: npt.NDArray[float] | None
226+
227+
_shape: tuple[int, int]
228+
_obsmask: npt.NDArray[bool]
229+
_varmask: npt.NDArray[bool]
230+
_response: FloatArray | None = None
231+
v: FloatArray | None = None
232+
uh: FloatArray | None = None
233+
singular_values: FloatArray | None = None
225234

226235
def __init__(self, nobs: int, nvar: int):
227236
self._shape = (nobs, nvar)
228-
self._response = None
229237
self._obsmask = np.ones(nobs, dtype=bool)
230238
self._varmask = np.ones(nvar, dtype=bool)
231-
self.v = None
232-
self.uh = None
233-
self.singular_values = None
234239

235240
def reset_vars(self):
236241
"""Reset the variable exclusion mask: enable all variables"""
@@ -286,15 +291,15 @@ def check_norm(self) -> tuple[np.ndarray, np.ndarray]:
286291
return obs, var
287292

288293
@property
289-
def response(self) -> npt.NDArray[float]:
294+
def response(self) -> FloatArray:
290295
"""Response matrix."""
291296
resp = self._response
292297
if resp is None:
293298
raise AtError("No matrix yet: run build() or load() first")
294299
return resp
295300

296301
@response.setter
297-
def response(self, response: npt.NDArray[float]) -> None:
302+
def response(self, response: FloatArray) -> None:
298303
l1, c1 = self._shape
299304
l2, c2 = response.shape
300305
if l1 != l1 or c1 != c2:
@@ -308,7 +313,7 @@ def weighted_response(self) -> np.ndarray:
308313
"""Weighted response matrix."""
309314
return self.response * (self.varweights / self.obsweights.reshape(-1, 1))
310315

311-
def correction_matrix(self, nvals: int | None = None) -> npt.NDArray[float]:
316+
def correction_matrix(self, nvals: int | None = None) -> FloatArray:
312317
"""Return the correction matrix (pseudo-inverse of the response matrix).
313318
314319
Args:
@@ -322,14 +327,14 @@ def correction_matrix(self, nvals: int | None = None) -> npt.NDArray[float]:
322327
self.solve()
323328
if nvals is None:
324329
nvals = len(self.singular_values)
325-
cormat = np.zeros(self._shape)
326-
selected = np.ix_(self._obsmask, self._varmask)
330+
cormat = np.zeros(self._shape[::-1])
331+
selected = np.ix_(self._varmask, self._obsmask)
327332
cormat[selected] = self.v[:, :nvals] @ self.uh[:nvals, :]
328333
return cormat
329334

330335
def get_correction(
331336
self, observed: np.ndarray, nvals: int | None = None
332-
) -> npt.NDArray[float]:
337+
) -> FloatArray:
333338
"""Compute the correction of the given observation.
334339
335340
Args:
@@ -375,8 +380,11 @@ class ResponseMatrix(_SvdSolver):
375380
of their :py:class:`~.variables.VariableBase`\ s and :py:class:`.Observable`\s to
376381
produce combined responses.
377382
"""
383+
384+
ring: Lattice
378385
variables: VariableList #: List of matrix :py:class:`Variable <.VariableBase>`\ s
379386
observables: ObservableList #: List of matrix :py:class:`.Observable`\s
387+
eval_args: dict[str, Any] = {}
380388

381389
def __init__(
382390
self,
@@ -404,7 +412,6 @@ def limits(obslist):
404412
self.ring = ring
405413
self.variables = variables
406414
self.observables = observables
407-
self.buildargs = {}
408415
variables.get(ring=ring, initial=True)
409416
observables.evaluate(ring=ring, initial=True)
410417
super().__init__(len(observables.flat_values), len(variables))
@@ -418,7 +425,7 @@ def __add__(self, other: ResponseMatrix):
418425
return ResponseMatrix(
419426
self.ring,
420427
VariableList(self.variables + other.variables),
421-
self.observables + other.observables
428+
self.observables + other.observables,
422429
)
423430

424431
def __str__(self):
@@ -460,7 +467,7 @@ def correct(
460467
self.variables.get(ring=ring)
461468
sumcorr = np.array([0.0])
462469
for _ in range(niter):
463-
obs.evaluate(ring, **self.buildargs)
470+
obs.evaluate(ring, **self.eval_args)
464471
corr = self.get_correction(obs.flat_deviations, nvals=nvals)
465472
sumcorr = sumcorr + corr # non-broadcastable sumcorr
466473
if apply:
@@ -473,7 +480,7 @@ def build_tracking(
473480
pool_size: int | None = None,
474481
start_method: str | None = None,
475482
**kwargs,
476-
) -> np.ndarray:
483+
) -> FloatArray:
477484
"""Build the response matrix.
478485
479486
Args:
@@ -498,8 +505,9 @@ def build_tracking(
498505
Returns:
499506
response: Response matrix
500507
"""
501-
self.buildargs = kwargs
508+
self.eval_args = kwargs
502509
self.observables.evaluate(self.ring)
510+
ring = self.ring.deepcopy()
503511

504512
if use_mp:
505513
global _globring
@@ -509,35 +517,26 @@ def build_tracking(
509517
pool_size = min(len(self.variables), os.cpu_count())
510518
obschunks = sequence_split(self.variables, pool_size)
511519
if ctx.get_start_method() == "fork":
512-
_globring = self.ring
520+
_globring = ring
513521
_globobs = self.observables
514522
_single_resp = partial(_resp_fork, **kwargs)
515-
516523
else:
517-
_single_resp = partial(_resp, self.ring, self.observables, **kwargs)
524+
_single_resp = partial(_resp, ring, self.observables, **kwargs)
518525
with concurrent.futures.ProcessPoolExecutor(
519526
max_workers=pool_size,
520527
mp_context=ctx,
521528
) as pool:
522529
results = list(chain(*pool.map(_single_resp, obschunks)))
523-
# with ctx.Pool(pool_size) as pool:
524-
# results = pool.map(_single_resp, self.variables)
525530
_globring = None
526531
_globobs = None
527532
else:
528-
ring = self.ring
529-
boolrefs = ring.get_bool_index(None)
530-
for var in self.variables:
531-
boolrefs |= ring.get_bool_index(var.refpts)
532-
533-
ring = ring.replace(boolrefs)
534-
results = _resp(ring.deepcopy(), self.observables, self.variables, **kwargs)
533+
results = _resp(ring, self.observables, self.variables, **kwargs)
535534

536535
resp = np.stack(results, axis=-1)
537536
self.response = resp
538537
return resp
539538

540-
def build_analytical(self) -> np.ndarray:
539+
def build_analytical(self) -> FloatArray:
541540
"""Build the response matrix."""
542541
raise NotImplementedError(
543542
f"build_analytical not implemented for {self.__class__.__name__}"
@@ -626,15 +625,15 @@ def exclude_vars(self, *varid: int | str) -> None:
626625
>>> resp.exclude_vars(0, "var1", -1)
627626
628627
Exclude the 1st variable, the variable named "var1" and the last variable.
629-
"""
628+
"""
630629
nameset = set(nm for nm in varid if isinstance(nm, str))
631630
varidx = [nm for nm in varid if not isinstance(nm, str)]
632631
mask = np.array([var.name in nameset for var in self.variables])
633632
mask[varidx] = True
634633
miss = nameset - {var.name for var, ok in zip(self.variables, mask) if ok}
635634
if miss:
636635
raise ValueError(f"Unknown variables: {miss}")
637-
self._varmask &= logical_not(mask)
636+
self._varmask &= np.logical_not(mask)
638637

639638
@property
640639
def excluded_vars(self) -> list:
@@ -675,6 +674,9 @@ class OrbitResponseMatrix(ResponseMatrix):
675674
... )
676675
"""
677676

677+
bpmrefs: Uint32Refpts
678+
steerrefs: Uint32Refpts
679+
678680
def __init__(
679681
self,
680682
ring: Lattice,
@@ -759,9 +761,7 @@ def set_norm():
759761
cavd, stsw = set_norm()
760762

761763
# Observables
762-
bpms = OrbitObservable(
763-
bpmrefs, axis=2 * pl, target=bpmtarget, weight=bpmweight
764-
)
764+
bpms = OrbitObservable(bpmrefs, axis=2 * pl, target=bpmtarget, weight=bpmweight)
765765
observables = ObservableList([bpms])
766766
if steersum:
767767
# noinspection PyUnboundLocalVariable
@@ -876,7 +876,7 @@ def normalise(
876876
self.stsumweight * normobs[-1] / np.mean(normobs[:-1]) / stsum_ampl
877877
)
878878

879-
def build_analytical(self, **kwargs) -> np.ndarray:
879+
def build_analytical(self, **kwargs) -> FloatArray:
880880
"""Build analytically the response matrix.
881881
882882
Keyword Args:
@@ -995,6 +995,10 @@ class TrajectoryResponseMatrix(ResponseMatrix):
995995
996996
"""
997997

998+
bpmrefs: Uint32Refpts
999+
steerrefs: Uint32Refpts
1000+
_default_twiss_in: ClassVar[dict] = {"beta": np.ones(2), "alpha": np.zeros(2)}
1001+
9981002
def __init__(
9991003
self,
10001004
ring: Lattice,
@@ -1047,7 +1051,7 @@ def steerer(ik, delta):
10471051
self.nbsteers = nbsteers
10481052
self.bpmrefs = ring.get_uint32_index(bpmrefs)
10491053

1050-
def build_analytical(self, **kwargs) -> np.ndarray:
1054+
def build_analytical(self, **kwargs) -> FloatArray:
10511055
"""Build analytically the response matrix.
10521056
10531057
Keyword Args:
@@ -1061,7 +1065,7 @@ def build_analytical(self, **kwargs) -> np.ndarray:
10611065
"""
10621066
ring = self.ring
10631067
pl = self.plane
1064-
twiss_in = kwargs.pop("twiss_in", {"beta": np.ones(2), "alpha": np.zeros(2)})
1068+
twiss_in = self.eval_args.get("twiss_in", self._default_twiss_in)
10651069
_, _, elemdata = ring.linopt6(All, twiss_in=twiss_in, **kwargs)
10661070
dataj = elemdata[self.bpmrefs]
10671071
dataw = elemdata[self.steerrefs]

pyat/at/plot/response_matrix.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def plot_obs_analysis(
6868
if resp.singular_values is None:
6969
resp.solve()
7070
obs = resp.observables
71-
obs.evaluate(lattice, **resp.buildargs)
71+
obs.evaluate(lattice, **resp.eval_args)
7272
corr = resp.uh @ obs.flat_deviations
7373
if ax is None:
7474
fig, ax = plt.subplots()

0 commit comments

Comments
 (0)