diff --git a/xla/service/gpu/nccl_p2p_thunk_common.h b/xla/service/gpu/nccl_p2p_thunk_common.h index 25814402491fc..2f37306ef4284 100644 --- a/xla/service/gpu/nccl_p2p_thunk_common.h +++ b/xla/service/gpu/nccl_p2p_thunk_common.h @@ -67,9 +67,10 @@ std::enable_if_t || std::is_same_v, CollectiveOpGroupMode> GetGroupModeForSendRecv(OpT op) { - return GetCollectiveOpGroupMode(op.getChannelHandle().getHandle() > 1, - std::nullopt) - .value(); + // return GetCollectiveOpGroupMode(op.getChannelHandle().getHandle() > 1, + // std::nullopt) + // .value(); + return CollectiveOpGroupMode::kFlattenedID; } // Constructs the NcclP2PConfig for Send and Recv. @@ -104,15 +105,16 @@ GetNcclP2PConfigForSendRecv(OpT op, int64_t replica_count, } // All execution instances of a send/recv together form a replica group. - const int64_t num_participants = - config.group_mode == CollectiveOpGroupMode::kCrossReplica - ? replica_count - : partition_count; + // const int64_t num_participants = + // config.group_mode == CollectiveOpGroupMode::kCrossReplica + // ? replica_count + // : partition_count; + // const int64_t num_participants = 2; config.replica_groups.emplace_back(); ReplicaGroup& replica_group = config.replica_groups.front(); - for (int i = 0; i < num_participants; ++i) { - replica_group.add_replica_ids(i); - } + // for (int i = 0; i < num_participants; ++i) { + // replica_group.add_replica_ids(i); + // } auto source_target_pairs = GetSourceTargetPairs(op.getFrontendAttributes()); TF_CHECK_OK(source_target_pairs.status()); @@ -125,8 +127,14 @@ GetNcclP2PConfigForSendRecv(OpT op, int64_t replica_count, source; p2p_config.id_to_source_target.insert({source, {}}).first->second.target = target; + if (source <= target) { + replica_group.add_replica_ids(source); + replica_group.add_replica_ids(target); + } else { + replica_group.add_replica_ids(target); + replica_group.add_replica_ids(source); + } } - return p2p_config; } diff --git a/xla/service/gpu/nccl_recv_thunk.cc b/xla/service/gpu/nccl_recv_thunk.cc index a50bdf52f5d07..c83a7e7274fea 100644 --- a/xla/service/gpu/nccl_recv_thunk.cc +++ b/xla/service/gpu/nccl_recv_thunk.cc @@ -121,21 +121,21 @@ Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target, se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); // Receive data from the source peer to the destination buffer. - if (source_id) { - VLOG(3) << absl::StreamFormat( - "%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, " - "stream=%p)", - device_string, dest_addr.opaque(), element_count, *source_id, - static_cast(comm), gpu_stream); - XLA_CUDA_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype, - *source_id, comm, gpu_stream)); - } else { - // If there is no source peer, i.e. no sender to this instance, zero out - // the destination buffer. - VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero", - device_string); - stream.ThenMemZero(&dest_addr, dest_addr.size()); - } + // if (source_id) { + VLOG(3) << absl::StreamFormat( + "%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, " + "stream=%p)", + device_string, dest_addr.opaque(), element_count, *source_id, + static_cast(comm), gpu_stream); + XLA_CUDA_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype, + *source_id, comm, gpu_stream)); + // } else { + // // If there is no source peer, i.e. no sender to this instance, zero out + // // the destination buffer. + // VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero", + // device_string); + // stream.ThenMemZero(&dest_addr, dest_addr.size()); + // } return OkStatus(); #else // XLA_ENABLE_XCCL return Unimplemented( diff --git a/xla/service/gpu/nccl_send_thunk.cc b/xla/service/gpu/nccl_send_thunk.cc index 159089bd977f3..76c793b2cf49c 100644 --- a/xla/service/gpu/nccl_send_thunk.cc +++ b/xla/service/gpu/nccl_send_thunk.cc @@ -121,15 +121,15 @@ Status RunSend(NcclP2PConfig::SourceTargetMapEntry source_target, se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); // Send source buffer to target peer if needed. - if (target_id) { - VLOG(3) << absl::StreamFormat( - "%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - device_string, src_addr.opaque(), element_count, *target_id, - static_cast(comm), gpu_stream); - XLA_CUDA_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype, - *target_id, comm, gpu_stream)); - } + // if (target_id) { + VLOG(3) << absl::StreamFormat( + "%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d " + "comm=%p, stream=%p)", + device_string, src_addr.opaque(), element_count, *target_id, + static_cast(comm), gpu_stream); + XLA_CUDA_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype, + *target_id, comm, gpu_stream)); + // } return OkStatus(); #else // XLA_ENABLE_XCCL return Unimplemented(