diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py index b9dd12d3b..473aa1192 100644 --- a/torchrec/sparse/tests/jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -18,6 +18,7 @@ from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult from torchrec.modules.regroup import KTRegroupAsDict from torchrec.sparse.jagged_tensor import ( + _desugar_keyed_tensors, _fbgemm_permute_pooled_embs, _regroup_keyed_tensors, KeyedJaggedTensor, @@ -26,6 +27,37 @@ regroup_kts, ) from torchrec.sparse.tests.utils import build_groups, build_kts +from torchrec.sparse.triton_ops import ( + triton_permute_multi_embs, + triton_permute_pooled_embs, +) + + +@torch.fx.wrap +def _triton_permute_pooled_embs( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> List[torch.Tensor]: + keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) + permuted_values, splits = triton_permute_pooled_embs( + values, + keys, + lengths, + groups, + ) + return list(torch.split(permuted_values, splits, dim=1)) + + +@torch.fx.wrap +def _triton_permute_multi_embs( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> List[torch.Tensor]: + keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) + return triton_permute_multi_embs( + values, + keys, + lengths, + groups, + ) class DummyModel(torch.nn.Module): @@ -283,6 +315,28 @@ def main( {"keyed_tensors": kts, "groups": groups}, profile, ) + bench( + "[Triton] permute_pooled_embs", + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + _triton_permute_pooled_embs, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + bench( + "[Triton] permute_multi_embs", + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + _triton_permute_multi_embs, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) if __name__ == "__main__": diff --git a/torchrec/sparse/tests/test_triton_ops.py b/torchrec/sparse/tests/test_triton_ops.py new file mode 100644 index 000000000..90f53edae --- /dev/null +++ b/torchrec/sparse/tests/test_triton_ops.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import unittest + +import torch + +from torchrec.sparse.jagged_tensor import _desugar_keyed_tensors, _regroup_keyed_tensors +from torchrec.sparse.tests.utils import build_groups, build_kts +from torchrec.sparse.triton_ops import ( + triton_permute_multi_embs, + triton_permute_pooled_embs, +) + + +class TestPermutePooledEmbs(unittest.TestCase): + # pyre-ignore[56] + @unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available") + def test_triton_permute_pooled_embs_forward(self) -> None: + kts = build_kts( + dense_features=2, + sparse_features=2, + dim_dense=16, + dim_sparse=16, + batch_size=8, + device=torch.device("cuda"), + run_backward=False, + ) + groups = build_groups( + kts, + 4, + ) + keys, lengths, values = _desugar_keyed_tensors(kts) + output, splits = triton_permute_pooled_embs(values, keys, lengths, groups) + refs = _regroup_keyed_tensors(kts, groups) + outputs = torch.split(output, splits, dim=1) + for ref, output in zip(refs, outputs): + torch.testing.assert_close(ref, output) + + +class TestPermuteMultiEmbs(unittest.TestCase): + # pyre-ignore[56] + @unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available") + def test_triton_permute_multi_embs_forward(self) -> None: + kts = build_kts( + dense_features=2, + sparse_features=2, + dim_dense=16, + dim_sparse=16, + batch_size=8, + device=torch.device("cuda"), + run_backward=False, + ) + groups = build_groups( + kts, + 4, + ) + keys, lengths, values = _desugar_keyed_tensors(kts) + outputs = triton_permute_multi_embs(values, keys, lengths, groups) + refs = _regroup_keyed_tensors(kts, groups) + for ref, output in zip(refs, outputs): + torch.testing.assert_close(ref, output) diff --git a/torchrec/sparse/triton_ops.py b/torchrec/sparse/triton_ops.py new file mode 100644 index 000000000..355de5206 --- /dev/null +++ b/torchrec/sparse/triton_ops.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Dict, List, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl + + +def triton_permute_pooled_embs( + values: List[torch.Tensor], + keys: List[List[str]], + lengths: List[List[int]], + groups: List[List[str]], +) -> Tuple[torch.Tensor, List[int]]: + """ + Permute the values of a KeyedTensor based on the groups. + """ + assert len(values) == len(keys) + assert len(values) == len(lengths) + P = sum(len(g) for g in groups) + B = values[0].shape[0] + device = values[0].device + in_length: int = 0 + out_length: int = 0 + splits: List[int] = [0] * len(groups) + + # permute: [in_offset, out_offset, length, next] + permutes: List[List[int]] = [[0] * 4 for _ in range(P)] + # key -> (in_tensor, in_offset, length) + lookup: Dict[str, Tuple[int, int, int]] = {} + for i, (key, length) in enumerate(zip(keys, lengths)): + for k, l in zip(key, length): + lookup[k] = (i, in_length, l) + in_length += l + + curr = 0 + for j, group in enumerate(groups): + for k in group: + in_tensor, in_offset, length = lookup[k] + permutes[curr][:] = [in_offset, out_length, length, 0] + out_length += length + splits[j] += length + curr += 1 + + permute_tensor = torch.tensor(permutes, dtype=torch.int32).to( + device, non_blocking=True + ) + output: torch.Tensor = torch.empty(B, out_length, device=device) + permute_pooled_embeddings_kernel[(B, P)]( + torch.concat(values, dim=1), + output, + permute_tensor, + in_length, + out_length, + ) + return output, splits + + +@triton.jit +def permute_pooled_embeddings_kernel( + values, + outputs, + permutes, + in_length, + out_length, +): + batch_id = tl.program_id(0) + pid = tl.program_id(1) + in_offset = tl.load(permutes + 4 * pid) + out_offset = tl.load(permutes + 4 * pid + 1) + length = tl.load(permutes + 4 * pid + 2) + BLOCK_SIZE: tl.constexpr = 32 + + idx = tl.arange(0, BLOCK_SIZE) + in_ptr = values + batch_id * in_length + in_offset + idx + out_ptr = outputs + batch_id * out_length + out_offset + idx + + for k in range(0, length, BLOCK_SIZE): + inputs = tl.load(in_ptr + k, mask=idx < length - k) + tl.store(out_ptr + k, inputs, mask=idx < length - k) + + +def triton_permute_multi_embs( + values: List[torch.Tensor], + keys: List[List[str]], + lengths: List[List[int]], + groups: List[List[str]], +) -> List[torch.Tensor]: + """ + Permute the values of a KeyedTensor based on the groups. + """ + assert len(values) == len(keys) + assert len(values) == len(lengths) + P = sum(len(g) for g in groups) + B = values[0].shape[0] + device = values[0].device + in_lengths: List[int] = [0] * len(values) + out_lengths: List[int] = [0] * len(groups) + + inputs: torch.Tensor = torch.tensor( + [v.data_ptr() for v in values], dtype=torch.int64 + ).to(device, non_blocking=True) + + # permute: [in_tensor, out_tensor, in_offset, out_offset, length, next] + permutes: List[List[int]] = [[0] * 6 for _ in range(P)] + # key -> (in_tensor, in_offset, length) + lookup: Dict[str, Tuple[int, int, int]] = {} + for i, (key, length) in enumerate(zip(keys, lengths)): + for k, l in zip(key, length): + lookup[k] = (i, in_lengths[i], l) + in_lengths[i] += l + + curr = 0 + for out_tensor, group in enumerate(groups): + for k in group: + in_tensor, in_offset, length = lookup[k] + permutes[curr][:] = [ + in_tensor, + out_tensor, + in_offset, + out_lengths[out_tensor], + length, + 0, + ] + out_lengths[out_tensor] += length + curr += 1 + + permute_tensor = torch.tensor(permutes, dtype=torch.int64).to( + device, non_blocking=True + ) + outputs: List[torch.Tensor] = [ + torch.empty(B, L, device=device) for L in out_lengths + ] + output: torch.Tensor = torch.tensor( + [o.data_ptr() for o in outputs], dtype=torch.int64 + ).to(device, non_blocking=True) + in_lengths_ptr: torch.Tensor = torch.tensor(in_lengths, dtype=torch.int64).to( + device, non_blocking=True + ) + out_lengths_ptr: torch.Tensor = torch.tensor(out_lengths, dtype=torch.int64).to( + device, non_blocking=True + ) + permute_multi_embeddings_kernel[(B, P)]( + values[0], + inputs, + output, + permute_tensor, + in_lengths_ptr, + out_lengths_ptr, + ) + return outputs + + +@triton.jit +def permute_multi_embeddings_kernel( + example, + inputs, + output, + permutes, + in_lengths, + out_lengths, +): + batch_id = tl.program_id(0) + pid = tl.program_id(1) + in_tensor = tl.load(permutes + 6 * pid) + out_tensor = tl.load(permutes + 6 * pid + 1) + in_offset = tl.load(permutes + 6 * pid + 2) + out_offset = tl.load(permutes + 6 * pid + 3) + length = tl.load(permutes + 6 * pid + 4) + + in_length = tl.load(in_lengths + in_tensor) + out_length = tl.load(out_lengths + out_tensor) + + BLOCK_SIZE: tl.constexpr = 32 + idx = tl.arange(0, BLOCK_SIZE) + + in_ptr = ( + tl.load(inputs + in_tensor).to(example.dtype, bitcast=True) + + batch_id * in_length + + in_offset + + idx + ) + out_ptr = ( + tl.load(output + out_tensor).to(example.dtype, bitcast=True) + + batch_id * out_length + + out_offset + + idx + ) + + for k in range(0, length, BLOCK_SIZE): + in_data = tl.load(in_ptr + k, mask=idx < length - k) + tl.store(out_ptr + k, in_data, mask=idx < length - k) + + +# @custom_impl("torchrec::permute_multi_embeddings", "CUDA") +# @custom_impl("torchrec::permute_multi_embeddings", "AutogradCUDA") +# def permute_multi_embeddings( +# values: List[torch.Tensor], +# keys: List[List[str]], +# lengths: List[List[int]], +# groups: List[List[str]], +# ) -> List[torch.Tensor]: +# """ +# Permute the values of a KeyedTensor based on the groups. +# """ +# assert len(values) == len(keys) +# assert len(values) == len(lengths)