Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Allow for modifying the scaled_mm compute (#144)
Browse files Browse the repository at this point in the history
Summary:
This does two things:
1.  Creates a new named_tuple type `ScaledMMConfig` that is used to control the behavior of the scaled_mm op. This includes, emulate, fast_accumulation, and fp8_out_dtype(the latter is not currently used).  It replaces the emulate arg and strings it through all the relevant infra, and updates test accordingly.
2. This adds the fp8 fast accum mode and enables it for the forward path and not the backward pass.

### Performance
With settings use_fast_accum in the forward using the linear_float8 benchmark:

![image](https://github.com/pytorch-labs/float8_experimental/assets/32754868/8510814e-88d0-402c-9676-d4afe8fef2a0)

|    | shape               |   Speedup_with_False |   Speedup_with_True |   Percentage_Gain |
|---:|:--------------------|---------------------:|--------------------:|------------------:|
|  0 | (16384, 1024, 8192) |             1.19086  |            1.26397  |           6.13912 |
|  1 | (16384, 3584, 8192) |             1.42227  |            1.48921  |           4.70629 |
|  2 | (16384, 8192, 1280) |             0.970685 |            0.986167 |           1.59497 |
|  3 | (16384, 8192, 7168) |             1.50755  |            1.54886  |           2.74022 |

Pull Request resolved: #144

Reviewed By: vkuzo

Differential Revision: D55906764

Pulled By: drisspg

fbshipit-source-id: c6c7f7d5f7831bc594c8e70c71d9ab0e0c90755c
  • Loading branch information
drisspg authored and facebook-github-bot committed Apr 9, 2024
1 parent 14da04f commit 31877bb
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 116 deletions.
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,17 @@ for _ in range(N_ITER):
optimizer.step()
```

# code tips
# 🧭 Code Organization

* `float8_experimental/float8_linear.py` - `Float8Linear` (main user facing entry point for delayed scaling)
* `float8_experimental/float8_dynamic_linear.py` - `Float8DynamicLinear` (main user facing entry point for dynamic scaling)
* `float8_experimental/float8_tensor.py` - `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
* `float8_experimental/float8_linear.py`
- `Float8Linear` (main user facing entry point for delayed scaling)
* `float8_experimental/float8_dynamic_linear.py`
- `Float8DynamicLinear` (main user facing entry point for dynamic scaling)
* `float8_experimental/float8_tensor.py`
- `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
- `ScaledMMConfig` defines the semantics for matmul in the forward and backwards pass

# testing
# Testing

```bash
# run single-GPU unit tests
Expand All @@ -117,7 +121,7 @@ pytest test/test_compile.py
./test/run_everything.sh
```

# benchmarking
# Benchmarking

```bash
# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes
Expand All @@ -130,4 +134,3 @@ pytest test/test_compile.py

# License
PyTorch has a BSD 3-Clause License, as found in the LICENSE file.

22 changes: 15 additions & 7 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch.utils.benchmark as benchmark
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history
from float8_experimental.float8_tensor import ScaledMMConfig
from tqdm import tqdm

# estimating TOPs for matmuls in fp32, fp16, fp8
Expand Down Expand Up @@ -54,8 +55,8 @@ class Experiment:
ref_time_sec: float
float8_time_sec: float
dtype: torch.dtype
compiled: bool = False
float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn
compiled: bool
use_fast_accum: bool

# 3 Times since we are calculating forward backward
@property
Expand All @@ -74,7 +75,7 @@ def float8_tops_sec(self):

@property
def float8_pct_top_peak(self):
return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype]
return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn]


