66
66
from typing import Tuple
67
67
68
68
import numpy as np
69
+ import numpy .typing as npt
69
70
70
71
from ..lattice import AtError , AxisDef , axis_ , plane_
71
72
from ..lattice import Lattice , Refpts , End
@@ -204,8 +205,8 @@ def __init__(
204
205
fun : Callable ,
205
206
* args ,
206
207
name : str | None = None ,
207
- target = None ,
208
- weight = 1.0 ,
208
+ target : npt . ArrayLike | None = None ,
209
+ weight : npt . ArrayLike = 1.0 ,
209
210
bounds = (0.0 , 0.0 ),
210
211
needs : Set [Need ] | None = None ,
211
212
** kwargs ,
@@ -256,12 +257,12 @@ def __init__(
256
257
self .fun : Callable = fun #: Evaluation function
257
258
self .needs : Set [Need ] = needs or set () #: Set of requirements
258
259
self .name : str = name #: Observable name
259
- self .target = target #: Target value
260
- self .w = weight
260
+ self .target : npt . ArrayLike | None = target #: Target value
261
+ self .w : npt . ArrayLike = weight
261
262
self .lbound , self .ubound = bounds
262
- self .initial = None
263
- self ._value = None
264
- self ._shape = None
263
+ self .initial : npt . NDArray [ float ] | None = None
264
+ self ._value : npt . NDArray [ float ] | Exception | None = None
265
+ self ._shape : tuple [ int , ...] | None = None
265
266
self .args = args
266
267
self .kwargs = kwargs
267
268
@@ -299,23 +300,25 @@ def _all_lines(self):
299
300
vnow = self ._value
300
301
if vnow is None or isinstance (vnow , Exception ):
301
302
deviation = None
302
- else :
303
- deviation = self .deviation
304
- if self .target is None :
305
303
vmin = None
306
304
vmax = None
307
305
else :
308
- target = np .broadcast_to (self .target , vnow .shape )
309
- vmin = target + self .lbound
310
- vmax = target + self .ubound
306
+ deviation = self .deviation
307
+ if self .target is None :
308
+ vmin = None
309
+ vmax = None
310
+ else :
311
+ target = np .broadcast_to (self .target , vnow .shape ) # type: ignore
312
+ vmin = target + self .lbound
313
+ vmax = target + self .ubound
311
314
values = self ._line ("" , self .initial , vnow , vmin , vmax , deviation )
312
315
return "\n " .join ((self .name , values ))
313
316
314
317
def _setup (self , ring : Lattice ):
315
318
"""Setup function called when the observable is added to a list."""
316
319
pass
317
320
318
- def evaluate (self , * data , initial : bool = False ):
321
+ def evaluate (self , * data , initial : bool = False ) -> npt . NDArray [ float ] | Exception :
319
322
"""Compute and store the value of the observable.
320
323
321
324
The direct evaluation of a single :py:class:`Observable` is normally
@@ -327,6 +330,9 @@ def evaluate(self, *data, initial: bool = False):
327
330
sent to the evaluation function
328
331
initial: It :py:obj:`None`, store the result as the initial
329
332
value
333
+
334
+ Returns:
335
+ value: The value of the observable or the error in evaluation
330
336
"""
331
337
for d in data :
332
338
if isinstance (d , Exception ):
@@ -354,30 +360,34 @@ def check(self) -> bool:
354
360
return self .value is not None
355
361
356
362
@staticmethod
357
- def check_value (value ) :
363
+ def check_value (value : npt . NDArray [ float ] | Exception ) -> npt . NDArray [ float ] :
358
364
if isinstance (value , Exception ):
359
365
raise type (value )(value .args [0 ]) from value
360
366
return value
361
367
362
368
@property
363
- def value (self ):
369
+ def value (self ) -> npt . NDArray [ float ] :
364
370
"""Value of the observable."""
365
371
return self .check_value (self ._value )
366
372
367
373
@property
368
- def weight (self ):
374
+ def weight (self ) -> npt . NDArray [ float ] :
369
375
"""Observable weight."""
370
- return np .broadcast_to (self .w , self ._value .shape )
376
+ return np .broadcast_to (self .w , self ._value .shape ) # type: ignore
377
+
378
+ @weight .setter
379
+ def weight (self , w : npt .ArrayLike ):
380
+ self .w = w
371
381
372
382
@property
373
- def weighted_value (self ):
383
+ def weighted_value (self ) -> npt . NDArray [ float ] :
374
384
"""Weighted value of the Observable, computed as
375
385
:pycode:`weighted_value = value/weight`.
376
386
"""
377
387
return self .value / self .w
378
388
379
389
@property
380
- def deviation (self ):
390
+ def deviation (self ) -> npt . NDArray [ float ] :
381
391
"""Deviation from target value, computed as
382
392
:pycode:`deviation = value-target`.
383
393
"""
@@ -395,12 +405,12 @@ def deviation(self):
395
405
return deviation
396
406
397
407
@property
398
- def weighted_deviation (self ):
408
+ def weighted_deviation (self ) -> npt . NDArray [ float ] :
399
409
""":pycode:`weighted_deviation = (value-target)/weight`."""
400
410
return self .deviation / self .w
401
411
402
412
@property
403
- def residual (self ):
413
+ def residual (self ) -> npt . NDArray [ float ] :
404
414
"""residual, computed as :pycode:`residual = ((value-target)/weight)**2`."""
405
415
return (self .deviation / self .w ) ** 2
406
416
@@ -527,7 +537,7 @@ def __init__(
527
537
self ._excluded = None
528
538
self ._locations = ["" ]
529
539
530
- def check (self ):
540
+ def check (self ) -> bool :
531
541
ok = super ().check ()
532
542
shp = self ._shape
533
543
if ok and shp and shp [0 ] <= 0 :
@@ -552,7 +562,7 @@ def _all_lines(self):
552
562
vmin = repeat (None )
553
563
vmax = repeat (None )
554
564
else :
555
- target = np .broadcast_to (self .target , vnow .shape )
565
+ target = np .broadcast_to (self .target , vnow .shape ) # type: ignore
556
566
vmin = target + self .lbound
557
567
vmax = target + self .ubound
558
568
vini = self .initial
@@ -655,8 +665,8 @@ def __init__(
655
665
656
666
Observe the horizontal closed orbit at monitor locations
657
667
"""
658
- name = self ._set_name (name , "orbit" , axis_ (axis , "code" ))
659
- fun = _ArrayAccess (axis_ (axis , "index" ))
668
+ name = self ._set_name (name , "orbit" , axis_ (axis , key = "code" ))
669
+ fun = _ArrayAccess (axis_ (axis , key = "index" ))
660
670
needs = {Need .ORBIT }
661
671
super ().__init__ (fun , refpts , needs = needs , name = name , ** kwargs )
662
672
@@ -705,8 +715,8 @@ def __init__(
705
715
Observe the transfer matrix from origin to monitor locations and
706
716
extract T[0,1]
707
717
"""
708
- name = self ._set_name (name , "matrix" , axis_ (axis , "code" ))
709
- fun = _ArrayAccess (axis_ (axis , "index" ))
718
+ name = self ._set_name (name , "matrix" , axis_ (axis , key = "code" ))
719
+ fun = _ArrayAccess (axis_ (axis , key = "index" ))
710
720
needs = {Need .MATRIX }
711
721
super ().__init__ (fun , refpts , needs = needs , name = name , ** kwargs )
712
722
@@ -739,12 +749,12 @@ def __init__(
739
749
shape of *value*.
740
750
"""
741
751
needs = {Need .GLOBALOPTICS }
742
- name = self ._set_name (name , param , plane_ (plane , "code" ))
752
+ name = self ._set_name (name , param , plane_ (plane , key = "code" ))
743
753
if callable (param ):
744
754
fun = param
745
755
needs .add (Need .CHROMATICITY )
746
756
else :
747
- fun = partial (_record_access , param , plane_ (plane , "index" ))
757
+ fun = partial (_record_access , param , plane_ (plane , key = "index" ))
748
758
if param == "chromaticity" :
749
759
needs .add (Need .CHROMATICITY )
750
760
super ().__init__ (fun , needs = needs , name = name , ** kwargs )
@@ -881,8 +891,8 @@ def __init__(
881
891
ax_ = plane_
882
892
883
893
needs = {Need .LOCALOPTICS }
884
- name = self ._set_name (name , param , ax_ (plane , "code" ))
885
- index = _all_rows (ax_ (plane , "index" ))
894
+ name = self ._set_name (name , param , ax_ (plane , key = "code" ))
895
+ index = _all_rows (ax_ (plane , key = "index" ))
886
896
if callable (param ):
887
897
fun = param
888
898
else :
@@ -922,7 +932,9 @@ def __init__(
922
932
923
933
Example:
924
934
925
- >>> obs = LatticeObservable(at.Sextupole, "KickAngle", index=0, statfun=np.sum)
935
+ >>> obs = LatticeObservable(
936
+ ... at.Sextupole, "KickAngle", index=0, statfun=np.sum
937
+ ... )
926
938
927
939
Observe the sum of horizontal kicks in Sextupoles
928
940
"""
@@ -967,8 +979,8 @@ def __init__(
967
979
The *target*, *weight* and *bounds* inputs must be broadcastable to the
968
980
shape of *value*.
969
981
"""
970
- name = self ._set_name (name , "trajectory" , axis_ (axis , "code" ))
971
- fun = _ArrayAccess (axis_ (axis , "index" ))
982
+ name = self ._set_name (name , "trajectory" , axis_ (axis , key = "code" ))
983
+ fun = _ArrayAccess (axis_ (axis , key = "index" ))
972
984
needs = {Need .TRAJECTORY }
973
985
super ().__init__ (fun , refpts , needs = needs , name = name , ** kwargs )
974
986
@@ -1020,11 +1032,11 @@ def __init__(
1020
1032
1021
1033
Observe the horizontal emittance
1022
1034
"""
1023
- name = self ._set_name (name , param , plane_ (plane , "code" ))
1035
+ name = self ._set_name (name , param , plane_ (plane , key = "code" ))
1024
1036
if callable (param ):
1025
1037
fun = param
1026
1038
else :
1027
- fun = partial (_record_access , param , plane_ (plane , "index" ))
1039
+ fun = partial (_record_access , param , plane_ (plane , key = "index" ))
1028
1040
needs = {Need .EMITTANCE }
1029
1041
super ().__init__ (fun , needs = needs , name = name , ** kwargs )
1030
1042
@@ -1087,10 +1099,10 @@ def GlobalOpticsObservable(
1087
1099
"""
1088
1100
if param == "tune" and use_integer :
1089
1101
# noinspection PyProtectedMember
1090
- name = ElementObservable ._set_name (name , param , plane_ (plane , "code" ))
1102
+ name = ElementObservable ._set_name (name , param , plane_ (plane , key = "code" ))
1091
1103
return LocalOpticsObservable (
1092
1104
End ,
1093
- _Tune (plane_ (plane , "index" )),
1105
+ _Tune (plane_ (plane , key = "index" )),
1094
1106
name = name ,
1095
1107
summary = True ,
1096
1108
all_points = True ,
0 commit comments