diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index dc4033b4e..090544a48 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -1156,6 +1156,10 @@ def __init__( **ssd_tbe_params, ).to(device) + logger.info( + f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}" + ) + self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer( config, self._emb_module,