Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
support delayed scaling of weight in float8 all-gather (#312)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #312

Adds support for delayed scaling in FSDP2 float8 all-gather. In detail:
1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse
   code with the dynamic version because I'd rather not deal with
   plumbing optional tensors through dynamo. We can try that in a
   separate PR later.
2. wire `Float8Linear` to use (1)
3. add weight amax syncing back, since we need it for float8 all-gather
4. add test coverage for eager mode numerics

Next up (in separate PRs) will be training run validation for numerics, and
taking a look at performance.

Reviewed By: awgu

Differential Revision: D59685258

fbshipit-source-id: 9ff18d7649cc6e0e3c9e2a64a30a5ff8bc4108be
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 12, 2024
1 parent 3fe7c4a commit de93990
Show file tree
Hide file tree
Showing 5 changed files with 320 additions and 52 deletions.
85 changes: 54 additions & 31 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
tensor_to_amax,
)

from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor
from float8_experimental.fsdp_utils import (
WeightWithDelayedFloat8CastTensor,
WeightWithDynamicFloat8CastTensor,
)


def _maybe_initialize_amaxes_scales_for_float8_cast(
Expand Down Expand Up @@ -316,28 +319,30 @@ def cast_w_to_float8(
self, w: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
if self.scaling_type_w is TensorScalingType.DELAYED:
scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
self.fp8_amax_w,
self.fp8_amax_history_w,
self.fp8_scale_w,
scale_fn_name,
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
)

w_fp8 = Float8Tensor.to_float8(
w,
self.fp8_scale_w,
e4m3_dtype,
self.fp8_amax_w,
self.forward_config,
)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
self.fp8_amax_w,
self.fp8_amax_history_w,
self.fp8_scale_w,
scale_fn_name,
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
)

w_fp8 = Float8Tensor.to_float8(
w,
self.fp8_scale_w,
e4m3_dtype,
self.fp8_amax_w,
self.forward_config,
)
else:
assert self.scaling_type_w is TensorScalingType.DYNAMIC
# TODO(future): also support FSDP integration in delayed scaling path
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
Expand Down Expand Up @@ -436,18 +441,36 @@ def from_float(
scaling_type_dL_dY=scaling_type_dL_dY,
emulate=emulate,
)
if (
scaling_type_w == TensorScalingType.DYNAMIC
and config.enable_fsdp_fp8_all_gather
):
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
)
else:
assert not config.enable_fsdp_fp8_all_gather, "unsupported"
new_mod.weight = mod.weight
new_mod.weight = mod.weight
new_mod.bias = mod.bias
# need to create buffers again when moving from meta device to
# real device
new_mod.create_buffers()

# If FSDP float8 all-gather is on, wrap the weight in a float8-aware
# tensor subclass. This must happen last because:
# 1. weight needs to be on the correct device to create the buffers
# 2. buffers need to be already created for the delayed scaling version
# of the weight wrapper to be initialized
if config.enable_fsdp_fp8_all_gather:
if scaling_type_w is TensorScalingType.DYNAMIC:
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.forward_config,
)
)
else:
assert scaling_type_w is TensorScalingType.DELAYED
new_mod.weight = torch.nn.Parameter(
WeightWithDelayedFloat8CastTensor(
new_mod.weight,
new_mod.fp8_amax_w,
new_mod.fp8_amax_history_w,
new_mod.fp8_scale_w,
new_mod.forward_config,
new_mod.is_amax_initialized,
)
)

return new_mod
13 changes: 7 additions & 6 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,10 @@ def inner_func():
), "Mismatched lengths of amax tensors."

if dist.is_initialized():
# Combine all the amax tensors into one tensor and reduce it
# Note: do not reduce the weight values, because FSDP already ensures
# the weight values on all ranks are the same after all-gather.
all_amax_tensors = torch.cat(
fp8_amax_x_tensor_list + fp8_amax_dL_dY_tensor_list
fp8_amax_x_tensor_list
+ fp8_amax_w_tensor_list
+ fp8_amax_dL_dY_tensor_list
)
all_reduced_amax_tensor = all_reduce(
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
Expand All @@ -302,12 +301,14 @@ def inner_func():
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()

(
reduced_fp8_amax_tensor,
reduced_fp8_amax_x_tensor,
reduced_fp8_amax_w_tensor,
reduced_fp8_amax_dL_dY_tensor,
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))

for idx, child in enumerate(fp8_layers):
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
child.fp8_amax_x.copy_(reduced_fp8_amax_x_tensor[idx])
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])

# We create two stacked tensor groups, one for the amax history and one for the current scales
Expand Down
181 changes: 180 additions & 1 deletion float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ScaledMMConfig,
)

from float8_experimental.float8_utils import EPS
from float8_experimental.float8_utils import e4m3_dtype, EPS
from torch._prims_common import suggest_memory_format


Expand Down Expand Up @@ -189,3 +189,182 @@ def fsdp_post_all_gather(
out._scale = scale
return
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)


class WeightWithDelayedFloat8CastTensor(torch.Tensor):
@staticmethod
def __new__(
cls,
tensor: torch.Tensor,
amax_buffer: torch.Tensor,
amax_history_buffer: torch.Tensor,
scale_buffer: torch.Tensor,
mm_config: ScaledMMConfig,
is_amax_initialized: bool,
):
return torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
memory_format=suggest_memory_format(tensor),
dtype=tensor.dtype,
layout=tensor.layout,
device=tensor.device,
pin_memory=tensor.is_pinned(),
requires_grad=tensor.requires_grad,
)

