Skip to content

Commit

Permalink
add PT2 support for permute_multi_embedding (pytorch#2381)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#208

X-link: pytorch/FBGEMM#3115

Pull Request resolved: pytorch#2381

# context
* make fbgemm operator `permute_multi_embedding` PT2 compatible.
* `out_lengths` is the list of sizes for all the output KT, which should be dynamic dims.
* change the `out_lengths` from `std::vector<int64_t>` to `c10::SymIntArrayRef`, and other type compatibility fixes.

# ref
* previously
graph breaks: P1557581728
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmphgx6wM/rank_0/failures_and_restarts.html
* new
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpxLBHmj/index.html

Reviewed By: IvanKobzarev

Differential Revision: D62226292

fbshipit-source-id: c826309939e0a33190b49a3aa090cbcc7515b20d
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Sep 12, 2024
1 parent 4530b72 commit 48d6eac
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
kjt_for_pt2_tracing,
register_fake_classes,
)
from torchrec.sparse.jagged_tensor import _kt_regroup_arguments

try:
# pyre-ignore
Expand Down Expand Up @@ -842,6 +843,33 @@ def test_permute_pooled_embs_split(self) -> None:
inp = torch.randn(12, 3)
_test_compile_fwd_bwd(m, inp, device)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_permute_multi_embedding(self) -> None:
device = "cuda"
batch_size = 16

def func(values, permutes, in_shapes, out_shapes, out_lengths):
return torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_shapes, out_shapes, out_lengths.tolist()
)

keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [torch.randn(batch_size, sum(L), device=device) for L in lengths]
for embs in values:
torch._dynamo.mark_dynamic(embs, 0)
torch._dynamo.mark_dynamic(embs, 1)
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments(
values[0], keys, lengths, groups
)
out_lengths = torch.tensor(out_lengths, device=device, dtype=torch.int32)
inp = (values, permutes, in_shapes, out_shapes, out_lengths)
_test_compile_fwd_bwd(func, inp, device, unpack_inp=True)

@unittest.skipIf(
torch.cuda.device_count() < 1,
"Not enough GPUs, this test requires at least one GPU",
Expand Down

0 comments on commit 48d6eac

Please sign in to comment.