From 88f7a8715b0a02beab0e669c8fa7675ccb31f014 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 26 Feb 2024 14:36:18 -0800 Subject: [PATCH 01/11] add option for using fused kernel --- float8_experimental/config.py | 2 ++ float8_experimental/float8_tensor.py | 18 ++++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 9df065bc..8fbbb092 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -19,3 +19,5 @@ # TODO(before land): add test coverage for both cases # dynamic_use_activation_hooks = True # dynamic_use_activation_hooks = False + +use_fused_cast = True diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 3647e185..3f9012c7 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -5,10 +5,15 @@ # LICENSE file in the root directory of this source tree. from typing import Dict, Optional -import torch +import float8_experimental.config as fp8_config -from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated +import torch +from float8_experimental.float8_utils import ( + tensor_to_amax, + tensor_to_scale, + to_fp8_saturated, +) from torch.distributed._tensor import DTensor aten = torch.ops.aten @@ -35,8 +40,13 @@ def to_fp8_no_autograd( float8_dtype: the float8 dtype to use emulate: whether to emulate the matmuls in fp32 """ - x_scaled = x * x_scale - bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) + if fp8_config.use_fused_cast and x.device == "cuda": + from driss_torch import saturated_cast + + bits_fp8 = saturated_cast(x, float8_dtype, x_scale) + else: + x_scaled = x * x_scale + bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) if isinstance(bits_fp8, DTensor): assert isinstance( From fe3485b1f97d15e3380c2739184d974d9e76413f Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 26 Feb 2024 18:51:30 -0800 Subject: [PATCH 02/11] more support --- benchmarks/bench_dynamic_linear_fused_cast.py | 233 ++++++++++++++++++ float8_experimental/float8_tensor.py | 5 +- 2 files changed, 235 insertions(+), 3 deletions(-) create mode 100644 benchmarks/bench_dynamic_linear_fused_cast.py diff --git a/benchmarks/bench_dynamic_linear_fused_cast.py b/benchmarks/bench_dynamic_linear_fused_cast.py new file mode 100644 index 00000000..c82109f8 --- /dev/null +++ b/benchmarks/bench_dynamic_linear_fused_cast.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import copy +from dataclasses import dataclass +from itertools import product +from pathlib import Path +from typing import Callable, List, Optional, Tuple + +import pandas as pd + +import torch +import torch.utils.benchmark as benchmark +from float8_experimental.float8_dynamic_linear import Float8DynamicLinear +import float8_experimental.config as fp8_config +from tqdm import tqdm + +# estimating TOPs for matmuls in fp32, fp16, fp8 +# assuming A * B = C, with A being M * K, B being K * N, C being M * N + +# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ +h100_peak_flops_float32 = 67e12 +h100_peak_flops_fp16_tc = 1979e12 +h100_peak_tops_float8_tc = 3958e12 + +dtype_to_peak_tops = { + torch.float32: h100_peak_flops_float32, + torch.float16: h100_peak_flops_fp16_tc, + torch.bfloat16: h100_peak_flops_fp16_tc, + torch.float8_e4m3fn: h100_peak_tops_float8_tc, + torch.float8_e5m2: h100_peak_tops_float8_tc, +} + + +def benchmark_torch_function_in_microseconds( + func: Callable, + *args, + **kwargs, +) -> float: + t0 = benchmark.Timer( + stmt="func(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "func": func}, + ) + return t0.blocked_autorange().median * 1e6 + + +@dataclass +class Experiment: + name: str + shape: Tuple[int, int, int] + ref_time_sec: float + float8_time_sec: float + dtype: torch.dtype + use_fused_cast: bool + float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn + + # 3 Times since we are calculating forward backward + @property + def ref_tops_sec(self): + M, K, N = self.shape + return float(3 * (2 * M * K * N)) / self.ref_time_sec + + @property + def ref_pct_top_peak(self): + return self.ref_tops_sec / dtype_to_peak_tops[self.dtype] + + @property + def float8_tops_sec(self): + M, K, N = self.shape + return float(3 * (2 * M * K * N)) / self.float8_time_sec + + @property + def float8_pct_top_peak(self): + return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype] + + +def main( + sweep_path: Path, + n_limit: Optional[int] = None, +): + device = "cuda" + + # LLaMa 2 70B single-node weight shapes + # assumes fused attn.wqkv and ffn.w13 + name_to_shapes_70b = { + "attn.wqkv": (8192, 1280), + "attn.w0": (1024, 8192), + "ffn.w13": (8192, 7168), + "ffn.w2": (3584, 8192), + } + input_bias = False + ref_dtypes = [torch.bfloat16, torch.float32] + experiment_list: List[Experiment] = [] + fused_casts = [True, False] + for idx, (dtype, (name, (K, N)), fuse_cast) in enumerate( + tqdm(list(product(ref_dtypes, name_to_shapes_70b.items(), fused_casts))) + ): + fp8_config.use_fused_cast = fuse_cast + if n_limit is not None and idx >= n_limit: + break + linear_ref = torch.nn.Linear(K, N, bias=input_bias).to( + device=device, dtype=dtype + ) + + linear_float8 = Float8DynamicLinear.from_float( + copy.deepcopy(linear_ref), emulate=False + ) + + bsz, seq_len = 4, 4096 + M = bsz * seq_len + input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True) + ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() + + def float8_forw_backward(): + linear_float8(input_tensor).sum().backward() + + def n_times(n, fn, *args, **kwargs): + def wrapper(*args, **kwargs): + for _ in range(n): + fn(*args, **kwargs) + + return wrapper + + REPEAT_N = 100 + + ref_forw_backward = n_times(REPEAT_N, ref_forw_backward) + float8_forw_backward = n_times(REPEAT_N, float8_forw_backward) + + for _ in range(5): + ref_forw_backward() + float8_forw_backward() + + ref_time = ( + benchmark_torch_function_in_microseconds(ref_forw_backward) + * 1e-6 + / REPEAT_N + ) + float8_time = ( + benchmark_torch_function_in_microseconds(float8_forw_backward) + * 1e-6 + / REPEAT_N + ) + experiment = Experiment( + name, + (M, K, N), + ref_time, + float8_time, + dtype, + fuse_cast + ) + print(experiment) + print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) + experiment_list.append(experiment) + torch._dynamo.reset() + + headers = [ + "name", + "M", + "K", + "N", + "ref_dtype", + "fuse_cast", + "fp8_dtype", + "ref_time_sec", + "pt_fp8_time_sec", + "ref_tops_sec", + "ref_pct_top_peak", + "pt_fp8_tops_sec", + "pt_fp8_pct_top_peak", + ] + data = [] + for experiment in experiment_list: + data.append( + [ + experiment.name, + experiment.shape[0], + experiment.shape[1], + experiment.shape[2], + experiment.dtype, + experiment.use_fused_cast, + experiment.float_8_dtype, + experiment.ref_time_sec, + experiment.float8_time_sec, + experiment.ref_tops_sec, + experiment.ref_pct_top_peak, + experiment.float8_tops_sec, + experiment.float8_pct_top_peak, + ] + ) + + data_pd = pd.DataFrame(data, columns=headers) + data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"] + data_pd["shape"] = ( + "(" + + data_pd["M"].astype(str) + + ", " + + data_pd["K"].astype(str) + + ", " + + data_pd["N"].astype(str) + + ")" + ) + + data_pd_simple = data_pd[ + [ + "shape", + "ref_dtype", + "fuse_cast", + "ref_time_sec", + "pt_fp8_time_sec", + "pt_fp8_speedup", + ] + ] + print(data_pd_simple) + + sweep_path = sweep_path.with_suffix(".csv") + with open(sweep_path, mode="w") as file: + data_pd.to_csv(sweep_path) + + +def invoke_main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("-o", "--output_path", type=str, required=True) + parser.add_argument("-n", "--n_limit", type=int, required=False) + args = parser.parse_args() + output_path = Path(args.output_path) + main(output_path, args.n_limit) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 3f9012c7..059e058e 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict, Optional -import float8_experimental.config as fp8_config +import float8_experimental import torch @@ -40,9 +40,8 @@ def to_fp8_no_autograd( float8_dtype: the float8 dtype to use emulate: whether to emulate the matmuls in fp32 """ - if fp8_config.use_fused_cast and x.device == "cuda": + if float8_experimental.config.use_fused_cast and x.is_cuda and x.dtype in {torch.float32, torch.bfloat16}: from driss_torch import saturated_cast - bits_fp8 = saturated_cast(x, float8_dtype, x_scale) else: x_scaled = x * x_scale From d61b1f8bd243bad72f889e306fa583beebbf4211 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 27 Feb 2024 12:41:57 -0800 Subject: [PATCH 03/11] add two more kernels --- benchmarks/bench_dynamic_linear_fused_cast.py | 10 +-- float8_experimental/float8_tensor.py | 9 ++- float8_experimental/float8_utils.py | 14 +++- float8_experimental/fused_kernels/__init__.py | 0 .../fused_kernels/fused_casting_kernels.py | 71 +++++++++++++++++++ test/test_fused_kernels.py | 26 +++++++ 6 files changed, 120 insertions(+), 10 deletions(-) create mode 100644 float8_experimental/fused_kernels/__init__.py create mode 100644 float8_experimental/fused_kernels/fused_casting_kernels.py create mode 100644 test/test_fused_kernels.py diff --git a/benchmarks/bench_dynamic_linear_fused_cast.py b/benchmarks/bench_dynamic_linear_fused_cast.py index c82109f8..0f2a7a76 100644 --- a/benchmarks/bench_dynamic_linear_fused_cast.py +++ b/benchmarks/bench_dynamic_linear_fused_cast.py @@ -10,12 +10,13 @@ from pathlib import Path from typing import Callable, List, Optional, Tuple +import float8_experimental.config as fp8_config + import pandas as pd import torch import torch.utils.benchmark as benchmark from float8_experimental.float8_dynamic_linear import Float8DynamicLinear -import float8_experimental.config as fp8_config from tqdm import tqdm # estimating TOPs for matmuls in fp32, fp16, fp8 @@ -144,12 +145,7 @@ def wrapper(*args, **kwargs): / REPEAT_N ) experiment = Experiment( - name, - (M, K, N), - ref_time, - float8_time, - dtype, - fuse_cast + name, (M, K, N), ref_time, float8_time, dtype, fuse_cast ) print(experiment) print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 059e058e..a2292a32 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -40,9 +40,14 @@ def to_fp8_no_autograd( float8_dtype: the float8 dtype to use emulate: whether to emulate the matmuls in fp32 """ - if float8_experimental.config.use_fused_cast and x.is_cuda and x.dtype in {torch.float32, torch.bfloat16}: + if ( + float8_experimental.config.use_fused_cast + and x.is_cuda + and x.dtype in {torch.float32, torch.bfloat16} + ): from driss_torch import saturated_cast - bits_fp8 = saturated_cast(x, float8_dtype, x_scale) + + bits_fp8 = saturated_cast(x, x_scale, float8_dtype) else: x_scaled = x * x_scale bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index e31aa885..14724e48 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import float8_experimental import torch import torch.distributed as dist @@ -69,7 +70,12 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax(x, distributed_reduction=False): - amax = torch.max(torch.abs(x)) + if float8_experimental.config.use_fused_cast and x.is_cuda: + from float8_experimental.fused_kernels.fused_casting_kernels import abs_max + + amax = abs_max(x) + else: + amax = x.abs().max() # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -83,6 +89,12 @@ def tensor_to_amax(x, distributed_reduction=False): @torch.no_grad() def tensor_to_scale(x, float8_dtype): amax = tensor_to_amax(x) + if float8_experimental.config.use_fused_cast and x.is_cuda: + from float8_experimental.fused_kernels.fused_casting_kernels import ( + abs_max_to_scale, + ) + + return abs_max_to_scale(amax, float8_dtype, x.dtype == torch.float16) return amax_to_scale(amax, float8_dtype, x.dtype) diff --git a/float8_experimental/fused_kernels/__init__.py b/float8_experimental/fused_kernels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/float8_experimental/fused_kernels/fused_casting_kernels.py b/float8_experimental/fused_kernels/fused_casting_kernels.py new file mode 100644 index 00000000..2d2b3f94 --- /dev/null +++ b/float8_experimental/fused_kernels/fused_casting_kernels.py @@ -0,0 +1,71 @@ +import torch +import triton +import triton.language as tl +from triton import next_power_of_2 + + +E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max +E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max +FP16_MAX_POS = torch.finfo(torch.float16).max +EPS = 1e-12 + + +@triton.jit +def abs_max_kernel(x_ptr, out_ptr, n_elements: int, BLOCK_SIZE: tl.constexpr): + offset_base = tl.arange(0, BLOCK_SIZE)[None, :] + acc = tl.full([1, BLOCK_SIZE], -float("inf"), tl.float32) + for offset in range(0, n_elements, BLOCK_SIZE): + index = offset + offset_base + mask = index < n_elements + x = tl.load(x_ptr + index, mask, eviction_policy="evict_first", other=0.0) + x_broadcast = tl.broadcast_to(x, [1, BLOCK_SIZE]) + x_abs = tl.abs(x_broadcast) + acc = tl.maximum(acc, x_abs) + out = tl.max(acc, 1)[:, None] + tl.store(out_ptr + (tl.full([1, 1], 0, tl.int32)), out.to(tl.float32)) + + +def abs_max(x: torch.Tensor) -> torch.Tensor: + "Calculates the global max of the absolute values of a tensor" + output = torch.empty((), device=x.device, dtype=torch.float32) + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + BLOCK_SIZE = 1024 + abs_max_kernel[grid](x, output, n_elements=n_elements, BLOCK_SIZE=BLOCK_SIZE) + return output + + +@triton.jit +def abs_max_to_scale_kernel_e4m3(x_ptr, out_ptr, clamp_float16): + abs_max = tl.load(x_ptr).to(tl.float32) + clamped = E4M3_MAX_POS / tl.clamp(abs_max, min=EPS, max=float("inf")) + if clamp_float16: + clamped = tl.clamp(clamped, min=EPS, max=FP16_MAX_POS) + tl.store(out_ptr, clamped) + + +@triton.jit +def abs_max_to_scale_kernel_e5m2(x_ptr, out_ptr, clamp_float16): + abs_max = tl.load(x_ptr) + clamped = E5M2_MAX_POS / tl.clamp(abs_max, min=EPS, max=float("inf")) + if clamp_float16: + clamped = tl.clamp(clamped, min=EPS, max=FP16_MAX_POS) + tl.store(out_ptr, clamped) + + +def abs_max_to_scale( + x: torch.Tensor, fp8_dtype: torch.dtype, clamp_float16: bool +) -> torch.Tensor: + assert x.numel() == 1, "Expected a single value, but got: {} elements".format( + x.numel() + ) + assert x.dtype == torch.float32, "Expected a float32 tensor, but got: {}".format( + x.dtype + ) + output = torch.empty((), device=x.device, dtype=torch.float32) + grid = lambda meta: (1,) + if fp8_dtype == torch.float8_e4m3fn: + abs_max_to_scale_kernel_e4m3[grid](x, output, clamp_float16) + else: + abs_max_to_scale_kernel_e5m2[grid](x, output, clamp_float16) + return output diff --git a/test/test_fused_kernels.py b/test/test_fused_kernels.py new file mode 100644 index 00000000..d4853740 --- /dev/null +++ b/test/test_fused_kernels.py @@ -0,0 +1,26 @@ +import pytest +import torch +from float8_experimental.float8_utils import amax_to_scale, tensor_to_amax +from float8_experimental.fused_kernels.fused_casting_kernels import ( + abs_max, + abs_max_to_scale, +) + + +@pytest.mark.parametrize("numel", [2**i for i in range(10, 20)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_abs_max(numel: int, dtype: torch.dtype): + x = torch.randn(numel, dtype=dtype, device="cuda") + max_abs = abs_max(x) + assert torch.allclose(max_abs, tensor_to_amax(x)) + + +@pytest.mark.parametrize("numel", [2**i for i in range(10, 20)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("fp8_type", [torch.float8_e4m3fn, torch.float8_e5m2]) +def test_amax_to_scale(numel: int, dtype: torch.dtype, fp8_type: torch.dtype): + x = torch.randn(numel, dtype=dtype, device="cuda") + max_abs = abs_max(x) + fused = abs_max_to_scale(max_abs, fp8_type, dtype == torch.float16) + eager = amax_to_scale(max_abs, fp8_type, dtype) + assert torch.allclose(fused, eager) From 548dceb25fc667c72d3d1a3d0a2be2fc2a5f1650 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 27 Feb 2024 12:57:36 -0800 Subject: [PATCH 04/11] change the test shapes --- test/test_fused_kernels.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/test/test_fused_kernels.py b/test/test_fused_kernels.py index d4853740..06d27104 100644 --- a/test/test_fused_kernels.py +++ b/test/test_fused_kernels.py @@ -1,3 +1,5 @@ +from typing import Tuple + import pytest import torch from float8_experimental.float8_utils import amax_to_scale, tensor_to_amax @@ -7,10 +9,21 @@ ) -@pytest.mark.parametrize("numel", [2**i for i in range(10, 20)]) +@pytest.mark.parametrize( + "shape", + [ + (16384, 1024), + (16384, 8192), + (16384, 3584), + (8192, 1280), + (1024, 8192), + (8192, 7168), + (3584, 8192), + ], +) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) -def test_abs_max(numel: int, dtype: torch.dtype): - x = torch.randn(numel, dtype=dtype, device="cuda") +def test_abs_max(shape: Tuple[int], dtype: torch.dtype): + x = torch.randn(shape, dtype=dtype, device="cuda") max_abs = abs_max(x) assert torch.allclose(max_abs, tensor_to_amax(x)) From e6a1e3c88c8190f684ff833bfa7fc6db33877ccd Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 27 Feb 2024 14:11:27 -0800 Subject: [PATCH 05/11] split reduction --- .../fused_kernels/fused_casting_kernels.py | 76 +++++++++++++++---- 1 file changed, 60 insertions(+), 16 deletions(-) diff --git a/float8_experimental/fused_kernels/fused_casting_kernels.py b/float8_experimental/fused_kernels/fused_casting_kernels.py index 2d2b3f94..47c37ef9 100644 --- a/float8_experimental/fused_kernels/fused_casting_kernels.py +++ b/float8_experimental/fused_kernels/fused_casting_kernels.py @@ -11,28 +11,72 @@ @triton.jit -def abs_max_kernel(x_ptr, out_ptr, n_elements: int, BLOCK_SIZE: tl.constexpr): - offset_base = tl.arange(0, BLOCK_SIZE)[None, :] - acc = tl.full([1, BLOCK_SIZE], -float("inf"), tl.float32) - for offset in range(0, n_elements, BLOCK_SIZE): - index = offset + offset_base - mask = index < n_elements - x = tl.load(x_ptr + index, mask, eviction_policy="evict_first", other=0.0) - x_broadcast = tl.broadcast_to(x, [1, BLOCK_SIZE]) - x_abs = tl.abs(x_broadcast) - acc = tl.maximum(acc, x_abs) +def promote_to_tensor(x): + # Addition promotes to tensor for us + return x + tl.zeros((1,), tl.int1) + + +@triton.jit +def is_floating(x): + return promote_to_tensor(x).dtype.is_floating() + + +@triton.jit +def maximum(a, b): + mask = a > b + if is_floating(a): + mask |= a != a + return tl.where(mask, a, b) + + +@triton.jit +def abs_max_kernel( + x_ptr, + out_ptr, + x_numel: int, + r_numel: int, + X_BLOCK_SIZE: tl.constexpr, + R_BLOCK_SIZE: tl.constexpr, +): + x_offset = tl.program_id(0) * X_BLOCK_SIZE + x_index = x_offset + tl.arange(0, X_BLOCK_SIZE)[:, None] + x_mask = x_index < x_numel + reduction_base = tl.arange(0, R_BLOCK_SIZE)[None, :] + acc = tl.full([X_BLOCK_SIZE, R_BLOCK_SIZE], -float("inf"), tl.float32) + for r_offset in range(0, r_numel, R_BLOCK_SIZE): + r_index = r_offset + reduction_base + r_mask = r_index < r_numel + values = tl.load( + x_ptr + (r_index + (r_numel * x_index)), + x_mask & r_mask, + eviction_policy="evict_last", + other=0.0, + ).to(tl.float32) + x_abs = tl.abs(values) + x_abs_broadcasted = tl.broadcast_to(x_abs, [X_BLOCK_SIZE, R_BLOCK_SIZE]) + acc_mask = maximum(acc, x_abs_broadcasted) + acc = tl.where(x_mask, acc_mask, acc) out = tl.max(acc, 1)[:, None] - tl.store(out_ptr + (tl.full([1, 1], 0, tl.int32)), out.to(tl.float32)) + tl.store(out_ptr + x_index, out.to(tl.float32), x_mask) def abs_max(x: torch.Tensor) -> torch.Tensor: "Calculates the global max of the absolute values of a tensor" - output = torch.empty((), device=x.device, dtype=torch.float32) + output = torch.empty((512, 1), device=x.device, dtype=torch.float32) n_elements = x.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - BLOCK_SIZE = 1024 - abs_max_kernel[grid](x, output, n_elements=n_elements, BLOCK_SIZE=BLOCK_SIZE) - return output + grid = lambda meta: (512,) + X_BLOCK_SIZE = 1 + R_BLOCK_SIZE = 64 + r_numel = n_elements // 512 + abs_max_kernel[grid]( + x, + output, + x_numel=512, + r_numel=r_numel, + X_BLOCK_SIZE=X_BLOCK_SIZE, + R_BLOCK_SIZE=R_BLOCK_SIZE, + ) + return output.max() @triton.jit From da2ab5692dd1e405dbc3a08f43b2568543322128 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 27 Feb 2024 15:33:46 -0800 Subject: [PATCH 06/11] use bigger reduce block --- float8_experimental/fused_kernels/fused_casting_kernels.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/float8_experimental/fused_kernels/fused_casting_kernels.py b/float8_experimental/fused_kernels/fused_casting_kernels.py index 47c37ef9..b6ede030 100644 --- a/float8_experimental/fused_kernels/fused_casting_kernels.py +++ b/float8_experimental/fused_kernels/fused_casting_kernels.py @@ -28,7 +28,6 @@ def maximum(a, b): mask |= a != a return tl.where(mask, a, b) - @triton.jit def abs_max_kernel( x_ptr, @@ -64,9 +63,9 @@ def abs_max(x: torch.Tensor) -> torch.Tensor: "Calculates the global max of the absolute values of a tensor" output = torch.empty((512, 1), device=x.device, dtype=torch.float32) n_elements = x.numel() - grid = lambda meta: (512,) + grid = lambda meta: (meta["X_BLOCK_SIZE"],) X_BLOCK_SIZE = 1 - R_BLOCK_SIZE = 64 + R_BLOCK_SIZE = 1024 r_numel = n_elements // 512 abs_max_kernel[grid]( x, From 4526eb8e7787039bd3b6d15ec7de97aad3375067 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 27 Feb 2024 17:45:22 -0800 Subject: [PATCH 07/11] need to handle braodcasted grads --- benchmarks/bench_dynamic_linear_fused_cast.py | 3 ++- float8_experimental/float8_tensor.py | 1 + float8_experimental/fused_kernels/fused_casting_kernels.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/benchmarks/bench_dynamic_linear_fused_cast.py b/benchmarks/bench_dynamic_linear_fused_cast.py index 0f2a7a76..4259f2e7 100644 --- a/benchmarks/bench_dynamic_linear_fused_cast.py +++ b/benchmarks/bench_dynamic_linear_fused_cast.py @@ -116,7 +116,8 @@ def main( ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() def float8_forw_backward(): - linear_float8(input_tensor).sum().backward() + out = linear_float8(input_tensor) + out.sum().backward() def n_times(n, fn, *args, **kwargs): def wrapper(*args, **kwargs): diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index a2292a32..c5088b6e 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -48,6 +48,7 @@ def to_fp8_no_autograd( from driss_torch import saturated_cast bits_fp8 = saturated_cast(x, x_scale, float8_dtype) + else: x_scaled = x * x_scale bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) diff --git a/float8_experimental/fused_kernels/fused_casting_kernels.py b/float8_experimental/fused_kernels/fused_casting_kernels.py index b6ede030..11341113 100644 --- a/float8_experimental/fused_kernels/fused_casting_kernels.py +++ b/float8_experimental/fused_kernels/fused_casting_kernels.py @@ -28,6 +28,7 @@ def maximum(a, b): mask |= a != a return tl.where(mask, a, b) + @triton.jit def abs_max_kernel( x_ptr, From df940ae9eaaaf3af1ad0d3c2f74fdca575a98bd8 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 27 Feb 2024 20:16:59 -0800 Subject: [PATCH 08/11] this is failing in backward for some reason with device context failure --- float8_experimental/config.py | 3 +++ float8_experimental/float8_utils.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 8fbbb092..eb7ba22a 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -20,4 +20,7 @@ # dynamic_use_activation_hooks = True # dynamic_use_activation_hooks = False +# This is a global flag that controls whether the fused_cast kernels, +# This can offer greater performance in eager but it is still recommended +# That if you are using torch.compile to set this to False. use_fused_cast = True diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 14724e48..f8589827 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -72,7 +72,6 @@ def amax_history_to_scale_stack( def tensor_to_amax(x, distributed_reduction=False): if float8_experimental.config.use_fused_cast and x.is_cuda: from float8_experimental.fused_kernels.fused_casting_kernels import abs_max - amax = abs_max(x) else: amax = x.abs().max() From 7c0778ab44142fd44ad287f411a024e3f07ce887 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 28 Feb 2024 12:02:33 -0800 Subject: [PATCH 09/11] my abs_max is busted --- float8_experimental/float8_tensor.py | 9 ++++- float8_experimental/float8_utils.py | 9 ++++- .../fused_kernels/fused_casting_kernels.py | 37 +++++++++++-------- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index c5088b6e..bbe8e73c 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -47,8 +47,13 @@ def to_fp8_no_autograd( ): from driss_torch import saturated_cast - bits_fp8 = saturated_cast(x, x_scale, float8_dtype) - + if x.dim() in {3, 4}: + prev_x_shape = x.shape + x = x.view(-1, x.size(-1)) + bits_fp8 = saturated_cast(x, x_scale, float8_dtype) + bits_fp8 = bits_fp8.view(prev_x_shape) + else: + bits_fp8 = saturated_cast(x, x_scale, float8_dtype) else: x_scaled = x * x_scale bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index f8589827..438d1b4c 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -70,11 +70,16 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax(x, distributed_reduction=False): - if float8_experimental.config.use_fused_cast and x.is_cuda: + if False and float8_experimental.config.use_fused_cast and x.is_cuda: from float8_experimental.fused_kernels.fused_casting_kernels import abs_max + amax = abs_max(x) + diff = abs_max(x) - x.abs().max().to(torch.float32) + assert ( + diff.item() == 0 + ), f"Expected {amax} to be equal to {x.abs().max().to(torch.float32)} but got {diff}" else: - amax = x.abs().max() + amax = x.abs().max().to(torch.float32) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will diff --git a/float8_experimental/fused_kernels/fused_casting_kernels.py b/float8_experimental/fused_kernels/fused_casting_kernels.py index 11341113..83eb646d 100644 --- a/float8_experimental/fused_kernels/fused_casting_kernels.py +++ b/float8_experimental/fused_kernels/fused_casting_kernels.py @@ -48,7 +48,7 @@ def abs_max_kernel( r_mask = r_index < r_numel values = tl.load( x_ptr + (r_index + (r_numel * x_index)), - x_mask & r_mask, + r_mask, eviction_policy="evict_last", other=0.0, ).to(tl.float32) @@ -62,21 +62,26 @@ def abs_max_kernel( def abs_max(x: torch.Tensor) -> torch.Tensor: "Calculates the global max of the absolute values of a tensor" - output = torch.empty((512, 1), device=x.device, dtype=torch.float32) - n_elements = x.numel() - grid = lambda meta: (meta["X_BLOCK_SIZE"],) - X_BLOCK_SIZE = 1 - R_BLOCK_SIZE = 1024 - r_numel = n_elements // 512 - abs_max_kernel[grid]( - x, - output, - x_numel=512, - r_numel=r_numel, - X_BLOCK_SIZE=X_BLOCK_SIZE, - R_BLOCK_SIZE=R_BLOCK_SIZE, - ) - return output.max() + x = x.contiguous() + if x.numel() % 512 == 0: + output = torch.full( + (512, 1), -float("inf"), device=x.device, dtype=torch.float32 + ) + grid = lambda meta: (meta["X_BLOCK_SIZE"],) + X_BLOCK_SIZE = 1 + R_BLOCK_SIZE = 1024 + r_numel = x.numel() // 512 + abs_max_kernel[grid]( + x, + output, + x_numel=512, + r_numel=r_numel, + X_BLOCK_SIZE=X_BLOCK_SIZE, + R_BLOCK_SIZE=R_BLOCK_SIZE, + ) + return output.max() + else: + return x.abs().max().to(torch.float32) @triton.jit From 1dd45736ead3cc2fba3c871c3129a88a46b0f579 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 28 Feb 2024 12:18:38 -0800 Subject: [PATCH 10/11] fix kernel args --- float8_experimental/float8_tensor.py | 4 ++-- float8_experimental/float8_utils.py | 6 +----- .../fused_kernels/fused_casting_kernels.py | 12 ++++++++++-- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index bbe8e73c..2f65abec 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -49,9 +49,9 @@ def to_fp8_no_autograd( if x.dim() in {3, 4}: prev_x_shape = x.shape - x = x.view(-1, x.size(-1)) + x = x.reshape(-1, x.size(-1)) bits_fp8 = saturated_cast(x, x_scale, float8_dtype) - bits_fp8 = bits_fp8.view(prev_x_shape) + bits_fp8 = bits_fp8.reshape(prev_x_shape) else: bits_fp8 = saturated_cast(x, x_scale, float8_dtype) else: diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 438d1b4c..21574fba 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -70,14 +70,10 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax(x, distributed_reduction=False): - if False and float8_experimental.config.use_fused_cast and x.is_cuda: + if float8_experimental.config.use_fused_cast and x.is_cuda: from float8_experimental.fused_kernels.fused_casting_kernels import abs_max amax = abs_max(x) - diff = abs_max(x) - x.abs().max().to(torch.float32) - assert ( - diff.item() == 0 - ), f"Expected {amax} to be equal to {x.abs().max().to(torch.float32)} but got {diff}" else: amax = x.abs().max().to(torch.float32) diff --git a/float8_experimental/fused_kernels/fused_casting_kernels.py b/float8_experimental/fused_kernels/fused_casting_kernels.py index 83eb646d..06bb41d8 100644 --- a/float8_experimental/fused_kernels/fused_casting_kernels.py +++ b/float8_experimental/fused_kernels/fused_casting_kernels.py @@ -61,13 +61,21 @@ def abs_max_kernel( def abs_max(x: torch.Tensor) -> torch.Tensor: - "Calculates the global max of the absolute values of a tensor" + """Calculates the global max of the absolute values of a tensor + + This kernel launches a grid of 512 threads, each thread calculates the + maximum of x.numel // 512 elements. The results are then reduced to a single + value in a follow up kernel. + + Args: + x: Input tensor to calculate the abs_max for + """ x = x.contiguous() if x.numel() % 512 == 0: output = torch.full( (512, 1), -float("inf"), device=x.device, dtype=torch.float32 ) - grid = lambda meta: (meta["X_BLOCK_SIZE"],) + grid = lambda meta: (512,) X_BLOCK_SIZE = 1 R_BLOCK_SIZE = 1024 r_numel = x.numel() // 512 From 110ec4b07adf838fd6ca4289349683204631e203 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 5 Mar 2024 11:01:54 -0800 Subject: [PATCH 11/11] cast to fp32 in amax --- float8_experimental/float8_utils.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 21574fba..00f0ed03 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -23,8 +23,10 @@ @torch.no_grad() -def amax_to_scale(amax, float8_dtype, orig_dtype): - scale = torch.empty_like(amax, dtype=torch.float32) +def amax_to_scale( + amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype +): + assert amax.dtype == torch.float32, "amax must be a float32 tensor" if float8_dtype == torch.float8_e4m3fn: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) else: # e5m2 @@ -35,16 +37,15 @@ def amax_to_scale(amax, float8_dtype, orig_dtype): # to care about this for float32/bfloat16. if orig_dtype is torch.float16: res = torch.clamp(res, max=FP16_MAX_POS) - scale.copy_(res) - return scale + return res @torch.no_grad() def amax_history_to_scale( - amax_history, - float8_dtype, - orig_dtype, - history_to_scale_fn_type, + amax_history: torch.Tensor, + float8_dtype: torch.dtype, + orig_dtype: torch.dtype, + history_to_scale_fn_type: str, ): if history_to_scale_fn_type == "max": amax = torch.max(amax_history) @@ -87,7 +88,7 @@ def tensor_to_amax(x, distributed_reduction=False): @torch.no_grad() -def tensor_to_scale(x, float8_dtype): +def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype): amax = tensor_to_amax(x) if float8_experimental.config.use_fused_cast and x.is_cuda: from float8_experimental.fused_kernels.fused_casting_kernels import (