Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support gpu send/recv thunk #1

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions xla/service/gpu/nccl_p2p_thunk_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ std::enable_if_t<std::is_same_v<OpT, mlir::lmhlo::SendOp> ||
std::is_same_v<OpT, mlir::lmhlo::RecvOp>,
CollectiveOpGroupMode>
GetGroupModeForSendRecv(OpT op) {
return GetCollectiveOpGroupMode(op.getChannelHandle().getHandle() > 1,
std::nullopt)
.value();
// return GetCollectiveOpGroupMode(op.getChannelHandle().getHandle() > 1,
// std::nullopt)
// .value();
return CollectiveOpGroupMode::kFlattenedID;
}

// Constructs the NcclP2PConfig for Send and Recv.
Expand Down Expand Up @@ -104,15 +105,16 @@ GetNcclP2PConfigForSendRecv(OpT op, int64_t replica_count,
}

// All execution instances of a send/recv together form a replica group.
const int64_t num_participants =
config.group_mode == CollectiveOpGroupMode::kCrossReplica
? replica_count
: partition_count;
// const int64_t num_participants =
// config.group_mode == CollectiveOpGroupMode::kCrossReplica
// ? replica_count
// : partition_count;
// const int64_t num_participants = 2;
config.replica_groups.emplace_back();
ReplicaGroup& replica_group = config.replica_groups.front();
for (int i = 0; i < num_participants; ++i) {
replica_group.add_replica_ids(i);
}
// for (int i = 0; i < num_participants; ++i) {
// replica_group.add_replica_ids(i);
// }

auto source_target_pairs = GetSourceTargetPairs(op.getFrontendAttributes());
TF_CHECK_OK(source_target_pairs.status());
Expand All @@ -125,8 +127,14 @@ GetNcclP2PConfigForSendRecv(OpT op, int64_t replica_count,
source;
p2p_config.id_to_source_target.insert({source, {}}).first->second.target =
target;
if (source <= target) {
replica_group.add_replica_ids(source);
replica_group.add_replica_ids(target);
} else {
replica_group.add_replica_ids(target);
replica_group.add_replica_ids(source);
}
}

return p2p_config;
}

Expand Down
30 changes: 15 additions & 15 deletions xla/service/gpu/nccl_recv_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,21 @@ Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target,
se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);

// Receive data from the source peer to the destination buffer.
if (source_id) {
VLOG(3) << absl::StreamFormat(
"%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, "
"stream=%p)",
device_string, dest_addr.opaque(), element_count, *source_id,
static_cast<const void*>(comm), gpu_stream);
XLA_CUDA_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype,
*source_id, comm, gpu_stream));
} else {
// If there is no source peer, i.e. no sender to this instance, zero out
// the destination buffer.
VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero",
device_string);
stream.ThenMemZero(&dest_addr, dest_addr.size());
}
// if (source_id) {
VLOG(3) << absl::StreamFormat(
"%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, "
"stream=%p)",
device_string, dest_addr.opaque(), element_count, *source_id,
static_cast<const void*>(comm), gpu_stream);
XLA_CUDA_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype,
*source_id, comm, gpu_stream));
// } else {
// // If there is no source peer, i.e. no sender to this instance, zero out
// // the destination buffer.
// VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero",
// device_string);
// stream.ThenMemZero(&dest_addr, dest_addr.size());
// }
return OkStatus();
#else // XLA_ENABLE_XCCL
return Unimplemented(
Expand Down
18 changes: 9 additions & 9 deletions xla/service/gpu/nccl_send_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@ Status RunSend(NcclP2PConfig::SourceTargetMapEntry source_target,
se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);

// Send source buffer to target peer if needed.
if (target_id) {
VLOG(3) << absl::StreamFormat(
"%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
"comm=%p, stream=%p)",
device_string, src_addr.opaque(), element_count, *target_id,
static_cast<const void*>(comm), gpu_stream);
XLA_CUDA_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype,
*target_id, comm, gpu_stream));
}
// if (target_id) {
VLOG(3) << absl::StreamFormat(
"%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
"comm=%p, stream=%p)",
device_string, src_addr.opaque(), element_count, *target_id,
static_cast<const void*>(comm), gpu_stream);
XLA_CUDA_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype,
*target_id, comm, gpu_stream));
// }
return OkStatus();
#else // XLA_ENABLE_XCCL
return Unimplemented(
Expand Down
Loading