[wip] Python-only float8 data type + bare bones UEX #23
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
This is a lightweight example of how a
Float8Tensor
could be build out of core, and how it could hook up with a scaling UEX based on module swapping. TheFloat8Tensor
here is the important part. TheFloat8Linear
part is just an example to demonstrate e2e - I'd expect framework owners to create their own UEX for now.Float8Tensor
should ideally hook into the existing TransformerEngine cleanly and simplify things.Note: Ready for initial review, but things might change after we add distributed support.
Note: this is WIP and does not represent PyTorch's opinion on how we will integrate float8. At this point, this is a prototype to get some light feedback.
TODOs that need to be implemented before review of this PR / sending to NVIDIA:
What is out of scope for this POC
a. hooking up to real float8 ops (saved for later, just needs someone to do it)
b. real UEX (saved for later and will need a lot of design discussion)
Test plan: