Skip to content

Commit

Permalink
implementation of fbgemm op - permute_multi_embedding (#2120)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/FBGEMM#2738


# context
* current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward
* it has several limitations:
a) it has to be a full permute, duplicates are not supported;
b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient;
c) the function output a single KT which requires a split operation
* there is some attempt to support duplicated outputs, but the backward doesn't work
* this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support

# notes
* this diff focuses on the implemenation and test of the operator
* performance analysis and benchmark are in the next diff

# operator example usage
* used in python
```
# test inputs: 3 KTs with batch_size=2048
batch_size = 2048
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[96, 256], [512, 128, 768], [1024]]
values = [
    torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
    for lens in lengths
]

# target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]

# accessorial arguments to the op/kernel
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
    keys, lengths, groups
)

# arguments
outputs = torch.ops.fbgemm.permute_multi_embedding(
    values, permutes, in_lengths, out_lengths
)
```
* permutes
```
permutes = tensor(
            [
                [0, 0, 0, 0, 3, 4],  # f1
                [1, 0, 0, 3, 5, 0],  # f3
                [0, 1, 3, 0, 4, 0],  # f2
                [1, 2, 5, 0, 6, 0],  # f4
                [0, 2, 0, 6, 3, -6],  # f1
                [2, 2, 0, 9, 8, 0],  # f6
                [0, 3, 0, 0, 3, -8],  # f1
                [1, 3, 11, 3, 7, 0],  # f5
            ]
)
```

# details
1. from the above example usage, we can clearly see that the operatior takes in the following:
a) values: List[torch.Tensor], which represents the input KTs
b) permutes: torch.Tensor, which contains the permute information, will be explained later
c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead
d) in_lengths: torch.Tensor, lengths of input tensors, which is on device
e) out_lengths: torch.Tensor, lengths of output tensors, which is on device
2. the operator returns a list of tensors, which represents the permuted KTs
3. `permute` is the most critical argument in this operator:
a) 2-D tensor
b) each row represents a key (feature) permute move
c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump]
d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors

Differential Revision: D57055616
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jun 22, 2024
1 parent 530bf04 commit 18b184c
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 2 deletions.
82 changes: 82 additions & 0 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu"
)
except OSError:
pass

Expand Down Expand Up @@ -240,6 +246,82 @@ def _remap_to_groups(
return permute, inv_permute, offsets, inv_offsets, splits


def _multi_remap_to_groups(
keys: List[List[str]],
key_lengths: List[List[int]],
groups: List[List[str]],
) -> Tuple[List[int], List[int], List[int]]:
"""
Given a list of keys and lengths per key for each group, return the permute 2D tensor, and 1D tensor lengths:
[[input_tensor_idx, output_tensor_idx, input_start, output_start, length]], [length]
"""
# key => (tensor_idx, key_index)
key_map: Dict[str, Tuple[int, int]] = {
key: (tensor_idx, key_idx)
for tensor_idx, tensor in enumerate(keys)
for key_idx, key in enumerate(tensor)
}

# [offsets per tensor]
in_offsets: List[List[int]] = [[] for _ in key_lengths]
for i, tensor in enumerate(key_lengths):
in_offsets[i] = _cumsum(tensor)
in_lengths: List[int] = [sum(lengths) for lengths in key_lengths]

# set total_permutes as the jump stop sign
total_permutes: int = sum(len(tensor) for tensor in groups)
out_lengths: List[int] = [0] * len(groups)

# [input_tensor_idx, output_tensor_idx, input_start, output_start, length, jump]
permute_param = 6
permutes: List[int] = [0] * (total_permutes * permute_param)

# record the last seen index, so that can make the jump from last_seen to current
last_seen: Dict[str, int] = {}
permute_idx = 0
for output_tensor_idx, output_tenser in enumerate(groups):
output_start = 0
for output_key in output_tenser:
input_tensor_idx, input_key_idx = key_map[output_key]
input_start = in_offsets[input_tensor_idx][input_key_idx]
length = key_lengths[input_tensor_idx][input_key_idx]

# add jump data
if output_key not in last_seen:
jump = 0 # don't need to jump yet
# positive as a potential jump start
last_seen[output_key] = permute_idx
else:
prev = last_seen[output_key]
if prev >= 0: # positive ==> it's a jump start
# jump to current idx, positive as the jump start
permutes[prev * permute_param + 5] = permute_idx
else: # it's already in a jump sequence, mark as negative
permutes[-prev * permute_param + 5] = -permute_idx
# mark last_seen negative since it's already in jump
last_seen[output_key] = -permute_idx
# it's a potential jump stop
jump = -total_permutes

permutes[permute_idx * permute_param : permute_idx * permute_param + 6] = [
input_tensor_idx,
output_tensor_idx,
input_start,
output_start,
length,
jump,
]
permute_idx += 1
output_start += length
out_lengths[output_tensor_idx] = output_start

return (
permutes,
in_lengths,
out_lengths,
)


def _values_string(values: torch.Tensor, start: int, end: int) -> str:
size = values.size()
if len(size) == 1:
Expand Down
167 changes: 165 additions & 2 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.testing import FileCheck
from torchrec.fx import symbolic_trace
from torchrec.sparse.jagged_tensor import (
_multi_remap_to_groups,
_regroup_keyed_tensors,
ComputeJTDictToKJT,
ComputeKJTToJTDict,
Expand Down Expand Up @@ -1374,6 +1375,170 @@ def test_permute_vb(self) -> None:
)
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)

