-
Notifications
You must be signed in to change notification settings - Fork 479
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implementation of fbgemm op - permute_multi_embedding (#2120)
Summary: X-link: pytorch/FBGEMM#2738 Pull Request resolved: #2120 # context * current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward * it has several limitations: a) it has to be a full permute, duplicates are not supported; b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient; c) the function output a single KT which requires a split operation * there is some attempt to support duplicated outputs, but the backward doesn't work * this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support # notes * this diff focuses on the implemenation and test of the operator * performance analysis and benchmark are in the next diff # operator example usage * used in python ``` # test inputs: 3 KTs with batch_size=2048 batch_size = 2048 keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] lengths = [[96, 256], [512, 128, 768], [1024]] values = [ torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) for lens in lengths ] # target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] # accessorial arguments to the op/kernel permutes, in_lengths, out_lengths = _multi_remap_to_groups( keys, lengths, groups ) # arguments outputs = torch.ops.fbgemm.permute_multi_embedding_internal_testing( values, permutes, in_lengths, out_lengths ) ``` * permutes ``` # each row represents a key (feature) permute move, which consists of the following parameters: # [input_tensor_idx, output_tensor_idx, input_key_idx, output_key_idx, key_length, magic_jump] permutes = tensor( [ [0, 0, 0, 0, 3, 4], # f1 [1, 0, 0, 3, 5, 0], # f3 [0, 1, 3, 0, 4, 0], # f2 [1, 2, 5, 0, 6, 0], # f4 [0, 2, 0, 6, 3, -6], # f1 [2, 2, 0, 9, 8, 0], # f6 [0, 3, 0, 0, 3, -8], # f1 [1, 3, 11, 3, 7, 0], # f5 ] ) ``` # details 1. from the above example usage, we can clearly see that the operatior takes in the following: a) values: List[torch.Tensor], which represents the input KTs b) permutes: torch.Tensor, which contains the permute information, will be explained later c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead d) in_lengths: torch.Tensor, lengths of input tensors, which is on device e) out_lengths: torch.Tensor, lengths of output tensors, which is on device 2. the operator returns a list of tensors, which represents the permuted KTs 3. `permute` is the most critical argument in this operator: a) 2-D tensor b) each row represents a key (feature) permute move c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump] d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors 4. The magic_jump a) It's only used in the backward computation b) it's usually 0, means no jump c) it's non-zero when there is a duplicate in the permute, e.g., the same feature appears more than once in the output d) the `magic_jump` is the next index of the very same feature in the permute sequence with some modifications e) modification-1: `magic_jump` is positive when it's the first of its kind [Start] f) modification-2: `magic_jump` is negative when it's not the first of its kind [Continue] g) modification-3: `magic_jump` is the negative value of the length of the permute sequence when it's the last of its kind. [Stop] Reviewed By: sryap Differential Revision: D57055616 fbshipit-source-id: 16673d3a2eafab93b08d4ff3c43d54366966064a
- Loading branch information
1 parent
90f3054
commit 3db28b3
Showing
2 changed files
with
295 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters