diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py
index b74b726a0..c24b912d8 100644
--- a/torchrec/distributed/batched_embedding_kernel.py
+++ b/torchrec/distributed/batched_embedding_kernel.py
@@ -624,6 +624,9 @@ def step(self, closure: Any = None) -> None:
     def set_optimizer_step(self, step: int) -> None:
         self._emb_module.set_optimizer_step(step)
 
+    def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None:
+        self._emb_module.update_hyper_parameters(params_dict)
+
 
 def _gen_named_parameters_by_table_ssd(
     emb_module: SSDTableBatchedEmbeddingBags,
diff --git a/torchrec/distributed/global_settings.py b/torchrec/distributed/global_settings.py
index fd86ac4bb..6d13ffd3e 100644
--- a/torchrec/distributed/global_settings.py
+++ b/torchrec/distributed/global_settings.py
@@ -30,3 +30,7 @@ def construct_sharded_tensor_from_metadata_enabled() -> bool:
     return (
         os.environ.get(TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV, "0") == "1"
     )
+
+
+def enable_construct_sharded_tensor_from_metadata() -> None:
+    os.environ[TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV] = "1"