Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add traveling-wave cavity model #286

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
446a07c
Add cavity_type param in Cavity element and update TW cavity model
zihan-zh Oct 25, 2024
71082b9
Add cavity_type param in Cavity element and update TW cavity model
zihan-zh Oct 25, 2024
eceb2d3
fixed all issues found by ./cheetah/accelerator/cavity.py:322:89: E5…
zihan-zh Oct 25, 2024
4796e19
fixed all issues found by flake8
zihan-zh Oct 25, 2024
a0c044d
fixed all issues found by flake8
zihan-zh Oct 25, 2024
d39fb67
fixed all issues found by flake8
zihan-zh Oct 25, 2024
43b4839
Merge branch 'master' into tw-cavity-fix
jank324 Nov 20, 2024
7f8f669
Merge branch 'master' into tw-cavity-fix
jank324 Dec 12, 2024
a6e23fa
Merge branch 'master' into tw-cavity-fix
jank324 Jan 10, 2025
0afae61
Run formatting
jank324 Jan 10, 2025
7d17531
Fix non-differentiable travelling wave implementation
jank324 Jan 10, 2025
da43532
Fix type annotation for cavity type
jank324 Jan 10, 2025
98ee181
Fix broken cavity type check
jank324 Jan 10, 2025
6a2d66e
Correctly import cavity type from Ocelot
jank324 Jan 10, 2025
952c13a
Fix an issue where Ocelot beam import still added a dummy vector dime…
jank324 Jan 10, 2025
1f06661
Move all of the travelling wave code into the respective if branch
jank324 Jan 10, 2025
8dc641d
Fix test failure as a result of Ocelot import fix
jank324 Jan 10, 2025
5915008
Move valid cavity type check to `else` so it is guaranteed to hit
jank324 Jan 14, 2025
e5217f8
FIx issue that was Cheetah cavity to produce wrong result with `stand…
jank324 Jan 14, 2025
86a3f2d
Fix remaining issue from wrong cavity result fix
jank324 Jan 14, 2025
4093ce3
Add test to detect vectorisation issue in `traveling_wave` mode of `C…
jank324 Jan 14, 2025
79bebc9
Correctly vectorise `traveling_wave` mode of `Cavity`
jank324 Jan 14, 2025
917c591
Update changelog
jank324 Jan 14, 2025
3c7516b
Add Zihan to contributor lists
jank324 Jan 14, 2025
258c979
Add travelling wave cavity to Ocelot converter
jank324 Jan 15, 2025
46ae49b
Add travelling wave cavity test
jank324 Jan 15, 2025
5d7681b
Change Ocelot dependency back to Ocelot `master`
jank324 Jan 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
### 🚀 Features

- `ParticleBeam` now supports importing from and exporting to [openPMD-beamphysics](https://github.com/ChristopherMayes/openPMD-beamphysics) HDF5 files and `ParticleGroup` objects. This allows for easy conversion to and from other file formats supported by openPMD-beamphysics. (see #305) (@cr-xu)
- `Cavity` now supports travelling wave cavities in addition to standing wave cavities via the `cavity_type` argument. (see #286) (@zihan-zh, @jank324)

### 🐛 Bug fixes

Expand All @@ -20,6 +21,8 @@

### 🌟 First Time Contributors

- Zihan Zhu (@zihan-zh)

## [v0.7.0](https://github.com/desy-ml/cheetah/releases/tag/v0.7.0) (2024-12-13)

We are proud to announce this new major release of Cheetah! This is probably the biggest release since the original Cheetah release, with many with significant upgrades under the hood. Cheetah is now fully vectorised and compatible with PyTorch broadcasting rules, while additional physics and higher fidelity options for existing physics have also been introduced. Despite extensive testing, you might still encounter a few bugs. Please report them by opening an issue, so we can fix them as soon as possible and improve the experience for everyone.
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ The following people have contributed to the development of Cheetah:
- Juan Pablo Gonzalez-Aguilera (@jp-ga)
- Ryan Roussel (@roussel-ryan)
- Auralee Edelen (@lee-edelen)
- Zihan Zhu (@zihan-zh)

### Institutions

Expand Down
85 changes: 63 additions & 22 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Literal, Optional

import matplotlib.pyplot as plt
import torch
Expand Down Expand Up @@ -29,6 +29,7 @@ class Cavity(Element):
:param phase: Phase of the cavity in degrees.
:param frequency: Frequency of the cavity in Hz.
:param name: Unique identifier of the element.
:param cavity_type: Type of the cavity.
"""

def __init__(
Expand All @@ -37,6 +38,7 @@ def __init__(
voltage: Optional[torch.Tensor] = None,
phase: Optional[torch.Tensor] = None,
frequency: Optional[torch.Tensor] = None,
cavity_type: Literal["standing_wave", "traveling_wave"] = "standing_wave",
name: Optional[str] = None,
device=None,
dtype=None,
Expand All @@ -59,6 +61,8 @@ def __init__(
if frequency is not None:
self.frequency = torch.as_tensor(frequency, **factory_kwargs)

self.cavity_type = cavity_type

@property
def is_active(self) -> bool:
return torch.any(self.voltage != 0)
Expand Down Expand Up @@ -252,31 +256,61 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:

alpha = torch.sqrt(eta / 8) / torch.cos(phi) * torch.log(Ef / Ei)

r11 = torch.cos(alpha) - torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin(alpha)
if self.cavity_type == "standing_wave":
r11 = torch.cos(alpha) - torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin(
alpha
)

# In Ocelot r12 is defined as below only if abs(Ep) > 10, and self.length
# otherwise. This is implemented differently here in order to achieve results
# closer to Bmad.
r12 = torch.sqrt(8 / eta) * Ei / Ep * torch.cos(phi) * torch.sin(alpha)
# In Ocelot r12 is defined as below only if abs(Ep) > 10, and self.length
# otherwise. This is implemented differently here to achieve results
# closer to Bmad.
r12 = torch.sqrt(8 / eta) * Ei / Ep * torch.cos(phi) * torch.sin(alpha)

r21 = (
-Ep
/ Ef
* (
torch.cos(phi) / torch.sqrt(2 * eta)
+ torch.sqrt(eta / 8) / torch.cos(phi)
)
* torch.sin(alpha)
)

r21 = (
-Ep
/ Ef
* (
torch.cos(phi) / torch.sqrt(2 * eta)
+ torch.sqrt(eta / 8) / torch.cos(phi)
r22 = (
Ei
/ Ef
* (
torch.cos(alpha)
+ torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin(alpha)
)
)
* torch.sin(alpha)
)
elif self.cavity_type == "traveling_wave":
# Reference paper: Rosenzweig and Serafini, PhysRevE, Vol.49, p.1599,(1994)
dE = Ef - Ei
f = (Ei / dE) * torch.log(1 + (dE / Ei))

r22 = (
Ei
/ Ef
* (
torch.cos(alpha)
+ torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin(alpha)
vector_shape = torch.broadcast_shapes(
self.length.shape, f.shape, Ei.shape, Ef.shape, dE.shape
)
)

M_body = torch.eye(2, **factory_kwargs).repeat((*vector_shape, 1, 1))
M_body[..., 0, 1] = self.length * f
M_body[..., 1, 1] = Ei / Ef

M_f_entry = torch.eye(2, **factory_kwargs).repeat((*vector_shape, 1, 1))
M_f_entry[..., 1, 0] = -dE / (2 * self.length * Ei)

M_f_exit = torch.eye(2, **factory_kwargs).repeat((*vector_shape, 1, 1))
M_f_exit[..., 1, 0] = dE / (2 * self.length * Ef)

M_combined = M_f_exit @ M_body @ M_f_entry

r11 = M_combined[..., 0, 0]
r12 = M_combined[..., 0, 1]
r21 = M_combined[..., 1, 0]
r22 = M_combined[..., 1, 1]
else:
raise ValueError(f"Invalid cavity type: {self.cavity_type}")

r56 = torch.tensor(0.0, **factory_kwargs)
beta0 = torch.tensor(1.0, **factory_kwargs)
Expand Down Expand Up @@ -345,13 +379,20 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No

@property
def defining_features(self) -> list[str]:
return super().defining_features + ["length", "voltage", "phase", "frequency"]
return super().defining_features + [
"length",
"voltage",
"phase",
"frequency",
"cavity_type",
]

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(length={repr(self.length)}, "
+ f"voltage={repr(self.voltage)}, "
+ f"phase={repr(self.phase)}, "
+ f"frequency={repr(self.frequency)}, "
+ f"cavity_type={repr(self.cavity_type)}, "
+ f"name={repr(self.name)})"
)
1 change: 1 addition & 0 deletions cheetah/converters/bmad.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def convert_element(
-np.degrees(bmad_parsed.get("phi0", 0.0) * 2 * np.pi)
),
frequency=torch.tensor(bmad_parsed["rf_frequency"]),
cavity_type=bmad_parsed["cavity_type"],
name=name,
device=device,
dtype=dtype,
Expand Down
12 changes: 12 additions & 0 deletions cheetah/converters/ocelot.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ def convert_element_to_cheetah(
voltage=torch.tensor(element.v, dtype=torch.float32) * 1e9,
frequency=torch.tensor(element.freq, dtype=torch.float32),
phase=torch.tensor(element.phi, dtype=torch.float32),
cavity_type="standing_wave",
name=element.id,
device=device,
dtype=dtype,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we should also include the conversion for ocelot.TWCacity as it's implemented now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cr-xu like this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.

I think the proper test for TWCavity is still missing though. It would be nice to have a benchmark against both OCELOT and Bmad tracking results.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. There is a test for the standing wave already that I think can just be parameterised to cover travelling wave as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added the test, comparing the results to Ocelot. Weirdly, it seems the TWCavity implementation in Ocelot is broken. 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well... indeed the OCELOT TWCavity seems to be broken, both for the stable 22.12.0 version and the unreleased one in the master branch now 24.12.0. I already opened an Issue there: ocelot-collab/ocelot#268

Let's see if that can be resolved soon. Otherwise I would suggest to have a static bmad test case and compare the result, so that we can have the PR merged soon-ish.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a side note, the Ocelot master branch did merge my PR, so now we can also point to its official master instead of my fork for the CI/CD.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked @zihan-zh to provide results from a Bmad tracking, so we can use those for a static comparison.

)
elif isinstance(element, ocelot.TWCavity):
return cheetah.Cavity(
length=torch.tensor(element.l, dtype=torch.float32),
voltage=torch.tensor(element.v, dtype=torch.float32) * 1e9,
frequency=torch.tensor(element.freq, dtype=torch.float32),
phase=torch.tensor(element.phi, dtype=torch.float32),
cavity_type="traveling_wave",
name=element.id,
device=device,
dtype=dtype,
Expand Down
6 changes: 3 additions & 3 deletions cheetah/particles/particle_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,9 @@ def from_ocelot(cls, parray, device=None, dtype=torch.float32) -> "ParticleBeam"
particle_charges = torch.tensor(parray.q_array)

return cls(
particles=particles.unsqueeze(0),
energy=torch.tensor(1e9 * parray.E).unsqueeze(0),
particle_charges=particle_charges.unsqueeze(0),
particles=particles,
energy=torch.tensor(1e9 * parray.E),
particle_charges=particle_charges,
device=device,
dtype=dtype,
)
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ The following people have contributed to the development of Cheetah:
- Juan Pablo Gonzalez-Aguilera (@jp-ga)
- Ryan Roussel (@roussel-ryan)
- Auralee Edelen (@lee-edelen)
- Zihan Zhu (@zihan-zh)

Institutions
~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
git+https://github.com/cr-xu/ocelot@update-scipy-compatibility # Ocelot
git+https://github.com/ocelot-collab/ocelot.git # Ocelot
pytest
pytest-cov
29 changes: 21 additions & 8 deletions tests/test_compare_ocelot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import ocelot
import pytest
import torch

import cheetah
Expand Down Expand Up @@ -637,7 +638,8 @@ def test_asymmetric_bend():
)


def test_cavity():
@pytest.mark.parametrize("cavity_type", ["standing_wave", "traveling_wave"])
def test_cavity(cavity_type):
"""
Compare tracking through a cavity that is on.

Expand Down Expand Up @@ -680,7 +682,11 @@ def test_cavity():

p_array = ocelot.generate_parray(tws=tws, charge=5e-9)

cell = [ocelot.Cavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=0.0)]
cell = (
[ocelot.Cavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=0.0)]
if cavity_type == "standing_wave"
else [ocelot.TWCavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=0.0)]
)
lattice = ocelot.MagneticLattice(cell)
navigator = ocelot.Navigator(lattice=lattice)

