Skip to content

Commit

Permalink
Fix dtype mismatch between weight and per_sample_weights (#1758)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1758

# Background
T176105639
|case |embedding bag weight |per_sample_weight |nn.EmbeddingBag, device="cpu"|nn.EmbeddingBag, device="cuda" |nn.EmbeddingBag, device="meta" |fbgemm lookup |
|A|fp32|fp32|good|good|good|good|
|B|fp16|fp32|Error:Expected tensor for argument #1 'weight' to have the same type as tensor for argument #1 'per_sample_weights'; but type torch.HalfTensor does not equal torch.FloatTensor |Error:expected scalar type Half but found Float|failed [check](https://fburl.com/code/ng9pv1vp) that forces weight dtype ==  per_sample_weights dtype|good|
|C|fp16|fp16|good|good|good|good now with D54370192.       Previous error: P1046999270, RuntimeError: "expected scalar type Float but found Half from fbgemm call"|

Notebook to see nn.EmbeddingBag forward errors: N5007274.

Currently we are in case A. Users need to add `use_fp32_embedding` in training to force embedding bag dtype to be fp32. However, users actually hope for case B to use fp16 as the embedding bag weight to reduce memory usage. When deleting `use_fp32_embedding`, they would fail the [check that forces weight dtype == per_sample_weights dtype](https://www.internalfb.com/code/fbsource/[e750b9f69f7f758682000804409456103510078c]/fbcode/caffe2/torch/_meta_registrations.py?lines=3521-3524) in meta_registration.

Therefore, this diff aims to achieve case C - make dtype the same between embedding module weight and per_sample_weights. With the backend fbgemm lookup to support Half for per_sample_weights (D54370192), this diff introduces `dtype` in all feature process classes and initializes per_sample_weights according to the passed dtype.

# Reference diffs to resolve this issue

Diff 1: D52591217
This passes embedding bag dtype to feature_processor to make per_sample_weights same dtype as emb bag weight. However, is_meta also needs to be passed because of case C because fbgemm did not support per_sample_weights = fp16 (see the above table) at that time. Therefore users were forced to only make per_sample_weights fp16 when it is on meta. The solution requires too many hacks.

Diff 2: D53232739
Basically doing the same thing in diff 1 D52591217, except that the hack is added in TorchRec library. This adds an if in EBC and PEA for: when emb bag weight is fp16, it forces per_sample_weight fp16 too.

Reviewed By: henrylhtsang

Differential Revision: D54526190

fbshipit-source-id: 969cb64c4af345ea222e8a2c7e5be0d9af0d0ae3
  • Loading branch information
ge0405 authored and facebook-github-bot committed Mar 12, 2024
1 parent 1cd088f commit 5856c4d
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
self._embedding_bag_configs = tables
self._lengths_per_embedding: List[int] = []
self._dtypes: List[int] = []

table_names = set()
for embedding_config in tables:
Expand All @@ -183,6 +184,7 @@ def __init__(
)
if device is None:
device = self.embedding_bags[embedding_config.name].weight.device
self._dtypes.append(embedding_config.data_type.value)

if not embedding_config.feature_names:
embedding_config.feature_names = [embedding_config.name]
Expand Down Expand Up @@ -219,10 +221,19 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
for i, embedding_bag in enumerate(self.embedding_bags.values()):
for feature_name in self._feature_names[i]:
f = feature_dict[feature_name]
per_sample_weights: Optional[torch.Tensor] = None
if self._is_weighted:
per_sample_weights = (
f.weights().half()
if self._dtypes[i] == DataType.FP16.value
else f.weights()
)
res = embedding_bag(
input=f.values(),
offsets=f.offsets(),
per_sample_weights=f.weights() if self._is_weighted else None,
per_sample_weights=(
per_sample_weights if self._is_weighted else None
),
).float()
pooled_embeddings.append(res)
return KeyedTensor(
Expand Down

0 comments on commit 5856c4d

Please sign in to comment.