diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 01bda64101..007821038f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -448,6 +448,9 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed + # by the `restore_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 88eebc8e6c..f4ee0a1155 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -567,6 +567,10 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed + # by the `restore_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None + # Since main_grad can be modified inplace, it should not be a part of saved_tensors fc1_weight_main_grad = ( ctx.fc1_main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index bae21eebfd..83dc652c62 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -354,6 +354,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking restore_from_saved(ctx.tensor_objects, saved_tensors) ) + # Delete the references to tensor objects once they've been consumed + # by the `restore_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ( diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 6b816db3b5..8ae45c9375 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -105,8 +105,8 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8Tensor """ tensors = [self._data, self._transpose] - # self._data = None - # self._transpose = None + self._data = None + self._transpose = None return tensors, self def restore_from_saved( diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index d78bd55d9a..ea7fc3cf2f 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -100,8 +100,8 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorB """ tensors = [self._rowwise_data, self._columnwise_data] - # self._rowwise_data = None - # self._columnwise_data = None + self._rowwise_data = None + self._columnwise_data = None return tensors, self def restore_from_saved( diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index c9e65bd93a..333b8d1733 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -348,6 +348,15 @@ def clear(self): self._transpose = torch.Tensor() if self._transpose is not None else None self._transpose_invalid = True + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: + """Prepare the tensor base for saving for backward + + After calling this, the tensor instance does not hold any + data. + + """ + return [self], None + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 6e3835fbef..940f2ae46f 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -285,6 +285,15 @@ def clear(self): self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: + """Prepare the tensor base for saving for backward + + After calling this, the tensor instance does not hold any + data. + + """ + return [self], None + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None):