Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
add PrepareFloat8ModuleInput for sequence parallel
Browse files Browse the repository at this point in the history
when applying Sequence Parallel to a module with more than 2 linear
layers for input proj, we often want to transform from Shard to
Replicate once (allgather once) and then reuse the allgathered result,
for fp8 we would need to do the casting before the shard -> replicate so
that we can perform the fp8 allgather.

This PR subclasses the PrepareModuleInput to add the fp8 casting logic
to make sure we run the fp8 allgather instead of bf16 allgather then do
the casting for computation.

Also adjust the test cases to test the real ffn case for sequence
parallel
  • Loading branch information
wanchaol committed Jun 9, 2024
1 parent cdb7867 commit d015c2e
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 14 deletions.
76 changes: 75 additions & 1 deletion float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
)
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, PrepareModuleInput

# subclass the ColwiseParallel and RowwiseParallel classes
# to add the float8 support
Expand Down Expand Up @@ -109,3 +109,77 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
)

return super()._apply(module, device_mesh)


class PrepareFloat8ModuleInput(PrepareModuleInput):
# subclass the PrepareModuleInput classes, the only difference is that after we prepare
# the input DTensor, we cast the input to DTensor(Float8Tensor)
def _prepare_input_fn(self, inputs, device_mesh):
if self.input_layouts is None:
return inputs
prepared_inputs = []
if not isinstance(inputs, tuple):
inputs = (inputs,)
if len(inputs) != len(self.input_layouts):
raise ValueError("module inputs and input_layouts should have same length!")

assert self.desired_input_layouts is not None, "desired module inputs should not be None!"
for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts):
if input_layout is not None:
if isinstance(inp, DTensor):
# TODO: re-enable the check once we fix the compile path
# assert inp.placements[0] == input_layout
dt_inp = inp
else:
dt_inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False)

dt_inp = cast_to_float8_e4m3fn(
dt_inp, self.fwd_linear_config
) # DTensor(Float8Tensor)
if desired_layout is not None and input_layout != desired_layout:
# i.e. Shard -> Replicate: allgather
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
prepared_inputs.append(dt_inp.to_local() if self.use_local_output else dt_inp)
else:
prepared_inputs.append(inp)
return tuple(prepared_inputs)

def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh)
prepared_kwarg_inputs = {}
for kwarg_key in kwarg_inputs.keys():
kwarg_val = kwarg_inputs[kwarg_key]
input_layout = None
if kwarg_key in self.input_kwarg_layouts:
input_layout = self.input_kwarg_layouts[kwarg_key]
assert isinstance(kwarg_val, torch.Tensor), f"input of key {kwarg_key} to the module should be a Tensor!"
kwarg_val = DTensor.from_local(kwarg_val, device_mesh, (input_layout,), run_check=False)

kwarg_val = cast_to_float8_e4m3fn(
kwarg_val, self.fwd_linear_config
) # DTensor(Float8Tensor)
if kwarg_key in self.desired_input_kwarg_layouts:
desired_layout = self.desired_input_kwarg_layouts[kwarg_key]
if desired_layout != input_layout:
kwarg_val = kwarg_val.redistribute(placements=(desired_layout,))

prepared_kwarg_inputs[kwarg_key] = kwarg_val.to_local() if self.use_local_output else kwarg_val
else:
prepared_kwarg_inputs[kwarg_key] = kwarg_val

return (prepared_arg_inputs, prepared_kwarg_inputs)

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
# search for ScaledMM configs for all the submodules and make sure they are the same
fwd_linear_config = None
for mod in module.modules():
if isinstance(mod, Float8DynamicLinear):
if fwd_linear_config is None:
fwd_linear_config = mod.forward_config
else:
assert fwd_linear_config == mod.forward_config, "All the Float8DynamicLinear modules should have same forward config!"

self.fwd_linear_config = fwd_linear_config
super()._apply(module, device_mesh)
return module
39 changes: 26 additions & 13 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from float8_experimental.float8_dynamic_linear import (
Float8DynamicLinear,
Expand All @@ -22,6 +23,7 @@
from float8_experimental.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput
)
from float8_experimental.float8_utils import tensor_to_scale
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
Expand All @@ -38,17 +40,25 @@ def setup_distributed():
return device_mesh


class ToyModel(nn.Module):
class FeedForward(nn.Module):
"""MLP based model"""

def __init__(self):
super(FeedForward, self).__init__()
self.w1 = nn.Linear(16, 32, bias=False)
self.w2 = nn.Linear(16, 32, bias=False)
self.out_proj = nn.Linear(32, 16, bias=False)

def forward(self, x):
return self.out_proj(F.silu(self.w1(x)) * self.w2(x))

class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.in_proj = nn.Linear(16, 32)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(32, 16)
self.ffn = FeedForward()

def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))
return self.ffn(x)


def test_scaled_mm(mesh: DeviceMesh, size=16):
Expand Down Expand Up @@ -182,8 +192,9 @@ def test_fp8_mlp_tensor_parallelism_base(
tp_model,
mesh,
{
"in_proj": Float8ColwiseParallel(),
"out_proj": Float8RowwiseParallel(),
"ffn.w1": Float8ColwiseParallel(),
"ffn.w2": Float8ColwiseParallel(),
"ffn.out_proj": Float8RowwiseParallel(),
},
)

Expand All @@ -192,17 +203,19 @@ def test_fp8_mlp_tensor_parallelism_base(
sp_model,
mesh,
{
"in_proj": Float8ColwiseParallel(input_layouts=Shard(0)),
"out_proj": Float8RowwiseParallel(
output_layouts=Shard(0), use_local_output=False
"ffn": PrepareFloat8ModuleInput(input_layouts=Shard(1), desired_input_layouts=Replicate()),
"ffn.w1": Float8ColwiseParallel(),
"ffn.w2": Float8ColwiseParallel(),
"ffn.out_proj": Float8RowwiseParallel(
output_layouts=Shard(1), use_local_output=False
),
},
)

if compile:
tp_model = torch.compile(tp_model)

x_fp32 = torch.rand(size * 2, size, device=device, requires_grad=False)
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
x_fp32_tp_input = x_fp32.clone()
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])

Expand All @@ -215,10 +228,10 @@ def test_fp8_mlp_tensor_parallelism_base(
torch.testing.assert_close(tp_out, global_out)
torch.testing.assert_close(sp_out.full_tensor(), global_out)
torch.testing.assert_close(
tp_model.in_proj.weight.grad, sp_model.in_proj.weight.grad
tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad
)
torch.testing.assert_close(
tp_model.out_proj.weight.grad, sp_model.out_proj.weight.grad
tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
)


Expand Down

0 comments on commit d015c2e

Please sign in to comment.