diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 3fa3dd499..402ecafb2 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -666,6 +666,8 @@ def __init__( super().__init__(config, pg, device) weights_precision = data_type_to_sparse_type(config.data_type) + fused_params = config.fused_params or {} + output_dtype = fused_params.get("output_dtype", SparseType.FP32) self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = ( DenseTableBatchedEmbeddingBagsCodegen( list(zip(self._local_rows, self._local_cols)), @@ -675,6 +677,7 @@ def __init__( or device.type == "cpu" or not torch.cuda.is_available(), weights_precision=weights_precision, + output_dtype=output_dtype, ) ) self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict( @@ -958,6 +961,8 @@ def __init__( super().__init__(config, pg, device) weights_precision = data_type_to_sparse_type(config.data_type) + fused_params = config.fused_params or {} + output_dtype = fused_params.get("output_dtype", SparseType.FP32) self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = ( DenseTableBatchedEmbeddingBagsCodegen( list(zip(self._local_rows, self._local_cols)), @@ -967,6 +972,7 @@ def __init__( or device.type == "cpu" or not torch.cuda.is_available(), weights_precision=weights_precision, + output_dtype=output_dtype, ) ) self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(