From e5777b99e6223aca272df7d610e4771e1f8132db Mon Sep 17 00:00:00 2001 From: luyang Date: Thu, 2 Jan 2025 14:05:14 +0000 Subject: [PATCH] refactor _nccl_logical_fusion kernel using ccl::CclComm --- .../kernels/nccl_logical_fusion_kernel.cpp | 149 +++++++++++------- 1 file changed, 90 insertions(+), 59 deletions(-) diff --git a/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp b/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp index aeb906b6387..662426c5ea1 100644 --- a/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp @@ -14,6 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "collective_communication/include/all_gather.h" +#include "collective_communication/include/all_reduce.h" +#include "collective_communication/include/collective_communication.h" +#include "collective_communication/include/reduce_scatter.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" @@ -21,6 +25,8 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/ops/nccl_logical_util.h" +#include "oneflow/user/kernels/collective_communication/include/send.h" +#include "oneflow/user/kernels/collective_communication/include/recv.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -83,9 +89,9 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState { } ~NcclLogicalFusionKernelState() override = default; - ncclComm_t comm() { + ccl::CclComm ccl_comm() { if (!is_init_) { InitComm(); } - return comm_; + return ccl_comm_; } int64_t num_ranks() { @@ -169,8 +175,7 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState { } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - comm_ = - comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); is_init_ = true; } @@ -277,7 +282,7 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState { std::vector dst_split_axis_list_; std::vector tmp_buffer_offset_; std::vector tmp_buffer_size_; - ncclComm_t comm_{}; + ccl::CclComm ccl_comm_{}; }; class NcclLogicalFusionKernel final : public user_op::OpKernel { @@ -425,109 +430,135 @@ void DoNcclComputeByNcclTypeInGroup(const void* pack_to_ptr, void* unpack_from_p const std::string& nccl_type, const user_op::Tensor* in, user_op::Tensor* out, user_op::KernelComputeContext* ctx, NcclLogicalFusionKernelState* kernel_state, const int32_t i, - const ncclComm_t& comm) { + ccl::CclComm ccl_comm) { + std::unique_ptr ccl_send = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); + std::unique_ptr ccl_recv = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); + const int64_t num_ranks = kernel_state->num_ranks(); VLOG(3) << "[NcclLogicalFusion] op: " << ctx->op_name() << " , i= " << i << ", stream: " << kernel_state->stream_name() << " Try launch nccl_type: " << nccl_type; if (nccl_type == "_nccl_logical_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); + } else if (nccl_type == "_nccl_logical_reduce_scatter") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt() * num_ranks); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclReduceScatter(in->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_reduce_scatter = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_reduce_scatter->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), + out->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_all_gather") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), comm, - ctx->stream()->As()->cuda_stream())); + + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_all_gather_noncontinuous") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() != unpack_from_ptr); // do unpack from ptr -> out CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), comm, - ctx->stream()->As()->cuda_stream())); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_reduce_scatter_noncontinuous") { CHECK(in->dptr() != pack_to_ptr); // do in -> pack to ptr CHECK(out->mut_dptr() == unpack_from_ptr); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclReduceScatter(pack_to_ptr, out->mut_dptr(), out->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_reduce_scatter = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_reduce_scatter->Launch(ctx->stream(), pack_to_ptr, out->mut_dptr(), + out->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_s2s") { const int64_t elem_cnt = in->shape_view().elem_cnt(); const int64_t dtype_size = GetSizeOfDataType(in->data_type()); const int64_t elem_per_chunk = elem_cnt / num_ranks; const int64_t chunk_size = elem_per_chunk * dtype_size; for (int64_t j = 0; j < num_ranks; ++j) { - OF_NCCL_CHECK(ncclSend(reinterpret_cast( - reinterpret_cast(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); - OF_NCCL_CHECK(ncclRecv( + ccl_send->Launch(ctx->stream(), + reinterpret_cast(reinterpret_cast(pack_to_ptr) + + j * chunk_size), + elem_per_chunk, j, ccl_comm); + ccl_recv->Launch( + ctx->stream(), reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); + elem_per_chunk, j, ccl_comm); } } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), comm, - ctx->stream()->As()->cuda_stream())); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather_noncontinuous") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() != unpack_from_ptr); // do unpack from ptr -> out CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), comm, - ctx->stream()->As()->cuda_stream())); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all2all") { const int64_t elem_cnt = in->shape_view().elem_cnt(); const int64_t dtype_size = GetSizeOfDataType(in->data_type()); const int64_t elem_per_chunk = elem_cnt / num_ranks; const int64_t chunk_size = elem_per_chunk * dtype_size; for (int64_t j = 0; j < num_ranks; ++j) { - OF_NCCL_CHECK(ncclSend(reinterpret_cast( - reinterpret_cast(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); - OF_NCCL_CHECK(ncclRecv( + ccl_send->Launch(ctx->stream(), + reinterpret_cast(reinterpret_cast(pack_to_ptr) + + j * chunk_size), + elem_per_chunk, j, ccl_comm); + ccl_recv->Launch( + ctx->stream(), reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); + elem_per_chunk, j, ccl_comm); } } else if (nccl_type == "_nccl_logical_2D_same_dim1_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); + } else { UNIMPLEMENTED(); } @@ -663,7 +694,7 @@ void NcclLogicalFusionKernel::Compute(user_op::KernelComputeContext* ctx, } // NOTE(chengcheng): init nccl comm need before ncclGroupStart. - ncclComm_t comm = kernel_state->comm(); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); // do nccl compute in group OF_NCCL_CHECK(ncclGroupStart()); @@ -671,7 +702,7 @@ void NcclLogicalFusionKernel::Compute(user_op::KernelComputeContext* ctx, const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", i); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", i); DoNcclComputeByNcclTypeInGroup(pack_to_ptr_list.at(i), unpack_from_ptr_list.at(i), - nccl_type_list.at(i), in, out, ctx, kernel_state, i, comm); + nccl_type_list.at(i), in, out, ctx, kernel_state, i, ccl_comm); } OF_NCCL_CHECK(ncclGroupEnd());