Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
update matmul and linear benchmarks with 7B and 13B shapes
Browse files Browse the repository at this point in the history
Summary:

as titled

also removes float16 for now since it's not super important,
and moves common constants to a reusable place

Test Plan:

```
matmul bench results:
https://gist.github.com/vkuzo/930219d6de1015dd00486c32eac8fa67

linear bench results:
https://gist.github.com/vkuzo/31a063df8d10794497949285319b5ccd
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Jan 11, 2024
1 parent dd0c596 commit 83e1c30
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 65 deletions.
49 changes: 49 additions & 0 deletions benchmarks/bench_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import torch

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 1979e12
h100_peak_tops_float8_tc = 3958e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}

name_to_shapes = {
# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
"70B": {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
},
# source: LLaMa 2 7B def, unfused ffn
"7B": {
"attn.wqkv": (4096, 12288),
"attn.w0": (4096, 4096),
"ffn.w1_or_w3": (4096, 11008),
"ffn.w2": (11008, 4096),
},
# source: LLaMa 2 13B def, unfused ffn
"13B": {
"attn.wqkv": (5120, 15360),
"attn.w0": (5120, 5120),
"ffn.w1_or_w3": (5120, 13824),
"ffn.w2": (13824, 5120),
},
}
87 changes: 54 additions & 33 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import bench_constants as bc

import pandas as pd

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.dynamic_linear.dynamic_float8_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history
from tqdm import tqdm
Expand All @@ -28,22 +31,6 @@
except ImportError:
print("transformer_engine not installed and we won't compare against this")

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 1979e12
h100_peak_tops_float8_tc = 3958e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}


def benchmark_torch_function_in_microseconds(
func: Callable,
Expand All @@ -63,6 +50,7 @@ class Experiment:
shape: Tuple[int, int, int]
ref_time_sec: float
float8_time_sec: float
float8_dynamic_time_sec: float
dtype: torch.dtype
compiled: bool = False
float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn
Expand All @@ -76,7 +64,7 @@ def ref_tops_sec(self):

@property
def ref_pct_top_peak(self):
return self.ref_tops_sec / dtype_to_peak_tops[self.dtype]
return self.ref_tops_sec / bc.dtype_to_peak_tops[self.dtype]

@property
def float8_tops_sec(self):
Expand All @@ -85,7 +73,7 @@ def float8_tops_sec(self):

@property
def float8_pct_top_peak(self):
return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype]
return self.float8_tops_sec / bc.dtype_to_peak_tops[self.float_8_dtype]

@property
def te_tops_sec(self):
Expand All @@ -98,7 +86,7 @@ def te_tops_sec(self):
@property
def te_pct_top_peak(self):
if self.te_tops_sec is not None:
return self.te_tops_sec / dtype_to_peak_tops[self.float_8_dtype]
return self.te_tops_sec / bc.dtype_to_peak_tops[self.float_8_dtype]
else:
return None

Expand All @@ -107,24 +95,27 @@ def main(
sweep_path: Path,
compile: bool,
n_limit: Optional[int] = None,
llama_model_size: str = "70B",
):
device = "cuda"
print(f"Compile is set to | {compile}")
print("model size:", llama_model_size)

name_to_shapes = bc.name_to_shapes[llama_model_size]
if llama_model_size == "70B":
# common distributed setup, single GPU numbers
bsz, seq_len = 4, 4096
else:
# debug single gpu setup
bsz, seq_len = 1, 4096

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}
input_bias = False
ref_dtypes = [torch.bfloat16, torch.float16]
ref_dtypes = [
torch.bfloat16,
]
experiment_list: List[Experiment] = []
for idx, (dtype, (name, (K, N))) in enumerate(
tqdm(list(product(ref_dtypes, name_to_shapes_70b.items())))
tqdm(list(product(ref_dtypes, name_to_shapes.items())))
):
if n_limit is not None and idx >= n_limit:
break
Expand All @@ -136,7 +127,10 @@ def main(
copy.deepcopy(linear_ref), emulate=False
)

bsz, seq_len = 4, 4096
linear_dynamic_float8 = Float8DynamicLinear.from_float(
copy.deepcopy(linear_ref), emulate=False
)

M = bsz * seq_len
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()
Expand All @@ -145,6 +139,10 @@ def float8_forw_backward():
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

float8_dynamic_forw_backward = (
lambda: linear_dynamic_float8(input_tensor).sum().backward()
)

if transformer_engine_installed:
# Use the same recipe as float8_linear.DelayedScalingRecipe
fp8_format = recipe.Format.HYBRID
Expand All @@ -169,19 +167,23 @@ def wrapper(*args, **kwargs):

ref_forw_backward = n_times(REPEAT_N, ref_forw_backward)
float8_forw_backward = n_times(REPEAT_N, float8_forw_backward)
float8_dynamic_forw_backward = n_times(REPEAT_N, float8_dynamic_forw_backward)
if transformer_engine_installed:
te_forw_backward = n_times(REPEAT_N, te_forw_backward)

if compile:
ref_forw_backward = torch.compile(ref_forw_backward)
float8_forw_backward = torch.compile(float8_forw_backward)
float8_dynamic_forw_backward = torch.compile(float8_dynamic_forw_backward)
# Compiling TE_linear fails but they are already compiling under the hood
# if transformer_engine_installed:
# te_forw_backward = torch.compile(te_forw_backward)

# warmup
for _ in range(5):
ref_forw_backward()
float8_forw_backward()
float8_dynamic_forw_backward()
if transformer_engine_installed:
te_forw_backward()

Expand All @@ -195,6 +197,11 @@ def wrapper(*args, **kwargs):
* 1e-6
/ REPEAT_N
)
float8_dynamic_time = (
benchmark_torch_function_in_microseconds(float8_dynamic_forw_backward)
* 1e-6
/ REPEAT_N
)
if transformer_engine_installed:
te_time_sec = (
benchmark_torch_function_in_microseconds(te_forw_backward)
Expand All @@ -208,12 +215,17 @@ def wrapper(*args, **kwargs):
(M, K, N),
ref_time,
float8_time,
float8_dynamic_time,
dtype,
compile,
te_time_sec=te_time_sec,
)
print(experiment)
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
print(
"float8 dynamic speedup",
experiment.ref_time_sec / experiment.float8_dynamic_time_sec,
)
if transformer_engine_installed:
print("te speedup", experiment.ref_time_sec / experiment.te_time_sec)
experiment_list.append(experiment)
Expand All @@ -229,6 +241,7 @@ def wrapper(*args, **kwargs):
"fp8_dtype",
"ref_time_sec",
"pt_fp8_time_sec",
"pt_fp8_dynamic_time_sec",
"te_fp8_time_sec",
"ref_tops_sec",
"ref_pct_top_peak",
Expand All @@ -250,6 +263,7 @@ def wrapper(*args, **kwargs):
experiment.float_8_dtype,
experiment.ref_time_sec,
experiment.float8_time_sec,
experiment.float8_dynamic_time_sec,
experiment.te_time_sec,
experiment.ref_tops_sec,
experiment.ref_pct_top_peak,
Expand All @@ -262,6 +276,9 @@ def wrapper(*args, **kwargs):

data_pd = pd.DataFrame(data, columns=headers)
data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"]
data_pd["pt_fp8_dynamic_speedup"] = (
data_pd["ref_time_sec"] / data_pd["pt_fp8_dynamic_time_sec"]
)
if transformer_engine_installed:
data_pd["te_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["te_fp8_time_sec"]
else:
Expand All @@ -280,12 +297,13 @@ def wrapper(*args, **kwargs):
[
"name",
"shape",
"ref_dtype",
"compiled",
"ref_time_sec",
"pt_fp8_time_sec",
"pt_fp8_dynamic_time_sec",
"te_fp8_time_sec",
"pt_fp8_speedup",
"pt_fp8_dynamic_speedup",
"te_fp8_speedup",
]
]
Expand All @@ -301,9 +319,12 @@ def invoke_main() -> None:
parser.add_argument("-o", "--output_path", type=str, required=True)
parser.add_argument("--compile", action="store_true")
parser.add_argument("-n", "--n_limit", type=int, required=False)
parser.add_argument(
"--llama_model_size", type=str, required=True, choices=["70B", "7B", "13B"]
)
args = parser.parse_args()
output_path = Path(args.output_path)
main(output_path, args.compile, args.n_limit)
main(output_path, args.compile, args.n_limit, args.llama_model_size)


if __name__ == "__main__":
Expand Down
53 changes: 21 additions & 32 deletions benchmarks/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,15 @@
import itertools
from typing import Optional

import bench_constants as bc

import fire
import pandas as pd

import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 989e12
h100_peak_tops_float8_tc = 1979e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}


def benchmark_fn_in_sec(f, *args, **kwargs):
# Manual warmup
Expand All @@ -50,25 +36,25 @@ def do_benchmarks(tops, peak_tops, f, *args, **kwargs):


@torch.inference_mode()
def run(n_limit: Optional[int] = None):
def run(
llama_model_size: Optional[str] = "70B",
n_limit: Optional[int] = None,
output_path: Optional[str] = None,
):
print("model size", llama_model_size)
device = "cuda"

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}

headers = ("name", "shape", "dtype", "ref_time_s", "fp8_time_s", "fp8_speedup")
results = []

name_to_shapes = name_to_shapes_70b
bsz_and_seq_len = ((4, 4096),)
dtypes = torch.bfloat16, torch.float16
name_to_shapes = bc.name_to_shapes[llama_model_size]
if llama_model_size == "70B":
# common distributed setup, single GPU numbers
bsz_and_seq_len = ((4, 4096),)
else:
# debug single gpu setup
bsz_and_seq_len = ((1, 4096),)
dtypes = (torch.bfloat16,)

for idx, (dtype, (name, (K, N))) in enumerate(
itertools.product(dtypes, name_to_shapes.items())
Expand All @@ -88,7 +74,7 @@ def run(n_limit: Optional[int] = None):
A = torch.randn(M, K, device=device, dtype=dtype)
m_ref = nn.Sequential(nn.Linear(K, N, dtype=dtype, device=device, bias=False))
ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks(
tops, dtype_to_peak_tops[dtype], m_ref, A
tops, bc.dtype_to_peak_tops[dtype], m_ref, A
)
print(
f"{dtype} time_sec {ref_time_sec:.2E}, tops/sec {ref_tops_sec:.2E}, pct_peak {ref_pct_top_peak:.3f}"
Expand All @@ -106,7 +92,7 @@ def do_matmul(A, B):
return torch._scaled_mm(A, B, out_dtype=d3, use_fast_accum=False)

fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
tops, dtype_to_peak_tops[d1], do_matmul, A, B
tops, bc.dtype_to_peak_tops[d1], do_matmul, A, B
)
print(
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
Expand All @@ -127,6 +113,9 @@ def do_matmul(A, B):

data_pd = pd.DataFrame(results, columns=headers)
print(data_pd)
if output_path is not None:
with open(output_path, mode="w") as file:
data_pd.to_csv(output_path)


def main() -> None:
Expand Down

0 comments on commit 83e1c30

Please sign in to comment.