1
1
import math
2
- import warnings
3
2
from typing import List
4
3
5
4
import torch
6
5
import torch .nn as nn
7
6
from float8_experimental .float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
8
- from float8_experimental .float8_linear import Float8Linear
9
- from float8_experimental .float8_linear_utils import linear_requires_sync
7
+ from float8_experimental .float8_linear import Float8Linear , TensorScalingType
10
8
from float8_experimental .float8_utils import EPS
11
9
12
10
13
- def precompute_float8_scale_for_fsdp (module : nn .Module ) -> None :
11
+ @torch .no_grad ()
12
+ def precompute_float8_dynamic_scale_for_fsdp (module : nn .Module ) -> None :
14
13
"""
15
- Calculate scale for all float8 parameters after optimizer step
16
- It performs a single all-reduce instead of many all-reduces for each parameter
17
- Exmaple usage:
14
+ Calculate scale dynamically for all float8 parameters.
15
+ This should be run after the optimizer step. It performs a single all-reduce to compute the
16
+ scales for all float8 weights.
17
+ Example usage:
18
18
model(input).sum().backward()
19
19
optim.step()
20
- precompute_float8_scale_for_fsdp (model)
20
+ precompute_float8_dynamic_scale_for_fsdp (model)
21
21
"""
22
22
from torch .distributed ._tensor import DTensor
23
23
24
24
if any (
25
- isinstance (m , Float8Linear )
26
- and linear_requires_sync (
27
- m .scaling_type_x , m .scaling_type_w , m .scaling_type_dL_dY
28
- )
25
+ isinstance (m , Float8Linear ) and m .scaling_type_w is TensorScalingType .DELAYED
29
26
for m in module .modules ()
30
27
):
31
28
raise NotImplementedError ("Only supports delayed scaling" )
@@ -38,24 +35,18 @@ def precompute_float8_scale_for_fsdp(module: nn.Module) -> None:
38
35
]
39
36
weights : List [DTensor ] = [float8_linear .weight for float8_linear in float8_linears ]
40
37
41
- def compute_scales (weights : List [DTensor ]):
42
- # inf-norm is equivalent to max(abs(w))
43
- max_weights = torch ._foreach_norm (weights , ord = math .inf ) # Partial
44
- amax_tensor = torch .vstack (max_weights ) # Partial
45
- # clamp is dispatched through DTensor
46
- # it will issue a single all-reduce
47
- amax_tensor = torch .clamp (amax_tensor , EPS ) # Replicate
48
- scale_tensor = torch .finfo (torch .float8_e4m3fn ).max / amax_tensor # Replicate
49
- if amax_tensor .dtype is torch .float16 :
50
- scale_tensor = torch .clamp (scale_tensor , max = torch .finfo (torch .float16 ).max )
51
- scales = torch .split (scale_tensor , 1 ) # Replicate
52
- return scales
38
+ if not weights :
39
+ return
53
40
54
- if weights :
55
- scales = compute_scales (weights )
56
- for scale , float8_linear in zip (scales , float8_linears ):
57
- float8_linear .weight ._local_tensor ._precomputed_scale = scale ._local_tensor
58
- else :
59
- warnings .warn (
60
- "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!"
61
- )
41
+ # inf-norm is equivalent to max(abs(w))
42
+ max_weights = torch ._foreach_norm (weights , ord = math .inf ) # Partial
43
+ amax_tensor = torch .vstack (max_weights ) # Partial
44
+ # clamp is dispatched through DTensor
45
+ # it will issue a single all-reduce
46
+ amax_tensor = torch .clamp (amax_tensor , EPS ) # Replicate
47
+ scale_tensor = torch .finfo (torch .float8_e4m3fn ).max / amax_tensor # Replicate
48
+ if amax_tensor .dtype is torch .float16 :
49
+ scale_tensor = torch .clamp (scale_tensor , max = torch .finfo (torch .float16 ).max )
50
+ scales = torch .split (scale_tensor , 1 ) # Replicate
51
+ for scale , float8_linear in zip (scales , float8_linears ):
52
+ float8_linear .weight ._local_tensor ._precomputed_scale = scale ._local_tensor
0 commit comments