Skip to content

Commit

Permalink
FBGEMM kernel for KeyedTensor (PooledEmbedding) permute mapping
Browse files Browse the repository at this point in the history
Summary:
# context
* current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward
* it has several limitations:
a) it has to be a full permute, duplicates are not supported;
b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient;
c) the function output a single KT which requires a split operation
* there is some attempt to support duplicated outputs, but the backward doesn't work
* this diff is trying to create a new kernel (named `multi_permute_pooled_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support

# operator example usage
* used in python
```
# test inputs: 3 KTs with batch_size=2048
batch_size = 2048
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[96, 256], [512, 128, 768], [1024]]
values = [
    torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
    for lens in lengths
]

# target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]

# accessorial arguments to the op/kernel
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
    keys, lengths, groups
)

# arguments
outputs = torch.ops.fbgemm.permute_multi_embedding(
    values, # list of tensors (on device)
    permutes.to(device=torch.device("cuda")), # tensor on device
    out_lengths.tolist(), # List[int] on CPU
    in_lengths.to(device=torch.device("cuda")), # tensor on device
    out_lengths.to(device=torch.device("cuda")), # tensor on device
)
```
* values
```
permutes = tensor(
            [
                [0, 0, 0, 0, 3, 4],  # f1
                [1, 0, 0, 3, 5, 0],  # f3
                [0, 1, 3, 0, 4, 0],  # f2
                [1, 2, 5, 0, 6, 0],  # f4
                [0, 2, 0, 6, 3, -6],  # f1
                [2, 2, 0, 9, 8, 0],  # f6
                [0, 3, 0, 0, 3, -8],  # f1
                [1, 3, 11, 3, 7, 0],  # f5
            ]
)
```

# details
1. from the above example usage, we can clean see that the operatior takes in the following:
a) values: List[torch.Tensor], which represents the input KTs
b) permutes: torch.Tensor, which contains the permute information, will be explained later
c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead
d) in_lengths: torch.Tensor, lengths of input tensors, which is on device
e) out_lengths: torch.Tensor, lengths of output tensors, which is on device
2. the operator returns a list of tensors, which represents the permuted KTs
3. `permute` is the most critical argument in this operator:
a) 2-D tensor
b) each row represents key (feature) permute move
c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump]
d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors

# performance notes
The good:
1. the algorithm is designed in a way that it doesn't need to know in advance whether the 1-to-N mapping exists in the permutes. 
2. `_all_keys_used_once` is no longer needed
3. no longer need a torch.cat before calling the old operator

The same bad:
1. it requires several HtoD communications (move tensor to device):
a) 3 tensors, which are `permutes`, `input_lengths`, and `output_lengths`. Those tensors needs to be on the device so that the cuda kernels has access to it.
b) 2 lists of (scalar_t*) pointers, input and output tensor lists. 
c) Didn't find a good way to let the kernel knows the address of the lists of input/output tensors, because the lists are also need to be on the device.
2. tensor.contiguous for the backward function, it looks like the grad from the backward are somehow not contiguous

Differential Revision: D57055616
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jun 17, 2024
1 parent be40210 commit 4958d8e
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 2 deletions.
67 changes: 67 additions & 0 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,73 @@ def _remap_to_groups(
return permute, inv_permute, offsets, inv_offsets, splits


@torch.fx.wrap
def _multi_remap_to_groups(
keys: List[List[str]],
key_lengths: List[List[int]],
groups: List[List[str]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given a list of keys and lengths per key for each group, return the permute 2D tensor, and 1D tensor lengths:
[[input_tensor_idx, output_tensor_idx, input_start, output_start, length]], [length]
"""
# key => (tensor_idx, key_index)
key_map: Dict[str, Tuple[int, int]] = {
key: (tensor_idx, key_idx)
for tensor_idx, tensor in enumerate(keys)
for key_idx, key in enumerate(tensor)
}

# [offsets per tensor]
offsets_list: List[List[int]] = [_cumsum(tensor) for tensor in key_lengths]

# [input_tensor_idx, output_tensor_idx, input_start, output_start, length]
permute_list: List[List[int]] = []
output_lengths: List[int] = [0] * len(groups)

total_permutes = sum(len(tensor) for tensor in groups)
last_seen: Dict[str, int] = {}
for output_tensor_idx, output_tenser in enumerate(groups):
output_start = 0
for output_key in output_tenser:
input_tensor_idx, input_key_idx = key_map[output_key]
input_start = offsets_list[input_tensor_idx][input_key_idx]
length = key_lengths[input_tensor_idx][input_key_idx]

# add jump data
if output_key in last_seen:
jump = last_seen[output_key]
if jump >= 0: # it's a jump start
# change previous jump to current
permute_list[jump][5] = len(permute_list)
else: # it's already in a jump sequence
permute_list[-jump][5] = -len(permute_list)
last_seen[output_key] = -len(permute_list) # it's already in jump
jump = -total_permutes
else:
jump = 0
last_seen[output_key] = len(permute_list) # potential jump start

permute_list.append(
[
input_tensor_idx,
output_tensor_idx,
input_start,
output_start,
length,
jump,
]
)
output_start += length
output_lengths[output_tensor_idx] = output_start
permutes = torch.tensor(permute_list, dtype=torch.int64)
in_lengths = torch.tensor(
[offsets[-1] for offsets in offsets_list], dtype=torch.int64
)
out_lengths = torch.tensor(output_lengths, dtype=torch.int64)
return permutes, in_lengths, out_lengths


def _values_string(values: torch.Tensor, start: int, end: int) -> str:
size = values.size()
if len(size) == 1:
Expand Down
170 changes: 168 additions & 2 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.testing import FileCheck
from torchrec.fx import symbolic_trace
from torchrec.sparse.jagged_tensor import (
_multi_remap_to_groups,
_regroup_keyed_tensors,
ComputeJTDictToKJT,
ComputeKJTToJTDict,
Expand Down Expand Up @@ -1374,6 +1375,173 @@ def test_permute_vb(self) -> None:
)
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)

def test_multi_remap_to_group(self) -> None:
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
res, in_lengths, out_lengths = _multi_remap_to_groups(keys, lengths, groups)
ref = torch.tensor(
[
[0, 0, 0, 0, 3, 4], # f1
[1, 0, 0, 3, 5, 0], # f3
[0, 1, 3, 0, 4, 0], # f2
[1, 2, 5, 0, 6, 0], # f4
[0, 2, 0, 6, 3, -6], # f1
[2, 2, 0, 9, 8, 0], # f6
[0, 3, 0, 0, 3, -8], # f1
[1, 3, 11, 3, 7, 0], # f5
]
)
self.assertEqual(in_lengths.tolist(), [7, 18, 8])
self.assertEqual(out_lengths.tolist(), [8, 4, 17, 10])
self.assertTrue(torch.equal(res, ref))

def test_multi_permute_forward_cpu(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
for lens in lengths
]
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
refs = [[] for _ in groups]
for in_idx, out_idx, in_start, _, length, _ in permutes.tolist():
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, out_lengths.tolist(), in_lengths, out_lengths
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

def test_multi_permute_forward_meta(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="meta", requires_grad=True)
for lens in lengths
]
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
refs = [[] for _ in groups]
for in_idx, out_idx, in_start, _, length, _ in permutes.tolist():
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, out_lengths.tolist(), in_lengths, out_lengths
)
for out, ref in zip(outputs, refs):
self.assertEqual(out.shape, ref.shape)

def test_multi_permute_forward_gpu(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
for lens in lengths
]
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
refs = [[] for _ in groups]
for in_idx, out_idx, in_start, _, length, _ in permutes.tolist():
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values,
permutes.to(device=torch.device("cuda")),
out_lengths.tolist(),
in_lengths.to(device=torch.device("cuda")),
out_lengths.to(device=torch.device("cuda")),
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

def test_multi_permute_backward_cpu(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
for lens in lengths
]
ref_values = [v.detach() for v in values]
for v in ref_values:
v.requires_grad = True
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
refs = [[] for _ in groups]
for in_idx, out_idx, in_start, _, length, _ in permutes.tolist():
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values,
permutes,
out_lengths.tolist(),
in_lengths,
out_lengths,
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

ref_loss = sum((i + 1.1) * ref.sum() for i, ref in enumerate(refs))
self.assertTrue(isinstance(ref_loss, torch.Tensor))
ref_loss.backward()
loss = sum((i + 1.1) * out.sum() for i, out in enumerate(outputs))
self.assertTrue(isinstance(loss, torch.Tensor))
loss.backward()
for val, ref in zip(values, ref_values):
self.assertTrue(torch.allclose(val.grad, ref.grad))

def test_multi_permute_backward_gpu(self) -> None:
batch_size = 2048
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[96, 256], [512, 128, 768], [1024]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
for lens in lengths
]
ref_values = [v.detach() for v in values]
for v in ref_values:
v.requires_grad = True
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
refs = [[] for _ in groups]
for in_idx, out_idx, in_start, _, length, _ in permutes.tolist():
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values,
permutes.to(device=torch.device("cuda")),
out_lengths.tolist(),
in_lengths.to(device=torch.device("cuda")),
out_lengths.to(device=torch.device("cuda")),
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

ref_loss = sum((i + 1.1) * ref.sum() for i, ref in enumerate(refs))
self.assertTrue(isinstance(ref_loss, torch.Tensor))
ref_loss.backward()
loss = sum((i + 1.1) * out.sum() for i, out in enumerate(outputs))
loss = sum((i + 1.1) * out.sum() for i, out in enumerate(outputs))
loss.backward()
for val, ref in zip(values, ref_values):
self.assertTrue(torch.allclose(val.grad, ref.grad))

def test_permute_duplicates(self) -> None:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])
Expand Down Expand Up @@ -1650,8 +1818,6 @@ def test_string_vb(self) -> None:
stride_per_key_per_rank=stride_per_key_per_rank,
)

print(str(jag_tensor))

self.assertEqual(
str(jag_tensor),
"""\
Expand Down

0 comments on commit 4958d8e

Please sign in to comment.