def test_multi_remap_to_group(self) -> None:
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
ref_permutes = [
[0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start
[1, 0, 0, 3, 5, 0], # f3
[0, 1, 3, 0, 4, 0], # f2
[1, 2, 5, 0, 6, 0], # f4
[0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence
[2, 2, 0, 9, 8, 0], # f6
[0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary
[1, 3, 11, 3, 7, 0], # f5
]
self.assertEqual(permutes, [i for p in ref_permutes for i in p])
self.assertEqual(in_lengths, [7, 18, 8])
self.assertEqual(out_lengths, [8, 4, 17, 10])

def test_multi_permute_forward_cpu(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
for lens in lengths
]
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
refs = [[] for _ in groups]
for i in range(len(permutes) // 6):
in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6]
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_lengths, out_lengths
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

def test_multi_permute_forward_meta(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="meta", requires_grad=True)
for lens in lengths
]
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
refs = [[] for _ in groups]
for i in range(len(permutes) // 6):
in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6]
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_lengths, out_lengths
)
for out, ref in zip(outputs, refs):
self.assertEqual(out.shape, ref.shape)

def test_multi_permute_forward_gpu(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
for lens in lengths
]
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
refs = [[] for _ in groups]
for i in range(len(permutes) // 6):
in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6]
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_lengths, out_lengths
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

def test_multi_permute_backward_cpu(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
for lens in lengths
]
ref_values = [v.detach() for v in values]
for v in ref_values:
v.requires_grad = True
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
refs = [[] for _ in groups]
for i in range(len(permutes) // 6):
in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6]
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_lengths, out_lengths
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

ref_loss, loss = refs[0].sum(), outputs[0].sum()
for i in range(1, len(refs)):
ref_loss += (i + 1.1) * refs[i].sum()
loss += (i + 1.1) * outputs[i].sum()
ref_loss.backward()
loss.backward()
for val, ref in zip(values, ref_values):
val_grad, ref_grad = val.grad, ref.grad
assert isinstance(val_grad, torch.Tensor)
self.assertTrue(torch.allclose(val_grad, ref_grad))

def test_multi_permute_backward_gpu(self) -> None:
batch_size = 2048
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[96, 256], [512, 128, 768], [1024]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
for lens in lengths
]
ref_values = [v.detach() for v in values]
for v in ref_values:
v.requires_grad = True
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
refs = [[] for _ in groups]
for i in range(len(permutes) // 6):
in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6]
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_lengths, out_lengths
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

ref_loss, loss = refs[0].sum(), outputs[0].sum()
for i in range(1, len(refs)):
ref_loss += (i + 1.1) * refs[i].sum()
loss += (i + 1.1) * outputs[i].sum()
ref_loss.backward()
loss.backward()
for val, ref in zip(values, ref_values):
val_grad, ref_grad = val.grad, ref.grad
assert isinstance(val_grad, torch.Tensor)
self.assertTrue(torch.allclose(val_grad, ref_grad))

def test_permute_duplicates(self) -> None:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])
Expand Down Expand Up @@ -1650,8 +1815,6 @@ def test_string_vb(self) -> None:
stride_per_key_per_rank=stride_per_key_per_rank,
)

print(str(jag_tensor))

self.assertEqual(
str(jag_tensor),
"""\
Expand Down

0 comments on commit 18b184c

Please sign in to comment.