Skip to content

Commit

Permalink
Refactor model_parallel tests to allow different (device, backend) co…
Browse files Browse the repository at this point in the history
…mbination (pytorch#1667)

Summary:

Refactoring to make model_parallel tests to take more combinations of (device, backend).

1. Won't force device to cuda if backend is NCCL or force device to cpu if backend is gloo. i.e. allow the combinations of (NCCL, CPU) and (Gloo, GPU).
2. Refactor test_parameter_init to test parameter init on (nccl, gpu) and (gloo, cpu) combinations.

Reviewed By: sarckk

Differential Revision: D53149924
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jan 31, 2024
1 parent 2005d7f commit 44fce44
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 53 deletions.
14 changes: 7 additions & 7 deletions torchrec/distributed/test_utils/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ def __init__(
self.backend = backend
self.local_size = local_size

if backend == "nccl":
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
self.device: torch.device = device
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
self.device: torch.device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(self.device)

torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
self.device: torch.device = torch.device("cpu")
torch.use_deterministic_algorithms(True)

self.pg: Optional[dist.ProcessGroup] = None

def __enter__(self) -> "MultiProcessContext":
Expand Down
22 changes: 17 additions & 5 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

class ModelParallelTestShared(MultiProcessTestBase):
@seed_and_log
def setUp(self) -> None:
def setUp(self, backend: str = "nccl") -> None:
super().setUp()

num_features = 4
Expand Down Expand Up @@ -76,12 +76,14 @@ def setUp(self) -> None:
for feature in table.feature_names
]
}
self.backend = backend
if torch.cuda.is_available():
self.device = torch.device("cuda")
self.backend = "nccl"
else:
self.device = torch.device("cpu")
self.backend = "gloo"

if self.backend == "nccl" and self.device == torch.device("cpu"):
self.skipTest("NCCL not supported on CPUs.")