def main(
Expand All @@ -95,9 +96,10 @@ def main(
}
input_bias = False
ref_dtypes = [torch.bfloat16, torch.float16]
use_fast_accum = [True, False]
experiment_list: List[Experiment] = []
for idx, (dtype, (name, (K, N))) in enumerate(
tqdm(list(product(ref_dtypes, name_to_shapes_70b.items())))
for idx, (dtype, fast_accum, (name, (K, N))) in enumerate(
tqdm(list(product(ref_dtypes, use_fast_accum, name_to_shapes_70b.items())))
):
if n_limit is not None and idx >= n_limit:
break
Expand All @@ -108,6 +110,10 @@ def main(
linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref), emulate=False
)
if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
else:
linear_float8.forward_config = ScaledMMConfig(False, False, False)

bsz, seq_len = 4, 4096
M = bsz * seq_len
Expand Down Expand Up @@ -155,6 +161,7 @@ def wrapper(*args, **kwargs):
float8_time,
dtype,
compile,
use_fast_accum=fast_accum,
)
print(experiment)
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
Expand All @@ -168,7 +175,7 @@ def wrapper(*args, **kwargs):
"N",
"ref_dtype",
"compiled",
"fp8_dtype",
"use_fast_accum",
"ref_time_sec",
"pt_fp8_time_sec",
"ref_tops_sec",
Expand All @@ -186,7 +193,7 @@ def wrapper(*args, **kwargs):
experiment.shape[2],
experiment.dtype,
experiment.compiled,
experiment.float_8_dtype,
experiment.use_fast_accum,
experiment.ref_time_sec,
experiment.float8_time_sec,
experiment.ref_tops_sec,
Expand Down Expand Up @@ -214,6 +221,7 @@ def wrapper(*args, **kwargs):
"shape",
"ref_dtype",
"compiled",
"use_fast_accum",
"ref_time_sec",
"pt_fp8_time_sec",
"pt_fp8_speedup",
Expand Down
14 changes: 8 additions & 6 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
Expand All @@ -27,9 +28,9 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
def forward(
ctx,
tensor,
emulate: bool,
mm_config: ScaledMMConfig,
):
ctx.emulate = emulate
ctx.mm_config = mm_config
return tensor

@staticmethod
Expand All @@ -39,7 +40,7 @@ def backward(ctx, gradY):
return gradY, None
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, torch.float8_e5m2, ctx.emulate
gradY, gradY_scale, torch.float8_e5m2, mm_config=ctx.mm_config
)
return fp8_tensor, None

Expand Down Expand Up @@ -73,11 +74,11 @@ def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor:
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=self.forward_config
)

def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate)
return NoopFwToFloat8E5M2Bw.apply(gradY, self.backward_config)

@classmethod
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
Expand All @@ -97,5 +98,6 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
new_mod = cls(**super_kwargs)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
new_mod.backward_config = ScaledMMConfig(emulate, False)
return new_mod
42 changes: 28 additions & 14 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@

import torch

from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
to_fp8_no_autograd,
)

from float8_experimental.float8_utils import (
amax_history_to_scale,
Expand Down Expand Up @@ -73,12 +77,12 @@ def forward(
fp8_scale_dL_dY,
scale_fn_name,
is_amax_initialized,
emulate: bool,
mm_config: ScaledMMConfig,
):
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
ctx.scale_fn_name = scale_fn_name
ctx.is_amax_initialized = is_amax_initialized
ctx.emulate = emulate
ctx.mm_config = mm_config
return tensor

@staticmethod
Expand All @@ -99,7 +103,9 @@ def backward(ctx, go):

fp8_amax_dL_dY.fill_(tensor_to_amax(go))

res = to_fp8_no_autograd(go, fp8_scale_dL_dY, torch.float8_e5m2, ctx.emulate)
res = to_fp8_no_autograd(
go, fp8_scale_dL_dY, torch.float8_e5m2, mm_config=ctx.mm_config
)
empty_grads = None, None, None, None, None, None
return res, *empty_grads

Expand Down Expand Up @@ -154,8 +160,9 @@ def __init__(self, *args, **kwargs):
)
self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor([1.0]))

# Whether to emulate the fp8 matmul logic in float32
self.emulate = False
# Defines the behavior of the matmul in the forward and backward pass
self.forward_config = ScaledMMConfig()
self.backward_config = ScaledMMConfig()

# Note: is_amax_initialized is not a buffer to avoid data dependent
# control flow visible to dynamo
Expand Down Expand Up @@ -216,7 +223,11 @@ def cast_x_to_float8(
is_amax_initialized,
)
x_fp8 = Float8Tensor.to_float8(
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x, self.emulate
x,
self.fp8_scale_x,
torch.float8_e4m3fn,
self.fp8_amax_x,
self.forward_config,
)
return x_fp8

