Skip to content

Commit f2ae598

Browse files
committed
merged from response_matrices
1 parent 5bf0be2 commit f2ae598

File tree

1 file changed

+51
-39
lines changed

1 file changed

+51
-39
lines changed

pyat/at/latticetools/observables.py

+51-39
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from typing import Tuple
6767

6868
import numpy as np
69+
import numpy.typing as npt
6970

7071
from ..lattice import AtError, AxisDef, axis_, plane_
7172
from ..lattice import Lattice, Refpts, End
@@ -204,8 +205,8 @@ def __init__(
204205
fun: Callable,
205206
*args,
206207
name: str | None = None,
207-
target=None,
208-
weight=1.0,
208+
target: npt.ArrayLike | None = None,
209+
weight: npt.ArrayLike = 1.0,
209210
bounds=(0.0, 0.0),
210211
needs: Set[Need] | None = None,
211212
**kwargs,
@@ -256,12 +257,12 @@ def __init__(
256257
self.fun: Callable = fun #: Evaluation function
257258
self.needs: Set[Need] = needs or set() #: Set of requirements
258259
self.name: str = name #: Observable name
259-
self.target = target #: Target value
260-
self.w = weight
260+
self.target: npt.ArrayLike | None = target #: Target value
261+
self.w: npt.ArrayLike = weight
261262
self.lbound, self.ubound = bounds
262-
self.initial = None
263-
self._value = None
264-
self._shape = None
263+
self.initial: npt.NDArray[float] | None = None
264+
self._value: npt.NDArray[float] | Exception | None = None
265+
self._shape: tuple[int, ...] | None = None
265266
self.args = args
266267
self.kwargs = kwargs
267268

@@ -299,23 +300,25 @@ def _all_lines(self):
299300
vnow = self._value
300301
if vnow is None or isinstance(vnow, Exception):
301302
deviation = None
302-
else:
303-
deviation = self.deviation
304-
if self.target is None:
305303
vmin = None
306304
vmax = None
307305
else:
308-
target = np.broadcast_to(self.target, vnow.shape)
309-
vmin = target + self.lbound
310-
vmax = target + self.ubound
306+
deviation = self.deviation
307+
if self.target is None:
308+
vmin = None
309+
vmax = None
310+
else:
311+
target = np.broadcast_to(self.target, vnow.shape) # type: ignore
312+
vmin = target + self.lbound
313+
vmax = target + self.ubound
311314
values = self._line("", self.initial, vnow, vmin, vmax, deviation)
312315
return "\n".join((self.name, values))
313316

314317
def _setup(self, ring: Lattice):
315318
"""Setup function called when the observable is added to a list."""
316319
pass
317320

318-
def evaluate(self, *data, initial: bool = False):
321+
def evaluate(self, *data, initial: bool = False) -> npt.NDArray[float] | Exception:
319322
"""Compute and store the value of the observable.
320323
321324
The direct evaluation of a single :py:class:`Observable` is normally
@@ -327,6 +330,9 @@ def evaluate(self, *data, initial: bool = False):
327330
sent to the evaluation function
328331
initial: It :py:obj:`None`, store the result as the initial
329332
value
333+
334+
Returns:
335+
value: The value of the observable or the error in evaluation
330336
"""
331337
for d in data:
332338
if isinstance(d, Exception):
@@ -354,30 +360,34 @@ def check(self) -> bool:
354360
return self.value is not None
355361

356362
@staticmethod
357-
def check_value(value):
363+
def check_value(value: npt.NDArray[float] | Exception) -> npt.NDArray[float]:
358364
if isinstance(value, Exception):
359365
raise type(value)(value.args[0]) from value
360366
return value
361367

362368
@property
363-
def value(self):
369+
def value(self) -> npt.NDArray[float]:
364370
"""Value of the observable."""
365371
return self.check_value(self._value)
366372

367373
@property
368-
def weight(self):
374+
def weight(self) -> npt.NDArray[float]:
369375
"""Observable weight."""
370-
return np.broadcast_to(self.w, self._value.shape)
376+
return np.broadcast_to(self.w, self._value.shape) # type: ignore
377+
378+
@weight.setter
379+
def weight(self, w: npt.ArrayLike):
380+
self.w = w
371381

372382
@property
373-
def weighted_value(self):
383+
def weighted_value(self) -> npt.NDArray[float]:
374384
"""Weighted value of the Observable, computed as
375385
:pycode:`weighted_value = value/weight`.
376386
"""
377387
return self.value / self.w
378388

379389
@property
380-
def deviation(self):
390+
def deviation(self) -> npt.NDArray[float]:
381391
"""Deviation from target value, computed as
382392
:pycode:`deviation = value-target`.
383393
"""
@@ -395,12 +405,12 @@ def deviation(self):
395405
return deviation
396406

397407
@property
398-
def weighted_deviation(self):
408+
def weighted_deviation(self) -> npt.NDArray[float]:
399409
""":pycode:`weighted_deviation = (value-target)/weight`."""
400410
return self.deviation / self.w
401411

402412
@property
403-
def residual(self):
413+
def residual(self) -> npt.NDArray[float]:
404414
"""residual, computed as :pycode:`residual = ((value-target)/weight)**2`."""
405415
return (self.deviation / self.w) ** 2
406416

@@ -527,7 +537,7 @@ def __init__(
527537
self._excluded = None
528538
self._locations = [""]
529539

530-
def check(self):
540+
def check(self) -> bool:
531541
ok = super().check()
532542
shp = self._shape
533543
if ok and shp and shp[0] <= 0:
@@ -552,7 +562,7 @@ def _all_lines(self):
552562
vmin = repeat(None)
553563
vmax = repeat(None)
554564
else:
555-
target = np.broadcast_to(self.target, vnow.shape)
565+
target = np.broadcast_to(self.target, vnow.shape) # type: ignore
556566
vmin = target + self.lbound
557567
vmax = target + self.ubound
558568
vini = self.initial
@@ -655,8 +665,8 @@ def __init__(
655665
656666
Observe the horizontal closed orbit at monitor locations
657667
"""
658-
name = self._set_name(name, "orbit", axis_(axis, "code"))
659-
fun = _ArrayAccess(axis_(axis, "index"))
668+
name = self._set_name(name, "orbit", axis_(axis, key="code"))
669+
fun = _ArrayAccess(axis_(axis, key="index"))
660670
needs = {Need.ORBIT}
661671
super().__init__(fun, refpts, needs=needs, name=name, **kwargs)
662672

@@ -705,8 +715,8 @@ def __init__(
705715
Observe the transfer matrix from origin to monitor locations and
706716
extract T[0,1]
707717
"""
708-
name = self._set_name(name, "matrix", axis_(axis, "code"))
709-
fun = _ArrayAccess(axis_(axis, "index"))
718+
name = self._set_name(name, "matrix", axis_(axis, key="code"))
719+
fun = _ArrayAccess(axis_(axis, key="index"))
710720
needs = {Need.MATRIX}
711721
super().__init__(fun, refpts, needs=needs, name=name, **kwargs)
712722

@@ -739,12 +749,12 @@ def __init__(
739749
shape of *value*.
740750
"""
741751
needs = {Need.GLOBALOPTICS}
742-
name = self._set_name(name, param, plane_(plane, "code"))
752+
name = self._set_name(name, param, plane_(plane, key="code"))
743753
if callable(param):
744754
fun = param
745755
needs.add(Need.CHROMATICITY)
746756
else:
747-
fun = partial(_record_access, param, plane_(plane, "index"))
757+
fun = partial(_record_access, param, plane_(plane, key="index"))
748758
if param == "chromaticity":
749759
needs.add(Need.CHROMATICITY)
750760
super().__init__(fun, needs=needs, name=name, **kwargs)
@@ -881,8 +891,8 @@ def __init__(
881891
ax_ = plane_
882892

883893
needs = {Need.LOCALOPTICS}
884-
name = self._set_name(name, param, ax_(plane, "code"))
885-
index = _all_rows(ax_(plane, "index"))
894+
name = self._set_name(name, param, ax_(plane, key="code"))
895+
index = _all_rows(ax_(plane, key="index"))
886896
if callable(param):
887897
fun = param
888898
else:
@@ -922,7 +932,9 @@ def __init__(
922932
923933
Example:
924934
925-
>>> obs = LatticeObservable(at.Sextupole, "KickAngle", index=0, statfun=np.sum)
935+
>>> obs = LatticeObservable(
936+
... at.Sextupole, "KickAngle", index=0, statfun=np.sum
937+
... )
926938
927939
Observe the sum of horizontal kicks in Sextupoles
928940
"""
@@ -967,8 +979,8 @@ def __init__(
967979
The *target*, *weight* and *bounds* inputs must be broadcastable to the
968980
shape of *value*.
969981
"""
970-
name = self._set_name(name, "trajectory", axis_(axis, "code"))
971-
fun = _ArrayAccess(axis_(axis, "index"))
982+
name = self._set_name(name, "trajectory", axis_(axis, key="code"))
983+
fun = _ArrayAccess(axis_(axis, key="index"))
972984
needs = {Need.TRAJECTORY}
973985
super().__init__(fun, refpts, needs=needs, name=name, **kwargs)
974986

@@ -1020,11 +1032,11 @@ def __init__(
10201032
10211033
Observe the horizontal emittance
10221034
"""
1023-
name = self._set_name(name, param, plane_(plane, "code"))
1035+
name = self._set_name(name, param, plane_(plane, key="code"))
10241036
if callable(param):
10251037
fun = param
10261038
else:
1027-
fun = partial(_record_access, param, plane_(plane, "index"))
1039+
fun = partial(_record_access, param, plane_(plane, key="index"))
10281040
needs = {Need.EMITTANCE}
10291041
super().__init__(fun, needs=needs, name=name, **kwargs)
10301042

@@ -1087,10 +1099,10 @@ def GlobalOpticsObservable(
10871099
"""
10881100
if param == "tune" and use_integer:
10891101
# noinspection PyProtectedMember
1090-
name = ElementObservable._set_name(name, param, plane_(plane, "code"))
1102+
name = ElementObservable._set_name(name, param, plane_(plane, key="code"))
10911103
return LocalOpticsObservable(
10921104
End,
1093-
_Tune(plane_(plane, "index")),
1105+
_Tune(plane_(plane, key="index")),
10941106
name=name,
10951107
summary=True,
10961108
all_points=True,

0 commit comments

Comments
 (0)