Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
permute_multi_embs benchmark (pytorch#2238)
Summary: Pull Request resolved: pytorch#2238 # context * In this diff, we are exploring triton language to author the GPU kernel instead of using cuda * We use the same benchmark in jagged_tensor_benchmark.py to compare different implementations * In this diff stack, we developed a new cuda operator `permute_multi_embedding` in fbgemm to perform N-KT input and M-KT output permutation [[code](https://fburl.com/code/z3ck7iqi)]. * The intention of developing this new op is that currently in production, KT.regroup uses a cuda operator named `permute_pooled_embedding`, which perform 1-KT input and 1-KT output permutation. To achieve the same functionality, a `torch.concat` is needed before calling the op, and a `torch.split` is called after the op. This `torch.concat` is quite unnecessary and takes time and memory. [[code](https://fburl.com/code/4t2d3lz1)] * We also developed two triton operators following the same pattern, 1 op handles single tensor input (concatenated), 1 op handles multiple tensors. NOTE: for simplicity, we name these four operators as `cuda-multi-KT-permute`, `cuda-single-KT-permute`, `triton-multi-KT-permute`, and `triton-single-KT-permute`. * benchmark readings ``` cuda-multi-KT-permute | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 2.01 ms | Memory (P90): 1011.0 cuda-single-KT-permute | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 4.92 ms | Memory (P90): 1517.0 triton-multi-KT-permute | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 6.66 ms | Memory (P90): 1011.0 triton-single-KT-permute | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 7.75 ms | Memory (P90): 1517.0 ``` # details * the triton kernel design is very similar to the corresponding cuda kernel * the triton kernel takes in 2 `List[torch.Tensor]` as the input tensor list and output tensor list * the triton kernel also takes in the `permutes` tensor and runs in parallel # metrics | function | kernel type | memory | cpu runtime | gpu runtime | net kernel runtime| torch.cat/split | metadata preparation | notes | |---|---|---|---| | multi-KT-permute | cuda | 1011.0 | 1.060 ms | 2.01 ms | 1.99 ms | No | in cpp |call a cpu op for metadata preparation for optimal cpu runtime, including sending the tensor address list (tensor) to gpu, very efficent.| | single-KT-permute | cuda | 1517.0 | 1.793 ms | 4.92 ms | 1.99 ms | Yes | in python |optimal net kernel runtime as pure memory permutation on single tensor | | multi-KT-permute | triton | 1011.0 | 2.593 ms | 6.66 ms | 6.64 ms | No | in python |sending tensor address list (tensor) to gpu (executed in python) not as efficent.| | single-KT-permute | triton | 1517.0 | 1.669 ms | 7.75 ms | 5.00 ms | Yes | in python |maybe the optimal net kernel runtime in triton| # traces * [files](https://drive.google.com/drive/folders/173zZMnxnhLmFKkiJDXomS7c_ui0KKeV6?usp=sharing) ``` adding: trace-[1 Op] KT_regroup_dup.json (deflated 92%) adding: trace-[1 Op] KT_regroup.json (deflated 92%) adding: trace-[2 Ops] permute_multi_embs_dup.json (deflated 92%) adding: trace-[2 Ops] permute_multi_embs.json (deflated 92%) adding: trace-[Module] KTRegroupAsDict_dup.json (deflated 95%) adding: trace-[Module] KTRegroupAsDict.json (deflated 91%) adding: trace-[Old Prod] permute_pooled_embs.json (deflated 92%) adding: trace-[Prod] KeyedTensor.regroup_dup.json (deflated 95%) adding: trace-[Prod] KeyedTensor.regroup.json (deflated 92%) adding: trace-[pytorch generic] fallback_dup.json (deflated 95%) adding: trace-[pytorch generic] fallback.json (deflated 95%) adding: trace-[Triton] permute_multi_embs.json (deflated 92%) adding: trace-[Triton] permute_pooled_embs.json (deflated 90%) ``` * cuda-multi-KT-permute {F1764404357} * cuda-single-KT-permute {F1764401919} * triton-multi-KT-permute {F1764399954} * triton-single-KT-permute {F1764400438} # reference * [Trick: Passing tensor lists as pointer vectors](https://fburl.com/workplace/wdym5b7p) * D39735564 for List[tensor] in triton kernel Differential Revision: D52354486
- Loading branch information