Skip to content

Commit

Permalink
move memory copy into one_shot_all_reduce (#2770)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2770

Avoid latency of launching hipMemcpyAsync. Could see 3-4us reduction in benchmarking. Also see improvements in end to end testing.

Reviewed By: sryap, jianyuh

Differential Revision: D58223358

fbshipit-source-id: c5bf36866ab5f89a8ce186bcd728d02638c12070
  • Loading branch information
xw285cornell authored and facebook-github-bot committed Jun 22, 2024
1 parent 2114817 commit 7f77444
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions fbgemm_gpu/experimental/gen_ai/src/comm/car.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,28 @@ __global__ void one_shot_all_reduce(
int32_t flag,
std::array<int32_t*, 8> barriers,
std::array<at::BFloat16*, 8> inputs,
at::BFloat16* ar_input,
at::BFloat16* acc,
at::BFloat16* output,
int32_t N) {
// It is expensive to launch hipMemcpyAsync on ROCm
// Move data copy here. Each block copies part of input data
at::BFloat16* input = inputs[rank];
for (size_t i = blockDim.x * blockIdx.x * 8 + threadIdx.x * 8; i < N;
i += (size_t)blockDim.x * gridDim.x * 8) {
#if defined(USE_ROCM)
__builtin_nontemporal_store(
reinterpret_cast<uint64_t*>(&ar_input[i])[0], (uint64_t*)(&input[i]));
__builtin_nontemporal_store(
reinterpret_cast<uint64_t*>(&ar_input[i])[1],
(uint64_t*)(&input[i]) + 1);
#else
*reinterpret_cast<uint64_t*>(&input[i]) =
reinterpret_cast<uint64_t*>(&ar_input[i])[0];
*(reinterpret_cast<uint64_t*>(&input[i]) + 1) =
reinterpret_cast<uint64_t*>(&ar_input[i])[1];
#endif
}
// Synchronize the ranks.
volatile int32_t* barrier_d = barriers[rank];
if (threadIdx.x < kWorldSize) {
Expand Down Expand Up @@ -516,13 +535,6 @@ void one_shot_car_allreduce(
barriers[ii] = state->barriers_[ii].data_ptr<int32_t>();
}

AT_CUDA_CHECK(cudaMemcpyAsync(
inputs[state->rank_],
y.data_ptr<at::BFloat16>(),
y.numel() * y.element_size(),
cudaMemcpyDeviceToDevice,
at::cuda::getCurrentCUDAStream()));

constexpr int32_t N_per_thread = 8;
constexpr int32_t N_per_warp = N_per_thread * kThreadsPerWarp;
TORCH_CHECK(N % N_per_warp == 0);
Expand Down Expand Up @@ -555,6 +567,7 @@ void one_shot_car_allreduce(
state->flag_ * state->world_size_, \
barriers, \
inputs, \
y.data_ptr<at::BFloat16>(), \
z ? z->data_ptr<at::BFloat16>() : nullptr, \
y_allreduce.data_ptr<at::BFloat16>(), \
N); \
Expand Down

0 comments on commit 7f77444

Please sign in to comment.