diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index 58bf884..df21c30 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -14,8 +14,13 @@ import torch import torch.utils.benchmark as benchmark +from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history +from float8_experimental.float8_linear_utils import ( + get_float8_linear, + LinearType, + sync_float8_amax_and_scale_history, +) from float8_experimental.float8_tensor import ScaledMMConfig from tqdm import tqdm @@ -35,6 +40,13 @@ torch.float8_e5m2: h100_peak_tops_float8_tc, } +# prevent splitting columns when printing a data frame +pd.set_option("display.expand_frame_repr", False) +# print the entire data frame +pd_print_full_ctx = pd.option_context( + "display.max_rows", None, "display.max_columns", None +) + def benchmark_torch_function_in_microseconds( func: Callable, @@ -57,6 +69,7 @@ class Experiment: dtype: torch.dtype compiled: bool use_fast_accum: bool + linear_type: str # 3 Times since we are calculating forward backward @property @@ -79,9 +92,12 @@ def float8_pct_top_peak(self): def main( - sweep_path: Path, - compile: bool, + sweep_path: Optional[Path] = None, + compile: bool = False, n_limit: Optional[int] = None, + fast_accum_filter: Optional[bool] = None, + shape_name_filter: Optional[str] = None, + linear_type_filter: Optional[str] = None, ): device = "cuda" print(f"Compile is set to | {compile}") @@ -95,20 +111,33 @@ def main( "ffn.w2": (3584, 8192), } input_bias = False - ref_dtypes = [torch.bfloat16, torch.float16] - use_fast_accum = [True, False] + if fast_accum_filter is not None: + use_fast_accum = [fast_accum_filter] + else: + use_fast_accum = [True, False] + if linear_type_filter is not None: + linear_types = [linear_type_filter] + else: + linear_types = ["delayed", "dynamic"] + if shape_name_filter is not None: + k = shape_name_filter + name_to_shapes_70b = {k: name_to_shapes_70b[k]} experiment_list: List[Experiment] = [] - for idx, (dtype, fast_accum, (name, (K, N))) in enumerate( - tqdm(list(product(ref_dtypes, use_fast_accum, name_to_shapes_70b.items()))) + dtype = torch.bfloat16 + for idx, (fast_accum, (name, (K, N)), linear_type) in enumerate( + tqdm(list(product(use_fast_accum, name_to_shapes_70b.items(), linear_types))) ): if n_limit is not None and idx >= n_limit: break linear_ref = torch.nn.Linear(K, N, bias=input_bias).to( device=device, dtype=dtype ) + linear_type_enum = ( + LinearType.DELAYED if linear_type == "delayed" else LinearType.DYNAMIC + ) - linear_float8 = Float8Linear.from_float( - copy.deepcopy(linear_ref), emulate=False + linear_float8 = get_float8_linear( + linear_type_enum, copy.deepcopy(linear_ref), emulate=False ) if fast_accum: linear_float8.forward_config = ScaledMMConfig(False, True, False) @@ -120,9 +149,16 @@ def main( input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True) ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() - def float8_forw_backward(): - sync_float8_amax_and_scale_history(linear_float8) - linear_float8(input_tensor).sum().backward() + if linear_type_enum == LinearType.DELAYED: + + def float8_forw_backward(): + sync_float8_amax_and_scale_history(linear_float8) + linear_float8(input_tensor).sum().backward() + + else: + + def float8_forw_backward(): + linear_float8(input_tensor).sum().backward() def n_times(n, fn, *args, **kwargs): def wrapper(*args, **kwargs): @@ -162,6 +198,7 @@ def wrapper(*args, **kwargs): dtype, compile, use_fast_accum=fast_accum, + linear_type=linear_type, ) print(experiment) print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) @@ -173,6 +210,7 @@ def wrapper(*args, **kwargs): "M", "K", "N", + "linear_type", "ref_dtype", "compiled", "use_fast_accum", @@ -191,6 +229,7 @@ def wrapper(*args, **kwargs): experiment.shape[0], experiment.shape[1], experiment.shape[2], + experiment.linear_type, experiment.dtype, experiment.compiled, experiment.use_fast_accum, @@ -219,7 +258,7 @@ def wrapper(*args, **kwargs): [ "name", "shape", - "ref_dtype", + "linear_type", "compiled", "use_fast_accum", "ref_time_sec", @@ -227,20 +266,32 @@ def wrapper(*args, **kwargs): "pt_fp8_speedup", ] ] - print(data_pd_simple) + with pd_print_full_ctx: + print(data_pd_simple) - sweep_path = sweep_path.with_suffix(".csv") - data_pd.to_csv(sweep_path) + if sweep_path is not None: + sweep_path = sweep_path.with_suffix(".csv") + data_pd.to_csv(sweep_path) def invoke_main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("-o", "--output_path", type=str, required=True) + parser.add_argument("-o", "--output_path", type=str, required=False) parser.add_argument("--compile", action="store_true") parser.add_argument("-n", "--n_limit", type=int, required=False) + parser.add_argument("--fast_accum_filter", type=bool, required=False) + parser.add_argument("--shape_name_filter", type=str, required=False) + parser.add_argument("--linear_type_filter", type=str, required=False) args = parser.parse_args() - output_path = Path(args.output_path) - main(output_path, args.compile, args.n_limit) + output_path = Path(args.output_path) if args.output_path is not None else None + main( + output_path, + args.compile, + args.n_limit, + args.fast_accum_filter, + args.shape_name_filter, + args.linear_type_filter, + ) if __name__ == "__main__":