From a6bad4200613d520e3bd654ac176eda3dcca5537 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 18 Jan 2024 14:40:40 -0800 Subject: [PATCH] Add QEC to Benchmark Example (#1640) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1640 Adding benchmarking of QEC alongside QEBC into benchmark_inference.py. This allows benchmarking the pooled and unpooled module. Reviewed By: gnahzg Differential Revision: D52874402 fbshipit-source-id: ef670e8ffe3588c52b527c0fec8edf478339016e --- .../benchmark/benchmark_inference.py | 169 ++++++++++++------ .../distributed/benchmark/benchmark_utils.py | 108 +++++++++-- 2 files changed, 208 insertions(+), 69 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_inference.py b/torchrec/distributed/benchmark/benchmark_inference.py index 0e7c1622d..b52e4df13 100644 --- a/torchrec/distributed/benchmark/benchmark_inference.py +++ b/torchrec/distributed/benchmark/benchmark_inference.py @@ -11,7 +11,8 @@ import logging import os import time -from typing import List +from functools import partial +from typing import List, Tuple import torch @@ -22,26 +23,49 @@ get_tables, ) from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType -from torchrec.distributed.test_utils.infer_utils import TestQuantEBCSharder +from torchrec.distributed.test_utils.infer_utils import ( + TestQuantEBCSharder, + TestQuantECSharder, +) from torchrec.quant.embedding_modules import ( EmbeddingBagCollection as QuantEmbeddingBagCollection, + EmbeddingCollection as QuantEmbeddingCollection, ) logger: logging.Logger = logging.getLogger() -def init_argparse() -> argparse.ArgumentParser: +def init_argparse_and_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--warmup_iters", type=int, default=20) - parser.add_argument("--bench_iters", type=int, default=2000) + parser.add_argument("--bench_iters", type=int, default=500) parser.add_argument("--prof_iters", type=int, default=20) parser.add_argument("--batch_size", type=int, default=2048) parser.add_argument("--world_size", type=int, default=2) parser.add_argument("--output_dir", type=str, default="/var/tmp/torchrec-bench") parser.add_argument("--num_benchmarks", type=int, default=9) - return parser + args = parser.parse_args() + return args + + +BENCH_SHARDING_TYPES: List[ShardingType] = [ + ShardingType.TABLE_WISE, + ShardingType.ROW_WISE, + ShardingType.COLUMN_WISE, +] + +BENCH_COMPILE_MODES: List[CompileMode] = [ + CompileMode.EAGER, + CompileMode.FX_SCRIPT, +] + +TABLE_SIZES: List[Tuple[int, int]] = [ + (40_000_000, 256), + (4_000_000, 256), + (1_000_000, 256), +] def write_report( @@ -73,54 +97,41 @@ def write_report( logger.info(f"Report written to {report_file}:\n{report_str}") -def benchmark_qebc() -> None: - parser = init_argparse() - args = parser.parse_args() - - datetime_sfx: str = time.strftime("%Y%m%dT%H%M%S") +def benchmark_qec(args: argparse.Namespace, output_dir: str) -> List[BenchmarkResult]: + tables = get_tables(TABLE_SIZES, is_pooled=False) + sharder = TestQuantECSharder( + sharding_type="", + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in tables], + ) - output_dir = args.output_dir - if not os.path.exists(output_dir): - # Create output directory if not exist - os.mkdir(output_dir) + module = QuantEmbeddingCollection( + # pyre-ignore [6] + tables=tables, + device=torch.device("cpu"), + quant_state_dict_split_scale_bias=True, + ) - output_dir += f"/run_{datetime_sfx}" - if not os.path.exists(output_dir): - # Place all outputs under the datetime folder - os.mkdir(output_dir) + args_kwargs = { + argname: getattr(args, argname) + for argname in dir(args) + # Don't include output_dir since output_dir was modified + if not argname.startswith("_") and argname != "output_dir" + } - BENCH_SHARDING_TYPES = [ - ShardingType.TABLE_WISE, - ShardingType.ROW_WISE, - ShardingType.COLUMN_WISE, - ] - - BENCH_COMPILE_MODES = [ - CompileMode.EAGER, - CompileMode.FX_SCRIPT, - ] - - table_sizes = [ - (40_000_000, 256), - (4_000_000, 256), - (1_000_000, 256), - ] - - tables_info = "\nTABLE SIZES QUANT:" - for i, (num, dim) in enumerate(table_sizes): - mb = int(float(num * dim) / 1024 / 1024) - tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] u8: {mb:6}Mb" - - report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size//1000}k\n" - report += "Module: QuantEmbeddingBagCollection\n" - report += tables_info - report += "\n" + return benchmark_module( + module=module, + sharder=sharder, + sharding_types=BENCH_SHARDING_TYPES, + compile_modes=BENCH_COMPILE_MODES, + tables=tables, + output_dir=output_dir, + **args_kwargs, + ) - num_requests = args.bench_iters * args.batch_size * args.num_benchmarks - report += f"num_requests:{num_requests:8}\n" - report_file: str = f"{output_dir}/run.report" - tables = get_tables(table_sizes) +def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkResult]: + tables = get_tables(TABLE_SIZES) sharder = TestQuantEBCSharder( sharding_type="", kernel_type=EmbeddingComputeKernel.QUANT.value, @@ -128,6 +139,7 @@ def benchmark_qebc() -> None: ) module = QuantEmbeddingBagCollection( + # pyre-ignore [6] tables=tables, is_weighted=False, device=torch.device("cpu"), @@ -141,7 +153,7 @@ def benchmark_qebc() -> None: if not argname.startswith("_") and argname != "output_dir" } - benchmark_results = benchmark_module( + return benchmark_module( module=module, sharder=sharder, sharding_types=BENCH_SHARDING_TYPES, @@ -151,10 +163,65 @@ def benchmark_qebc() -> None: **args_kwargs, ) - write_report(benchmark_results, report_file, report, num_requests) + +def main() -> None: + args: argparse.Namespace = init_argparse_and_args() + + num_requests = args.bench_iters * args.batch_size * args.num_benchmarks + datetime_sfx: str = time.strftime("%Y%m%dT%H%M%S") + + output_dir = args.output_dir + if not os.path.exists(output_dir): + # Create output directory if not exist + os.mkdir(output_dir) + + benchmark_results_per_module = [] + write_report_funcs_per_module = [] + + for module_name in ["QuantEmbeddingBagCollection", "QuantEmbeddingCollection"]: + output_dir = args.output_dir + f"/run_{datetime_sfx}" + if module_name == "QuantEmbeddingBagCollection": + output_dir += "_qebc" + benchmark_func = benchmark_qebc + else: + output_dir += "_qec" + benchmark_func = benchmark_qec + + if not os.path.exists(output_dir): + # Place all outputs under the datetime folder + os.mkdir(output_dir) + + tables_info = "\nTABLE SIZES QUANT:" + for i, (num, dim) in enumerate(TABLE_SIZES): + mb = int(float(num * dim) / 1024 / 1024) + tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] u8: {mb:6}Mb" + + report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n" + report += f"Module: {module_name}\n" + report += tables_info + report += "\n" + + num_requests = args.bench_iters * args.batch_size * args.num_benchmarks + report += f"num_requests:{num_requests:8}\n" + report_file: str = f"{output_dir}/run.report" + + # Save results to output them once benchmarking is all done + benchmark_results_per_module.append(benchmark_func(args, output_dir)) + write_report_funcs_per_module.append( + partial( + write_report, + report_file=report_file, + report_str=report, + num_requests=num_requests, + ) + ) + + for i, write_report_func in enumerate(write_report_funcs_per_module): + write_report_func(benchmark_results_per_module[i]) if __name__ == "__main__": logging.basicConfig() logging.getLogger().setLevel(logging.DEBUG) - benchmark_qebc() + + main() diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index c083cf3b5..e67c5d400 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from enum import Enum -from typing import List, Tuple, TypeVar +from typing import Dict, List, Tuple, TypeVar, Union import torch from torch.autograd.profiler import record_function @@ -30,8 +30,8 @@ from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv from torchrec.fx import symbolic_trace -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor logger: logging.Logger = logging.getLogger() @@ -49,6 +49,63 @@ class BenchmarkResult: max_mem_allocated: List[int] +class ECWrapper(torch.nn.Module): + """ + Wrapper Module for benchmarking EC Modules + + Args: + module: module to benchmark + + Call Args: + input: KeyedJaggedTensor KJT input to module + + Returns: + output: KT output from module + + Example: + e1_config = EmbeddingConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + e2_config = EmbeddingConfig( + name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] + ) + + ec = EmbeddingCollection(tables=[e1_config, e2_config]) + + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + ec.qconfig = torch.quantization.QConfig( + activation=torch.quantization.PlaceholderObserver.with_args( + dtype=torch.qint8 + ), + weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), + ) + + qec = QuantEmbeddingCollection.from_float(ecc) + + wrapped_module = ECWrapper(qec) + quantized_embeddings = wrapped_module(features) + """ + + def __init__(self, module: torch.nn.Module) -> None: + super().__init__() + self._module = module + + def forward(self, input: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: + """ + Args: + input (KeyedJaggedTensor): KJT of form [F X B X L]. + + Returns: + Dict[str, JaggedTensor] + """ + return self._module.forward(input) + + class EBCWrapper(torch.nn.Module): """ Wrapper Module for benchmarking Modules @@ -109,24 +166,36 @@ def forward(self, input: KeyedJaggedTensor) -> KeyedTensor: def get_tables( - table_sizes: List[Tuple[int, int]], -) -> List[EmbeddingBagConfig]: - tables: List[EmbeddingBagConfig] = [ - EmbeddingBagConfig( - num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - data_type=DataType.INT8, - ) - for i, (num_embeddings, embedding_dim) in enumerate(table_sizes) - ] + table_sizes: List[Tuple[int, int]], is_pooled: bool = True +) -> Union[List[EmbeddingBagConfig], List[EmbeddingConfig]]: + if is_pooled: + tables: List[EmbeddingBagConfig] = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + data_type=DataType.INT8, + ) + for i, (num_embeddings, embedding_dim) in enumerate(table_sizes) + ] + else: + tables: List[EmbeddingConfig] = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + data_type=DataType.INT8, + ) + for i, (num_embeddings, embedding_dim) in enumerate(table_sizes) + ] return tables def get_inputs( - tables: List[EmbeddingBagConfig], + tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]], batch_size: int, world_size: int, num_inputs: int, @@ -366,7 +435,7 @@ def benchmark_module( sharder: ModuleSharder[T], sharding_types: List[ShardingType], compile_modes: List[CompileMode], - tables: List[EmbeddingBagConfig], + tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]], warmup_iters: int = 20, bench_iters: int = 2000, prof_iters: int = 20, @@ -409,7 +478,10 @@ def benchmark_module( bench_inputs = inputs[warmup_iters : (warmup_iters + bench_iters)] prof_inputs = inputs[-prof_iters:] - wrapped_module = EBCWrapper(module) + if isinstance(tables[0], EmbeddingBagConfig): + wrapped_module = EBCWrapper(module) + else: + wrapped_module = ECWrapper(module) for sharding_type in sharding_types: for compile_mode in compile_modes: