From 3db28b3d84a6cf89f0414d37602b28365674a961 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Tue, 9 Jul 2024 11:06:37 -0700 Subject: [PATCH] implementation of fbgemm op - permute_multi_embedding (#2120) Summary: X-link: https://github.com/pytorch/FBGEMM/pull/2738 Pull Request resolved: https://github.com/pytorch/torchrec/pull/2120 # context * current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward * it has several limitations: a) it has to be a full permute, duplicates are not supported; b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient; c) the function output a single KT which requires a split operation * there is some attempt to support duplicated outputs, but the backward doesn't work * this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support # notes * this diff focuses on the implemenation and test of the operator * performance analysis and benchmark are in the next diff # operator example usage * used in python ``` # test inputs: 3 KTs with batch_size=2048 batch_size = 2048 keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] lengths = [[96, 256], [512, 128, 768], [1024]] values = [ torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) for lens in lengths ] # target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] # accessorial arguments to the op/kernel permutes, in_lengths, out_lengths = _multi_remap_to_groups( keys, lengths, groups ) # arguments outputs = torch.ops.fbgemm.permute_multi_embedding_internal_testing( values, permutes, in_lengths, out_lengths ) ``` * permutes ``` # each row represents a key (feature) permute move, which consists of the following parameters: # [input_tensor_idx, output_tensor_idx, input_key_idx, output_key_idx, key_length, magic_jump] permutes = tensor( [ [0, 0, 0, 0, 3, 4], # f1 [1, 0, 0, 3, 5, 0], # f3 [0, 1, 3, 0, 4, 0], # f2 [1, 2, 5, 0, 6, 0], # f4 [0, 2, 0, 6, 3, -6], # f1 [2, 2, 0, 9, 8, 0], # f6 [0, 3, 0, 0, 3, -8], # f1 [1, 3, 11, 3, 7, 0], # f5 ] ) ``` # details 1. from the above example usage, we can clearly see that the operatior takes in the following: a) values: List[torch.Tensor], which represents the input KTs b) permutes: torch.Tensor, which contains the permute information, will be explained later c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead d) in_lengths: torch.Tensor, lengths of input tensors, which is on device e) out_lengths: torch.Tensor, lengths of output tensors, which is on device 2. the operator returns a list of tensors, which represents the permuted KTs 3. `permute` is the most critical argument in this operator: a) 2-D tensor b) each row represents a key (feature) permute move c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump] d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors 4. The magic_jump a) It's only used in the backward computation b) it's usually 0, means no jump c) it's non-zero when there is a duplicate in the permute, e.g., the same feature appears more than once in the output d) the `magic_jump` is the next index of the very same feature in the permute sequence with some modifications e) modification-1: `magic_jump` is positive when it's the first of its kind [Start] f) modification-2: `magic_jump` is negative when it's not the first of its kind [Continue] g) modification-3: `magic_jump` is the negative value of the length of the permute sequence when it's the last of its kind. [Stop] Reviewed By: sryap Differential Revision: D57055616 fbshipit-source-id: 16673d3a2eafab93b08d4ff3c43d54366966064a --- torchrec/sparse/jagged_tensor.py | 108 +++++++++++ torchrec/sparse/tests/test_jagged_tensor.py | 189 +++++++++++++++++++- 2 files changed, 295 insertions(+), 2 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 476592397..14bc48123 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -36,6 +36,12 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu" + ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu" + ) except OSError: pass @@ -164,6 +170,24 @@ def _all_keys_used_once( return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups) +@torch.fx.wrap +def permute_multi_embedding( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> List[torch.Tensor]: + keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) + permutes, in_shape, out_shape, out_lengths = _kt_regroup_permutes( + values[0], keys, lengths, groups + ) + permuted_values = torch.ops.fbgemm.permute_multi_embedding( + values, + permutes, + in_shape, + out_shape, + out_lengths, + ) + return permuted_values + + @torch.fx.wrap def _fbgemm_permute_pooled_embs( keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] @@ -240,6 +264,90 @@ def _remap_to_groups( return permute, inv_permute, offsets, inv_offsets, splits +def _kt_regroup_permutes( + value: torch.Tensor, + keys: List[List[str]], + key_lengths: List[List[int]], + groups: List[List[str]], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """ + returns: permutes, in_shapes, out_shapes, out_lengths + """ + # key => (tensor_idx, key_index) + key_map: Dict[str, Tuple[int, int]] = { + key: (tensor_idx, key_idx) + for tensor_idx, tensor in enumerate(keys) + for key_idx, key in enumerate(tensor) + } + + # [offsets per tensor] + in_offsets: List[List[int]] = [[] for _ in key_lengths] + for i, tensor in enumerate(key_lengths): + in_offsets[i] = _cumsum(tensor) + in_lengths: List[int] = [sum(lengths) for lengths in key_lengths] + + # set total_permutes as the jump stop sign + total_permutes: int = sum(len(tensor) for tensor in groups) + out_lengths: List[int] = [0] * len(groups) + + # [input_tensor_idx, output_tensor_idx, input_start, output_start, length, jump] + permute_param = 6 + permutes: List[List[int]] = [[0] * permute_param for _ in range(total_permutes)] + + # record the last seen index, so that can make the jump from last_seen to current + last_seen: Dict[str, int] = {} + permute_idx = 0 + for output_tensor_idx, output_tenser in enumerate(groups): + output_start = 0 + for output_key in output_tenser: + input_tensor_idx, input_key_idx = key_map[output_key] + input_start = in_offsets[input_tensor_idx][input_key_idx] + length = key_lengths[input_tensor_idx][input_key_idx] + + # add jump data + if output_key not in last_seen: + jump = 0 # don't need to jump yet + # positive as a potential jump start + last_seen[output_key] = permute_idx + else: + prev = last_seen[output_key] + if prev >= 0: # positive ==> it's a jump start + # jump to current idx, positive as the jump start + permutes[prev][5] = permute_idx + else: # it's already in a jump sequence, mark as negative + permutes[-prev][5] = -permute_idx + # mark last_seen negative since it's already in jump + last_seen[output_key] = -permute_idx + # it's a potential jump stop + jump = -total_permutes + + permutes[permute_idx][:] = [ + input_tensor_idx, + output_tensor_idx, + input_start, + output_start, + length, + jump, + ] + permute_idx += 1 + output_start += length + out_lengths[output_tensor_idx] = output_start + + permute_tensor = torch.tensor(permutes, dtype=torch.int32) + in_shapes = torch.tensor(in_lengths, dtype=torch.int32) + out_shapes = torch.tensor(out_lengths, dtype=torch.int32) + device = value.device + permute_tensor = _pin_and_move(permute_tensor, device) + in_shapes = _pin_and_move(in_shapes, device) + out_shapes = _pin_and_move(out_shapes, device) + return ( + permute_tensor, + in_shapes, + out_shapes, + out_lengths, + ) + + def _values_string(values: torch.Tensor, start: int, end: int) -> str: size = values.size() if len(size) == 1: diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index b52b34a3c..e632eb83d 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -16,6 +16,7 @@ from torch.testing import FileCheck from torchrec.fx import symbolic_trace from torchrec.sparse.jagged_tensor import ( + _kt_regroup_permutes, _regroup_keyed_tensors, ComputeJTDictToKJT, ComputeKJTToJTDict, @@ -1397,6 +1398,192 @@ def test_permute_vb(self) -> None: ) self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + def test_kt_regroup_permutes(self) -> None: + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + for device in ["cpu", "meta", "cuda"]: + if device == "cuda" and not torch.cuda.is_available(): + continue + device = torch.device(device) + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( + torch.empty(0, device=device), keys, lengths, groups + ) + ref_permutes = [ + [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start + [1, 0, 0, 3, 5, 0], # f3 + [0, 1, 3, 0, 4, 0], # f2 + [1, 2, 5, 0, 6, 0], # f4 + [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence + [2, 2, 0, 9, 8, 0], # f6 + [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary + [1, 3, 11, 3, 7, 0], # f5 + ] + if device.type == "meta": + self.assertEqual( + permutes.shape, (len(ref_permutes), len(ref_permutes[0])) + ) + self.assertEqual(in_shapes.shape, (3,)) + self.assertEqual(out_shapes.shape, (4,)) + else: + self.assertTrue( + torch.equal( + permutes, + torch.tensor(ref_permutes, dtype=torch.int32, device=device), + ) + ) + self.assertEqual(in_shapes.tolist(), [7, 18, 8]) + self.assertEqual(out_shapes.tolist(), [8, 4, 17, 10]) + self.assertEqual(out_lengths, [8, 4, 17, 10]) + + def test_multi_permute_forward_cpu(self) -> None: + batch_size = 32 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) + for lens in lengths + ] + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( + values[0], keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + def test_multi_permute_forward_meta(self) -> None: + batch_size = 32 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device="meta", requires_grad=True) + for lens in lengths + ] + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( + values[0], keys, lengths, groups + ) + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths + ) + for out, ref in zip(outputs, out_lengths): + self.assertEqual(out.shape, (batch_size, ref)) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "CUDA is not available", + ) + def test_multi_permute_forward_gpu(self) -> None: + batch_size = 1024 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[96, 256], [512, 128, 768], [1024]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) + for lens in lengths + ] + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( + values[0], keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + def test_multi_permute_backward_cpu(self) -> None: + batch_size = 32 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) + for lens in lengths + ] + ref_values = [v.detach() for v in values] + for v in ref_values: + v.requires_grad = True + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( + values[0], keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + ref_loss, loss = refs[0].sum(), outputs[0].sum() + for i in range(1, len(refs)): + ref_loss += (i + 1.1) * refs[i].sum() + loss += (i + 1.1) * outputs[i].sum() + ref_loss.backward() + loss.backward() + for val, ref in zip(values, ref_values): + val_grad, ref_grad = val.grad, ref.grad + assert isinstance(val_grad, torch.Tensor) + self.assertTrue(torch.allclose(val_grad, ref_grad)) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "CUDA is not available", + ) + def test_multi_permute_backward_gpu(self) -> None: + batch_size = 2048 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[96, 256], [512, 128, 768], [1024]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) + for lens in lengths + ] + ref_values = [v.detach() for v in values] + for v in ref_values: + v.requires_grad = True + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( + values[0], keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + ref_loss, loss = refs[0].sum(), outputs[0].sum() + for i in range(1, len(refs)): + ref_loss += (i + 1.1) * refs[i].sum() + loss += (i + 1.1) * outputs[i].sum() + ref_loss.backward() + loss.backward() + for val, ref in zip(values, ref_values): + val_grad, ref_grad = val.grad, ref.grad + assert isinstance(val_grad, torch.Tensor) + self.assertTrue(torch.allclose(val_grad, ref_grad)) + def test_permute_duplicates(self) -> None: values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0]) @@ -1673,8 +1860,6 @@ def test_string_vb(self) -> None: stride_per_key_per_rank=stride_per_key_per_rank, ) - print(str(jag_tensor)) - self.assertEqual( str(jag_tensor), """\