diff --git a/benchmarks/bench_padding.py b/benchmarks/bench_padding.py new file mode 100644 index 0000000..af036d6 --- /dev/null +++ b/benchmarks/bench_padding.py @@ -0,0 +1,204 @@ +from dataclasses import dataclass +from typing import Optional + +import fire + +import torch +from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_utils import pad_tensor_for_matmul +from tabulate import tabulate +from torch._inductor.utils import do_bench_using_profiling +from tqdm import tqdm + +# estimating TOPs for matmuls in fp32, fp16, fp8 +# assuming A * B = C, with A being M * K, B being K * N, C being M * N + +# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ +h100_peak_flops_float32 = 67e12 +h100_peak_flops_fp16_tc = 1979e12 +h100_peak_tops_float8_tc = 3958e12 + +dtype_to_peak_tops = { + torch.float32: h100_peak_flops_float32, + torch.float16: h100_peak_flops_fp16_tc, + torch.bfloat16: h100_peak_flops_fp16_tc, + torch.float8_e4m3fn: h100_peak_tops_float8_tc, + torch.float8_e5m2: h100_peak_tops_float8_tc, +} + + +def benchmark_fn_in_usec(f, *args, **kwargs): + no_args = lambda: f(*args, **kwargs) + time = do_bench_using_profiling(no_args) + return time * 1e3 + + +def get_tops_info(tops, time, peak_tops): + time_sec = time / 1e6 + tops_sec = float(tops) / time_sec + pct_top_peak = tops_sec / peak_tops + return tops_sec, pct_top_peak + + +def do_fp8_matmul(A, B, fp8_dtype, out_dtype): + scale_a = torch.tensor([1], device="cuda", dtype=torch.float32) + scale_b = torch.tensor([1], device="cuda", dtype=torch.float32) + + a_config = ScaledMMConfig( + emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True + ) + b_config = ScaledMMConfig( + emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True + ) + + a_fp8 = Float8Tensor.to_float8(A, scale_a, fp8_dtype, mm_config=a_config) + b_fp8 = Float8Tensor.to_float8(B, scale_b, fp8_dtype, mm_config=b_config) + + return a_fp8 @ b_fp8 + + +def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype): + # Breaks with compile due to trying to pad on fp8 dtype + # return do_fp8_matmul(A, B, fp8_dtype, out_dtype) + A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy + B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy + + scale_a = torch.tensor([1], device="cuda", dtype=torch.float32) + scale_b = torch.tensor([1], device="cuda", dtype=torch.float32) + + A_pad = A_pad.to(fp8_dtype) # mem copy + B_pad = B_pad.to(fp8_dtype) # mem copy + + B_pad = B_pad.t().contiguous().t() # mem copy + + return torch._scaled_mm( + A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True + ) + + +def do_hp_matmul(A, B): + return torch.matmul(A, B) + + +def do_aligned_bf16_matmul(A, B): + A_pad = pad_tensor_for_matmul(A, dims=1) + B_pad = pad_tensor_for_matmul(B, dims=0) + return torch.matmul(A_pad, B_pad) + + +@dataclass +class Experiment_config: + M: int + K: int + N: int + output_dtype: torch.dtype + fp8_dtype: torch.dtype + + def __iter__(self): + return iter((self.M, self.K, self.N, self.output_dtype, self.fp8_dtype)) + + +def gen_configs(): + shapes = shapes = [ + (8193, 2501, 5008), + (65, 253, 4096), + (1023, 1029, 2512), + (4095, 511, 10000), + (2047, 3073, 8192), + (511, 769, 7504), + (127, 4097, 12288), + (32769, 15, 15024), + (9217, 8191, 20480), + (16385, 1025, 25008), + ] + output_dtype = torch.bfloat16 + fp8_dtype = torch.float8_e4m3fn + return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes] + + +@torch.no_grad() +def run(compile: bool = False, n_limit: Optional[int] = None): + device = "cuda" + experiments = gen_configs() + results = [] + tops_table = [] + tops_headers = [ + "Shape", + "Ref Dtype", + "Ref Tops", + "Aligned BF16 Tops", + "FP8 Tops", + "Ref % Peak", + "Aligned BF16 % Peak", + "FP8 % Peak", + ] + + for experiment in tqdm(experiments): + M, K, N, output_dtype, fp8_dtype = experiment + tops = 2 * M * N * K + + A_base = torch.rand(M, K, device=device, dtype=output_dtype) + B_base = torch.rand(K, N, device=device, dtype=output_dtype) + + hp_func = torch.compile(do_hp_matmul) if compile else do_hp_matmul + aligned_bf16_func = ( + torch.compile(do_aligned_bf16_matmul) if compile else do_aligned_bf16_matmul + ) + fp8_func = torch.compile(do_fp8_pad_first_matmul) if compile else do_fp8_matmul + + ref_time = benchmark_fn_in_usec(hp_func, A_base, B_base) + aligned_bf16_time = benchmark_fn_in_usec(aligned_bf16_func, A_base, B_base) + fp8_time = benchmark_fn_in_usec( + fp8_func, A_base, B_base, fp8_dtype, output_dtype + ) + + ref_tops_sec, ref_pct_top_peak = get_tops_info( + tops, ref_time, dtype_to_peak_tops[output_dtype] + ) + aligned_bf16_tops_sec, aligned_bf16_pct_top_peak = get_tops_info( + tops, aligned_bf16_time, dtype_to_peak_tops[torch.bfloat16] + ) + fp8_tops_sec, fp8_pct_top_peak = get_tops_info( + tops, fp8_time, dtype_to_peak_tops[fp8_dtype] + ) + tops_table.append( + [ + f"({M}x{K}x{N})", + f"{output_dtype}", + f"{ref_tops_sec:.2E}", + f"{aligned_bf16_tops_sec:.2E}", + f"{fp8_tops_sec:.2E}", + f"{ref_pct_top_peak:.3f}", + f"{aligned_bf16_pct_top_peak:.3f}", + f"{fp8_pct_top_peak:.3f}", + ] + ) + results.append( + [ + (M, K, N), + output_dtype, + ref_time, + aligned_bf16_time, + fp8_time, + ref_time / aligned_bf16_time, + ref_time / fp8_time, + ] + ) + + print("TOPs".center(80, "*")) + print(tabulate(tops_table, headers=tops_headers)) + print("Speed Results".center(80, "*")) + headers = [ + "Shape", + "Ref Dtype", + "Ref Time", + "Aligned BF16 Time", + "FP8 Time", + "Aligned BF16 Speedup", + "FP8 Speedup", + ] + print(tabulate(results, headers=headers, tablefmt="grid")) + + +if __name__ == "__main__": + fire.Fire(run) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 41b278c..99574c0 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -23,3 +23,9 @@ # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. use_fnuz_dtype = False + +# If True, then prior to performing the fp8 scaled mamtmul we will pad the +# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls +# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16. +# This can cause a memory spike however so we keep this off by default. +pad_inner_dim = False diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 0d4dbc0..ef0be7f 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -88,8 +88,19 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": "bias": False, } new_mod = cls(**super_kwargs) - new_mod.forward_config = ScaledMMConfig(emulate, not bool(emulate)) - new_mod.backward_config = ScaledMMConfig(emulate, False) + + new_mod.forward_config = ScaledMMConfig( + emulate=emulate, + use_fast_accum=not bool(emulate), + fp8_output=False, + pad_inner_dim=config.pad_inner_dim, + ) + new_mod.backward_config = ScaledMMConfig( + emulate=emulate, + use_fast_accum=False, + fp8_output=False, + pad_inner_dim=config.pad_inner_dim, + ) if config.enable_fsdp_fp8_all_gather: new_mod.weight = nn.Parameter( WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 35c03c0..3b3caed 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -347,6 +347,10 @@ def from_float(cls, mod, emulate: bool = False): new_mod.create_buffers() # Defines the behavior of the matmul in the forward and backward # Forward we use fast_accum, backwards we do not - new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False) - new_mod.backward_config = ScaledMMConfig(emulate, False) + new_mod.forward_config = ScaledMMConfig( + emulate, True if not emulate else False, False, config.pad_inner_dim + ) + new_mod.backward_config = ScaledMMConfig( + emulate, False, False, config.pad_inner_dim + ) return new_mod diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index ffe6491..ea2cb67 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -13,7 +13,8 @@ merge_mm_configs, ScaledMMConfig, ) -from float8_experimental.float8_utils import is_row_major +from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul + from torch.utils._pytree import tree_map aten = torch.ops.aten @@ -121,6 +122,16 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): a_scale = a._scale b_data = b._data + if a._mm_config.pad_inner_dim: + assert ( + b._mm_config.pad_inner_dim + ), "Both mm configs must have pad_inner_dim set to True" + assert a._data.size(1) == b._data.size( + 0 + ), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}" + a_data = pad_tensor_for_matmul(a_data, dims=1) + b_data = pad_tensor_for_matmul(b_data, dims=0) + if not is_row_major(a_data.stride()): a_data = a_data.contiguous() if is_row_major(b_data.stride()): diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py index 3b752d7..d8aa081 100644 --- a/float8_experimental/float8_python_api.py +++ b/float8_experimental/float8_python_api.py @@ -9,7 +9,6 @@ to simplify the product code. """ - from typing import Optional import float8_experimental.float8_aten_api # noqa diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 5c8e9a8..4644a4b 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -23,10 +23,11 @@ # emulate: whether to emulate the matmuls in fp32 # use_fast_accum: whether to use the fast-accumulation option for scaled_mm # fp8_output: whether to output the result of the scaled_mm in fp8 +# pad_inner_dim: whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16. ScaledMMConfig = namedtuple( "ScaledMMConfig", - ["emulate", "use_fast_accum", "fp8_output"], - defaults=[False, False, False], + ["emulate", "use_fast_accum", "fp8_output", "pad_inner_dim"], + defaults=[False, False, False, False], ) @@ -48,6 +49,7 @@ def merge_mm_configs( emulate=a_mm_config.emulate, use_fast_accum=a_mm_config.use_fast_accum and b_mm_config.use_fast_accum, fp8_output=a_mm_config.fp8_output and b_mm_config.fp8_output, + pad_inner_dim=a_mm_config.pad_inner_dim and b_mm_config.pad_inner_dim, ) diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index f9ae70a..f6d95a9 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Literal, Tuple +from typing import Iterable, Literal, Tuple, Union import float8_experimental.config as config @@ -179,3 +179,69 @@ def fp8_tensor_statistics( def is_row_major(stride): assert len(stride) == 2, "is_row_major only supports 2D tensors" return stride[0] > stride[1] and stride[1] == 1 + + +def _get_min_alignment(size: int, alignment_value: int) -> int: + """ + Returns the minimum alignment value that is greater than or equal to the given size. + + Args: + size: The size of the data to be aligned. + alignment_value: The alignment value to be used. + + Returns: + int: The minimum alignment value that is greater than or equal to the given size. + + Usage: + ``` + >>> _get_min_alignment(10, 8) + 16 + ``` + """ + if size % alignment_value == 0: + return size + return (1 + (size // alignment_value)) * alignment_value + + +def pad_tensor_for_matmul( + tensor: torch.Tensor, dims: Union[int, Iterable[int]] +) -> torch.Tensor: + """ + Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required `torch._scaled_mm` + + Args: + tensor: The tensor to pad. + both: Whether to pad both dimensions or just the second dimension. + + Returns: + torch.Tensor: The padded tensor. + + Usage: + ``` + >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=0).shape + torch.Size([16, 10]) + >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=1).shape + torch.Size([10, 16]) + >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=(0, 1)).shape + torch.Size([16, 16]) + ``` + """ + assert tensor.dim() == 2 + dim1, dim2 = tensor.shape + + if isinstance(dims, int): + dims = (dims,) + + # Calculate aligned dimensions based on the specified dims + dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1 + dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2 + + # Check if padding is needed for either dimension + if dim1 == dim1_aligned and dim2 == dim2_aligned: + return tensor + + # Calculate padding values for both dimensions + pad_dim1 = dim1_aligned - dim1 + pad_dim2 = dim2_aligned - dim2 + + return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1)) diff --git a/test/test_base.py b/test/test_base.py index da9da87..b688ccb 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import itertools import random +import re import unittest import warnings @@ -387,6 +388,59 @@ def test_merge_configs(self): assert c.use_fast_accum is True assert c.fp8_output is False + @unittest.skipIf( + not is_H100, + "CUDA not available", + ) + @pytest.mark.parametrize( + "base_dtype", [torch.float16, torch.bfloat16, torch.float32] + ) + @pytest.mark.parametrize("use_fast_accum", [True, False]) + def test_pad_inner_dim(self, base_dtype, use_fast_accum): + torch.manual_seed(42) + input_dtype = torch.float8_e4m3fn + compare_type = torch.float32 + + a = torch.randn(16, 41, device="cuda", dtype=base_dtype) + b = torch.randn(41, 128, device="cuda", dtype=base_dtype) + + a_scale = tensor_to_scale(a, input_dtype).float() + b_scale = tensor_to_scale(b, input_dtype).float() + + a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype) + b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype) + + with pytest.raises( + RuntimeError, + match=re.escape( + "Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41." + ), + ): + a_fp8 @ b_fp8 + + pad_config = ScaledMMConfig(False, use_fast_accum, False, True) + + a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype, mm_config=pad_config) + b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype, mm_config=pad_config) + out_padded = a_fp8 @ b_fp8 + out_padded.to(compare_type) + + emulated_conifg = ScaledMMConfig(True, use_fast_accum, False, False) + a_fp8 = Float8Tensor.to_float8( + a, a_scale, input_dtype, mm_config=emulated_conifg + ) + b_fp8 = Float8Tensor.to_float8( + b, b_scale, input_dtype, mm_config=emulated_conifg + ) + out_emualted = a_fp8 @ b_fp8 + out_emualted.to(compare_type) + + if base_dtype in {torch.bfloat16, torch.float16}: + atol, rtol = 7e-2, 7e-2 + else: + atol, rtol = 2e-3, 2e-3 + torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol) + class TestNumerics: @pytest.mark.parametrize(