diff --git a/torchacc/dist/__init__.py b/torchacc/dist/__init__.py index a35de3a..ccc6fac 100644 --- a/torchacc/dist/__init__.py +++ b/torchacc/dist/__init__.py @@ -89,8 +89,9 @@ def init_nccl_context(config) -> None: _NCCL_CONTEXT_INITED = True + def rendezvous(tag, payload=b'', replicas=[]): - """Waits for all the mesh clients to reach the named rendezvous. + """Waits for all the mesh clients to reach the named rendezvous. We use the rendezvous api of xla directly. Args: @@ -104,4 +105,4 @@ def rendezvous(tag, payload=b'', replicas=[]): The payloads exchanged by all the other cores, with the payload of core ordinal `i` at position `i` in the returned tuple. """ - return xm.rendezvous(payload, replicas or None, tag=tag) + return xm.rendezvous(payload, replicas or None, tag=tag)