Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

add option for using fused kernel #227

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
230 changes: 230 additions & 0 deletions benchmarks/bench_dynamic_linear_fused_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import copy
from dataclasses import dataclass
from itertools import product
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import float8_experimental.config as fp8_config

import pandas as pd

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from tqdm import tqdm

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 1979e12
h100_peak_tops_float8_tc = 3958e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}


def benchmark_torch_function_in_microseconds(
func: Callable,
*args,
**kwargs,
) -> float:
t0 = benchmark.Timer(
stmt="func(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "func": func},
)
return t0.blocked_autorange().median * 1e6


@dataclass
class Experiment:
name: str
shape: Tuple[int, int, int]
ref_time_sec: float
float8_time_sec: float
dtype: torch.dtype
use_fused_cast: bool
float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn

# 3 Times since we are calculating forward backward
@property
def ref_tops_sec(self):
M, K, N = self.shape
return float(3 * (2 * M * K * N)) / self.ref_time_sec

@property
def ref_pct_top_peak(self):
return self.ref_tops_sec / dtype_to_peak_tops[self.dtype]

@property
def float8_tops_sec(self):
M, K, N = self.shape
return float(3 * (2 * M * K * N)) / self.float8_time_sec

@property
def float8_pct_top_peak(self):
return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype]


def main(
sweep_path: Path,
n_limit: Optional[int] = None,
):
device = "cuda"

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}
input_bias = False
ref_dtypes = [torch.bfloat16, torch.float32]
experiment_list: List[Experiment] = []
fused_casts = [True, False]
for idx, (dtype, (name, (K, N)), fuse_cast) in enumerate(
tqdm(list(product(ref_dtypes, name_to_shapes_70b.items(), fused_casts)))
):
fp8_config.use_fused_cast = fuse_cast
if n_limit is not None and idx >= n_limit:
break
linear_ref = torch.nn.Linear(K, N, bias=input_bias).to(
device=device, dtype=dtype
)

linear_float8 = Float8DynamicLinear.from_float(
copy.deepcopy(linear_ref), emulate=False
)

bsz, seq_len = 4, 4096
M = bsz * seq_len
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

def float8_forw_backward():
out = linear_float8(input_tensor)
out.sum().backward()

def n_times(n, fn, *args, **kwargs):
def wrapper(*args, **kwargs):
for _ in range(n):
fn(*args, **kwargs)

return wrapper

REPEAT_N = 100

ref_forw_backward = n_times(REPEAT_N, ref_forw_backward)
float8_forw_backward = n_times(REPEAT_N, float8_forw_backward)

for _ in range(5):
ref_forw_backward()
float8_forw_backward()

ref_time = (
benchmark_torch_function_in_microseconds(ref_forw_backward)
* 1e-6
/ REPEAT_N
)
float8_time = (
benchmark_torch_function_in_microseconds(float8_forw_backward)
* 1e-6
/ REPEAT_N
)
experiment = Experiment(
name, (M, K, N), ref_time, float8_time, dtype, fuse_cast
)
print(experiment)
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
experiment_list.append(experiment)
torch._dynamo.reset()

headers = [
"name",
"M",
"K",
"N",
"ref_dtype",
"fuse_cast",
"fp8_dtype",
"ref_time_sec",
"pt_fp8_time_sec",
"ref_tops_sec",
"ref_pct_top_peak",
"pt_fp8_tops_sec",
"pt_fp8_pct_top_peak",
]
data = []
for experiment in experiment_list:
data.append(
[
experiment.name,
experiment.shape[0],
experiment.shape[1],
experiment.shape[2],
experiment.dtype,
experiment.use_fused_cast,
experiment.float_8_dtype,
experiment.ref_time_sec,
experiment.float8_time_sec,
experiment.ref_tops_sec,
experiment.ref_pct_top_peak,
experiment.float8_tops_sec,
experiment.float8_pct_top_peak,
]
)

data_pd = pd.DataFrame(data, columns=headers)
data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"]
data_pd["shape"] = (
"("
+ data_pd["M"].astype(str)
+ ", "
+ data_pd["K"].astype(str)
+ ", "
+ data_pd["N"].astype(str)
+ ")"
)

