|
5 | 5 | )
|
6 | 6 | from torch.distributed._tensor import DTensor
|
7 | 7 | from torch.distributed.device_mesh import DeviceMesh
|
8 |
| -from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel |
| 8 | +from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, PrepareModuleInput |
9 | 9 |
|
10 | 10 | # subclass the ColwiseParallel and RowwiseParallel classes
|
11 | 11 | # to add the float8 support
|
@@ -109,3 +109,77 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
109 | 109 | )
|
110 | 110 |
|
111 | 111 | return super()._apply(module, device_mesh)
|
| 112 | + |
| 113 | + |
| 114 | +class PrepareFloat8ModuleInput(PrepareModuleInput): |
| 115 | + # subclass the PrepareModuleInput classes, the only difference is that after we prepare |
| 116 | + # the input DTensor, we cast the input to DTensor(Float8Tensor) |
| 117 | + def _prepare_input_fn(self, inputs, device_mesh): |
| 118 | + if self.input_layouts is None: |
| 119 | + return inputs |
| 120 | + prepared_inputs = [] |
| 121 | + if not isinstance(inputs, tuple): |
| 122 | + inputs = (inputs,) |
| 123 | + if len(inputs) != len(self.input_layouts): |
| 124 | + raise ValueError("module inputs and input_layouts should have same length!") |
| 125 | + |
| 126 | + assert self.desired_input_layouts is not None, "desired module inputs should not be None!" |
| 127 | + for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts): |
| 128 | + if input_layout is not None: |
| 129 | + if isinstance(inp, DTensor): |
| 130 | + # TODO: re-enable the check once we fix the compile path |
| 131 | + # assert inp.placements[0] == input_layout |
| 132 | + dt_inp = inp |
| 133 | + else: |
| 134 | + dt_inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False) |
| 135 | + |
| 136 | + dt_inp = cast_to_float8_e4m3fn( |
| 137 | + dt_inp, self.fwd_linear_config |
| 138 | + ) # DTensor(Float8Tensor) |
| 139 | + if desired_layout is not None and input_layout != desired_layout: |
| 140 | + # i.e. Shard -> Replicate: allgather |
| 141 | + dt_inp = dt_inp.redistribute(placements=(desired_layout,)) |
| 142 | + prepared_inputs.append(dt_inp.to_local() if self.use_local_output else dt_inp) |
| 143 | + else: |
| 144 | + prepared_inputs.append(inp) |
| 145 | + return tuple(prepared_inputs) |
| 146 | + |
| 147 | + def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): |
| 148 | + prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh) |
| 149 | + prepared_kwarg_inputs = {} |
| 150 | + for kwarg_key in kwarg_inputs.keys(): |
| 151 | + kwarg_val = kwarg_inputs[kwarg_key] |
| 152 | + input_layout = None |
| 153 | + if kwarg_key in self.input_kwarg_layouts: |
| 154 | + input_layout = self.input_kwarg_layouts[kwarg_key] |
| 155 | + assert isinstance(kwarg_val, torch.Tensor), f"input of key {kwarg_key} to the module should be a Tensor!" |
| 156 | + kwarg_val = DTensor.from_local(kwarg_val, device_mesh, (input_layout,), run_check=False) |
| 157 | + |
| 158 | + kwarg_val = cast_to_float8_e4m3fn( |
| 159 | + kwarg_val, self.fwd_linear_config |
| 160 | + ) # DTensor(Float8Tensor) |
| 161 | + if kwarg_key in self.desired_input_kwarg_layouts: |
| 162 | + desired_layout = self.desired_input_kwarg_layouts[kwarg_key] |
| 163 | + if desired_layout != input_layout: |
| 164 | + kwarg_val = kwarg_val.redistribute(placements=(desired_layout,)) |
| 165 | + |
| 166 | + prepared_kwarg_inputs[kwarg_key] = kwarg_val.to_local() if self.use_local_output else kwarg_val |
| 167 | + else: |
| 168 | + prepared_kwarg_inputs[kwarg_key] = kwarg_val |
| 169 | + |
| 170 | + return (prepared_arg_inputs, prepared_kwarg_inputs) |
| 171 | + |
| 172 | + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
| 173 | + from float8_experimental.float8_dynamic_linear import Float8DynamicLinear |
| 174 | + # search for ScaledMM configs for all the submodules and make sure they are the same |
| 175 | + fwd_linear_config = None |
| 176 | + for mod in module.modules(): |
| 177 | + if isinstance(mod, Float8DynamicLinear): |
| 178 | + if fwd_linear_config is None: |
| 179 | + fwd_linear_config = mod.forward_config |
| 180 | + else: |
| 181 | + assert fwd_linear_config == mod.forward_config, "All the Float8DynamicLinear modules should have same forward config!" |
| 182 | + |
| 183 | + self.fwd_linear_config = fwd_linear_config |
| 184 | + super()._apply(module, device_mesh) |
| 185 | + return module |
0 commit comments