diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 0fd0dcc45..e1d48972e 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -163,6 +163,7 @@ def __init__( self._device: torch.device = ( device if device is not None else torch.device("cpu") ) + self._dtypes: List[int] = [] table_names = set() for embedding_config in tables: @@ -182,6 +183,7 @@ def __init__( include_last_offset=True, dtype=dtype, ) + self._dtypes.append(embedding_config.data_type.value) if not embedding_config.feature_names: embedding_config.feature_names = [embedding_config.name] @@ -217,10 +219,19 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: for i, embedding_bag in enumerate(self.embedding_bags.values()): for feature_name in self._feature_names[i]: f = feature_dict[feature_name] + per_sample_weights: Optional[torch.Tensor] = None + if self._is_weighted: + per_sample_weights = ( + f.weights().half() + if self._dtypes[i] == DataType.FP16.value + else f.weights() + ) res = embedding_bag( input=f.values(), offsets=f.offsets(), - per_sample_weights=f.weights() if self._is_weighted else None, + per_sample_weights=per_sample_weights + if self._is_weighted + else None, ).float() pooled_embeddings.append(res) return KeyedTensor(