Skip to content

Add JIT and Variable Batch Support to Benchmark #3131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions torchrec/distributed/benchmark/benchmark_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class BaseModelConfig(ABC):

# Common parameters for all model types
batch_size: int
batch_sizes: Optional[List[int]]
num_float_features: int
feature_pooling_avg: int
use_offsets: bool
Expand Down Expand Up @@ -283,6 +284,7 @@ def generate_pipeline(
model: nn.Module,
opt: torch.optim.Optimizer,
device: torch.device,
apply_jit: bool = False,
) -> Union[TrainPipelineBase, TrainPipelineSparseDist]:
"""
Generate a training pipeline instance based on the configuration.
Expand All @@ -303,6 +305,8 @@ def generate_pipeline(
model (nn.Module): The model to be trained.
opt (torch.optim.Optimizer): The optimizer to use for training.
device (torch.device): The device to run the training on.
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
Default is False.

Returns:
Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the
Expand All @@ -324,20 +328,28 @@ def generate_pipeline(

if pipeline_type == "semi":
return TrainPipelineSemiSync(
model=model, optimizer=opt, device=device, start_batch=0
model=model,
optimizer=opt,
device=device,
start_batch=0,
apply_jit=apply_jit,
)
elif pipeline_type == "fused":
return TrainPipelineFusedSparseDist(
model=model,
optimizer=opt,
device=device,
emb_lookup_stream=emb_lookup_stream,
apply_jit=apply_jit,
)
elif pipeline_type in _pipeline_cls:
Pipeline = _pipeline_cls[pipeline_type]
return Pipeline(model=model, optimizer=opt, device=device)
elif pipeline_type == "base":
assert apply_jit is False, "JIT is not supported for base pipeline"

return TrainPipelineBase(model=model, optimizer=opt, device=device)
else:
raise RuntimeError(f"unknown pipeline option {pipeline_type}")
Pipeline = _pipeline_cls[pipeline_type]
# pyre-ignore[28]
return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit)


def generate_planner(
Expand All @@ -347,8 +359,7 @@ def generate_planner(
weighted_tables: Optional[List[EmbeddingBagConfig]],
sharding_type: ShardingType,
compute_kernel: EmbeddingComputeKernel,
num_batches: int,
batch_size: int,
batch_sizes: List[int],
pooling_factors: Optional[List[float]],
num_poolings: Optional[List[float]],
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
Expand All @@ -362,8 +373,7 @@ def generate_planner(
weighted_tables: List of weighted embedding tables
sharding_type: Strategy for sharding embedding tables
compute_kernel: Compute kernel to use for embedding tables
num_batches: Number of batches to process
batch_size: Size of each batch
batch_sizes: Sizes of each batch
pooling_factors: Pooling factors for each feature of the table
num_poolings: Number of poolings for each feature of the table

Expand All @@ -375,15 +385,14 @@ def generate_planner(
"""
# Create parameter constraints for tables
constraints = {}
num_batches = len(batch_sizes)

if pooling_factors is None:
pooling_factors = [POOLING_FACTOR] * num_batches

if num_poolings is None:
num_poolings = [NUM_POOLINGS] * num_batches

batch_sizes = [batch_size] * num_batches

assert (
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
), "The length of pooling_factors and num_poolings must match the number of batches."
Expand Down Expand Up @@ -481,7 +490,7 @@ def generate_data(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
model_config: BaseModelConfig,
num_batches: int,
batch_sizes: List[int],
) -> List[ModelInput]:
"""
Generate model input data for benchmarking.
Expand All @@ -499,7 +508,7 @@ def generate_data(

return [
ModelInput.generate(
batch_size=model_config.batch_size,
batch_size=batch_size,
tables=tables,
weighted_tables=weighted_tables,
num_float_features=model_config.num_float_features,
Expand All @@ -517,5 +526,5 @@ def generate_data(
),
pin_memory=model_config.pin_memory,
)
for _ in range(num_batches)
for batch_size in batch_sizes
]
78 changes: 55 additions & 23 deletions torchrec/distributed/benchmark/benchmark_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,13 @@ class PipelineConfig:
emb_lookup_stream (str): The stream to use for embedding lookups.
Only used by certain pipeline types (e.g., "fused").
Default is "data_dist".
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
Default is False.
"""

pipeline: str = "base"
emb_lookup_stream: str = "data_dist"
apply_jit: bool = False


@dataclass
Expand All @@ -148,6 +151,7 @@ class ModelSelectionConfig:

# Common config for all model types
batch_size: int = 8192
batch_sizes: Optional[List[int]] = None
num_float_features: int = 10
feature_pooling_avg: int = 10
use_offsets: bool = False
Expand Down Expand Up @@ -200,6 +204,7 @@ def main(
model_config = create_model_config(
model_name=model_selection.model_name,
batch_size=model_selection.batch_size,
batch_sizes=model_selection.batch_sizes,
num_float_features=model_selection.num_float_features,
feature_pooling_avg=model_selection.feature_pooling_avg,
use_offsets=model_selection.use_offsets,
Expand Down Expand Up @@ -266,6 +271,15 @@ def runner(
compute_device=ctx.device.type,
)

batch_sizes = model_config.batch_sizes

if batch_sizes is None:
batch_sizes = [model_config.batch_size] * run_option.num_batches
else:
assert (
len(batch_sizes) == run_option.num_batches
), "The length of batch_sizes must match the number of batches."

# Create a planner for sharding based on the specified type
planner = generate_planner(
planner_type=run_option.planner_type,
Expand All @@ -274,16 +288,15 @@ def runner(
weighted_tables=weighted_tables,
sharding_type=run_option.sharding_type,
compute_kernel=run_option.compute_kernel,
num_batches=run_option.num_batches,
batch_size=model_config.batch_size,
batch_sizes=batch_sizes,
pooling_factors=run_option.pooling_factors,
num_poolings=run_option.num_poolings,
)
bench_inputs = generate_data(
tables=tables,
weighted_tables=weighted_tables,
model_config=model_config,
num_batches=run_option.num_batches,
batch_sizes=batch_sizes,
)

sharded_model, optimizer = generate_sharded_model_and_optimizer(
Expand All @@ -299,14 +312,6 @@ def runner(
},
planner=planner,
)
pipeline = generate_pipeline(
pipeline_type=pipeline_config.pipeline,
emb_lookup_stream=pipeline_config.emb_lookup_stream,
model=sharded_model,
opt=optimizer,
device=ctx.device,
)
pipeline.progress(iter(bench_inputs))

def _func_to_benchmark(
bench_inputs: List[ModelInput],
Expand All @@ -320,20 +325,47 @@ def _func_to_benchmark(
except StopIteration:
break

result = benchmark_func(
name=type(pipeline).__name__,
bench_inputs=bench_inputs, # pyre-ignore
prof_inputs=bench_inputs, # pyre-ignore
num_benchmarks=5,
num_profiles=2,
profile_dir=run_option.profile,
world_size=run_option.world_size,
func_to_benchmark=_func_to_benchmark,
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
rank=rank,
# Run comparison if apply_jit is True, otherwise run single benchmark
jit_configs = (
[(True, "WithJIT"), (False, "WithoutJIT")]
if pipeline_config.apply_jit
else [(False, "")]
)
results = []

for apply_jit, jit_suffix in jit_configs:
pipeline = generate_pipeline(
pipeline_type=pipeline_config.pipeline,
emb_lookup_stream=pipeline_config.emb_lookup_stream,
model=sharded_model,
opt=optimizer,
device=ctx.device,
apply_jit=apply_jit,
)
pipeline.progress(iter(bench_inputs))

name = (
f"{type(pipeline).__name__}{jit_suffix}"
if jit_suffix
else type(pipeline).__name__
)
result = benchmark_func(
name=name,
bench_inputs=bench_inputs, # pyre-ignore
prof_inputs=bench_inputs, # pyre-ignore
num_benchmarks=5,
num_profiles=2,
profile_dir=run_option.profile,
world_size=run_option.world_size,
func_to_benchmark=_func_to_benchmark,
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
rank=rank,
)
results.append(result)

if rank == 0:
print(result)
for result in results:
print(result)


if __name__ == "__main__":
Expand Down
Loading