Expand All @@ -696,6 +702,7 @@ def test_cavity():
voltage=torch.tensor(0.01815975e9),
frequency=torch.tensor(1.3e9),
phase=torch.tensor(0.0),
cavity_type=cavity_type,
dtype=torch.float64,
)
outgoing_beam = cheetah_cavity.track(incoming_beam)
Expand All @@ -709,16 +716,17 @@ def test_cavity():
outgoing_beam.total_charge.cpu().numpy(), np.sum(outgoing_parray.q_array)
)
assert np.allclose(
outgoing_beam.particles[:, :, 5].cpu().numpy(),
outgoing_beam.particles[:, 5].cpu().numpy(),
outgoing_parray.rparticles.transpose()[:, 5],
)
assert np.allclose(
outgoing_beam.particles[:, :, 4].cpu().numpy(),
outgoing_beam.particles[:, 4].cpu().numpy(),
outgoing_parray.rparticles.transpose()[:, 4],
)


def test_cavity_non_zero_phase():
@pytest.mark.parametrize("cavity_type", ["standing_wave", "traveling_wave"])
def test_cavity_non_zero_phase(cavity_type):
"""Compare tracking through a cavity with a phase offset."""
# Ocelot
tws = ocelot.Twiss()
Expand All @@ -734,7 +742,11 @@ def test_cavity_non_zero_phase():

