From 28c0aae78c70db05adf37b1a1b9b311f2a901fdc Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 2 Jan 2025 10:38:54 -0800 Subject: [PATCH] [xla:collectives] Remove Send/Recv Ptr To/From peer from Communicator API Sending and receiving pointers is not a part of generic communicator API. PiperOrigin-RevId: 711464151 --- .../gpu/collectives/nccl_communicator.cc | 25 ------------ .../gpu/collectives/nccl_communicator.h | 6 --- .../xla/xla/core/collectives/communicator.h | 8 ---- third_party/xla/xla/service/gpu/runtime/BUILD | 1 + .../gpu/runtime/nccl_all_to_all_thunk.cc | 39 ++++++++++++++----- 5 files changed, 30 insertions(+), 49 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc index 781c230e253b87..de27fac8a5facf 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc @@ -312,18 +312,6 @@ absl::Status NcclCommunicator::Send(se::DeviceMemoryBase send_buffer, peer.value(), comm_, se::gpu::AsGpuStreamValue(stream))); } -absl::Status NcclCommunicator::SendPtrToPeer(void* ptr, RankId peer, - const Executor& executor) { - TF_ASSIGN_OR_RETURN(se::Stream * stream, ToStream(executor)); - - VLOG(3) << absl::StreamFormat( - "Launch NCCL RecvPtrFromPeer operation on device #%d; " - "peer=%d; comm=%p; stream=%p", - stream->parent()->device_ordinal(), peer.value(), comm_, stream); - return XLA_NCCL_STATUS(ncclSend(ptr, 1, ncclUint64, peer.value(), comm_, - se::gpu::AsGpuStreamValue(stream))); -} - absl::Status NcclCommunicator::Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) { @@ -343,19 +331,6 @@ absl::Status NcclCommunicator::Recv(se::DeviceMemoryBase recv_buffer, peer.value(), comm_, se::gpu::AsGpuStreamValue(stream))); } -absl::Status NcclCommunicator::RecvPtrFromPeer(void* ptr, RankId peer, - const Executor& executor) { - TF_ASSIGN_OR_RETURN(se::Stream * stream, ToStream(executor)); - - VLOG(3) << absl::StreamFormat( - "Launch NCCL RecvPtrFromPeer operation on device #%d; " - "peer=%d; comm=%p; stream=%p", - stream->parent()->device_ordinal(), peer.value(), comm_, stream); - - return XLA_NCCL_STATUS(ncclRecv(ptr, 1, ncclUint64, peer.value(), comm_, - se::gpu::AsGpuStreamValue(stream))); -} - std::string NcclCommunicator::ToString() const { return absl::StrFormat("NccCommunicator(ncclComm_t=%p)", comm_); } diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.h b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.h index 4ff9e79cef470b..b6dda86a8e72fd 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.h +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.h @@ -78,15 +78,9 @@ class NcclCommunicator : public Communicator { absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) final; - absl::Status SendPtrToPeer(void* ptr, RankId peer, - const Executor& executor) final; - absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) final; - absl::Status RecvPtrFromPeer(void* ptr, RankId peer, - const Executor& executor) final; - std::string ToString() const final; ncclComm_t comm() const { return comm_; } diff --git a/third_party/xla/xla/core/collectives/communicator.h b/third_party/xla/xla/core/collectives/communicator.h index 7d2d3cb681567b..529c5d28d79f75 100644 --- a/third_party/xla/xla/core/collectives/communicator.h +++ b/third_party/xla/xla/core/collectives/communicator.h @@ -106,19 +106,11 @@ class Communicator { PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) = 0; - // Send a pointer `ptr` to rank `peer`. - virtual absl::Status SendPtrToPeer(void* ptr, RankId peer, - const Executor& executor) = 0; - // Receive data from rank `peer` into `recv_buff`. virtual absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) = 0; - // Receive a pointer from rank `peer` into `ptr`. - virtual absl::Status RecvPtrFromPeer(void* ptr, RankId peer, - const Executor& executor) = 0; - virtual std::string ToString() const = 0; }; diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 2cefa6f3bf320b..93ec5730424146 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -755,6 +755,7 @@ cc_library( "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", ], diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc index f20447d0481cd5..4e49a3b9f31320 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" #include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" @@ -277,6 +278,26 @@ absl::Status RunAllToAll(GpuCollectives* collectives, bool has_split_dimension, return collectives->GroupEnd(); } +static absl::Status SendPtrToPeer(void* ptr, RankId peer, Communicator* comm, + se::Stream& stream) { + VLOG(3) << absl::StreamFormat( + "RecvPtrFromPeer on device #%d; peer=%d; comm=%p; stream=%p", + stream.parent()->device_ordinal(), peer.value(), comm, &stream); + + return comm->Send(se::DeviceMemoryBase(ptr, sizeof(void*)), U64, 1, peer, + GpuCollectives::On(stream)); +} + +static absl::Status RecvPtrFromPeer(void* ptr, RankId peer, Communicator* comm, + se::Stream& stream) { + VLOG(3) << absl::StreamFormat( + "RecvPtrFromPeer on device #%d; peer=%d; comm=%p; stream=%p", + stream.parent()->device_ordinal(), peer.value(), comm, &stream); + + return comm->Recv(se::DeviceMemoryBase(ptr, sizeof(void*)), U64, 1, peer, + GpuCollectives::On(stream)); +} + absl::Status RunMemCpyAllToAll( GpuCollectives* collectives, bool has_split_dimension, std::vector& buffers, se::Stream& stream, @@ -309,11 +330,10 @@ absl::Status RunMemCpyAllToAll( peer * chunk_elements, chunk_elements); send_pointer_map[peer] = (uint64_t)recv_slice.opaque(); - TF_RETURN_IF_ERROR(comm->SendPtrToPeer( - &send_pointer_map[peer], RankId(peer), GpuCollectives::On(stream))); - TF_RETURN_IF_ERROR(comm->RecvPtrFromPeer(&receive_pointer_map[peer], - RankId(peer), - GpuCollectives::On(stream))); + TF_RETURN_IF_ERROR( + SendPtrToPeer(&send_pointer_map[peer], RankId(peer), comm, stream)); + TF_RETURN_IF_ERROR(RecvPtrFromPeer(&receive_pointer_map[peer], + RankId(peer), comm, stream)); } TF_RETURN_IF_ERROR(collectives->GroupEnd()); TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); @@ -337,11 +357,10 @@ absl::Status RunMemCpyAllToAll( send_pointer_map[peer] = (uint64_t)buffers[peer].destination_buffer.opaque(); - TF_RETURN_IF_ERROR(comm->SendPtrToPeer( - &send_pointer_map[peer], RankId(peer), GpuCollectives::On(stream))); - TF_RETURN_IF_ERROR(comm->RecvPtrFromPeer(&receive_pointer_map[peer], - RankId(peer), - GpuCollectives::On(stream))); + TF_RETURN_IF_ERROR( + SendPtrToPeer(&send_pointer_map[peer], RankId(peer), comm, stream)); + TF_RETURN_IF_ERROR(RecvPtrFromPeer(&receive_pointer_map[peer], + RankId(peer), comm, stream)); } TF_RETURN_IF_ERROR(collectives->GroupEnd()); TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());