diff --git a/torchrec/distributed/benchmark/benchmark_inference.py b/torchrec/distributed/benchmark/benchmark_inference.py index 9f28394a7..0e7c1622d 100644 --- a/torchrec/distributed/benchmark/benchmark_inference.py +++ b/torchrec/distributed/benchmark/benchmark_inference.py @@ -8,336 +8,28 @@ #!/usr/bin/env python3 import argparse -import gc import logging import os import time -from dataclasses import dataclass - -from enum import Enum -from typing import Dict, List, Tuple +from typing import List import torch -from torch.autograd.profiler import record_function -from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType -from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology -from torchrec.distributed.planner.enumerators import EmbeddingEnumerator -from torchrec.distributed.planner.shard_estimators import ( - EmbeddingPerfEstimator, - EmbeddingStorageEstimator, +from torchrec.distributed.benchmark.benchmark_utils import ( + benchmark_module, + BenchmarkResult, + CompileMode, + get_tables, ) -from torchrec.distributed.shard import _shard_modules - +from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType from torchrec.distributed.test_utils.infer_utils import TestQuantEBCSharder -from torchrec.distributed.test_utils.test_model import ModelInput -from torchrec.distributed.types import DataType, ShardingEnv -from torchrec.fx import symbolic_trace -from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.quant.embedding_modules import ( EmbeddingBagCollection as QuantEmbeddingBagCollection, ) -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor logger: logging.Logger = logging.getLogger() -# TODO: Workaround for torchscript silent failure issue -# https://fb.workplace.com/groups/1405155842844877/permalink/7840375302656200/ -torch._C._jit_override_can_fuse_on_cpu(False) -torch._C._jit_override_can_fuse_on_gpu(False) -torch._C._jit_set_texpr_fuser_enabled(False) -torch._C._jit_set_nvfuser_enabled(False) - - -class TerminalStage(Enum): - FP = 0 - QUANT = 1 - FXJIT_QUANT = 2 - SHARDED_QUANT = 3 - FXJIT_SHARDED_QUANT = 4 - - -class EBCWrapper(torch.nn.Module): - "Wrapper Module for benchmarking TorchRec Inference Modules" - - def __init__(self, module: torch.nn.Module) -> None: - super().__init__() - self._module = module - - def forward(self, input: KeyedJaggedTensor) -> KeyedTensor: - return self._module.forward(input) - - -@dataclass -class BenchmarkResult: - "Class for holding results of benchmark runs" - short_name: str - total_duration_sec: float - std_sec: float - max_mem_allocated: List[int] - - -def print_device_memory_allocated_status(log: str) -> None: - for di in range(torch.cuda.device_count()): - device = torch.device(f"cuda:{di}") - logger.info( - f"cuda.memory_allocated[{log}][{device}]:{torch.cuda.memory.memory_allocated(device) // 1024 // 1024} Mb" - ) - - -UNSHARDED_TERMINAL_STAGES: List[TerminalStage] = [ - TerminalStage.QUANT, - TerminalStage.FXJIT_QUANT, -] - -SHARDED_TERMINAL_STAGES: List[TerminalStage] = [ - TerminalStage.SHARDED_QUANT, - TerminalStage.FXJIT_SHARDED_QUANT, -] - - -def _model_ebc( - tables: List[EmbeddingBagConfig], - quant_device: torch.device, - device: torch.device, - quant_state_dict_split_scale_bias: bool, - sharding_type: ShardingType, - world_size: int, - batch_size: int, - terminal_stage: TerminalStage, - inputs: List[KeyedJaggedTensor], -) -> torch.nn.Module: - logging.info(f" _model_ebc.BEGIN[{terminal_stage}]") - print_device_memory_allocated_status(f"_model_ebc.BEGIN {terminal_stage}") - - wrapped_module = EBCWrapper( - QuantEmbeddingBagCollection( - tables=tables, - is_weighted=False, - device=quant_device, - quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias, - ) - ) - - print_device_memory_allocated_status(f"_model_ebc.WRAPPED_MODULE {terminal_stage}") - - if terminal_stage == TerminalStage.QUANT: - return wrapped_module - - if terminal_stage == TerminalStage.FXJIT_QUANT: - wrapped_module(inputs[0]) - graph_module = symbolic_trace( - wrapped_module, leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"] - ) - scripted_module = torch.jit.script(graph_module) - return scripted_module - - print_device_memory_allocated_status(f"_model_ebc.BEFORE_SHARDING {terminal_stage}") - sharder = TestQuantEBCSharder( - sharding_type=sharding_type.value, - kernel_type=EmbeddingComputeKernel.QUANT.value, - shardable_params=[table.name for table in tables], - ) - - topology: Topology = Topology(world_size=world_size, compute_device="cuda") - planner = EmbeddingShardingPlanner( - topology=topology, - batch_size=batch_size, - enumerator=EmbeddingEnumerator( - topology=topology, - batch_size=batch_size, - estimator=[ - EmbeddingPerfEstimator(topology=topology, is_inference=True), - EmbeddingStorageEstimator(topology=topology), - ], - ), - ) - - # pyre-ignore [6] - plan = planner.plan(wrapped_module, [sharder]) - - sharded_module = _shard_modules( - module=wrapped_module, - # pyre-ignore [6] - sharders=[sharder], - device=device, - plan=plan, - env=ShardingEnv.from_local(world_size=topology.world_size, rank=0), - ) - print_device_memory_allocated_status(f"_model_ebc.AFTER_SHARDING {terminal_stage}") - - if terminal_stage == TerminalStage.SHARDED_QUANT: - return sharded_module - - sharded_module(inputs[0]) - sharded_traced_module = symbolic_trace( - sharded_module, leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"] - ) - - sharded_scripted_module = torch.jit.script(sharded_traced_module) - return sharded_scripted_module - - -def benchmark( - name: str, - model: torch.nn.Module, - warmup_inputs: List[KeyedJaggedTensor], - bench_inputs: List[KeyedJaggedTensor], - prof_inputs: List[KeyedJaggedTensor], - batch_size: int, - world_size: int, - output_dir: str, - num_benchmarks: int, -) -> BenchmarkResult: - model.training = False - max_mem_allocated: List[int] = [] - logger.info(f" BENCHMARK_MODEL[{name}]:\n{model}") - - for _input in warmup_inputs: - model(_input) - - # Reset memory for measurement - for di in range(torch.cuda.device_count()): - torch.cuda.reset_max_memory_allocated(torch.device(f"cuda:{di}")) - - # Measure time taken for batches in bench_inputs - start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] - end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] - - gc.disable() - - for i in range(num_benchmarks): - start[i].record() - for _input in bench_inputs: - model(_input) - end[i].record() - gc.collect() - - for di in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{di}")) - - # TODO: First Benchmark Run for Eager Mode produces outlier - # Start counting after first as workaround for standard deviation - elapsed_time = torch.tensor( - [si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])] - ) - - total_duration_sec = elapsed_time.mean().item() * 1e-3 # time in seconds - std_sec = elapsed_time.std().item() * 1e-3 # time in seconds - gc.enable() - - for di in range(world_size): - b = torch.cuda.max_memory_allocated(torch.device(f"cuda:{di}")) - max_mem_allocated.append(b // 1024 // 1024) - - # pyre-ignore[2] - def trace_handler(prof) -> None: - total_average = prof.profiler.total_average() - logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_average}") - dir_path: str = output_dir - trace_file: str = f"{dir_path}/trace-{name}.json" - stacks_cpu_file = f"{dir_path}/stacks-cpu-{name}.stacks" - stacks_cuda_file = f"{dir_path}/stacks-cuda-{name}.stacks" - logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}") - - prof.export_chrome_trace(trace_file) - prof.export_stacks(stacks_cpu_file, "self_cpu_time_total") - prof.export_stacks(stacks_cuda_file, "self_cuda_time_total") - - # - git clone https://github.com/brendangregg/FlameGraph - # - cd FlameGraph - # - ./flamegraph.pl --title "CPU time" --countname "us." profiler.stacks > perf_viz.svg - - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=True, - profile_memory=True, - with_stack=True, - with_flops=True, - with_modules=True, - on_trace_ready=trace_handler, - ) as p: - for _input in prof_inputs: - with record_function("## forward ##"): - model(_input) - p.step() - for di in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{di}")) - - return BenchmarkResult( - short_name=name, - total_duration_sec=total_duration_sec, - std_sec=std_sec, - max_mem_allocated=max_mem_allocated, - ) - - -def benchmark_type_name( - terminal_stage: TerminalStage, sharding_type: ShardingType -) -> str: - if terminal_stage in UNSHARDED_TERMINAL_STAGES: - name = "unsharded-qebc" - if terminal_stage == TerminalStage.FXJIT_QUANT: - name += "-fxjit-quant" - else: - if sharding_type == ShardingType.ROW_WISE: - name = "rw-sharded-qebc" - elif sharding_type == ShardingType.COLUMN_WISE: - name = "cw-sharded-qebc" - else: - name = "tw-sharded-qebc" - - if terminal_stage == TerminalStage.FXJIT_SHARDED_QUANT: - name += "-fxjit-quant" - - return name - - -def run_benchmark( - tables: List[EmbeddingBagConfig], - quant_device: torch.device, - device: torch.device, - sharding_type: ShardingType, - world_size: int, - batch_size: int, - terminal_stage: TerminalStage, - warmup_inputs: List[KeyedJaggedTensor], - bench_inputs: List[KeyedJaggedTensor], - prof_inputs: List[KeyedJaggedTensor], - output_dir: str, - num_benchmarks: int, -) -> BenchmarkResult: - module = _model_ebc( - terminal_stage=terminal_stage, - tables=tables, - quant_device=quant_device, - device=device, - quant_state_dict_split_scale_bias=True, - sharding_type=sharding_type, - world_size=world_size, - batch_size=batch_size, - inputs=bench_inputs, - ) - - name = benchmark_type_name(terminal_stage, sharding_type) - - return benchmark( - name, - module, - warmup_inputs, - bench_inputs, - prof_inputs, - batch_size, - world_size=world_size, - output_dir=output_dir, - num_benchmarks=num_benchmarks, - ) - - def init_argparse() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() @@ -352,38 +44,6 @@ def init_argparse() -> argparse.ArgumentParser: return parser -def get_tables_and_input( - table_sizes: List[Tuple[int, int]], - num_inputs: int, - batch_size: int, - world_size: int, -) -> Tuple[List[EmbeddingBagConfig], List[KeyedJaggedTensor]]: - tables: List[EmbeddingBagConfig] = [ - EmbeddingBagConfig( - num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - data_type=DataType.INT8, - ) - for i, (num_embeddings, embedding_dim) in enumerate(table_sizes) - ] - - inputs: List[KeyedJaggedTensor] = [] - for _ in range(num_inputs): - model_input = ModelInput.generate( - batch_size=batch_size, - world_size=world_size, - num_float_features=0, - tables=tables, - weighted_tables=[], - long_indices=False, - )[1][0] - inputs.append(model_input.idlist_features.to(torch.device("cuda:0"))) - - return tables, inputs - - def write_report( benchmark_results: List[BenchmarkResult], report_file: str, @@ -392,7 +52,8 @@ def write_report( ) -> None: for benchmark_res in benchmark_results: - avg_dur_s, std_dur_s = benchmark_res.total_duration_sec, benchmark_res.std_sec + avg_dur_s = benchmark_res.elapsed_time.mean().item() * 1e-3 # time in seconds + std_dur_s = benchmark_res.elapsed_time.std().item() * 1e-3 # time in seconds qps = int(num_requests / avg_dur_s) @@ -412,24 +73,13 @@ def write_report( logger.info(f"Report written to {report_file}:\n{report_str}") -def main() -> None: +def benchmark_qebc() -> None: parser = init_argparse() args = parser.parse_args() datetime_sfx: str = time.strftime("%Y%m%dT%H%M%S") - warmup_iters: int = args.warmup_iters - bench_iters: int = args.bench_iters - prof_iters: int = args.prof_iters - num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters - - world_size: int = args.world_size - batch_size: int = args.batch_size - - output_dir: str = args.output_dir - - num_benchmarks = args.num_benchmarks - + output_dir = args.output_dir if not os.path.exists(output_dir): # Create output directory if not exist os.mkdir(output_dir) @@ -445,96 +95,66 @@ def main() -> None: ShardingType.COLUMN_WISE, ] + BENCH_COMPILE_MODES = [ + CompileMode.EAGER, + CompileMode.FX_SCRIPT, + ] + table_sizes = [ (40_000_000, 256), (4_000_000, 256), (1_000_000, 256), ] - tables, inputs = get_tables_and_input( - table_sizes, num_inputs_to_gen, batch_size, world_size - ) - print_device_memory_allocated_status("Memory After Generating Inputs") - - warmup_inputs = inputs[:warmup_iters] - bench_inputs = inputs[warmup_iters : (warmup_iters + bench_iters)] - prof_inputs = inputs[-prof_iters:] - tables_info = "\nTABLE SIZES QUANT:" for i, (num, dim) in enumerate(table_sizes): mb = int(float(num * dim) / 1024 / 1024) tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] u8: {mb:6}Mb" - report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{world_size} batch_size:{batch_size//1000}k\n" + report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size//1000}k\n" + report += "Module: QuantEmbeddingBagCollection\n" report += tables_info report += "\n" - benchmark_results: List[BenchmarkResult] = [] - - # enable_reference_cycle_detector() - - # TEST UNSHARDED - for terminal_stage in UNSHARDED_TERMINAL_STAGES: - benchmark_type = benchmark_type_name(terminal_stage, ShardingType.TABLE_WISE) - logging.info( - f"\n\n###### Running QEBC Benchmark Type: {benchmark_type} ######\n" - ) - - res = run_benchmark( - terminal_stage=terminal_stage, - tables=tables, - quant_device=torch.device("cuda:0"), - device=torch.device("cuda:0"), - sharding_type=ShardingType.TABLE_WISE, - world_size=world_size, - batch_size=batch_size, - warmup_inputs=warmup_inputs, - bench_inputs=bench_inputs, - prof_inputs=prof_inputs, - output_dir=output_dir, - num_benchmarks=num_benchmarks, - ) - - # Reference cycles present with torch.fx.GraphModule - gc.collect() - benchmark_results.append(res) - print_device_memory_allocated_status("Memory Post Benchmarking") - - # TEST SHARDED - for sharding_type in BENCH_SHARDING_TYPES: - for terminal_stage in SHARDED_TERMINAL_STAGES: - benchmark_type = benchmark_type_name(terminal_stage, sharding_type) - logging.info( - f"\n\n###### Running QEBC Benchmark Type: {benchmark_type} ######\n" - ) - res = run_benchmark( - terminal_stage=terminal_stage, - tables=tables, - quant_device=torch.device("cpu"), - device=torch.device("cuda:0"), - sharding_type=sharding_type, - world_size=world_size, - batch_size=batch_size, - warmup_inputs=warmup_inputs, - bench_inputs=bench_inputs, - prof_inputs=prof_inputs, - output_dir=output_dir, - num_benchmarks=num_benchmarks, - ) - - gc.collect() - benchmark_results.append(res) - print_device_memory_allocated_status("Memory Post Benchmarking") - - num_requests = len(bench_inputs) * batch_size * num_benchmarks + num_requests = args.bench_iters * args.batch_size * args.num_benchmarks report += f"num_requests:{num_requests:8}\n" - report_file: str = f"{output_dir}/run.report" + tables = get_tables(table_sizes) + sharder = TestQuantEBCSharder( + sharding_type="", + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in tables], + ) + + module = QuantEmbeddingBagCollection( + tables=tables, + is_weighted=False, + device=torch.device("cpu"), + quant_state_dict_split_scale_bias=True, + ) + + args_kwargs = { + argname: getattr(args, argname) + for argname in dir(args) + # Don't include output_dir since output_dir was modified + if not argname.startswith("_") and argname != "output_dir" + } + + benchmark_results = benchmark_module( + module=module, + sharder=sharder, + sharding_types=BENCH_SHARDING_TYPES, + compile_modes=BENCH_COMPILE_MODES, + tables=tables, + output_dir=output_dir, + **args_kwargs, + ) + write_report(benchmark_results, report_file, report, num_requests) if __name__ == "__main__": logging.basicConfig() logging.getLogger().setLevel(logging.DEBUG) - main() + benchmark_qebc() diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py new file mode 100644 index 000000000..c083cf3b5 --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -0,0 +1,443 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +import copy +import gc +import logging +from dataclasses import dataclass + +from enum import Enum +from typing import List, Tuple, TypeVar + +import torch +from torch.autograd.profiler import record_function +from torchrec.distributed.embedding_types import ShardingType + +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.shard_estimators import ( + EmbeddingPerfEstimator, + EmbeddingStorageEstimator, +) +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.test_utils.test_model import ModelInput + +from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv +from torchrec.fx import symbolic_trace +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + +logger: logging.Logger = logging.getLogger() + + +class CompileMode(Enum): + EAGER = "eager" + FX_SCRIPT = "fx_script" + + +@dataclass +class BenchmarkResult: + "Class for holding results of benchmark runs" + short_name: str + elapsed_time: torch.Tensor + max_mem_allocated: List[int] + + +class EBCWrapper(torch.nn.Module): + """ + Wrapper Module for benchmarking Modules + + Args: + module: module to benchmark + + Call Args: + input: KeyedJaggedTensor KJT input to module + + Returns: + output: KT output from module + + Example: + table_0 = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + table_1 = EmbeddingBagConfig( + name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] + ) + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + ebc.qconfig = torch.quantization.QConfig( + activation=torch.quantization.PlaceholderObserver.with_args( + dtype=torch.qint8 + ), + weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), + ) + + qebc = QuantEmbeddingBagCollection.from_float(ebc) + + wrapped_module = EBCWrapper(qebc) + quantized_embeddings = wrapped_module(features) + """ + + def __init__(self, module: torch.nn.Module) -> None: + super().__init__() + self._module = module + + def forward(self, input: KeyedJaggedTensor) -> KeyedTensor: + """ + Args: + input (KeyedJaggedTensor): KJT of form [F X B X L]. + + Returns: + KeyedTensor + """ + return self._module.forward(input) + + +T = TypeVar("T", bound=torch.nn.Module) + + +def get_tables( + table_sizes: List[Tuple[int, int]], +) -> List[EmbeddingBagConfig]: + tables: List[EmbeddingBagConfig] = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + data_type=DataType.INT8, + ) + for i, (num_embeddings, embedding_dim) in enumerate(table_sizes) + ] + + return tables + + +def get_inputs( + tables: List[EmbeddingBagConfig], + batch_size: int, + world_size: int, + num_inputs: int, +) -> List[KeyedJaggedTensor]: + inputs: List[KeyedJaggedTensor] = [] + for _ in range(num_inputs): + model_input = ModelInput.generate( + batch_size=batch_size, + world_size=world_size, + num_float_features=0, + tables=tables, + weighted_tables=[], + long_indices=False, + )[1][0] + inputs.append(model_input.idlist_features.to(torch.device("cuda:0"))) + + return inputs + + +def transform_module( + module: torch.nn.Module, + device: torch.device, + inputs: List[KeyedJaggedTensor], + sharder: ModuleSharder[T], + sharding_type: ShardingType, + compile_mode: CompileMode, + world_size: int, + batch_size: int, +) -> torch.nn.Module: + def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: + eager_module(inputs[0]) + graph_module = symbolic_trace( + eager_module, leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"] + ) + scripted_module = torch.jit.script(graph_module) + return scripted_module + + topology: Topology = Topology(world_size=world_size, compute_device="cuda") + planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + # Don't want to modify the module outright + # Since module is on cpu, won't cause cuda oom. + copied_module = copy.deepcopy(module) + # pyre-ignore [6] + plan = planner.plan(copied_module, [sharder]) + + sharded_module = _shard_modules( + module=copied_module, + # pyre-ignore [6] + sharders=[sharder], + device=device, + plan=plan, + env=ShardingEnv.from_local(world_size=topology.world_size, rank=0), + ) + + if compile_mode == CompileMode.FX_SCRIPT: + return fx_script_module(sharded_module) + else: + return sharded_module + + +def benchmark( + name: str, + model: torch.nn.Module, + warmup_inputs: List[KeyedJaggedTensor], + bench_inputs: List[KeyedJaggedTensor], + prof_inputs: List[KeyedJaggedTensor], + world_size: int, + output_dir: str, + num_benchmarks: int, +) -> BenchmarkResult: + model.training = False + max_mem_allocated: List[int] = [] + logger.info(f" BENCHMARK_MODEL[{name}]:\n{model}") + + for _input in warmup_inputs: + model(_input) + + # Reset memory for measurement + for di in range(torch.cuda.device_count()): + torch.cuda.reset_max_memory_allocated(torch.device(f"cuda:{di}")) + + # Measure time taken for batches in bench_inputs + start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] + end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] + + for i in range(num_benchmarks): + start[i].record() + for _input in bench_inputs: + model(_input) + end[i].record() + + for di in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{di}")) + + # TODO: First Benchmark Run for Eager Mode produces outlier + # Start counting after first as workaround for standard deviation + elapsed_time = torch.tensor( + [si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])] + ) + + for di in range(world_size): + b = torch.cuda.max_memory_allocated(torch.device(f"cuda:{di}")) + max_mem_allocated.append(b // 1024 // 1024) + + # pyre-ignore[2] + def trace_handler(prof) -> None: + total_average = prof.profiler.total_average() + logger.info(f" TOTAL_AVERAGE:\n{name}\n{total_average}") + dir_path: str = output_dir + + if dir_path == "": + return + + trace_file: str = f"{dir_path}/trace-{name}.json" + stacks_cpu_file = f"{dir_path}/stacks-cpu-{name}.stacks" + stacks_cuda_file = f"{dir_path}/stacks-cuda-{name}.stacks" + logger.info(f" PROFILE[{name}].chrome_trace:{trace_file}") + + prof.export_chrome_trace(trace_file) + prof.export_stacks(stacks_cpu_file, "self_cpu_time_total") + prof.export_stacks(stacks_cuda_file, "self_cuda_time_total") + + # - git clone https://github.com/brendangregg/FlameGraph + # - cd FlameGraph + # - ./flamegraph.pl --title "CPU time" --countname "us." profiler.stacks > perf_viz.svg + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + profile_memory=True, + with_stack=True, + with_flops=True, + with_modules=True, + on_trace_ready=trace_handler, + ) as p: + for _input in prof_inputs: + with record_function("## forward ##"): + model(_input) + p.step() + for di in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{di}")) + + return BenchmarkResult( + short_name=name, + elapsed_time=elapsed_time, + max_mem_allocated=max_mem_allocated, + ) + + +def benchmark_type_name(compile_mode: CompileMode, sharding_type: ShardingType) -> str: + if sharding_type == ShardingType.TABLE_WISE: + name = "tw-sharded" + elif sharding_type == ShardingType.ROW_WISE: + name = "rw-sharded" + elif sharding_type == ShardingType.COLUMN_WISE: + name = "cw-sharded" + else: + raise Exception(f"Unknown sharding type {sharding_type}") + + if compile_mode == CompileMode.EAGER: + name += "-eager" + elif compile_mode == CompileMode.FX_SCRIPT: + name += "-fxjit" + + return name + + +def init_module_and_run_benchmark( + module: torch.nn.Module, + sharder: ModuleSharder[T], + device: torch.device, + sharding_type: ShardingType, + compile_mode: CompileMode, + world_size: int, + batch_size: int, + warmup_inputs: List[KeyedJaggedTensor], + bench_inputs: List[KeyedJaggedTensor], + prof_inputs: List[KeyedJaggedTensor], + output_dir: str, + num_benchmarks: int, +) -> BenchmarkResult: + """ + There are a couple of caveats here as to why the module has to be initialized + here: + 1. Device. To accurately track memory usage, when sharding modules the initial + placement of the module should be on CPU. This is to avoid double counting + memory allocations and also to prevent CUDA OOMs. + 2. Garbage Collector. Since torch.fx.GraphModule has circular references, + garbage collection us funky and can lead to ooms. Since this frame is + called by the loop through compile modes and sharding types, returning the + benchmark result will mean that the reference to module is lost instead of + existing in the loop + """ + + module = transform_module( + module=module, + device=device, + inputs=warmup_inputs, + sharder=sharder, + sharding_type=sharding_type, + compile_mode=compile_mode, + world_size=world_size, + batch_size=batch_size, + ) + + name = benchmark_type_name(compile_mode, sharding_type) + + return benchmark( + name, + module, + warmup_inputs, + bench_inputs, + prof_inputs, + world_size=world_size, + output_dir=output_dir, + num_benchmarks=num_benchmarks, + ) + + +def benchmark_module( + module: torch.nn.Module, + sharder: ModuleSharder[T], + sharding_types: List[ShardingType], + compile_modes: List[CompileMode], + tables: List[EmbeddingBagConfig], + warmup_iters: int = 20, + bench_iters: int = 2000, + prof_iters: int = 20, + batch_size: int = 2048, + world_size: int = 2, + num_benchmarks: int = 9, + output_dir: str = "", +) -> List[BenchmarkResult]: + """ + Args: + eager_module: Eager mode module to be benchmarked + sharding_types: Sharding types to be benchmarked + compile_modes: Compilation modes to be benchmarked + warmup_iters: Number of iterations to run before profiling + bench_iters: Number of iterations to run during profiling + prof_iters: Number of iterations to run after profiling + batch_size: Batch size used in the model + world_size: World size used in the + num_benchmarks: How many times to run over benchmark inputs for statistics + output_dir: Directory to output profiler outputs (traces, stacks) + + Returns: + A list of BenchmarkResults + """ + + # logging.info(f"###### Benchmarking Module: {eager_module} ######\n") + logging.info(f"Warmup iterations: {warmup_iters}") + logging.info(f"Benchmark iterations: {bench_iters}") + logging.info(f"Profile iterations: {prof_iters}") + logging.info(f"Batch Size: {batch_size}") + logging.info(f"World Size: {world_size}") + logging.info(f"Number of Benchmarks: {num_benchmarks}") + logging.info(f"Output Directory: {output_dir}") + + benchmark_results: List[BenchmarkResult] = [] + num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters + inputs = get_inputs(tables, batch_size, world_size, num_inputs_to_gen) + + warmup_inputs = inputs[:warmup_iters] + bench_inputs = inputs[warmup_iters : (warmup_iters + bench_iters)] + prof_inputs = inputs[-prof_iters:] + + wrapped_module = EBCWrapper(module) + + for sharding_type in sharding_types: + for compile_mode in compile_modes: + # Test sharders should have a singular sharding_type + # pyre-ignore [16] + sharder._sharding_type = sharding_type.value + + benchmark_type = benchmark_type_name(compile_mode, sharding_type) + logging.info( + f"\n\n###### Running Benchmark Type: {benchmark_type} ######\n" + ) + res = init_module_and_run_benchmark( + module=wrapped_module, + sharder=sharder, + # TODO: GPU hardcode for now, expand if needed for heter hardware + device=torch.device("cuda:0"), + sharding_type=sharding_type, + compile_mode=compile_mode, + world_size=world_size, + batch_size=batch_size, + warmup_inputs=warmup_inputs, + bench_inputs=bench_inputs, + prof_inputs=prof_inputs, + num_benchmarks=num_benchmarks, + output_dir=output_dir, + ) + + gc.collect() + benchmark_results.append(res) + + return benchmark_results