77
77
benchmark_torch_function ,
78
78
benchmark_vbe ,
79
79
fill_random_scale_bias ,
80
+ warmup ,
80
81
)
81
82
else :
82
83
from fbgemm_gpu .bench .bench_utils import (
87
88
benchmark_torch_function ,
88
89
benchmark_vbe ,
89
90
fill_random_scale_bias ,
91
+ warmup ,
90
92
)
91
93
92
94
93
95
logging .basicConfig (level = logging .DEBUG )
94
96
95
97
98
+ def kineto_trace_profiler (p : profile , trace_info : tuple [str , str , str , str ]) -> float :
99
+ phase , trace_url , tbe_type , kern_name = trace_info
100
+ p .export_chrome_trace (
101
+ trace_url .format (tbe_type = tbe_type , phase = phase , ospid = os .getpid ())
102
+ )
103
+ kernel_time = 0
104
+ for event in p .key_averages ():
105
+ # Sum the total time of forward kernel runs
106
+ if kern_name in event .key :
107
+ kernel_time += event .device_time
108
+ assert kernel_time > 0
109
+ print (f"Total CUDA time: { kernel_time :.2f} " )
110
+ return kernel_time
111
+
112
+
96
113
@click .group ()
97
114
def cli () -> None :
98
115
pass
@@ -323,7 +340,6 @@ def device( # noqa C901
323
340
logging .info (
324
341
f"Accessed weights per batch: { B * sum (Ds ) * L * param_size_multiplier / 1.0e9 : .2f} GB"
325
342
)
326
-
327
343
requests = generate_requests (
328
344
iters ,
329
345
B ,
@@ -1135,6 +1151,7 @@ def nbit_cpu( # noqa C901
1135
1151
@click .option ("--iters" , default = 100 )
1136
1152
@click .option ("--runs-of-iters" , default = 5 )
1137
1153
@click .option ("--warmup-runs" , default = 2 )
1154
+ @click .option ("--warmup-ms" , type = int , default = None )
1138
1155
@click .option ("--output-dtype" , type = SparseType , default = SparseType .FP16 )
1139
1156
@click .option ("--report-aibench" , is_flag = True )
1140
1157
@click .option ("--run-reference" , is_flag = True , default = False )
@@ -1148,6 +1165,17 @@ def nbit_cpu( # noqa C901
1148
1165
type = str ,
1149
1166
default = "{tbe_type}_tbe_{phase}_trace_{ospid}.json" ,
1150
1167
)
1168
+ @click .option (
1169
+ "--warmup-runs" ,
1170
+ default = 2 ,
1171
+ help = "Number of warmup runs. Ignored if --warmup-ms is set." ,
1172
+ )
1173
+ @click .option (
1174
+ "--warmup-ms" ,
1175
+ type = int ,
1176
+ default = None ,
1177
+ help = "Warmup duration in milliseconds. Disables the --run-nums option." ,
1178
+ )
1151
1179
def nbit_device ( # noqa C901
1152
1180
alpha : float ,
1153
1181
bag_size : int ,
@@ -1168,7 +1196,6 @@ def nbit_device( # noqa C901
1168
1196
check_median : bool ,
1169
1197
iters : int ,
1170
1198
runs_of_iters : int ,
1171
- warmup_runs : int ,
1172
1199
output_dtype : SparseType ,
1173
1200
report_aibench : bool ,
1174
1201
run_reference : bool ,
@@ -1178,6 +1205,8 @@ def nbit_device( # noqa C901
1178
1205
fp8_exponent_bias : Optional [int ],
1179
1206
export_trace : bool ,
1180
1207
trace_url : str ,
1208
+ warmup_runs : int ,
1209
+ warmup_ms : Optional [int ],
1181
1210
) -> None :
1182
1211
np .random .seed (42 )
1183
1212
torch .manual_seed (42 )
@@ -1295,6 +1324,7 @@ def nbit_device( # noqa C901
1295
1324
per_sample_weights ,
1296
1325
),
1297
1326
check_median = check_median ,
1327
+ warmup_ms = warmup_ms ,
1298
1328
)
1299
1329
1300
1330
# free up GPU memory
@@ -1324,18 +1354,6 @@ def nbit_device( # noqa C901
1324
1354
f"Memory Usage For Pruning: { mem_for_pruning / 1.0e9 :.0f} GB"
1325
1355
)
1326
1356
1327
- # Get trace for one run of iter
1328
- tbe_type : str = "split"
1329
-
1330
- def _kineto_trace_handler (p : profile , phase : str ) -> None :
1331
- p .export_chrome_trace (
1332
- trace_url .format (tbe_type = tbe_type , phase = phase , ospid = os .getpid ())
1333
- )
1334
-
1335
- # pyre-ignore[3]
1336
- def context_factory (on_trace_ready : Callable [[profile ], None ]):
1337
- return profile (on_trace_ready = on_trace_ready ) if export_trace else nullcontext ()
1338
-
1339
1357
requests = generate_requests (
1340
1358
iters ,
1341
1359
B ,
@@ -1353,7 +1371,35 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
1353
1371
for req in requests
1354
1372
]
1355
1373
1356
- with context_factory (lambda p : _kineto_trace_handler (p , "fwd" )):
1374
+ # pyre-ignore[3]
1375
+ def context_factory (on_trace_ready : Callable [[profile ], None ]):
1376
+ return profile (on_trace_ready = on_trace_ready ) if export_trace else nullcontext ()
1377
+
1378
+ # Get trace for one run of iter
1379
+ tbe_type : str = "split"
1380
+ # input of the kineto_trace_profiler
1381
+ trace_info = ("fwd" , trace_url , tbe_type , "embedding_codegen_forward" )
1382
+ time_dict = {"kernel_time" : None } # dict to hold the kernel time
1383
+
1384
+ # warm-up right before profiling
1385
+ # warmup_ms prioritized over warmup_runs
1386
+ if warmup_ms or warmup_runs :
1387
+ warmup (
1388
+ requests [0 ],
1389
+ # pyre-ignore[6]
1390
+ warmup_ms ,
1391
+ warmup_runs ,
1392
+ lambda indices , offsets , per_sample_weights : emb .forward (
1393
+ indices .int (),
1394
+ offsets .int (),
1395
+ per_sample_weights ,
1396
+ ),
1397
+ )
1398
+
1399
+ with context_factory (
1400
+ # pyre-ignore[6]
1401
+ lambda p : time_dict .update (kernel_time = kineto_trace_profiler (p , trace_info ))
1402
+ ):
1357
1403
# forward
1358
1404
time_per_iter = benchmark_requests (
1359
1405
requests ,
@@ -1364,6 +1410,21 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
1364
1410
),
1365
1411
check_median = check_median ,
1366
1412
)
1413
+
1414
+ if export_trace :
1415
+ kernel_time = time_dict ["kernel_time" ]
1416
+ # pyre-ignore[58]
1417
+ bandwidth = read_write_bytes / kernel_time / 1.0e3
1418
+
1419
+ logging .info (
1420
+ f"kineto profiled stats: "
1421
+ f"{ weights_precision } Forward, B: { B } , "
1422
+ f"E: { E } , T: { T } , D: { D } , L: { L } , W: { weighted } , "
1423
+ f"BW: { bandwidth : .2f} GB/s, " # noqa: B950
1424
+ f"Time: { kernel_time :.0f} us, "
1425
+ f"Memory Usage For Pruning: { mem_for_pruning / 1.0e9 :.0f} GB"
1426
+ )
1427
+
1367
1428
# free up GPU memory
1368
1429
del requests
1369
1430
@@ -1465,12 +1526,28 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
1465
1526
@click .option ("--check-median" , is_flag = True , default = True )
1466
1527
@click .option ("--iters" , default = 100 )
1467
1528
@click .option ("--runs-of-iters" , default = 5 )
1468
- @click .option ("--warmup-runs" , default = 2 )
1469
1529
@click .option ("--output-dtype" , type = SparseType , default = SparseType .FP16 )
1470
1530
@click .option ("--report-aibench" , is_flag = True )
1471
1531
@click .option ("--fp8-exponent-bits" , type = int , default = None )
1472
1532
@click .option ("--fp8-exponent-bias" , type = int , default = None )
1473
1533
@click .option ("--use-cpu" , is_flag = True , default = False )
1534
+ @click .option ("--export-trace" , is_flag = True , default = False )
1535
+ @click .option (
1536
+ "--trace-url" ,
1537
+ type = str ,
1538
+ default = "{tbe_type}_tbe_spec_{phase}_trace_{ospid}.json" ,
1539
+ )
1540
+ @click .option (
1541
+ "--warmup-runs" ,
1542
+ default = 2 ,
1543
+ help = "Number of warmup runs. Ignored if --warmup-ms is set." ,
1544
+ )
1545
+ @click .option (
1546
+ "--warmup-ms" ,
1547
+ type = int ,
1548
+ default = None ,
1549
+ help = "Warmup duration in milliseconds. Disables the --run-nums option." ,
1550
+ )
1474
1551
def nbit_device_with_spec ( # noqa C901
1475
1552
alpha : float ,
1476
1553
bag_size_list : str ,
@@ -1490,12 +1567,15 @@ def nbit_device_with_spec( # noqa C901
1490
1567
check_median : bool ,
1491
1568
iters : int ,
1492
1569
runs_of_iters : int ,
1493
- warmup_runs : int ,
1494
1570
output_dtype : SparseType ,
1495
1571
report_aibench : bool ,
1496
1572
fp8_exponent_bits : Optional [int ],
1497
1573
fp8_exponent_bias : Optional [int ],
1498
1574
use_cpu : bool ,
1575
+ export_trace : bool ,
1576
+ trace_url : str ,
1577
+ warmup_runs : int ,
1578
+ warmup_ms : Optional [int ],
1499
1579
) -> None :
1500
1580
np .random .seed (42 )
1501
1581
torch .manual_seed (42 )
@@ -1607,6 +1687,7 @@ def nbit_device_with_spec( # noqa C901
1607
1687
)
1608
1688
1609
1689
times = []
1690
+ kineto_request = []
1610
1691
for i in range (runs_of_iters ):
1611
1692
# Generate a request for each table then combine
1612
1693
all_requests = {
@@ -1683,8 +1764,13 @@ def nbit_device_with_spec( # noqa C901
1683
1764
per_sample_weights ,
1684
1765
),
1685
1766
check_median = check_median ,
1767
+ warmup_ms = warmup_ms ,
1686
1768
)
1687
1769
1770
+ # copy the request of last iteration for kineto profile benchmark
1771
+ if i == runs_of_iters - 1 :
1772
+ kineto_request = requests
1773
+
1688
1774
# free up memory
1689
1775
del requests
1690
1776
@@ -1712,6 +1798,63 @@ def nbit_device_with_spec( # noqa C901
1712
1798
f"Memory Usage For Pruning: { mem_for_pruning / 1.0e9 :.0f} GB"
1713
1799
)
1714
1800
1801
+ # pyre-ignore[3]
1802
+ def context_factory (on_trace_ready : Callable [[profile ], None ]):
1803
+ return profile (on_trace_ready = on_trace_ready ) if export_trace else nullcontext ()
1804
+
1805
+ if not use_cpu :
1806
+ # profile with kineto
1807
+ tbe_type : str = "split"
1808
+ time_dict = {"kernel_time" : None } # Shared variable to hold the kernel time
1809
+ trace_info = ("fwd" , trace_url , tbe_type , "embedding_codegen_forward" )
1810
+
1811
+ # warm-up right before profiling
1812
+ # warmup_ms prioritized over warmup_runs
1813
+ if warmup_ms or warmup_runs :
1814
+ warmup (
1815
+ kineto_request [0 ],
1816
+ # pyre-ignore[6]
1817
+ warmup_ms ,
1818
+ warmup_runs ,
1819
+ lambda indices , offsets , per_sample_weights : emb .forward (
1820
+ indices .int (),
1821
+ offsets .int (),
1822
+ per_sample_weights ,
1823
+ ),
1824
+ )
1825
+
1826
+ with context_factory (
1827
+ # pyre-ignore[6]
1828
+ lambda p : time_dict .update (kernel_time = kineto_trace_profiler (p , trace_info ))
1829
+ ):
1830
+ # forward
1831
+ time_per_iter = benchmark_requests (
1832
+ kineto_request ,
1833
+ lambda indices , offsets , per_sample_weights : emb .forward (
1834
+ indices .int (),
1835
+ offsets .int (),
1836
+ per_sample_weights ,
1837
+ ),
1838
+ check_median = check_median ,
1839
+ )
1840
+
1841
+ if export_trace :
1842
+ kernel_time = time_dict ["kernel_time" ]
1843
+ # pyre-ignore[6]
1844
+ bandwidth = read_write_bytes / kernel_time / 1.0e3
1845
+
1846
+ logging .info (
1847
+ f"kineto profiled stats: "
1848
+ f"{ weights_precision } Forward, B: { B } , "
1849
+ f"E: { E } , T: { T } , D: { D } , L: { L } , W: { weighted } , "
1850
+ f"BW: { bandwidth : .2f} GB/s, " # noqa: B950
1851
+ f"Time: { kernel_time :.0f} us, "
1852
+ f"Memory Usage For Pruning: { mem_for_pruning / 1.0e9 :.0f} GB"
1853
+ )
1854
+
1855
+ # free up memory
1856
+ del kineto_request
1857
+
1715
1858
if report_aibench and haveAIBench :
1716
1859
print (
1717
1860
emitMetric (
0 commit comments