Skip to content

Commit

Permalink
Add QEC to Benchmark Example (#1640)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1640

Adding benchmarking of QEC alongside QEBC into benchmark_inference.py. This allows benchmarking the pooled and unpooled module.

Reviewed By: gnahzg

Differential Revision: D52874402

fbshipit-source-id: ef670e8ffe3588c52b527c0fec8edf478339016e
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Jan 18, 2024
1 parent 719f678 commit a6bad42
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 69 deletions.
169 changes: 118 additions & 51 deletions torchrec/distributed/benchmark/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import logging
import os
import time
from typing import List
from functools import partial
from typing import List, Tuple

import torch

Expand All @@ -22,26 +23,49 @@
get_tables,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
from torchrec.distributed.test_utils.infer_utils import TestQuantEBCSharder
from torchrec.distributed.test_utils.infer_utils import (
TestQuantEBCSharder,
TestQuantECSharder,
)
from torchrec.quant.embedding_modules import (
EmbeddingBagCollection as QuantEmbeddingBagCollection,
EmbeddingCollection as QuantEmbeddingCollection,
)

logger: logging.Logger = logging.getLogger()


def init_argparse() -> argparse.ArgumentParser:
def init_argparse_and_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()

parser.add_argument("--warmup_iters", type=int, default=20)
parser.add_argument("--bench_iters", type=int, default=2000)
parser.add_argument("--bench_iters", type=int, default=500)
parser.add_argument("--prof_iters", type=int, default=20)
parser.add_argument("--batch_size", type=int, default=2048)
parser.add_argument("--world_size", type=int, default=2)
parser.add_argument("--output_dir", type=str, default="/var/tmp/torchrec-bench")
parser.add_argument("--num_benchmarks", type=int, default=9)

return parser
args = parser.parse_args()
return args


BENCH_SHARDING_TYPES: List[ShardingType] = [
ShardingType.TABLE_WISE,
ShardingType.ROW_WISE,
ShardingType.COLUMN_WISE,
]

BENCH_COMPILE_MODES: List[CompileMode] = [
CompileMode.EAGER,
CompileMode.FX_SCRIPT,
]

TABLE_SIZES: List[Tuple[int, int]] = [
(40_000_000, 256),
(4_000_000, 256),
(1_000_000, 256),
]


def write_report(
Expand Down Expand Up @@ -73,61 +97,49 @@ def write_report(
logger.info(f"Report written to {report_file}:\n{report_str}")


def benchmark_qebc() -> None:
parser = init_argparse()
args = parser.parse_args()

datetime_sfx: str = time.strftime("%Y%m%dT%H%M%S")
def benchmark_qec(args: argparse.Namespace, output_dir: str) -> List[BenchmarkResult]:
tables = get_tables(TABLE_SIZES, is_pooled=False)
sharder = TestQuantECSharder(
sharding_type="",
kernel_type=EmbeddingComputeKernel.QUANT.value,
shardable_params=[table.name for table in tables],
)

output_dir = args.output_dir
if not os.path.exists(output_dir):
# Create output directory if not exist
os.mkdir(output_dir)
module = QuantEmbeddingCollection(
# pyre-ignore [6]
tables=tables,
device=torch.device("cpu"),
quant_state_dict_split_scale_bias=True,
)

output_dir += f"/run_{datetime_sfx}"
if not os.path.exists(output_dir):
# Place all outputs under the datetime folder
os.mkdir(output_dir)
args_kwargs = {
argname: getattr(args, argname)
for argname in dir(args)
# Don't include output_dir since output_dir was modified
if not argname.startswith("_") and argname != "output_dir"
}

BENCH_SHARDING_TYPES = [
ShardingType.TABLE_WISE,
ShardingType.ROW_WISE,
ShardingType.COLUMN_WISE,
]

BENCH_COMPILE_MODES = [
CompileMode.EAGER,
CompileMode.FX_SCRIPT,
]

table_sizes = [
(40_000_000, 256),
(4_000_000, 256),
(1_000_000, 256),
]

tables_info = "\nTABLE SIZES QUANT:"
for i, (num, dim) in enumerate(table_sizes):
mb = int(float(num * dim) / 1024 / 1024)
tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] u8: {mb:6}Mb"

report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size//1000}k\n"
report += "Module: QuantEmbeddingBagCollection\n"
report += tables_info
report += "\n"
return benchmark_module(
module=module,
sharder=sharder,
sharding_types=BENCH_SHARDING_TYPES,
compile_modes=BENCH_COMPILE_MODES,
tables=tables,
output_dir=output_dir,
**args_kwargs,
)

num_requests = args.bench_iters * args.batch_size * args.num_benchmarks
report += f"num_requests:{num_requests:8}\n"
report_file: str = f"{output_dir}/run.report"

tables = get_tables(table_sizes)
def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkResult]:
tables = get_tables(TABLE_SIZES)
sharder = TestQuantEBCSharder(
sharding_type="",
kernel_type=EmbeddingComputeKernel.QUANT.value,
shardable_params=[table.name for table in tables],
)

module = QuantEmbeddingBagCollection(
# pyre-ignore [6]
tables=tables,
is_weighted=False,
device=torch.device("cpu"),
Expand All @@ -141,7 +153,7 @@ def benchmark_qebc() -> None:
if not argname.startswith("_") and argname != "output_dir"
}

benchmark_results = benchmark_module(
return benchmark_module(
module=module,
sharder=sharder,
sharding_types=BENCH_SHARDING_TYPES,
Expand All @@ -151,10 +163,65 @@ def benchmark_qebc() -> None:
**args_kwargs,
)

write_report(benchmark_results, report_file, report, num_requests)

def main() -> None:
args: argparse.Namespace = init_argparse_and_args()

num_requests = args.bench_iters * args.batch_size * args.num_benchmarks
datetime_sfx: str = time.strftime("%Y%m%dT%H%M%S")

output_dir = args.output_dir
if not os.path.exists(output_dir):
# Create output directory if not exist
os.mkdir(output_dir)

benchmark_results_per_module = []
write_report_funcs_per_module = []

for module_name in ["QuantEmbeddingBagCollection", "QuantEmbeddingCollection"]:
output_dir = args.output_dir + f"/run_{datetime_sfx}"
if module_name == "QuantEmbeddingBagCollection":
output_dir += "_qebc"
benchmark_func = benchmark_qebc
else:
output_dir += "_qec"
benchmark_func = benchmark_qec

if not os.path.exists(output_dir):
# Place all outputs under the datetime folder
os.mkdir(output_dir)

tables_info = "\nTABLE SIZES QUANT:"
for i, (num, dim) in enumerate(TABLE_SIZES):
mb = int(float(num * dim) / 1024 / 1024)
tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] u8: {mb:6}Mb"

report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
report += f"Module: {module_name}\n"
report += tables_info
report += "\n"

num_requests = args.bench_iters * args.batch_size * args.num_benchmarks
report += f"num_requests:{num_requests:8}\n"
report_file: str = f"{output_dir}/run.report"

# Save results to output them once benchmarking is all done
benchmark_results_per_module.append(benchmark_func(args, output_dir))
write_report_funcs_per_module.append(
partial(
write_report,
report_file=report_file,
report_str=report,
num_requests=num_requests,
)
)

for i, write_report_func in enumerate(write_report_funcs_per_module):
write_report_func(benchmark_results_per_module[i])


if __name__ == "__main__":
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
benchmark_qebc()

main()
108 changes: 90 additions & 18 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dataclasses import dataclass

from enum import Enum
from typing import List, Tuple, TypeVar
from typing import Dict, List, Tuple, TypeVar, Union

import torch
from torch.autograd.profiler import record_function
Expand All @@ -30,8 +30,8 @@

from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv
from torchrec.fx import symbolic_trace
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor

logger: logging.Logger = logging.getLogger()

Expand All @@ -49,6 +49,63 @@ class BenchmarkResult:
max_mem_allocated: List[int]


class ECWrapper(torch.nn.Module):
"""
Wrapper Module for benchmarking EC Modules
Args:
module: module to benchmark
Call Args:
input: KeyedJaggedTensor KJT input to module
Returns:
output: KT output from module
Example:
e1_config = EmbeddingConfig(
name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
e2_config = EmbeddingConfig(
name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"]
)
ec = EmbeddingCollection(tables=[e1_config, e2_config])
features = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)
ec.qconfig = torch.quantization.QConfig(
activation=torch.quantization.PlaceholderObserver.with_args(
dtype=torch.qint8
),
weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8),
)
qec = QuantEmbeddingCollection.from_float(ecc)
wrapped_module = ECWrapper(qec)
quantized_embeddings = wrapped_module(features)
"""

def __init__(self, module: torch.nn.Module) -> None:
super().__init__()
self._module = module

def forward(self, input: KeyedJaggedTensor) -> Dict[str, JaggedTensor]:
"""
Args:
input (KeyedJaggedTensor): KJT of form [F X B X L].
Returns:
Dict[str, JaggedTensor]
"""
return self._module.forward(input)


class EBCWrapper(torch.nn.Module):
"""
Wrapper Module for benchmarking Modules
Expand Down Expand Up @@ -109,24 +166,36 @@ def forward(self, input: KeyedJaggedTensor) -> KeyedTensor:


def get_tables(
table_sizes: List[Tuple[int, int]],
) -> List[EmbeddingBagConfig]:
tables: List[EmbeddingBagConfig] = [
EmbeddingBagConfig(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
data_type=DataType.INT8,
)
for i, (num_embeddings, embedding_dim) in enumerate(table_sizes)
]
table_sizes: List[Tuple[int, int]], is_pooled: bool = True
) -> Union[List[EmbeddingBagConfig], List[EmbeddingConfig]]:
if is_pooled:
tables: List[EmbeddingBagConfig] = [
EmbeddingBagConfig(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
data_type=DataType.INT8,
)
for i, (num_embeddings, embedding_dim) in enumerate(table_sizes)
]
else:
tables: List[EmbeddingConfig] = [
EmbeddingConfig(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
data_type=DataType.INT8,
)
for i, (num_embeddings, embedding_dim) in enumerate(table_sizes)
]

return tables


def get_inputs(
tables: List[EmbeddingBagConfig],
tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]],
batch_size: int,
world_size: int,
num_inputs: int,
Expand Down Expand Up @@ -366,7 +435,7 @@ def benchmark_module(
sharder: ModuleSharder[T],
sharding_types: List[ShardingType],
compile_modes: List[CompileMode],
tables: List[EmbeddingBagConfig],
tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]],
warmup_iters: int = 20,
bench_iters: int = 2000,
prof_iters: int = 20,
Expand Down Expand Up @@ -409,7 +478,10 @@ def benchmark_module(
bench_inputs = inputs[warmup_iters : (warmup_iters + bench_iters)]
prof_inputs = inputs[-prof_iters:]

wrapped_module = EBCWrapper(module)
if isinstance(tables[0], EmbeddingBagConfig):
wrapped_module = EBCWrapper(module)
else:
wrapped_module = ECWrapper(module)

for sharding_type in sharding_types:
for compile_mode in compile_modes:
Expand Down

0 comments on commit a6bad42

Please sign in to comment.