Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
QOL improvements to linear benchmarking script (#278)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #278

1. add more command line filters
2. add dynamic scaling
3. remove float16 since it's low-pri and this cuts down benchmark time
   by 50%

Reviewed By: drisspg

Differential Revision: D58396927

fbshipit-source-id: 298cb3c48418d4b9dd1529fe38cf2229ab5618b7
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jun 14, 2024
1 parent 5d293a7 commit 1e9add3
Showing 1 changed file with 70 additions and 19 deletions.
89 changes: 70 additions & 19 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history
from float8_experimental.float8_linear_utils import (
get_float8_linear,
LinearType,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_tensor import ScaledMMConfig
from tqdm import tqdm

Expand All @@ -35,6 +40,13 @@
torch.float8_e5m2: h100_peak_tops_float8_tc,
}

# prevent splitting columns when printing a data frame
pd.set_option("display.expand_frame_repr", False)
# print the entire data frame
pd_print_full_ctx = pd.option_context(
"display.max_rows", None, "display.max_columns", None
)


def benchmark_torch_function_in_microseconds(
func: Callable,
Expand All @@ -57,6 +69,7 @@ class Experiment:
dtype: torch.dtype
compiled: bool
use_fast_accum: bool
linear_type: str

# 3 Times since we are calculating forward backward
@property
Expand All @@ -79,9 +92,12 @@ def float8_pct_top_peak(self):


def main(
sweep_path: Path,
compile: bool,
sweep_path: Optional[Path] = None,
compile: bool = False,
n_limit: Optional[int] = None,
fast_accum_filter: Optional[bool] = None,
shape_name_filter: Optional[str] = None,
linear_type_filter: Optional[str] = None,
):
device = "cuda"
print(f"Compile is set to | {compile}")
Expand All @@ -95,20 +111,33 @@ def main(
"ffn.w2": (3584, 8192),
}
input_bias = False
ref_dtypes = [torch.bfloat16, torch.float16]
use_fast_accum = [True, False]
if fast_accum_filter is not None:
use_fast_accum = [fast_accum_filter]
else:
use_fast_accum = [True, False]
if linear_type_filter is not None:
linear_types = [linear_type_filter]
else:
linear_types = ["delayed", "dynamic"]
if shape_name_filter is not None:
k = shape_name_filter
name_to_shapes_70b = {k: name_to_shapes_70b[k]}
experiment_list: List[Experiment] = []
for idx, (dtype, fast_accum, (name, (K, N))) in enumerate(
tqdm(list(product(ref_dtypes, use_fast_accum, name_to_shapes_70b.items())))
dtype = torch.bfloat16
for idx, (fast_accum, (name, (K, N)), linear_type) in enumerate(
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items(), linear_types)))
):
if n_limit is not None and idx >= n_limit:
break
linear_ref = torch.nn.Linear(K, N, bias=input_bias).to(
device=device, dtype=dtype
)
linear_type_enum = (
LinearType.DELAYED if linear_type == "delayed" else LinearType.DYNAMIC
)

linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref), emulate=False
linear_float8 = get_float8_linear(
linear_type_enum, copy.deepcopy(linear_ref), emulate=False
)
if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
Expand All @@ -120,9 +149,16 @@ def main(
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

def float8_forw_backward():
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()
if linear_type_enum == LinearType.DELAYED:

def float8_forw_backward():
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

else:

def float8_forw_backward():
linear_float8(input_tensor).sum().backward()

def n_times(n, fn, *args, **kwargs):
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -162,6 +198,7 @@ def wrapper(*args, **kwargs):
dtype,
compile,
use_fast_accum=fast_accum,
linear_type=linear_type,
)
print(experiment)
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
Expand All @@ -173,6 +210,7 @@ def wrapper(*args, **kwargs):
"M",
"K",
"N",
"linear_type",
"ref_dtype",
"compiled",
"use_fast_accum",
Expand All @@ -191,6 +229,7 @@ def wrapper(*args, **kwargs):
experiment.shape[0],
experiment.shape[1],
experiment.shape[2],
experiment.linear_type,
experiment.dtype,
experiment.compiled,
experiment.use_fast_accum,
Expand Down Expand Up @@ -219,28 +258,40 @@ def wrapper(*args, **kwargs):
[
"name",
"shape",
"ref_dtype",
"linear_type",
"compiled",
"use_fast_accum",
"ref_time_sec",
"pt_fp8_time_sec",
"pt_fp8_speedup",
]
]
print(data_pd_simple)
with pd_print_full_ctx:
print(data_pd_simple)

sweep_path = sweep_path.with_suffix(".csv")
data_pd.to_csv(sweep_path)
if sweep_path is not None:
sweep_path = sweep_path.with_suffix(".csv")
data_pd.to_csv(sweep_path)


def invoke_main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output_path", type=str, required=True)
parser.add_argument("-o", "--output_path", type=str, required=False)
parser.add_argument("--compile", action="store_true")
parser.add_argument("-n", "--n_limit", type=int, required=False)
parser.add_argument("--fast_accum_filter", type=bool, required=False)
parser.add_argument("--shape_name_filter", type=str, required=False)
parser.add_argument("--linear_type_filter", type=str, required=False)
args = parser.parse_args()
output_path = Path(args.output_path)
main(output_path, args.compile, args.n_limit)
output_path = Path(args.output_path) if args.output_path is not None else None
main(
output_path,
args.compile,
args.n_limit,
args.fast_accum_filter,
args.shape_name_filter,
args.linear_type_filter,
)


if __name__ == "__main__":
Expand Down

0 comments on commit 1e9add3

Please sign in to comment.