Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

add PrepareFloat8ModuleInput for sequence parallel #275

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import torch
import torch.nn as nn
from float8_experimental.float8_dynamic_linear import (
cast_to_float8_e4m3fn,
cast_to_float8_e5m2_bw,
)
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
)

# subclass the ColwiseParallel and RowwiseParallel classes
# to add the float8 support
Expand Down Expand Up @@ -109,3 +114,93 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
)

return super()._apply(module, device_mesh)


class PrepareFloat8ModuleInput(PrepareModuleInput):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can have e4m3 in the name, and maybe add a TODO to support the AMD version of e4m3 eventually?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a quick docblock to explain that this is ensuring the float8 cast happens before the all-gather if there are multiple float8 users of the input activation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can have e4m3 in the name, and maybe add a TODO to support the AMD version of e4m3 eventually

I wonder what's your thought on these two choice: 1. make e4m3 appears in the name of this class 2. make this class constructor take an additional argument of fp8 dtype, i.e. float8_dtype=torch.float8_e4m3fn, and we default to this e4m3fn dtype, and then later we can add on the AMD version of e4m3 by passing a different float8_dtype` arg?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this class constructor take an additional argument of fp8 dtype

sgtm

# subclass the PrepareModuleInput classes to implement fp8 specific logic, the only difference is that
# after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor)
# This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate)
# so that if there are multiple float8 users of the input activation, we perform fp8 allgather
# only once.
# FP8 Args:
# float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input,
# we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn
# fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used
# for the float8 cast. If not specified, we will search for the Float8DynamicLinear in the submodules
# and use the forward config from that module, in this case all module's forward config must be
# the same.

def __init__(
self,
*,
input_layouts=None,
desired_input_layouts=None,
input_kwarg_layouts=None,
desired_input_kwarg_layouts=None,
use_local_output=False,
float8_dtype=torch.float8_e4m3fn,
fwd_config_submodule_fqn=None,
wanchaol marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(
input_layouts=input_layouts,
desired_input_layouts=desired_input_layouts,
input_kwarg_layouts=input_kwarg_layouts,
desired_input_kwarg_layouts=desired_input_kwarg_layouts,
use_local_output=use_local_output,
)

# fp8 specific fields
self.float8_dtype = float8_dtype
self.fwd_config_submodule_fqn = fwd_config_submodule_fqn

if self.float8_dtype != torch.float8_e4m3fn:
raise NotImplementedError(
"PrepareFloat8ModuleInput only support casting to float8_e4m3fn for now"
)

def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
if input_layout is not None:
if isinstance(input, DTensor):
# TODO: re-enable the check once we fix the compile path
# assert inp.placements[0] == input_layout
dt_inp = input
else:
assert isinstance(
input, torch.Tensor
), "expecting input to be a torch.Tensor!"
dt_inp = DTensor.from_local(
input, mesh, (input_layout,), run_check=False
)

dt_inp = cast_to_float8_e4m3fn(
dt_inp, self.fwd_linear_config
) # DTensor(Float8Tensor)
if desired_layout is not None and input_layout != desired_layout:
dt_inp = dt_inp.redistribute(placements=(desired_layout,))

return dt_inp.to_local() if self.use_local_output else dt_inp
else:
return input

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear

fwd_linear_config = None
if self.fwd_config_submodule_fqn is not None:
fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn)
assert isinstance(fwd_linear, Float8DynamicLinear)
fwd_linear_config = fwd_linear.forward_config
else:
# search for ScaledMM configs for all the submodules and make sure they are the same
for mod in module.modules():
if isinstance(mod, Float8DynamicLinear):
if fwd_linear_config is None:
fwd_linear_config = mod.forward_config
else:
assert (
fwd_linear_config == mod.forward_config
), "All the Float8DynamicLinear modules should have same forward config!"

self.fwd_linear_config = fwd_linear_config
super()._apply(module, device_mesh)
return module
75 changes: 62 additions & 13 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from float8_experimental.float8_dynamic_linear import (
Float8DynamicLinear,
Expand All @@ -22,6 +23,7 @@
from float8_experimental.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from float8_experimental.float8_utils import tensor_to_scale
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
Expand All @@ -38,17 +40,26 @@ def setup_distributed():
return device_mesh


class ToyModel(nn.Module):
class FeedForward(nn.Module):
"""MLP based model"""

def __init__(self):
super(FeedForward, self).__init__()
self.w1 = nn.Linear(16, 32, bias=False)
self.w2 = nn.Linear(16, 32, bias=False)
self.out_proj = nn.Linear(32, 16, bias=False)

def forward(self, x):
return self.out_proj(F.silu(self.w1(x)) * self.w2(x))


class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.in_proj = nn.Linear(16, 32)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(32, 16)
self.ffn = FeedForward()

def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))
return self.ffn(x)


def test_scaled_mm(mesh: DeviceMesh, size=16):
Expand Down Expand Up @@ -182,8 +193,9 @@ def test_fp8_mlp_tensor_parallelism_base(
tp_model,
mesh,
{
"in_proj": Float8ColwiseParallel(),
"out_proj": Float8RowwiseParallel(),
"ffn.w1": Float8ColwiseParallel(),
"ffn.w2": Float8ColwiseParallel(),
"ffn.out_proj": Float8RowwiseParallel(),
},
)

Expand All @@ -192,17 +204,46 @@ def test_fp8_mlp_tensor_parallelism_base(
sp_model,
mesh,
{
"in_proj": Float8ColwiseParallel(input_layouts=Shard(0)),
"out_proj": Float8RowwiseParallel(
output_layouts=Shard(0), use_local_output=False
"ffn": PrepareFloat8ModuleInput(
wanchaol marked this conversation as resolved.
Show resolved Hide resolved
input_layouts=Shard(1), desired_input_layouts=Replicate()
),
"ffn.w1": Float8ColwiseParallel(),
"ffn.w2": Float8ColwiseParallel(),
"ffn.out_proj": Float8RowwiseParallel(
output_layouts=Shard(1), use_local_output=False
),
},
)

# PrepareFloat8ModuleInput with specific submodule fqn
sp_model2 = copy.deepcopy(toy_model)
sp_model2 = swap_linear_with_float8_linear(
sp_model2, Float8DynamicLinear, emulate=True
)

sp_model2 = parallelize_module(
sp_model2,
mesh,
{
"ffn": PrepareFloat8ModuleInput(
input_layouts=Shard(1),
desired_input_layouts=Replicate(),
fwd_config_submodule_fqn="w2",
),
"ffn.w1": Float8ColwiseParallel(),
"ffn.w2": Float8ColwiseParallel(),
"ffn.out_proj": Float8RowwiseParallel(
output_layouts=Shard(1), use_local_output=False
),
},
)

if compile:
tp_model = torch.compile(tp_model)
sp_model = torch.compile(sp_model)
sp_model2 = torch.compile(sp_model2)

x_fp32 = torch.rand(size * 2, size, device=device, requires_grad=False)
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
x_fp32_tp_input = x_fp32.clone()
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])

Expand All @@ -214,11 +255,19 @@ def test_fp8_mlp_tensor_parallelism_base(
global_out.sum().backward()
torch.testing.assert_close(tp_out, global_out)
torch.testing.assert_close(sp_out.full_tensor(), global_out)
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
torch.testing.assert_close(
tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
)

sp_out2 = sp_model2(x_fp32_sp_input)
sp_out2.sum().backward()
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
torch.testing.assert_close(
tp_model.in_proj.weight.grad, sp_model.in_proj.weight.grad
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad
)
torch.testing.assert_close(
tp_model.out_proj.weight.grad, sp_model.out_proj.weight.grad
tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad
)


Expand Down
Loading