Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 31877bb

Browse files
drisspgfacebook-github-bot
authored andcommitted
Allow for modifying the scaled_mm compute (#144)
Summary: This does two things: 1. Creates a new named_tuple type `ScaledMMConfig` that is used to control the behavior of the scaled_mm op. This includes, emulate, fast_accumulation, and fp8_out_dtype(the latter is not currently used). It replaces the emulate arg and strings it through all the relevant infra, and updates test accordingly. 2. This adds the fp8 fast accum mode and enables it for the forward path and not the backward pass. ### Performance With settings use_fast_accum in the forward using the linear_float8 benchmark: ![image](https://github.com/pytorch-labs/float8_experimental/assets/32754868/8510814e-88d0-402c-9676-d4afe8fef2a0) | | shape | Speedup_with_False | Speedup_with_True | Percentage_Gain | |---:|:--------------------|---------------------:|--------------------:|------------------:| | 0 | (16384, 1024, 8192) | 1.19086 | 1.26397 | 6.13912 | | 1 | (16384, 3584, 8192) | 1.42227 | 1.48921 | 4.70629 | | 2 | (16384, 8192, 1280) | 0.970685 | 0.986167 | 1.59497 | | 3 | (16384, 8192, 7168) | 1.50755 | 1.54886 | 2.74022 | Pull Request resolved: #144 Reviewed By: vkuzo Differential Revision: D55906764 Pulled By: drisspg fbshipit-source-id: c6c7f7d5f7831bc594c8e70c71d9ab0e0c90755c
1 parent 14da04f commit 31877bb

File tree

10 files changed

+219
-116
lines changed

10 files changed

+219
-116
lines changed

README.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,17 @@ for _ in range(N_ITER):
9090
optimizer.step()
9191
```
9292

93-
# code tips
93+
# 🧭 Code Organization
9494

95-
* `float8_experimental/float8_linear.py` - `Float8Linear` (main user facing entry point for delayed scaling)
96-
* `float8_experimental/float8_dynamic_linear.py` - `Float8DynamicLinear` (main user facing entry point for dynamic scaling)
97-
* `float8_experimental/float8_tensor.py` - `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
95+
* `float8_experimental/float8_linear.py`
96+
- `Float8Linear` (main user facing entry point for delayed scaling)
97+
* `float8_experimental/float8_dynamic_linear.py`
98+
- `Float8DynamicLinear` (main user facing entry point for dynamic scaling)
99+
* `float8_experimental/float8_tensor.py`
100+
- `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
101+
- `ScaledMMConfig` defines the semantics for matmul in the forward and backwards pass
98102

99-
# testing
103+
# Testing
100104

101105
```bash
102106
# run single-GPU unit tests
@@ -117,7 +121,7 @@ pytest test/test_compile.py
117121
./test/run_everything.sh
118122
```
119123

120-
# benchmarking
124+
# Benchmarking
121125

122126
```bash
123127
# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes
@@ -130,4 +134,3 @@ pytest test/test_compile.py
130134

131135
# License
132136
PyTorch has a BSD 3-Clause License, as found in the LICENSE file.
133-

benchmarks/bench_linear_float8.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch.utils.benchmark as benchmark
1717
from float8_experimental.float8_linear import Float8Linear
1818
from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history
19+
from float8_experimental.float8_tensor import ScaledMMConfig
1920
from tqdm import tqdm
2021

2122
# estimating TOPs for matmuls in fp32, fp16, fp8
@@ -54,8 +55,8 @@ class Experiment:
5455
ref_time_sec: float
5556
float8_time_sec: float
5657
dtype: torch.dtype
57-
compiled: bool = False
58-
float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn
58+
compiled: bool
59+
use_fast_accum: bool
5960

6061
# 3 Times since we are calculating forward backward
6162
@property
@@ -74,7 +75,7 @@ def float8_tops_sec(self):
7475

7576
@property
7677
def float8_pct_top_peak(self):
77-
return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype]
78+
return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn]
7879

7980

8081
def main(
@@ -95,9 +96,10 @@ def main(
9596
}
9697
input_bias = False
9798
ref_dtypes = [torch.bfloat16, torch.float16]
99+
use_fast_accum = [True, False]
98100
experiment_list: List[Experiment] = []
99-
for idx, (dtype, (name, (K, N))) in enumerate(
100-
tqdm(list(product(ref_dtypes, name_to_shapes_70b.items())))
101+
for idx, (dtype, fast_accum, (name, (K, N))) in enumerate(
102+
tqdm(list(product(ref_dtypes, use_fast_accum, name_to_shapes_70b.items())))
101103
):
102104
if n_limit is not None and idx >= n_limit:
103105
break
@@ -108,6 +110,10 @@ def main(
108110
linear_float8 = Float8Linear.from_float(
109111
copy.deepcopy(linear_ref), emulate=False
110112
)
113+
if fast_accum:
114+
linear_float8.forward_config = ScaledMMConfig(False, True, False)
115+
else:
116+
linear_float8.forward_config = ScaledMMConfig(False, False, False)
111117

112118
bsz, seq_len = 4, 4096
113119
M = bsz * seq_len
@@ -155,6 +161,7 @@ def wrapper(*args, **kwargs):
155161
float8_time,
156162
dtype,
157163
compile,
164+
use_fast_accum=fast_accum,
158165
)
159166
print(experiment)
160167
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
@@ -168,7 +175,7 @@ def wrapper(*args, **kwargs):
168175
"N",
169176
"ref_dtype",
170177
"compiled",
171-
"fp8_dtype",
178+
"use_fast_accum",
172179
"ref_time_sec",
173180
"pt_fp8_time_sec",
174181
"ref_tops_sec",
@@ -186,7 +193,7 @@ def wrapper(*args, **kwargs):
186193
experiment.shape[2],
187194
experiment.dtype,
188195
experiment.compiled,
189-
experiment.float_8_dtype,
196+
experiment.use_fast_accum,
190197
experiment.ref_time_sec,
191198
experiment.float8_time_sec,
192199
experiment.ref_tops_sec,
@@ -214,6 +221,7 @@ def wrapper(*args, **kwargs):
214221
"shape",
215222
"ref_dtype",
216223
"compiled",
224+
"use_fast_accum",
217225
"ref_time_sec",
218226
"pt_fp8_time_sec",
219227
"pt_fp8_speedup",

float8_experimental/float8_dynamic_linear.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from float8_experimental.float8_tensor import (
1212
Float8Tensor,
13+
ScaledMMConfig,
1314
tensor_already_casted_to_fp8,
1415
to_fp8_no_autograd,
1516
)
@@ -27,9 +28,9 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
2728
def forward(
2829
ctx,
2930
tensor,
30-
emulate: bool,
31+
mm_config: ScaledMMConfig,
3132
):
32-
ctx.emulate = emulate
33+
ctx.mm_config = mm_config
3334
return tensor
3435

3536
@staticmethod
@@ -39,7 +40,7 @@ def backward(ctx, gradY):
3940
return gradY, None
4041
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
4142
fp8_tensor = to_fp8_no_autograd(
42-
gradY, gradY_scale, torch.float8_e5m2, ctx.emulate
43+
gradY, gradY_scale, torch.float8_e5m2, mm_config=ctx.mm_config
4344
)
4445
return fp8_tensor, None
4546

@@ -73,11 +74,11 @@ def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor:
7374
return inpt_tensor
7475
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
7576
return Float8Tensor.to_float8(
76-
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
77+
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=self.forward_config
7778
)
7879

7980
def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
80-
return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate)
81+
return NoopFwToFloat8E5M2Bw.apply(gradY, self.backward_config)
8182

8283
@classmethod
8384
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
@@ -97,5 +98,6 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
9798
new_mod = cls(**super_kwargs)
9899
new_mod.weight = mod.weight
99100
new_mod.bias = mod.bias
100-
new_mod.emulate = emulate
101+
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
102+
new_mod.backward_config = ScaledMMConfig(emulate, False)
101103
return new_mod

float8_experimental/float8_linear.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020

2121
import torch
2222

23-
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
23+
from float8_experimental.float8_tensor import (
24+
Float8Tensor,
25+
ScaledMMConfig,
26+
to_fp8_no_autograd,
27+
)
2428

2529
from float8_experimental.float8_utils import (
2630
amax_history_to_scale,
@@ -73,12 +77,12 @@ def forward(
7377
fp8_scale_dL_dY,
7478
scale_fn_name,
7579
is_amax_initialized,
76-
emulate: bool,
80+
mm_config: ScaledMMConfig,
7781
):
7882
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
7983
ctx.scale_fn_name = scale_fn_name
8084
ctx.is_amax_initialized = is_amax_initialized
81-
ctx.emulate = emulate
85+
ctx.mm_config = mm_config
8286
return tensor
8387

8488
@staticmethod
@@ -99,7 +103,9 @@ def backward(ctx, go):
99103

100104
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
101105

102-
res = to_fp8_no_autograd(go, fp8_scale_dL_dY, torch.float8_e5m2, ctx.emulate)
106+
res = to_fp8_no_autograd(
107+
go, fp8_scale_dL_dY, torch.float8_e5m2, mm_config=ctx.mm_config
108+
)
103109
empty_grads = None, None, None, None, None, None
104110
return res, *empty_grads
105111

@@ -154,8 +160,9 @@ def __init__(self, *args, **kwargs):
154160
)
155161
self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor([1.0]))
156162

157-
# Whether to emulate the fp8 matmul logic in float32
158-
self.emulate = False
163+
# Defines the behavior of the matmul in the forward and backward pass
164+
self.forward_config = ScaledMMConfig()
165+
self.backward_config = ScaledMMConfig()
159166

160167
# Note: is_amax_initialized is not a buffer to avoid data dependent
161168
# control flow visible to dynamo
@@ -216,7 +223,11 @@ def cast_x_to_float8(
216223
is_amax_initialized,
217224
)
218225
x_fp8 = Float8Tensor.to_float8(
219-
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x, self.emulate
226+
x,
227+
self.fp8_scale_x,
228+
torch.float8_e4m3fn,
229+
self.fp8_amax_x,
230+
self.forward_config,
220231
)
221232
return x_fp8
222233

@@ -239,13 +250,11 @@ def cast_w_to_float8(
239250
self.fp8_scale_w,
240251
torch.float8_e4m3fn,
241252
self.fp8_amax_w,
242-
self.emulate,
253+
self.forward_config,
243254
)
244255
return w_fp8
245256

246-
def cast_y_to_float8_in_bw(
247-
self, y: torch.Tensor, emulate: bool = False
248-
) -> torch.Tensor:
257+
def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
249258
scale_fn_name = self.recipe.scale_fn_name
250259
y = NoopFwToFloat8E5M2Bw.apply(
251260
y,
@@ -254,7 +263,7 @@ def cast_y_to_float8_in_bw(
254263
self.fp8_scale_dL_dY,
255264
scale_fn_name,
256265
self.is_amax_initialized,
257-
emulate,
266+
self.backward_config,
258267
)
259268
return y
260269

@@ -295,7 +304,7 @@ def forward(self, x):
295304
y = torch.matmul(x_fp8, w_fp8.t())
296305

297306
# Cast gradY to float8_e5m2 during backward
298-
y = self.cast_y_to_float8_in_bw(y, self.emulate)
307+
y = self.cast_y_to_float8_in_bw(y)
299308

300309
if self.bias is not None:
301310
y = y + self.bias.to(y.dtype)
@@ -318,7 +327,12 @@ def from_float(cls, mod, emulate: bool = False):
318327
new_mod = cls(mod.in_features, mod.out_features, bias=False)
319328
new_mod.weight = mod.weight
320329
new_mod.bias = mod.bias
321-
new_mod.emulate = emulate
330+
331+
# Defines the behavior of the matmul in the forward and backward
332+
# Forward we use fast_accum, backwards we do not
333+
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
334+
new_mod.backward_config = ScaledMMConfig(emulate, False)
335+
322336
# I think its okay to send all params and buffers to device
323337
new_mod.to(mod.weight.device)
324338
return new_mod

float8_experimental/float8_ops.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
import torch
99

1010
from float8_experimental.float8_python_api import addmm_float8_unwrapped
11-
from float8_experimental.float8_tensor import Float8Tensor
11+
from float8_experimental.float8_tensor import (
12+
Float8Tensor,
13+
merge_mm_configs,
14+
ScaledMMConfig,
15+
)
1216
from float8_experimental.float8_utils import is_row_major
1317
from torch.utils._pytree import tree_map
1418

@@ -41,7 +45,9 @@ def decorator(func):
4145
)
4246
def float8_desugar_op(aten_op, args, kwargs=None):
4347
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
44-
return Float8Tensor(new_data, args[0]._scale, args[0]._orig_dtype, args[0]._emulate)
48+
return Float8Tensor(
49+
new_data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config
50+
)
4551

4652

4753
@implements([aten.sum.dim_IntList])
@@ -89,13 +95,22 @@ def float8_mm(aten_op, args, kwargs=None):
8995
)
9096
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
9197
output_dtype = a._orig_dtype
92-
if a._emulate:
93-
assert a._emulate == b._emulate
98+
a_mm_config: ScaledMMConfig = a._mm_config
99+
b_mm_config: ScaledMMConfig = b._mm_config
100+
mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config)
101+
if mm_config.emulate:
94102
return torch.ops.aten.mm_float8_emulated(
95103
a._data, a._scale, b._data, b._scale, output_dtype
96104
)[0]
97105
tensor_out, amax = addmm_float8_unwrapped(
98-
a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=None
106+
a_data,
107+
a_scale,
108+
b_data,
109+
b_scale,
110+
output_dtype,
111+
output_scale=None,
112+
bias=None,
113+
use_fast_accum=mm_config.use_fast_accum,
99114
)
100115
return tensor_out
101116

@@ -113,14 +128,23 @@ def float8_addmm(aten_op, args, kwargs=None):
113128
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
114129
output_dtype = a._orig_dtype
115130
assert bias.dtype == output_dtype, "bias dtype must match output dtype"
116-
if a._emulate:
117-
assert a._emulate == b._emulate
131+
a_mm_config: ScaledMMConfig = a._mm_config
132+
b_mm_config: ScaledMMConfig = b._mm_config
133+
mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config)
134+
if mm_config.emulate:
118135
out = torch.ops.aten.mm_float8_emulated(
119136
a._data, a._scale, b._data, b._scale, output_dtype
120137
)[0]
121138
return out + bias
122139
tensor_out, amax = addmm_float8_unwrapped(
123-
a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=bias
140+
a_data,
141+
a_scale,
142+
b_data,
143+
b_scale,
144+
output_dtype,
145+
output_scale=None,
146+
bias=bias,
147+
use_fast_accum=mm_config.use_fast_accum,
124148
)
125149
return tensor_out
126150

@@ -145,7 +169,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
145169
torch.bfloat16,
146170
}, "Only support floating point conversion for autocast w/ Float8Tensor"
147171
return Float8Tensor(
148-
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._emulate
172+
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._mm_config
149173
)
150174

151175

@@ -170,7 +194,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
170194
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
171195
fp8_out = fp8_out.view(fp8_input._data.dtype)
172196
return Float8Tensor(
173-
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._emulate
197+
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
174198
)
175199

176200

@@ -182,5 +206,5 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
182206
fp8_data = fp8_input._data
183207
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
184208
return Float8Tensor(
185-
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._emulate
209+
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
186210
)

0 commit comments

Comments
 (0)