9
9
from typing import Any , Deque , Dict , Optional , Sequence , Tuple
10
10
11
11
import torch
12
- from torch .distributed . rendezvous import rendezvous
12
+ from torch .distributed import TCPStore
13
13
14
14
import vllm .envs as envs
15
15
from vllm .logger import init_logger
@@ -97,7 +97,6 @@ class StatelessProcessGroup:
97
97
group. Only use it to communicate metadata between processes.
98
98
For data-plane communication, create NCCL-related objects.
99
99
"""
100
- prefix : str
101
100
rank : int
102
101
world_size : int
103
102
store : torch ._C ._distributed_c10d .Store
@@ -127,7 +126,7 @@ def __post_init__(self):
127
126
def send_obj (self , obj : Any , dst : int ):
128
127
"""Send an object to a destination rank."""
129
128
self .expire_data ()
130
- key = f"{ self . prefix } / send_to/{ dst } /{ self .send_dst_counter [dst ]} "
129
+ key = f"send_to/{ dst } /{ self .send_dst_counter [dst ]} "
131
130
self .store .set (key , pickle .dumps (obj ))
132
131
self .send_dst_counter [dst ] += 1
133
132
self .entries .append ((key , time .time ()))
@@ -147,8 +146,7 @@ def recv_obj(self, src: int) -> Any:
147
146
"""Receive an object from a source rank."""
148
147
obj = pickle .loads (
149
148
self .store .get (
150
- f"{ self .prefix } /send_to/{ self .rank } /{ self .recv_src_counter [src ]} "
151
- ))
149
+ f"send_to/{ self .rank } /{ self .recv_src_counter [src ]} " ))
152
150
self .recv_src_counter [src ] += 1
153
151
return obj
154
152
@@ -159,14 +157,14 @@ def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
159
157
"""
160
158
if self .rank == src :
161
159
self .expire_data ()
162
- key = (f"{ self . prefix } / broadcast_from/{ src } /"
160
+ key = (f"broadcast_from/{ src } /"
163
161
f"{ self .broadcast_send_counter } " )
164
162
self .store .set (key , pickle .dumps (obj ))
165
163
self .broadcast_send_counter += 1
166
164
self .entries .append ((key , time .time ()))
167
165
return obj
168
166
else :
169
- key = (f"{ self . prefix } / broadcast_from/{ src } /"
167
+ key = (f"broadcast_from/{ src } /"
170
168
f"{ self .broadcast_recv_src_counter [src ]} " )
171
169
recv_obj = pickle .loads (self .store .get (key ))
172
170
self .broadcast_recv_src_counter [src ] += 1
@@ -194,7 +192,8 @@ def barrier(self):
194
192
195
193
@staticmethod
196
194
def create (
197
- init_method : str ,
195
+ host : str ,
196
+ port : int ,
198
197
rank : int ,
199
198
world_size : int ,
200
199
data_expiration_seconds : int = 3600 ,
@@ -214,15 +213,14 @@ def create(
214
213
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
215
214
C, and D can call `StatelessProcessGroup.create` to form another group.
216
215
""" # noqa
217
- from torch . _C . _distributed_c10d import _DEFAULT_PG_TIMEOUT
218
- timeout = _DEFAULT_PG_TIMEOUT
219
-
220
- store , rank , world_size = next (
221
- rendezvous ( init_method , rank , world_size , timeout = timeout ))
222
- store . set_timeout ( timeout )
216
+ store = TCPStore (
217
+ host_name = host ,
218
+ port = port ,
219
+ world_size = world_size ,
220
+ is_master = ( rank == 0 ),
221
+ )
223
222
224
223
return StatelessProcessGroup (
225
- prefix = init_method ,
226
224
rank = rank ,
227
225
world_size = world_size ,
228
226
store = store ,
0 commit comments