Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit d015c2e

Browse files
committed
add PrepareFloat8ModuleInput for sequence parallel
when applying Sequence Parallel to a module with more than 2 linear layers for input proj, we often want to transform from Shard to Replicate once (allgather once) and then reuse the allgathered result, for fp8 we would need to do the casting before the shard -> replicate so that we can perform the fp8 allgather. This PR subclasses the PrepareModuleInput to add the fp8 casting logic to make sure we run the fp8 allgather instead of bf16 allgather then do the casting for computation. Also adjust the test cases to test the real ffn case for sequence parallel
1 parent cdb7867 commit d015c2e

File tree

2 files changed

+101
-14
lines changed

2 files changed

+101
-14
lines changed

float8_experimental/float8_tensor_parallel.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
)
66
from torch.distributed._tensor import DTensor
77
from torch.distributed.device_mesh import DeviceMesh
8-
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
8+
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, PrepareModuleInput
99

1010
# subclass the ColwiseParallel and RowwiseParallel classes
1111
# to add the float8 support
@@ -109,3 +109,77 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
109109
)
110110

111111
return super()._apply(module, device_mesh)
112+
113+
114+
class PrepareFloat8ModuleInput(PrepareModuleInput):
115+
# subclass the PrepareModuleInput classes, the only difference is that after we prepare
116+
# the input DTensor, we cast the input to DTensor(Float8Tensor)
117+
def _prepare_input_fn(self, inputs, device_mesh):
118+
if self.input_layouts is None:
119+
return inputs
120+
prepared_inputs = []
121+
if not isinstance(inputs, tuple):
122+
inputs = (inputs,)
123+
if len(inputs) != len(self.input_layouts):
124+
raise ValueError("module inputs and input_layouts should have same length!")
125+
126+
assert self.desired_input_layouts is not None, "desired module inputs should not be None!"
127+
for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts):
128+
if input_layout is not None:
129+
if isinstance(inp, DTensor):
130+
# TODO: re-enable the check once we fix the compile path
131+
# assert inp.placements[0] == input_layout
132+
dt_inp = inp
133+
else:
134+
dt_inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False)
135+
136+
dt_inp = cast_to_float8_e4m3fn(
137+
dt_inp, self.fwd_linear_config
138+
) # DTensor(Float8Tensor)
139+
if desired_layout is not None and input_layout != desired_layout:
140+
# i.e. Shard -> Replicate: allgather
141+
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
142+
prepared_inputs.append(dt_inp.to_local() if self.use_local_output else dt_inp)
143+
else:
144+
prepared_inputs.append(inp)
145+
return tuple(prepared_inputs)
146+
147+
def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
148+
prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh)
149+
prepared_kwarg_inputs = {}
150+
for kwarg_key in kwarg_inputs.keys():
151+
kwarg_val = kwarg_inputs[kwarg_key]
152+
input_layout = None
153+
if kwarg_key in self.input_kwarg_layouts:
154+
input_layout = self.input_kwarg_layouts[kwarg_key]
155+
assert isinstance(kwarg_val, torch.Tensor), f"input of key {kwarg_key} to the module should be a Tensor!"
156+
kwarg_val = DTensor.from_local(kwarg_val, device_mesh, (input_layout,), run_check=False)
157+
158+
kwarg_val = cast_to_float8_e4m3fn(
159+
kwarg_val, self.fwd_linear_config
160+
) # DTensor(Float8Tensor)
161+
if kwarg_key in self.desired_input_kwarg_layouts:
162+
desired_layout = self.desired_input_kwarg_layouts[kwarg_key]
163+
if desired_layout != input_layout:
164+
kwarg_val = kwarg_val.redistribute(placements=(desired_layout,))
165+
166+
prepared_kwarg_inputs[kwarg_key] = kwarg_val.to_local() if self.use_local_output else kwarg_val
167+
else:
168+
prepared_kwarg_inputs[kwarg_key] = kwarg_val
169+
170+
return (prepared_arg_inputs, prepared_kwarg_inputs)
171+
172+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
173+
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
174+
# search for ScaledMM configs for all the submodules and make sure they are the same
175+
fwd_linear_config = None
176+
for mod in module.modules():
177+
if isinstance(mod, Float8DynamicLinear):
178+
if fwd_linear_config is None:
179+
fwd_linear_config = mod.forward_config
180+
else:
181+
assert fwd_linear_config == mod.forward_config, "All the Float8DynamicLinear modules should have same forward config!"
182+
183+
self.fwd_linear_config = fwd_linear_config
184+
super()._apply(module, device_mesh)
185+
return module

