Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically detect default dtype for lattice and beam importers #340

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

Hespe
Copy link
Member

@Hespe Hespe commented Feb 11, 2025

Description

Motivation and Context

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code and checked that formatting passes (required).
  • I have have fixed all issues found by flake8 (required).
  • I have ensured that all pytest tests pass (required).
  • I have run pytest on a machine with a CUDA GPU and made sure all tests pass (required).
  • I have checked that the documentation builds (required).

Note: We are using a maximum length of 88 characters per line.

@Hespe Hespe added the enhancement New feature or request label Feb 11, 2025
@Hespe Hespe self-assigned this Feb 11, 2025
@Hespe
Copy link
Member Author

Hespe commented Feb 11, 2025

Before finishing this PR, we should discuss the default behaviour of the beam importers that import numpy data. Right now there are actually in an inconsistent state:

  • When from_ocelot is used with default dtype=None, it uses the dtype of the numpy array passed to it. In my opinion this is the correct behaviour and it mimickes the dtype one obtains from torch.tensor when passing a numpy array. In fact, the behaviour is implemented using torch.as_tensor to determine the appropriate dtype.
  • from_astra on the other hand uses a default of dtype=torch.float32. This does not follow the automatic dtype selection we are implementing in this PR and thus definitely has to change. I see two possibilities:
    1. Simply change to dtype=None, making its behaviour identical to from_ocelot. This would be my preferred choice but it leads to a number of test failures currently. The problem is that the tests implicitly assume a beam of type torch.float32 and the solution would be to explicitly set dtype=torch.float32 in the tests while loading the beam.
    2. Change to dtype=None, but query torch.get_default_dtype in that case to determine the dtype. This makes the API mostly consistent with the current behaviour since torch.get_default_dtype() defaults to torch.float32 but deviates from how torch.tensor determines its dtype from numpy arrays.

In case we go for the second option, from_ocelot should be updated in the same way.

@cr-xu
Copy link
Member

cr-xu commented Feb 17, 2025

Hi, I just noticed that the lattice loaded with Segment.from_lattice_json() also doesn't have an option to select dtype, which is annoying because it now defaults to float64.

@cr-xu
Copy link
Member

cr-xu commented Feb 17, 2025

On a side note, I discovered that now the entire Segment can be reliably casted to float or double simply by calling
segment.to(torch.float32)
Did we know that already? @Hespe @jank324 I vaguely remembered that we were discussing this at some point but it was not possible. Probably it's worth to document it somewhere.

@Hespe
Copy link
Member Author

Hespe commented Feb 18, 2025

You should even be ablo to simply call segment.double() or segment.single() to convert between float32 and float64, same for instances of Beam. However, you are right that this should be properly documented somewhere.

@Hespe Hespe changed the title Improve beam and lattice importers Automatically detect default dtype for lattice and beam importers Feb 18, 2025
@Hespe
Copy link
Member Author

Hespe commented Feb 18, 2025

Hi, I just noticed that the lattice loaded with Segment.from_lattice_json() also doesn't have an option to select dtype, which is annoying because it now defaults to float64.

I've just added a device and dtype argument to Segment.from_lattice_json(). Furthermore, it should now always default to torch.float32 unless you use torch.set_default_dtype. But to be fair, I'm not even sure in which cases it would previously default totorch.float64.

@cr-xu
Copy link
Member

cr-xu commented Feb 18, 2025

Hi, I just noticed that the lattice loaded with Segment.from_lattice_json() also doesn't have an option to select dtype, which is annoying because it now defaults to float64.

I've just added a device and dtype argument to Segment.from_lattice_json(). Furthermore, it should now always default to torch.float32 unless you use torch.set_default_dtype. But to be fair, I'm not even sure in which cases it would previously default totorch.float64.

Thanks! I was testing that on released v0.7.0 version, not this PR branch.

@Hespe Hespe marked this pull request as ready for review February 18, 2025 08:19
@Hespe
Copy link
Member Author

Hespe commented Feb 18, 2025

I noticed a small inconsistency in the dtype of cheetah elements while writing this PR. All elements inherit a length property from Element. But for elements that are not supposed to have a length, e.g., Marker, there is no length argument to their constructors and also no dtype and device arguments such that their length always has dtype torch.float32. Importing or constructing lattices of type torch.float64 therefore leads to a Segment with mixed dtype.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adopt automatic type scheme for importers
2 participants