Skip to content

Commit

Permalink
Fix xla.gpu.send/recv (#2)
Browse files Browse the repository at this point in the history
* update p2p thunk

* fi nccl p2p

---------

Co-authored-by: mochen.bmc <[email protected]>
  • Loading branch information
wbmc and mochen.bmc authored Dec 5, 2023
1 parent 5b044b1 commit 77c1547
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 26 deletions.
2 changes: 2 additions & 0 deletions xla/service/gpu/nccl_collective_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ struct NcclCollectiveConfig {
RendezvousKey::CollectiveOpKind collective_op_kind;
int64_t op_id;
CollectiveOpGroupMode group_mode;
int partition_count;
int replica_count;

template <typename OpT>
void SetCollectiveOpKindAndID(OpT op);
Expand Down
21 changes: 6 additions & 15 deletions xla/service/gpu/nccl_p2p_thunk_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,14 @@ GetNcclP2PConfigForSendRecv(OpT op, int64_t replica_count,
}

// All execution instances of a send/recv together form a replica group.
// const int64_t num_participants =
// config.group_mode == CollectiveOpGroupMode::kCrossReplica
// ? replica_count
// : partition_count;
// const int64_t num_participants = 2;
config.replica_count = replica_count;
config.partition_count = partition_count;
const int64_t num_participants = replica_count * partition_count;
config.replica_groups.emplace_back();
ReplicaGroup& replica_group = config.replica_groups.front();
// for (int i = 0; i < num_participants; ++i) {
// replica_group.add_replica_ids(i);
// }
for (int i = 0; i < num_participants; ++i) {
replica_group.add_replica_ids(i);
}

auto source_target_pairs = GetSourceTargetPairs(op.getFrontendAttributes());
TF_CHECK_OK(source_target_pairs.status());
Expand All @@ -127,13 +125,6 @@ GetNcclP2PConfigForSendRecv(OpT op, int64_t replica_count,
source;
p2p_config.id_to_source_target.insert({source, {}}).first->second.target =
target;
if (source <= target) {
replica_group.add_replica_ids(source);
replica_group.add_replica_ids(target);
} else {
replica_group.add_replica_ids(target);
replica_group.add_replica_ids(source);
}
}
return p2p_config;
}
Expand Down
12 changes: 9 additions & 3 deletions xla/service/gpu/nccl_recv_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,16 @@ Status NcclRecvThunk::RunNcclCollective(const ExecuteParams& params,
TF_ASSIGN_OR_RETURN(
const DeviceAssignment::LogicalID current_logical_id,
params.nccl_params.device_assn->LogicalIdForDevice(global_device_id));
// const int64_t current_id =
// config_.config.group_mode == CollectiveOpGroupMode::kCrossReplica
// ? current_logical_id.replica_id
// : current_logical_id.computation_id;
const int64_t current_id =
config_.config.group_mode == CollectiveOpGroupMode::kCrossReplica
? current_logical_id.replica_id
: current_logical_id.computation_id;
current_logical_id.replica_id * config_.config.partition_count +
current_logical_id.computation_id;
VLOG(3) << "Performing Recv, replica_id: " << current_logical_id.replica_id
<< ", partition_count: " << config_.config.partition_count
<< ", computation_id: " << current_logical_id.computation_id;
std::string device_string = GetDeviceString(params.nccl_params);

const NcclP2PConfig::SourceTargetMapEntry source_target =
Expand Down
14 changes: 10 additions & 4 deletions xla/service/gpu/nccl_send_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,16 @@ Status NcclSendThunk::RunNcclCollective(const ExecuteParams& params,
TF_ASSIGN_OR_RETURN(
const DeviceAssignment::LogicalID current_logical_id,
params.nccl_params.device_assn->LogicalIdForDevice(global_device_id));
// const int64_t current_id =
// config_.config.group_mode == CollectiveOpGroupMode::kCrossReplica
// ? current_logical_id.replica_id
// : current_logical_id.computation_id;
const int64_t current_id =
config_.config.group_mode == CollectiveOpGroupMode::kCrossReplica
? current_logical_id.replica_id
: current_logical_id.computation_id;
current_logical_id.replica_id * config_.config.partition_count +
current_logical_id.computation_id;
VLOG(3) << "Performing Send, replica_id: " << current_logical_id.replica_id
<< ", partition_count: " << config_.config.partition_count
<< ", computation_id: " << current_logical_id.computation_id;
std::string device_string = GetDeviceString(params.nccl_params);

const NcclP2PConfig::SourceTargetMapEntry source_target =
Expand All @@ -103,7 +109,7 @@ Status RunSend(NcclP2PConfig::SourceTargetMapEntry source_target,
// to which this instance will copy its data.

int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing collective permute from device ordinal: "
VLOG(3) << "Performing Send from device ordinal: "
<< device_ordinal << "current_id " << current_id;

const std::optional<int64_t> target_id = source_target.target;
Expand Down
22 changes: 18 additions & 4 deletions xla/service/gpu/runtime/collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,24 @@ absl::Status P2PImplCommon(const ServiceExecutableRunOptions* run_options,
TF_ASSIGN_OR_RETURN(DeviceAssignment::LogicalID current_logical_id,
params.device_assn->LogicalIdForDevice(global_device_id));

const int64_t current_id = static_cast<CollectiveOpGroupMode>(group_mode) ==
CollectiveOpGroupMode::kCrossReplica
? current_logical_id.replica_id
: current_logical_id.computation_id;
int64_t current_id = 0;
switch (static_cast<CollectiveOpGroupMode>(group_mode)) {
case CollectiveOpGroupMode::kFlattenedID: {
int replica_count = params.device_assn->replica_count();
int computation_count = params.device_assn->computation_count();
current_id = current_logical_id.replica_id * computation_count +
current_logical_id.computation_id;
break;
}
case CollectiveOpGroupMode::kCrossReplica: {
current_id = current_logical_id.replica_id;
break;
}
default: {
current_id = current_logical_id.computation_id;
break;
}
}

NcclP2PConfig::IdToSourceTargetMap id_to_source_target;
for (int i = 0; i < source_peers.size(); ++i) {
Expand Down

0 comments on commit 77c1547

Please sign in to comment.