data_pd_simple = data_pd[
[
"shape",
"ref_dtype",
"fuse_cast",
"ref_time_sec",
"pt_fp8_time_sec",
"pt_fp8_speedup",
]
]
print(data_pd_simple)

sweep_path = sweep_path.with_suffix(".csv")
with open(sweep_path, mode="w") as file:
data_pd.to_csv(sweep_path)


def invoke_main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output_path", type=str, required=True)
parser.add_argument("-n", "--n_limit", type=int, required=False)
args = parser.parse_args()
output_path = Path(args.output_path)
main(output_path, args.n_limit)


if __name__ == "__main__":
invoke_main() # pragma: no cover
5 changes: 5 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@
# TODO(before land): add test coverage for both cases
# dynamic_use_activation_hooks = True
# dynamic_use_activation_hooks = False

# This is a global flag that controls whether the fused_cast kernels,
# This can offer greater performance in eager but it is still recommended
# That if you are using torch.compile to set this to False.
use_fused_cast = True
28 changes: 24 additions & 4 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional

import torch
import float8_experimental

from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
import torch

from float8_experimental.float8_utils import (
tensor_to_amax,
tensor_to_scale,
to_fp8_saturated,
)
from torch.distributed._tensor import DTensor

aten = torch.ops.aten
Expand All @@ -35,8 +40,23 @@ def to_fp8_no_autograd(
float8_dtype: the float8 dtype to use
emulate: whether to emulate the matmuls in fp32
"""
x_scaled = x * x_scale
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)
if (
float8_experimental.config.use_fused_cast
and x.is_cuda
and x.dtype in {torch.float32, torch.bfloat16}
):
from driss_torch import saturated_cast

if x.dim() in {3, 4}:
prev_x_shape = x.shape
x = x.reshape(-1, x.size(-1))
bits_fp8 = saturated_cast(x, x_scale, float8_dtype)
bits_fp8 = bits_fp8.reshape(prev_x_shape)
else:
bits_fp8 = saturated_cast(x, x_scale, float8_dtype)
else:
x_scaled = x * x_scale
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)

if isinstance(bits_fp8, DTensor):
assert isinstance(
Expand Down
33 changes: 23 additions & 10 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import float8_experimental
import torch
import torch.distributed as dist

Expand All @@ -22,8 +23,10 @@


@torch.no_grad()
def amax_to_scale(amax, float8_dtype, orig_dtype):
scale = torch.empty_like(amax, dtype=torch.float32)
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
):
assert amax.dtype == torch.float32, "amax must be a float32 tensor"
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else: # e5m2
Expand All @@ -34,16 +37,15 @@ def amax_to_scale(amax, float8_dtype, orig_dtype):
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=FP16_MAX_POS)
scale.copy_(res)
return scale
return res


@torch.no_grad()
def amax_history_to_scale(
amax_history,
float8_dtype,
orig_dtype,
history_to_scale_fn_type,
amax_history: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
history_to_scale_fn_type: str,
):
if history_to_scale_fn_type == "max":
amax = torch.max(amax_history)
Expand All @@ -69,7 +71,12 @@ def amax_history_to_scale_stack(

@torch.no_grad()
def tensor_to_amax(x, distributed_reduction=False):
amax = torch.max(torch.abs(x))
if float8_experimental.config.use_fused_cast and x.is_cuda:
from float8_experimental.fused_kernels.fused_casting_kernels import abs_max

amax = abs_max(x)
else:
amax = x.abs().max().to(torch.float32)

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
Expand All @@ -81,8 +88,14 @@ def tensor_to_amax(x, distributed_reduction=False):


@torch.no_grad()
def tensor_to_scale(x, float8_dtype):
def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype):
amax = tensor_to_amax(x)
if float8_experimental.config.use_fused_cast and x.is_cuda:
from float8_experimental.fused_kernels.fused_casting_kernels import (
abs_max_to_scale,
)

return abs_max_to_scale(amax, float8_dtype, x.dtype == torch.float16)
return amax_to_scale(amax, float8_dtype, x.dtype)


Expand Down
Empty file.
Loading
Loading