Skip to content

Commit

Permalink
A hot fix to disable CE deadlock check (NVIDIA#926)
Browse files Browse the repository at this point in the history
* A hot fix to disable CE deadlock check

Signed-off-by: Pavel Shamis (Pasha) <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Pavel Shamis (Pasha) <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 14, 2024
1 parent 9416519 commit d71fc94
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2236,8 +2236,8 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
const int peer, cudaStream_t stream) {
int peerlocal = peer % comm->nvsize;
void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 0);
void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1);
void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 2);
// void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1);
// void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 2);
bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0);

assert(INTRANODE(peer));
Expand All @@ -2251,9 +2251,9 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
void *dstptr = reinterpret_cast<char *>(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset;

if (comm->use_ce) {
kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
}
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
int *arg1 = &comm->send_id[peer], *arg2 = reinterpret_cast<int *>(flagptr);
Expand All @@ -2274,18 +2274,18 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size
int send_peerlocal = send_peer % comm->nvsize;
int recv_peerlocal = recv_peer % comm->nvsize;
void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0);
void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1);
void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2);
// void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1);
// void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2);
void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0);

void *send_srcptr = reinterpret_cast<char *>(comm->mem_ptr[srchandler]) + send_offset;
void *send_dstptr =
reinterpret_cast<char *>(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset;

if (comm->use_ce) {
kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
}
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);

Expand All @@ -2302,10 +2302,12 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size
uint64_t arg11 = comm->ub_timeout;
int arg12 = send_peerlocal;
int arg13 = recv_peerlocal;
int *arg14 = reinterpret_cast<int *>(
comm->use_ce ? GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1) : nullptr);
int *arg15 = reinterpret_cast<int *>(
comm->use_ce ? GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2) : nullptr);
int *arg14 = reinterpret_cast<int *>(0 ? // temporary disable
GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1)
: nullptr);
int *arg15 = reinterpret_cast<int *>(0 ? // temporary disable
GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2)
: nullptr);
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2),
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4),
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6),
Expand All @@ -2328,17 +2330,17 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
int send_peerlocal = send_peer % comm->nvsize;
int recv_peerlocal = recv_peer % comm->nvsize;
void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0);
void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1);
void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2);
// void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1);
// void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2);
void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0);

void *send_srcptr = reinterpret_cast<char *>(comm->mem_ptr[srchandler]) + send_offset;
void *send_dstptr =
reinterpret_cast<char *>(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset;
if (comm->use_ce) {
kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
}
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);

Expand All @@ -2356,10 +2358,12 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
int arg12 = comm->ub_timeout;
int arg13 = send_peerlocal;
int arg14 = recv_peerlocal;
int *arg15 = reinterpret_cast<int *>(
comm->use_ce ? GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1) : nullptr);
int *arg16 = reinterpret_cast<int *>(
comm->use_ce ? GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2) : nullptr);
int *arg15 = reinterpret_cast<int *>(0 ? // temporary disable
GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1)
: nullptr);
int *arg16 = reinterpret_cast<int *>(0 ? // temporary disable
GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2)
: nullptr);
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2),
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4),
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6),
Expand Down Expand Up @@ -2447,10 +2451,12 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
comm->myrank, peer, comm->nvrank, peerlocal,
&comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler], reinterpret_cast<int *>(flagptr),
signalonly || comm->sms, comm->ub_timeout,
reinterpret_cast<int *>(comm->use_ce ? GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 1)
: nullptr),
reinterpret_cast<int *>(comm->use_ce ? GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2)
: nullptr));
reinterpret_cast<int *>(0 ? // temporary disable
GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 1)
: nullptr),
reinterpret_cast<int *>(0 ? // temporary disable
GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2)
: nullptr));
}
}

Expand Down

0 comments on commit d71fc94

Please sign in to comment.