Skip to content

[PyTorch] Draft of new activation offloading API #1762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
234 changes: 160 additions & 74 deletions tests/pytorch/test_cpu_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,29 @@
# See LICENSE for license information.

import os
from contextlib import nullcontext
import pytest
import torch

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.quantized_tensor import prepare_for_saving, restore_from_saved
from transformer_engine.pytorch.cpu_offload import offload, _manual_reload, CPUOffload

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_available, reason_for_no_fp8_block = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)

fp8_recipes = [
None, # non-fp8
# recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet
recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(),
recipe.DelayedScaling(),
recipe.Float8BlockScaling(),
]

SIZE = 512
Expand Down Expand Up @@ -48,97 +54,177 @@
),
}


def _get_input():
return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda()

return torch.empty((8, SIZE, SIZE), dtype=torch.bfloat16, requires_grad=True).cuda()

def test_auto_offload():
def run(offload_enabled):
inp = _get_input()

def compute(input_tensor):
x = _get_input()
if offload_enabled:
offload(x)
y = input_tensor * x
# x is necessary for backward pass, thus it will be saved.
return y

cpu_offload = CPUOffload()
if offload_enabled:
compute = cpu_offload(compute)

y = compute(inp)
cpu_offload.sync_before_bwd()

memory_allocated = torch.cuda.memory_allocated() / (1024**2)
y.sum().backward() # for sanity check
return memory_allocated

# x will be offloaded to CPU when offload_enabled is True
# which should result in SIZE * SIZE * 2 / (1024 ** 2) memory allocated
assert run(True) < run(False)
assert run(False) - run(True) > (SIZE * SIZE * 2 / (1024 ** 2)) - EPSILON

def _tensor_size(x):
if type(x) == torch.Tensor:
return x.numel() * x.element_size() / (1024 ** 2)
elif type(x) == te.float8_tensor.Float8Tensor:
return x._data.numel() * x._data.element_size() / (1024 ** 2)
elif type(x) == te.tensor._internal.float8_tensor_base.Float8TensorBase:
return x._data.numel() * x._data.element_size() / (1024 ** 2)
else:
raise ValueError(f"Unknown tensor type: {type(x)}")

tensor_empty_constructrs = {
"tensor": lambda: torch.empty((SIZE, SIZE), dtype=torch.bfloat16).cuda(),
"float8tensor": lambda: te.float8_tensor.Float8Tensor(
data=torch.empty((SIZE, SIZE), dtype=torch.uint8, device="cuda"),
fp8_scale_inv=torch.tensor(1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=(SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
),
"float8tensorbase": lambda: te.tensor._internal.float8_tensor_base.Float8TensorBase(
data=torch.empty((SIZE, SIZE), dtype=torch.uint8, device="cuda"),
fp8_scale_inv=torch.tensor(1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=(SIZE, SIZE),
dtype=torch.bfloat16,
),
}

def _get_fp8_weight_cache_size(models, fp8_recipe):
@pytest.mark.parametrize("x_tensor_type", tensor_empty_constructrs.keys())
def test_manual_offload(x_tensor_type):
class Function(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor, x, offload_enabled):
if offload_enabled:
offload(x, manual_reload=True)

tensors, tensor_objects = prepare_for_saving(x, input_tensor)
ctx.tensor_objects = tensor_objects
ctx.save_for_backward(*tensors)
ctx.offload_enabled = offload_enabled
return input_tensor

@staticmethod
def backward(ctx, _):
torch.cuda.synchronize()
x, input_tensor = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
if ctx.offload_enabled:
if hasattr(x, "device"):
assert x.device.type == "cpu"
x = _manual_reload(x)
#if hasattr(x, "device"):
# assert x.device.type == "cuda"
return input_tensor, None, None

def run(offload_enabled):
inp = _get_input()
def compute(input_tensor):
x = tensor_empty_constructrs[x_tensor_type]()
return Function.apply(input_tensor, x, offload_enabled)

cpu_offload = CPUOffload()
if offload_enabled:
compute = cpu_offload(compute)

y = compute(inp)
cpu_offload.sync_before_bwd()

memory_allocated = torch.cuda.memory_allocated() / (1024**2)
y.sum().backward() # for sanity check
return memory_allocated

# x will be offloaded to CPU when offload_enabled is True
assert run(True) < run(False)
diff = run(False) - run(True)
assert abs(diff - _tensor_size(tensor_empty_constructrs[x_tensor_type]())) < EPSILON

def _get_fp8_weight_cache_size(model, fp8_recipe):
"""
Calculate the total FP8 weight cache size (in MB) for a list of models.
"""
if fp8_recipe is None:
return 0

params_bytes = 0
for model in models:
for name, param in model.named_parameters():
if "weight" in name:
params_bytes += param.numel()
for name, param in model.named_parameters():
if "weight" in name:
params_bytes += param.numel()

# One byte for columnwise and one byte for rowwise,
# hence multiply by 2 and convert to MB
# there is 1 byte of scale per 32 elements in mxFP8
factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1
return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2)


def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload):
tensor = _get_input()
if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context(
enabled=True,
num_layers=len(models) - 1,
model_layers=len(models),
offload_activations=True,
offload_weights=False,
)
else:
offload_context = nullcontext()
sync_function = lambda x: x

for model in models:
with te.fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe
), offload_context:
tensor = model(tensor)
tensor = sync_function(tensor)

