Skip to content

Commit 2ca5935

Browse files
committed
cleanup
1 parent 0b55b4d commit 2ca5935

File tree

2 files changed

+68
-58
lines changed

2 files changed

+68
-58
lines changed

pyat/at/lattice/axisdef.py

+54-44
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
"""Helper functions for axis and plane descriptions"""
22

33
from __future__ import annotations
4-
from typing import Optional, Union
54

6-
# For sys.version_info.minor < 9:
5+
# Necessary for type aliases in python <= 3.8 :
76
from typing import Tuple
7+
from typing import Union
88

99
AxisCode = Union[str, int, slice, None, type(Ellipsis)]
1010
AxisDef = Union[AxisCode, Tuple[AxisCode, AxisCode]]
1111

12-
_axis_def = dict(
13-
x=dict(index=0, label="x", unit=" [m]"),
14-
px=dict(index=1, label=r"$p_x$", unit=" [rad]"),
15-
y=dict(index=2, label="y", unit=" [m]"),
16-
py=dict(index=3, label=r"$p_y$", unit=" [rad]"),
17-
dp=dict(index=4, label=r"$\delta$", unit=""),
18-
ct=dict(index=5, label=r"$\beta c \tau$", unit=" [m]"),
19-
)
20-
for xk, xv in [it for it in _axis_def.items()]:
12+
_axis_def = {
13+
"x": {"index": 0, "label": "x", "unit": " [m]"},
14+
"px": {"index": 1, "label": r"$p_x$", "unit": " [rad]"},
15+
"y": {"index": 2, "label": "y", "unit": " [m]"},
16+
"py": {"index": 3, "label": r"$p_y$", "unit": " [rad]"},
17+
"dp": {"index": 4, "label": r"$\delta$", "unit": ""},
18+
"ct": {"index": 5, "label": r"$\beta c \tau$", "unit": " [m]"},
19+
}
20+
for xk, xv in list(_axis_def.items()):
2121
xv["code"] = xk
2222
_axis_def[xv["index"]] = xv
2323
_axis_def[xk.upper()] = xv
@@ -26,41 +26,43 @@
2626
_axis_def["yp"] = _axis_def["py"] # For backward compatibility
2727
_axis_def["s"] = _axis_def["ct"]
2828
_axis_def["S"] = _axis_def["ct"]
29-
_axis_def[None] = dict(index=None, label="", unit="", code=":")
30-
_axis_def[Ellipsis] = dict(index=Ellipsis, label="", unit="", code="...")
31-
32-
_plane_def = dict(
33-
x=dict(index=0, label="x", unit=" [m]"),
34-
y=dict(index=1, label="y", unit=" [m]"),
35-
z=dict(index=2, label="z", unit=""),
36-
)
37-
for xk, xv in [it for it in _plane_def.items()]:
29+
_axis_def[None] = {"index": None, "label": "", "unit": "", "code": ":"}
30+
_axis_def[Ellipsis] = {"index": Ellipsis, "label": "", "unit": "", "code": "..."}
31+
32+
_plane_def = {
33+
"x": {"index": 0, "label": "x", "unit": " [m]"},
34+
"y": {"index": 1, "label": "y", "unit": " [m]"},
35+
"z": {"index": 2, "label": "z", "unit": ""},
36+
}
37+
for xk, xv in list(_plane_def.items()):
3838
xv["code"] = xk
3939
_plane_def[xv["index"]] = xv
4040
_plane_def[xk.upper()] = xv
4141
_plane_def["h"] = _plane_def["x"]
4242
_plane_def["v"] = _plane_def["y"]
4343
_plane_def["H"] = _plane_def["x"]
4444
_plane_def["V"] = _plane_def["y"]
45-
_plane_def[None] = dict(index=None, label="", unit="", code=":")
46-
_plane_def[Ellipsis] = dict(index=Ellipsis, label="", unit="", code="...")
45+
_plane_def[None] = {"index": None, "label": "", "unit": "", "code": ":"}
46+
_plane_def[Ellipsis] = {"index": Ellipsis, "label": "", "unit": "", "code": "..."}
4747

4848

49-
def _descr(dd: dict, arg: AxisDef, key: Optional[str] = None):
50-
if isinstance(arg, tuple):
51-
return tuple(_descr(dd, a, key=key) for a in arg)
52-
else:
53-
try:
54-
descr = dd[arg]
55-
except (TypeError, KeyError):
56-
descr = dict(index=arg, code=arg, label="", unit="")
57-
if key is None:
58-
return descr
49+
def _descr(dd: dict, *args: AxisDef, key: str | None = None):
50+
for arg in args:
51+
if isinstance(arg, tuple):
52+
for a in arg:
53+
yield from _descr(dd, a, key=key)
5954
else:
60-
return descr[key]
55+
if isinstance(arg, slice):
56+
descr = {"index": arg, "code": arg, "label": "", "unit": ""}
57+
else:
58+
descr = dd[arg]
59+
if key is None:
60+
yield descr
61+
else:
62+
yield descr[key]
6163

6264

63-
def axis_(axis: AxisDef, key: Optional[str] = None):
65+
def axis_(*axis: AxisDef, key: str | None = None):
6466
r"""Return axis descriptions
6567
6668
Parameters:
@@ -100,28 +102,32 @@ def axis_(axis: AxisDef, key: Optional[str] = None):
100102
101103
Examples:
102104
103-
>>> axis_(('x','dp'), key='index')
105+
>>> axis_("x", "dp", key="index")
104106
(0, 4)
105107
106108
returns the indices in the standard coordinate vector
107109
108-
>>> dplabel = axis_('dp', key='label')
110+
>>> dplabel = axis_("dp", key="label")
109111
>>> print(dplabel)
110112
$\delta$
111113
112114
returns the coordinate label for plot annotation
113115
114-
>>> axis_((0,'dp'))
116+
>>> axis_(0, "dp")
115117
({'plane': 0, 'label': 'x', 'unit': ' [m]', 'code': 'x'},
116118
{'plane': 4, 'label': '$\\delta$', 'unit': '', 'code': 'dp'})
117119
118120
returns the entire description directories
119121
120122
"""
121-
return _descr(_axis_def, axis, key=key)
123+
ret = tuple(_descr(_axis_def, *axis, key=key))
124+
if len(ret) > 1:
125+
return ret
126+
else:
127+
return ret[0]
122128

123129

124-
def plane_(plane: AxisDef, key: Optional[str] = None):
130+
def plane_(*plane: AxisDef, key: str | None = None):
125131
r"""Return plane descriptions
126132
127133
Parameters:
@@ -154,16 +160,20 @@ def plane_(plane: AxisDef, key: Optional[str] = None):
154160
155161
Examples:
156162
157-
>>> plane_('v', key='index')
163+
>>> plane_("v", key="index")
158164
1
159165
160166
returns the indices in the standard coordinate vector
161167
162-
>>> plane_(('x','y'))
163-
({'plane': 0, 'label': 'h', 'unit': ' [m]', 'code': 'h'},
164-
{'plane': 1, 'label': 'v', 'unit': ' [m]', 'code': 'v'})
168+
>>> plane_("x", "y")
169+
({'plane': 0, 'label': 'x', 'unit': ' [m]', 'code': 'h'},
170+
{'plane': 1, 'label': 'y', 'unit': ' [m]', 'code': 'v'})
165171
166172
returns the entire description directories
167173
168174
"""
169-
return _descr(_plane_def, plane, key=key)
175+
ret = tuple(_descr(_plane_def, *plane, key=key))
176+
if len(ret) > 1:
177+
return ret
178+
else:
179+
return ret[0]

pyat/at/latticetools/observables.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -620,8 +620,8 @@ def __init__(
620620
621621
Observe the horizontal closed orbit at monitor locations
622622
"""
623-
name = self._set_name(name, "orbit", axis_(axis, "code"))
624-
fun = _ArrayAccess(axis_(axis, "index"))
623+
name = self._set_name(name, "orbit", axis_(axis, key="code"))
624+
fun = _ArrayAccess(axis_(axis, key="index"))
625625
needs = {Need.ORBIT}
626626
super().__init__(fun, refpts, needs=needs, name=name, **kwargs)
627627

@@ -670,8 +670,8 @@ def __init__(
670670
Observe the transfer matrix from origin to monitor locations and
671671
extract T[0,1]
672672
"""
673-
name = self._set_name(name, "matrix", axis_(axis, "code"))
674-
fun = _ArrayAccess(axis_(axis, "index"))
673+
name = self._set_name(name, "matrix", axis_(axis, key="code"))
674+
fun = _ArrayAccess(axis_(axis, key="index"))
675675
needs = {Need.MATRIX}
676676
super().__init__(fun, refpts, needs=needs, name=name, **kwargs)
677677

@@ -704,12 +704,12 @@ def __init__(
704704
shape of *value*.
705705
"""
706706
needs = {Need.GLOBALOPTICS}
707-
name = self._set_name(name, param, plane_(plane, "code"))
707+
name = self._set_name(name, param, plane_(plane, key="code"))
708708
if callable(param):
709709
fun = param
710710
needs.add(Need.CHROMATICITY)
711711
else:
712-
fun = _RecordAccess(param, plane_(plane, "index"))
712+
fun = _RecordAccess(param, plane_(plane, key="index"))
713713
if param == "chromaticity":
714714
needs.add(Need.CHROMATICITY)
715715
super().__init__(fun, needs=needs, name=name, **kwargs)
@@ -807,11 +807,11 @@ def __init__(
807807
ax_ = plane_
808808

809809
needs = {Need.LOCALOPTICS}
810-
name = self._set_name(name, param, ax_(plane, "code"))
810+
name = self._set_name(name, param, ax_(plane, key="code"))
811811
if callable(param):
812812
fun = param
813813
else:
814-
fun = _RecordAccess(param, _all_rows(ax_(plane, "index")))
814+
fun = _RecordAccess(param, _all_rows(ax_(plane, key="index")))
815815
if param == "mu" or all_points:
816816
needs.add(Need.ALL_POINTS)
817817
if param in {"W", "Wp", "dalpha", "dbeta", "dmu", "ddispersion", "dR"}:
@@ -892,8 +892,8 @@ def __init__(
892892
The *target*, *weight* and *bounds* inputs must be broadcastable to the
893893
shape of *value*.
894894
"""
895-
name = self._set_name(name, "trajectory", axis_(axis, "code"))
896-
fun = _ArrayAccess(axis_(axis, "index"))
895+
name = self._set_name(name, "trajectory", axis_(axis, key="code"))
896+
fun = _ArrayAccess(axis_(axis, key="index"))
897897
needs = {Need.TRAJECTORY}
898898
super().__init__(fun, refpts, needs=needs, name=name, **kwargs)
899899

@@ -945,11 +945,11 @@ def __init__(
945945
946946
Observe the horizontal emittance
947947
"""
948-
name = self._set_name(name, param, plane_(plane, "code"))
948+
name = self._set_name(name, param, plane_(plane, key="code"))
949949
if callable(param):
950950
fun = param
951951
else:
952-
fun = _RecordAccess(param, plane_(plane, "index"))
952+
fun = _RecordAccess(param, plane_(plane, key="index"))
953953
needs = {Need.EMITTANCE}
954954
super().__init__(fun, needs=needs, name=name, **kwargs)
955955

@@ -1012,10 +1012,10 @@ def GlobalOpticsObservable(
10121012
"""
10131013
if param == "tune" and use_integer:
10141014
# noinspection PyProtectedMember
1015-
name = ElementObservable._set_name(name, param, plane_(plane, "code"))
1015+
name = ElementObservable._set_name(name, param, plane_(plane, key="code"))
10161016
return LocalOpticsObservable(
10171017
End,
1018-
_Tune(plane_(plane, "index")),
1018+
_Tune(plane_(plane, key="index")),
10191019
name=name,
10201020
summary=True,
10211021
all_points=True,

0 commit comments

Comments
 (0)