Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit ac0afb0

Browse files
committed
rename to precompute_float8_dynamic_scale_for_fsdp
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent ba085e5 commit ac0afb0

File tree

3 files changed

+28
-36
lines changed

3 files changed

+28
-36
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ This is the most accurate recipe as every tensor is scaled dynamically.
3737
from float8_experimental.float8_linear_utils import (
3838
swap_linear_with_float8_linear,
3939
)
40+
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
4041
from float8_experimental.float8_linear import Float8Linear
4142

4243
# create model
@@ -58,10 +59,10 @@ for _ in range(N_ITER):
5859
y.sum().backward()
5960
optimizer.step()
6061

61-
# specific to fsdp2 + float8 with dynamic scaling
62+
# specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
6263
# this method is optional but is highly recommended for performance
6364
# it calcuclates scales for all parameters in a single all-reduce
64-
precompute_float8_scale_for_fsdp(model)
65+
precompute_float8_dynamic_scale_for_fsdp(model)
6566

6667
```
6768

float8_experimental/fsdp_utils.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,28 @@
11
import math
2-
import warnings
32
from typing import List
43

54
import torch
65
import torch.nn as nn
76
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
8-
from float8_experimental.float8_linear import Float8Linear
9-
from float8_experimental.float8_linear_utils import linear_requires_sync
7+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
108
from float8_experimental.float8_utils import EPS
119

1210

13-
def precompute_float8_scale_for_fsdp(module: nn.Module) -> None:
11+
@torch.no_grad()
12+
def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
1413
"""
15-
Calculate scale for all float8 parameters after optimizer step
16-
It performs a single all-reduce instead of many all-reduces for each parameter
17-
Exmaple usage:
14+
Calculate scale dynamically for all float8 parameters.
15+
This should be run after the optimizer step. It performs a single all-reduce to compute the
16+
scales for all float8 weights.
17+
Example usage:
1818
model(input).sum().backward()
1919
optim.step()
20-
precompute_float8_scale_for_fsdp(model)
20+
precompute_float8_dynamic_scale_for_fsdp(model)
2121
"""
2222
from torch.distributed._tensor import DTensor
2323

2424
if any(
25-
isinstance(m, Float8Linear)
26-
and linear_requires_sync(
27-
m.scaling_type_x, m.scaling_type_w, m.scaling_type_dL_dY
28-
)
25+
isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED
2926
for m in module.modules()
3027
):
3128
raise NotImplementedError("Only supports delayed scaling")
@@ -38,24 +35,18 @@ def precompute_float8_scale_for_fsdp(module: nn.Module) -> None:
3835
]
3936
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]
4037

41-
def compute_scales(weights: List[DTensor]):
42-
# inf-norm is equivalent to max(abs(w))
43-
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
44-
amax_tensor = torch.vstack(max_weights) # Partial
45-
# clamp is dispatched through DTensor
46-
# it will issue a single all-reduce
47-
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
48-
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
49-
if amax_tensor.dtype is torch.float16:
50-
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
51-
scales = torch.split(scale_tensor, 1) # Replicate
52-
return scales
38+
if not weights:
39+
return
5340

54-
if weights:
55-
scales = compute_scales(weights)
56-
for scale, float8_linear in zip(scales, float8_linears):
57-
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor
58-
else:
59-
warnings.warn(
60-
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!"
61-
)
41+
# inf-norm is equivalent to max(abs(w))
42+
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
43+
amax_tensor = torch.vstack(max_weights) # Partial
44+
# clamp is dispatched through DTensor
45+
# it will issue a single all-reduce
46+
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
47+
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
48+
if amax_tensor.dtype is torch.float16:
49+
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
50+
scales = torch.split(scale_tensor, 1) # Replicate
51+
for scale, float8_linear in zip(scales, float8_linears):
52+
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor

test/test_fsdp2/test_fsdp2_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import torch.distributed as dist
88
import torch.nn as nn
9-
from float8_experimental.fsdp_utils import precompute_float8_scale_for_fsdp
9+
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
1010

1111

1212
def check_parity_no_mp(
@@ -31,7 +31,7 @@ def check_parity_no_mp(
3131
# TODO(future): add amax syncing once delayed scaling is supported
3232
optim.step()
3333
if model is fsdp_model and precompute:
34-
precompute_float8_scale_for_fsdp(model)
34+
precompute_float8_dynamic_scale_for_fsdp(model)
3535
test_cls.assertEqual(losses[0], losses[1])
3636

3737

0 commit comments

Comments
 (0)