Skip to content

Commit 379db5f

Browse files
amirakb89facebook-github-bot
authored andcommitted
Profile with kineto and warmup for more accurate benchmarking (#3585)
Summary: Pull Request resolved: #3585 X-link: facebookresearch/FBGEMM#667 **Summary** This PR introduces: A new warm-up method to ensure sufficient GPU preparation before benchmarking. Benchmark time calculation using the Kineto profiler for measuring the time and bandwidth of inference forward kernels. **Motivation** In small benchmark cases, kernel launch and synchronization overheads can be significant compared to the actual kernel runtime. By leveraging the Kineto profiler: These overheads are eliminated. Users get a more accurate estimation of kernel execution time and bandwidth of the forward kernel. For small kernels the iteration based warm-up might not be sufficient. By leveraging the time based warmup: Users will be confident the GPU has done enough warm-up. **Test instruction** The below script shows how to use this features: python bench/split_table_batched_embeddings_benchmark.py nbit-device-with-spec --export-trace --warmup_ms 50 Pull Request resolved: #3580 Reviewed By: leitian Differential Revision: D68292871 Pulled By: q10 fbshipit-source-id: 0a90cddcf07780164e38ac1b945717d8456947c0
1 parent 21d1260 commit 379db5f

File tree

2 files changed

+203
-27
lines changed

2 files changed

+203
-27
lines changed

fbgemm_gpu/bench/bench_utils.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,28 @@
3030
logging.basicConfig(level=logging.DEBUG)
3131

3232

33+
def warmup(
34+
request: TBERequest,
35+
warmup_ms: int,
36+
warmup_runs: int,
37+
func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
38+
bwd_only: bool = False,
39+
grad: Optional[torch.Tensor] = None,
40+
) -> None:
41+
indices, offsets, weights = request.unpack_3()
42+
if warmup_ms:
43+
start_time_ms = time.time() * 1000
44+
while time.time() * 1000 - start_time_ms < warmup_ms:
45+
out = func(indices, offsets, weights)
46+
if bwd_only:
47+
out.backward(grad)
48+
else:
49+
for _ in range(warmup_runs):
50+
out = func(indices, offsets, weights)
51+
if bwd_only:
52+
out.backward(grad)
53+
54+
3355
def benchmark_torch_function( # noqa: C901
3456
# pyre-fixme[2]: Parameter must be annotated.
3557
f,
@@ -159,19 +181,30 @@ def benchmark_requests(
159181
# Can be used to clear model's stats after warmup for example.
160182
callback_after_warmup: Optional[Callable[[], None]] = None,
161183
periodic_logs: bool = False,
184+
warmup_ms: Optional[int] = None,
162185
) -> float:
163186
times = []
164-
165187
# Run at least one warmup iteration to avoid the long cudaLaunchKernel time
166-
# for the first kernel
167-
num_warmups = num_warmups + 1 if num_warmups >= 0 else 1
168-
169-
if num_warmups > 0:
170-
indices, offsets, weights = requests[0].unpack_3()
171-
for _ in range(num_warmups):
172-
out = func(indices, offsets, weights)
173-
if bwd_only:
174-
out.backward(grad)
188+
# for the first kernel if warmup_ms > 0
189+
# warmup_ms is prioritized over num_warmups
190+
191+
if warmup_ms is None:
192+
num_warmups = num_warmups + 1 if num_warmups >= 0 else 1
193+
194+
# warm-up the GPU before profiling
195+
warmup(
196+
requests[0],
197+
# pyre-ignore[6]
198+
warmup_ms,
199+
num_warmups,
200+
lambda indices, offsets, per_sample_weights: func(
201+
indices,
202+
offsets,
203+
per_sample_weights,
204+
),
205+
bwd_only=bwd_only,
206+
grad=grad,
207+
)
175208

176209
if callback_after_warmup is not None:
177210
callback_after_warmup()

fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py

Lines changed: 160 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
benchmark_torch_function,
7878
benchmark_vbe,
7979
fill_random_scale_bias,
80+
warmup,
8081
)
8182
else:
8283
from fbgemm_gpu.bench.bench_utils import (
@@ -87,12 +88,28 @@
8788
benchmark_torch_function,
8889
benchmark_vbe,
8990
fill_random_scale_bias,
91+
warmup,
9092
)
9193

9294

9395
logging.basicConfig(level=logging.DEBUG)
9496

9597

98+
def kineto_trace_profiler(p: profile, trace_info: tuple[str, str, str, str]) -> float:
99+
phase, trace_url, tbe_type, kern_name = trace_info
100+
p.export_chrome_trace(
101+
trace_url.format(tbe_type=tbe_type, phase=phase, ospid=os.getpid())
102+
)
103+
kernel_time = 0
104+
for event in p.key_averages():
105+
# Sum the total time of forward kernel runs
106+
if kern_name in event.key:
107+
kernel_time += event.device_time
108+
assert kernel_time > 0
109+
print(f"Total CUDA time: {kernel_time:.2f} ")
110+
return kernel_time
111+
112+
96113
@click.group()
97114
def cli() -> None:
98115
pass
@@ -323,7 +340,6 @@ def device( # noqa C901
323340
logging.info(
324341
f"Accessed weights per batch: {B * sum(Ds) * L * param_size_multiplier / 1.0e9: .2f} GB"
325342
)
326-
327343
requests = generate_requests(
328344
iters,
329345
B,
@@ -1135,6 +1151,7 @@ def nbit_cpu( # noqa C901
11351151
@click.option("--iters", default=100)
11361152
@click.option("--runs-of-iters", default=5)
11371153
@click.option("--warmup-runs", default=2)
1154+
@click.option("--warmup-ms", type=int, default=None)
11381155
@click.option("--output-dtype", type=SparseType, default=SparseType.FP16)
11391156
@click.option("--report-aibench", is_flag=True)
11401157
@click.option("--run-reference", is_flag=True, default=False)
@@ -1148,6 +1165,17 @@ def nbit_cpu( # noqa C901
11481165
type=str,
11491166
default="{tbe_type}_tbe_{phase}_trace_{ospid}.json",
11501167
)
1168+
@click.option(
1169+
"--warmup-runs",
1170+
default=2,
1171+
help="Number of warmup runs. Ignored if --warmup-ms is set.",
1172+
)
1173+
@click.option(
1174+
"--warmup-ms",
1175+
type=int,
1176+
default=None,
1177+
help="Warmup duration in milliseconds. Disables the --run-nums option.",
1178+
)
11511179
def nbit_device( # noqa C901
11521180
alpha: float,
11531181
bag_size: int,
@@ -1168,7 +1196,6 @@ def nbit_device( # noqa C901
11681196
check_median: bool,
11691197
iters: int,
11701198
runs_of_iters: int,
1171-
warmup_runs: int,
11721199
output_dtype: SparseType,
11731200
report_aibench: bool,
11741201
run_reference: bool,
@@ -1178,6 +1205,8 @@ def nbit_device( # noqa C901
11781205
fp8_exponent_bias: Optional[int],
11791206
export_trace: bool,
11801207
trace_url: str,
1208+
warmup_runs: int,
1209+
warmup_ms: Optional[int],
11811210
) -> None:
11821211
np.random.seed(42)
11831212
torch.manual_seed(42)
@@ -1295,6 +1324,7 @@ def nbit_device( # noqa C901
12951324
per_sample_weights,
12961325
),
12971326
check_median=check_median,
1327+
warmup_ms=warmup_ms,
12981328
)
12991329

13001330
# free up GPU memory
@@ -1324,18 +1354,6 @@ def nbit_device( # noqa C901
13241354
f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB"
13251355
)
13261356

1327-
# Get trace for one run of iter
1328-
tbe_type: str = "split"
1329-
1330-
def _kineto_trace_handler(p: profile, phase: str) -> None:
1331-
p.export_chrome_trace(
1332-
trace_url.format(tbe_type=tbe_type, phase=phase, ospid=os.getpid())
1333-
)
1334-
1335-
# pyre-ignore[3]
1336-
def context_factory(on_trace_ready: Callable[[profile], None]):
1337-
return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext()
1338-
13391357
requests = generate_requests(
13401358
iters,
13411359
B,
@@ -1353,7 +1371,35 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
13531371
for req in requests
13541372
]
13551373

1356-
with context_factory(lambda p: _kineto_trace_handler(p, "fwd")):
1374+
# pyre-ignore[3]
1375+
def context_factory(on_trace_ready: Callable[[profile], None]):
1376+
return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext()
1377+
1378+
# Get trace for one run of iter
1379+
tbe_type: str = "split"
1380+
# input of the kineto_trace_profiler
1381+
trace_info = ("fwd", trace_url, tbe_type, "embedding_codegen_forward")
1382+
time_dict = {"kernel_time": None} # dict to hold the kernel time
1383+
1384+
# warm-up right before profiling
1385+
# warmup_ms prioritized over warmup_runs
1386+
if warmup_ms or warmup_runs:
1387+
warmup(
1388+
requests[0],
1389+
# pyre-ignore[6]
1390+
warmup_ms,
1391+
warmup_runs,
1392+
lambda indices, offsets, per_sample_weights: emb.forward(
1393+
indices.int(),
1394+
offsets.int(),
1395+
per_sample_weights,
1396+
),
1397+
)
1398+
1399+
with context_factory(
1400+
# pyre-ignore[6]
1401+
lambda p: time_dict.update(kernel_time=kineto_trace_profiler(p, trace_info))
1402+
):
13571403
# forward
13581404
time_per_iter = benchmark_requests(
13591405
requests,
@@ -1364,6 +1410,21 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
13641410
),
13651411
check_median=check_median,
13661412
)
1413+
1414+
if export_trace:
1415+
kernel_time = time_dict["kernel_time"]
1416+
# pyre-ignore[58]
1417+
bandwidth = read_write_bytes / kernel_time / 1.0e3
1418+
1419+
logging.info(
1420+
f"kineto profiled stats: "
1421+
f"{weights_precision} Forward, B: {B}, "
1422+
f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
1423+
f"BW: {bandwidth: .2f} GB/s, " # noqa: B950
1424+
f"Time: {kernel_time:.0f}us, "
1425+
f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB"
1426+
)
1427+
13671428
# free up GPU memory
13681429
del requests
13691430

@@ -1465,12 +1526,28 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
14651526
@click.option("--check-median", is_flag=True, default=True)
14661527
@click.option("--iters", default=100)
14671528
@click.option("--runs-of-iters", default=5)
1468-
@click.option("--warmup-runs", default=2)
14691529
@click.option("--output-dtype", type=SparseType, default=SparseType.FP16)
14701530
@click.option("--report-aibench", is_flag=True)
14711531
@click.option("--fp8-exponent-bits", type=int, default=None)
14721532
@click.option("--fp8-exponent-bias", type=int, default=None)
14731533
@click.option("--use-cpu", is_flag=True, default=False)
1534+
@click.option("--export-trace", is_flag=True, default=False)
1535+
@click.option(
1536+
"--trace-url",
1537+
type=str,
1538+
default="{tbe_type}_tbe_spec_{phase}_trace_{ospid}.json",
1539+
)
1540+
@click.option(
1541+
"--warmup-runs",
1542+
default=2,
1543+
help="Number of warmup runs. Ignored if --warmup-ms is set.",
1544+
)
1545+
@click.option(
1546+
"--warmup-ms",
1547+
type=int,
1548+
default=None,
1549+
help="Warmup duration in milliseconds. Disables the --run-nums option.",
1550+
)
14741551
def nbit_device_with_spec( # noqa C901
14751552
alpha: float,
14761553
bag_size_list: str,
@@ -1490,12 +1567,15 @@ def nbit_device_with_spec( # noqa C901
14901567
check_median: bool,
14911568
iters: int,
14921569
runs_of_iters: int,
1493-
warmup_runs: int,
14941570
output_dtype: SparseType,
14951571
report_aibench: bool,
14961572
fp8_exponent_bits: Optional[int],
14971573
fp8_exponent_bias: Optional[int],
14981574
use_cpu: bool,
1575+
export_trace: bool,
1576+
trace_url: str,
1577+
warmup_runs: int,
1578+
warmup_ms: Optional[int],
14991579
) -> None:
15001580
np.random.seed(42)
15011581
torch.manual_seed(42)
@@ -1607,6 +1687,7 @@ def nbit_device_with_spec( # noqa C901
16071687
)
16081688

16091689
times = []
1690+
kineto_request = []
16101691
for i in range(runs_of_iters):
16111692
# Generate a request for each table then combine
16121693
all_requests = {
@@ -1683,8 +1764,13 @@ def nbit_device_with_spec( # noqa C901
16831764
per_sample_weights,
16841765
),
16851766
check_median=check_median,
1767+
warmup_ms=warmup_ms,
16861768
)
16871769

1770+
# copy the request of last iteration for kineto profile benchmark
1771+
if i == runs_of_iters - 1:
1772+
kineto_request = requests
1773+
16881774
# free up memory
16891775
del requests
16901776

@@ -1712,6 +1798,63 @@ def nbit_device_with_spec( # noqa C901
17121798
f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB"
17131799
)
17141800

1801+
# pyre-ignore[3]
1802+
def context_factory(on_trace_ready: Callable[[profile], None]):
1803+
return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext()
1804+
1805+
if not use_cpu:
1806+
# profile with kineto
1807+
tbe_type: str = "split"
1808+
time_dict = {"kernel_time": None} # Shared variable to hold the kernel time
1809+
trace_info = ("fwd", trace_url, tbe_type, "embedding_codegen_forward")
1810+
1811+
# warm-up right before profiling
1812+
# warmup_ms prioritized over warmup_runs
1813+
if warmup_ms or warmup_runs:
1814+
warmup(
1815+
kineto_request[0],
1816+
# pyre-ignore[6]
1817+
warmup_ms,
1818+
warmup_runs,
1819+
lambda indices, offsets, per_sample_weights: emb.forward(
1820+
indices.int(),
1821+
offsets.int(),
1822+
per_sample_weights,
1823+
),
1824+
)
1825+
1826+
with context_factory(
1827+
# pyre-ignore[6]
1828+
lambda p: time_dict.update(kernel_time=kineto_trace_profiler(p, trace_info))
1829+
):
1830+
# forward
1831+
time_per_iter = benchmark_requests(
1832+
kineto_request,
1833+
lambda indices, offsets, per_sample_weights: emb.forward(
1834+
indices.int(),
1835+
offsets.int(),
1836+
per_sample_weights,
1837+
),
1838+
check_median=check_median,
1839+
)
1840+
1841+
if export_trace:
1842+
kernel_time = time_dict["kernel_time"]
1843+
# pyre-ignore[6]
1844+
bandwidth = read_write_bytes / kernel_time / 1.0e3
1845+
1846+
logging.info(
1847+
f"kineto profiled stats: "
1848+
f"{weights_precision} Forward, B: {B}, "
1849+
f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
1850+
f"BW: {bandwidth: .2f} GB/s, " # noqa: B950
1851+
f"Time: {kernel_time:.0f}us, "
1852+
f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB"
1853+
)
1854+
1855+
# free up memory
1856+
del kineto_request
1857+
17151858
if report_aibench and haveAIBench:
17161859
print(
17171860
emitMetric(

0 commit comments

Comments
 (0)