Skip to content

Commit ef12e27

Browse files
committed
finish python ref impl for bulk alloc
Signed-off-by: zhongboz <[email protected]>
1 parent c203f52 commit ef12e27

File tree

3 files changed

+366
-4
lines changed

3 files changed

+366
-4
lines changed
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
import argparse
2+
import torch
3+
import torch.utils.benchmark as benchmark
4+
import pandas as pd
5+
import pathlib
6+
7+
from transformer_engine.pytorch.module import GroupedLinear
8+
from transformer_engine.common.recipe import Float8BlockScaling
9+
from transformer_engine.pytorch.fp8 import fp8_autocast
10+
from contextlib import nullcontext
11+
RECIPES = {
12+
"bf16": None,
13+
"fp8_sub_channel": Float8BlockScaling(),
14+
}
15+
16+
17+
def run_linear_multiple_steps(
18+
layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None
19+
):
20+
assert mode in ["fwd_only", "fwd_bwd"]
21+
fp8_context = fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext()
22+
# print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}")
23+
24+
if mode == "fwd_only":
25+
with torch.no_grad(), fp8_context:
26+
for i in range(run_num_steps):
27+
y_q = layer.forward(
28+
x,
29+
m_splits,
30+
is_first_microbatch=(i == 0),
31+
)
32+
return y_q
33+
else:
34+
# reset gradients
35+
layer.zero_grad()
36+
x.grad = None
37+
38+
with fp8_context:
39+
for i in range(run_num_steps):
40+
label = f"step_{i}"
41+
torch.cuda.nvtx.range_push(label)
42+
y_q = layer.forward(
43+
x,
44+
m_splits,
45+
is_first_microbatch=(i == 0),
46+
)
47+
y_q.backward(gradient)
48+
torch.cuda.nvtx.range_pop()
49+
50+
grads_q = []
51+
grads_q.append(x.grad)
52+
# remaining derivatives are in respect to model parameters
53+
for p in layer.parameters():
54+
if p.requires_grad:
55+
grads_q.append(p.grad)
56+
57+
return y_q, grads_q
58+
59+
60+
def benchmark_linear(
61+
x,
62+
ws,
63+
m_splits,
64+
bias,
65+
recipe_name,
66+
mode,
67+
num_gemms=4,
68+
):
69+
params_dtype = torch.bfloat16
70+
recipe =RECIPES[recipe_name]
71+
72+
in_features = x.shape[1]
73+
out_features = ws[0].shape[0]
74+
gradient = torch.ones(
75+
(x.shape[0], out_features), dtype=torch.bfloat16, device=x.device
76+
)
77+
78+
layer = GroupedLinear(
79+
num_gemms,
80+
in_features,
81+
out_features,
82+
bias=bias is not None,
83+
params_dtype=params_dtype,
84+
)
85+
86+
layer = layer.to("cuda")
87+
with torch.no_grad():
88+
for i in range(num_gemms):
89+
weight_i = getattr(layer, f"weight{i}")
90+
weight_i.copy_(ws[i])
91+
if bias is not None:
92+
bias_i = getattr(layer, f"bias{i}")
93+
bias_i.copy_(bias)
94+
95+
num_microbatches = 32
96+
97+
label = f"{recipe_name}_{'grouped'}"
98+
torch.cuda.nvtx.range_push(label)
99+
timing = benchmark.Timer(
100+
stmt="run_linear_multiple_steps(layer, x, m_splits, mode, gradient, num_microbatches, recipe)",
101+
globals={
102+
"run_linear_multiple_steps": run_linear_multiple_steps,
103+
"layer": layer,
104+
"x": x,
105+
"m_splits": m_splits,
106+
"mode": mode,
107+
"gradient": gradient,
108+
"num_microbatches": num_microbatches,
109+
"recipe": recipe,
110+
},
111+
num_threads=1,
112+
).blocked_autorange(min_run_time=5)
113+
print(f"{recipe_name}: {timing} \n")
114+
timing_ms = timing.median * 1000 / num_microbatches
115+
116+
return timing_ms
117+
118+
119+
def run_benchmark_linear(
120+
mkns, recipe_name, use_bias, num_gemms=4
121+
):
122+
data = []
123+
assert not use_bias, "Bias is not supported for GroupedLinear benchmark"
124+
125+
print(f"========== Benchmarking {recipe_name} ==========")
126+
for m, k, n in mkns:
127+
device = "cuda"
128+
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
129+
ws = [
130+
torch.randn((n, k), dtype=torch.bfloat16, device=device)
131+
for _ in range(num_gemms)
132+
]
133+
assert m % num_gemms == 0
134+
m_splits = [m // num_gemms] * num_gemms
135+
# Bias is not supported for GroupedLinear benchmark
136+
bias = None
137+
138+
# Run the benchmark
139+
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
140+
141+
grouped_fwd_bwd_timing_ms = benchmark_linear(
142+
x,
143+
ws,
144+
m_splits,
145+
bias,
146+
recipe_name,
147+
mode="fwd_bwd",
148+
num_gemms=num_gemms,
149+
)
150+
151+
# Append the results
152+
data.append(
153+
[
154+
m,
155+
k,
156+
n,
157+
recipe_name,
158+
num_gemms,
159+
grouped_fwd_bwd_timing_ms,
160+
]
161+
)
162+
163+
df = pd.DataFrame(
164+
data=data,
165+
columns=[
166+
"m",
167+
"k",
168+
"n",
169+
"recipe",
170+
"num_gemms",
171+
"grouped_fwd_bwd_time_ms",
172+
],
173+
)
174+
175+
print(df, "\n")
176+
return df
177+
178+
179+
if __name__ == "__main__":
180+
181+
parser = argparse.ArgumentParser()
182+
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
183+
parser.add_argument(
184+
"--output_dir",
185+
type=str,
186+
default="benchmark_output/",
187+
help="output path for report",
188+
)
189+
args = parser.parse_args()
190+
191+
use_bias = False
192+
# Set the MKN values to benchmark
193+
mkns = []
194+
for m in [1024]:
195+
# for m in [4096, 8192, 16384]:
196+
# for n in [1024, 2048, 4096, 8192, 16384]:
197+
for n in [3072]:
198+
for k in [4096]:
199+
mkns.append((m, k, n))
200+
201+
# recipe_list = [
202+
# "bf16", "fp8_sub_channel",
203+
# ]
204+
recipe_list = [
205+
"fp8_sub_channel",
206+
]
207+
208+
# num_gemms_list = [16, 32]
209+
num_gemms_list = [4]
210+
211+
if args.profile:
212+
# nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_1_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
213+
# nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_32_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
214+
# nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_8_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
215+
# nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_2_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
216+
mkns = [(4096, 4096, 4096)]
217+
recipe_list = ["fp8_sub_channel"]
218+
# recipe_list = ["bf16"]
219+
num_gemms_list = [8]
220+
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
221+
222+
# Initialize a dataframe to store the results
223+
df_linears = pd.DataFrame()
224+
225+
# Run the fp8 benchmarks
226+
for num_gemms in num_gemms_list:
227+
print(f"========== Benchmarking with num_gemms={num_gemms} ==========")
228+
for recipe_name in recipe_list:
229+
df = run_benchmark_linear(
230+
mkns,
231+
recipe_name,
232+
use_bias,
233+
num_gemms=num_gemms,
234+
)
235+
df_linears = pd.concat([df_linears, df])
236+
237+
print(df_linears)
238+
239+
240+
if args.profile:
241+
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
restore_from_saved,
5151
)
5252

53+
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer, bulk_alloc_float8_blockwise_tensor
54+
5355
__all__ = ["GroupedLinear"]
5456

5557

@@ -91,7 +93,8 @@ def forward(
9193
# Make sure input dimensions are compatible
9294
in_features = weights[0].shape[-1]
9395
assert inp.shape[-1] == in_features, "GEMM not possible"
94-
inputmats = torch.split(inp.view(-1, in_features), m_splits)
96+
inp_view = inp.view(-1, in_features)
97+
inputmats = torch.split(inp_view, m_splits)
9598
if fp8:
9699
assert_dim_for_fp8_exec(*inputmats, *weights)
97100

@@ -125,9 +128,25 @@ def forward(
125128
recipe = FP8GlobalStateManager.get_fp8_recipe()
126129
if hasattr(recipe, "fp8_gemm_fprop"):
127130
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
131+
# TODO(zhongbo): make bulk alloc available for all quantizers
132+
output_list = bulk_alloc_float8_blockwise_tensor(inp_view, m_splits, input_quantizers) if isinstance(input_quantizers[0], Float8BlockQuantizer) else None
128133
inputmats = tex.fused_multi_quantize(
129-
inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
134+
inputmats_no_fp8, output_list, input_quantizers, TE_DType[activation_dtype]
130135
)
136+
# inputmats_ref = tex.fused_multi_quantize(
137+
# inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
138+
# )
139+
# [DEBUG]
140+
# use torch testing to directly do zero tolerance test
141+
# for i in range(len(inputmats)):
142+
# tensor = inputmats[i]
143+
# tensor_ref = inputmats_ref[i]
144+
# torch.testing.assert_close(tensor._rowwise_data, tensor_ref._rowwise_data)
145+
# torch.testing.assert_close(tensor._rowwise_scale_inv, tensor_ref._rowwise_scale_inv)
146+
# torch.testing.assert_close(tensor._columnwise_data, tensor_ref._columnwise_data)
147+
# torch.testing.assert_close(tensor._columnwise_scale_inv, tensor_ref._columnwise_scale_inv)
148+
149+
# raise Exception("Not implemented")
131150
weights_fp8 = []
132151
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
133152
# FP8 cast to workspace buffer
@@ -250,8 +269,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
250269
# preprocess grad_output
251270

252271
grad_output = grad_output.contiguous()
272+
grad_output_view = grad_output.view(-1, grad_output.shape[-1])
253273
grad_output_mats = torch.split(
254-
grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits
274+
grad_output_view, ctx.m_splits
255275
)
256276
grad_output = [None] * ctx.num_gemms
257277
grad_biases = [None] * ctx.num_gemms
@@ -269,9 +289,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
269289
grad_output_mats[i], ctx.grad_output_quantizers[i]
270290
)
271291
else:
292+
output_list = bulk_alloc_float8_blockwise_tensor(grad_output_view, ctx.m_splits, ctx.grad_output_quantizers) if isinstance(ctx.grad_output_quantizers[0], Float8BlockQuantizer) else None
272293
grad_output = tex.fused_multi_quantize(
273294
grad_output_mats,
274-
None,
295+
output_list,
275296
ctx.grad_output_quantizers,
276297
TE_DType[ctx.activation_dtype],
277298
)

0 commit comments

Comments
 (0)