def _test_sharding(
self,
Expand Down Expand Up @@ -122,8 +124,8 @@ def _test_sharding(

@skip_if_asan_class
class ModelParallelBase(ModelParallelTestShared):
def setUp(self) -> None:
super().setUp()
def setUp(self, backend: str = "nccl") -> None:
super().setUp(backend=backend)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
Expand Down Expand Up @@ -166,6 +168,11 @@ def test_sharding_rw(
],
variable_batch_size: bool,
) -> None:
if self.backend == "gloo":
self.skipTest(
"Gloo reduce_scatter_base fallback not supported with async_op=True"
)

sharding_type = ShardingType.ROW_WISE.value
kernel_type = EmbeddingComputeKernel.FUSED.value
assume(
Expand Down Expand Up @@ -367,6 +374,11 @@ def test_sharding_variable_batch(
sharding_type: str,
global_constant_batch: bool,
) -> None:
if self.backend == "gloo":
# error is from FBGEMM, it says CPU even if we are on GPU.
self.skipTest(
"bounds_check_indices on CPU does not support variable length (batch size)"
)
self._test_sharding(
# pyre-ignore[6]
sharders=[
Expand Down
69 changes: 31 additions & 38 deletions torchrec/distributed/test_utils/test_model_parallel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,26 +178,30 @@ def _test_sharded_forward(


class ModelParallelSparseOnlyBase(unittest.TestCase):
def tearDown(self) -> None:
dist.destroy_process_group()

def test_sharding_ebc_as_top_level(self) -> None:
def setUp(self, backend: str = "nccl") -> None:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["NCCL_SOCKET_IFNAME"] = "lo"

self.backend = backend
if torch.cuda.is_available():
curr_device = torch.device("cuda:0")
torch.cuda.set_device(curr_device)
backend = "nccl"
self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device)
else:
curr_device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)
self.device = torch.device("cpu")

if self.backend == "nccl" and self.device == torch.device("cpu"):
self.skipTest("NCCL not supported on CPUs.")

dist.init_process_group(backend=self.backend)

def tearDown(self) -> None:
dist.destroy_process_group()

def test_sharding_ebc_as_top_level(self) -> None:
embedding_dim = 128
num_embeddings = 256
ebc = EmbeddingBagCollection(
Expand All @@ -213,27 +217,11 @@ def test_sharding_ebc_as_top_level(self) -> None:
],
)

model = DistributedModelParallel(ebc, device=curr_device)
model = DistributedModelParallel(ebc, device=self.device)

self.assertTrue(isinstance(model.module, ShardedEmbeddingBagCollection))

def test_sharding_fused_ebc_as_top_level(self) -> None:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["NCCL_SOCKET_IFNAME"] = "lo"

if torch.cuda.is_available():
curr_device = torch.device("cuda:0")
torch.cuda.set_device(curr_device)
backend = "nccl"
else:
curr_device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)

embedding_dim = 128
num_embeddings = 256
ebc = FusedEmbeddingBagCollection(
Expand All @@ -251,26 +239,30 @@ def test_sharding_fused_ebc_as_top_level(self) -> None:
optimizer_kwargs={"lr": 0.02},
)

model = DistributedModelParallel(ebc, device=curr_device)
model = DistributedModelParallel(ebc, device=self.device)

self.assertTrue(isinstance(model.module, ShardedFusedEmbeddingBagCollection))


class ModelParallelStateDictBase(unittest.TestCase):
def setUp(self) -> None:
def setUp(self, backend: str = "nccl") -> None:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["NCCL_SOCKET_IFNAME"] = "lo"

self.backend = backend
if torch.cuda.is_available():
self.device = torch.device("cuda:0")
backend = "nccl"
torch.cuda.set_device(self.device)
else:
self.device = torch.device("cpu")
backend = "gloo"

if self.backend == "nccl" and self.device == torch.device("cpu"):
self.skipTest("NCCL not supported on CPUs.")

dist.init_process_group(backend=backend)

num_features = 4
Expand Down Expand Up @@ -377,27 +369,28 @@ def __init__(self, device: str, val: float) -> None:
def reset_parameters(self) -> None:
nn.init.constant_(self.p, self.val)

dist.destroy_process_group()
dist.init_process_group(backend="gloo")

# Check that already allocated parameters are left 'as is'.
cpu_model = MyModel(device="cpu", val=3.2)
unsharded_model = MyModel(device=self.device, val=3.2)
sharded_model = DistributedModelParallel(
cpu_model,
unsharded_model,
device=self.device,
)
sharded_param = next(sharded_model.parameters())
np.testing.assert_array_equal(
np.array([3.2, 3.2, 3.2], dtype=np.float32), sharded_param.detach().numpy()
np.array([3.2, 3.2, 3.2], dtype=np.float32),
sharded_param.detach().cpu().numpy(),
)

# Check that parameters over 'meta' device are allocated and initialized.
meta_model = MyModel(device="meta", val=7.5)
sharded_model = DistributedModelParallel(
meta_model,
device=self.device,
)
sharded_param = next(sharded_model.parameters())
np.testing.assert_array_equal(
np.array([7.5, 7.5, 7.5], dtype=np.float32), sharded_param.detach().numpy()
np.array([7.5, 7.5, 7.5], dtype=np.float32),
sharded_param.detach().cpu().numpy(),
)

def test_meta_device_dmp_state_dict(self) -> None:
Expand Down
11 changes: 8 additions & 3 deletions torchrec/distributed/tests/test_model_parallel_gloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
ModelParallelStateDictBase,
)

# CPU tests for Gloo.


class ModelParallelTestGloo(ModelParallelBase):
pass
def setUp(self, backend: str = "gloo") -> None:
super().setUp(backend=backend)


class ModelParallelStateDictTestGloo(ModelParallelStateDictBase):
pass
def setUp(self, backend: str = "gloo") -> None:
super().setUp(backend=backend)


class ModelParallelSparseOnlyTestGloo(ModelParallelSparseOnlyBase):
pass
def setUp(self, backend: str = "gloo") -> None:
super().setUp(backend=backend)

0 comments on commit 44fce44

Please sign in to comment.