Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
rename to precompute_float8_dynamic_scale_for_fsdp
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Jul 11, 2024
1 parent ba085e5 commit ac0afb0
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 36 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ This is the most accurate recipe as every tensor is scaled dynamically.
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
from float8_experimental.float8_linear import Float8Linear

# create model
Expand All @@ -58,10 +59,10 @@ for _ in range(N_ITER):
y.sum().backward()
optimizer.step()

# specific to fsdp2 + float8 with dynamic scaling
# specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
# this method is optional but is highly recommended for performance
# it calcuclates scales for all parameters in a single all-reduce
precompute_float8_scale_for_fsdp(model)
precompute_float8_dynamic_scale_for_fsdp(model)

```

Expand Down
55 changes: 23 additions & 32 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,28 @@
import math
import warnings
from typing import List

import torch
import torch.nn as nn
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import linear_requires_sync
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_utils import EPS


def precompute_float8_scale_for_fsdp(module: nn.Module) -> None:
@torch.no_grad()
def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
"""
Calculate scale for all float8 parameters after optimizer step
It performs a single all-reduce instead of many all-reduces for each parameter
Exmaple usage:
Calculate scale dynamically for all float8 parameters.
This should be run after the optimizer step. It performs a single all-reduce to compute the
scales for all float8 weights.
Example usage:
model(input).sum().backward()
optim.step()
precompute_float8_scale_for_fsdp(model)
precompute_float8_dynamic_scale_for_fsdp(model)
"""
from torch.distributed._tensor import DTensor

if any(
isinstance(m, Float8Linear)
and linear_requires_sync(
m.scaling_type_x, m.scaling_type_w, m.scaling_type_dL_dY
)
isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED
for m in module.modules()
):
raise NotImplementedError("Only supports delayed scaling")
Expand All @@ -38,24 +35,18 @@ def precompute_float8_scale_for_fsdp(module: nn.Module) -> None:
]
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]

def compute_scales(weights: List[DTensor]):
# inf-norm is equivalent to max(abs(w))
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
amax_tensor = torch.vstack(max_weights) # Partial
# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
scales = torch.split(scale_tensor, 1) # Replicate
return scales
if not weights:
return

if weights:
scales = compute_scales(weights)
for scale, float8_linear in zip(scales, float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor
else:
warnings.warn(
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!"
)
# inf-norm is equivalent to max(abs(w))
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
amax_tensor = torch.vstack(max_weights) # Partial
# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
scales = torch.split(scale_tensor, 1) # Replicate
for scale, float8_linear in zip(scales, float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor
4 changes: 2 additions & 2 deletions test/test_fsdp2/test_fsdp2_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.fsdp_utils import precompute_float8_scale_for_fsdp
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp


def check_parity_no_mp(
Expand All @@ -31,7 +31,7 @@ def check_parity_no_mp(
# TODO(future): add amax syncing once delayed scaling is supported
optim.step()
if model is fsdp_model and precompute:
precompute_float8_scale_for_fsdp(model)
precompute_float8_dynamic_scale_for_fsdp(model)
test_cls.assertEqual(losses[0], losses[1])


Expand Down

0 comments on commit ac0afb0

Please sign in to comment.