diff --git a/torchrec/distributed/test_utils/multi_process.py b/torchrec/distributed/test_utils/multi_process.py index 7542ac853..3580eda0b 100644 --- a/torchrec/distributed/test_utils/multi_process.py +++ b/torchrec/distributed/test_utils/multi_process.py @@ -36,16 +36,15 @@ def __init__( self.backend = backend self.local_size = local_size - if backend == "nccl": - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - else: - device = torch.device("cpu") - self.device: torch.device = device - torch.use_deterministic_algorithms(True) if torch.cuda.is_available(): + self.device: torch.device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(self.device) + torch.backends.cudnn.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False + else: + self.device: torch.device = torch.device("cpu") + torch.use_deterministic_algorithms(True) self.pg: Optional[dist.ProcessGroup] = None def __enter__(self) -> "MultiProcessContext": diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index 778b2c875..c8dc3c2d1 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -29,7 +29,7 @@ class ModelParallelTestShared(MultiProcessTestBase): @seed_and_log - def setUp(self) -> None: + def setUp(self, backend: str = "nccl") -> None: super().setUp() num_features = 4 @@ -76,12 +76,11 @@ def setUp(self) -> None: for feature in table.feature_names ] } + self.backend = backend if torch.cuda.is_available(): self.device = torch.device("cuda") - self.backend = "nccl" else: self.device = torch.device("cpu") - self.backend = "gloo" def _test_sharding( self, @@ -122,8 +121,8 @@ def _test_sharding( @skip_if_asan_class class ModelParallelBase(ModelParallelTestShared): - def setUp(self) -> None: - super().setUp() + def setUp(self, backend: str = "nccl") -> None: + super().setUp(backend=backend) @unittest.skipIf( torch.cuda.device_count() <= 1, diff --git a/torchrec/distributed/test_utils/test_model_parallel_base.py b/torchrec/distributed/test_utils/test_model_parallel_base.py index ded796f6b..47fff68c3 100644 --- a/torchrec/distributed/test_utils/test_model_parallel_base.py +++ b/torchrec/distributed/test_utils/test_model_parallel_base.py @@ -178,10 +178,7 @@ def _test_sharded_forward( class ModelParallelSparseOnlyBase(unittest.TestCase): - def tearDown(self) -> None: - dist.destroy_process_group() - - def test_sharding_ebc_as_top_level(self) -> None: + def setUp(self, backend: str = "nccl") -> None: os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["LOCAL_WORLD_SIZE"] = "1" @@ -189,15 +186,19 @@ def test_sharding_ebc_as_top_level(self) -> None: os.environ["MASTER_PORT"] = str(get_free_port()) os.environ["NCCL_SOCKET_IFNAME"] = "lo" + self.backend = backend if torch.cuda.is_available(): - curr_device = torch.device("cuda:0") - torch.cuda.set_device(curr_device) - backend = "nccl" + self.device = torch.device("cuda:0") + torch.cuda.set_device(self.device) else: - curr_device = torch.device("cpu") - backend = "gloo" - dist.init_process_group(backend=backend) + self.device = torch.device("cpu") + dist.init_process_group(backend=self.backend) + + def tearDown(self) -> None: + dist.destroy_process_group() + + def test_sharding_ebc_as_top_level(self) -> None: embedding_dim = 128 num_embeddings = 256 ebc = EmbeddingBagCollection( @@ -213,27 +214,11 @@ def test_sharding_ebc_as_top_level(self) -> None: ], ) - model = DistributedModelParallel(ebc, device=curr_device) + model = DistributedModelParallel(ebc, device=self.device) self.assertTrue(isinstance(model.module, ShardedEmbeddingBagCollection)) def test_sharding_fused_ebc_as_top_level(self) -> None: - os.environ["RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["LOCAL_WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = str("localhost") - os.environ["MASTER_PORT"] = str(get_free_port()) - os.environ["NCCL_SOCKET_IFNAME"] = "lo" - - if torch.cuda.is_available(): - curr_device = torch.device("cuda:0") - torch.cuda.set_device(curr_device) - backend = "nccl" - else: - curr_device = torch.device("cpu") - backend = "gloo" - dist.init_process_group(backend=backend) - embedding_dim = 128 num_embeddings = 256 ebc = FusedEmbeddingBagCollection( @@ -251,26 +236,26 @@ def test_sharding_fused_ebc_as_top_level(self) -> None: optimizer_kwargs={"lr": 0.02}, ) - model = DistributedModelParallel(ebc, device=curr_device) + model = DistributedModelParallel(ebc, device=self.device) self.assertTrue(isinstance(model.module, ShardedFusedEmbeddingBagCollection)) class ModelParallelStateDictBase(unittest.TestCase): - def setUp(self) -> None: + def setUp(self, backend: str = "nccl") -> None: os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["LOCAL_WORLD_SIZE"] = "1" os.environ["MASTER_ADDR"] = str("localhost") os.environ["MASTER_PORT"] = str(get_free_port()) os.environ["NCCL_SOCKET_IFNAME"] = "lo" + + self.backend = backend if torch.cuda.is_available(): self.device = torch.device("cuda:0") - backend = "nccl" torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") - backend = "gloo" dist.init_process_group(backend=backend) num_features = 4 @@ -377,27 +362,28 @@ def __init__(self, device: str, val: float) -> None: def reset_parameters(self) -> None: nn.init.constant_(self.p, self.val) - dist.destroy_process_group() - dist.init_process_group(backend="gloo") - # Check that already allocated parameters are left 'as is'. - cpu_model = MyModel(device="cpu", val=3.2) + unsharded_model = MyModel(device=self.device, val=3.2) sharded_model = DistributedModelParallel( - cpu_model, + unsharded_model, + device=self.device, ) sharded_param = next(sharded_model.parameters()) np.testing.assert_array_equal( - np.array([3.2, 3.2, 3.2], dtype=np.float32), sharded_param.detach().numpy() + np.array([3.2, 3.2, 3.2], dtype=np.float32), + sharded_param.detach().cpu().numpy(), ) # Check that parameters over 'meta' device are allocated and initialized. meta_model = MyModel(device="meta", val=7.5) sharded_model = DistributedModelParallel( meta_model, + device=self.device, ) sharded_param = next(sharded_model.parameters()) np.testing.assert_array_equal( - np.array([7.5, 7.5, 7.5], dtype=np.float32), sharded_param.detach().numpy() + np.array([7.5, 7.5, 7.5], dtype=np.float32), + sharded_param.detach().cpu().numpy(), ) def test_meta_device_dmp_state_dict(self) -> None: