Skip to content

Commit 8f29745

Browse files
committed
Added tests
1 parent b828b30 commit 8f29745

File tree

4 files changed

+181
-52
lines changed

4 files changed

+181
-52
lines changed

docs/p/notebooks/observables.ipynb

+16-16
Original file line numberDiff line numberDiff line change
@@ -777,22 +777,22 @@
777777
"amax(beta[y])\n",
778778
" 7.17614 7.17614 - - 0.0 \n",
779779
"closed_orbit[slice(None, 4, None)]\n",
780-
" QF1A [-3.028e-09 ...] [-3.028e-09 ...] 0.0 0.0 [ 9.17e-06 ...] \n",
781-
" QD2A [-1.785e-09 ...] [-1.785e-09 ...] 0.0 0.0 [ 3.186e-06 ...] \n",
782-
" QD3A [ 2.06e-07 ...] [ 2.06e-07 ...] 0.0 0.0 [ 0.04245 ...] \n",
783-
" QF4A [ 4.635e-07 ...] [ 4.635e-07 ...] 0.0 0.0 [ 0.2148 ...] \n",
784-
" QF4B [ 4.929e-07 ...] [ 4.929e-07 ...] 0.0 0.0 [ 0.2429 ...] \n",
785-
" QD5B [ 2.391e-07 ...] [ 2.391e-07 ...] 0.0 0.0 [ 0.05716 ...] \n",
786-
" QF6B [ 2.25e-08 ...] [ 2.25e-08 ...] 0.0 0.0 [ 0.000506 ...] \n",
787-
" QF8B [-2.958e-08 ...] [-2.958e-08 ...] 0.0 0.0 [ 0.000875 ...] \n",
788-
" QF8D [ 3.864e-08 ...] [ 3.864e-08 ...] 0.0 0.0 [ 0.001493 ...] \n",
789-
" QF6D [-1.147e-08 ...] [-1.147e-08 ...] 0.0 0.0 [ 0.0001316 ...] \n",
790-
" QD5D [-1.925e-07 ...] [-1.925e-07 ...] 0.0 0.0 [ 0.03705 ...] \n",
791-
" QF4D [-4.585e-07 ...] [-4.585e-07 ...] 0.0 0.0 [ 0.2103 ...] \n",
792-
" QF4E [-4.902e-07 ...] [-4.902e-07 ...] 0.0 0.0 [ 0.2403 ...] \n",
793-
" QD3E [-2.424e-07 ...] [-2.424e-07 ...] 0.0 0.0 [ 0.05877 ...] \n",
794-
" QD2E [-8.05e-10 ...] [-8.05e-10 ...] 0.0 0.0 [ 6.48e-07 ...] \n",
795-
" QF1E [-1.927e-09 ...] [-1.927e-09 ...] 0.0 0.0 [ 3.715e-06 ...] \n",
780+
" QF1A [-3.028e-09 ...] [-3.028e-09 ...] [ 0.0 ...] [ 0.0 ...] [ 9.17e-06 ...] \n",
781+
" QD2A [-1.785e-09 ...] [-1.785e-09 ...] [ 0.0 ...] [ 0.0 ...] [ 3.186e-06 ...] \n",
782+
" QD3A [ 2.06e-07 ...] [ 2.06e-07 ...] [ 0.0 ...] [ 0.0 ...] [ 0.04245 ...] \n",
783+
" QF4A [ 4.635e-07 ...] [ 4.635e-07 ...] [ 0.0 ...] [ 0.0 ...] [ 0.2148 ...] \n",
784+
" QF4B [ 4.929e-07 ...] [ 4.929e-07 ...] [ 0.0 ...] [ 0.0 ...] [ 0.2429 ...] \n",
785+
" QD5B [ 2.391e-07 ...] [ 2.391e-07 ...] [ 0.0 ...] [ 0.0 ...] [ 0.05716 ...] \n",
786+
" QF6B [ 2.25e-08 ...] [ 2.25e-08 ...] [ 0.0 ...] [ 0.0 ...] [ 0.000506 ...] \n",
787+
" QF8B [-2.958e-08 ...] [-2.958e-08 ...] [ 0.0 ...] [ 0.0 ...] [ 0.000875 ...] \n",
788+
" QF8D [ 3.864e-08 ...] [ 3.864e-08 ...] [ 0.0 ...] [ 0.0 ...] [ 0.001493 ...] \n",
789+
" QF6D [-1.147e-08 ...] [-1.147e-08 ...] [ 0.0 ...] [ 0.0 ...] [ 0.0001316 ...] \n",
790+
" QD5D [-1.925e-07 ...] [-1.925e-07 ...] [ 0.0 ...] [ 0.0 ...] [ 0.03705 ...] \n",
791+
" QF4D [-4.585e-07 ...] [-4.585e-07 ...] [ 0.0 ...] [ 0.0 ...] [ 0.2103 ...] \n",
792+
" QF4E [-4.902e-07 ...] [-4.902e-07 ...] [ 0.0 ...] [ 0.0 ...] [ 0.2403 ...] \n",
793+
" QD3E [-2.424e-07 ...] [-2.424e-07 ...] [ 0.0 ...] [ 0.0 ...] [ 0.05877 ...] \n",
794+
" QD2E [-8.05e-10 ...] [-8.05e-10 ...] [ 0.0 ...] [ 0.0 ...] [ 6.48e-07 ...] \n",
795+
" QF1E [-1.927e-09 ...] [-1.927e-09 ...] [ 0.0 ...] [ 0.0 ...] [ 3.715e-06 ...] \n",
796796
"s_pos\n",
797797
" QF1A 2.69395 2.69395 - - 0.0 \n",
798798
" QD2A 3.42956 3.42956 - - 0.0 \n",

