Skip to content

Commit

Permalink
Refactor model_parallel tests to allow different (device, backend) co…
Browse files Browse the repository at this point in the history
…mbination

Summary:
Refactoring to make model_parallel tests to take more combinations of (device, backend).

1. Won't force device to cuda if backend is NCCL or force device to cpu if backend is gloo. i.e. allow the combinations of (NCCL, CPU) and (Gloo, GPU).
2. Refactor test_parameter_init to test parameter init on (nccl, gpu) and (gloo, cpu) combinations.

Differential Revision: D53149924
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jan 30, 2024
1 parent 29edfe9 commit 8e73513
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 50 deletions.
13 changes: 6 additions & 7 deletions torchrec/distributed/test_utils/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@ def __init__(
self.backend = backend
self.local_size = local_size

if backend == "nccl":
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
self.device: torch.device = device
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
self.device: torch.device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(self.device)

torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
self.device: torch.device = torch.device("cpu")
torch.use_deterministic_algorithms(True)
self.pg: Optional[dist.ProcessGroup] = None

def __enter__(self) -> "MultiProcessContext":
Expand Down
9 changes: 4 additions & 5 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

class ModelParallelTestShared(MultiProcessTestBase):
@seed_and_log
def setUp(self) -> None:
def setUp(self, backend: str = "nccl") -> None:
super().setUp()

num_features = 4
Expand Down Expand Up @@ -76,12 +76,11 @@ def setUp(self) -> None:
for feature in table.feature_names
]
}
self.backend = backend
if torch.cuda.is_available():
self.device = torch.device("cuda")
self.backend = "nccl"
else:
self.device = torch.device("cpu")
self.backend = "gloo"

def _test_sharding(
self,
Expand Down Expand Up @@ -122,8 +121,8 @@ def _test_sharding(

@skip_if_asan_class
class ModelParallelBase(ModelParallelTestShared):
def setUp(self) -> None:
super().setUp()
def setUp(self, backend: str = "nccl") -> None:
super().setUp(backend=backend)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
Expand Down
62 changes: 24 additions & 38 deletions torchrec/distributed/test_utils/test_model_parallel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,26 +178,27 @@ def _test_sharded_forward(


class ModelParallelSparseOnlyBase(unittest.TestCase):
def tearDown(self) -> None:
dist.destroy_process_group()

def test_sharding_ebc_as_top_level(self) -> None:
def setUp(self, backend: str = "nccl") -> None:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["NCCL_SOCKET_IFNAME"] = "lo"

self.backend = backend
if torch.cuda.is_available():
curr_device = torch.device("cuda:0")
torch.cuda.set_device(curr_device)
backend = "nccl"
self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device)
else:
curr_device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)
self.device = torch.device("cpu")

dist.init_process_group(backend=self.backend)

def tearDown(self) -> None:
dist.destroy_process_group()

def test_sharding_ebc_as_top_level(self) -> None:
embedding_dim = 128
num_embeddings = 256
ebc = EmbeddingBagCollection(
Expand All @@ -213,27 +214,11 @@ def test_sharding_ebc_as_top_level(self) -> None:
],
)

model = DistributedModelParallel(ebc, device=curr_device)
model = DistributedModelParallel(ebc, device=self.device)

self.assertTrue(isinstance(model.module, ShardedEmbeddingBagCollection))

def test_sharding_fused_ebc_as_top_level(self) -> None:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["NCCL_SOCKET_IFNAME"] = "lo"

if torch.cuda.is_available():
curr_device = torch.device("cuda:0")
torch.cuda.set_device(curr_device)
backend = "nccl"
else:
curr_device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)

embedding_dim = 128
num_embeddings = 256
ebc = FusedEmbeddingBagCollection(
Expand All @@ -251,26 +236,26 @@ def test_sharding_fused_ebc_as_top_level(self) -> None:
optimizer_kwargs={"lr": 0.02},
)

model = DistributedModelParallel(ebc, device=curr_device)
model = DistributedModelParallel(ebc, device=self.device)

self.assertTrue(isinstance(model.module, ShardedFusedEmbeddingBagCollection))


class ModelParallelStateDictBase(unittest.TestCase):
def setUp(self) -> None:
def setUp(self, backend: str = "nccl") -> None:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["NCCL_SOCKET_IFNAME"] = "lo"

self.backend = backend
if torch.cuda.is_available():
self.device = torch.device("cuda:0")
backend = "nccl"
torch.cuda.set_device(self.device)
else:
self.device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)

num_features = 4
Expand Down Expand Up @@ -377,27 +362,28 @@ def __init__(self, device: str, val: float) -> None:
def reset_parameters(self) -> None:
nn.init.constant_(self.p, self.val)

dist.destroy_process_group()
dist.init_process_group(backend="gloo")

# Check that already allocated parameters are left 'as is'.
cpu_model = MyModel(device="cpu", val=3.2)
unsharded_model = MyModel(device=self.device, val=3.2)
sharded_model = DistributedModelParallel(
cpu_model,
unsharded_model,
device=self.device,
)
sharded_param = next(sharded_model.parameters())
np.testing.assert_array_equal(
np.array([3.2, 3.2, 3.2], dtype=np.float32), sharded_param.detach().numpy()
np.array([3.2, 3.2, 3.2], dtype=np.float32),
sharded_param.detach().cpu().numpy(),
)

# Check that parameters over 'meta' device are allocated and initialized.
meta_model = MyModel(device="meta", val=7.5)
sharded_model = DistributedModelParallel(
meta_model,
device=self.device,
)
sharded_param = next(sharded_model.parameters())
np.testing.assert_array_equal(
np.array([7.5, 7.5, 7.5], dtype=np.float32), sharded_param.detach().numpy()
np.array([7.5, 7.5, 7.5], dtype=np.float32),
sharded_param.detach().cpu().numpy(),
)

def test_meta_device_dmp_state_dict(self) -> None:
Expand Down

0 comments on commit 8e73513

Please sign in to comment.