144
144
import abc
145
145
import warnings
146
146
from collections .abc import Sequence , Generator
147
+ from typing import Any , ClassVar
147
148
from itertools import chain
148
149
from functools import partial
149
150
import math
150
151
151
152
import numpy as np
152
153
import numpy .typing as npt
153
- from numpy .ma .core import logical_not
154
154
155
155
from .observables import ElementObservable
156
156
from .observables import TrajectoryObservable , OrbitObservable , LatticeObservable
157
157
from .observables import LocalOpticsObservable , GlobalOpticsObservable
158
158
from .observablelist import ObservableList
159
- from ..lattice import AtError , AtWarning , Lattice , Refpts , AxisDef , plane_ , All
160
- from ..lattice import Monitor , checkattr
159
+ from ..lattice import AtError , AtWarning , Refpts , Uint32Refpts , All
160
+ from ..lattice import AxisDef , plane_ , Lattice , Monitor , checkattr
161
161
from ..lattice .lattice_variables import RefptsVariable
162
162
from ..lattice .variables import VariableList
163
163
164
+ FloatArray = npt .NDArray [np .float64 ]
165
+
164
166
_orbit_correctors = checkattr ("KickAngle" )
165
167
166
168
_globring : Lattice | None = None
@@ -221,16 +223,19 @@ def _resp_fork(variables: VariableList, **kwargs):
221
223
222
224
class _SvdSolver (abc .ABC ):
223
225
"""SVD solver for response matrices."""
224
- _response : npt .NDArray [float ] | None
226
+
227
+ _shape : tuple [int , int ]
228
+ _obsmask : npt .NDArray [bool ]
229
+ _varmask : npt .NDArray [bool ]
230
+ _response : FloatArray | None = None
231
+ v : FloatArray | None = None
232
+ uh : FloatArray | None = None
233
+ singular_values : FloatArray | None = None
225
234
226
235
def __init__ (self , nobs : int , nvar : int ):
227
236
self ._shape = (nobs , nvar )
228
- self ._response = None
229
237
self ._obsmask = np .ones (nobs , dtype = bool )
230
238
self ._varmask = np .ones (nvar , dtype = bool )
231
- self .v = None
232
- self .uh = None
233
- self .singular_values = None
234
239
235
240
def reset_vars (self ):
236
241
"""Reset the variable exclusion mask: enable all variables"""
@@ -286,15 +291,15 @@ def check_norm(self) -> tuple[np.ndarray, np.ndarray]:
286
291
return obs , var
287
292
288
293
@property
289
- def response (self ) -> npt . NDArray [ float ] :
294
+ def response (self ) -> FloatArray :
290
295
"""Response matrix."""
291
296
resp = self ._response
292
297
if resp is None :
293
298
raise AtError ("No matrix yet: run build() or load() first" )
294
299
return resp
295
300
296
301
@response .setter
297
- def response (self , response : npt . NDArray [ float ] ) -> None :
302
+ def response (self , response : FloatArray ) -> None :
298
303
l1 , c1 = self ._shape
299
304
l2 , c2 = response .shape
300
305
if l1 != l1 or c1 != c2 :
@@ -308,7 +313,7 @@ def weighted_response(self) -> np.ndarray:
308
313
"""Weighted response matrix."""
309
314
return self .response * (self .varweights / self .obsweights .reshape (- 1 , 1 ))
310
315
311
- def correction_matrix (self , nvals : int | None = None ) -> npt . NDArray [ float ] :
316
+ def correction_matrix (self , nvals : int | None = None ) -> FloatArray :
312
317
"""Return the correction matrix (pseudo-inverse of the response matrix).
313
318
314
319
Args:
@@ -322,14 +327,14 @@ def correction_matrix(self, nvals: int | None = None) -> npt.NDArray[float]:
322
327
self .solve ()
323
328
if nvals is None :
324
329
nvals = len (self .singular_values )
325
- cormat = np .zeros (self ._shape )
326
- selected = np .ix_ (self ._obsmask , self ._varmask )
330
+ cormat = np .zeros (self ._shape [:: - 1 ] )
331
+ selected = np .ix_ (self ._varmask , self ._obsmask )
327
332
cormat [selected ] = self .v [:, :nvals ] @ self .uh [:nvals , :]
328
333
return cormat
329
334
330
335
def get_correction (
331
336
self , observed : np .ndarray , nvals : int | None = None
332
- ) -> npt . NDArray [ float ] :
337
+ ) -> FloatArray :
333
338
"""Compute the correction of the given observation.
334
339
335
340
Args:
@@ -375,8 +380,11 @@ class ResponseMatrix(_SvdSolver):
375
380
of their :py:class:`~.variables.VariableBase`\ s and :py:class:`.Observable`\s to
376
381
produce combined responses.
377
382
"""
383
+
384
+ ring : Lattice
378
385
variables : VariableList #: List of matrix :py:class:`Variable <.VariableBase>`\ s
379
386
observables : ObservableList #: List of matrix :py:class:`.Observable`\s
387
+ eval_args : dict [str , Any ] = {}
380
388
381
389
def __init__ (
382
390
self ,
@@ -404,7 +412,6 @@ def limits(obslist):
404
412
self .ring = ring
405
413
self .variables = variables
406
414
self .observables = observables
407
- self .buildargs = {}
408
415
variables .get (ring = ring , initial = True )
409
416
observables .evaluate (ring = ring , initial = True )
410
417
super ().__init__ (len (observables .flat_values ), len (variables ))
@@ -418,7 +425,7 @@ def __add__(self, other: ResponseMatrix):
418
425
return ResponseMatrix (
419
426
self .ring ,
420
427
VariableList (self .variables + other .variables ),
421
- self .observables + other .observables
428
+ self .observables + other .observables ,
422
429
)
423
430
424
431
def __str__ (self ):
@@ -460,7 +467,7 @@ def correct(
460
467
self .variables .get (ring = ring )
461
468
sumcorr = np .array ([0.0 ])
462
469
for _ in range (niter ):
463
- obs .evaluate (ring , ** self .buildargs )
470
+ obs .evaluate (ring , ** self .eval_args )
464
471
corr = self .get_correction (obs .flat_deviations , nvals = nvals )
465
472
sumcorr = sumcorr + corr # non-broadcastable sumcorr
466
473
if apply :
@@ -473,7 +480,7 @@ def build_tracking(
473
480
pool_size : int | None = None ,
474
481
start_method : str | None = None ,
475
482
** kwargs ,
476
- ) -> np . ndarray :
483
+ ) -> FloatArray :
477
484
"""Build the response matrix.
478
485
479
486
Args:
@@ -498,8 +505,9 @@ def build_tracking(
498
505
Returns:
499
506
response: Response matrix
500
507
"""
501
- self .buildargs = kwargs
508
+ self .eval_args = kwargs
502
509
self .observables .evaluate (self .ring )
510
+ ring = self .ring .deepcopy ()
503
511
504
512
if use_mp :
505
513
global _globring
@@ -509,35 +517,26 @@ def build_tracking(
509
517
pool_size = min (len (self .variables ), os .cpu_count ())
510
518
obschunks = sequence_split (self .variables , pool_size )
511
519
if ctx .get_start_method () == "fork" :
512
- _globring = self . ring
520
+ _globring = ring
513
521
_globobs = self .observables
514
522
_single_resp = partial (_resp_fork , ** kwargs )
515
-
516
523
else :
517
- _single_resp = partial (_resp , self . ring , self .observables , ** kwargs )
524
+ _single_resp = partial (_resp , ring , self .observables , ** kwargs )
518
525
with concurrent .futures .ProcessPoolExecutor (
519
526
max_workers = pool_size ,
520
527
mp_context = ctx ,
521
528
) as pool :
522
529
results = list (chain (* pool .map (_single_resp , obschunks )))
523
- # with ctx.Pool(pool_size) as pool:
524
- # results = pool.map(_single_resp, self.variables)
525
530
_globring = None
526
531
_globobs = None
527
532
else :
528
- ring = self .ring
529
- boolrefs = ring .get_bool_index (None )
530
- for var in self .variables :
531
- boolrefs |= ring .get_bool_index (var .refpts )
532
-
533
- ring = ring .replace (boolrefs )
534
- results = _resp (ring .deepcopy (), self .observables , self .variables , ** kwargs )
533
+ results = _resp (ring , self .observables , self .variables , ** kwargs )
535
534
536
535
resp = np .stack (results , axis = - 1 )
537
536
self .response = resp
538
537
return resp
539
538
540
- def build_analytical (self ) -> np . ndarray :
539
+ def build_analytical (self ) -> FloatArray :
541
540
"""Build the response matrix."""
542
541
raise NotImplementedError (
543
542
f"build_analytical not implemented for { self .__class__ .__name__ } "
@@ -626,15 +625,15 @@ def exclude_vars(self, *varid: int | str) -> None:
626
625
>>> resp.exclude_vars(0, "var1", -1)
627
626
628
627
Exclude the 1st variable, the variable named "var1" and the last variable.
629
- """
628
+ """
630
629
nameset = set (nm for nm in varid if isinstance (nm , str ))
631
630
varidx = [nm for nm in varid if not isinstance (nm , str )]
632
631
mask = np .array ([var .name in nameset for var in self .variables ])
633
632
mask [varidx ] = True
634
633
miss = nameset - {var .name for var , ok in zip (self .variables , mask ) if ok }
635
634
if miss :
636
635
raise ValueError (f"Unknown variables: { miss } " )
637
- self ._varmask &= logical_not (mask )
636
+ self ._varmask &= np . logical_not (mask )
638
637
639
638
@property
640
639
def excluded_vars (self ) -> list :
@@ -675,6 +674,9 @@ class OrbitResponseMatrix(ResponseMatrix):
675
674
... )
676
675
"""
677
676
677
+ bpmrefs : Uint32Refpts
678
+ steerrefs : Uint32Refpts
679
+
678
680
def __init__ (
679
681
self ,
680
682
ring : Lattice ,
@@ -759,9 +761,7 @@ def set_norm():
759
761
cavd , stsw = set_norm ()
760
762
761
763
# Observables
762
- bpms = OrbitObservable (
763
- bpmrefs , axis = 2 * pl , target = bpmtarget , weight = bpmweight
764
- )
764
+ bpms = OrbitObservable (bpmrefs , axis = 2 * pl , target = bpmtarget , weight = bpmweight )
765
765
observables = ObservableList ([bpms ])
766
766
if steersum :
767
767
# noinspection PyUnboundLocalVariable
@@ -876,7 +876,7 @@ def normalise(
876
876
self .stsumweight * normobs [- 1 ] / np .mean (normobs [:- 1 ]) / stsum_ampl
877
877
)
878
878
879
- def build_analytical (self , ** kwargs ) -> np . ndarray :
879
+ def build_analytical (self , ** kwargs ) -> FloatArray :
880
880
"""Build analytically the response matrix.
881
881
882
882
Keyword Args:
@@ -995,6 +995,10 @@ class TrajectoryResponseMatrix(ResponseMatrix):
995
995
996
996
"""
997
997
998
+ bpmrefs : Uint32Refpts
999
+ steerrefs : Uint32Refpts
1000
+ _default_twiss_in : ClassVar [dict ] = {"beta" : np .ones (2 ), "alpha" : np .zeros (2 )}
1001
+
998
1002
def __init__ (
999
1003
self ,
1000
1004
ring : Lattice ,
@@ -1047,7 +1051,7 @@ def steerer(ik, delta):
1047
1051
self .nbsteers = nbsteers
1048
1052
self .bpmrefs = ring .get_uint32_index (bpmrefs )
1049
1053
1050
- def build_analytical (self , ** kwargs ) -> np . ndarray :
1054
+ def build_analytical (self , ** kwargs ) -> FloatArray :
1051
1055
"""Build analytically the response matrix.
1052
1056
1053
1057
Keyword Args:
@@ -1061,7 +1065,7 @@ def build_analytical(self, **kwargs) -> np.ndarray:
1061
1065
"""
1062
1066
ring = self .ring
1063
1067
pl = self .plane
1064
- twiss_in = kwargs . pop ("twiss_in" , { "beta" : np . ones ( 2 ), "alpha" : np . zeros ( 2 )} )
1068
+ twiss_in = self . eval_args . get ("twiss_in" , self . _default_twiss_in )
1065
1069
_ , _ , elemdata = ring .linopt6 (All , twiss_in = twiss_in , ** kwargs )
1066
1070
dataj = elemdata [self .bpmrefs ]
1067
1071
dataw = elemdata [self .steerrefs ]
0 commit comments