diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 6eff1c3a0..0fae8d51c 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -219,7 +219,6 @@ def __init__( init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None, - separate_pg_for_ddp: bool = False, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") @@ -233,11 +232,6 @@ def __init__( assert pg is not None, "Process group is not initialized" env = ShardingEnv.from_process_group(pg) self._env: ShardingEnv = env - if separate_pg_for_ddp: - ddp_pg = dist.new_group() - self.ddp_env: ShardingEnv = ShardingEnv.from_process_group(ddp_pg) - else: - self.ddp_env: ShardingEnv = self._env if device is None: device = torch.device("cpu") @@ -309,7 +303,7 @@ def init_data_parallel(self) -> None: # Allocate any 'meta' tensors if self.init_parameters: self._init_parameters(self._dmp_wrapped_module) - self._data_parallel_wrapper.wrap(self, self.ddp_env, self.device) + self._data_parallel_wrapper.wrap(self, self._env, self.device) self._ddp_wrapped = True def copy(