diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index b74b726a0..c24b912d8 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -624,6 +624,9 @@ def step(self, closure: Any = None) -> None: def set_optimizer_step(self, step: int) -> None: self._emb_module.set_optimizer_step(step) + def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: + self._emb_module.update_hyper_parameters(params_dict) + def _gen_named_parameters_by_table_ssd( emb_module: SSDTableBatchedEmbeddingBags, diff --git a/torchrec/distributed/global_settings.py b/torchrec/distributed/global_settings.py index fd86ac4bb..6d13ffd3e 100644 --- a/torchrec/distributed/global_settings.py +++ b/torchrec/distributed/global_settings.py @@ -30,3 +30,7 @@ def construct_sharded_tensor_from_metadata_enabled() -> bool: return ( os.environ.get(TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV, "0") == "1" ) + + +def enable_construct_sharded_tensor_from_metadata() -> None: + os.environ[TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV] = "1"