Skip to content

Commit

Permalink
Delete extra tensor objects after restoring float8 tensors (NVIDIA#1500)
Browse files Browse the repository at this point in the history
* delete extra tensor objects after restoring float8 tensors

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* nit fix

Signed-off-by: Sudhakar Singh <[email protected]>

* fix the leak in float8tensor and mxfloat8tensor classes

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* uncomment the fix

Signed-off-by: Sudhakar Singh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Sudhakar Singh <[email protected]>

---------

Signed-off-by: Sudhakar Singh <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
sudhakarsingh27 and pre-commit-ci[bot] authored Feb 28, 2025
1 parent 303c6d1 commit d3efaeb
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 4 deletions.
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,9 @@ def backward(
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,10 @@ def backward(
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
fc1_weight_main_grad = (
ctx.fc1_main_grad
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors)
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8Tensor
"""
tensors = [self._data, self._transpose]
# self._data = None
# self._transpose = None
self._data = None
self._transpose = None
return tensors, self

def restore_from_saved(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorB
"""
tensors = [self._rowwise_data, self._columnwise_data]
# self._rowwise_data = None
# self._columnwise_data = None
self._rowwise_data = None
self._columnwise_data = None
return tensors, self

def restore_from_saved(
Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,15 @@ def clear(self):
self._transpose = torch.Tensor() if self._transpose is not None else None
self._transpose_invalid = True

def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]:
"""Prepare the tensor base for saving for backward
After calling this, the tensor instance does not hold any
data.
"""
return [self], None

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):

Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,15 @@ def clear(self):
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None

def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]:
"""Prepare the tensor base for saving for backward
After calling this, the tensor instance does not hold any
data.
"""
return [self], None

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):

Expand Down

0 comments on commit d3efaeb

Please sign in to comment.