From 0f4fee68e1f8a14e199004ab129182dbcdafee1f Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Fri, 1 Mar 2024 22:58:08 -0800 Subject: [PATCH] : Support output dtype for embeddings (#1744) Summary: As title Reviewed By: jiaqizhai Differential Revision: D54337769 Privacy Context Container: L1183554 --- torchrec/distributed/batched_embedding_kernel.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 3fa3dd499..402ecafb2 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -666,6 +666,8 @@ def __init__( super().__init__(config, pg, device) weights_precision = data_type_to_sparse_type(config.data_type) + fused_params = config.fused_params or {} + output_dtype = fused_params.get("output_dtype", SparseType.FP32) self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = ( DenseTableBatchedEmbeddingBagsCodegen( list(zip(self._local_rows, self._local_cols)), @@ -675,6 +677,7 @@ def __init__( or device.type == "cpu" or not torch.cuda.is_available(), weights_precision=weights_precision, + output_dtype=output_dtype, ) ) self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict( @@ -958,6 +961,8 @@ def __init__( super().__init__(config, pg, device) weights_precision = data_type_to_sparse_type(config.data_type) + fused_params = config.fused_params or {} + output_dtype = fused_params.get("output_dtype", SparseType.FP32) self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = ( DenseTableBatchedEmbeddingBagsCodegen( list(zip(self._local_rows, self._local_cols)), @@ -967,6 +972,7 @@ def __init__( or device.type == "cpu" or not torch.cuda.is_available(), weights_precision=weights_precision, + output_dtype=output_dtype, ) ) self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(