def __init__(
self,
tensor: torch.Tensor,
amax_buffer: torch.Tensor,
amax_history_buffer: torch.Tensor,
scale_buffer: torch.Tensor,
mm_config: ScaledMMConfig,
is_amax_initialized: bool,
):
self._tensor = tensor
self._amax_buffer = amax_buffer
self._amax_history_buffer = amax_history_buffer
self._scale_buffer = scale_buffer
self._mm_config = mm_config

# Note: is_amax_initialized is not a buffer to avoid data dependent
# control flow visible to dynamo
# TODO(future PR): add serialization for this flag
self.is_amax_initialized = is_amax_initialized

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func == torch.ops.aten.detach.default:
return WeightWithDelayedFloat8CastTensor(
args[0]._tensor,
args[0]._amax_buffer,
args[0]._amax_history_buffer,
args[0]._scale_buffer,
args[0]._mm_config,
args[0].is_amax_initialized,
)
mm_config: Optional[ScaledMMConfig] = None
amax_buffer: Optional[torch.Tensor] = None
amax_history_buffer: Optional[torch.Tensor] = None
scale_buffer: Optional[torch.Tensor] = None
is_amax_initialized: Optional[bool] = None

def unwrap(t):
nonlocal mm_config
if mm_config is None:
mm_config = t._mm_config
else:
mm_config = merge_mm_configs(mm_config, t._mm_config)
nonlocal amax_buffer
if amax_buffer is None:
amax_buffer = t._amax_buffer
nonlocal amax_history_buffer
if amax_history_buffer is None:
amax_history_buffer = t._amax_history_buffer
nonlocal scale_buffer
if scale_buffer is None:
scale_buffer = t._scale_buffer
nonlocal is_amax_initialized
if is_amax_initialized is None:
is_amax_initialized = t.is_amax_initialized
return t._tensor

args, kwargs = pytree.tree_map_only(
WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {})
)
out = func(*args, **kwargs)
if func not in _ops_to_preserve_subclass:
return out
return pytree.tree_map_only(
torch.Tensor,
lambda x: WeightWithDelayedFloat8CastTensor(
x,
amax_buffer,
amax_history_buffer,
scale_buffer,
mm_config,
is_amax_initialized,
),
out,
)

def __tensor_flatten__(self):
return (
[
"_tensor",
"_amax_buffer",
"_amax_history_buffer",
"_scale_buffer",
],
{
"mm_config": self._mm_config,
"is_amax_initialized": is_amax_initialized,
},
)

@staticmethod
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
return WeightWithDelayedFloat8CastTensor(
inner_tensors["_tensor"],
inner_tensors["_amax_buffer"],
inner_tensors["_amax_history_buffer"],
inner_tensors["_scale_buffer"],
metadata["mm_config"],
metadata["is_amax_initialized"],
)

def __repr__(self):
return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})"

def fsdp_pre_all_gather(self, mesh):
# initialize if needed
# TODO(before land): ensure settings are consistent between Float8Linear and here
if not self.is_amax_initialized:
from float8_experimental.float8_linear import (
_maybe_initialize_amaxes_scales_for_float8_cast,
)

_maybe_initialize_amaxes_scales_for_float8_cast(
self._tensor,
self._amax_buffer,
self._amax_history_buffer,
self._scale_buffer,
"max", # TODO(before land): read this from parent
e4m3_dtype,
self.is_amax_initialized,
reduce_amax=True,
)
self.is_amax_initialized = True

# this will:
# 1. cast the tensor to float8 using `_scale_buffer`
# 2. populate `_amax_buffer` inplace
# TODO(future PR): clean up all the casting functions and clearly
# separate dynamic vs delayed, tech debt has accumulated
float8_tensor = Float8Tensor.to_float8(
self._tensor,
self._scale_buffer,
e4m3_dtype,
self._amax_buffer,
self._mm_config,
)
return (float8_tensor._data,), (float8_tensor._scale,)

def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
):
(data,) = all_gather_outputs
(scale,) = metadata
if out is not None:
assert isinstance(out, Float8Tensor), f"{type(out)}"
out._scale = scale
return
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)
18 changes: 16 additions & 2 deletions test/test_fsdp2/test_fsdp2_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp


Expand All @@ -17,6 +22,7 @@ def check_parity_no_mp(
fsdp_optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
precompute: bool = False,
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
):
for iter_idx in range(10):
losses: List[torch.Tensor] = []
Expand All @@ -28,10 +34,18 @@ def check_parity_no_mp(
for param in model.parameters():
dist.all_reduce(param.grad)
param.grad.div_(dist.get_world_size())
# TODO(future): add amax syncing once delayed scaling is supported

if linear_requires_sync(scaling_type_w=scaling_type_w):
sync_float8_amax_and_scale_history(model)

optim.step()
if model is fsdp_model and precompute:
if (
model is fsdp_model
and precompute
and scaling_type_w is TensorScalingType.DYNAMIC
):
precompute_float8_dynamic_scale_for_fsdp(model)

test_cls.assertEqual(losses[0], losses[1])


Expand Down
Loading

0 comments on commit de93990

Please sign in to comment.