diff --git a/benchmarks/bench_constants.py b/benchmarks/bench_constants.py new file mode 100644 index 00000000..dce51bba --- /dev/null +++ b/benchmarks/bench_constants.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +# estimating TOPs for matmuls in fp32, fp16, fp8 +# assuming A * B = C, with A being M * K, B being K * N, C being M * N + +# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ +h100_peak_flops_float32 = 67e12 +h100_peak_flops_fp16_tc = 1979e12 +h100_peak_tops_float8_tc = 3958e12 + +dtype_to_peak_tops = { + torch.float32: h100_peak_flops_float32, + torch.float16: h100_peak_flops_fp16_tc, + torch.bfloat16: h100_peak_flops_fp16_tc, + torch.float8_e4m3fn: h100_peak_tops_float8_tc, + torch.float8_e5m2: h100_peak_tops_float8_tc, +} + +name_to_shapes = { + # LLaMa 2 70B single-node weight shapes + # assumes fused attn.wqkv and ffn.w13 + # source: https://fburl.com/gsheet/g8onr7rh + "70B": { + "attn.wqkv": (8192, 1280), + "attn.w0": (1024, 8192), + "ffn.w13": (8192, 7168), + "ffn.w2": (3584, 8192), + }, + # source: LLaMa 2 7B def, unfused ffn + "7B": { + "attn.wqkv": (4096, 12288), + "attn.w0": (4096, 4096), + "ffn.w1_or_w3": (4096, 11008), + "ffn.w2": (11008, 4096), + }, + # source: LLaMa 2 13B def, unfused ffn + "13B": { + "attn.wqkv": (5120, 15360), + "attn.w0": (5120, 5120), + "ffn.w1_or_w3": (5120, 13824), + "ffn.w2": (13824, 5120), + }, +} diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index eef8f41c..90c38ca6 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -10,10 +10,13 @@ from pathlib import Path from typing import Callable, List, Optional, Tuple +import bench_constants as bc + import pandas as pd import torch import torch.utils.benchmark as benchmark +from float8_experimental.dynamic_linear.dynamic_float8_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history from tqdm import tqdm @@ -28,22 +31,6 @@ except ImportError: print("transformer_engine not installed and we won't compare against this") -# estimating TOPs for matmuls in fp32, fp16, fp8 -# assuming A * B = C, with A being M * K, B being K * N, C being M * N - -# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ -h100_peak_flops_float32 = 67e12 -h100_peak_flops_fp16_tc = 1979e12 -h100_peak_tops_float8_tc = 3958e12 - -dtype_to_peak_tops = { - torch.float32: h100_peak_flops_float32, - torch.float16: h100_peak_flops_fp16_tc, - torch.bfloat16: h100_peak_flops_fp16_tc, - torch.float8_e4m3fn: h100_peak_tops_float8_tc, - torch.float8_e5m2: h100_peak_tops_float8_tc, -} - def benchmark_torch_function_in_microseconds( func: Callable, @@ -63,6 +50,7 @@ class Experiment: shape: Tuple[int, int, int] ref_time_sec: float float8_time_sec: float + float8_dynamic_time_sec: float dtype: torch.dtype compiled: bool = False float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn @@ -76,7 +64,7 @@ def ref_tops_sec(self): @property def ref_pct_top_peak(self): - return self.ref_tops_sec / dtype_to_peak_tops[self.dtype] + return self.ref_tops_sec / bc.dtype_to_peak_tops[self.dtype] @property def float8_tops_sec(self): @@ -85,7 +73,7 @@ def float8_tops_sec(self): @property def float8_pct_top_peak(self): - return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype] + return self.float8_tops_sec / bc.dtype_to_peak_tops[self.float_8_dtype] @property def te_tops_sec(self): @@ -98,7 +86,7 @@ def te_tops_sec(self): @property def te_pct_top_peak(self): if self.te_tops_sec is not None: - return self.te_tops_sec / dtype_to_peak_tops[self.float_8_dtype] + return self.te_tops_sec / bc.dtype_to_peak_tops[self.float_8_dtype] else: return None @@ -107,24 +95,27 @@ def main( sweep_path: Path, compile: bool, n_limit: Optional[int] = None, + llama_model_size: str = "70B", ): device = "cuda" print(f"Compile is set to | {compile}") + print("model size:", llama_model_size) + + name_to_shapes = bc.name_to_shapes[llama_model_size] + if llama_model_size == "70B": + # common distributed setup, single GPU numbers + bsz, seq_len = 4, 4096 + else: + # debug single gpu setup + bsz, seq_len = 1, 4096 - # LLaMa 2 70B single-node weight shapes - # assumes fused attn.wqkv and ffn.w13 - # source: https://fburl.com/gsheet/g8onr7rh - name_to_shapes_70b = { - "attn.wqkv": (8192, 1280), - "attn.w0": (1024, 8192), - "ffn.w13": (8192, 7168), - "ffn.w2": (3584, 8192), - } input_bias = False - ref_dtypes = [torch.bfloat16, torch.float16] + ref_dtypes = [ + torch.bfloat16, + ] experiment_list: List[Experiment] = [] for idx, (dtype, (name, (K, N))) in enumerate( - tqdm(list(product(ref_dtypes, name_to_shapes_70b.items()))) + tqdm(list(product(ref_dtypes, name_to_shapes.items()))) ): if n_limit is not None and idx >= n_limit: break @@ -136,7 +127,10 @@ def main( copy.deepcopy(linear_ref), emulate=False ) - bsz, seq_len = 4, 4096 + linear_dynamic_float8 = Float8DynamicLinear.from_float( + copy.deepcopy(linear_ref), emulate=False + ) + M = bsz * seq_len input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True) ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() @@ -145,6 +139,10 @@ def float8_forw_backward(): sync_float8_amax_and_scale_history(linear_float8) linear_float8(input_tensor).sum().backward() + float8_dynamic_forw_backward = ( + lambda: linear_dynamic_float8(input_tensor).sum().backward() + ) + if transformer_engine_installed: # Use the same recipe as float8_linear.DelayedScalingRecipe fp8_format = recipe.Format.HYBRID @@ -169,19 +167,23 @@ def wrapper(*args, **kwargs): ref_forw_backward = n_times(REPEAT_N, ref_forw_backward) float8_forw_backward = n_times(REPEAT_N, float8_forw_backward) + float8_dynamic_forw_backward = n_times(REPEAT_N, float8_dynamic_forw_backward) if transformer_engine_installed: te_forw_backward = n_times(REPEAT_N, te_forw_backward) if compile: ref_forw_backward = torch.compile(ref_forw_backward) float8_forw_backward = torch.compile(float8_forw_backward) + float8_dynamic_forw_backward = torch.compile(float8_dynamic_forw_backward) # Compiling TE_linear fails but they are already compiling under the hood # if transformer_engine_installed: # te_forw_backward = torch.compile(te_forw_backward) + # warmup for _ in range(5): ref_forw_backward() float8_forw_backward() + float8_dynamic_forw_backward() if transformer_engine_installed: te_forw_backward() @@ -195,6 +197,11 @@ def wrapper(*args, **kwargs): * 1e-6 / REPEAT_N ) + float8_dynamic_time = ( + benchmark_torch_function_in_microseconds(float8_dynamic_forw_backward) + * 1e-6 + / REPEAT_N + ) if transformer_engine_installed: te_time_sec = ( benchmark_torch_function_in_microseconds(te_forw_backward) @@ -208,12 +215,17 @@ def wrapper(*args, **kwargs): (M, K, N), ref_time, float8_time, + float8_dynamic_time, dtype, compile, te_time_sec=te_time_sec, ) print(experiment) print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) + print( + "float8 dynamic speedup", + experiment.ref_time_sec / experiment.float8_dynamic_time_sec, + ) if transformer_engine_installed: print("te speedup", experiment.ref_time_sec / experiment.te_time_sec) experiment_list.append(experiment) @@ -229,6 +241,7 @@ def wrapper(*args, **kwargs): "fp8_dtype", "ref_time_sec", "pt_fp8_time_sec", + "pt_fp8_dynamic_time_sec", "te_fp8_time_sec", "ref_tops_sec", "ref_pct_top_peak", @@ -250,6 +263,7 @@ def wrapper(*args, **kwargs): experiment.float_8_dtype, experiment.ref_time_sec, experiment.float8_time_sec, + experiment.float8_dynamic_time_sec, experiment.te_time_sec, experiment.ref_tops_sec, experiment.ref_pct_top_peak, @@ -262,6 +276,9 @@ def wrapper(*args, **kwargs): data_pd = pd.DataFrame(data, columns=headers) data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"] + data_pd["pt_fp8_dynamic_speedup"] = ( + data_pd["ref_time_sec"] / data_pd["pt_fp8_dynamic_time_sec"] + ) if transformer_engine_installed: data_pd["te_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["te_fp8_time_sec"] else: @@ -280,12 +297,13 @@ def wrapper(*args, **kwargs): [ "name", "shape", - "ref_dtype", "compiled", "ref_time_sec", "pt_fp8_time_sec", + "pt_fp8_dynamic_time_sec", "te_fp8_time_sec", "pt_fp8_speedup", + "pt_fp8_dynamic_speedup", "te_fp8_speedup", ] ] @@ -301,9 +319,12 @@ def invoke_main() -> None: parser.add_argument("-o", "--output_path", type=str, required=True) parser.add_argument("--compile", action="store_true") parser.add_argument("-n", "--n_limit", type=int, required=False) + parser.add_argument( + "--llama_model_size", type=str, required=True, choices=["70B", "7B", "13B"] + ) args = parser.parse_args() output_path = Path(args.output_path) - main(output_path, args.compile, args.n_limit) + main(output_path, args.compile, args.n_limit, args.llama_model_size) if __name__ == "__main__": diff --git a/benchmarks/bench_matmul.py b/benchmarks/bench_matmul.py index f3e54313..7f16bc14 100644 --- a/benchmarks/bench_matmul.py +++ b/benchmarks/bench_matmul.py @@ -7,6 +7,8 @@ import itertools from typing import Optional +import bench_constants as bc + import fire import pandas as pd @@ -14,22 +16,6 @@ import torch.nn as nn import torch.utils.benchmark as benchmark -# estimating TOPs for matmuls in fp32, fp16, fp8 -# assuming A * B = C, with A being M * K, B being K * N, C being M * N - -# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ -h100_peak_flops_float32 = 67e12 -h100_peak_flops_fp16_tc = 989e12 -h100_peak_tops_float8_tc = 1979e12 - -dtype_to_peak_tops = { - torch.float32: h100_peak_flops_float32, - torch.float16: h100_peak_flops_fp16_tc, - torch.bfloat16: h100_peak_flops_fp16_tc, - torch.float8_e4m3fn: h100_peak_tops_float8_tc, - torch.float8_e5m2: h100_peak_tops_float8_tc, -} - def benchmark_fn_in_sec(f, *args, **kwargs): # Manual warmup @@ -50,25 +36,25 @@ def do_benchmarks(tops, peak_tops, f, *args, **kwargs): @torch.inference_mode() -def run(n_limit: Optional[int] = None): +def run( + llama_model_size: Optional[str] = "70B", + n_limit: Optional[int] = None, + output_path: Optional[str] = None, +): + print("model size", llama_model_size) device = "cuda" - # LLaMa 2 70B single-node weight shapes - # assumes fused attn.wqkv and ffn.w13 - # source: https://fburl.com/gsheet/g8onr7rh - name_to_shapes_70b = { - "attn.wqkv": (8192, 1280), - "attn.w0": (1024, 8192), - "ffn.w13": (8192, 7168), - "ffn.w2": (3584, 8192), - } - headers = ("name", "shape", "dtype", "ref_time_s", "fp8_time_s", "fp8_speedup") results = [] - name_to_shapes = name_to_shapes_70b - bsz_and_seq_len = ((4, 4096),) - dtypes = torch.bfloat16, torch.float16 + name_to_shapes = bc.name_to_shapes[llama_model_size] + if llama_model_size == "70B": + # common distributed setup, single GPU numbers + bsz_and_seq_len = ((4, 4096),) + else: + # debug single gpu setup + bsz_and_seq_len = ((1, 4096),) + dtypes = (torch.bfloat16,) for idx, (dtype, (name, (K, N))) in enumerate( itertools.product(dtypes, name_to_shapes.items()) @@ -88,7 +74,7 @@ def run(n_limit: Optional[int] = None): A = torch.randn(M, K, device=device, dtype=dtype) m_ref = nn.Sequential(nn.Linear(K, N, dtype=dtype, device=device, bias=False)) ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks( - tops, dtype_to_peak_tops[dtype], m_ref, A + tops, bc.dtype_to_peak_tops[dtype], m_ref, A ) print( f"{dtype} time_sec {ref_time_sec:.2E}, tops/sec {ref_tops_sec:.2E}, pct_peak {ref_pct_top_peak:.3f}" @@ -106,7 +92,7 @@ def do_matmul(A, B): return torch._scaled_mm(A, B, out_dtype=d3, use_fast_accum=False) fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks( - tops, dtype_to_peak_tops[d1], do_matmul, A, B + tops, bc.dtype_to_peak_tops[d1], do_matmul, A, B ) print( f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}" @@ -127,6 +113,9 @@ def do_matmul(A, B): data_pd = pd.DataFrame(results, columns=headers) print(data_pd) + if output_path is not None: + with open(output_path, mode="w") as file: + data_pd.to_csv(output_path) def main() -> None: