Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix dtype mismatch between weight and per_sample_weights (#1758)
Summary: Pull Request resolved: #1758 # Background T176105639 |case |embedding bag weight |per_sample_weight |nn.EmbeddingBag, device="cpu"|nn.EmbeddingBag, device="cuda" |nn.EmbeddingBag, device="meta" |fbgemm lookup | |A|fp32|fp32|good|good|good|good| |B|fp16|fp32|Error:Expected tensor for argument #1 'weight' to have the same type as tensor for argument #1 'per_sample_weights'; but type torch.HalfTensor does not equal torch.FloatTensor |Error:expected scalar type Half but found Float|failed [check](https://fburl.com/code/ng9pv1vp) that forces weight dtype == per_sample_weights dtype|good| |C|fp16|fp16|good|good|good|good now with D54370192. Previous error: P1046999270, RuntimeError: "expected scalar type Float but found Half from fbgemm call"| Notebook to see nn.EmbeddingBag forward errors: N5007274. Currently we are in case A. Users need to add `use_fp32_embedding` in training to force embedding bag dtype to be fp32. However, users actually hope for case B to use fp16 as the embedding bag weight to reduce memory usage. When deleting `use_fp32_embedding`, they would fail the [check that forces weight dtype == per_sample_weights dtype](https://www.internalfb.com/code/fbsource/[e750b9f69f7f758682000804409456103510078c]/fbcode/caffe2/torch/_meta_registrations.py?lines=3521-3524) in meta_registration. Therefore, this diff aims to achieve case C - make dtype the same between embedding module weight and per_sample_weights. With the backend fbgemm lookup to support Half for per_sample_weights (D54370192), this diff introduces `dtype` in all feature process classes and initializes per_sample_weights according to the passed dtype. # Reference diffs to resolve this issue Diff 1: D52591217 This passes embedding bag dtype to feature_processor to make per_sample_weights same dtype as emb bag weight. However, is_meta also needs to be passed because of case C because fbgemm did not support per_sample_weights = fp16 (see the above table) at that time. Therefore users were forced to only make per_sample_weights fp16 when it is on meta. The solution requires too many hacks. Diff 2: D53232739 Basically doing the same thing in diff 1 D52591217, except that the hack is added in TorchRec library. This adds an if in EBC and PEA for: when emb bag weight is fp16, it forces per_sample_weight fp16 too. Reviewed By: henrylhtsang Differential Revision: D54526190 fbshipit-source-id: 969cb64c4af345ea222e8a2c7e5be0d9af0d0ae3
- Loading branch information