Expand All @@ -239,13 +250,11 @@ def cast_w_to_float8(
self.fp8_scale_w,
torch.float8_e4m3fn,
self.fp8_amax_w,
self.emulate,
self.forward_config,
)
return w_fp8

def cast_y_to_float8_in_bw(
self, y: torch.Tensor, emulate: bool = False
) -> torch.Tensor:
def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
scale_fn_name = self.recipe.scale_fn_name
y = NoopFwToFloat8E5M2Bw.apply(
y,
Expand All @@ -254,7 +263,7 @@ def cast_y_to_float8_in_bw(
self.fp8_scale_dL_dY,
scale_fn_name,
self.is_amax_initialized,
emulate,
self.backward_config,
)
return y

Expand Down Expand Up @@ -295,7 +304,7 @@ def forward(self, x):
y = torch.matmul(x_fp8, w_fp8.t())

# Cast gradY to float8_e5m2 during backward
y = self.cast_y_to_float8_in_bw(y, self.emulate)
y = self.cast_y_to_float8_in_bw(y)

if self.bias is not None:
y = y + self.bias.to(y.dtype)
Expand All @@ -318,7 +327,12 @@ def from_float(cls, mod, emulate: bool = False):
new_mod = cls(mod.in_features, mod.out_features, bias=False)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate

# Defines the behavior of the matmul in the forward and backward
# Forward we use fast_accum, backwards we do not
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
new_mod.backward_config = ScaledMMConfig(emulate, False)

# I think its okay to send all params and buffers to device
new_mod.to(mod.weight.device)
return new_mod
46 changes: 35 additions & 11 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import torch

from float8_experimental.float8_python_api import addmm_float8_unwrapped
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import (
Float8Tensor,
merge_mm_configs,
ScaledMMConfig,
)
from float8_experimental.float8_utils import is_row_major
from torch.utils._pytree import tree_map

Expand Down Expand Up @@ -41,7 +45,9 @@ def decorator(func):
)
def float8_desugar_op(aten_op, args, kwargs=None):
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
return Float8Tensor(new_data, args[0]._scale, args[0]._orig_dtype, args[0]._emulate)
return Float8Tensor(
new_data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config
)


@implements([aten.sum.dim_IntList])
Expand Down Expand Up @@ -89,13 +95,22 @@ def float8_mm(aten_op, args, kwargs=None):
)
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
output_dtype = a._orig_dtype
if a._emulate:
assert a._emulate == b._emulate
a_mm_config: ScaledMMConfig = a._mm_config
b_mm_config: ScaledMMConfig = b._mm_config
mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config)
if mm_config.emulate:
return torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
)[0]
tensor_out, amax = addmm_float8_unwrapped(
a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=None
a_data,
a_scale,
b_data,
b_scale,
output_dtype,
output_scale=None,
bias=None,
use_fast_accum=mm_config.use_fast_accum,
)
return tensor_out

Expand All @@ -113,14 +128,23 @@ def float8_addmm(aten_op, args, kwargs=None):
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
output_dtype = a._orig_dtype
assert bias.dtype == output_dtype, "bias dtype must match output dtype"
if a._emulate:
assert a._emulate == b._emulate
a_mm_config: ScaledMMConfig = a._mm_config
b_mm_config: ScaledMMConfig = b._mm_config
mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config)
if mm_config.emulate:
out = torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
)[0]
return out + bias
tensor_out, amax = addmm_float8_unwrapped(
a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=bias
a_data,
a_scale,
b_data,
b_scale,
output_dtype,
output_scale=None,
bias=bias,
use_fast_accum=mm_config.use_fast_accum,
)
return tensor_out

Expand All @@ -145,7 +169,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
torch.bfloat16,
}, "Only support floating point conversion for autocast w/ Float8Tensor"
return Float8Tensor(
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._emulate
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._mm_config
)


Expand All @@ -170,7 +194,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
fp8_out = fp8_out.view(fp8_input._data.dtype)
return Float8Tensor(
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._emulate
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
)


Expand All @@ -182,5 +206,5 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
fp8_data = fp8_input._data
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
return Float8Tensor(
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._emulate
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
)
Loading

0 comments on commit 31877bb

Please sign in to comment.