diff --git a/CHANGELOG.md b/CHANGELOG.md index f8d4cc36..ac2083ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ ### 🚀 Features - `ParticleBeam` now supports importing from and exporting to [openPMD-beamphysics](https://github.com/ChristopherMayes/openPMD-beamphysics) HDF5 files and `ParticleGroup` objects. This allows for easy conversion to and from other file formats supported by openPMD-beamphysics. (see #305) (@cr-xu) +- `Cavity` now supports travelling wave cavities in addition to standing wave cavities via the `cavity_type` argument. (see #286) (@zihan-zh, @jank324) ### 🐛 Bug fixes @@ -20,6 +21,8 @@ ### 🌟 First Time Contributors +- Zihan Zhu (@zihan-zh) + ## [v0.7.0](https://github.com/desy-ml/cheetah/releases/tag/v0.7.0) (2024-12-13) We are proud to announce this new major release of Cheetah! This is probably the biggest release since the original Cheetah release, with many with significant upgrades under the hood. Cheetah is now fully vectorised and compatible with PyTorch broadcasting rules, while additional physics and higher fidelity options for existing physics have also been introduced. Despite extensive testing, you might still encounter a few bugs. Please report them by opening an issue, so we can fix them as soon as possible and improve the experience for everyone. diff --git a/README.md b/README.md index 7db8d08a..c72f35ad 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,7 @@ The following people have contributed to the development of Cheetah: - Juan Pablo Gonzalez-Aguilera (@jp-ga) - Ryan Roussel (@roussel-ryan) - Auralee Edelen (@lee-edelen) +- Zihan Zhu (@zihan-zh) ### Institutions diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index 029ed19a..2b902b3b 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional import matplotlib.pyplot as plt import torch @@ -29,6 +29,7 @@ class Cavity(Element): :param phase: Phase of the cavity in degrees. :param frequency: Frequency of the cavity in Hz. :param name: Unique identifier of the element. + :param cavity_type: Type of the cavity. """ def __init__( @@ -37,6 +38,7 @@ def __init__( voltage: Optional[torch.Tensor] = None, phase: Optional[torch.Tensor] = None, frequency: Optional[torch.Tensor] = None, + cavity_type: Literal["standing_wave", "traveling_wave"] = "standing_wave", name: Optional[str] = None, device=None, dtype=None, @@ -59,6 +61,8 @@ def __init__( if frequency is not None: self.frequency = torch.as_tensor(frequency, **factory_kwargs) + self.cavity_type = cavity_type + @property def is_active(self) -> bool: return torch.any(self.voltage != 0) @@ -252,31 +256,61 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: alpha = torch.sqrt(eta / 8) / torch.cos(phi) * torch.log(Ef / Ei) - r11 = torch.cos(alpha) - torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin(alpha) + if self.cavity_type == "standing_wave": + r11 = torch.cos(alpha) - torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin( + alpha + ) - # In Ocelot r12 is defined as below only if abs(Ep) > 10, and self.length - # otherwise. This is implemented differently here in order to achieve results - # closer to Bmad. - r12 = torch.sqrt(8 / eta) * Ei / Ep * torch.cos(phi) * torch.sin(alpha) + # In Ocelot r12 is defined as below only if abs(Ep) > 10, and self.length + # otherwise. This is implemented differently here to achieve results + # closer to Bmad. + r12 = torch.sqrt(8 / eta) * Ei / Ep * torch.cos(phi) * torch.sin(alpha) + + r21 = ( + -Ep + / Ef + * ( + torch.cos(phi) / torch.sqrt(2 * eta) + + torch.sqrt(eta / 8) / torch.cos(phi) + ) + * torch.sin(alpha) + ) - r21 = ( - -Ep - / Ef - * ( - torch.cos(phi) / torch.sqrt(2 * eta) - + torch.sqrt(eta / 8) / torch.cos(phi) + r22 = ( + Ei + / Ef + * ( + torch.cos(alpha) + + torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin(alpha) + ) ) - * torch.sin(alpha) - ) + elif self.cavity_type == "traveling_wave": + # Reference paper: Rosenzweig and Serafini, PhysRevE, Vol.49, p.1599,(1994) + dE = Ef - Ei + f = (Ei / dE) * torch.log(1 + (dE / Ei)) - r22 = ( - Ei - / Ef - * ( - torch.cos(alpha) - + torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin(alpha) + vector_shape = torch.broadcast_shapes( + self.length.shape, f.shape, Ei.shape, Ef.shape, dE.shape ) - ) + + M_body = torch.eye(2, **factory_kwargs).repeat((*vector_shape, 1, 1)) + M_body[..., 0, 1] = self.length * f + M_body[..., 1, 1] = Ei / Ef + + M_f_entry = torch.eye(2, **factory_kwargs).repeat((*vector_shape, 1, 1)) + M_f_entry[..., 1, 0] = -dE / (2 * self.length * Ei) + + M_f_exit = torch.eye(2, **factory_kwargs).repeat((*vector_shape, 1, 1)) + M_f_exit[..., 1, 0] = dE / (2 * self.length * Ef) + + M_combined = M_f_exit @ M_body @ M_f_entry + + r11 = M_combined[..., 0, 0] + r12 = M_combined[..., 0, 1] + r21 = M_combined[..., 1, 0] + r22 = M_combined[..., 1, 1] + else: + raise ValueError(f"Invalid cavity type: {self.cavity_type}") r56 = torch.tensor(0.0, **factory_kwargs) beta0 = torch.tensor(1.0, **factory_kwargs) @@ -345,7 +379,13 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No @property def defining_features(self) -> list[str]: - return super().defining_features + ["length", "voltage", "phase", "frequency"] + return super().defining_features + [ + "length", + "voltage", + "phase", + "frequency", + "cavity_type", + ] def __repr__(self) -> str: return ( @@ -353,5 +393,6 @@ def __repr__(self) -> str: + f"voltage={repr(self.voltage)}, " + f"phase={repr(self.phase)}, " + f"frequency={repr(self.frequency)}, " + + f"cavity_type={repr(self.cavity_type)}, " + f"name={repr(self.name)})" ) diff --git a/cheetah/converters/bmad.py b/cheetah/converters/bmad.py index a6c34ef6..fa17574a 100644 --- a/cheetah/converters/bmad.py +++ b/cheetah/converters/bmad.py @@ -206,6 +206,7 @@ def convert_element( -np.degrees(bmad_parsed.get("phi0", 0.0) * 2 * np.pi) ), frequency=torch.tensor(bmad_parsed["rf_frequency"]), + cavity_type=bmad_parsed["cavity_type"], name=name, device=device, dtype=dtype, diff --git a/cheetah/converters/ocelot.py b/cheetah/converters/ocelot.py index d1577563..c6afd0b5 100644 --- a/cheetah/converters/ocelot.py +++ b/cheetah/converters/ocelot.py @@ -116,6 +116,18 @@ def convert_element_to_cheetah( voltage=torch.tensor(element.v, dtype=torch.float32) * 1e9, frequency=torch.tensor(element.freq, dtype=torch.float32), phase=torch.tensor(element.phi, dtype=torch.float32), + cavity_type="standing_wave", + name=element.id, + device=device, + dtype=dtype, + ) + elif isinstance(element, ocelot.TWCavity): + return cheetah.Cavity( + length=torch.tensor(element.l, dtype=torch.float32), + voltage=torch.tensor(element.v, dtype=torch.float32) * 1e9, + frequency=torch.tensor(element.freq, dtype=torch.float32), + phase=torch.tensor(element.phi, dtype=torch.float32), + cavity_type="traveling_wave", name=element.id, device=device, dtype=dtype, diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index f840514c..faecad3c 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -646,9 +646,9 @@ def from_ocelot(cls, parray, device=None, dtype=torch.float32) -> "ParticleBeam" particle_charges = torch.tensor(parray.q_array) return cls( - particles=particles.unsqueeze(0), - energy=torch.tensor(1e9 * parray.E).unsqueeze(0), - particle_charges=particle_charges.unsqueeze(0), + particles=particles, + energy=torch.tensor(1e9 * parray.E), + particle_charges=particle_charges, device=device, dtype=dtype, ) diff --git a/docs/index.rst b/docs/index.rst index 0558bc78..b54f3f9f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -133,6 +133,7 @@ The following people have contributed to the development of Cheetah: - Juan Pablo Gonzalez-Aguilera (@jp-ga) - Ryan Roussel (@roussel-ryan) - Auralee Edelen (@lee-edelen) +- Zihan Zhu (@zihan-zh) Institutions ~~~~~~~~~~~~ diff --git a/test_requirements.txt b/test_requirements.txt index 5db7984a..54a87726 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -1,3 +1,3 @@ -git+https://github.com/cr-xu/ocelot@update-scipy-compatibility # Ocelot +git+https://github.com/ocelot-collab/ocelot.git # Ocelot pytest pytest-cov diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index 639ea1bf..11712ead 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -2,6 +2,7 @@ import numpy as np import ocelot +import pytest import torch import cheetah @@ -637,7 +638,8 @@ def test_asymmetric_bend(): ) -def test_cavity(): +@pytest.mark.parametrize("cavity_type", ["standing_wave", "traveling_wave"]) +def test_cavity(cavity_type): """ Compare tracking through a cavity that is on. @@ -680,7 +682,11 @@ def test_cavity(): p_array = ocelot.generate_parray(tws=tws, charge=5e-9) - cell = [ocelot.Cavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=0.0)] + cell = ( + [ocelot.Cavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=0.0)] + if cavity_type == "standing_wave" + else [ocelot.TWCavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=0.0)] + ) lattice = ocelot.MagneticLattice(cell) navigator = ocelot.Navigator(lattice=lattice) @@ -696,6 +702,7 @@ def test_cavity(): voltage=torch.tensor(0.01815975e9), frequency=torch.tensor(1.3e9), phase=torch.tensor(0.0), + cavity_type=cavity_type, dtype=torch.float64, ) outgoing_beam = cheetah_cavity.track(incoming_beam) @@ -709,16 +716,17 @@ def test_cavity(): outgoing_beam.total_charge.cpu().numpy(), np.sum(outgoing_parray.q_array) ) assert np.allclose( - outgoing_beam.particles[:, :, 5].cpu().numpy(), + outgoing_beam.particles[:, 5].cpu().numpy(), outgoing_parray.rparticles.transpose()[:, 5], ) assert np.allclose( - outgoing_beam.particles[:, :, 4].cpu().numpy(), + outgoing_beam.particles[:, 4].cpu().numpy(), outgoing_parray.rparticles.transpose()[:, 4], ) -def test_cavity_non_zero_phase(): +@pytest.mark.parametrize("cavity_type", ["standing_wave", "traveling_wave"]) +def test_cavity_non_zero_phase(cavity_type): """Compare tracking through a cavity with a phase offset.""" # Ocelot tws = ocelot.Twiss() @@ -734,7 +742,11 @@ def test_cavity_non_zero_phase(): p_array = ocelot.generate_parray(tws=tws, charge=5e-9) - cell = [ocelot.Cavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=30.0)] + cell = ( + [ocelot.Cavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=30.0)] + if cavity_type == "standing_wave" + else [ocelot.TWCavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=30.0)] + ) lattice = ocelot.MagneticLattice(cell) navigator = ocelot.Navigator(lattice=lattice) @@ -750,6 +762,7 @@ def test_cavity_non_zero_phase(): voltage=torch.tensor(0.01815975e9), frequency=torch.tensor(1.3e9), phase=torch.tensor(30.0), + cavity_type=cavity_type, dtype=torch.float64, ) outgoing_beam = cheetah_cavity.track(incoming_beam) @@ -763,10 +776,10 @@ def test_cavity_non_zero_phase(): outgoing_beam.total_charge.cpu().numpy(), np.sum(outgoing_parray.q_array) ) assert np.allclose( - outgoing_beam.particles[:, :, 5].cpu().numpy(), + outgoing_beam.particles[:, 5].cpu().numpy(), outgoing_parray.rparticles.transpose()[:, 5], ) assert np.allclose( - outgoing_beam.particles[:, :, 4].cpu().numpy(), + outgoing_beam.particles[:, 4].cpu().numpy(), outgoing_parray.rparticles.transpose()[:, 4], ) diff --git a/tests/test_ocelot_import.py b/tests/test_ocelot_import.py index cb2ee660..fe98510a 100644 --- a/tests/test_ocelot_import.py +++ b/tests/test_ocelot_import.py @@ -57,12 +57,12 @@ def test_ocelot_to_particlebeam(): parray = ocelot.astraBeam2particleArray("tests/resources/ACHIP_EA1_2021.1351.001") beam = cheetah.ParticleBeam.from_ocelot(parray) - assert np.allclose(beam.particles[0, :, 0].cpu().numpy(), parray.x()) - assert np.allclose(beam.particles[0, :, 1].cpu().numpy(), parray.px()) - assert np.allclose(beam.particles[0, :, 2].cpu().numpy(), parray.y()) - assert np.allclose(beam.particles[0, :, 3].cpu().numpy(), parray.py()) - assert np.allclose(beam.particles[0, :, 4].cpu().numpy(), parray.tau()) - assert np.allclose(beam.particles[0, :, 5].cpu().numpy(), parray.p()) + assert np.allclose(beam.particles[:, 0].cpu().numpy(), parray.x()) + assert np.allclose(beam.particles[:, 1].cpu().numpy(), parray.px()) + assert np.allclose(beam.particles[:, 2].cpu().numpy(), parray.y()) + assert np.allclose(beam.particles[:, 3].cpu().numpy(), parray.py()) + assert np.allclose(beam.particles[:, 4].cpu().numpy(), parray.tau()) + assert np.allclose(beam.particles[:, 5].cpu().numpy(), parray.p()) assert np.allclose(beam.energy.cpu().numpy(), parray.E * 1e9) assert np.allclose(beam.particle_charges.cpu().numpy(), parray.q_array) diff --git a/tests/test_vectorized.py b/tests/test_vectorized.py index 21038ad4..5ebc6aba 100644 --- a/tests/test_vectorized.py +++ b/tests/test_vectorized.py @@ -211,16 +211,18 @@ def test_enormous_through_ares_ea(): @pytest.mark.parametrize("BeamClass", [cheetah.ParticleBeam, cheetah.ParameterBeam]) -def test_cavity_with_zero_and_non_zero_voltage(BeamClass): +@pytest.mark.parametrize("cavity_type", ["standing_wave", "traveling_wave"]) +def test_cavity_with_zero_and_non_zero_voltage(BeamClass, cavity_type): """ Tests that if zero and non-zero voltages are passed to a cavity in a single batch, there are no errors. This test does NOT check physical correctness. """ cavity = cheetah.Cavity( length=torch.tensor(3.0441), - voltage=torch.tensor([0.0, 48_198_468.0, 0.0]), + voltage=torch.tensor([0.0, 48198468.0, 0.0]), phase=torch.tensor(48198468.0), frequency=torch.tensor(2.8560e09), + cavity_type=cavity_type, name="my_test_cavity", ) incoming = BeamClass.from_parameters(sigma_x=torch.tensor(1e-5))