Skip to content

Commit da1bbf9

Browse files
committed
fixes
Signed-off-by: Pawel Gadzinski <[email protected]>
1 parent ac93bb4 commit da1bbf9

File tree

5 files changed

+15
-9
lines changed

5 files changed

+15
-9
lines changed

tests/pytorch/test_cpu_offloading.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
recipe.Float8BlockScaling(),
2727
]
2828

29-
SIZE = 512
29+
SIZE = 64
3030
NUM_HEADS = 8
3131
NUM_LAYERS = 5
32-
EPSILON = 0.1
32+
EPSILON = 0.05
3333

3434
# Flash attention saves some internal tensor for the backward pass
3535
# that cannot be offloaded to CPU.
@@ -48,7 +48,7 @@
4848
SIZE, NUM_HEADS, params_dtype=torch.bfloat16
4949
),
5050
"transformer_layer": lambda: te.TransformerLayer(
51-
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
51+
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
5252
),
5353
}
5454

@@ -97,7 +97,8 @@ def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload
9797
), offload_context:
9898
tensor = model(tensor)
9999
tensor = sync_function(tensor)
100-
100+
101+
import gc; gc.collect()
101102
max_mem_used = torch.cuda.memory_allocated() / (1024**2)
102103
torch.cuda.synchronize()
103104

@@ -119,7 +120,6 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
119120
the difference being the size of the FP8 cache that is not offloaded to the CPU.
120121
We also expect this memory consumption to be smaller than in scenario (1).
121122
"""
122-
123123
model_cls = model_types[model_key]
124124
models_list = [model_cls() for _ in range(NUM_LAYERS)]
125125

transformer_engine/pytorch/cpu_offload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def mark_activation_offload(*tensors):
2222
if isinstance(tensor, torch.Tensor):
2323
tensor.activation_offloading = True
2424
else:
25-
data_tensors = tensor.get_data_tensors()
25+
data_tensors = tensor.get_data_tensors(scaling_factors=True)
2626
for tensor in data_tensors:
2727
if tensor is not None:
2828
tensor.activation_offloading = True

transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ def restore_from_saved(
112112
self._columnwise_scale_inv = tensors[3]
113113
return tensors[4:]
114114

115-
def get_data_tensors(self):
115+
def get_data_tensors(self, scaling_factors=False):
116116
"""Get this Tensor's data."""
117+
if scaling_factors:
118+
return self._rowwise_data, self._columnwise_data, self._rowwise_scale_inv, self._columnwise_scale_inv
117119
return self._rowwise_data, self._columnwise_data
118120

119121
def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor:

transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,10 @@ def restore_from_saved(
128128
self._scale_inv = tensors[2]
129129
return tensors[3:]
130130

131-
def get_data_tensors(self):
131+
def get_data_tensors(self, scaling_factors=False):
132132
"""Get this Tensor's data."""
133+
if scaling_factors:
134+
return self._data, self._transpose, self._scale_inv
133135
return self._data, self._transpose
134136

135137
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:

transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,10 @@ def restore_from_saved(
131131
self._columnwise_scale_inv = tensors[3]
132132
return tensors[4:]
133133

134-
def get_data_tensors(self):
134+
def get_data_tensors(self, scaling_factors=False):
135135
"""Get this Tensor's data."""
136+
if scaling_factors:
137+
return self._rowwise_data, self._columnwise_data, self._rowwise_scale_inv, self._columnwise_scale_inv
136138
return self._rowwise_data, self._columnwise_data
137139

138140
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:

0 commit comments

Comments
 (0)