diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index b5fe73002..7382841f8 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1975,8 +1975,20 @@ def permute( indices_tensor, self.weights_or_none(), ) + elif is_torchdynamo_compiling() and not torch.jit.is_scripting(): + ( + permuted_lengths, + permuted_values, + permuted_weights, + ) = torch.ops.fbgemm.permute_2D_sparse_data_input1D( + indices_tensor, + self.lengths(), + self.values(), + self.stride(), + self.weights_or_none(), + permuted_length_per_key_sum, + ) else: - ( permuted_lengths, permuted_values, @@ -2357,8 +2369,24 @@ def dist_init( single_batch_per_rank = all( s == stride for s in stride_per_rank ) - - if single_batch_per_rank: + if ( + single_batch_per_rank + and is_torchdynamo_compiling() + and not torch.jit.is_scripting() + ): + ( + lengths, + values, + weights, + ) = torch.ops.fbgemm.permute_2D_sparse_data_input1D( + torch.jit._unwrap_optional(recat), + lengths, + values, + stride, + weights, + values.numel(), + ) + elif single_batch_per_rank: ( lengths, values, diff --git a/torchrec/sparse/tests/test_jagged_tensor_gpu.py b/torchrec/sparse/tests/test_jagged_tensor_gpu.py index b32075f8d..0f89c697e 100644 --- a/torchrec/sparse/tests/test_jagged_tensor_gpu.py +++ b/torchrec/sparse/tests/test_jagged_tensor_gpu.py @@ -11,7 +11,11 @@ import unittest import torch -from torchrec.sparse.jagged_tensor import _regroup_keyed_tensors, KeyedTensor +from torchrec.sparse.jagged_tensor import ( + _regroup_keyed_tensors, + KeyedJaggedTensor, + KeyedTensor, +) from torchrec.sparse.tests.utils import build_groups, build_kts from torchrec.test_utils import skip_if_asan_class @@ -111,3 +115,128 @@ def test_regroup_backward(self) -> None: torch.allclose(actual_kt_0_grad, expected_kt_0_grad) torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPUs", + ) + def test_permute(self) -> None: + values = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device + ) + lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device) + keys = ["index_0", "index_1", "index_2"] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + ) + indices = [1, 0, 2] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 3, 5, 8], + ) + self.assertEqual( + permuted_jag_tensor.values().tolist(), + [3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0], + ) + self.assertEqual( + permuted_jag_tensor.lengths().tolist(), [1, 1, 1, 0, 2, 0, 0, 3, 0] + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPUs", + ) + def test_permute_vb(self) -> None: + values = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device + ) + lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device) + keys = ["index_0", "index_1", "index_2"] + stride_per_key_per_rank = [[2], [4], [3]] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + + indices = [1, 0, 2] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 5, 6, 8], + ) + self.assertEqual( + permuted_jag_tensor.values().tolist(), + [2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0], + ) + self.assertEqual( + permuted_jag_tensor.lengths().tolist(), [1, 3, 0, 1, 1, 0, 0, 2, 0] + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPUs", + ) + def test_permute_duplicates(self) -> None: + values = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device + ) + lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device) + keys = ["index_0", "index_1", "index_2"] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + ) + + indices = [1, 0, 2, 1, 1] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual( + permuted_jag_tensor.keys(), + ["index_1", "index_0", "index_2", "index_1", "index_1"], + ) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 3, 5, 8, 11, 14], + ) + self.assertEqual( + permuted_jag_tensor.values().tolist(), + [ + 3.0, + 4.0, + 5.0, + 1.0, + 2.0, + 6.0, + 7.0, + 8.0, + 3.0, + 4.0, + 5.0, + 3.0, + 4.0, + 5.0, + ], + ) + self.assertEqual( + permuted_jag_tensor.lengths().tolist(), + [1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1], + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None)