Skip to content

Commit 8002a00

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Add JIT and Variable Batch Support to Benchmark (#3131)
Summary: Pull Request resolved: #3131 This update introduces an option to apply Just-In-Time (JIT) compilation in the training pipeline configuration for performance comparison. It also adds support for variable batch sizes, including the generation of Variable Batch KeyedJaggedTensor (VB-KJT). Differential Revision: D76928208
1 parent 4225394 commit 8002a00

File tree

2 files changed

+78
-37
lines changed

2 files changed

+78
-37
lines changed

torchrec/distributed/benchmark/benchmark_pipeline_utils.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class BaseModelConfig(ABC):
6262

6363
# Common parameters for all model types
6464
batch_size: int
65+
batch_sizes: Optional[List[int]]
6566
num_float_features: int
6667
feature_pooling_avg: int
6768
use_offsets: bool
@@ -283,6 +284,7 @@ def generate_pipeline(
283284
model: nn.Module,
284285
opt: torch.optim.Optimizer,
285286
device: torch.device,
287+
apply_jit: bool = False,
286288
) -> Union[TrainPipelineBase, TrainPipelineSparseDist]:
287289
"""
288290
Generate a training pipeline instance based on the configuration.
@@ -303,6 +305,8 @@ def generate_pipeline(
303305
model (nn.Module): The model to be trained.
304306
opt (torch.optim.Optimizer): The optimizer to use for training.
305307
device (torch.device): The device to run the training on.
308+
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
309+
Default is False.
306310
307311
Returns:
308312
Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the
@@ -324,20 +328,28 @@ def generate_pipeline(
324328

325329
if pipeline_type == "semi":
326330
return TrainPipelineSemiSync(
327-
model=model, optimizer=opt, device=device, start_batch=0
331+
model=model,
332+
optimizer=opt,
333+
device=device,
334+
start_batch=0,
335+
apply_jit=apply_jit,
328336
)
329337
elif pipeline_type == "fused":
330338
return TrainPipelineFusedSparseDist(
331339
model=model,
332340
optimizer=opt,
333341
device=device,
334342
emb_lookup_stream=emb_lookup_stream,
343+
apply_jit=apply_jit,
335344
)
336-
elif pipeline_type in _pipeline_cls:
337-
Pipeline = _pipeline_cls[pipeline_type]
338-
return Pipeline(model=model, optimizer=opt, device=device)
345+
elif pipeline_type == "base":
346+
assert apply_jit is False, "JIT is not supported for base pipeline"
347+
348+
return TrainPipelineBase(model=model, optimizer=opt, device=device)
339349
else:
340-
raise RuntimeError(f"unknown pipeline option {pipeline_type}")
350+
Pipeline = _pipeline_cls[pipeline_type]
351+
# pyre-ignore[28]
352+
return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit)
341353

342354

343355
def generate_planner(
@@ -347,8 +359,7 @@ def generate_planner(
347359
weighted_tables: Optional[List[EmbeddingBagConfig]],
348360
sharding_type: ShardingType,
349361
compute_kernel: EmbeddingComputeKernel,
350-
num_batches: int,
351-
batch_size: int,
362+
batch_sizes: List[int],
352363
pooling_factors: Optional[List[float]],
353364
num_poolings: Optional[List[float]],
354365
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
@@ -362,8 +373,7 @@ def generate_planner(
362373
weighted_tables: List of weighted embedding tables
363374
sharding_type: Strategy for sharding embedding tables
364375
compute_kernel: Compute kernel to use for embedding tables
365-
num_batches: Number of batches to process
366-
batch_size: Size of each batch
376+
batch_sizes: Sizes of each batch
367377
pooling_factors: Pooling factors for each feature of the table
368378
num_poolings: Number of poolings for each feature of the table
369379
@@ -375,15 +385,14 @@ def generate_planner(
375385
"""
376386
# Create parameter constraints for tables
377387
constraints = {}
388+
num_batches = len(batch_sizes)
378389

379390
if pooling_factors is None:
380391
pooling_factors = [POOLING_FACTOR] * num_batches
381392

382393
if num_poolings is None:
383394
num_poolings = [NUM_POOLINGS] * num_batches
384395

385-
batch_sizes = [batch_size] * num_batches
386-
387396
assert (
388397
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
389398
), "The length of pooling_factors and num_poolings must match the number of batches."
@@ -481,7 +490,7 @@ def generate_data(
481490
tables: List[EmbeddingBagConfig],
482491
weighted_tables: List[EmbeddingBagConfig],
483492
model_config: BaseModelConfig,
484-
num_batches: int,
493+
batch_sizes: List[int],
485494
) -> List[ModelInput]:
486495
"""
487496
Generate model input data for benchmarking.
@@ -499,7 +508,7 @@ def generate_data(
499508

500509
return [
501510
ModelInput.generate(
502-
batch_size=model_config.batch_size,
511+
batch_size=batch_size,
503512
tables=tables,
504513
weighted_tables=weighted_tables,
505514
num_float_features=model_config.num_float_features,
@@ -517,5 +526,5 @@ def generate_data(
517526
),
518527
pin_memory=model_config.pin_memory,
519528
)
520-
for _ in range(num_batches)
529+
for batch_size in batch_sizes
521530
]

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,13 @@ class PipelineConfig:
136136
emb_lookup_stream (str): The stream to use for embedding lookups.
137137
Only used by certain pipeline types (e.g., "fused").
138138
Default is "data_dist".
139+
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
140+
Default is False.
139141
"""
140142

141143
pipeline: str = "base"
142144
emb_lookup_stream: str = "data_dist"
145+
apply_jit: bool = False
143146

144147

145148
@dataclass
@@ -148,6 +151,7 @@ class ModelSelectionConfig:
148151

149152
# Common config for all model types
150153
batch_size: int = 8192
154+
batch_sizes: Optional[List[int]] = None
151155
num_float_features: int = 10
152156
feature_pooling_avg: int = 10
153157
use_offsets: bool = False
@@ -200,6 +204,7 @@ def main(
200204
model_config = create_model_config(
201205
model_name=model_selection.model_name,
202206
batch_size=model_selection.batch_size,
207+
batch_sizes=model_selection.batch_sizes,
203208
num_float_features=model_selection.num_float_features,
204209
feature_pooling_avg=model_selection.feature_pooling_avg,
205210
use_offsets=model_selection.use_offsets,
@@ -266,6 +271,15 @@ def runner(
266271
compute_device=ctx.device.type,
267272
)
268273

274+
batch_sizes = model_config.batch_sizes
275+
276+
if batch_sizes is None:
277+
batch_sizes = [model_config.batch_size] * run_option.num_batches
278+
else:
279+
assert (
280+
len(batch_sizes) == run_option.num_batches
281+
), "The length of batch_sizes must match the number of batches."
282+
269283
# Create a planner for sharding based on the specified type
270284
planner = generate_planner(
271285
planner_type=run_option.planner_type,
@@ -274,16 +288,15 @@ def runner(
274288
weighted_tables=weighted_tables,
275289
sharding_type=run_option.sharding_type,
276290
compute_kernel=run_option.compute_kernel,
277-
num_batches=run_option.num_batches,
278-
batch_size=model_config.batch_size,
291+
batch_sizes=batch_sizes,
279292
pooling_factors=run_option.pooling_factors,
280293
num_poolings=run_option.num_poolings,
281294
)
282295
bench_inputs = generate_data(
283296
tables=tables,
284297
weighted_tables=weighted_tables,
285298
model_config=model_config,
286-
num_batches=run_option.num_batches,
299+
batch_sizes=batch_sizes,
287300
)
288301

289302
sharded_model, optimizer = generate_sharded_model_and_optimizer(
@@ -299,14 +312,6 @@ def runner(
299312
},
300313
planner=planner,
301314
)
302-
pipeline = generate_pipeline(
303-
pipeline_type=pipeline_config.pipeline,
304-
emb_lookup_stream=pipeline_config.emb_lookup_stream,
305-
model=sharded_model,
306-
opt=optimizer,
307-
device=ctx.device,
308-
)
309-
pipeline.progress(iter(bench_inputs))
310315

311316
def _func_to_benchmark(
312317
bench_inputs: List[ModelInput],
@@ -320,20 +325,47 @@ def _func_to_benchmark(
320325
except StopIteration:
321326
break
322327

323-
result = benchmark_func(
324-
name=type(pipeline).__name__,
325-
bench_inputs=bench_inputs, # pyre-ignore
326-
prof_inputs=bench_inputs, # pyre-ignore
327-
num_benchmarks=5,
328-
num_profiles=2,
329-
profile_dir=run_option.profile,
330-
world_size=run_option.world_size,
331-
func_to_benchmark=_func_to_benchmark,
332-
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
333-
rank=rank,
328+
# Run comparison if apply_jit is True, otherwise run single benchmark
329+
jit_configs = (
330+
[(True, "WithJIT"), (False, "WithoutJIT")]
331+
if pipeline_config.apply_jit
332+
else [(False, "")]
334333
)
334+
results = []
335+
336+
for apply_jit, jit_suffix in jit_configs:
337+
pipeline = generate_pipeline(
338+
pipeline_type=pipeline_config.pipeline,
339+
emb_lookup_stream=pipeline_config.emb_lookup_stream,
340+
model=sharded_model,
341+
opt=optimizer,
342+
device=ctx.device,
343+
apply_jit=apply_jit,
344+
)
345+
pipeline.progress(iter(bench_inputs))
346+
347+
name = (
348+
f"{type(pipeline).__name__}{jit_suffix}"
349+
if jit_suffix
350+
else type(pipeline).__name__
351+
)
352+
result = benchmark_func(
353+
name=name,
354+
bench_inputs=bench_inputs, # pyre-ignore
355+
prof_inputs=bench_inputs, # pyre-ignore
356+
num_benchmarks=5,
357+
num_profiles=2,
358+
profile_dir=run_option.profile,
359+
world_size=run_option.world_size,
360+
func_to_benchmark=_func_to_benchmark,
361+
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
362+
rank=rank,
363+
)
364+
results.append(result)
365+
335366
if rank == 0:
336-
print(result)
367+
for result in results:
368+
print(result)
337369

338370

339371
if __name__ == "__main__":

0 commit comments

Comments
 (0)