diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index ecd64fd..3cdbaf9 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -19,12 +19,7 @@ tensor_already_casted_to_fp8, to_fp8_no_autograd, ) -from float8_experimental.float8_utils import ( - amax_to_scale, - e4m3_dtype, - e5m2_dtype, - tensor_to_scale, -) +from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale from torch._prims_common import suggest_memory_format @@ -91,7 +86,7 @@ def __new__( cls, tensor: torch.Tensor, mm_config: ScaledMMConfig, - amax: Optional[torch.Tensor] = None, + precomputed_scale: Optional[torch.Tensor] = None, ): return torch.Tensor._make_wrapper_subclass( cls, @@ -110,14 +105,14 @@ def __init__( self, tensor: torch.Tensor, mm_config: ScaledMMConfig, - amax: Optional[torch.Tensor] = None, + precomputed_scale: Optional[torch.Tensor] = None, ): self._tensor = tensor self._mm_config = mm_config # for dynamic scaling - # `precompute_float8_amax_for_fsdp` calculates amax + # `precompute_float8_scale_for_fsdp` calculates scales # for all float8 parameters after optimizer step - self._precomputed_amax = amax + self._precomputed_scale = precomputed_scale @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -146,8 +141,8 @@ def unwrap(t): ) def __tensor_flatten__(self): - if self._precomputed_amax: - return ["_tensor", "_precomputed_amax"], self._mm_config + if self._precomputed_scale: + return ["_tensor", "_precomputed_scale"], self._mm_config else: return ["_tensor"], self._mm_config @@ -157,21 +152,19 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): return WeightWithDynamicFloat8CastTensor( inner_tensors["_tensor"], mm_config, - getattr(inner_tensors, "_precomputed_amax", None), + getattr(inner_tensors, "_precomputed_scale", None), ) def __repr__(self): return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" def fsdp_pre_all_gather(self, mesh): - if self._precomputed_amax is not None: - scale = amax_to_scale( - self._precomputed_amax, - torch.float8_e4m3fn, - self._precomputed_amax.dtype, - ) + if self._precomputed_scale is not None: float8_tensor = Float8Tensor.to_float8( - self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config + self._tensor, + self._precomputed_scale, + torch.float8_e4m3fn, + mm_config=self._mm_config, ) else: float8_tensor = cast_to_float8_e4m3_dynamic( diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index ad5ffe1..2be568e 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -34,9 +34,7 @@ @torch.no_grad() def amax_to_scale( - amax: torch.Tensor, - float8_dtype: torch.dtype, - orig_dtype: torch.dtype, + amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype ): """Converts the amax value of a tensor to the fp8 scale. Args: diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index bb21e88..e06ec66 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -10,14 +10,14 @@ from float8_experimental.float8_utils import EPS -def precompute_float8_amax_for_fsdp(module: nn.Module) -> None: +def precompute_float8_scale_for_fsdp(module: nn.Module) -> None: """ - Calculate amax for all float8 parameters after optimizer step + Calculate scale for all float8 parameters after optimizer step It performs a single all-reduce instead of many all-reduces for each parameter Exmaple usage: model(input).sum().backward() optim.step() - precompute_float8_amax_for_fsdp(model) + precompute_float8_scale_for_fsdp(model) """ from torch.distributed._tensor import DTensor @@ -38,20 +38,23 @@ def precompute_float8_amax_for_fsdp(module: nn.Module) -> None: ] weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] - def compute_amaxes(weights: List[DTensor]): + def compute_scales(weights: List[DTensor]): # inf-norm is equivalent to max(abs(w)) max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial amax_tensor = torch.vstack(max_weights) # Partial # clamp is dispatched through DTensor # it will issue a single all-reduce amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate - amaxes = torch.split(amax_tensor, 1) # Replicate - return amaxes + scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + if amax_tensor.dtype is torch.float16: + scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) + scales = torch.split(scale_tensor, 1) # Replicate + return scales if weights: - amaxes = compute_amaxes(weights) - for amax, float8_linear in zip(amaxes, float8_linears): - float8_linear.weight._local_tensor._precomputed_amax = amax._local_tensor + scales = compute_scales(weights) + for scale, float8_linear in zip(scales, float8_linears): + float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor else: warnings.warn( "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!"