Skip to content

Commit 3b5566e

Browse files
authored
Fix the energy loss computation (#877)
* keep only element-by-element tracking in get_energy_lss * update tests * Update Matlab tests
1 parent b90e44d commit 3b5566e

File tree

6 files changed

+256
-580
lines changed

6 files changed

+256
-580
lines changed

.github/workflows/python-tests.yml

+8-13
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,19 @@ jobs:
1515

1616
strategy:
1717
matrix:
18-
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13']
19-
os: [macos-latest, ubuntu-latest, windows-latest]
20-
exclude:
21-
- os: windows-latest
22-
python-version: '3.7'
23-
- os: macos-latest
24-
python-version: '3.7'
25-
- os: macos-latest
26-
python-version: '3.8'
27-
- os: macos-latest
28-
python-version: '3.9'
18+
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
19+
os: [macos-13, macos-latest, ubuntu-latest, windows-latest]
2920
include:
3021
- os: macos-13
3122
python-version: '3.7'
3223
- os: macos-13
3324
python-version: '3.8'
34-
- os: macos-13
35-
python-version: '3.9'
25+
- os: ubuntu-22.04
26+
python-version: '3.7'
27+
- os: ubuntu-22.04
28+
python-version: '3.8'
29+
- os: windows-latest
30+
python-version: '3.8'
3631

3732

3833
steps:

atmat/atphysics/Radiation/atgetU0.m

+8-6
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,14 @@
7575
function U0=tracking(ring)
7676
% Ensure 6d is enabled
7777
check_6d(ring,true);
78-
% Turn cavities off
79-
ringtmp=atdisable_6d(ring,'allpass','','cavipass','auto',...
80-
'quantdiffpass','auto','simplequantdiffpass','auto');
81-
o0=zeros(6,1);
82-
o6=ringpass(ringtmp,o0);
83-
U0=-o6(5)*energy;
78+
radiating=atgetcells(ring,'PassMethod','*RadPass');
79+
sumd=sum(cellfun(@comp, ring(radiating)));
80+
U0=-sumd*energy;
81+
82+
function delta = comp(elem)
83+
rout=elempass(elem,zeros(6,1),'Energy',energy);
84+
delta=rout(5);
85+
end
8486
end
8587

8688
end

atmat/attests/pytests.m

+2-2
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ function tunechrom6(testCase,lat2,dp)
183183
[mtune,mchrom]=tunechrom(mlat,'get_chrom');
184184
ptune=double(plat.get_tune());
185185
pchrom=double(plat.get_chrom());
186-
testCase.verifyEqual(mod(mtune*periodicity,1),ptune,AbsTol=1.e-9);
187-
testCase.verifyEqual(mchrom*periodicity,pchrom,RelTol=1.e-4,AbsTol=3.e-4);
186+
testCase.verifyEqual(mod(mtune*periodicity,1),ptune,AbsTol=2.5e-9);
187+
testCase.verifyEqual(mchrom*periodicity,pchrom,RelTol=3.e-4,AbsTol=2.e-4);
188188
end
189189

190190
function linopt1(testCase,dp)

pyat/at/physics/energy_loss.py

+100-84
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1+
from __future__ import annotations
2+
3+
__all__ = ["get_energy_loss", "set_cavity_phase", "ELossMethod", "get_timelag_fromU0"]
4+
15
from enum import Enum
26
from warnings import warn
3-
from math import pi
4-
from typing import Optional, Tuple
5-
import numpy
7+
from collections.abc import Sequence
8+
9+
import numpy as np
610
from scipy.optimize import least_squares
11+
12+
from at.constants import clight, Cgamma
713
from at.lattice import Lattice, Dipole, Wiggler, RFCavity, Refpts, EnergyLoss
814
from at.lattice import check_radiation, AtError, AtWarning
9-
from at.lattice import QuantumDiffusion, Collective, SimpleQuantDiff
1015
from at.lattice import get_bool_index, set_value_refpts
11-
from at.constants import clight, Cgamma
12-
from at.tracking import internal_lpass
13-
14-
__all__ = ['get_energy_loss', 'set_cavity_phase', 'ELossMethod',
15-
'get_timelag_fromU0']
1616

1717

1818
class ELossMethod(Enum):
1919
"""methods for the computation of energy losses"""
20+
2021
#: The losses are obtained from
2122
#: :math:`E_{loss}=C_\gamma/2\pi . E^4 . I_2`.
2223
#: Takes into account bending magnets and wigglers.
@@ -26,9 +27,9 @@ class ELossMethod(Enum):
2627
TRACKING = 2
2728

2829

29-
def get_energy_loss(ring: Lattice,
30-
method: Optional[ELossMethod] = ELossMethod.INTEGRAL
31-
) -> float:
30+
def get_energy_loss(
31+
ring: Lattice, method: ELossMethod | None = ELossMethod.INTEGRAL
32+
) -> float:
3233
"""Computes the energy loss per turn
3334
3435
Parameters:
@@ -42,153 +43,168 @@ def get_energy_loss(ring: Lattice,
4243

4344
# noinspection PyShadowingNames
4445
def integral(ring):
45-
"""Losses = Cgamma / 2pi * EGeV^4 * i2
46-
"""
46+
"""Losses = Cgamma / 2pi * EGeV^4 * i2"""
4747

4848
def wiggler_i2(wiggler: Wiggler):
4949
rhoinv = wiggler.Bmax / ring.BRho
5050
coefh = wiggler.By[1, :]
5151
coefv = wiggler.Bx[1, :]
52-
return wiggler.Length * (numpy.sum(coefh * coefh) + numpy.sum(
53-
coefv*coefv)) * rhoinv ** 2 / 2
52+
return (
53+
wiggler.Length
54+
* (np.sum(coefh * coefh) + np.sum(coefv * coefv))
55+
* rhoinv**2
56+
/ 2
57+
)
5458

5559
def dipole_i2(dipole: Dipole):
56-
return dipole.BendingAngle ** 2 / dipole.Length
60+
return dipole.BendingAngle**2 / dipole.Length
5761

5862
def eloss_i2(eloss: EnergyLoss):
5963
return eloss.EnergyLoss / coef
6064

6165
i2 = 0.0
62-
coef = Cgamma / 2.0 / pi * ring.energy ** 4
66+
coef = Cgamma / 2.0 / np.pi * ring.energy**4
6367
for el in ring:
6468
if isinstance(el, Dipole):
6569
i2 += dipole_i2(el)
66-
elif isinstance(el, Wiggler) and el.PassMethod != 'DriftPass':
70+
elif isinstance(el, Wiggler) and el.PassMethod != "DriftPass":
6771
i2 += wiggler_i2(el)
68-
elif isinstance(el, EnergyLoss) and el.PassMethod != 'IdentityPass':
72+
elif isinstance(el, EnergyLoss) and el.PassMethod != "IdentityPass":
6973
i2 += eloss_i2(el)
7074
e_loss = coef * i2
7175
return e_loss
7276

7377
# noinspection PyShadowingNames
7478
@check_radiation(True)
7579
def tracking(ring):
76-
"""Losses from tracking
77-
"""
78-
ringtmp = ring.disable_6d(RFCavity, QuantumDiffusion, Collective,
79-
SimpleQuantDiff, copy=True)
80-
o6 = numpy.squeeze(internal_lpass(ringtmp, numpy.zeros(6),
81-
refpts=len(ringtmp)))
82-
if numpy.isnan(o6[0]):
83-
dp = 0
84-
for e in ringtmp:
85-
ot = numpy.squeeze(internal_lpass([e], numpy.zeros(6)))
86-
dp += -ot[4] * ring.energy
87-
return dp
88-
else:
89-
return -o6[4] * ring.energy
80+
"""Losses from tracking"""
81+
energy = ring.energy
82+
particle = ring.particle
83+
delta = 0.0
84+
for e in ring:
85+
if e.PassMethod.endswith("RadPass"):
86+
ot = e.track(np.zeros(6), energy=energy, particle=particle)
87+
delta += ot[4]
88+
return -delta * energy
9089

9190
if isinstance(method, str):
9291
method = ELossMethod[method.upper()]
93-
warn(FutureWarning('You should use {0!s}'.format(method)))
92+
warn(FutureWarning(f"You should use {method!s}"), stacklevel=2)
9493
if method is ELossMethod.INTEGRAL:
9594
return ring.periodicity * integral(ring)
9695
elif method == ELossMethod.TRACKING:
9796
return ring.periodicity * tracking(ring)
9897
else:
99-
raise AtError('Invalid method: {}'.format(method))
98+
raise AtError(f"Invalid method: {method}")
10099

101100

102101
# noinspection PyPep8Naming
103-
def get_timelag_fromU0(ring: Lattice,
104-
method: Optional[ELossMethod] = ELossMethod.TRACKING,
105-
cavpts: Optional[Refpts] = None,
106-
divider: Optional[int] = 4,
107-
ts_tol: Optional[float] = 1.0e-9) -> Tuple[float, float]:
102+
def get_timelag_fromU0(
103+
ring: Lattice,
104+
*,
105+
method: ELossMethod | None = ELossMethod.TRACKING,
106+
cavpts: Refpts | None = None,
107+
divider: int | None = 4,
108+
ts_tol: float | None = 1.0e-9,
109+
) -> tuple[Sequence[float], float]:
108110
"""
109111
Get the TimeLag attribute of RF cavities based on frequency,
110112
voltage and energy loss per turn, so that the synchronous phase is zero.
111-
An error occurs if all cavities do not have the same frequency.
112113
Used in set_cavity_phase()
113114
114-
115115
Parameters:
116116
ring: Lattice description
117117
method: Method for energy loss computation.
118118
See :py:class:`ELossMethod`.
119119
cavpts: Cavity location. If None, use all cavities.
120120
This allows to ignore harmonic cavities.
121121
divider: number of segments to search for ts
122-
phis_tol: relative tolerance for ts calculation
122+
ts_tol: relative tolerance for ts calculation
123123
Returns:
124-
timelag (float): Timelag
124+
timelag (float): (ncav,) array of *Timelag* values
125125
ts (float): Time difference with the present value
126126
"""
127+
127128
def singlev(values):
128-
vals = numpy.unique(values)
129+
vals = np.unique(values)
129130
if len(vals) > 1:
130-
raise AtError('values not equal for all cavities')
131+
raise AtError("values not equal for all cavities")
131132
return vals[0]
132133

133134
def eq(x, freq, rfv, tl0, u0):
134-
omf = 2*numpy.pi*freq/clight
135+
omf = 2 * np.pi * freq / clight
135136
if u0 > 0.0:
136-
eq1 = (numpy.sum(-rfv*numpy.sin(omf*(x-tl0)))-u0)/u0
137+
eq1 = (np.sum(-rfv * np.sin(omf * (x - tl0))) - u0) / u0
137138
else:
138-
eq1 = numpy.sum(-rfv * numpy.sin(omf * (x - tl0)))
139-
eq2 = numpy.sum(-omf*rfv*numpy.cos(omf*(x-tl0)))
139+
eq1 = np.sum(-rfv * np.sin(omf * (x - tl0)))
140+
eq2 = np.sum(-omf * rfv * np.cos(omf * (x - tl0)))
140141
if eq2 > 0:
141-
return numpy.sqrt(eq1**2+eq2**2)
142+
return np.sqrt(eq1**2 + eq2**2)
142143
else:
143144
return abs(eq1)
144145

145146
if cavpts is None:
146147
cavpts = get_bool_index(ring, RFCavity)
147148
u0 = get_energy_loss(ring, method=method) / ring.periodicity
148-
freq = numpy.array([cav.Frequency for cav in ring.select(cavpts)])
149-
rfv = numpy.array([cav.Voltage for cav in ring.select(cavpts)])
150-
tl0 = numpy.array([cav.TimeLag for cav in ring.select(cavpts)])
149+
freq = np.array([cav.Frequency for cav in ring.select(cavpts)])
150+
rfv = np.array([cav.Voltage for cav in ring.select(cavpts)])
151+
tl0 = np.array([cav.TimeLag for cav in ring.select(cavpts)])
151152
try:
152153
frf = singlev(freq)
153154
tml = singlev(tl0)
154155
except AtError:
155-
ctmax = clight/numpy.amin(freq)/2
156-
tt0 = tl0[numpy.argmin(freq)]
156+
ctmax = clight / np.amin(freq) / 2
157+
tt0 = tl0[np.argmin(freq)]
157158
bounds = (-ctmax, ctmax)
158159
args = (freq, rfv, tl0, u0)
159160
r = []
160161
for i in range(divider):
161-
fact = (i+1)/divider
162-
r.append(least_squares(eq, bounds[0]*fact+tt0,
163-
args=args, bounds=bounds+tt0))
164-
r.append(least_squares(eq, bounds[1]*fact+tt0,
165-
args=args, bounds=bounds+tt0))
166-
res = numpy.array([ri.fun[0] for ri in r])
162+
fact = (i + 1) / divider
163+
r.append(
164+
least_squares(
165+
eq, bounds[0] * fact + tt0, args=args, bounds=bounds + tt0
166+
)
167+
)
168+
r.append(
169+
least_squares(
170+
eq, bounds[1] * fact + tt0, args=args, bounds=bounds + tt0
171+
)
172+
)
173+
res = np.array([ri.fun[0] for ri in r])
167174
ok = res < ts_tol
168-
vals = numpy.array([abs(ri.x[0]).round(decimals=6) for ri in r])
169-
if not numpy.any(ok):
170-
raise AtError('No solution found for Phis, please check '
171-
'RF settings')
172-
if len(numpy.unique(vals[ok])) > 1:
173-
warn(AtWarning('More than one solution found for Phis: use '
174-
'best fit, please check RF settings'))
175-
ts = -r[numpy.argmin(res)].x[0]
176-
timelag = ts+tl0
175+
vals = np.array([abs(ri.x[0]).round(decimals=6) for ri in r])
176+
if not np.any(ok):
177+
raise AtError("No solution found for Phis: check RF settings") from None
178+
if len(np.unique(vals[ok])) > 1:
179+
warn(
180+
AtWarning("More than one solution found for Phis: check RF settings"),
181+
stacklevel=2,
182+
)
183+
ts = -r[np.argmin(res)].x[0]
184+
timelag = ts + tl0
177185
else:
178-
if u0 > numpy.sum(rfv):
179-
raise AtError('Not enough RF voltage: unstable ring')
180-
vrf = numpy.sum(rfv)
181-
timelag = clight/(2*numpy.pi*frf)*numpy.arcsin(u0/vrf)
186+
vrf = np.sum(rfv)
187+
if u0 > vrf:
188+
v1 = ring.periodicity * vrf
189+
v2 = ring.periodicity * u0
190+
raise AtError(
191+
f"The RF voltage ({v1:.3e} eV) is lower than "
192+
f"the radiation losses ({v2:.3e} eV)."
193+
)
194+
timelag = clight / (2 * np.pi * frf) * np.arcsin(u0 / vrf)
182195
ts = timelag - tml
183-
timelag *= numpy.ones(ring.refcount(cavpts))
196+
timelag *= np.ones(ring.refcount(cavpts))
184197
return timelag, ts
185198

186199

187-
def set_cavity_phase(ring: Lattice,
188-
method: ELossMethod = ELossMethod.TRACKING,
189-
refpts: Optional[Refpts] = None,
190-
cavpts: Optional[Refpts] = None,
191-
copy: bool = False) -> None:
200+
def set_cavity_phase(
201+
ring: Lattice,
202+
*,
203+
method: ELossMethod = ELossMethod.TRACKING,
204+
refpts: Refpts | None = None,
205+
cavpts: Refpts | None = None,
206+
copy: bool = False,
207+
) -> None:
192208
"""
193209
Adjust the TimeLag attribute of RF cavities based on frequency,
194210
voltage and energy loss per turn, so that the synchronous phase is zero.
@@ -209,12 +225,12 @@ def set_cavity_phase(ring: Lattice,
209225
"""
210226
# refpts is kept for backward compatibility
211227
if cavpts is None and refpts is not None:
212-
warn(FutureWarning('You should use "cavpts" instead of "refpts"'))
228+
warn(FutureWarning('You should use "cavpts" instead of "refpts"'), stacklevel=2)
213229
cavpts = refpts
214230
elif cavpts is None:
215231
cavpts = get_bool_index(ring, RFCavity)
216232
timelag, _ = get_timelag_fromU0(ring, method=method, cavpts=cavpts)
217-
set_value_refpts(ring, cavpts, 'TimeLag', timelag, copy=copy)
233+
set_value_refpts(ring, cavpts, "TimeLag", timelag, copy=copy)
218234

219235

220236
Lattice.get_energy_loss = get_energy_loss

pyat/at/physics/orbit.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -321,12 +321,11 @@ def _orbit6(ring: Lattice, cavpts=None, guess=None, keep_lattice=False,
321321
harm_number = round(f_rf*l0/ring.beta/clight)
322322

323323
if guess is None:
324-
_, dt = get_timelag_fromU0(ring, method=method, cavpts=cavpts)
325-
# Getting timelag by tracking uses a different lattice,
326-
# so we cannot now use the same one again.
327-
if method is ELossMethod.TRACKING:
328-
keep_lattice = False
329324
ref_in = numpy.zeros((6,), order='F')
325+
try:
326+
_, dt = get_timelag_fromU0(ring, method=method, cavpts=cavpts)
327+
except AtError as exc:
328+
raise AtError("Could not determine the initial synchronous phase") from exc
330329
ref_in[5] = -dt
331330
else:
332331
ref_in = numpy.copy(guess)

0 commit comments

Comments
 (0)