linear: clear row-wise weight at the end of forward #1770
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Per the above image, row-wise quantized weight can be freed post the forward. However, this does not happen without explicit
update_usage(columnwise=True, rowwise=False)
(as with the default, row-wise copy is preserved)TransformerEngine/transformer_engine/pytorch/module/linear.py
Lines 357 to 359 in 51cd441
Implementation of
update_usage
forMXFP8TensorBase
.TransformerEngine/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py
Line 159 in 51cd441
Implementation of
update_usage
forFloat8TensorBase
.TransformerEngine/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
Line 163 in 51cd441
Example Script
NOTE: Will add a test if the patch makes sense.