|
| 1 | +import argparse |
| 2 | +import torch |
| 3 | +import torch.utils.benchmark as benchmark |
| 4 | +import pandas as pd |
| 5 | +import pathlib |
| 6 | + |
| 7 | +from transformer_engine.pytorch.module import GroupedLinear |
| 8 | +from transformer_engine.common.recipe import Float8BlockScaling |
| 9 | +from transformer_engine.pytorch.fp8 import fp8_autocast |
| 10 | +from contextlib import nullcontext |
| 11 | +RECIPES = { |
| 12 | + "bf16": None, |
| 13 | + "fp8_sub_channel": Float8BlockScaling(), |
| 14 | +} |
| 15 | + |
| 16 | + |
| 17 | +def run_linear_multiple_steps( |
| 18 | + layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None |
| 19 | +): |
| 20 | + assert mode in ["fwd_only", "fwd_bwd"] |
| 21 | + fp8_context = fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext() |
| 22 | + # print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}") |
| 23 | + |
| 24 | + if mode == "fwd_only": |
| 25 | + with torch.no_grad(), fp8_context: |
| 26 | + for i in range(run_num_steps): |
| 27 | + y_q = layer.forward( |
| 28 | + x, |
| 29 | + m_splits, |
| 30 | + is_first_microbatch=(i == 0), |
| 31 | + ) |
| 32 | + return y_q |
| 33 | + else: |
| 34 | + # reset gradients |
| 35 | + layer.zero_grad() |
| 36 | + x.grad = None |
| 37 | + |
| 38 | + with fp8_context: |
| 39 | + for i in range(run_num_steps): |
| 40 | + label = f"step_{i}" |
| 41 | + torch.cuda.nvtx.range_push(label) |
| 42 | + y_q = layer.forward( |
| 43 | + x, |
| 44 | + m_splits, |
| 45 | + is_first_microbatch=(i == 0), |
| 46 | + ) |
| 47 | + y_q.backward(gradient) |
| 48 | + torch.cuda.nvtx.range_pop() |
| 49 | + |
| 50 | + grads_q = [] |
| 51 | + grads_q.append(x.grad) |
| 52 | + # remaining derivatives are in respect to model parameters |
| 53 | + for p in layer.parameters(): |
| 54 | + if p.requires_grad: |
| 55 | + grads_q.append(p.grad) |
| 56 | + |
| 57 | + return y_q, grads_q |
| 58 | + |
| 59 | + |
| 60 | +def benchmark_linear( |
| 61 | + x, |
| 62 | + ws, |
| 63 | + m_splits, |
| 64 | + bias, |
| 65 | + recipe_name, |
| 66 | + mode, |
| 67 | + num_gemms=4, |
| 68 | +): |
| 69 | + params_dtype = torch.bfloat16 |
| 70 | + recipe =RECIPES[recipe_name] |
| 71 | + |
| 72 | + in_features = x.shape[1] |
| 73 | + out_features = ws[0].shape[0] |
| 74 | + gradient = torch.ones( |
| 75 | + (x.shape[0], out_features), dtype=torch.bfloat16, device=x.device |
| 76 | + ) |
| 77 | + |
| 78 | + layer = GroupedLinear( |
| 79 | + num_gemms, |
| 80 | + in_features, |
| 81 | + out_features, |
| 82 | + bias=bias is not None, |
| 83 | + params_dtype=params_dtype, |
| 84 | + ) |
| 85 | + |
| 86 | + layer = layer.to("cuda") |
| 87 | + with torch.no_grad(): |
| 88 | + for i in range(num_gemms): |
| 89 | + weight_i = getattr(layer, f"weight{i}") |
| 90 | + weight_i.copy_(ws[i]) |
| 91 | + if bias is not None: |
| 92 | + bias_i = getattr(layer, f"bias{i}") |
| 93 | + bias_i.copy_(bias) |
| 94 | + |
| 95 | + num_microbatches = 32 |
| 96 | + |
| 97 | + label = f"{recipe_name}_{'grouped'}" |
| 98 | + torch.cuda.nvtx.range_push(label) |
| 99 | + timing = benchmark.Timer( |
| 100 | + stmt="run_linear_multiple_steps(layer, x, m_splits, mode, gradient, num_microbatches, recipe)", |
| 101 | + globals={ |
| 102 | + "run_linear_multiple_steps": run_linear_multiple_steps, |
| 103 | + "layer": layer, |
| 104 | + "x": x, |
| 105 | + "m_splits": m_splits, |
| 106 | + "mode": mode, |
| 107 | + "gradient": gradient, |
| 108 | + "num_microbatches": num_microbatches, |
| 109 | + "recipe": recipe, |
| 110 | + }, |
| 111 | + num_threads=1, |
| 112 | + ).blocked_autorange(min_run_time=5) |
| 113 | + print(f"{recipe_name}: {timing} \n") |
| 114 | + timing_ms = timing.median * 1000 / num_microbatches |
| 115 | + |
| 116 | + return timing_ms |
| 117 | + |
| 118 | + |
| 119 | +def run_benchmark_linear( |
| 120 | + mkns, recipe_name, use_bias, num_gemms=4 |
| 121 | +): |
| 122 | + data = [] |
| 123 | + assert not use_bias, "Bias is not supported for GroupedLinear benchmark" |
| 124 | + |
| 125 | + print(f"========== Benchmarking {recipe_name} ==========") |
| 126 | + for m, k, n in mkns: |
| 127 | + device = "cuda" |
| 128 | + x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True) |
| 129 | + ws = [ |
| 130 | + torch.randn((n, k), dtype=torch.bfloat16, device=device) |
| 131 | + for _ in range(num_gemms) |
| 132 | + ] |
| 133 | + assert m % num_gemms == 0 |
| 134 | + m_splits = [m // num_gemms] * num_gemms |
| 135 | + # Bias is not supported for GroupedLinear benchmark |
| 136 | + bias = None |
| 137 | + |
| 138 | + # Run the benchmark |
| 139 | + print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}") |
| 140 | + |
| 141 | + grouped_fwd_bwd_timing_ms = benchmark_linear( |
| 142 | + x, |
| 143 | + ws, |
| 144 | + m_splits, |
| 145 | + bias, |
| 146 | + recipe_name, |
| 147 | + mode="fwd_bwd", |
| 148 | + num_gemms=num_gemms, |
| 149 | + ) |
| 150 | + |
| 151 | + # Append the results |
| 152 | + data.append( |
| 153 | + [ |
| 154 | + m, |
| 155 | + k, |
| 156 | + n, |
| 157 | + recipe_name, |
| 158 | + num_gemms, |
| 159 | + grouped_fwd_bwd_timing_ms, |
| 160 | + ] |
| 161 | + ) |
| 162 | + |
| 163 | + df = pd.DataFrame( |
| 164 | + data=data, |
| 165 | + columns=[ |
| 166 | + "m", |
| 167 | + "k", |
| 168 | + "n", |
| 169 | + "recipe", |
| 170 | + "num_gemms", |
| 171 | + "grouped_fwd_bwd_time_ms", |
| 172 | + ], |
| 173 | + ) |
| 174 | + |
| 175 | + print(df, "\n") |
| 176 | + return df |
| 177 | + |
| 178 | + |
| 179 | +if __name__ == "__main__": |
| 180 | + |
| 181 | + parser = argparse.ArgumentParser() |
| 182 | + parser.add_argument("--profile", action="store_true", help="Enable profiling mode") |
| 183 | + parser.add_argument( |
| 184 | + "--output_dir", |
| 185 | + type=str, |
| 186 | + default="benchmark_output/", |
| 187 | + help="output path for report", |
| 188 | + ) |
| 189 | + args = parser.parse_args() |
| 190 | + |
| 191 | + use_bias = False |
| 192 | + # Set the MKN values to benchmark |
| 193 | + mkns = [] |
| 194 | + for m in [1024]: |
| 195 | + # for m in [4096, 8192, 16384]: |
| 196 | + # for n in [1024, 2048, 4096, 8192, 16384]: |
| 197 | + for n in [3072]: |
| 198 | + for k in [4096]: |
| 199 | + mkns.append((m, k, n)) |
| 200 | + |
| 201 | + # recipe_list = [ |
| 202 | + # "bf16", "fp8_sub_channel", |
| 203 | + # ] |
| 204 | + recipe_list = [ |
| 205 | + "fp8_sub_channel", |
| 206 | + ] |
| 207 | + |
| 208 | + # num_gemms_list = [16, 32] |
| 209 | + num_gemms_list = [4] |
| 210 | + |
| 211 | + if args.profile: |
| 212 | + # nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_1_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile |
| 213 | + # nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_32_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile |
| 214 | + # nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_8_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile |
| 215 | + # nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_2_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile |
| 216 | + mkns = [(4096, 4096, 4096)] |
| 217 | + recipe_list = ["fp8_sub_channel"] |
| 218 | + # recipe_list = ["bf16"] |
| 219 | + num_gemms_list = [8] |
| 220 | + torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() |
| 221 | + |
| 222 | + # Initialize a dataframe to store the results |
| 223 | + df_linears = pd.DataFrame() |
| 224 | + |
| 225 | + # Run the fp8 benchmarks |
| 226 | + for num_gemms in num_gemms_list: |
| 227 | + print(f"========== Benchmarking with num_gemms={num_gemms} ==========") |
| 228 | + for recipe_name in recipe_list: |
| 229 | + df = run_benchmark_linear( |
| 230 | + mkns, |
| 231 | + recipe_name, |
| 232 | + use_bias, |
| 233 | + num_gemms=num_gemms, |
| 234 | + ) |
| 235 | + df_linears = pd.concat([df_linears, df]) |
| 236 | + |
| 237 | + print(df_linears) |
| 238 | + |
| 239 | + |
| 240 | + if args.profile: |
| 241 | + torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) |
0 commit comments