@@ -136,10 +136,13 @@ class PipelineConfig:
136
136
emb_lookup_stream (str): The stream to use for embedding lookups.
137
137
Only used by certain pipeline types (e.g., "fused").
138
138
Default is "data_dist".
139
+ apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
140
+ Default is False.
139
141
"""
140
142
141
143
pipeline : str = "base"
142
144
emb_lookup_stream : str = "data_dist"
145
+ apply_jit : bool = False
143
146
144
147
145
148
@dataclass
@@ -148,6 +151,7 @@ class ModelSelectionConfig:
148
151
149
152
# Common config for all model types
150
153
batch_size : int = 8192
154
+ batch_sizes : Optional [List [int ]] = None
151
155
num_float_features : int = 10
152
156
feature_pooling_avg : int = 10
153
157
use_offsets : bool = False
@@ -200,6 +204,7 @@ def main(
200
204
model_config = create_model_config (
201
205
model_name = model_selection .model_name ,
202
206
batch_size = model_selection .batch_size ,
207
+ batch_sizes = model_selection .batch_sizes ,
203
208
num_float_features = model_selection .num_float_features ,
204
209
feature_pooling_avg = model_selection .feature_pooling_avg ,
205
210
use_offsets = model_selection .use_offsets ,
@@ -266,6 +271,15 @@ def runner(
266
271
compute_device = ctx .device .type ,
267
272
)
268
273
274
+ batch_sizes = model_config .batch_sizes
275
+
276
+ if batch_sizes is None :
277
+ batch_sizes = [model_config .batch_size ] * run_option .num_batches
278
+ else :
279
+ assert (
280
+ len (batch_sizes ) == run_option .num_batches
281
+ ), "The length of batch_sizes must match the number of batches."
282
+
269
283
# Create a planner for sharding based on the specified type
270
284
planner = generate_planner (
271
285
planner_type = run_option .planner_type ,
@@ -274,16 +288,15 @@ def runner(
274
288
weighted_tables = weighted_tables ,
275
289
sharding_type = run_option .sharding_type ,
276
290
compute_kernel = run_option .compute_kernel ,
277
- num_batches = run_option .num_batches ,
278
- batch_size = model_config .batch_size ,
291
+ batch_sizes = batch_sizes ,
279
292
pooling_factors = run_option .pooling_factors ,
280
293
num_poolings = run_option .num_poolings ,
281
294
)
282
295
bench_inputs = generate_data (
283
296
tables = tables ,
284
297
weighted_tables = weighted_tables ,
285
298
model_config = model_config ,
286
- num_batches = run_option . num_batches ,
299
+ batch_sizes = batch_sizes ,
287
300
)
288
301
289
302
sharded_model , optimizer = generate_sharded_model_and_optimizer (
@@ -299,14 +312,6 @@ def runner(
299
312
},
300
313
planner = planner ,
301
314
)
302
- pipeline = generate_pipeline (
303
- pipeline_type = pipeline_config .pipeline ,
304
- emb_lookup_stream = pipeline_config .emb_lookup_stream ,
305
- model = sharded_model ,
306
- opt = optimizer ,
307
- device = ctx .device ,
308
- )
309
- pipeline .progress (iter (bench_inputs ))
310
315
311
316
def _func_to_benchmark (
312
317
bench_inputs : List [ModelInput ],
@@ -320,20 +325,47 @@ def _func_to_benchmark(
320
325
except StopIteration :
321
326
break
322
327
323
- result = benchmark_func (
324
- name = type (pipeline ).__name__ ,
325
- bench_inputs = bench_inputs , # pyre-ignore
326
- prof_inputs = bench_inputs , # pyre-ignore
327
- num_benchmarks = 5 ,
328
- num_profiles = 2 ,
329
- profile_dir = run_option .profile ,
330
- world_size = run_option .world_size ,
331
- func_to_benchmark = _func_to_benchmark ,
332
- benchmark_func_kwargs = {"model" : sharded_model , "pipeline" : pipeline },
333
- rank = rank ,
328
+ # Run comparison if apply_jit is True, otherwise run single benchmark
329
+ jit_configs = (
330
+ [(True , "WithJIT" ), (False , "WithoutJIT" )]
331
+ if pipeline_config .apply_jit
332
+ else [(False , "" )]
334
333
)
334
+ results = []
335
+
336
+ for apply_jit , jit_suffix in jit_configs :
337
+ pipeline = generate_pipeline (
338
+ pipeline_type = pipeline_config .pipeline ,
339
+ emb_lookup_stream = pipeline_config .emb_lookup_stream ,
340
+ model = sharded_model ,
341
+ opt = optimizer ,
342
+ device = ctx .device ,
343
+ apply_jit = apply_jit ,
344
+ )
345
+ pipeline .progress (iter (bench_inputs ))
346
+
347
+ name = (
348
+ f"{ type (pipeline ).__name__ } { jit_suffix } "
349
+ if jit_suffix
350
+ else type (pipeline ).__name__
351
+ )
352
+ result = benchmark_func (
353
+ name = name ,
354
+ bench_inputs = bench_inputs , # pyre-ignore
355
+ prof_inputs = bench_inputs , # pyre-ignore
356
+ num_benchmarks = 5 ,
357
+ num_profiles = 2 ,
358
+ profile_dir = run_option .profile ,
359
+ world_size = run_option .world_size ,
360
+ func_to_benchmark = _func_to_benchmark ,
361
+ benchmark_func_kwargs = {"model" : sharded_model , "pipeline" : pipeline },
362
+ rank = rank ,
363
+ )
364
+ results .append (result )
365
+
335
366
if rank == 0 :
336
- print (result )
367
+ for result in results :
368
+ print (result )
337
369
338
370
339
371
if __name__ == "__main__" :
0 commit comments