diff --git a/torchrec/distributed/tests/test_fp_embeddingbag.py b/torchrec/distributed/tests/test_fp_embeddingbag.py index 6cfed8715..efcdc20dd 100644 --- a/torchrec/distributed/tests/test_fp_embeddingbag.py +++ b/torchrec/distributed/tests/test_fp_embeddingbag.py @@ -335,18 +335,22 @@ def test_sharding_ebc( use_fp_collection=use_fp_collection, ) + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) # pyre-ignore - @given(use_fp_collection=st.booleans()) - def test_sharding_fp_ebc_from_meta(self, use_fp_collection: bool) -> None: + @given(use_fp_collection=st.booleans(), backend=st.sampled_from(["nccl", "gloo"])) + def test_sharding_fp_ebc_from_meta( + self, use_fp_collection: bool, backend: str + ) -> None: embedding_bag_config, kjt_input_per_rank = get_configs_and_kjt_inputs() self._run_multi_process_test( callable=_test_sharding_from_meta, world_size=2, tables=embedding_bag_config, sharder=FeatureProcessedEmbeddingBagCollectionSharder(), - backend="nccl" - if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) - else "gloo", + backend=backend, use_fp_collection=use_fp_collection, ) diff --git a/torchrec/distributed/tests/test_mc_embedding.py b/torchrec/distributed/tests/test_mc_embedding.py index 98732a284..98259684a 100644 --- a/torchrec/distributed/tests/test_mc_embedding.py +++ b/torchrec/distributed/tests/test_mc_embedding.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn +from hypothesis import given, settings, strategies as st from torchrec.distributed.embedding import ShardedEmbeddingCollection from torchrec.distributed.mc_embedding import ( ManagedCollisionEmbeddingCollectionSharder, @@ -256,13 +257,14 @@ def _test_sharding_and_remapping( # noqa C901 @skip_if_asan_class class ShardedMCEmbeddingCollectionParallelTest(MultiProcessTestBase): - - # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs, this test requires at least two GPUs", ) - def test_uneven_sharding(self) -> None: + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=20000) + def test_uneven_sharding(self, backend: str) -> None: WORLD_SIZE = 2 embedding_config = [ @@ -285,15 +287,17 @@ def test_uneven_sharding(self) -> None: world_size=WORLD_SIZE, tables=embedding_config, sharder=ManagedCollisionEmbeddingCollectionSharder(), - backend="nccl", + backend=backend, ) - # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs, this test requires at least two GPUs", ) - def test_even_sharding(self) -> None: + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=20000) + def test_even_sharding(self, backend: str) -> None: WORLD_SIZE = 2 embedding_config = [ @@ -316,15 +320,17 @@ def test_even_sharding(self) -> None: world_size=WORLD_SIZE, tables=embedding_config, sharder=ManagedCollisionEmbeddingCollectionSharder(), - backend="nccl", + backend=backend, ) - # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs, this test requires at least two GPUs", ) - def test_sharding_zch_mc_ec(self) -> None: + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=20000) + def test_sharding_zch_mc_ec(self, backend: str) -> None: WORLD_SIZE = 2 @@ -420,15 +426,17 @@ def test_sharding_zch_mc_ec(self) -> None: kjt_input_per_rank=kjt_input_per_rank, kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, sharder=ManagedCollisionEmbeddingCollectionSharder(), - backend="nccl", + backend=backend, ) - # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs, this test requires at least two GPUs", ) - def test_sharding_zch_mch_mc_ec(self) -> None: + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=20000) + def test_sharding_zch_mch_mc_ec(self, backend: str) -> None: WORLD_SIZE = 2 @@ -553,5 +561,5 @@ def test_sharding_zch_mch_mc_ec(self) -> None: kjt_input_per_rank=kjt_input_per_rank, kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, sharder=ManagedCollisionEmbeddingCollectionSharder(), - backend="nccl", + backend=backend, ) diff --git a/torchrec/distributed/tests/test_mc_embeddingbag.py b/torchrec/distributed/tests/test_mc_embeddingbag.py index dba476d08..55235ac13 100644 --- a/torchrec/distributed/tests/test_mc_embeddingbag.py +++ b/torchrec/distributed/tests/test_mc_embeddingbag.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn +from hypothesis import given, settings, strategies as st from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.mc_embeddingbag import ( ManagedCollisionEmbeddingBagCollectionSharder, @@ -264,12 +265,14 @@ def _test_sharding_and_remapping( # noqa C901 @skip_if_asan_class class ShardedMCEmbeddingBagCollectionParallelTest(MultiProcessTestBase): - # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs, this test requires at least two GPUs", ) - def test_uneven_sharding(self) -> None: + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=20000) + def test_uneven_sharding(self, backend: str) -> None: WORLD_SIZE = 2 embedding_bag_config = [ @@ -292,15 +295,17 @@ def test_uneven_sharding(self) -> None: world_size=WORLD_SIZE, tables=embedding_bag_config, sharder=ManagedCollisionEmbeddingBagCollectionSharder(), - backend="nccl", + backend=backend, ) - # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs, this test requires at least two GPUs", ) - def test_even_sharding(self) -> None: + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=20000) + def test_even_sharding(self, backend: str) -> None: WORLD_SIZE = 2 embedding_bag_config = [ @@ -323,15 +328,17 @@ def test_even_sharding(self) -> None: world_size=WORLD_SIZE, tables=embedding_bag_config, sharder=ManagedCollisionEmbeddingBagCollectionSharder(), - backend="nccl", + backend=backend, ) - # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs, this test requires at least two GPUs", ) - def test_sharding_zch_mc_ebc(self) -> None: + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=20000) + def test_sharding_zch_mc_ebc(self, backend: str) -> None: WORLD_SIZE = 2 @@ -427,15 +434,17 @@ def test_sharding_zch_mc_ebc(self) -> None: kjt_input_per_rank=kjt_input_per_rank, kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, sharder=ManagedCollisionEmbeddingBagCollectionSharder(), - backend="nccl", + backend=backend, ) - # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs, this test requires at least two GPUs", ) - def test_sharding_zch_mch_mc_ebc(self) -> None: + # pyre-ignore + @given(backend=st.sampled_from(["nccl"])) + @settings(deadline=20000) + def test_sharding_zch_mch_mc_ebc(self, backend: str) -> None: WORLD_SIZE = 2 @@ -560,5 +569,5 @@ def test_sharding_zch_mch_mc_ebc(self) -> None: kjt_input_per_rank=kjt_input_per_rank, kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, sharder=ManagedCollisionEmbeddingBagCollectionSharder(), - backend="nccl", + backend=backend, )