Skip to content

Commit 836663b

Browse files
authored
New element properties: dx, dy, tilt (#902)
* new dx, dy, tilt properties * Added tests for the new properties
1 parent c8c5844 commit 836663b

File tree

2 files changed

+120
-42
lines changed

2 files changed

+120
-42
lines changed

pyat/at/lattice/elements.py

+89-31
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,24 @@
1010

1111
import abc
1212
import re
13+
import math
1314
from abc import ABC
1415
from collections.abc import Generator, Iterable
1516
from copy import copy, deepcopy
1617
from typing import Any, Optional
1718

18-
import numpy
19+
import numpy as np
1920

2021
# noinspection PyProtectedMember
2122
from .variables import _nop
2223

24+
_zero6 = np.zeros(6)
25+
_eye6 = np.eye(6, order="F")
2326

24-
def _array(value, shape=(-1,), dtype=numpy.float64):
27+
28+
def _array(value, shape=(-1,), dtype=np.float64):
2529
# Ensure proper ordering(F) and alignment(A) for "C" access in integrators
26-
return numpy.require(value, dtype=dtype, requirements=["F", "A"]).reshape(
30+
return np.require(value, dtype=dtype, requirements=["F", "A"]).reshape(
2731
shape, order="F"
2832
)
2933

@@ -408,7 +412,7 @@ def definition(self) -> tuple[str, tuple, dict]:
408412
keywords = {
409413
k: v
410414
for k, v in attrs.items()
411-
if not numpy.array_equal(v, getattr(defelem, k, None))
415+
if not np.array_equal(v, getattr(defelem, k, None))
412416
}
413417
return self.__class__.__name__, arguments, keywords
414418

@@ -448,6 +452,60 @@ def is_collective(self) -> bool:
448452
""":py:obj:`True` if the element involves collective effects"""
449453
return self._get_collective()
450454

455+
def _getshift(self, idx: int):
456+
t1 = getattr(self, "T1", _zero6)
457+
t2 = getattr(self, "T2", _zero6)
458+
return 0.5 * float(t2[idx] - t1[idx])
459+
460+
def _setshift(self, value: float, idx: int) -> None:
461+
t1 = getattr(self, "T1", _zero6.copy())
462+
t2 = getattr(self, "T2", _zero6.copy())
463+
sm = 0.5 * (t2[idx] + t1[idx])
464+
t2[idx] = sm + value
465+
t1[idx] = sm - value
466+
self.T1 = t1
467+
self.T2 = t2
468+
469+
@property
470+
def dx(self) -> float:
471+
"""Horizontal element shift"""
472+
return self._getshift(0)
473+
474+
@dx.setter
475+
def dx(self, value: float) -> None:
476+
self._setshift(value, 0)
477+
478+
@property
479+
def dy(self) -> float:
480+
"""Vertical element shift"""
481+
return self._getshift(2)
482+
483+
@dy.setter
484+
def dy(self, value: float) -> None:
485+
self._setshift(value, 2)
486+
487+
@property
488+
def tilt(self) -> float:
489+
"""Element tilt"""
490+
r1 = getattr(self, "R1", _eye6)
491+
r2 = getattr(self, "R2", _eye6)
492+
c = float(r2[0, 0] + r1[0, 0])
493+
s = float(r2[2, 0] - r1[2, 0])
494+
return math.atan2(s, c)
495+
496+
@tilt.setter
497+
def tilt(self, value: float) -> None:
498+
r1 = getattr(self, "R1", _eye6.copy())
499+
r2 = getattr(self, "R2", _eye6.copy())
500+
ct, st = math.cos(value), math.sin(value)
501+
r44 = np.diag([ct, ct, ct, ct])
502+
r44[0, 2] = r44[1, 3] = st
503+
r44[2, 0] = r44[3, 1] = -st
504+
r1[:4, :4] = r44
505+
r2[:4, :4] = r44.T
506+
self.R1 = r1
507+
self.R2 = r2
508+
451509

452510
class LongElement(Element):
453511
"""Base class for long elements"""
@@ -480,15 +538,15 @@ def popattr(element, attr):
480538
delattr(element, attr)
481539
return attr, val
482540

483-
frac = numpy.asarray(frac, dtype=float)
541+
frac = np.asarray(frac, dtype=float)
484542
el = self.copy()
485543
# Remove entrance and exit attributes
486544
fin = dict(
487545
popattr(el, key) for key in vars(self) if key in self._entrance_fields
488546
)
489547
fout = dict(popattr(el, key) for key in vars(self) if key in self._exit_fields)
490548
# Split element
491-
element_list = [el._part(f, numpy.sum(frac)) for f in frac]
549+
element_list = [el._part(f, np.sum(frac)) for f in frac]
492550
# Restore entrance and exit attributes
493551
for key, value in fin.items():
494552
setattr(element_list[0], key, value)
@@ -505,7 +563,7 @@ def compatible_field(fieldname):
505563
elif f1 is None or f2 is None: # only one
506564
return False
507565
else: # both
508-
return numpy.all(f1 == f2)
566+
return np.all(f1 == f2)
509567

510568
if not (type(other) is type(self) and self.PassMethod == other.PassMethod):
511569
return False
@@ -538,13 +596,13 @@ def __init__(self, family_name: str, **kwargs):
538596
Default PassMethod: ``BeamMomentsPass``
539597
"""
540598
kwargs.setdefault("PassMethod", "BeamMomentsPass")
541-
self._stds = numpy.zeros((6, 1, 1), order="F")
542-
self._means = numpy.zeros((6, 1, 1), order="F")
599+
self._stds = np.zeros((6, 1, 1), order="F")
600+
self._means = np.zeros((6, 1, 1), order="F")
543601
super().__init__(family_name, **kwargs)
544602

545603
def set_buffers(self, nturns, nbunch):
546-
self._stds = numpy.zeros((6, nbunch, nturns), order="F")
547-
self._means = numpy.zeros((6, nbunch, nturns), order="F")
604+
self._stds = np.zeros((6, nbunch, nturns), order="F")
605+
self._means = np.zeros((6, nbunch, nturns), order="F")
548606

549607
@property
550608
def stds(self):
@@ -583,20 +641,20 @@ def __init__(self, family_name: str, nslice: int, **kwargs):
583641
self.startturn = self._startturn
584642
self.endturn = self._endturn
585643
self._dturns = self.endturn - self.startturn
586-
self._stds = numpy.zeros((3, nslice, self._dturns), order="F")
587-
self._means = numpy.zeros((3, nslice, self._dturns), order="F")
588-
self._spos = numpy.zeros((nslice, self._dturns), order="F")
589-
self._weights = numpy.zeros((nslice, self._dturns), order="F")
644+
self._stds = np.zeros((3, nslice, self._dturns), order="F")
645+
self._means = np.zeros((3, nslice, self._dturns), order="F")
646+
self._spos = np.zeros((nslice, self._dturns), order="F")
647+
self._weights = np.zeros((nslice, self._dturns), order="F")
590648
self.set_buffers(self._endturn, 1)
591649

592650
def set_buffers(self, nturns, nbunch):
593651
self.endturn = min(self.endturn, nturns)
594652
self._dturns = self.endturn - self.startturn
595653
self._nbunch = nbunch
596-
self._stds = numpy.zeros((3, nbunch * self.nslice, self._dturns), order="F")
597-
self._means = numpy.zeros((3, nbunch * self.nslice, self._dturns), order="F")
598-
self._spos = numpy.zeros((nbunch * self.nslice, self._dturns), order="F")
599-
self._weights = numpy.zeros((nbunch * self.nslice, self._dturns), order="F")
654+
self._stds = np.zeros((3, nbunch * self.nslice, self._dturns), order="F")
655+
self._means = np.zeros((3, nbunch * self.nslice, self._dturns), order="F")
656+
self._spos = np.zeros((nbunch * self.nslice, self._dturns), order="F")
657+
self._weights = np.zeros((nbunch * self.nslice, self._dturns), order="F")
600658

601659
@property
602660
def stds(self):
@@ -730,11 +788,11 @@ def insert(
730788
"""
731789
frac, elements = zip(*insert_list)
732790
lg = [0.0 if el is None else el.Length for el in elements]
733-
fr = numpy.asarray(frac, dtype=float)
734-
lg = 0.5 * numpy.asarray(lg, dtype=float) / self.Length
735-
drfrac = numpy.hstack((fr - lg, 1.0)) - numpy.hstack((0.0, fr + lg))
791+
fr = np.asarray(frac, dtype=float)
792+
lg = 0.5 * np.asarray(lg, dtype=float) / self.Length
793+
drfrac = np.hstack((fr - lg, 1.0)) - np.hstack((0.0, fr + lg))
736794
long_elems = drfrac != 0.0
737-
drifts = numpy.ndarray((len(drfrac),), dtype="O")
795+
drifts = np.ndarray((len(drfrac),), dtype="O")
738796
drifts[long_elems] = self.divide(drfrac[long_elems])
739797
nline = len(drifts) + len(elements)
740798
line = [None] * nline # type: list[Optional[Element]]
@@ -783,12 +841,12 @@ def __init__(self, family_name: str, poly_a, poly_b, **kwargs):
783841
"""
784842

785843
def getpol(poly):
786-
nonzero = numpy.flatnonzero(poly != 0.0)
844+
nonzero = np.flatnonzero(poly != 0.0)
787845
return poly, len(poly), nonzero[-1] if len(nonzero) > 0 else -1
788846

789847
def lengthen(poly, dl):
790848
if dl > 0:
791-
return numpy.concatenate((poly, numpy.zeros(dl)))
849+
return np.concatenate((poly, np.zeros(dl)))
792850
else:
793851
return poly
794852

@@ -1198,7 +1256,7 @@ def __init__(self, family_name: str, m66=None, **kwargs):
11981256
Default PassMethod: ``Matrix66Pass``
11991257
"""
12001258
if m66 is None:
1201-
m66 = numpy.identity(6)
1259+
m66 = np.identity(6)
12021260
kwargs.setdefault("PassMethod", "Matrix66Pass")
12031261
kwargs.setdefault("M66", m66)
12041262
super().__init__(family_name, **kwargs)
@@ -1311,24 +1369,24 @@ def __init__(
13111369
if taux == 0.0:
13121370
dampx = 1
13131371
else:
1314-
dampx = numpy.exp(-1 / taux)
1372+
dampx = np.exp(-1 / taux)
13151373

13161374
assert tauy >= 0.0, "tauy must be greater than or equal to 0"
13171375
if tauy == 0.0:
13181376
dampy = 1
13191377
else:
1320-
dampy = numpy.exp(-1 / tauy)
1378+
dampy = np.exp(-1 / tauy)
13211379

13221380
assert tauz >= 0.0, "tauz must be greater than or equal to 0"
13231381
if tauz == 0.0:
13241382
dampz = 1
13251383
else:
1326-
dampz = numpy.exp(-1 / tauz)
1384+
dampz = np.exp(-1 / tauz)
13271385

13281386
kwargs.setdefault("PassMethod", self.default_pass[True])
13291387
kwargs.setdefault("U0", U0)
13301388
kwargs.setdefault(
1331-
"damp_mat_diag", numpy.array([dampx, dampx, dampy, dampy, dampz, dampz])
1389+
"damp_mat_diag", np.array([dampx, dampx, dampy, dampy, dampz, dampz])
13321390
)
13331391

13341392
super().__init__(family_name, **kwargs)
@@ -1449,7 +1507,7 @@ class QuantumDiffusion(_DictLongtMotion, Element):
14491507
default_pass = {False: "IdentityPass", True: "QuantDiffPass"}
14501508
_conversions = dict(Element._conversions, Lmatp=_array66)
14511509

1452-
def __init__(self, family_name: str, lmatp: numpy.ndarray, **kwargs):
1510+
def __init__(self, family_name: str, lmatp: np.ndarray, **kwargs):
14531511
"""Quantum diffusion element
14541512
14551513
Args:

pyat/test/test_lattice_utils.py

+31-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy
22
import pytest
33

4+
import math
5+
from at.lattice import elements as elt
46
from at.lattice import checktype, get_cells, refpts_iterator, get_elements
57
from at.lattice import elements, uint32_refpts, bool_refpts, checkattr
68
from at.lattice import get_s_pos, tilt_elem, shift_elem, set_tilt, set_shift
@@ -176,20 +178,38 @@ def test_get_s_pos_returns_all_pts_for_lat_with_2_elements_using_bool_refpts():
176178
)
177179

178180

179-
def test_tilt_elem(simple_ring):
180-
tilt_elem(simple_ring[0], (numpy.pi / 4))
181-
v = 1 / 2**0.5
181+
def test_tilt_elem():
182+
elem = elt.Drift("Drift", 1.0)
183+
# Test tilt_elem function
184+
tilt_elem(elem, math.pi / 4.0)
185+
v = math.sqrt(2.0) / 2.0
182186
a = numpy.diag([v, v, v, v, 1.0, 1.0])
183187
a[0, 2], a[1, 3], a[2, 0], a[3, 1] = v, v, -v, -v
184-
numpy.testing.assert_allclose(simple_ring[0].R1, a)
185-
numpy.testing.assert_allclose(simple_ring[0].R2, a.T)
186-
187-
188-
def test_shift_elem(simple_ring):
189-
shift_elem(simple_ring[2], 1.0, 0.5)
188+
numpy.testing.assert_allclose(elem.R1, a, atol=1.0e-15)
189+
numpy.testing.assert_allclose(elem.R2, a.T, atol=1.0e-15)
190+
numpy.testing.assert_allclose(elem.tilt, numpy.pi / 4.0)
191+
# Test tilt property
192+
elem.tilt = math.pi / 2.0
193+
a = numpy.diag([0.0, 0.0, 0.0, 0.0, 1.0, 1.0])
194+
a[0, 2], a[1, 3], a[2, 0], a[3, 1] = 1.0, 1.0, -1.0, -1.0
195+
numpy.testing.assert_allclose(elem.R1, a, atol=1.0e-15)
196+
numpy.testing.assert_allclose(elem.R2, a.T, atol=1.0e-15)
197+
198+
199+
def test_shift_elem():
200+
elem = elt.Drift("Drift", 1.0)
201+
# Test shift_elem function
202+
shift_elem(elem, 1.0, 0.5)
190203
a = numpy.array([1.0, 0.0, 0.5, 0.0, 0.0, 0.0])
191-
numpy.testing.assert_equal(simple_ring[2].T1, -a)
192-
numpy.testing.assert_equal(simple_ring[2].T2, a)
204+
numpy.testing.assert_equal(elem.T1, -a)
205+
numpy.testing.assert_equal(elem.T2, a)
206+
numpy.testing.assert_equal(elem.dx, a[0])
207+
numpy.testing.assert_equal(elem.dy, a[2])
208+
# Test dx, dy properties
209+
elem.dx = -2.0
210+
elem.dy = -1.0
211+
numpy.testing.assert_equal(elem.T1, 2 * a)
212+
numpy.testing.assert_equal(elem.T2, -2 * a)
193213

194214

195215
def test_set_tilt(simple_ring):

0 commit comments

Comments
 (0)