diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index 407913af..d27ae174 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -10,6 +10,7 @@ from typing import Callable, Optional import fire +import functools import torch from float8_experimental.float8_linear_utils import ( @@ -19,6 +20,7 @@ sync_float8_amax_and_scale_history, ) from torch.profiler import profile, ProfilerActivity, record_function +from torch._inductor.utils import do_bench_using_profiling @dataclass @@ -73,6 +75,10 @@ def profile_function( if config.file_path is None: print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) + full_func = functools.partial(func, *args, **kwargs) + latency = do_bench_using_profiling(full_func) + print(f"{func=}, {latency=}") + return prof @@ -134,24 +140,24 @@ def main(profile_path: Path, compile: bool, linear_type: str): def ref_forw_backward(x): if params.layer_norm: - with record_function("layer_norm"): - x = ln(x) - with record_function("forward"): - out = linear_ref(x) - with record_function("backward"): - out.sum().backward() + #with record_function("layer_norm"): + x = ln(x) + #with record_function("forward"): + out = linear_ref(x) + #with record_function("backward"): + out.sum().backward() def float8_forw_backward(x): if linear_requires_sync(linear_type): - with record_function("scale_amax_and_scales"): - sync_float8_amax_and_scale_history(linear_float8) + # with record_function("scale_amax_and_scales"): + sync_float8_amax_and_scale_history(linear_float8) if params.layer_norm: - with record_function("layer_norm"): - x = ln(x) - with record_function("forward"): - out = linear_float8(x) - with record_function("backward"): - out.sum().backward() + # with record_function("layer_norm"): + x = ln(x) + # with record_function("forward"): + out = linear_float8(x) + # with record_function("backward"): + out.sum().backward() if transformer_engine_installed: # Create an FP8 recipe. Note: All input args are optional. @@ -170,15 +176,20 @@ def te_forw_backward(x): out.sum().backward() if params.torch_compile: - ref_forw_backward = torch.compile(ref_forw_backward) + #ref_forw_backward = torch.compile(ref_forw_backward) float8_forw_backward = torch.compile(float8_forw_backward) # Compiling TE_linear fails but they are already compiling under the hood # if transformer_engine_installed: # te_forw_backward = torch.compile(te_forw_backward) + def wrapper_float8(x): + if linear_requires_sync(linear_type): + sync_float8_amax_and_scale_history(linear_float8) + float8_forw_backward(x) + for _ in range(5): ref_forw_backward(input_tensor) - float8_forw_backward(input_tensor) + wrapper_float8(input_tensor) if transformer_engine_installed: te_forw_backward(input_tensor) @@ -189,7 +200,7 @@ def te_forw_backward(x): ) profile_function(profile_config, ref_forw_backward, input_tensor) - # # Profile Float8 Linear + # Profile Float8 Linear float8_string = f"linear_float8_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}_compile_{params.torch_compile}_{linear_type}.json" profile_config = ProfileConfig( str(profile_path / float8_string), @@ -198,7 +209,7 @@ def te_forw_backward(x): warmup_iters=5, sync=True, ) - profile_function(profile_config, float8_forw_backward, input_tensor) + profile_function(profile_config, wrapper_float8, input_tensor) te_string = f"linear_transformer_engine_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}.json" if transformer_engine_installed: