Skip to content

Commit

Permalink
permute_multi_embs benchmark (pytorch#2238)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2238

# context
* In this diff, we are exploring triton language to author the GPU kernel instead of using cuda
* We use the same benchmark in jagged_tensor_benchmark.py to compare different implementations
* In this diff stack, we developed a new cuda operator `permute_multi_embedding` in fbgemm to perform N-KT input and M-KT output permutation [[code](https://fburl.com/code/z3ck7iqi)].
* The intention of developing this new op is that currently in production, KT.regroup uses a cuda operator named `permute_pooled_embedding`, which perform 1-KT input and 1-KT output permutation. To achieve the same functionality, a `torch.concat` is needed before calling the op, and a `torch.split` is called after the op. This `torch.concat` is quite unnecessary and takes time and memory. [[code](https://fburl.com/code/4t2d3lz1)]
* We also developed two triton operators following the same pattern, 1 op handles single tensor input (concatenated), 1 op handles multiple tensors.
NOTE: for simplicity, we name these four operators as `cuda-multi-KT-permute`, `cuda-single-KT-permute`, `triton-multi-KT-permute`, and `triton-single-KT-permute`.
* benchmark readings
```
cuda-multi-KT-permute    | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):  2.01 ms | Memory (P90): 1011.0
cuda-single-KT-permute   | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):  4.92 ms | Memory (P90): 1517.0
triton-multi-KT-permute  | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):  6.66 ms | Memory (P90): 1011.0
triton-single-KT-permute | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):  7.75 ms | Memory (P90): 1517.0
```

# details
* the triton kernel design is very similar to the corresponding cuda kernel
* the triton kernel takes in 2 `List[torch.Tensor]` as the input tensor list and output tensor list
* the triton kernel also takes in the `permutes` tensor and runs in parallel
# metrics
| function | kernel type | memory | cpu runtime | gpu runtime | net kernel runtime| torch.cat/split | metadata preparation | notes |
|---|---|---|---|
| multi-KT-permute | cuda | 1011.0 | 1.060 ms | 2.01 ms | 1.99 ms | No | in cpp |call a cpu op for metadata preparation for optimal cpu runtime, including sending the tensor address list (tensor) to gpu, very efficent.|
| single-KT-permute | cuda | 1517.0 | 1.793 ms | 4.92 ms | 1.99 ms | Yes | in python |optimal net kernel runtime as pure memory permutation on single tensor |
| multi-KT-permute | triton | 1011.0 | 2.593 ms | 6.66 ms | 6.64 ms | No | in python |sending tensor address list (tensor) to gpu (executed in python) not as efficent.|
| single-KT-permute | triton | 1517.0 | 1.669 ms | 7.75 ms | 5.00 ms | Yes | in python |maybe the optimal net kernel runtime in triton|

# traces
* [files](https://drive.google.com/drive/folders/173zZMnxnhLmFKkiJDXomS7c_ui0KKeV6?usp=sharing)
```
  adding: trace-[1 Op] KT_regroup_dup.json (deflated 92%)
  adding: trace-[1 Op] KT_regroup.json (deflated 92%)
  adding: trace-[2 Ops] permute_multi_embs_dup.json (deflated 92%)
  adding: trace-[2 Ops] permute_multi_embs.json (deflated 92%)
  adding: trace-[Module] KTRegroupAsDict_dup.json (deflated 95%)
  adding: trace-[Module] KTRegroupAsDict.json (deflated 91%)
  adding: trace-[Old Prod] permute_pooled_embs.json (deflated 92%)
  adding: trace-[Prod] KeyedTensor.regroup_dup.json (deflated 95%)
  adding: trace-[Prod] KeyedTensor.regroup.json (deflated 92%)
  adding: trace-[pytorch generic] fallback_dup.json (deflated 95%)
  adding: trace-[pytorch generic] fallback.json (deflated 95%)
  adding: trace-[Triton] permute_multi_embs.json (deflated 92%)
  adding: trace-[Triton] permute_pooled_embs.json (deflated 90%)
```
* cuda-multi-KT-permute
 {F1764404357}
* cuda-single-KT-permute
 {F1764401919}
* triton-multi-KT-permute
 {F1764399954}
* triton-single-KT-permute
 {F1764400438}

# reference
* [Trick: Passing tensor lists as pointer vectors](https://fburl.com/workplace/wdym5b7p)
* D39735564 for List[tensor] in triton kernel

Differential Revision: D52354486
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jul 20, 2024
1 parent 09d1ff2 commit 181ddd7
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 0 deletions.
54 changes: 54 additions & 0 deletions torchrec/sparse/tests/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import (
_desugar_keyed_tensors,
_fbgemm_permute_pooled_embs,
_regroup_keyed_tensors,
KeyedJaggedTensor,
Expand All @@ -26,6 +27,37 @@
regroup_kts,
)
from torchrec.sparse.tests.utils import build_groups, build_kts
from torchrec.sparse.triton_ops import (
triton_permute_multi_embs,
triton_permute_pooled_embs,
)


@torch.fx.wrap
def _triton_permute_pooled_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> List[torch.Tensor]:
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
permuted_values, splits = triton_permute_pooled_embs(
values,
keys,
lengths,
groups,
)
return list(torch.split(permuted_values, splits, dim=1))


@torch.fx.wrap
def _triton_permute_multi_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> List[torch.Tensor]:
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
return triton_permute_multi_embs(
values,
keys,
lengths,
groups,
)


class DummyModel(torch.nn.Module):
Expand Down Expand Up @@ -283,6 +315,28 @@ def main(
{"keyed_tensors": kts, "groups": groups},
profile,
)
bench(
"[Triton] permute_pooled_embs",
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
_triton_permute_pooled_embs,
{"keyed_tensors": kts, "groups": groups},
profile,
)
bench(
"[Triton] permute_multi_embs",
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
_triton_permute_multi_embs,
{"keyed_tensors": kts, "groups": groups},
profile,
)


if __name__ == "__main__":
Expand Down
69 changes: 69 additions & 0 deletions torchrec/sparse/tests/test_triton_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


import unittest

import torch

from torchrec.sparse.jagged_tensor import _desugar_keyed_tensors, _regroup_keyed_tensors
from torchrec.sparse.tests.utils import build_groups, build_kts
from torchrec.sparse.triton_ops import (
triton_permute_multi_embs,
triton_permute_pooled_embs,
)


class TestPermutePooledEmbs(unittest.TestCase):
# pyre-ignore[56]
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
def test_triton_permute_pooled_embs_forward(self) -> None:
kts = build_kts(
dense_features=2,
sparse_features=2,
dim_dense=16,
dim_sparse=16,
batch_size=8,
device=torch.device("cuda"),
run_backward=False,
)
groups = build_groups(
kts,
4,
)
keys, lengths, values = _desugar_keyed_tensors(kts)
output, splits = triton_permute_pooled_embs(values, keys, lengths, groups)
refs = _regroup_keyed_tensors(kts, groups)
outputs = torch.split(output, splits, dim=1)
for ref, output in zip(refs, outputs):
torch.testing.assert_close(ref, output)


class TestPermuteMultiEmbs(unittest.TestCase):
# pyre-ignore[56]
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
def test_triton_permute_multi_embs_forward(self) -> None:
kts = build_kts(
dense_features=2,
sparse_features=2,
dim_dense=16,
dim_sparse=16,
batch_size=8,
device=torch.device("cuda"),
run_backward=False,
)
groups = build_groups(
kts,
4,
)
keys, lengths, values = _desugar_keyed_tensors(kts)
outputs = triton_permute_multi_embs(values, keys, lengths, groups)
refs = _regroup_keyed_tensors(kts, groups)
for ref, output in zip(refs, outputs):
torch.testing.assert_close(ref, output)
219 changes: 219 additions & 0 deletions torchrec/sparse/triton_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Dict, List, Tuple

import torch

# @manual=//triton:triton
import triton

# @manual=//triton:triton
import triton.language as tl


def triton_permute_pooled_embs(
values: List[torch.Tensor],
keys: List[List[str]],
lengths: List[List[int]],
groups: List[List[str]],
) -> Tuple[torch.Tensor, List[int]]:
"""
Permute the values of a KeyedTensor based on the groups.
"""
assert len(values) == len(keys)
assert len(values) == len(lengths)
P = sum(len(g) for g in groups)
B = values[0].shape[0]
device = values[0].device
in_length: int = 0
out_length: int = 0
splits: List[int] = [0] * len(groups)

# permute: [in_offset, out_offset, length, next]
permutes: List[List[int]] = [[0] * 4 for _ in range(P)]
# key -> (in_tensor, in_offset, length)
lookup: Dict[str, Tuple[int, int, int]] = {}
for i, (key, length) in enumerate(zip(keys, lengths)):
for k, l in zip(key, length):
lookup[k] = (i, in_length, l)
in_length += l

curr = 0
for j, group in enumerate(groups):
for k in group:
in_tensor, in_offset, length = lookup[k]
permutes[curr][:] = [in_offset, out_length, length, 0]
out_length += length
splits[j] += length
curr += 1

permute_tensor = torch.tensor(permutes, dtype=torch.int32).to(
device, non_blocking=True
)
output: torch.Tensor = torch.empty(B, out_length, device=device)
permute_pooled_embeddings_kernel[(B, P)](
torch.concat(values, dim=1),
output,
permute_tensor,
in_length,
out_length,
)
return output, splits


@triton.jit
def permute_pooled_embeddings_kernel(
values,
outputs,
permutes,
in_length,
out_length,
):
batch_id = tl.program_id(0)
pid = tl.program_id(1)
in_offset = tl.load(permutes + 4 * pid)
out_offset = tl.load(permutes + 4 * pid + 1)
length = tl.load(permutes + 4 * pid + 2)
BLOCK_SIZE: tl.constexpr = 32

idx = tl.arange(0, BLOCK_SIZE)
in_ptr = values + batch_id * in_length + in_offset + idx
out_ptr = outputs + batch_id * out_length + out_offset + idx

for k in range(0, length, BLOCK_SIZE):
inputs = tl.load(in_ptr + k, mask=idx < length - k)
tl.store(out_ptr + k, inputs, mask=idx < length - k)


def triton_permute_multi_embs(
values: List[torch.Tensor],
keys: List[List[str]],
lengths: List[List[int]],
groups: List[List[str]],
) -> List[torch.Tensor]:
"""
Permute the values of a KeyedTensor based on the groups.
"""
assert len(values) == len(keys)
assert len(values) == len(lengths)
P = sum(len(g) for g in groups)
B = values[0].shape[0]
device = values[0].device
in_lengths: List[int] = [0] * len(values)
out_lengths: List[int] = [0] * len(groups)

inputs: torch.Tensor = torch.tensor(
[v.data_ptr() for v in values], dtype=torch.int64
).to(device, non_blocking=True)

# permute: [in_tensor, out_tensor, in_offset, out_offset, length, next]
permutes: List[List[int]] = [[0] * 6 for _ in range(P)]
# key -> (in_tensor, in_offset, length)
lookup: Dict[str, Tuple[int, int, int]] = {}
for i, (key, length) in enumerate(zip(keys, lengths)):
for k, l in zip(key, length):
lookup[k] = (i, in_lengths[i], l)
in_lengths[i] += l

curr = 0
for out_tensor, group in enumerate(groups):
for k in group:
in_tensor, in_offset, length = lookup[k]
permutes[curr][:] = [
in_tensor,
out_tensor,
in_offset,
out_lengths[out_tensor],
length,
0,
]
out_lengths[out_tensor] += length
curr += 1

permute_tensor = torch.tensor(permutes, dtype=torch.int64).to(
device, non_blocking=True
)
outputs: List[torch.Tensor] = [
torch.empty(B, L, device=device) for L in out_lengths
]
output: torch.Tensor = torch.tensor(
[o.data_ptr() for o in outputs], dtype=torch.int64
).to(device, non_blocking=True)
in_lengths_ptr: torch.Tensor = torch.tensor(in_lengths, dtype=torch.int64).to(
device, non_blocking=True
)
out_lengths_ptr: torch.Tensor = torch.tensor(out_lengths, dtype=torch.int64).to(
device, non_blocking=True
)
permute_multi_embeddings_kernel[(B, P)](
values[0],
inputs,
output,
permute_tensor,
in_lengths_ptr,
out_lengths_ptr,
)
return outputs


@triton.jit
def permute_multi_embeddings_kernel(
example,
inputs,
output,
permutes,
in_lengths,
out_lengths,
):
batch_id = tl.program_id(0)
pid = tl.program_id(1)
in_tensor = tl.load(permutes + 6 * pid)
out_tensor = tl.load(permutes + 6 * pid + 1)
in_offset = tl.load(permutes + 6 * pid + 2)
out_offset = tl.load(permutes + 6 * pid + 3)
length = tl.load(permutes + 6 * pid + 4)

in_length = tl.load(in_lengths + in_tensor)
out_length = tl.load(out_lengths + out_tensor)

BLOCK_SIZE: tl.constexpr = 32
idx = tl.arange(0, BLOCK_SIZE)

in_ptr = (
tl.load(inputs + in_tensor).to(example.dtype, bitcast=True)
+ batch_id * in_length
+ in_offset
+ idx
)
out_ptr = (
tl.load(output + out_tensor).to(example.dtype, bitcast=True)
+ batch_id * out_length
+ out_offset
+ idx
)

for k in range(0, length, BLOCK_SIZE):
in_data = tl.load(in_ptr + k, mask=idx < length - k)
tl.store(out_ptr + k, in_data, mask=idx < length - k)


# @custom_impl("torchrec::permute_multi_embeddings", "CUDA")
# @custom_impl("torchrec::permute_multi_embeddings", "AutogradCUDA")
# def permute_multi_embeddings(
# values: List[torch.Tensor],
# keys: List[List[str]],
# lengths: List[List[int]],
# groups: List[List[str]],
# ) -> List[torch.Tensor]:
# """
# Permute the values of a KeyedTensor based on the groups.
# """
# assert len(values) == len(keys)
# assert len(values) == len(lengths)

0 comments on commit 181ddd7

Please sign in to comment.