Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

update matmul and linear benchmarks with 7B and 13B shapes #179

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions benchmarks/bench_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import torch

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 1979e12
h100_peak_tops_float8_tc = 3958e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}

name_to_shapes = {
# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
"70B": {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
},
# source: LLaMa 2 7B def, unfused ffn
"7B": {
"attn.wqkv": (4096, 12288),
"attn.w0": (4096, 4096),
"ffn.w1_or_w3": (4096, 11008),
"ffn.w2": (11008, 4096),
},
# source: LLaMa 2 13B def, unfused ffn
"13B": {
"attn.wqkv": (5120, 15360),
"attn.w0": (5120, 5120),
"ffn.w1_or_w3": (5120, 13824),
"ffn.w2": (13824, 5120),
},
}
87 changes: 54 additions & 33 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import bench_constants as bc

import pandas as pd

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.dynamic_linear.dynamic_float8_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history
from tqdm import tqdm
Expand All @@ -28,22 +31,6 @@
except ImportError:
print("transformer_engine not installed and we won't compare against this")

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 1979e12
h100_peak_tops_float8_tc = 3958e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}


def benchmark_torch_function_in_microseconds(
func: Callable,
Expand All @@ -63,6 +50,7 @@ class Experiment:
shape: Tuple[int, int, int]
ref_time_sec: float
float8_time_sec: float
float8_dynamic_time_sec: float
dtype: torch.dtype
compiled: bool = False
float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn
Expand All @@ -76,7 +64,7 @@ def ref_tops_sec(self):

@property
def ref_pct_top_peak(self):
return self.ref_tops_sec / dtype_to_peak_tops[self.dtype]
return self.ref_tops_sec / bc.dtype_to_peak_tops[self.dtype]

@property
def float8_tops_sec(self):
Expand All @@ -85,7 +73,7 @@ def float8_tops_sec(self):

@property
def float8_pct_top_peak(self):
return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype]
return self.float8_tops_sec / bc.dtype_to_peak_tops[self.float_8_dtype]

@property
def te_tops_sec(self):
Expand All @@ -98,7 +86,7 @@ def te_tops_sec(self):
@property
def te_pct_top_peak(self):
if self.te_tops_sec is not None:
return self.te_tops_sec / dtype_to_peak_tops[self.float_8_dtype]
return self.te_tops_sec / bc.dtype_to_peak_tops[self.float_8_dtype]
else:
return None

Expand All @@ -107,24 +95,27 @@ def main(
sweep_path: Path,
compile: bool,
n_limit: Optional[int] = None,
llama_model_size: str = "70B",
):
device = "cuda"
print(f"Compile is set to | {compile}")
print("model size:", llama_model_size)

name_to_shapes = bc.name_to_shapes[llama_model_size]
if llama_model_size == "70B":
# common distributed setup, single GPU numbers
bsz, seq_len = 4, 4096
else:
# debug single gpu setup
bsz, seq_len = 1, 4096

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}
input_bias = False
ref_dtypes = [torch.bfloat16, torch.float16]
ref_dtypes = [
torch.bfloat16,
]
experiment_list: List[Experiment] = []
for idx, (dtype, (name, (K, N))) in enumerate(
tqdm(list(product(ref_dtypes, name_to_shapes_70b.items())))
tqdm(list(product(ref_dtypes, name_to_shapes.items())))
):
if n_limit is not None and idx >= n_limit:
break
Expand All @@ -136,7 +127,10 @@ def main(
copy.deepcopy(linear_ref), emulate=False
)

bsz, seq_len = 4, 4096
linear_dynamic_float8 = Float8DynamicLinear.from_float(
copy.deepcopy(linear_ref), emulate=False
)

M = bsz * seq_len
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()
Expand All @@ -145,6 +139,10 @@ def float8_forw_backward():
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

float8_dynamic_forw_backward = (
lambda: linear_dynamic_float8(input_tensor).sum().backward()
)

if transformer_engine_installed:
# Use the same recipe as float8_linear.DelayedScalingRecipe
fp8_format = recipe.Format.HYBRID
Expand All @@ -169,19 +167,23 @@ def wrapper(*args, **kwargs):

ref_forw_backward = n_times(REPEAT_N, ref_forw_backward)
float8_forw_backward = n_times(REPEAT_N, float8_forw_backward)
float8_dynamic_forw_backward = n_times(REPEAT_N, float8_dynamic_forw_backward)
if transformer_engine_installed:
te_forw_backward = n_times(REPEAT_N, te_forw_backward)

if compile:
ref_forw_backward = torch.compile(ref_forw_backward)
float8_forw_backward = torch.compile(float8_forw_backward)
float8_dynamic_forw_backward = torch.compile(float8_dynamic_forw_backward)
# Compiling TE_linear fails but they are already compiling under the hood
# if transformer_engine_installed:
# te_forw_backward = torch.compile(te_forw_backward)

# warmup
for _ in range(5):
ref_forw_backward()
float8_forw_backward()
float8_dynamic_forw_backward()
if transformer_engine_installed:
te_forw_backward()

Expand All @@ -195,6 +197,11 @@ def wrapper(*args, **kwargs):
* 1e-6
/ REPEAT_N
)
float8_dynamic_time = (
benchmark_torch_function_in_microseconds(float8_dynamic_forw_backward)
* 1e-6
/ REPEAT_N
)
if transformer_engine_installed:
te_time_sec = (
benchmark_torch_function_in_microseconds(te_forw_backward)
Expand All @@ -208,12 +215,17 @@ def wrapper(*args, **kwargs):
(M, K, N),
ref_time,
float8_time,
float8_dynamic_time,
dtype,
compile,
te_time_sec=te_time_sec,
)
print(experiment)
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
print(
"float8 dynamic speedup",
experiment.ref_time_sec / experiment.float8_dynamic_time_sec,
)
if transformer_engine_installed:
print("te speedup", experiment.ref_time_sec / experiment.te_time_sec)
experiment_list.append(experiment)
Expand All @@ -229,6 +241,7 @@ def wrapper(*args, **kwargs):
"fp8_dtype",
"ref_time_sec",
"pt_fp8_time_sec",
"pt_fp8_dynamic_time_sec",
"te_fp8_time_sec",
"ref_tops_sec",
"ref_pct_top_peak",
Expand All @@ -250,6 +263,7 @@ def wrapper(*args, **kwargs):
experiment.float_8_dtype,
experiment.ref_time_sec,
experiment.float8_time_sec,
experiment.float8_dynamic_time_sec,
experiment.te_time_sec,
experiment.ref_tops_sec,
experiment.ref_pct_top_peak,
Expand All @@ -262,6 +276,9 @@ def wrapper(*args, **kwargs):

data_pd = pd.DataFrame(data, columns=headers)
data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"]
data_pd["pt_fp8_dynamic_speedup"] = (
data_pd["ref_time_sec"] / data_pd["pt_fp8_dynamic_time_sec"]
)
if transformer_engine_installed:
data_pd["te_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["te_fp8_time_sec"]
else:
Expand All @@ -280,12 +297,13 @@ def wrapper(*args, **kwargs):
[
"name",
"shape",
"ref_dtype",
"compiled",
"ref_time_sec",
"pt_fp8_time_sec",
"pt_fp8_dynamic_time_sec",
"te_fp8_time_sec",
"pt_fp8_speedup",
"pt_fp8_dynamic_speedup",
"te_fp8_speedup",
]
]
Expand All @@ -301,9 +319,12 @@ def invoke_main() -> None:
parser.add_argument("-o", "--output_path", type=str, required=True)
parser.add_argument("--compile", action="store_true")
parser.add_argument("-n", "--n_limit", type=int, required=False)
parser.add_argument(
"--llama_model_size", type=str, required=True, choices=["70B", "7B", "13B"]
)
args = parser.parse_args()
output_path = Path(args.output_path)
main(output_path, args.compile, args.n_limit)
main(output_path, args.compile, args.n_limit, args.llama_model_size)


if __name__ == "__main__":
Expand Down
53 changes: 21 additions & 32 deletions benchmarks/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,15 @@
import itertools
from typing import Optional

import bench_constants as bc

import fire
import pandas as pd

import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 989e12
h100_peak_tops_float8_tc = 1979e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}


def benchmark_fn_in_sec(f, *args, **kwargs):
# Manual warmup
Expand All @@ -50,25 +36,25 @@ def do_benchmarks(tops, peak_tops, f, *args, **kwargs):


@torch.inference_mode()
def run(n_limit: Optional[int] = None):
def run(
llama_model_size: Optional[str] = "70B",
n_limit: Optional[int] = None,
output_path: Optional[str] = None,
):
print("model size", llama_model_size)
device = "cuda"

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}

headers = ("name", "shape", "dtype", "ref_time_s", "fp8_time_s", "fp8_speedup")
results = []

name_to_shapes = name_to_shapes_70b
bsz_and_seq_len = ((4, 4096),)
dtypes = torch.bfloat16, torch.float16
name_to_shapes = bc.name_to_shapes[llama_model_size]
if llama_model_size == "70B":
# common distributed setup, single GPU numbers
bsz_and_seq_len = ((4, 4096),)
else:
# debug single gpu setup
bsz_and_seq_len = ((1, 4096),)
dtypes = (torch.bfloat16,)

for idx, (dtype, (name, (K, N))) in enumerate(
itertools.product(dtypes, name_to_shapes.items())
Expand All @@ -88,7 +74,7 @@ def run(n_limit: Optional[int] = None):
A = torch.randn(M, K, device=device, dtype=dtype)
m_ref = nn.Sequential(nn.Linear(K, N, dtype=dtype, device=device, bias=False))
ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks(
tops, dtype_to_peak_tops[dtype], m_ref, A
tops, bc.dtype_to_peak_tops[dtype], m_ref, A
)
print(
f"{dtype} time_sec {ref_time_sec:.2E}, tops/sec {ref_tops_sec:.2E}, pct_peak {ref_pct_top_peak:.3f}"
Expand All @@ -106,7 +92,7 @@ def do_matmul(A, B):
return torch._scaled_mm(A, B, out_dtype=d3, use_fast_accum=False)

fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
tops, dtype_to_peak_tops[d1], do_matmul, A, B
tops, bc.dtype_to_peak_tops[d1], do_matmul, A, B
)
print(
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
Expand All @@ -127,6 +113,9 @@ def do_matmul(A, B):

data_pd = pd.DataFrame(results, columns=headers)
print(data_pd)
if output_path is not None:
with open(output_path, mode="w") as file:
data_pd.to_csv(output_path)


def main() -> None:
Expand Down