Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
fp8 benchmark testing
Browse files Browse the repository at this point in the history
  • Loading branch information
ipiszy committed Nov 18, 2023
1 parent 52aed83 commit d5cf228
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Callable, Optional

import fire
import functools

import torch
from float8_experimental.float8_linear_utils import (
Expand All @@ -19,6 +20,7 @@
sync_float8_amax_and_scale_history,
)
from torch.profiler import profile, ProfilerActivity, record_function
from torch._inductor.utils import do_bench_using_profiling


@dataclass
Expand Down Expand Up @@ -73,6 +75,10 @@ def profile_function(
if config.file_path is None:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

full_func = functools.partial(func, *args, **kwargs)
latency = do_bench_using_profiling(full_func)
print(f"{func=}, {latency=}")

return prof


Expand Down Expand Up @@ -134,24 +140,24 @@ def main(profile_path: Path, compile: bool, linear_type: str):

def ref_forw_backward(x):
if params.layer_norm:
with record_function("layer_norm"):
x = ln(x)
with record_function("forward"):
out = linear_ref(x)
with record_function("backward"):
out.sum().backward()
#with record_function("layer_norm"):
x = ln(x)
#with record_function("forward"):
out = linear_ref(x)
#with record_function("backward"):
out.sum().backward()

def float8_forw_backward(x):
if linear_requires_sync(linear_type):
with record_function("scale_amax_and_scales"):
sync_float8_amax_and_scale_history(linear_float8)
# with record_function("scale_amax_and_scales"):
sync_float8_amax_and_scale_history(linear_float8)
if params.layer_norm:
with record_function("layer_norm"):
x = ln(x)
with record_function("forward"):
out = linear_float8(x)
with record_function("backward"):
out.sum().backward()
# with record_function("layer_norm"):
x = ln(x)
# with record_function("forward"):
out = linear_float8(x)
# with record_function("backward"):
out.sum().backward()

if transformer_engine_installed:
# Create an FP8 recipe. Note: All input args are optional.
Expand All @@ -170,15 +176,20 @@ def te_forw_backward(x):
out.sum().backward()

if params.torch_compile:
ref_forw_backward = torch.compile(ref_forw_backward)
#ref_forw_backward = torch.compile(ref_forw_backward)
float8_forw_backward = torch.compile(float8_forw_backward)
# Compiling TE_linear fails but they are already compiling under the hood
# if transformer_engine_installed:
# te_forw_backward = torch.compile(te_forw_backward)

def wrapper_float8(x):
if linear_requires_sync(linear_type):
sync_float8_amax_and_scale_history(linear_float8)
float8_forw_backward(x)

for _ in range(5):
ref_forw_backward(input_tensor)
float8_forw_backward(input_tensor)
wrapper_float8(input_tensor)
if transformer_engine_installed:
te_forw_backward(input_tensor)

Expand All @@ -189,7 +200,7 @@ def te_forw_backward(x):
)
profile_function(profile_config, ref_forw_backward, input_tensor)

# # Profile Float8 Linear
# Profile Float8 Linear
float8_string = f"linear_float8_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}_compile_{params.torch_compile}_{linear_type}.json"
profile_config = ProfileConfig(
str(profile_path / float8_string),
Expand All @@ -198,7 +209,7 @@ def te_forw_backward(x):
warmup_iters=5,
sync=True,
)
profile_function(profile_config, float8_forw_backward, input_tensor)
profile_function(profile_config, wrapper_float8, input_tensor)

te_string = f"linear_transformer_engine_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}.json"
if transformer_engine_installed:
Expand Down

0 comments on commit d5cf228

Please sign in to comment.