Skip to content

Commit

Permalink
[DistDGL] fix device mismatch when calling all_to_all with gloo backe…
Browse files Browse the repository at this point in the history
…nd (#7409)
  • Loading branch information
Rhett-Ying authored May 17, 2024
1 parent 08a7c74 commit 191681d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 26 deletions.
3 changes: 3 additions & 0 deletions examples/distributed/rgcn/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,9 @@ def main(args):
dgl.distributed.initialize(args.ip_config, use_graphbolt=args.use_graphbolt)
if not args.standalone:
backend = "gloo" if args.num_gpus == -1 else "nccl"
if args.sparse_embedding and args.dgl_sparse:
# `nccl` is not fully supported in DistDGL's sparse optimizer.
backend = "gloo"
th.distributed.init_process_group(backend=backend)

g = dgl.distributed.DistGraph(args.graph_name, part_config=args.conf_path)
Expand Down
41 changes: 27 additions & 14 deletions python/dgl/distributed/optim/pytorch/sparse_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,14 @@ def step(self):
of the embeddings involved in a mini-batch to DGL's servers and update the embeddings.
"""
with th.no_grad():
device = (
# [Rui]
# As `gloo` supports CPU tensors only while `nccl` supports GPU
# tensors only, we firstly create tensors on the corresponding
# devices and then copy the data to target device if needed.
# Please note that the target device can be different from the
# preferred device.
target_device = None
preferred_device = (
th.device(f"cuda:{self._rank}")
if th.distributed.get_backend() == "nccl"
else th.device("cpu")
Expand All @@ -283,21 +290,26 @@ def step(self):
# Note: we cannot skip the gradient exchange and update steps as other
# working processes may send gradient update requests corresponding
# to certain embedding to this process.
#
# [WARNING][TODO][Rui]
# For empty idx and grad, we blindly create data on the
# preferred device, which may not be the device where the
# embedding is stored.
idics = (
th.cat(idics, dim=0)
if len(idics) != 0
else th.zeros((0,), dtype=th.long, device=th.device("cpu"))
else th.zeros((0,), dtype=th.int64, device=preferred_device)
)
grads = (
th.cat(grads, dim=0)
if len(grads) != 0
else th.zeros(
(0, emb.embedding_dim),
dtype=th.float32,
device=th.device("cpu"),
device=preferred_device,
)
)
device = grads.device
target_device = grads.device

# will send grad to each corresponding trainer
if self._world_size > 1:
Expand All @@ -317,7 +329,7 @@ def step(self):
th.tensor(
[idx_i.shape[0]],
dtype=th.int64,
device=device,
device=preferred_device,
)
)
idics_list.append(idx_i)
Expand All @@ -334,7 +346,7 @@ def step(self):
th.tensor(
[idx_j.shape[0]],
dtype=th.int64,
device=device,
device=preferred_device,
)
)
idics_list.append(idx_j)
Expand All @@ -349,19 +361,22 @@ def step(self):
# sync information here
gather_list = list(
th.empty(
[self._world_size], dtype=th.int64, device=device
[self._world_size],
dtype=th.int64,
device=preferred_device,
).chunk(self._world_size)
)
alltoall(
self._rank,
self._world_size,
gather_list,
idx_split_size,
device,
)
idx_gather_list = [
th.empty(
(int(num_emb),), dtype=idics.dtype, device=device
(int(num_emb),),
dtype=idics.dtype,
device=preferred_device,
)
for num_emb in gather_list
]
Expand All @@ -370,14 +385,13 @@ def step(self):
self._world_size,
idx_gather_list,
idics_list,
device,
)
local_indics[name] = idx_gather_list
grad_gather_list = [
th.empty(
(int(num_emb), grads.shape[1]),
dtype=grads.dtype,
device=device,
device=preferred_device,
)
for num_emb in gather_list
]
Expand All @@ -386,7 +400,6 @@ def step(self):
self._world_size,
grad_gather_list,
grad_list,
device,
)
local_grads[name] = grad_gather_list
else:
Expand All @@ -405,8 +418,8 @@ def step(self):
idx = th.cat(local_indics[name], dim=0)
grad = th.cat(local_grads[name], dim=0)
self.update(
idx.to(device, non_blocking=True),
grad.to(device, non_blocking=True),
idx.to(target_device, non_blocking=True),
grad.to(target_device, non_blocking=True),
emb,
)

Expand Down
14 changes: 2 additions & 12 deletions python/dgl/distributed/optim/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def alltoallv_cpu(rank, world_size, output_tensor_list, input_tensor_list):
th.distributed.barrier()


def alltoall(rank, world_size, output_tensor_list, input_tensor_list, device):
def alltoall(rank, world_size, output_tensor_list, input_tensor_list):
"""Each process scatters list of input tensors to all processes in a cluster
and return gathered list of tensors in output list. The tensors should have the same shape.
Expand All @@ -76,13 +76,8 @@ def alltoall(rank, world_size, output_tensor_list, input_tensor_list, device):
The received tensors
input_tensor_list : List of tensor
The tensors to exchange
device: th.device
Device of the tensors
"""
if th.distributed.get_backend() == "nccl":
input_tensor_list = [
tensor.to(th.device(device)) for tensor in input_tensor_list
]
th.distributed.all_to_all(output_tensor_list, input_tensor_list)
else:
alltoall_cpu(
Expand All @@ -93,7 +88,7 @@ def alltoall(rank, world_size, output_tensor_list, input_tensor_list, device):
)


def alltoallv(rank, world_size, output_tensor_list, input_tensor_list, device):
def alltoallv(rank, world_size, output_tensor_list, input_tensor_list):
"""Each process scatters list of input tensors to all processes in a cluster
and return gathered list of tensors in output list.
Expand All @@ -107,13 +102,8 @@ def alltoallv(rank, world_size, output_tensor_list, input_tensor_list, device):
The received tensors
input_tensor_list : List of tensor
The tensors to exchange
device: th.device
Device of the tensors
"""
if th.distributed.get_backend() == "nccl":
input_tensor_list = [
tensor.to(th.device(device)) for tensor in input_tensor_list
]
th.distributed.all_to_all(output_tensor_list, input_tensor_list)
else:
alltoallv_cpu(
Expand Down

0 comments on commit 191681d

Please sign in to comment.