Skip to content

Commit 03a3ef3

Browse files
authored
MSCCL Multithreaded regression root cause fix (#1347)
* Make sure the target device is used for MSCCL * Enable single process mode by default to use MSCCL in MT * Create a per-rank state when GPUs share a thread
1 parent 105ff16 commit 03a3ef3

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

src/misc/msccl/msccl_lifecycle.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
RCCL_PARAM(MscclEnabled, "MSCCL_ENABLE", 1);
3030
RCCL_PARAM(MscclForceEnabled, "MSCCL_FORCE_ENABLE", 0);
31-
RCCL_PARAM(MscclEnableSingleProcess, "MSCCL_ENABLE_SINGLE_PROCESS", 0);
31+
RCCL_PARAM(MscclEnableSingleProcess, "MSCCL_ENABLE_SINGLE_PROCESS", 1);
3232
static const char* mscclAlgoFilePathEnv = "MSCCL_ALGO_FILE_PATH";
3333

3434
bool mscclEnabled() {

src/misc/msccl/msccl_status.cc

+6-4
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,20 @@ static mutex rankStatesMutex;
2727
static unordered_map<int, shared_ptr<mscclRankState>> rankStates;
2828

2929
static inline mscclRankState& mscclGetRankState(int rank) {
30-
static thread_local shared_ptr<mscclRankState> threadRankState = make_shared<mscclRankState>();
31-
30+
// In the unlikely case of negative rank, return a per-thread state
3231
if (rank < 0) {
32+
static thread_local shared_ptr<mscclRankState> threadRankState(new mscclRankState());
3333
return *threadRankState;
3434
}
3535

3636
lock_guard<mutex> lock(rankStatesMutex);
3737

3838
auto rankStateIt = rankStates.find(rank);
3939
if (rankStateIt == rankStates.end()) {
40-
rankStateIt = rankStates.insert(make_pair(rank, make_shared<mscclRankState>(*threadRankState))).first;
41-
rankStateIt->second->rank = rank;
40+
// Create a per rank threadRankState rather than per thread
41+
shared_ptr<mscclRankState> newthreadRankState(new mscclRankState());
42+
newthreadRankState->rank = rank;
43+
rankStateIt = rankStates.insert(make_pair(rank, newthreadRankState)).first;
4244
}
4345
return *(rankStateIt->second);
4446
}

src/msccl.cc

+8
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ ncclResult_t mscclRunAlgo_impl(
6161
struct mscclAlgo* hostAlgo = status.hostAlgos[mscclAlgoHandle];
6262
struct mscclAlgo* devAlgo = status.devAlgos[mscclAlgoHandle];
6363

64+
// NCCL adds a lot of guarantees that target device is getting used
65+
// in its group management code, which we entirely skip when MSCCL is used
66+
// Therefore, in single thread multiGPU mode
67+
// setting the device is critical to be sure
68+
// communication is done on the intended device
69+
70+
CUDACHECK(hipSetDevice(comm->cudaDev));
71+
6472
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
6573

6674
NCCLCHECK(mscclSetupCount(hostAlgo, comm, count, dataType));

0 commit comments

Comments
 (0)