max_mem_used = torch.cuda.memory_allocated() / (1024**2)
torch.cuda.synchronize()

return max_mem_used


@pytest.mark.parametrize("layer_type", model_types.keys())
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model_key", model_types.keys())
def test_cpu_offload(fp8_recipe, model_key) -> None:
"""
We run three configurations:
(1) No offloading: All activations remain on the GPU between forward and backward passes.
(2) No offloading (one layer): Only the first layer's activations remain on the GPU between
forward and backward passes.
(3) With offloading (all layers): Only the last layer's activations remain on the GPU
between forward and backward passes, while all other layers are offloaded to the CPU.

We expect the memory consumption of configurations (2) and (3) to be similar, with
the difference being the size of the FP8 cache that is not offloaded to the CPU.
We also expect this memory consumption to be smaller than in scenario (1).
"""

model_cls = model_types[model_key]
models_list = [model_cls() for _ in range(NUM_LAYERS)]

if fp8_recipe and not fp8_available:
def test_cpu_offload_on_layers(layer_type, fp8_recipe):
if not fp8_available and fp8_recipe is not None:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None:
if fp8_recipe.mxfp8() and not mxfp8_available:
if not mxfp8_available and fp8_recipe.mxfp8():
pytest.skip(reason_for_no_mxfp8)

without_offloading = _measure_memory_between_forward_and_backward(
models_list, fp8_recipe, False
)
without_offloading_one_layer = _measure_memory_between_forward_and_backward(
models_list[:1], fp8_recipe, False
)
with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True)

assert with_offloading < without_offloading

# The only difference between the memory consumption of with_offloading
# and without_offloading_one_layer should be the size of the FP8 weights cache,
# which is not offloaded to the CPU.
memory_consumption_diff = abs(with_offloading - without_offloading_one_layer)
if not fp8_block_available and fp8_recipe.float8_block_scaling():
pytest.skip(reason_for_no_fp8_block)
model = model_types[layer_type]()

def _get_memory(offload_enabled):
def comp(inp):
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
y = model(inp)
return y
def run_comp(f):
inp = _get_input()
return f(inp)
cpu_offload = CPUOffload()
if offload_enabled:
comp = cpu_offload(comp)
y = run_comp(comp)
cpu_offload.sync_before_bwd()

memory_allocated = torch.cuda.memory_allocated() / (1024**2)

y.sum().backward()
return memory_allocated

# warm up
_get_memory(False)
_get_memory(True)
_get_memory(False)
initial_memory = torch.cuda.memory_allocated() / (1024**2)

with_offload = _get_memory(True)
without_offload = _get_memory(False)
print(f"initial_memory: {initial_memory}, with_offload: {with_offload}, without_offload: {without_offload}")
assert with_offload < without_offload
diff = with_offload - initial_memory
output_size = _tensor_size(_get_input())
assert (
memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON
diff < _get_fp8_weight_cache_size(model, fp8_recipe) + EPSILON + output_size
)

# prepare more general tests.
13 changes: 6 additions & 7 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
TransformerLayer,
RMSNorm,
LayerNorm,
get_cpu_offload_context,
CPUOffload
)
from transformer_engine.common import recipe
import transformer_engine_torch as tex
Expand Down Expand Up @@ -289,15 +289,14 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
_disable_wgrads(block)

if cpu_offload:
offload_context, sync_function = get_cpu_offload_context(enabled=True)
else:
offload_context = nullcontext()
sync_function = lambda x: x
cpu_offload = CPUOffload()
block = cpu_offload(block)

use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states)
te_out = sync_function(te_out)
if cpu_offload:
cpu_offload.sync_before_bwd()
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _load_library():
from transformer_engine.pytorch.graph import make_graphed_callables
from transformer_engine.pytorch.distributed import checkpoint
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch.cpu_offload import CPUOffload
from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,12 +683,12 @@ def forward(
)
else:
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadEnabled,
mark_activation_offload,
is_cpu_offload_enabled,
offload,
)

if CPUOffloadEnabled:
mark_activation_offload(
if is_cpu_offload_enabled():
offload(
query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv
)

Expand Down Expand Up @@ -1054,19 +1054,19 @@ def forward(
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))

from transformer_engine.pytorch.cpu_offload import (
CPUOffloadEnabled,
mark_activation_offload,
is_cpu_offload_enabled,
offload,
)

if CPUOffloadEnabled:
if is_cpu_offload_enabled():
if ctx.fp8:
tensor_list = fp8_tensors
else:
tensor_list = [q, k, v, out_save]

qkv_layout = "sbhd_sbhd_sbhd"
mark_activation_offload(*tensor_list)
mark_activation_offload(*aux_ctx_tensors)
offload(*tensor_list)
offload(*aux_ctx_tensors)

ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1117,15 +1117,15 @@ def forward(
cp_stream=self.cp_stream,
cp_comm_type=self.cp_comm_type,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
)

from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
from transformer_engine.pytorch.cpu_offload import is_cpu_offload_enabled

if CPUOffloadEnabled:
if is_cpu_offload_enabled():
warnings.warn(
"Attention activation Offloading is only implemented"
"with Flash Attention and Fused Attention!"
Expand Down
Loading