Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register element readings as torch buffers for proper dtype conversion #335

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
### 🐛 Bug fixes

- Fix issue where a space before a comma could cause the Elegant and Bmad converters to fail (see #327) (@jank324)
- Fix issue of `BPM` and `Screen` not properly converting the `dtype` of their readings (#335) (@Hespe)

### 🐆 Other

Expand Down
14 changes: 11 additions & 3 deletions cheetah/accelerator/bpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,19 @@ class BPM(Element):
:param name: Unique identifier of the element.
"""

def __init__(self, is_active: bool = False, name: Optional[str] = None) -> None:
super().__init__(name=name)
def __init__(
self,
is_active: bool = False,
name: Optional[str] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(name=name, device=device, dtype=dtype)

self.is_active = is_active
self.reading = None
self.register_buffer(
"reading", torch.tensor(torch.nan, device=device, dtype=dtype)
)

@property
def is_skippable(self) -> bool:
Expand Down
19 changes: 15 additions & 4 deletions cheetah/accelerator/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ def __init__(
self.register_buffer("misalignment", torch.tensor((0.0, 0.0), **factory_kwargs))
self.register_buffer("kde_bandwidth", torch.clone(self.pixel_size[0]))

# NOTE: According to its type hint, the operation on resolution below is a
# no-op. However, this form is robust against accidentally passing a
# torch.Tensor, preventing crashes in some instances.
self.register_buffer(
"cached_reading",
torch.full((resolution[0], resolution[1]), torch.nan, **factory_kwargs),
)

if pixel_size is not None:
self.pixel_size = torch.as_tensor(pixel_size, **factory_kwargs)
if misalignment is not None:
Expand All @@ -82,7 +90,6 @@ def __init__(
self.kde_bandwidth = torch.as_tensor(kde_bandwidth, **factory_kwargs)

self.set_read_beam(None)
self.cached_reading = None

@property
def is_skippable(self) -> bool:
Expand Down Expand Up @@ -190,8 +197,7 @@ def track(self, incoming: Beam) -> Beam:

@property
def reading(self) -> torch.Tensor:
image = None
if self.cached_reading is not None:
if not torch.all(torch.isnan(self.cached_reading)):
return self.cached_reading

read_beam = self.get_read_beam()
Expand Down Expand Up @@ -296,7 +302,12 @@ def set_read_beam(self, value: Beam) -> None:
# prevent `nn.Module` from intercepting the read beam, which is itself an
# `nn.Module`, and registering it as a submodule of the screen.
self._read_beam = value
self.cached_reading = None
self.cached_reading = torch.full(
(self.resolution[0], self.resolution[1]),
torch.nan,
device=self.cached_reading.device,
dtype=self.cached_reading.dtype,
)

def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]
Expand Down
18 changes: 18 additions & 0 deletions tests/test_bpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,21 @@ def test_no_tracking_error(is_bpm_active, beam_class):
segment.my_bpm.is_active = is_bpm_active

_ = segment.track(beam)


def test_reading_dtype_conversion():
"""Test that a dtype conversion is correctly reflected in the BPM reading."""
segment = cheetah.Segment(
elements=[
cheetah.Drift(length=torch.tensor(1.0), dtype=torch.float32),
cheetah.BPM(name="bpm", is_active=True, dtype=torch.float32),
],
)
beam = cheetah.ParameterBeam.from_parameters(dtype=torch.float32)
assert segment.bpm.reading.dtype == torch.float32

segment.track(beam)
assert segment.bpm.reading.dtype == torch.float32

segment = segment.double()
assert segment.bpm.reading.dtype == torch.float64
27 changes: 27 additions & 0 deletions tests/test_screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,30 @@ def test_reading_shows_beam_ares(screen_method):
assert segment.AREABSCR1.reading.shape == (2040, 2448)
assert torch.all(segment.AREABSCR1.reading >= 0.0)
assert torch.any(segment.AREABSCR1.reading > 0.0)


def test_reading_dtype_conversion():
"""Test that a dtype conversion is correctly reflected in the screen reading."""
segment = cheetah.Segment(
elements=[
cheetah.Drift(length=torch.tensor(1.0), dtype=torch.float32),
cheetah.Screen(name="screen", is_active=True, dtype=torch.float32),
],
)
beam = cheetah.ParameterBeam.from_parameters(dtype=torch.float32)
assert segment.screen.reading.dtype == torch.float32

# Test generating new image
cloned = segment.clone()
cloned.track(beam)
cloned = cloned.double()
assert torch.all(torch.isnan(cloned.screen.cached_reading))
assert cloned.screen.reading.dtype == torch.float64

# Test reading from cache
segment.track(beam)
assert segment.screen.reading.dtype == torch.float32
assert segment.screen.cached_reading.dtype == torch.float32
segment = segment.double()
assert segment.screen.cached_reading.dtype == torch.float64
assert segment.screen.reading.dtype == torch.float64