Skip to content

Commit

Permalink
[xla:collectives] Remove Send/Recv Ptr To/From peer from Communicator…
Browse files Browse the repository at this point in the history
… API

Sending and receiving pointers is not a part of generic communicator API.

PiperOrigin-RevId: 711464151
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Jan 2, 2025
1 parent 56dc2ba commit 28c0aae
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 49 deletions.
25 changes: 0 additions & 25 deletions third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,18 +312,6 @@ absl::Status NcclCommunicator::Send(se::DeviceMemoryBase send_buffer,
peer.value(), comm_, se::gpu::AsGpuStreamValue(stream)));
}

absl::Status NcclCommunicator::SendPtrToPeer(void* ptr, RankId peer,
const Executor& executor) {
TF_ASSIGN_OR_RETURN(se::Stream * stream, ToStream(executor));

VLOG(3) << absl::StreamFormat(
"Launch NCCL RecvPtrFromPeer operation on device #%d; "
"peer=%d; comm=%p; stream=%p",
stream->parent()->device_ordinal(), peer.value(), comm_, stream);
return XLA_NCCL_STATUS(ncclSend(ptr, 1, ncclUint64, peer.value(), comm_,
se::gpu::AsGpuStreamValue(stream)));
}

absl::Status NcclCommunicator::Recv(se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
RankId peer, const Executor& executor) {
Expand All @@ -343,19 +331,6 @@ absl::Status NcclCommunicator::Recv(se::DeviceMemoryBase recv_buffer,
peer.value(), comm_, se::gpu::AsGpuStreamValue(stream)));
}

absl::Status NcclCommunicator::RecvPtrFromPeer(void* ptr, RankId peer,
const Executor& executor) {
TF_ASSIGN_OR_RETURN(se::Stream * stream, ToStream(executor));

VLOG(3) << absl::StreamFormat(
"Launch NCCL RecvPtrFromPeer operation on device #%d; "
"peer=%d; comm=%p; stream=%p",
stream->parent()->device_ordinal(), peer.value(), comm_, stream);

return XLA_NCCL_STATUS(ncclRecv(ptr, 1, ncclUint64, peer.value(), comm_,
se::gpu::AsGpuStreamValue(stream)));
}

std::string NcclCommunicator::ToString() const {
return absl::StrFormat("NccCommunicator(ncclComm_t=%p)", comm_);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,9 @@ class NcclCommunicator : public Communicator {
absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype,
size_t count, RankId peer, const Executor& executor) final;

absl::Status SendPtrToPeer(void* ptr, RankId peer,
const Executor& executor) final;

absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
size_t count, RankId peer, const Executor& executor) final;

absl::Status RecvPtrFromPeer(void* ptr, RankId peer,
const Executor& executor) final;

std::string ToString() const final;

ncclComm_t comm() const { return comm_; }
Expand Down
8 changes: 0 additions & 8 deletions third_party/xla/xla/core/collectives/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,11 @@ class Communicator {
PrimitiveType dtype, size_t count, RankId peer,
const Executor& executor) = 0;

// Send a pointer `ptr` to rank `peer`.
virtual absl::Status SendPtrToPeer(void* ptr, RankId peer,
const Executor& executor) = 0;

// Receive data from rank `peer` into `recv_buff`.
virtual absl::Status Recv(se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, RankId peer,
const Executor& executor) = 0;

// Receive a pointer from rank `peer` into `ptr`.
virtual absl::Status RecvPtrFromPeer(void* ptr, RankId peer,
const Executor& executor) = 0;

virtual std::string ToString() const = 0;
};

Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,7 @@ cc_library(
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
Expand Down
39 changes: 29 additions & 10 deletions third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/node_hash_map.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/strings/substitute.h"
#include "absl/synchronization/mutex.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
Expand Down Expand Up @@ -277,6 +278,26 @@ absl::Status RunAllToAll(GpuCollectives* collectives, bool has_split_dimension,
return collectives->GroupEnd();
}

static absl::Status SendPtrToPeer(void* ptr, RankId peer, Communicator* comm,
se::Stream& stream) {
VLOG(3) << absl::StreamFormat(
"RecvPtrFromPeer on device #%d; peer=%d; comm=%p; stream=%p",
stream.parent()->device_ordinal(), peer.value(), comm, &stream);

return comm->Send(se::DeviceMemoryBase(ptr, sizeof(void*)), U64, 1, peer,
GpuCollectives::On(stream));
}

static absl::Status RecvPtrFromPeer(void* ptr, RankId peer, Communicator* comm,
se::Stream& stream) {
VLOG(3) << absl::StreamFormat(
"RecvPtrFromPeer on device #%d; peer=%d; comm=%p; stream=%p",
stream.parent()->device_ordinal(), peer.value(), comm, &stream);

return comm->Recv(se::DeviceMemoryBase(ptr, sizeof(void*)), U64, 1, peer,
GpuCollectives::On(stream));
}

absl::Status RunMemCpyAllToAll(
GpuCollectives* collectives, bool has_split_dimension,
std::vector<DeviceBufferPair>& buffers, se::Stream& stream,
Expand Down Expand Up @@ -309,11 +330,10 @@ absl::Status RunMemCpyAllToAll(
peer * chunk_elements, chunk_elements);
send_pointer_map[peer] = (uint64_t)recv_slice.opaque();

TF_RETURN_IF_ERROR(comm->SendPtrToPeer(
&send_pointer_map[peer], RankId(peer), GpuCollectives::On(stream)));
TF_RETURN_IF_ERROR(comm->RecvPtrFromPeer(&receive_pointer_map[peer],
RankId(peer),
GpuCollectives::On(stream)));
TF_RETURN_IF_ERROR(
SendPtrToPeer(&send_pointer_map[peer], RankId(peer), comm, stream));
TF_RETURN_IF_ERROR(RecvPtrFromPeer(&receive_pointer_map[peer],
RankId(peer), comm, stream));
}
TF_RETURN_IF_ERROR(collectives->GroupEnd());
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
Expand All @@ -337,11 +357,10 @@ absl::Status RunMemCpyAllToAll(
send_pointer_map[peer] =
(uint64_t)buffers[peer].destination_buffer.opaque();

TF_RETURN_IF_ERROR(comm->SendPtrToPeer(
&send_pointer_map[peer], RankId(peer), GpuCollectives::On(stream)));
TF_RETURN_IF_ERROR(comm->RecvPtrFromPeer(&receive_pointer_map[peer],
RankId(peer),
GpuCollectives::On(stream)));
TF_RETURN_IF_ERROR(
SendPtrToPeer(&send_pointer_map[peer], RankId(peer), comm, stream));
TF_RETURN_IF_ERROR(RecvPtrFromPeer(&receive_pointer_map[peer],
RankId(peer), comm, stream));
}
TF_RETURN_IF_ERROR(collectives->GroupEnd());
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
Expand Down

0 comments on commit 28c0aae

Please sign in to comment.