Skip to content

Commit 0d4ea3f

Browse files
authored
[core][distributed] use tcp store directly (#10275)
Signed-off-by: youkaichao <[email protected]>
1 parent 112fa0b commit 0d4ea3f

File tree

2 files changed

+29
-25
lines changed

2 files changed

+29
-25
lines changed

tests/distributed/test_utils.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,15 @@ def test_cuda_device_count_stateless():
4343

4444

4545
def cpu_worker(rank, WORLD_SIZE, port1, port2):
46-
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
46+
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
47+
port=port1,
4748
rank=rank,
4849
world_size=WORLD_SIZE)
4950
if rank <= 2:
50-
pg2 = StatelessProcessGroup.create(
51-
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
51+
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
52+
port=port2,
53+
rank=rank,
54+
world_size=3)
5255
data = torch.tensor([rank])
5356
data = pg1.broadcast_obj(data, src=2)
5457
assert data.item() == 2
@@ -62,14 +65,17 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
6265

6366
def gpu_worker(rank, WORLD_SIZE, port1, port2):
6467
torch.cuda.set_device(rank)
65-
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
68+
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
69+
port=port1,
6670
rank=rank,
6771
world_size=WORLD_SIZE)
6872
pynccl1 = PyNcclCommunicator(pg1, device=rank)
6973
pynccl1.disabled = False
7074
if rank <= 2:
71-
pg2 = StatelessProcessGroup.create(
72-
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
75+
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
76+
port=port2,
77+
rank=rank,
78+
world_size=3)
7379
pynccl2 = PyNcclCommunicator(pg2, device=rank)
7480
pynccl2.disabled = False
7581
data = torch.tensor([rank]).cuda()
@@ -89,7 +95,8 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
8995

9096

9197
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
92-
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
98+
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
99+
port=port1,
93100
rank=rank,
94101
world_size=WORLD_SIZE)
95102
if rank == 2:
@@ -101,16 +108,15 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):
101108

102109

103110
def allgather_worker(rank, WORLD_SIZE, port1, port2):
104-
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
111+
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
112+
port=port1,
105113
rank=rank,
106114
world_size=WORLD_SIZE)
107115
data = pg1.all_gather_obj(rank)
108116
assert data == list(range(WORLD_SIZE))
109117
pg1.barrier()
110118

111119

112-
# TODO: investigate why this test is flaky. It hangs during initialization.
113-
@pytest.mark.skip("Skip the test because it is flaky.")
114120
@multi_gpu_test(num_gpus=4)
115121
@pytest.mark.parametrize(
116122
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])

vllm/distributed/utils.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
1010

1111
import torch
12-
from torch.distributed.rendezvous import rendezvous
12+
from torch.distributed import TCPStore
1313

1414
import vllm.envs as envs
1515
from vllm.logger import init_logger
@@ -97,7 +97,6 @@ class StatelessProcessGroup:
9797
group. Only use it to communicate metadata between processes.
9898
For data-plane communication, create NCCL-related objects.
9999
"""
100-
prefix: str
101100
rank: int
102101
world_size: int
103102
store: torch._C._distributed_c10d.Store
@@ -127,7 +126,7 @@ def __post_init__(self):
127126
def send_obj(self, obj: Any, dst: int):
128127
"""Send an object to a destination rank."""
129128
self.expire_data()
130-
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}"
129+
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
131130
self.store.set(key, pickle.dumps(obj))
132131
self.send_dst_counter[dst] += 1
133132
self.entries.append((key, time.time()))
@@ -147,8 +146,7 @@ def recv_obj(self, src: int) -> Any:
147146
"""Receive an object from a source rank."""
148147
obj = pickle.loads(
149148
self.store.get(
150-
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}"
151-
))
149+
f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
152150
self.recv_src_counter[src] += 1
153151
return obj
154152

@@ -159,14 +157,14 @@ def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
159157
"""
160158
if self.rank == src:
161159
self.expire_data()
162-
key = (f"{self.prefix}/broadcast_from/{src}/"
160+
key = (f"broadcast_from/{src}/"
163161
f"{self.broadcast_send_counter}")
164162
self.store.set(key, pickle.dumps(obj))
165163
self.broadcast_send_counter += 1
166164
self.entries.append((key, time.time()))
167165
return obj
168166
else:
169-
key = (f"{self.prefix}/broadcast_from/{src}/"
167+
key = (f"broadcast_from/{src}/"
170168
f"{self.broadcast_recv_src_counter[src]}")
171169
recv_obj = pickle.loads(self.store.get(key))
172170
self.broadcast_recv_src_counter[src] += 1
@@ -194,7 +192,8 @@ def barrier(self):
194192

195193
@staticmethod
196194
def create(
197-
init_method: str,
195+
host: str,
196+
port: int,
198197
rank: int,
199198
world_size: int,
200199
data_expiration_seconds: int = 3600,
@@ -214,15 +213,14 @@ def create(
214213
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
215214
C, and D can call `StatelessProcessGroup.create` to form another group.
216215
""" # noqa
217-
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
218-
timeout = _DEFAULT_PG_TIMEOUT
219-
220-
store, rank, world_size = next(
221-
rendezvous(init_method, rank, world_size, timeout=timeout))
222-
store.set_timeout(timeout)
216+
store = TCPStore(
217+
host_name=host,
218+
port=port,
219+
world_size=world_size,
220+
is_master=(rank == 0),
221+
)
223222

224223
return StatelessProcessGroup(
225-
prefix=init_method,
226224
rank=rank,
227225
world_size=world_size,
228226
store=store,

0 commit comments

Comments
 (0)