Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use KTRegroupAsDict to Replace KeyedTensor.regroup_as_dict in EBC_spa…
…rse_arch (pytorch#2272) Summary: Pull Request resolved: pytorch#2272 # context * Currently in APF a class method `KeyedTensor.regroup_as_dict` is used for permuting and regrouping the pooled embeddings * This function is not very efficent in training because every time it needs to calculate the necessary metadata arguments for the fbgemm operator. Moreover, it also needs to do a host-to-device data transfer to move the metadata tensor to GPU: [codepointer](https://fburl.com/code/fmhkg6sn) ``` # // metadata calculation for permute, offsets, etc. all are tensors permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups( keys, lengths, groups ) values = torch.concat(values, dim=1) device = values.device permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad( values, _pin_and_move(offsets, device), # // needs to use pinned_memory and _pin_and_move(permute, device), # // initiate several H2D transfers _pin_and_move(inv_offsets, device), _pin_and_move(inv_permute, device), ) ``` * However, these metadata (Tensors) won't change during training, so we can actually cache the results for the first batch and re-use them. * The recommended usage is `KTRegroupAsDict`, as a module for permuting and regrouping work. This module can store these metadata tensors as instance variables in the first run of the forward pass, then re-used them directly afterwards. NOTE: This `KTRegroupAsDict` module is also IR-compatible with custom-op approach in D57578012 # numbers with IG FM model * [baseline trace](https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftraces%2Fdynocli%2Faps-ig_ctr_aps_vanilla_baseline-a6e954af60%2F1%2Frank-0.Aug_04_00_01_16.3621.pt.trace.json.gz&bucket=aps_traces), [experimental trace](https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftraces%2Fdynocli%2Faps-ig_ctr_aps_vanilla_both_diffs-3503324cb2%2F4%2Frank-0.Aug_03_20_16_01.3608.pt.trace.json.gz&bucket=aps_traces) * the GPU runtime saving is about 1ms, which is consistent with the benchmark results. the overall duration (per batch) is 200ms. so the improvement is about 0.5% {F1792260523} * cpu runtime saving is 1.3ms {F1792260686} # additional notes 1. we don't expect any results change, these two approaches should produce exactly the same results. 2. regroup_as_dict works on pooled embeddings, which are the results from embedding table lookup. so for each feature, the length should be constant as it's defined in the config. Reviewed By: really121 Differential Revision: D43405610
- Loading branch information