pyat/at/latticetools/observablelist.py

-7
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,6 @@ def ringeval(
304304
)
305305
except AtError as err:
306306
rgdata = eldata = err
307-
else:
308-
if Need.TUNEUNIT in needs:
309-
eldata["mu"] = eldata["mu"] / (2.0 * np.pi)
310-
if Need.MODULO in needs:
311-
eldata["mu"] = eldata["mu"] % 1.0
312-
elif Need.MODULO in needs:
313-
eldata["mu"] = eldata["mu"] % (2.0 * np.pi)
314307

315308
if Need.EMITTANCE in needs and o0 is not None:
316309
# Emittance computation

pyat/at/latticetools/observables.py

+67-29
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,16 @@ def __call__(self, ring, data):
111111
return data if index is None else data[self.index]
112112

113113

114+
class _MuAccess(_RecordAccess):
115+
"""Access to selected items in a record array"""
116+
117+
def __init__(self, index):
118+
super().__init__("mu", index)
119+
120+
def __call__(self, ring, data):
121+
return super().__call__(ring, data) % (2.0 * np.pi)
122+
123+
114124
def _all_rows(index: Optional[RefIndex]):
115125
"""Prepends "all rows" (":") to an index tuple"""
116126
if index is None:
@@ -173,12 +183,6 @@ class Need(Enum):
173183
#: Specify geometry computation and provide the full data at evaluation
174184
#: points
175185
GEOMETRY = 9
176-
#: Specify whether the modulo of the phase has to be used, necessary when
177-
#: matching the fractional part only
178-
MODULO = 10
179-
#: Specify whether the phase is expressed in units of tune units, in
180-
#: case phase = phase/(2*pi)
181-
TUNEUNIT = 11
182186

183187

184188
class Observable(object):
@@ -486,9 +490,12 @@ def _setup(self, ring: Lattice):
486490

487491

488492
class GeometryObservable(_ElementObservable):
489-
"""Observe the geometrical parameters of the reference trajectory"""
493+
"""Observe the geometrical parameters of the reference trajectory.
494+
495+
Process the result of calling :py:func:`.get_geometry`.
496+
"""
490497

