Skip to content

Commit

Permalink
benchmark of fbgemm op - keyed_tensor_regroup (pytorch#2159)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2159

# context
* added **fn-level** benchmark for the `regroup_keyed_tensor`
* `keyed_tensor_regroup` further reduces the CPU runtime from 2.0ms to 1.3ms (35% improvement) without hurting the GPU runtime/memory usage

# conclusion
* CPU runtime **reduces 1/3** from 1.8 ms to 1.1 ms
* GPU runtime **reduces 2/3** from 7.0 ms to 2.0 ms
* GPU memory **reduces 1/3** from 1.5 K to 1.0 K
* **we should migrate to the new op** unless any unknown concern/blocker

# traces
* [files](https://drive.google.com/drive/folders/1iiEf30LeG_i0xobMZVhmMneOQ5slmX3U?usp=drive_link)
```
[[email protected] /data/sandcastle/boxes/fbsource (04ad34da3)]$ ll *.json
-rw-r--r-- 1 hhy hhy  552501 Jul 10 16:01 'trace-[1 Op] KT_regroup_dup.json'
-rw-r--r-- 1 hhy hhy  548847 Jul 10 16:01 'trace-[1 Op] KT_regroup.json'
-rw-r--r-- 1 hhy hhy  559006 Jul 10 16:01 'trace-[2 Ops] permute_multi_embs_dup.json'
-rw-r--r-- 1 hhy hhy  553199 Jul 10 16:01 'trace-[2 Ops] permute_multi_embs.json'
-rw-r--r-- 1 hhy hhy 5104239 Jul 10 16:01 'trace-[Module] KTRegroupAsDict_dup.json'
-rw-r--r-- 1 hhy hhy  346643 Jul 10 16:01 'trace-[Module] KTRegroupAsDict.json'
-rw-r--r-- 1 hhy hhy  895096 Jul 10 16:01 'trace-[Old Prod] permute_pooled_embs.json'
-rw-r--r-- 1 hhy hhy  561685 Jul 10 16:01 'trace-[Prod] KeyedTensor.regroup_dup.json'
-rw-r--r-- 1 hhy hhy  559147 Jul 10 16:01 'trace-[Prod] KeyedTensor.regroup.json'
-rw-r--r-- 1 hhy hhy 7958676 Jul 10 16:01 'trace-[pytorch generic] fallback_dup.json'
-rw-r--r-- 1 hhy hhy 7978141 Jul 10 16:01 'trace-[pytorch generic] fallback.json'
```
* pytorch generic
 {F1752502508}
* current prod
 {F1752503546}
* permute_multi_embedding (2 Ops)
 {F1752503160}
* KT.regroup (1 Op)
 {F1752504258}
* regroupAsDict (Module)
{F1752504964}
* metrics
|Operator|CPU runtime|GPU runtime|GPU memory|notes|
|---|---|---|---|---|
|**[fallback] pytorch generic**|3.9 ms|3.2 ms|1.0 K|CPU-bounded, allow duplicates|
|**[prod] _fbgemm_permute_pooled_embs**|1.9 ms|7.1 ms|1.5 K|GPU-boudned, does **NOT** allow duplicates, PT2 non-compatible `pin_and_move`|
|**[hybrid python/cu] keyed_tensor_regroup**|1.5 ms|2.0 ms|1.0 K|both GPU runtime and memory improved, **ALLOW** duplicates, PT2 friendly|
|**[pure c++/cu] permute_multi_embedding**|1.0 ms|2.0 ms|1.0 K|both CPU and GPU runtime/memory improved, **ALLOW** duplicates, PT2 friendly|

Differential Revision: D58907223
  • Loading branch information
Huanyu He authored and facebook-github-bot committed Jul 19, 2024
1 parent 4f114bc commit 56a3f45
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
13 changes: 13 additions & 0 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ def permute_multi_embedding(
return permuted_values


@torch.fx.wrap
def regroup_kts(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> List[torch.Tensor]:
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
return torch.ops.fbgemm.regroup_keyed_tensor(
values,
keys,
lengths,
groups,
)


@torch.fx.wrap
def _fbgemm_permute_pooled_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
Expand Down
33 changes: 29 additions & 4 deletions torchrec/sparse/tests/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import (
_fbgemm_permute_pooled_embs,
_regroup_keyed_tensors,
KeyedJaggedTensor,
KeyedTensor,
permute_multi_embedding,
regroup_kts,
)
from torchrec.sparse.tests.utils import build_groups, build_kts

Expand Down Expand Up @@ -213,7 +215,7 @@ def main(
).float()
groups = build_groups(kts, n_groups, duplicates=duplicates)
bench(
"_regroup_keyed_tenors" + dup,
"[pytorch generic] fallback" + dup,
labels,
batch_size,
n_dense + n_sparse,
Expand All @@ -224,7 +226,7 @@ def main(
profile,
)
bench(
"KeyedTensor.regroup" + dup,
"[Prod] KeyedTensor.regroup" + dup,
labels,
batch_size,
n_dense + n_sparse,
Expand All @@ -235,7 +237,7 @@ def main(
profile,
)
bench(
"KTRegroupAsDict" + dup,
"[Module] KTRegroupAsDict" + dup,
labels,
batch_size,
n_dense + n_sparse,
Expand All @@ -248,7 +250,7 @@ def main(
profile,
)
bench(
"permute_multi_embs" + dup,
"[2 Ops] permute_multi_embs" + dup,
labels,
batch_size,
n_dense + n_sparse,
Expand All @@ -258,6 +260,29 @@ def main(
{"keyed_tensors": kts, "groups": groups},
profile,
)
bench(
"[1 Op] KT_regroup" + dup,
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
regroup_kts,
{"keyed_tensors": kts, "groups": groups},
profile,
)
if not duplicates:
bench(
"[Old Prod] permute_pooled_embs" + dup,
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
_fbgemm_permute_pooled_embs,
{"keyed_tensors": kts, "groups": groups},
profile,
)


if __name__ == "__main__":
Expand Down

0 comments on commit 56a3f45

Please sign in to comment.