From dc810568440e77a3be84f9325543d96bffaae452 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Wed, 14 Aug 2024 16:47:52 -0700 Subject: [PATCH] add test for supporting torch.float16 and torch.bfloat16 (#2300) Summary: X-link: https://github.com/pytorch/FBGEMM/pull/2992 Pull Request resolved: https://github.com/pytorch/torchrec/pull/2300 # context * We found the new operator `permute_multi_embedding` can't support `torch.float16` in an inference test * added test to cover the dtype support * before the operator change, we see the following error ``` Failures: 1) torchrec.sparse.tests.test_jagged_tensor.TestKeyedTensorRegroupOp: test_multi_permute_dtype 1) RuntimeError: expected scalar type Float but found Half File "torchrec/sparse/tests/test_jagged_tensor.py", line 2798, in test_multi_permute_dtype outputs = torch.ops.fbgemm.permute_multi_embedding( File "torch/_ops.py", line 1113, in __call__ return self._op(*args, **(kwargs or {})) ``` * suspicion is that in the cpu operator, there are tensor data access with `data_ptr` in the code, which limited the dtype could only be `float32` ``` auto outp = outputs[out_tensor][b].data_ptr() + out_offset; auto inp = inputs[in_tensor][b].data_ptr() + in_offset; ``` # changes * use `FBGEMM_DISPATCH_FLOATING_TYPES` to dispatch the dtype to template `scalar_t`. * after the change the operator can support `float16`, `bfloat16` WARNING: somehow this operator still can't support `int` types. Differential Revision: D57143637 --- torchrec/sparse/tests/test_jagged_tensor.py | 46 +++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index b9385b5cc..27f21978b 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -2775,6 +2775,52 @@ def test_multi_permute_forward(self, device_str: str, batch_size: int) -> None: for out, ref in zip(outputs, refs): torch.testing.assert_close(out, ref) + @repeat_test( + device_str=["meta", "cpu", "cuda"], + dtype=[ + # torch.int, + # torch.uint8, + # torch.int8, + # torch.int16, + # torch.float64, + torch.float, + torch.float32, + torch.float16, + torch.bfloat16, + ], + ) + def test_multi_permute_dtype(self, device_str: str, dtype: torch.dtype) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + batch_size = 4 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(L), device=device, dtype=dtype) for L in lengths + ] + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + values[0], keys, lengths, groups + ) + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths + ) + + if device_str == "meta": + for out, ref in zip(outputs, out_lengths): + self.assertEqual(out.shape, (batch_size, ref)) + else: + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out, in_start, _, length, _ = permutes[i].tolist() + refs[out].append(values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + for out, ref in zip(outputs, refs): + torch.testing.assert_close(out, ref) + self.assertEqual(out.dtype, ref.dtype) + @repeat_test( ["cpu", 32, [[3, 4], [5, 6, 7], [8]]], ["cuda", 128, [[96, 256], [512, 128, 768], [1024]]],