491-
field_list = {"x", "y", "angle"}
498+
_field_list = {"x", "y", "angle"}
492499

493500
def __init__(
494501
self, refpts: Refpts, param: str, name: Optional[str] = None, **kwargs
@@ -522,16 +529,19 @@ def __init__(
522529
523530
Observe x coordinate of monitors
524531
"""
525-
if param not in self.field_list:
526-
raise ValueError(f"Expected {param!r} to be one of {self.field_list!r}")
532+
if param not in self._field_list:
533+
raise ValueError(f"Expected {param!r} to be one of {self._field_list!r}")
527534
name = self._set_name(name, "geometry", param)
528535
fun = _RecordAccess(param, None)
529536
needs = {Need.GEOMETRY}
530537
super().__init__(fun, refpts, needs=needs, name=name, **kwargs)
531538

532539

533540
class OrbitObservable(_ElementObservable):
534-
"""Observes the transfer matrix at selected locations"""
541+
"""Observe the transfer matrix at selected locations.
542+
543+
Process the result of calling :py:func:`.find_orbit`.
544+
"""
535545

536546
def __init__(
537547
self, refpts: Refpts, axis: AxisDef = None, name: Optional[str] = None, **kwargs
@@ -573,7 +583,11 @@ def __init__(
573583

574584

575585
class MatrixObservable(_ElementObservable):
576-
"""Observes the closed orbit at selected locations"""
586+
"""Observe the closed orbit at selected locations.
587+
588+
Processs the result of calling :py:func:`.find_m44` or :py:func:`.find_m44`
589+
depending of :py:meth:`~.Lattice.is_6d`.
590+
"""
577591

578592
def __init__(
579593
self,
@@ -660,7 +674,10 @@ def __init__(
660674

661675

662676
class LocalOpticsObservable(_ElementObservable):
663-
"""Observe a local optics parameter at selected locations"""
677+
"""Observe a local optics parameter at selected locations.
678+
679+
Process the local output of :py:func:`.get_optics`.
680+
"""
664681

665682
def __init__(
666683
self,
@@ -685,6 +702,12 @@ def __init__(
685702
use_integer: For the *'mu'* parameter, compute the
686703
phase advance at all points to avoid discontinuities (slower)
687704
705+
.. Attention::
706+
707+
if *use_integer* is :py:obj:`False` (default value), all phase advance
708+
values are folded into the :math:`[0, 2\pi]` interval to avoid
709+
unpredictible jumps.
710+
688711
Keyword Args:
689712
summary: Set to :py:obj:`True` if the user-defined
690713
evaluation function returns a single item (see below)
@@ -708,7 +731,8 @@ def __init__(
708731
709732
:pycode:`value = fun(ring, elemdata)`
710733
711-
*elemdata* if the output of :py:func:`.get_optics`.
734+
*elemdata* if the output of :py:func:`.get_optics`, evaluated at the *refpts*
735+
of the observable.
712736
713737
*value* is the value of the Observable and must have one line per
714738
refpoint. Alternatively, it may be a single line, but then the
@@ -746,25 +770,19 @@ def __init__(
746770
if callable(param):
747771
fun = param
748772
needs.add(Need.CHROMATICITY)
773+
elif param == "mu" and not use_integer:
774+
# values and target are taken modulo 2*pi
775+
fun = _MuAccess(_all_rows(ax_(plane, "index")))
749776
else:
750777
fun = _RecordAccess(param, _all_rows(ax_(plane, "index")))
751778
if use_integer:
752779
needs.add(Need.ALL_POINTS)
753-
else:
754-
needs.add(Need.MODULO)
755-
target = kwargs.get("target", None)
756-
if target is not None and param == "mu":
757-
kwargs["target"] = target % (2.0 * np.pi)
758-
elif target is not None and param == "mun":
759-
kwargs["target"] = target % 1.0
760-
needs.add(Need.TUNEUNIT)
761-
fun = _RecordAccess("mu", _all_rows(ax_(plane, "index")))
762780

763781
super().__init__(fun, refpts, needs=needs, name=name, **kwargs)
764782

765783

766784
class LatticeObservable(_ElementObservable):
767-
"""Observe an attribute of selected lattice elements"""
785+
"""Observe an attribute of selected lattice elements."""
768786

769787
def __init__(
770788
self,
@@ -841,15 +859,18 @@ def __init__(
841859

842860

843861
class EmittanceObservable(Observable):
844-
"""Observe emittance-related parameters"""
862+
"""Observe emittance-related parameters.
863+
864+
Process the output of :py:func:`.envelope_parameters`.
865+
"""
845866

846867
def __init__(
847868
self, param: str, plane: AxisDef = None, name: Optional[str] = None, **kwargs
848869
):
849870
r"""
850871
Args:
851-
param: Parameter name (see
852-
:py:func:`.envelope_parameters`)
872+
param: Parameter name (see :py:func:`.envelope_parameters`) or
873+
:ref:`user-defined evaluation function <emittance_eval>`
853874
plane: One out of {0, 'x', 'h', 'H'} for horizontal plane,
854875
one out of {1, 'y', 'v', 'V'} for vertival plane or one out of
855876
{2, 'z', 'l', 'L'} for longitudinal plane
@@ -867,14 +888,29 @@ def __init__(
867888
is constrained in the interval
868889
[*target*\ +\ *low_bound* *target*\ +\ *up_bound*]
869890
891+
.. _emittance_eval:
892+
.. rubric:: User-defined evaluation function
893+
894+
It is called as:
895+
896+
:pycode:`value = fun(ring, paramdata)`
897+
898+
*paramdata* if the :py:class:`.RingParameters` object returned by
899+
:py:func:`.envelope_parameters`.
900+
901+
*value* is the value of the Observable.
902+
870903
Example:
871904
872905
>>> EmittanceObservable('emittances', plane='h')
873906
874907
Observe the horizontal emittance
875908
"""
876909
name = self._set_name(name, param, plane_(plane, "code"))
877-
fun = _RecordAccess(param, plane_(plane, "index"))
910+
if callable(param):
911+
fun = param
912+
else:
913+
fun = _RecordAccess(param, plane_(plane, "index"))
878914
needs = {Need.EMITTANCE}
879915
super().__init__(fun, needs=needs, name=name, **kwargs)
880916

@@ -888,7 +924,9 @@ def GlobalOpticsObservable(
888924
**kwargs,
889925
):
890926
# noinspection PyUnresolvedReferences
891-
r"""Observe a global optics parameter
927+
r"""Observe a global optics parameter.
928+
929+
Process the global output of :py:func:`.get_optics`.
892930
893931
Args:
894932
param: Optics parameter name (see :py:func:`.get_optics`)

pyat/test/test_observables.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
from numpy.testing import assert_allclose as assert_close
3+
4+
import at
5+
from at import (
6+
Observable,
7+
ObservableList,
8+
OrbitObservable,
9+
GlobalOpticsObservable,
10+
LocalOpticsObservable,
11+
MatrixObservable,
12+
TrajectoryObservable,
13+
EmittanceObservable,
14+
LatticeObservable,
15+
GeometryObservable,
16+
)
17+
18+
19+
def test_observables(hmba_lattice):
20+
# noinspection PyUnusedLocal
21+
def phase_advance(rng, elemdata):
22+
mu = elemdata.mu
23+
return mu[-1] - mu[0]
24+
25+
def circumference(rng):
26+
return rng.cell_length
27+
28+
ring = hmba_lattice.enable_6d(copy=True)
29+
ring.set_rf_frequency(dp=0.0)
30+
31+
# Create an empty ObervableList
32+
allobs = ObservableList()
33+
# Populate with all kinds of Observables
34+
allobs.append(OrbitObservable(at.Monitor, axis="x"))
35+
allobs.append(
36+
LocalOpticsObservable(
37+
at.Monitor, "beta", plane=1, target=7.0, bounds=(-np.inf, 0.0)
38+
)
39+
)
40+
allobs.append(MatrixObservable("BPM_02"))
41+
allobs.append(LocalOpticsObservable(at.Monitor, "beta", plane="v", statfun=np.amax))
42+
allobs.append(
43+
LocalOpticsObservable(
44+
at.Quadrupole, "closed_orbit", plane=slice(4), target=0.0, weight=1.0e-6
45+
)
46+
)
47+
allobs.append(LocalOpticsObservable(at.Quadrupole, "s_pos"))
48+
allobs.append(
49+
LocalOpticsObservable([33, 101], phase_advance, use_integer=True, summary=True)
50+
)
51+
allobs.append(GlobalOpticsObservable("tune", plane=0, use_integer=True))
52+
allobs.append(LocalOpticsObservable(at.End, "mu", use_integer=True))
53+
allobs.append(GlobalOpticsObservable("chromaticity"))
54+
allobs.append(LatticeObservable(at.Sextupole, "H", statfun=np.mean))
55+
allobs.append(LatticeObservable(at.Sextupole, "PolynomB", index=2))
56+
allobs.append(EmittanceObservable("emittances", plane="x"))
57+
allobs.append(Observable(circumference))
58+
allobs.append(TrajectoryObservable(at.Monitor, axis="px"))
59+
allobs.append(GeometryObservable(at.Monitor, "x"))
60+
61+
# Evaluate the Observables
62+
r_in = np.zeros(6)
63+
r_in[0] = 0.001
64+
r_in[2] = 0.001
65+
allobs.evaluate(ring, r_in=r_in, initial=True)
66+
67+
# Get the expected values
68+
o0, o = ring.find_orbit(refpts=at.All)
69+
el0, rg, el = ring.get_optics(refpts=at.All, orbit=o0, get_chrom=True)
70+
m66, ms = ring.find_m66(refpts="BPM_02", orbit=o0)
71+
prms = ring.envelope_parameters(orbit=o0)
72+
rout, _, _ = ring.track(r_in, refpts=at.Monitor)
73+
geodata, _ = ring.get_geometry(refpts=at.Monitor)
74+
75+
monitors = ring.get_bool_index(at.Monitor)
76+
quadrupoles = ring.get_bool_index(at.Quadrupole)
77+
78+
# Compare the results
79+
assert_close(allobs.values[0], o[monitors, 0])
80+
assert_close(allobs.values[1], el.beta[monitors, 1])
81+
assert_close(allobs.values[2], ms)
82+
assert_close(allobs.values[3], np.amax(el.beta[monitors, 1]))
83+
assert_close(allobs.values[4], el.closed_orbit[quadrupoles, :4])
84+
assert_close(allobs.values[5], el.s_pos[quadrupoles])
85+
assert_close(allobs.values[6], el.mu[101] - el.mu[33])
86+
assert_close(allobs.values[7], el.mu[-1, 0] / 2.0 / np.pi)
87+
assert_close(allobs.values[8], [el.mu[-1]])
88+
assert_close(allobs.values[9], rg.chromaticity)
89+
assert_close(
90+
allobs.values[10], np.mean([elem.H for elem in ring.select(at.Sextupole)])
91+
)
92+
assert_close(
93+
allobs.values[11], [elem.PolynomB[2] for elem in ring.select(at.Sextupole)]
94+
)
95+
assert_close(allobs.values[12], prms.emittances[0])
96+
assert_close(allobs.values[13], ring.cell_length)
97+
assert_close(allobs.values[14], rout[1, 0, :, 0])
98+
assert_close(allobs.values[15], geodata.x)

0 commit comments

Comments
 (0)