diff --git a/benchmarks/bench_dynamic_linear_fused_cast.py b/benchmarks/bench_dynamic_linear_fused_cast.py new file mode 100644 index 00000000..4259f2e7 --- /dev/null +++ b/benchmarks/bench_dynamic_linear_fused_cast.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import copy +from dataclasses import dataclass +from itertools import product +from pathlib import Path +from typing import Callable, List, Optional, Tuple + +import float8_experimental.config as fp8_config + +import pandas as pd + +import torch +import torch.utils.benchmark as benchmark +from float8_experimental.float8_dynamic_linear import Float8DynamicLinear +from tqdm import tqdm + +# estimating TOPs for matmuls in fp32, fp16, fp8 +# assuming A * B = C, with A being M * K, B being K * N, C being M * N + +# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ +h100_peak_flops_float32 = 67e12 +h100_peak_flops_fp16_tc = 1979e12 +h100_peak_tops_float8_tc = 3958e12 + +dtype_to_peak_tops = { + torch.float32: h100_peak_flops_float32, + torch.float16: h100_peak_flops_fp16_tc, + torch.bfloat16: h100_peak_flops_fp16_tc, + torch.float8_e4m3fn: h100_peak_tops_float8_tc, + torch.float8_e5m2: h100_peak_tops_float8_tc, +} + + +def benchmark_torch_function_in_microseconds( + func: Callable, + *args, + **kwargs, +) -> float: + t0 = benchmark.Timer( + stmt="func(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "func": func}, + ) + return t0.blocked_autorange().median * 1e6 + + +@dataclass +class Experiment: + name: str + shape: Tuple[int, int, int] + ref_time_sec: float + float8_time_sec: float + dtype: torch.dtype + use_fused_cast: bool + float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn + + # 3 Times since we are calculating forward backward + @property + def ref_tops_sec(self): + M, K, N = self.shape + return float(3 * (2 * M * K * N)) / self.ref_time_sec + + @property + def ref_pct_top_peak(self): + return self.ref_tops_sec / dtype_to_peak_tops[self.dtype] + + @property + def float8_tops_sec(self): + M, K, N = self.shape + return float(3 * (2 * M * K * N)) / self.float8_time_sec + + @property + def float8_pct_top_peak(self): + return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype] + + +def main( + sweep_path: Path, + n_limit: Optional[int] = None, +): + device = "cuda" + + # LLaMa 2 70B single-node weight shapes + # assumes fused attn.wqkv and ffn.w13 + name_to_shapes_70b = { + "attn.wqkv": (8192, 1280), + "attn.w0": (1024, 8192), + "ffn.w13": (8192, 7168), + "ffn.w2": (3584, 8192), + } + input_bias = False + ref_dtypes = [torch.bfloat16, torch.float32] + experiment_list: List[Experiment] = [] + fused_casts = [True, False] + for idx, (dtype, (name, (K, N)), fuse_cast) in enumerate( + tqdm(list(product(ref_dtypes, name_to_shapes_70b.items(), fused_casts))) + ): + fp8_config.use_fused_cast = fuse_cast + if n_limit is not None and idx >= n_limit: + break + linear_ref = torch.nn.Linear(K, N, bias=input_bias).to( + device=device, dtype=dtype + ) + + linear_float8 = Float8DynamicLinear.from_float( + copy.deepcopy(linear_ref), emulate=False + ) + + bsz, seq_len = 4, 4096 + M = bsz * seq_len + input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True) + ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() + + def float8_forw_backward(): + out = linear_float8(input_tensor) + out.sum().backward() + + def n_times(n, fn, *args, **kwargs): + def wrapper(*args, **kwargs): + for _ in range(n): + fn(*args, **kwargs) + + return wrapper + + REPEAT_N = 100 + + ref_forw_backward = n_times(REPEAT_N, ref_forw_backward) + float8_forw_backward = n_times(REPEAT_N, float8_forw_backward) + + for _ in range(5): + ref_forw_backward() + float8_forw_backward() + + ref_time = ( + benchmark_torch_function_in_microseconds(ref_forw_backward) + * 1e-6 + / REPEAT_N + ) + float8_time = ( + benchmark_torch_function_in_microseconds(float8_forw_backward) + * 1e-6 + / REPEAT_N + ) + experiment = Experiment( + name, (M, K, N), ref_time, float8_time, dtype, fuse_cast + ) + print(experiment) + print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) + experiment_list.append(experiment) + torch._dynamo.reset() + + headers = [ + "name", + "M", + "K", + "N", + "ref_dtype", + "fuse_cast", + "fp8_dtype", + "ref_time_sec", + "pt_fp8_time_sec", + "ref_tops_sec", + "ref_pct_top_peak", + "pt_fp8_tops_sec", + "pt_fp8_pct_top_peak", + ] + data = [] + for experiment in experiment_list: + data.append( + [ + experiment.name, + experiment.shape[0], + experiment.shape[1], + experiment.shape[2], + experiment.dtype, + experiment.use_fused_cast, + experiment.float_8_dtype, + experiment.ref_time_sec, + experiment.float8_time_sec, + experiment.ref_tops_sec, + experiment.ref_pct_top_peak, + experiment.float8_tops_sec, + experiment.float8_pct_top_peak, + ] + ) + + data_pd = pd.DataFrame(data, columns=headers) + data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"] + data_pd["shape"] = ( + "(" + + data_pd["M"].astype(str) + + ", " + + data_pd["K"].astype(str) + + ", " + + data_pd["N"].astype(str) + + ")" + ) + + data_pd_simple = data_pd[ + [ + "shape", + "ref_dtype", + "fuse_cast", + "ref_time_sec", + "pt_fp8_time_sec", + "pt_fp8_speedup", + ] + ] + print(data_pd_simple) + + sweep_path = sweep_path.with_suffix(".csv") + with open(sweep_path, mode="w") as file: + data_pd.to_csv(sweep_path) + + +def invoke_main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("-o", "--output_path", type=str, required=True) + parser.add_argument("-n", "--n_limit", type=int, required=False) + args = parser.parse_args() + output_path = Path(args.output_path) + main(output_path, args.n_limit) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 9df065bc..eb7ba22a 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -19,3 +19,8 @@ # TODO(before land): add test coverage for both cases # dynamic_use_activation_hooks = True # dynamic_use_activation_hooks = False + +# This is a global flag that controls whether the fused_cast kernels, +# This can offer greater performance in eager but it is still recommended +# That if you are using torch.compile to set this to False. +use_fused_cast = True diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 3647e185..2f65abec 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -5,10 +5,15 @@ # LICENSE file in the root directory of this source tree. from typing import Dict, Optional -import torch +import float8_experimental -from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated +import torch +from float8_experimental.float8_utils import ( + tensor_to_amax, + tensor_to_scale, + to_fp8_saturated, +) from torch.distributed._tensor import DTensor aten = torch.ops.aten @@ -35,8 +40,23 @@ def to_fp8_no_autograd( float8_dtype: the float8 dtype to use emulate: whether to emulate the matmuls in fp32 """ - x_scaled = x * x_scale - bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) + if ( + float8_experimental.config.use_fused_cast + and x.is_cuda + and x.dtype in {torch.float32, torch.bfloat16} + ): + from driss_torch import saturated_cast + + if x.dim() in {3, 4}: + prev_x_shape = x.shape + x = x.reshape(-1, x.size(-1)) + bits_fp8 = saturated_cast(x, x_scale, float8_dtype) + bits_fp8 = bits_fp8.reshape(prev_x_shape) + else: + bits_fp8 = saturated_cast(x, x_scale, float8_dtype) + else: + x_scaled = x * x_scale + bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) if isinstance(bits_fp8, DTensor): assert isinstance( diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index e31aa885..00f0ed03 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import float8_experimental import torch import torch.distributed as dist @@ -22,8 +23,10 @@ @torch.no_grad() -def amax_to_scale(amax, float8_dtype, orig_dtype): - scale = torch.empty_like(amax, dtype=torch.float32) +def amax_to_scale( + amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype +): + assert amax.dtype == torch.float32, "amax must be a float32 tensor" if float8_dtype == torch.float8_e4m3fn: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) else: # e5m2 @@ -34,16 +37,15 @@ def amax_to_scale(amax, float8_dtype, orig_dtype): # to care about this for float32/bfloat16. if orig_dtype is torch.float16: res = torch.clamp(res, max=FP16_MAX_POS) - scale.copy_(res) - return scale + return res @torch.no_grad() def amax_history_to_scale( - amax_history, - float8_dtype, - orig_dtype, - history_to_scale_fn_type, + amax_history: torch.Tensor, + float8_dtype: torch.dtype, + orig_dtype: torch.dtype, + history_to_scale_fn_type: str, ): if history_to_scale_fn_type == "max": amax = torch.max(amax_history) @@ -69,7 +71,12 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax(x, distributed_reduction=False): - amax = torch.max(torch.abs(x)) + if float8_experimental.config.use_fused_cast and x.is_cuda: + from float8_experimental.fused_kernels.fused_casting_kernels import abs_max + + amax = abs_max(x) + else: + amax = x.abs().max().to(torch.float32) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -81,8 +88,14 @@ def tensor_to_amax(x, distributed_reduction=False): @torch.no_grad() -def tensor_to_scale(x, float8_dtype): +def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype): amax = tensor_to_amax(x) + if float8_experimental.config.use_fused_cast and x.is_cuda: + from float8_experimental.fused_kernels.fused_casting_kernels import ( + abs_max_to_scale, + ) + + return abs_max_to_scale(amax, float8_dtype, x.dtype == torch.float16) return amax_to_scale(amax, float8_dtype, x.dtype) diff --git a/float8_experimental/fused_kernels/__init__.py b/float8_experimental/fused_kernels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/float8_experimental/fused_kernels/fused_casting_kernels.py b/float8_experimental/fused_kernels/fused_casting_kernels.py new file mode 100644 index 00000000..06bb41d8 --- /dev/null +++ b/float8_experimental/fused_kernels/fused_casting_kernels.py @@ -0,0 +1,128 @@ +import torch +import triton +import triton.language as tl +from triton import next_power_of_2 + + +E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max +E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max +FP16_MAX_POS = torch.finfo(torch.float16).max +EPS = 1e-12 + + +@triton.jit +def promote_to_tensor(x): + # Addition promotes to tensor for us + return x + tl.zeros((1,), tl.int1) + + +@triton.jit +def is_floating(x): + return promote_to_tensor(x).dtype.is_floating() + + +@triton.jit +def maximum(a, b): + mask = a > b + if is_floating(a): + mask |= a != a + return tl.where(mask, a, b) + + +@triton.jit +def abs_max_kernel( + x_ptr, + out_ptr, + x_numel: int, + r_numel: int, + X_BLOCK_SIZE: tl.constexpr, + R_BLOCK_SIZE: tl.constexpr, +): + x_offset = tl.program_id(0) * X_BLOCK_SIZE + x_index = x_offset + tl.arange(0, X_BLOCK_SIZE)[:, None] + x_mask = x_index < x_numel + reduction_base = tl.arange(0, R_BLOCK_SIZE)[None, :] + acc = tl.full([X_BLOCK_SIZE, R_BLOCK_SIZE], -float("inf"), tl.float32) + for r_offset in range(0, r_numel, R_BLOCK_SIZE): + r_index = r_offset + reduction_base + r_mask = r_index < r_numel + values = tl.load( + x_ptr + (r_index + (r_numel * x_index)), + r_mask, + eviction_policy="evict_last", + other=0.0, + ).to(tl.float32) + x_abs = tl.abs(values) + x_abs_broadcasted = tl.broadcast_to(x_abs, [X_BLOCK_SIZE, R_BLOCK_SIZE]) + acc_mask = maximum(acc, x_abs_broadcasted) + acc = tl.where(x_mask, acc_mask, acc) + out = tl.max(acc, 1)[:, None] + tl.store(out_ptr + x_index, out.to(tl.float32), x_mask) + + +def abs_max(x: torch.Tensor) -> torch.Tensor: + """Calculates the global max of the absolute values of a tensor + + This kernel launches a grid of 512 threads, each thread calculates the + maximum of x.numel // 512 elements. The results are then reduced to a single + value in a follow up kernel. + + Args: + x: Input tensor to calculate the abs_max for + """ + x = x.contiguous() + if x.numel() % 512 == 0: + output = torch.full( + (512, 1), -float("inf"), device=x.device, dtype=torch.float32 + ) + grid = lambda meta: (512,) + X_BLOCK_SIZE = 1 + R_BLOCK_SIZE = 1024 + r_numel = x.numel() // 512 + abs_max_kernel[grid]( + x, + output, + x_numel=512, + r_numel=r_numel, + X_BLOCK_SIZE=X_BLOCK_SIZE, + R_BLOCK_SIZE=R_BLOCK_SIZE, + ) + return output.max() + else: + return x.abs().max().to(torch.float32) + + +@triton.jit +def abs_max_to_scale_kernel_e4m3(x_ptr, out_ptr, clamp_float16): + abs_max = tl.load(x_ptr).to(tl.float32) + clamped = E4M3_MAX_POS / tl.clamp(abs_max, min=EPS, max=float("inf")) + if clamp_float16: + clamped = tl.clamp(clamped, min=EPS, max=FP16_MAX_POS) + tl.store(out_ptr, clamped) + + +@triton.jit +def abs_max_to_scale_kernel_e5m2(x_ptr, out_ptr, clamp_float16): + abs_max = tl.load(x_ptr) + clamped = E5M2_MAX_POS / tl.clamp(abs_max, min=EPS, max=float("inf")) + if clamp_float16: + clamped = tl.clamp(clamped, min=EPS, max=FP16_MAX_POS) + tl.store(out_ptr, clamped) + + +def abs_max_to_scale( + x: torch.Tensor, fp8_dtype: torch.dtype, clamp_float16: bool +) -> torch.Tensor: + assert x.numel() == 1, "Expected a single value, but got: {} elements".format( + x.numel() + ) + assert x.dtype == torch.float32, "Expected a float32 tensor, but got: {}".format( + x.dtype + ) + output = torch.empty((), device=x.device, dtype=torch.float32) + grid = lambda meta: (1,) + if fp8_dtype == torch.float8_e4m3fn: + abs_max_to_scale_kernel_e4m3[grid](x, output, clamp_float16) + else: + abs_max_to_scale_kernel_e5m2[grid](x, output, clamp_float16) + return output diff --git a/test/test_fused_kernels.py b/test/test_fused_kernels.py new file mode 100644 index 00000000..06d27104 --- /dev/null +++ b/test/test_fused_kernels.py @@ -0,0 +1,39 @@ +from typing import Tuple + +import pytest +import torch +from float8_experimental.float8_utils import amax_to_scale, tensor_to_amax +from float8_experimental.fused_kernels.fused_casting_kernels import ( + abs_max, + abs_max_to_scale, +) + + +@pytest.mark.parametrize( + "shape", + [ + (16384, 1024), + (16384, 8192), + (16384, 3584), + (8192, 1280), + (1024, 8192), + (8192, 7168), + (3584, 8192), + ], +) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_abs_max(shape: Tuple[int], dtype: torch.dtype): + x = torch.randn(shape, dtype=dtype, device="cuda") + max_abs = abs_max(x) + assert torch.allclose(max_abs, tensor_to_amax(x)) + + +@pytest.mark.parametrize("numel", [2**i for i in range(10, 20)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("fp8_type", [torch.float8_e4m3fn, torch.float8_e5m2]) +def test_amax_to_scale(numel: int, dtype: torch.dtype, fp8_type: torch.dtype): + x = torch.randn(numel, dtype=dtype, device="cuda") + max_abs = abs_max(x) + fused = abs_max_to_scale(max_abs, fp8_type, dtype == torch.float16) + eager = amax_to_scale(max_abs, fp8_type, dtype) + assert torch.allclose(fused, eager)