Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support meta device compatability (#1740)
Summary: Pull Request resolved: #1740 Adding meta device support to KeyedJaggedTensor to make Ads APF models compatible with meta device. This allows calculating FLOPs for models locally without allocating memory on cpu/gpu. This diff makes the following changes: * passes `values` tensor to `_maybe_compute_length_per_key` * when on meta device, creates dummy values for `length_per_key` that still sum to the size of `values` tensor so that downstream torch operators are valid. The changes resolve the following error when running model forward with meta tensors (full trace P1191019670) ``` torch.sum(lengths.view(-1, stride), dim=1).tolist() NotImplementedError: Cannot copy out of meta tensor; no data! ``` ## Additional Context Loading the model on meta device allows us to load large FM models locally since meta tensors don't have values. See snippet below of a meta tensor which doesn't have values but contains valid `size` attribute which are used for FLOPs calculations: ``` torch.tensor([1, 2, 3], device=torch.device("meta")) >>> tensor(..., device='meta', size=(3,), dtype=torch.int64) ``` By creating dummy values for `length_per_key` that sum to the total length of the `values` tensor we make KJT operations valid. Furthermore, model embedding ops care about the relationship between the shape of indices and the values of lengths, but since there is no values in meta tensor, the shape stops mattering. The total number of operations remain the same as total length of `values` tensor is preserved. The test plan shows FLOPs calculations remain the same if model were loaded on cpu vs. meta device. See design doc for more detailed context: https://docs.google.com/document/d/1DteShg9A8Nts3OTu2SrSxMp-0W2spZ9-v1br73z0LIQ/edit Reviewed By: joshuadeng Differential Revision: D54403867 fbshipit-source-id: a09ad16e269d241f66b2cba3377d878db65b581d
- Loading branch information