Skip to content

Commit b8fea2a

Browse files
authoredMay 25, 2024··
Variables (#700)
* rebased on master * Element array attributes are always ParamArrays * Documentation of Variables * Documentation of Parameters * PEP8 * [0.0, k] instead of [0, k] * Removed parameters * merged from master * lattice_variables * Refactored Variable to VariableBase * Refactored Variable to VariableBase * "black" formatting * merged elements.py from master * str and repr * str and repr * small fixes * utils.py * changed Number * minimised differences with master * Optimised imports * Documentation * PEP8 checks * variable notebook * axisdef * rebased on master * merged from master * merged from master * code style * Changed the examples of custom variables in the "variables" notebook
1 parent 588edbc commit b8fea2a

12 files changed

+1784
-154
lines changed
 

‎atintegrators/atelem.c

+4-2
Original file line numberDiff line numberDiff line change
@@ -193,16 +193,18 @@ static long atGetLong(const PyObject *element, const char *name)
193193
{
194194
const PyObject *attr = PyObject_GetAttrString((PyObject *)element, name);
195195
if (!attr) return 0L;
196+
long l = PyLong_AsLong((PyObject *)attr);
196197
Py_DECREF(attr);
197-
return PyLong_AsLong((PyObject *)attr);
198+
return l;
198199
}
199200

200201
static double atGetDouble(const PyObject *element, const char *name)
201202
{
202203
const PyObject *attr = PyObject_GetAttrString((PyObject *)element, name);
203204
if (!attr) return 0.0;
205+
double d = PyFloat_AsDouble((PyObject *)attr);
204206
Py_DECREF(attr);
205-
return PyFloat_AsDouble((PyObject *)attr);
207+
return d;
206208
}
207209

208210
static long atGetOptionalLong(const PyObject *element, const char *name, long default_value)

‎docs/conf.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@
4040
'sphinx.ext.intersphinx',
4141
'sphinx.ext.githubpages',
4242
'sphinx.ext.viewcode',
43-
'myst_parser',
43+
'myst_nb',
4444
'sphinx_copybutton',
45+
'sphinx_design',
4546
]
4647

4748
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
@@ -56,7 +57,7 @@
5657
# List of patterns, relative to source directory, that match files and
5758
# directories to ignore when looking for source files.
5859
# This pattern also affects html_static_path and html_extra_path.
59-
exclude_patterns = ["README.rst", "**/*.so"]
60+
exclude_patterns = ["README.rst", "**/*.so", "_build/*"]
6061
rst_prolog = """
6162
.. role:: pycode(code)
6263
:language: python
@@ -92,6 +93,8 @@
9293
"deflist"
9394
]
9495
myst_heading_anchors = 3
96+
nb_execution_mode = "auto"
97+
nb_execution_allow_errors = True
9598

9699
# -- Options for HTML output -------------------------------------------------
97100

‎docs/p/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Sub-packages
2626

2727
howto/Installation
2828
howto/Primer
29+
notebooks/variables
2930

3031
.. toctree::
3132
:maxdepth: 2

‎docs/p/notebooks/variables.ipynb

+942
Large diffs are not rendered by default.

‎pyat/at/future.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .lattice.variables import *
2+
from .lattice.lattice_variables import *

‎pyat/at/lattice/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
from .axisdef import *
1010
from .options import DConstant, random
1111
from .particle_object import Particle
12+
# from .variables import *
13+
from .variables import VariableList
1214
from .elements import *
1315
from .rectangular_bend import *
1416
from .idtable_element import InsertionDeviceKickMap
1517
from .utils import *
1618
from .lattice_object import *
19+
# from .lattice_variables import *
1720
from .cavity_access import *
1821
from .variable_elements import *
1922
from .deprecated import *

‎pyat/at/lattice/axisdef.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Helper functions for axis and plane descriptions"""
2+
23
from __future__ import annotations
34
from typing import Optional, Union
5+
46
# For sys.version_info.minor < 9:
57
from typing import Tuple
68

@@ -16,31 +18,31 @@
1618
ct=dict(index=5, label=r"$\beta c \tau$", unit=" [m]"),
1719
)
1820
for xk, xv in [it for it in _axis_def.items()]:
19-
xv['code'] = xk
20-
_axis_def[xv['index']] = xv
21+
xv["code"] = xk
22+
_axis_def[xv["index"]] = xv
2123
_axis_def[xk.upper()] = xv
22-
_axis_def['delta'] = _axis_def['dp']
23-
_axis_def['xp'] = _axis_def['px'] # For backward compatibility
24-
_axis_def['yp'] = _axis_def['py'] # For backward compatibility
25-
_axis_def['s'] = _axis_def['ct']
26-
_axis_def['S'] = _axis_def['ct']
27-
_axis_def[None] = dict(index=slice(None), label="", unit="", code=":")
24+
_axis_def["delta"] = _axis_def["dp"]
25+
_axis_def["xp"] = _axis_def["px"] # For backward compatibility
26+
_axis_def["yp"] = _axis_def["py"] # For backward compatibility
27+
_axis_def["s"] = _axis_def["ct"]
28+
_axis_def["S"] = _axis_def["ct"]
29+
_axis_def[None] = dict(index=None, label="", unit="", code=":")
2830
_axis_def[Ellipsis] = dict(index=Ellipsis, label="", unit="", code="...")
2931

3032
_plane_def = dict(
3133
x=dict(index=0, label="x", unit=" [m]"),
3234
y=dict(index=1, label="y", unit=" [m]"),
33-
z=dict(index=2, label="z", unit="")
35+
z=dict(index=2, label="z", unit=""),
3436
)
3537
for xk, xv in [it for it in _plane_def.items()]:
36-
xv['code'] = xk
37-
_plane_def[xv['index']] = xv
38+
xv["code"] = xk
39+
_plane_def[xv["index"]] = xv
3840
_plane_def[xk.upper()] = xv
39-
_plane_def['h'] = _plane_def['x']
40-
_plane_def['v'] = _plane_def['y']
41-
_plane_def['H'] = _plane_def['x']
42-
_plane_def['V'] = _plane_def['y']
43-
_plane_def[None] = dict(index=slice(None), label="", unit="", code=":")
41+
_plane_def["h"] = _plane_def["x"]
42+
_plane_def["v"] = _plane_def["y"]
43+
_plane_def["H"] = _plane_def["x"]
44+
_plane_def["V"] = _plane_def["y"]
45+
_plane_def[None] = dict(index=None, label="", unit="", code=":")
4446
_plane_def[Ellipsis] = dict(index=Ellipsis, label="", unit="", code="...")
4547

4648

‎pyat/at/lattice/elements.py

+129-107
Large diffs are not rendered by default.

‎pyat/at/lattice/lattice_object.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from .elements import Element
3434
from .particle_object import Particle
3535
from .utils import AtError, AtWarning, Refpts
36-
from .utils import get_s_pos, get_elements,get_value_refpts, set_value_refpts
36+
from .utils import get_s_pos, get_elements, get_value_refpts, set_value_refpts
3737
# noinspection PyProtectedMember
3838
from .utils import get_uint32_index, get_bool_index, _refcount, Uint32Refpts
3939
from .utils import refpts_iterator, checktype, set_shift, set_tilt, get_geometry
@@ -296,11 +296,11 @@ def _addition_filter(self, elems: Iterable[Element], copy_elements=False):
296296
if cavities and not hasattr(self, '_cell_harmnumber'):
297297
cavities.sort(key=lambda el: el.Frequency)
298298
try:
299-
self._cell_harmnumber = getattr(cavities[0], 'HarmNumber')
299+
self._cell_harmnumber = cavities[0].HarmNumber
300300
except AttributeError:
301301
length += self.get_s_pos(len(self))[0]
302302
rev = self.beta * clight / length
303-
frequency = getattr(cavities[0], 'Frequency')
303+
frequency = cavities[0].Frequency
304304
self._cell_harmnumber = int(round(frequency / rev))
305305
self._radiation |= params.pop('_radiation')
306306

@@ -314,13 +314,13 @@ def insert(self, idx: SupportsIndex, elem: Element, copy_elements=False):
314314
If :py:obj:`True` a deep copy of elem
315315
is used.
316316
"""
317-
# noinspection PyUnusedLocal
318317
# scan the new element to update it
319-
elist = list(self._addition_filter([elem],
318+
elist = list(self._addition_filter([elem], # noqa: F841
320319
copy_elements=copy_elements))
321320
super().insert(idx, elem)
322321

323322
def extend(self, elems: Iterable[Element], copy_elements=False):
323+
# noinspection PyUnresolvedReferences
324324
r"""This method adds all the elements of `elems` to the end of the
325325
lattice. The behavior is the same as for a :py:obj:`list`
326326
@@ -343,6 +343,7 @@ def extend(self, elems: Iterable[Element], copy_elements=False):
343343
super().extend(elems)
344344

345345
def append(self, elem: Element, copy_elements=False):
346+
# noinspection PyUnresolvedReferences
346347
r"""This method overwrites the inherited method :py:meth:`list.append()`,
347348
its behavior is changed, it accepts only AT lattice elements
348349
:py:obj:`Element` as input argument.
@@ -361,6 +362,7 @@ def append(self, elem: Element, copy_elements=False):
361362
self.extend([elem], copy_elements=copy_elements)
362363

363364
def repeat(self, n: int, copy_elements: bool = True):
365+
# noinspection SpellCheckingInspection,PyUnresolvedReferences,PyRedeclaration
364366
r"""This method allows to repeat the lattice `n` times.
365367
If `n` does not divide `ring.periodicity`, the new ring
366368
periodicity is set to 1, otherwise it is set to
@@ -405,6 +407,7 @@ def copy_fun(elem, copy):
405407

406408
def concatenate(self, *lattices: Iterable[Element],
407409
copy_elements=False, copy=False):
410+
# noinspection PyUnresolvedReferences,SpellCheckingInspection,PyRedeclaration
408411
"""Concatenate several `Iterable[Element]` with the lattice
409412
410413
Equivalent syntaxes:
@@ -439,6 +442,7 @@ def concatenate(self, *lattices: Iterable[Element],
439442
return lattice if copy else None
440443

441444
def reverse(self, copy=False):
445+
# noinspection PyUnresolvedReferences
442446
r"""Reverse the order of the lattice and swapt the faces
443447
of elements. Alignment errors are not swapped
444448
@@ -516,7 +520,7 @@ def copy(self) -> Lattice:
516520
def deepcopy(self) -> Lattice:
517521
"""Returns a deep copy of the lattice"""
518522
return copy.deepcopy(self)
519-
523+
520524
def slice_elements(self, refpts: Refpts, slices: int = 1) -> Lattice:
521525
"""Create a new lattice by slicing the elements at refpts
522526
@@ -538,7 +542,7 @@ def slice_generator(_):
538542
else:
539543
yield el
540544

541-
return Lattice(slice_generator, iterator=self.attrs_filter)
545+
return Lattice(slice_generator, iterator=self.attrs_filter)
542546

543547
def slice(self, size: Optional[float] = None, slices: Optional[int] = 1) \
544548
-> Lattice:
@@ -635,8 +639,8 @@ def energy(self) -> float:
635639
def energy(self, energy: float):
636640
# Set the Energy attribute of radiating elements
637641
for elem in self:
638-
if (isinstance(elem, (elt.RFCavity, elt.Wiggler)) or
639-
elem.PassMethod.endswith('RadPass')):
642+
if (isinstance(elem, (elt.RFCavity, elt.Wiggler))
643+
or elem.PassMethod.endswith('RadPass')):
640644
elem.Energy = energy
641645
# Set the energy attribute of the Lattice
642646
# Use a numpy scalar to allow division by zero
@@ -1471,7 +1475,7 @@ def params_filter(params, elem_filter: Filter, *args) -> Generator[Element, None
14711475
cavities = []
14721476
cell_length = 0
14731477

1474-
for idx, elem in enumerate(elem_filter(params, *args)):
1478+
for elem in elem_filter(params, *args):
14751479
if isinstance(elem, elt.RFCavity):
14761480
cavities.append(elem)
14771481
elif hasattr(elem, 'Energy'):

‎pyat/at/lattice/lattice_variables.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Variables are **references** to scalar attributes of lattice elements. There are 2
2+
kinds of element variables:
3+
4+
- an :py:class:`ElementVariable` is associated to an element object, and acts on all
5+
occurences of this object. But it will not affect any copy, neither shallow nor deep,
6+
of the original object,
7+
- a :py:class:`RefptsVariable` is not associated to an element object, but to an element
8+
location in a :py:class:`.Lattice`. It acts on any copy of the initial lattice. A
9+
*ring* argument must be provided to the *set* and *get* methods to identify the
10+
lattice, which may be a possibly modified copy of the original lattice
11+
"""
12+
13+
from __future__ import annotations
14+
15+
__all__ = ["RefptsVariable", "ElementVariable"]
16+
17+
from collections.abc import Sequence
18+
from typing import Union, Optional
19+
20+
import numpy as np
21+
22+
from .elements import Element
23+
from .lattice_object import Lattice
24+
from .utils import Refpts, getval, setval
25+
from .variables import VariableBase
26+
27+
28+
class RefptsVariable(VariableBase):
29+
r"""A reference to a scalar attribute of :py:class:`.Lattice` elements.
30+
31+
It can refer to:
32+
33+
* a scalar attribute or
34+
* an element of an array attribute
35+
36+
of one or several :py:class:`.Element`\ s of a lattice.
37+
38+
A :py:class:`RefptsVariable` is not associated to element objets themselves, but
39+
to the location of these elements in a lattice. So a :py:class:`RefptsVariable`
40+
will act equally on any copy of the initial ring.
41+
As a consequence, a *ring* keyword argument (:py:class:`.Lattice` object) must be
42+
supplied for getting or setting the variable.
43+
"""
44+
45+
def __init__(
46+
self, refpts: Refpts, attrname: str, index: Optional[int] = None, **kwargs
47+
):
48+
r"""
49+
Parameters:
50+
refpts: Location of variable :py:class:`.Element`\ s
51+
attrname: Attribute name
52+
index: Index in the attribute array. Use :py:obj:`None` for
53+
scalar attributes
54+
55+
Keyword Args:
56+
name (str): Name of the Variable. Default: ``''``
57+
bounds (tuple[float, float]): Lower and upper bounds of the
58+
variable value. Default: (-inf, inf)
59+
delta (float): Step. Default: 1.0
60+
ring (Lattice): If specified, it is used to get and store the initial
61+
value of the variable. Otherwise, the initial value is set to None
62+
"""
63+
self._getf = getval(attrname, index=index)
64+
self._setf = setval(attrname, index=index)
65+
self.refpts = refpts
66+
super().__init__(**kwargs)
67+
68+
def _setfun(self, value: float, ring: Lattice = None):
69+
if ring is None:
70+
raise ValueError("Can't set values if ring is None")
71+
for elem in ring.select(self.refpts):
72+
self._setf(elem, value)
73+
74+
def _getfun(self, ring: Lattice = None) -> float:
75+
if ring is None:
76+
raise ValueError("Can't get values if ring is None")
77+
values = np.array([self._getf(elem) for elem in ring.select(self.refpts)])
78+
return np.average(values)
79+
80+
81+
class ElementVariable(VariableBase):
82+
r"""A reference to a scalar attribute of :py:class:`.Lattice` elements.
83+
84+
It can refer to:
85+
86+
* a scalar attribute or
87+
* an element of an array attribute
88+
89+
of one or several :py:class:`.Element`\ s of a lattice.
90+
91+
An :py:class:`ElementVariable` is associated to an element object, and acts on all
92+
occurrences of this object. But it will not affect any copy, neither shallow nor
93+
deep, of the original object.
94+
"""
95+
96+
def __init__(
97+
self,
98+
elements: Union[Element, Sequence[Element]],
99+
attrname: str,
100+
index: Optional[int] = None,
101+
**kwargs,
102+
):
103+
r"""
104+
Parameters:
105+
elements: :py:class:`.Element` or Sequence[Element] whose
106+
attribute is varied
107+
attrname: Attribute name
108+
index: Index in the attribute array. Use :py:obj:`None` for
109+
scalar attributes
110+
111+
Keyword Args:
112+
name (str): Name of the Variable. Default: ``''``
113+
bounds (tuple[float, float]): Lower and upper bounds of the
114+
variable value. Default: (-inf, inf)
115+
delta (float): Step. Default: 1.0
116+
"""
117+
# Ensure the uniqueness of elements
118+
if isinstance(elements, Element):
119+
self._elements = {elements}
120+
else:
121+
self._elements = set(elements)
122+
self._getf = getval(attrname, index=index)
123+
self._setf = setval(attrname, index=index)
124+
super().__init__(**kwargs)
125+
126+
def _setfun(self, value: float, **kwargs):
127+
for elem in self._elements:
128+
self._setf(elem, value)
129+
130+
def _getfun(self, **kwargs) -> float:
131+
values = np.array([self._getf(elem) for elem in self._elements])
132+
return np.average(values)
133+
134+
@property
135+
def elements(self):
136+
"""Return the set of elements acted upon by the variable"""
137+
return self._elements

‎pyat/at/lattice/utils.py

+71-17
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from typing import Union, Tuple, List, Type
3838
from enum import Enum
3939
from itertools import compress
40+
from operator import attrgetter
4041
from fnmatch import fnmatch
4142
from .elements import Element, Dipole
4243

@@ -58,7 +59,7 @@
5859
'set_shift', 'set_tilt', 'set_rotation',
5960
'tilt_elem', 'shift_elem', 'rotate_elem',
6061
'get_value_refpts', 'set_value_refpts', 'Refpts',
61-
'get_geometry']
62+
'get_geometry', 'setval', 'getval']
6263

6364
_axis_def = dict(
6465
x=dict(index=0, label="x", unit=" [m]"),
@@ -113,6 +114,72 @@ def _type_error(refpts, types):
113114
"Invalid refpts type {0}. Allowed types: {1}".format(tp, types))
114115

115116

117+
# setval and getval return pickleable functions: no inner, nested function
118+
# are allowed. So nested functions are replaced be module-level callable
119+
# class instances
120+
class _AttrItemGetter(object):
121+
__slots__ = ["attrname", "index"]
122+
123+
def __init__(self, attrname: str, index: int):
124+
self.attrname = attrname
125+
self.index = index
126+
127+
def __call__(self, elem):
128+
return getattr(elem, self.attrname)[self.index]
129+
130+
131+
def getval(attrname: str, index: Optional[int] = None) -> Callable:
132+
"""Return a callable object which fetches item *index* of
133+
attribute *attrname* of its operand. Examples:
134+
135+
- After ``f = getval('Length')``, ``f(elem)`` returns ``elem.Length``
136+
- After ``f = getval('PolynomB, index=1)``, ``f(elem)`` returns
137+
``elem.PolynomB[1]``
138+
139+
"""
140+
if index is None:
141+
return attrgetter(attrname)
142+
else:
143+
return _AttrItemGetter(attrname, index)
144+
145+
146+
class _AttrSetter(object):
147+
__slots__ = ["attrname"]
148+
149+
def __init__(self, attrname: str):
150+
self.attrname = attrname
151+
152+
def __call__(self, elem, value):
153+
setattr(elem, self.attrname, value)
154+
155+
156+
class _AttrItemSetter(object):
157+
__slots__ = ["attrname", "index"]
158+
159+
def __init__(self, attrname: str, index: int):
160+
self.attrname = attrname
161+
self.index = index
162+
163+
def __call__(self, elem, value):
164+
getattr(elem, self.attrname)[self.index] = value
165+
166+
167+
def setval(attrname: str, index: Optional[int] = None) -> Callable:
168+
"""Return a callable object which sets the value of item *index* of
169+
attribute *attrname* of its 1st argument to it 2nd orgument.
170+
171+
- After ``f = setval('Length')``, ``f(elem, value)`` is equivalent to
172+
``elem.Length = value``
173+
- After ``f = setval('PolynomB, index=1)``, ``f(elem, value)`` is
174+
equivalent to ``elem.PolynomB[1] = value``
175+
176+
"""
177+
if index is None:
178+
return _AttrSetter(attrname)
179+
else:
180+
return _AttrItemSetter(attrname, index)
181+
182+
116183
# noinspection PyIncorrectDocstring
117184
def axis_descr(*args, key=None) -> Tuple:
118185
r"""axis_descr(axis [ ,axis], key=None)
@@ -779,13 +846,7 @@ def get_value_refpts(ring: Sequence[Element], refpts: Refpts,
779846
Returns:
780847
attrvalues: numpy Array of attribute values.
781848
"""
782-
if index is None:
783-
def getf(elem):
784-
return getattr(elem, attrname)
785-
else:
786-
def getf(elem):
787-
return getattr(elem, attrname)[index]
788-
849+
getf = getval(attrname, index=index)
789850
return numpy.array([getf(elem) for elem in refpts_iterator(ring, refpts,
790851
regex=regex)])
791852

@@ -822,13 +883,7 @@ def set_value_refpts(ring: Sequence[Element], refpts: Refpts,
822883
elements are shared with the original lattice.
823884
Any further modification will affect both lattices.
824885
"""
825-
if index is None:
826-
def setf(elem, value):
827-
setattr(elem, attrname, value)
828-
else:
829-
def setf(elem, value):
830-
getattr(elem, attrname)[index] = value
831-
886+
setf = setval(attrname, index=index)
832887
if increment:
833888
attrvalues += get_value_refpts(ring, refpts,
834889
attrname, index=index,
@@ -841,8 +896,7 @@ def setf(elem, value):
841896
# noinspection PyShadowingNames
842897
@make_copy(copy)
843898
def apply(ring, refpts, values, regex):
844-
for elm, val in zip(refpts_iterator(ring, refpts,
845-
regex=regex), values):
899+
for elm, val in zip(refpts_iterator(ring, refpts, regex=regex), values):
846900
setf(elm, val)
847901

848902
return apply(ring, refpts, attrvalues, regex)

‎pyat/at/lattice/variables.py

+458
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)
Please sign in to comment.