From 56a3f45c6d0bb6db78078b4b87f9c6496db6f6c7 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Fri, 19 Jul 2024 11:04:09 -0700 Subject: [PATCH] benchmark of fbgemm op - keyed_tensor_regroup (#2159) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2159 # context * added **fn-level** benchmark for the `regroup_keyed_tensor` * `keyed_tensor_regroup` further reduces the CPU runtime from 2.0ms to 1.3ms (35% improvement) without hurting the GPU runtime/memory usage # conclusion * CPU runtime **reduces 1/3** from 1.8 ms to 1.1 ms * GPU runtime **reduces 2/3** from 7.0 ms to 2.0 ms * GPU memory **reduces 1/3** from 1.5 K to 1.0 K * **we should migrate to the new op** unless any unknown concern/blocker # traces * [files](https://drive.google.com/drive/folders/1iiEf30LeG_i0xobMZVhmMneOQ5slmX3U?usp=drive_link) ``` [hhy@24963.od /data/sandcastle/boxes/fbsource (04ad34da3)]$ ll *.json -rw-r--r-- 1 hhy hhy 552501 Jul 10 16:01 'trace-[1 Op] KT_regroup_dup.json' -rw-r--r-- 1 hhy hhy 548847 Jul 10 16:01 'trace-[1 Op] KT_regroup.json' -rw-r--r-- 1 hhy hhy 559006 Jul 10 16:01 'trace-[2 Ops] permute_multi_embs_dup.json' -rw-r--r-- 1 hhy hhy 553199 Jul 10 16:01 'trace-[2 Ops] permute_multi_embs.json' -rw-r--r-- 1 hhy hhy 5104239 Jul 10 16:01 'trace-[Module] KTRegroupAsDict_dup.json' -rw-r--r-- 1 hhy hhy 346643 Jul 10 16:01 'trace-[Module] KTRegroupAsDict.json' -rw-r--r-- 1 hhy hhy 895096 Jul 10 16:01 'trace-[Old Prod] permute_pooled_embs.json' -rw-r--r-- 1 hhy hhy 561685 Jul 10 16:01 'trace-[Prod] KeyedTensor.regroup_dup.json' -rw-r--r-- 1 hhy hhy 559147 Jul 10 16:01 'trace-[Prod] KeyedTensor.regroup.json' -rw-r--r-- 1 hhy hhy 7958676 Jul 10 16:01 'trace-[pytorch generic] fallback_dup.json' -rw-r--r-- 1 hhy hhy 7978141 Jul 10 16:01 'trace-[pytorch generic] fallback.json' ``` * pytorch generic {F1752502508} * current prod {F1752503546} * permute_multi_embedding (2 Ops) {F1752503160} * KT.regroup (1 Op) {F1752504258} * regroupAsDict (Module) {F1752504964} * metrics |Operator|CPU runtime|GPU runtime|GPU memory|notes| |---|---|---|---|---| |**[fallback] pytorch generic**|3.9 ms|3.2 ms|1.0 K|CPU-bounded, allow duplicates| |**[prod] _fbgemm_permute_pooled_embs**|1.9 ms|7.1 ms|1.5 K|GPU-boudned, does **NOT** allow duplicates, PT2 non-compatible `pin_and_move`| |**[hybrid python/cu] keyed_tensor_regroup**|1.5 ms|2.0 ms|1.0 K|both GPU runtime and memory improved, **ALLOW** duplicates, PT2 friendly| |**[pure c++/cu] permute_multi_embedding**|1.0 ms|2.0 ms|1.0 K|both CPU and GPU runtime/memory improved, **ALLOW** duplicates, PT2 friendly| Differential Revision: D58907223 --- torchrec/sparse/jagged_tensor.py | 13 ++++++++ .../sparse/tests/jagged_tensor_benchmark.py | 33 ++++++++++++++++--- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 8f8785e2a..f3ad229e6 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -188,6 +188,19 @@ def permute_multi_embedding( return permuted_values +@torch.fx.wrap +def regroup_kts( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> List[torch.Tensor]: + keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) + return torch.ops.fbgemm.regroup_keyed_tensor( + values, + keys, + lengths, + groups, + ) + + @torch.fx.wrap def _fbgemm_permute_pooled_embs( keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py index aa426e448..b9dd12d3b 100644 --- a/torchrec/sparse/tests/jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -18,10 +18,12 @@ from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult from torchrec.modules.regroup import KTRegroupAsDict from torchrec.sparse.jagged_tensor import ( + _fbgemm_permute_pooled_embs, _regroup_keyed_tensors, KeyedJaggedTensor, KeyedTensor, permute_multi_embedding, + regroup_kts, ) from torchrec.sparse.tests.utils import build_groups, build_kts @@ -213,7 +215,7 @@ def main( ).float() groups = build_groups(kts, n_groups, duplicates=duplicates) bench( - "_regroup_keyed_tenors" + dup, + "[pytorch generic] fallback" + dup, labels, batch_size, n_dense + n_sparse, @@ -224,7 +226,7 @@ def main( profile, ) bench( - "KeyedTensor.regroup" + dup, + "[Prod] KeyedTensor.regroup" + dup, labels, batch_size, n_dense + n_sparse, @@ -235,7 +237,7 @@ def main( profile, ) bench( - "KTRegroupAsDict" + dup, + "[Module] KTRegroupAsDict" + dup, labels, batch_size, n_dense + n_sparse, @@ -248,7 +250,7 @@ def main( profile, ) bench( - "permute_multi_embs" + dup, + "[2 Ops] permute_multi_embs" + dup, labels, batch_size, n_dense + n_sparse, @@ -258,6 +260,29 @@ def main( {"keyed_tensors": kts, "groups": groups}, profile, ) + bench( + "[1 Op] KT_regroup" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + regroup_kts, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + if not duplicates: + bench( + "[Old Prod] permute_pooled_embs" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + _fbgemm_permute_pooled_embs, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) if __name__ == "__main__":