test/test_dtensor.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
import torch.nn as nn
15+
import torch.nn.functional as F
1516

1617
from float8_experimental.float8_dynamic_linear import (
1718
Float8DynamicLinear,
@@ -22,6 +23,7 @@
2223
from float8_experimental.float8_tensor_parallel import (
2324
Float8ColwiseParallel,
2425
Float8RowwiseParallel,
26+
PrepareFloat8ModuleInput
2527
)
2628
from float8_experimental.float8_utils import tensor_to_scale
2729
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
@@ -38,17 +40,25 @@ def setup_distributed():
3840
return device_mesh
3941

4042

41-
class ToyModel(nn.Module):
43+
class FeedForward(nn.Module):
4244
"""MLP based model"""
4345

46+
def __init__(self):
47+
super(FeedForward, self).__init__()
48+
self.w1 = nn.Linear(16, 32, bias=False)
49+
self.w2 = nn.Linear(16, 32, bias=False)
50+
self.out_proj = nn.Linear(32, 16, bias=False)
51+
52+
def forward(self, x):
53+
return self.out_proj(F.silu(self.w1(x)) * self.w2(x))
54+
55+
class ToyModel(nn.Module):
4456
def __init__(self):
4557
super(ToyModel, self).__init__()
46-
self.in_proj = nn.Linear(16, 32)
47-
self.relu = nn.ReLU()
48-
self.out_proj = nn.Linear(32, 16)
58+
self.ffn = FeedForward()
4959

5060
def forward(self, x):
51-
return self.out_proj(self.relu(self.in_proj(x)))
61+
return self.ffn(x)
5262

5363

5464
def test_scaled_mm(mesh: DeviceMesh, size=16):
@@ -182,8 +192,9 @@ def test_fp8_mlp_tensor_parallelism_base(
182192
tp_model,
183193
mesh,
184194
{
185-
"in_proj": Float8ColwiseParallel(),
186-
"out_proj": Float8RowwiseParallel(),
195+
"ffn.w1": Float8ColwiseParallel(),
196+
"ffn.w2": Float8ColwiseParallel(),
197+
"ffn.out_proj": Float8RowwiseParallel(),
187198
},
188199
)
189200

@@ -192,17 +203,19 @@ def test_fp8_mlp_tensor_parallelism_base(
192203
sp_model,
193204
mesh,
194205
{
195-
"in_proj": Float8ColwiseParallel(input_layouts=Shard(0)),
196-
"out_proj": Float8RowwiseParallel(
197-
output_layouts=Shard(0), use_local_output=False
206+
"ffn": PrepareFloat8ModuleInput(input_layouts=Shard(1), desired_input_layouts=Replicate()),
207+
"ffn.w1": Float8ColwiseParallel(),
208+
"ffn.w2": Float8ColwiseParallel(),
209+
"ffn.out_proj": Float8RowwiseParallel(
210+
output_layouts=Shard(1), use_local_output=False
198211
),
199212
},
200213
)
201214

202215
if compile:
203216
tp_model = torch.compile(tp_model)
204217

205-
x_fp32 = torch.rand(size * 2, size, device=device, requires_grad=False)
218+
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
206219
x_fp32_tp_input = x_fp32.clone()
207220
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])
208221

@@ -215,10 +228,10 @@ def test_fp8_mlp_tensor_parallelism_base(
215228
torch.testing.assert_close(tp_out, global_out)
216229
torch.testing.assert_close(sp_out.full_tensor(), global_out)
217230
torch.testing.assert_close(
218-
tp_model.in_proj.weight.grad, sp_model.in_proj.weight.grad
231+
tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad
219232
)
220233
torch.testing.assert_close(
221-
tp_model.out_proj.weight.grad, sp_model.out_proj.weight.grad
234+
tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
222235
)
223236

224237

0 commit comments

Comments
 (0)