Skip to content

Commit

Permalink
Reland [TorchRec][PT2] KJT custom op for 1d lengths input (pytorch#2183)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2183

# context
* previous landed D59031938 was reverted due to torchscript push schedule is behind
* adding `torch.jit.is_scripting()` to protect the exposure.

Reviewed By: IvanKobzarev

Differential Revision: D59081243
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jul 9, 2024
1 parent 9c74d8a commit 6ee9a94
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 4 deletions.
34 changes: 31 additions & 3 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1975,8 +1975,20 @@ def permute(
indices_tensor,
self.weights_or_none(),
)
elif is_torchdynamo_compiling() and not torch.jit.is_scripting():
(
permuted_lengths,
permuted_values,
permuted_weights,
) = torch.ops.fbgemm.permute_2D_sparse_data_input1D(
indices_tensor,
self.lengths(),
self.values(),
self.stride(),
self.weights_or_none(),
permuted_length_per_key_sum,
)
else:

(
permuted_lengths,
permuted_values,
Expand Down Expand Up @@ -2357,8 +2369,24 @@ def dist_init(
single_batch_per_rank = all(
s == stride for s in stride_per_rank
)

if single_batch_per_rank:
if (
single_batch_per_rank
and is_torchdynamo_compiling()
and not torch.jit.is_scripting()
):
(
lengths,
values,
weights,
) = torch.ops.fbgemm.permute_2D_sparse_data_input1D(
torch.jit._unwrap_optional(recat),
lengths,
values,
stride,
weights,
values.numel(),
)
elif single_batch_per_rank:
(
lengths,
values,
Expand Down
131 changes: 130 additions & 1 deletion torchrec/sparse/tests/test_jagged_tensor_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import unittest

import torch
from torchrec.sparse.jagged_tensor import _regroup_keyed_tensors, KeyedTensor
from torchrec.sparse.jagged_tensor import (
_regroup_keyed_tensors,
KeyedJaggedTensor,
KeyedTensor,
)
from torchrec.sparse.tests.utils import build_groups, build_kts
from torchrec.test_utils import skip_if_asan_class

Expand Down Expand Up @@ -111,3 +115,128 @@ def test_regroup_backward(self) -> None:

torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"Not enough GPUs, this test requires at least one GPUs",
)
def test_permute(self) -> None:
values = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
)
lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device)
keys = ["index_0", "index_1", "index_2"]

jag_tensor = KeyedJaggedTensor.from_lengths_sync(
values=values,
keys=keys,
lengths=lengths,
)
indices = [1, 0, 2]
permuted_jag_tensor = jag_tensor.permute(indices)

self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
self.assertEqual(
permuted_jag_tensor.offset_per_key(),
[0, 3, 5, 8],
)
self.assertEqual(
permuted_jag_tensor.values().tolist(),
[3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0],
)
self.assertEqual(
permuted_jag_tensor.lengths().tolist(), [1, 1, 1, 0, 2, 0, 0, 3, 0]
)
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"Not enough GPUs, this test requires at least one GPUs",
)
def test_permute_vb(self) -> None:
values = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
)
lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device)
keys = ["index_0", "index_1", "index_2"]
stride_per_key_per_rank = [[2], [4], [3]]

jag_tensor = KeyedJaggedTensor.from_lengths_sync(
values=values,
keys=keys,
lengths=lengths,
stride_per_key_per_rank=stride_per_key_per_rank,
)

indices = [1, 0, 2]
permuted_jag_tensor = jag_tensor.permute(indices)

self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
self.assertEqual(
permuted_jag_tensor.offset_per_key(),
[0, 5, 6, 8],
)
self.assertEqual(
permuted_jag_tensor.values().tolist(),
[2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0],
)
self.assertEqual(
permuted_jag_tensor.lengths().tolist(), [1, 3, 0, 1, 1, 0, 0, 2, 0]
)
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"Not enough GPUs, this test requires at least one GPUs",
)
def test_permute_duplicates(self) -> None:
values = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
)
lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device)
keys = ["index_0", "index_1", "index_2"]

jag_tensor = KeyedJaggedTensor.from_lengths_sync(
values=values,
keys=keys,
lengths=lengths,
)

indices = [1, 0, 2, 1, 1]
permuted_jag_tensor = jag_tensor.permute(indices)

self.assertEqual(
permuted_jag_tensor.keys(),
["index_1", "index_0", "index_2", "index_1", "index_1"],
)
self.assertEqual(
permuted_jag_tensor.offset_per_key(),
[0, 3, 5, 8, 11, 14],
)
self.assertEqual(
permuted_jag_tensor.values().tolist(),
[
3.0,
4.0,
5.0,
1.0,
2.0,
6.0,
7.0,
8.0,
3.0,
4.0,
5.0,
3.0,
4.0,
5.0,
],
)
self.assertEqual(
permuted_jag_tensor.lengths().tolist(),
[1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1],
)
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)

0 comments on commit 6ee9a94

Please sign in to comment.