Skip to content

linear: clear row-wise weight at the end of forward #1770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

kshitij12345
Copy link

@kshitij12345 kshitij12345 commented May 12, 2025

Description

image

Per the above image, row-wise quantized weight can be freed post the forward. However, this does not happen without explicit update_usage(columnwise=True, rowwise=False) (as with the default, row-wise copy is preserved)

if inp.requires_grad:
if isinstance(weightmat, QuantizedTensorBase):
weightmat.update_usage(columnwise_usage=True)

Implementation of update_usage for MXFP8TensorBase.

Implementation of update_usage for Float8TensorBase.

Example Script

import torch

import transformer_engine
from transformer_engine.pytorch import fp8_autocast, Linear, fp8_model_init

dim = 1024 * 22 # Large input for demonstration of memory change.

with fp8_model_init(enabled=False):
    linear = Linear(dim, dim, bias=False, params_dtype=torch.bfloat16)

# 1015.021568 MB PARAM
print(torch.cuda.memory_allocated() / 1e6, "MB PARAM")

x = torch.randn(dim, dim, requires_grad=True, device="cuda", dtype=torch.bfloat16)

# 2030.043136 MB X
print(torch.cuda.memory_allocated() / 1e6, "MB X")

for _ in range(10):
    with fp8_autocast():
        o = linear(x)
        g_o = torch.randn_like(o)
    

    o.backward(g_o)

# Without patch - 10661.943808 MB
# With patch - 10154.433024 MB MB
print(torch.cuda.max_memory_allocated() / 1e6, "MB")

NOTE: Will add a test if the patch makes sense.

@kshitij12345
Copy link
Author

Ping @ksivaman @ptrendx to see if it makes sense.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant