From d71fc946c7ec3488056925685484cf904df77456 Mon Sep 17 00:00:00 2001 From: "Pavel Shamis (Pasha)" Date: Fri, 14 Jun 2024 10:27:28 -0700 Subject: [PATCH] A hot fix to disable CE deadlock check (#926) * A hot fix to disable CE deadlock check Signed-off-by: Pavel Shamis (Pasha) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pavel Shamis (Pasha) Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../pytorch/csrc/userbuffers/userbuffers.cu | 54 ++++++++++--------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index 7632e69a0a..716874176d 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -2236,8 +2236,8 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds const int peer, cudaStream_t stream) { int peerlocal = peer % comm->nvsize; void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 0); - void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1); - void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 2); + // void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1); + // void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 2); bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); assert(INTRANODE(peer)); @@ -2251,9 +2251,9 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds void *dstptr = reinterpret_cast(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset; if (comm->use_ce) { - kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); + // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); - kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); + // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); int *arg1 = &comm->send_id[peer], *arg2 = reinterpret_cast(flagptr); @@ -2274,8 +2274,8 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size int send_peerlocal = send_peer % comm->nvsize; int recv_peerlocal = recv_peer % comm->nvsize; void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); - void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1); - void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2); + // void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1); + // void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2); void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); void *send_srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + send_offset; @@ -2283,9 +2283,9 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; if (comm->use_ce) { - kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); + // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); - kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); + // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); @@ -2302,10 +2302,12 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size uint64_t arg11 = comm->ub_timeout; int arg12 = send_peerlocal; int arg13 = recv_peerlocal; - int *arg14 = reinterpret_cast( - comm->use_ce ? GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1) : nullptr); - int *arg15 = reinterpret_cast( - comm->use_ce ? GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2) : nullptr); + int *arg14 = reinterpret_cast(0 ? // temporary disable + GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1) + : nullptr); + int *arg15 = reinterpret_cast(0 ? // temporary disable + GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2) + : nullptr); void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), reinterpret_cast(&arg3), reinterpret_cast(&arg4), reinterpret_cast(&arg5), reinterpret_cast(&arg6), @@ -2328,17 +2330,17 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, int send_peerlocal = send_peer % comm->nvsize; int recv_peerlocal = recv_peer % comm->nvsize; void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); - void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1); - void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2); + // void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1); + // void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2); void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); void *send_srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + send_offset; void *send_dstptr = reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; if (comm->use_ce) { - kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); + // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); - kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); + // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); @@ -2356,10 +2358,12 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, int arg12 = comm->ub_timeout; int arg13 = send_peerlocal; int arg14 = recv_peerlocal; - int *arg15 = reinterpret_cast( - comm->use_ce ? GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1) : nullptr); - int *arg16 = reinterpret_cast( - comm->use_ce ? GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2) : nullptr); + int *arg15 = reinterpret_cast(0 ? // temporary disable + GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1) + : nullptr); + int *arg16 = reinterpret_cast(0 ? // temporary disable + GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2) + : nullptr); void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), reinterpret_cast(&arg3), reinterpret_cast(&arg4), reinterpret_cast(&arg5), reinterpret_cast(&arg6), @@ -2447,10 +2451,12 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds comm->myrank, peer, comm->nvrank, peerlocal, &comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler], reinterpret_cast(flagptr), signalonly || comm->sms, comm->ub_timeout, - reinterpret_cast(comm->use_ce ? GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 1) - : nullptr), - reinterpret_cast(comm->use_ce ? GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2) - : nullptr)); + reinterpret_cast(0 ? // temporary disable + GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 1) + : nullptr), + reinterpret_cast(0 ? // temporary disable + GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2) + : nullptr)); } }