Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FBGEMM kernel for KeyedTensor (PooledEmbedding) permute mapping
Summary: # context * current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward * it has several limitations: a) it has to be a full permute, duplicates are not supported; b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient; c) the function output a single KT which requires a split operation * there is some attempt to support duplicated outputs, but the backward doesn't work * this diff is trying to create a new kernel (named `multi_permute_pooled_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support # operator example usage * used in python ``` # test inputs: 3 KTs with batch_size=2048 batch_size = 2048 keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] lengths = [[96, 256], [512, 128, 768], [1024]] values = [ torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) for lens in lengths ] # target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] # accessorial arguments to the op/kernel permutes, in_lengths, out_lengths = _multi_remap_to_groups( keys, lengths, groups ) # arguments outputs = torch.ops.fbgemm.permute_multi_embedding( values, # list of tensors (on device) permutes.to(device=torch.device("cuda")), # tensor on device out_lengths.tolist(), # List[int] on CPU in_lengths.to(device=torch.device("cuda")), # tensor on device out_lengths.to(device=torch.device("cuda")), # tensor on device ) ``` * values ``` permutes = tensor( [ [0, 0, 0, 0, 3, 4], # f1 [1, 0, 0, 3, 5, 0], # f3 [0, 1, 3, 0, 4, 0], # f2 [1, 2, 5, 0, 6, 0], # f4 [0, 2, 0, 6, 3, -6], # f1 [2, 2, 0, 9, 8, 0], # f6 [0, 3, 0, 0, 3, -8], # f1 [1, 3, 11, 3, 7, 0], # f5 ] ) ``` # details 1. from the above example usage, we can clean see that the operatior takes in the following: a) values: List[torch.Tensor], which represents the input KTs b) permutes: torch.Tensor, which contains the permute information, will be explained later c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead d) in_lengths: torch.Tensor, lengths of input tensors, which is on device e) out_lengths: torch.Tensor, lengths of output tensors, which is on device 2. the operator returns a list of tensors, which represents the permuted KTs 3. `permute` is the most critical argument in this operator: a) 2-D tensor b) each row represents key (feature) permute move c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump] d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors # performance notes The good: 1. the algorithm is designed in a way that it doesn't need to know in advance whether the 1-to-N mapping exists in the permutes. 2. `_all_keys_used_once` is no longer needed 3. no longer need a torch.cat before calling the old operator The same bad: 1. it requires several HtoD communications (move tensor to device): a) 3 tensors, which are `permutes`, `input_lengths`, and `output_lengths`. Those tensors needs to be on the device so that the cuda kernels has access to it. b) 2 lists of (scalar_t*) pointers, input and output tensor lists. c) Didn't find a good way to let the kernel knows the address of the lists of input/output tensors, because the lists are also need to be on the device. 2. tensor.contiguous for the backward function, it looks like the grad from the backward are somehow not contiguous Differential Revision: D57055616
- Loading branch information