Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
precompute scale after optimizer.step for dynamic scaling (#266)
Browse files Browse the repository at this point in the history
Summary:
Goal: improve float8 all-gather perf in FSDP2 by precomputing scales for all float8 params with a single all-reduce

updated README for API usage: call `precompute_float8_scale_for_fsdp` inside the training loop after optimizer step

```
from float8_experimental.fsdp_utils import precompute_float8_amax_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_scale_for_fsdp(model)
```

unit test `pytest -s test/test_fsdp2/test_fsdp2_eager.py -k test_transformer_parity_dynamic`

**FSDP pre-forward**: shortend from 3ms to 1.8ms because of doing 1 all-reduce instead N small all-reduces
<img width="703" alt="Screenshot 2024-05-30 at 12 38 24 AM" src="https://github.com/pytorch-labs/float8_experimental/assets/134637289/81361471-fde4-43e4-ad83-a8c5b39f0cf1">

<img width="720" alt="Screenshot 2024-05-30 at 12 48 14 AM" src="https://github.com/pytorch-labs/float8_experimental/assets/134637289/26202869-cf7d-4427-b87f-570e5dc39324">

**Pre-computing amax**: shortened from 5ms to 1.7ms, by switching from `torch._foreach_abs` + `torch.max(a)` to `torch._foreach_norm(weights, ord=math.inf)`

<img width="1075" alt="Screenshot 2024-05-30 at 12 50 17 AM" src="https://github.com/pytorch-labs/float8_experimental/assets/134637289/823fb717-8f5b-42e9-afc8-6f6c34ab45b2">

<img width="1050" alt="Screenshot 2024-05-30 at 12 49 54 AM" src="https://github.com/pytorch-labs/float8_experimental/assets/134637289/5ea15f59-ec85-456b-a28c-3e672d2cdaae">

Pull Request resolved: #266

Reviewed By: vkuzo

Differential Revision: D59562409

Pulled By: weifengpy

fbshipit-source-id: 683c4719e20f6b30f39ca9109ee29e53981a2aec
  • Loading branch information
weifengpy authored and facebook-github-bot committed Jul 12, 2024
1 parent 73fd168 commit 6cba2ae
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 12 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ This is the most accurate recipe as every tensor is scaled dynamically.
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
from float8_experimental.float8_linear import Float8Linear

# create model
Expand All @@ -51,7 +52,18 @@ model = FSDP(model, use_orig_params=True)
# optional: enable torch.compile for improved performance
m = torch.compile(m)

# train/finetune (not shown)
# toy training loop
for _ in range(N_ITER):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()

# specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
# this method is optional but is highly recommended for performance
# it calcuclates scales for all parameters in a single all-reduce
precompute_float8_dynamic_scale_for_fsdp(model)

```

## float8 linear with delayed scaling
Expand Down
43 changes: 36 additions & 7 deletions float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ def cast_to_float8_e5m2_dynamic_bw(

class WeightWithDynamicFloat8CastTensor(torch.Tensor):
@staticmethod
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
def __new__(
cls,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
precomputed_scale: Optional[torch.Tensor] = None,
):
return torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
Expand All @@ -96,9 +101,18 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
requires_grad=tensor.requires_grad,
)

def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
def __init__(
self,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
precomputed_scale: Optional[torch.Tensor] = None,
):
self._tensor = tensor
self._mm_config = mm_config
# for dynamic scaling
# `precompute_float8_dynamic_scale_for_fsdp` calculates scales
# for all float8 parameters after optimizer step
self._precomputed_scale = precomputed_scale

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
Expand Down Expand Up @@ -127,20 +141,35 @@ def unwrap(t):
)

def __tensor_flatten__(self):
return ["_tensor"], self._mm_config
if self._precomputed_scale:
return ["_tensor", "_precomputed_scale"], self._mm_config
else:
return ["_tensor"], self._mm_config

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
mm_config = flatten_spec
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config)
return WeightWithDynamicFloat8CastTensor(
inner_tensors["_tensor"],
mm_config,
getattr(inner_tensors, "_precomputed_scale", None),
)

def __repr__(self):
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"

def fsdp_pre_all_gather(self, mesh):
float8_tensor = cast_to_float8_e4m3_dynamic(
self._tensor, self._mm_config, reduce_amax=True
)
if self._precomputed_scale is not None:
float8_tensor = Float8Tensor.to_float8(
self._tensor,
self._precomputed_scale,
torch.float8_e4m3fn,
mm_config=self._mm_config,
)
else:
float8_tensor = cast_to_float8_e4m3_dynamic(
self._tensor, self._mm_config, reduce_amax=True
)
return (float8_tensor._data,), (float8_tensor._scale,)

def fsdp_post_all_gather(
Expand Down
52 changes: 52 additions & 0 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import math
from typing import List

import torch
import torch.nn as nn
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_utils import EPS


@torch.no_grad()
def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
"""
Calculate scale dynamically for all float8 parameters.
This should be run after the optimizer step. It performs a single all-reduce to compute the
scales for all float8 weights.
Example usage:
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
"""
from torch.distributed._tensor import DTensor

if any(
isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED
for m in module.modules()
):
raise NotImplementedError("Only supports delayed scaling")
float8_linears: List[Float8Linear] = [
m
for m in module.modules()
if isinstance(m, Float8Linear)
and isinstance(m.weight, DTensor)
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor)
]
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]

if not weights:
return

# inf-norm is equivalent to max(abs(w))
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
amax_tensor = torch.vstack(max_weights) # Partial
# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
scales = torch.split(scale_tensor, 1) # Replicate
for scale, float8_linear in zip(scales, float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor
4 changes: 4 additions & 0 deletions test/test_fsdp2/test_fsdp2_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp


def check_parity_no_mp(
Expand All @@ -15,6 +16,7 @@ def check_parity_no_mp(
fsdp_model: nn.Module,
fsdp_optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
precompute: bool = False,
):
for iter_idx in range(10):
losses: List[torch.Tensor] = []
Expand All @@ -28,6 +30,8 @@ def check_parity_no_mp(
param.grad.div_(dist.get_world_size())
# TODO(future): add amax syncing once delayed scaling is supported
optim.step()
if model is fsdp_model and precompute:
precompute_float8_dynamic_scale_for_fsdp(model)
test_cls.assertEqual(losses[0], losses[1])


Expand Down
21 changes: 17 additions & 4 deletions test/test_fsdp2/test_fsdp2_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,21 @@ def world_size(self) -> int:

@skip_if_lt_x_gpu(2)
def test_transformer_parity_dynamic(self):
for enable_fsdp_fp8_all_gather in [False, True]:
self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather)
self.run_subtests(
{
"enable_fsdp_fp8_all_gather": [False, True],
"precompute": [False, True],
},
self._test_transformer_parity_dynamic,
)

def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
def _test_transformer_parity_dynamic(
self,
enable_fsdp_fp8_all_gather: bool,
precompute: bool,
):
if not enable_fsdp_fp8_all_gather and precompute:
return
# NOTE: Weight-tying does not compose with fp8 all-gather because the
# embedding weight and output linear weight are tied but only the
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
Expand All @@ -109,7 +120,9 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
local_inp = torch.randint(
0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda"
)
check_parity_no_mp(self, ref_module, ref_optim, module, optim, local_inp)
check_parity_no_mp(
self, ref_module, ref_optim, module, optim, local_inp, precompute
)

@skip_if_lt_x_gpu(2)
def test_transformer_memory(self):
Expand Down

0 comments on commit 6cba2ae

Please sign in to comment.