Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement BCM and BAM lattice elements #342

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cheetah/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from . import converters # noqa: F401
from .accelerator import ( # noqa: F401
BAM,
BCM,
BPM,
Aperture,
Cavity,
Expand Down
2 changes: 2 additions & 0 deletions cheetah/accelerator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .aperture import Aperture # noqa: F401
from .bam import BAM # noqa: F401
from .bcm import BCM # noqa: F401
from .bpm import BPM # noqa: F401
from .cavity import Cavity # noqa: F401
from .custom_transfer_map import CustomTransferMap # noqa: F401
Expand Down
65 changes: 65 additions & 0 deletions cheetah/accelerator/bam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.utils import UniqueNameGenerator

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")


class BAM(Element):
"""
Beam Position Monitor (BPM) in a particle accelerator.

:param is_active: If `True` the BPM is active and will record the beam's position.
If `False` the BPM is inactive and will not record the beam's position.
:param name: Unique identifier of the element.
"""

def __init__(self, is_active: bool = False, name: Optional[str] = None) -> None:
super().__init__(name=name)

self.is_active = is_active
self.reading = None

@property
def is_skippable(self) -> bool:
return not self.is_active

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
return torch.eye(7, device=energy.device, dtype=energy.dtype).repeat(
(*energy.shape, 1, 1)
)

def track(self, incoming: Beam) -> Beam:
if isinstance(incoming, ParameterBeam):
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
elif isinstance(incoming, ParticleBeam):
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
else:
raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}")

return incoming.clone()

def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]

def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
plot_s = s[vector_idx] if s.dim() > 0 else s

alpha = 1 if self.is_active else 0.2
patch = Rectangle(
(plot_s, -0.3), 0, 0.3 * 2, color="darkkhaki", alpha=alpha, zorder=2
)
ax.add_patch(patch)

@property
def defining_features(self) -> list[str]:
return super().defining_features + ["is_active"]

def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={repr(self.name)})"
65 changes: 65 additions & 0 deletions cheetah/accelerator/bcm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Optional

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.utils import UniqueNameGenerator

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")


class BCM(Element):
"""
Beam Position Monitor (BPM) in a particle accelerator.

:param is_active: If `True` the BPM is active and will record the beam's position.
If `False` the BPM is inactive and will not record the beam's position.
:param name: Unique identifier of the element.
"""

def __init__(self, is_active: bool = False, name: Optional[str] = None) -> None:
super().__init__(name=name)

self.is_active = is_active
self.reading = None

@property
def is_skippable(self) -> bool:
return not self.is_active

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
return torch.eye(7, device=energy.device, dtype=energy.dtype).repeat(
(*energy.shape, 1, 1)
)

def track(self, incoming: Beam) -> Beam:
if isinstance(incoming, ParameterBeam):
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
elif isinstance(incoming, ParticleBeam):
self.reading = torch.stack([incoming.mu_x, incoming.mu_y])
else:
raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}")

return incoming.clone()

def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]

def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
plot_s = s[vector_idx] if s.dim() > 0 else s

alpha = 1 if self.is_active else 0.2
patch = Rectangle(
(plot_s, -0.3), 0, 0.3 * 2, color="darkkhaki", alpha=alpha, zorder=2
)
ax.add_patch(patch)

@property
def defining_features(self) -> list[str]:
return super().defining_features + ["is_active"]

def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={repr(self.name)})"