From 48d6eac51d01443563828ee0b4ef62df48d952d2 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Thu, 12 Sep 2024 11:30:20 -0700 Subject: [PATCH] add PT2 support for permute_multi_embedding (#2381) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/208 X-link: https://github.com/pytorch/FBGEMM/pull/3115 Pull Request resolved: https://github.com/pytorch/torchrec/pull/2381 # context * make fbgemm operator `permute_multi_embedding` PT2 compatible. * `out_lengths` is the list of sizes for all the output KT, which should be dynamic dims. * change the `out_lengths` from `std::vector` to `c10::SymIntArrayRef`, and other type compatibility fixes. # ref * previously graph breaks: P1557581728 https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmphgx6wM/rank_0/failures_and_restarts.html * new https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpxLBHmj/index.html Reviewed By: IvanKobzarev Differential Revision: D62226292 fbshipit-source-id: c826309939e0a33190b49a3aa090cbcc7515b20d --- torchrec/distributed/tests/test_pt2.py | 28 ++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index f042b8358..c14906abc 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -37,6 +37,7 @@ kjt_for_pt2_tracing, register_fake_classes, ) +from torchrec.sparse.jagged_tensor import _kt_regroup_arguments try: # pyre-ignore @@ -842,6 +843,33 @@ def test_permute_pooled_embs_split(self) -> None: inp = torch.randn(12, 3) _test_compile_fwd_bwd(m, inp, device) + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_permute_multi_embedding(self) -> None: + device = "cuda" + batch_size = 16 + + def func(values, permutes, in_shapes, out_shapes, out_lengths): + return torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths.tolist() + ) + + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [torch.randn(batch_size, sum(L), device=device) for L in lengths] + for embs in values: + torch._dynamo.mark_dynamic(embs, 0) + torch._dynamo.mark_dynamic(embs, 1) + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + values[0], keys, lengths, groups + ) + out_lengths = torch.tensor(out_lengths, device=device, dtype=torch.int32) + inp = (values, permutes, in_shapes, out_shapes, out_lengths) + _test_compile_fwd_bwd(func, inp, device, unpack_inp=True) + @unittest.skipIf( torch.cuda.device_count() < 1, "Not enough GPUs, this test requires at least one GPU",