p_array = ocelot.generate_parray(tws=tws, charge=5e-9)

cell = [ocelot.Cavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=30.0)]
cell = (
[ocelot.Cavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=30.0)]
if cavity_type == "standing_wave"
else [ocelot.TWCavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=30.0)]
)
lattice = ocelot.MagneticLattice(cell)
navigator = ocelot.Navigator(lattice=lattice)

Expand All @@ -750,6 +762,7 @@ def test_cavity_non_zero_phase():
voltage=torch.tensor(0.01815975e9),
frequency=torch.tensor(1.3e9),
phase=torch.tensor(30.0),
cavity_type=cavity_type,
dtype=torch.float64,
)
outgoing_beam = cheetah_cavity.track(incoming_beam)
Expand All @@ -763,10 +776,10 @@ def test_cavity_non_zero_phase():
outgoing_beam.total_charge.cpu().numpy(), np.sum(outgoing_parray.q_array)
)
assert np.allclose(
outgoing_beam.particles[:, :, 5].cpu().numpy(),
outgoing_beam.particles[:, 5].cpu().numpy(),
outgoing_parray.rparticles.transpose()[:, 5],
)
assert np.allclose(
outgoing_beam.particles[:, :, 4].cpu().numpy(),
outgoing_beam.particles[:, 4].cpu().numpy(),
outgoing_parray.rparticles.transpose()[:, 4],
)
12 changes: 6 additions & 6 deletions tests/test_ocelot_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ def test_ocelot_to_particlebeam():
parray = ocelot.astraBeam2particleArray("tests/resources/ACHIP_EA1_2021.1351.001")
beam = cheetah.ParticleBeam.from_ocelot(parray)

assert np.allclose(beam.particles[0, :, 0].cpu().numpy(), parray.x())
assert np.allclose(beam.particles[0, :, 1].cpu().numpy(), parray.px())
assert np.allclose(beam.particles[0, :, 2].cpu().numpy(), parray.y())
assert np.allclose(beam.particles[0, :, 3].cpu().numpy(), parray.py())
assert np.allclose(beam.particles[0, :, 4].cpu().numpy(), parray.tau())
assert np.allclose(beam.particles[0, :, 5].cpu().numpy(), parray.p())
assert np.allclose(beam.particles[:, 0].cpu().numpy(), parray.x())
assert np.allclose(beam.particles[:, 1].cpu().numpy(), parray.px())
assert np.allclose(beam.particles[:, 2].cpu().numpy(), parray.y())
assert np.allclose(beam.particles[:, 3].cpu().numpy(), parray.py())
assert np.allclose(beam.particles[:, 4].cpu().numpy(), parray.tau())
assert np.allclose(beam.particles[:, 5].cpu().numpy(), parray.p())
assert np.allclose(beam.energy.cpu().numpy(), parray.E * 1e9)
assert np.allclose(beam.particle_charges.cpu().numpy(), parray.q_array)

Expand Down
6 changes: 4 additions & 2 deletions tests/test_vectorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,18 @@ def test_enormous_through_ares_ea():


@pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam])
def test_cavity_with_zero_and_non_zero_voltage(BeamClass):
@pytest.mark.parametrize("cavity_type", ["standing_wave", "traveling_wave"])
def test_cavity_with_zero_and_non_zero_voltage(BeamClass, cavity_type):
"""
Tests that if zero and non-zero voltages are passed to a cavity in a single batch,
there are no errors. This test does NOT check physical correctness.
"""
cavity = cheetah.Cavity(
length=torch.tensor(3.0441),
voltage=torch.tensor([0.0, 48_198_468.0, 0.0]),
voltage=torch.tensor([0.0, 48198468.0, 0.0]),
phase=torch.tensor(48198468.0),
frequency=torch.tensor(2.8560e09),
cavity_type=cavity_type,
name="my_test_cavity",
)
incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5))
Expand Down
Loading