Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
precompute scale
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Jul 10, 2024
1 parent e12c973 commit 9ef67fb
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 32 deletions.
33 changes: 13 additions & 20 deletions float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
from float8_experimental.float8_utils import (
amax_to_scale,
e4m3_dtype,
e5m2_dtype,
tensor_to_scale,
)
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale
from torch._prims_common import suggest_memory_format


Expand Down Expand Up @@ -91,7 +86,7 @@ def __new__(
cls,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
amax: Optional[torch.Tensor] = None,
precomputed_scale: Optional[torch.Tensor] = None,
):
return torch.Tensor._make_wrapper_subclass(
cls,
Expand All @@ -110,14 +105,14 @@ def __init__(
self,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
amax: Optional[torch.Tensor] = None,
precomputed_scale: Optional[torch.Tensor] = None,
):
self._tensor = tensor
self._mm_config = mm_config
# for dynamic scaling
# `precompute_float8_amax_for_fsdp` calculates amax
# `precompute_float8_scale_for_fsdp` calculates scales
# for all float8 parameters after optimizer step
self._precomputed_amax = amax
self._precomputed_scale = precomputed_scale

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
Expand Down Expand Up @@ -146,8 +141,8 @@ def unwrap(t):
)

def __tensor_flatten__(self):
if self._precomputed_amax:
return ["_tensor", "_precomputed_amax"], self._mm_config
if self._precomputed_scale:
return ["_tensor", "_precomputed_scale"], self._mm_config
else:
return ["_tensor"], self._mm_config

Expand All @@ -157,21 +152,19 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
return WeightWithDynamicFloat8CastTensor(
inner_tensors["_tensor"],
mm_config,
getattr(inner_tensors, "_precomputed_amax", None),
getattr(inner_tensors, "_precomputed_scale", None),
)

def __repr__(self):
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"

def fsdp_pre_all_gather(self, mesh):
if self._precomputed_amax is not None:
scale = amax_to_scale(
self._precomputed_amax,
torch.float8_e4m3fn,
self._precomputed_amax.dtype,
)
if self._precomputed_scale is not None:
float8_tensor = Float8Tensor.to_float8(
self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config
self._tensor,
self._precomputed_scale,
torch.float8_e4m3fn,
mm_config=self._mm_config,
)
else:
float8_tensor = cast_to_float8_e4m3_dynamic(
Expand Down
4 changes: 1 addition & 3 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@

@torch.no_grad()
def amax_to_scale(
amax: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
Expand Down
21 changes: 12 additions & 9 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from float8_experimental.float8_utils import EPS


def precompute_float8_amax_for_fsdp(module: nn.Module) -> None:
def precompute_float8_scale_for_fsdp(module: nn.Module) -> None:
"""
Calculate amax for all float8 parameters after optimizer step
Calculate scale for all float8 parameters after optimizer step
It performs a single all-reduce instead of many all-reduces for each parameter
Exmaple usage:
model(input).sum().backward()
optim.step()
precompute_float8_amax_for_fsdp(model)
precompute_float8_scale_for_fsdp(model)
"""
from torch.distributed._tensor import DTensor

Expand All @@ -38,20 +38,23 @@ def precompute_float8_amax_for_fsdp(module: nn.Module) -> None:
]
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]

def compute_amaxes(weights: List[DTensor]):
def compute_scales(weights: List[DTensor]):
# inf-norm is equivalent to max(abs(w))
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
amax_tensor = torch.vstack(max_weights) # Partial
# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
amaxes = torch.split(amax_tensor, 1) # Replicate
return amaxes
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
scales = torch.split(scale_tensor, 1) # Replicate
return scales

if weights:
amaxes = compute_amaxes(weights)
for amax, float8_linear in zip(amaxes, float8_linears):
float8_linear.weight._local_tensor._precomputed_amax = amax._local_tensor
scales = compute_scales(weights)
for scale, float8_linear in zip(scales, float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor
else:
warnings.warn(
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!"
Expand Down

0 comments on commit 9ef67fb

Please sign in to comment.