Skip to content

Commit

Permalink
refactor _nccl_logical_fusion kernel using ccl::CclComm
Browse files Browse the repository at this point in the history
  • Loading branch information
Flowingsun007 committed Jan 2, 2025
1 parent 3cb872a commit e5777b9
Showing 1 changed file with 90 additions and 59 deletions.
149 changes: 90 additions & 59 deletions oneflow/user/kernels/nccl_logical_fusion_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

#include "collective_communication/include/all_gather.h"
#include "collective_communication/include/all_reduce.h"
#include "collective_communication/include/collective_communication.h"
#include "collective_communication/include/reduce_scatter.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/device/nccl_util.h"
#include "oneflow/core/job/eager_nccl_comm_manager.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/ep/include/primitive/permute.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include "oneflow/user/ops/nccl_logical_util.h"
#include "oneflow/user/kernels/collective_communication/include/send.h"
#include "oneflow/user/kernels/collective_communication/include/recv.h"

#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700

Expand Down Expand Up @@ -83,9 +89,9 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState {
}
~NcclLogicalFusionKernelState() override = default;

ncclComm_t comm() {
ccl::CclComm ccl_comm() {
if (!is_init_) { InitComm(); }
return comm_;
return ccl_comm_;
}

int64_t num_ranks() {
Expand Down Expand Up @@ -169,8 +175,7 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState {
}

EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());
comm_ =
comm_mgr->As<EagerNcclCommMgr>()->GetCommForDeviceAndStreamName(device_set, stream_name_);
ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_);
is_init_ = true;
}

Expand Down Expand Up @@ -277,7 +282,7 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState {
std::vector<int64_t> dst_split_axis_list_;
std::vector<size_t> tmp_buffer_offset_;
std::vector<size_t> tmp_buffer_size_;
ncclComm_t comm_{};
ccl::CclComm ccl_comm_{};
};

class NcclLogicalFusionKernel final : public user_op::OpKernel {
Expand Down Expand Up @@ -425,109 +430,135 @@ void DoNcclComputeByNcclTypeInGroup(const void* pack_to_ptr, void* unpack_from_p
const std::string& nccl_type, const user_op::Tensor* in,
user_op::Tensor* out, user_op::KernelComputeContext* ctx,
NcclLogicalFusionKernelState* kernel_state, const int32_t i,
const ncclComm_t& comm) {
ccl::CclComm ccl_comm) {
std::unique_ptr<ccl::Send> ccl_send =
ccl::NewCollectiveCommunication<ccl::Send>(ctx->stream()->device_type(), in->data_type());
std::unique_ptr<ccl::Recv> ccl_recv =
ccl::NewCollectiveCommunication<ccl::Recv>(ctx->stream()->device_type(), in->data_type());

const int64_t num_ranks = kernel_state->num_ranks();
VLOG(3) << "[NcclLogicalFusion] op: " << ctx->op_name() << " , i= " << i
<< ", stream: " << kernel_state->stream_name() << " Try launch nccl_type: " << nccl_type;
if (nccl_type == "_nccl_logical_all_reduce") {
CHECK(in->dptr() == pack_to_ptr);
CHECK(out->mut_dptr() == unpack_from_ptr);
ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum;
if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; }
OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),
GetNcclDataType(in->data_type()), reduce_type, comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;
if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }
std::unique_ptr<ccl::AllReduce> ccl_all_reduce =
ccl::NewCollectiveCommunication<ccl::AllReduce>(ctx->stream()->device_type(),
in->data_type(), ccl_reduce_type);
ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),
ccl_comm);

} else if (nccl_type == "_nccl_logical_reduce_scatter") {
CHECK(in->dptr() == pack_to_ptr);
CHECK(out->mut_dptr() == unpack_from_ptr);
CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt() * num_ranks);
ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum;
if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; }
OF_NCCL_CHECK(ncclReduceScatter(in->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(),
GetNcclDataType(in->data_type()), reduce_type, comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;
if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }
std::unique_ptr<ccl::ReduceScatter> ccl_reduce_scatter =
ccl::NewCollectiveCommunication<ccl::ReduceScatter>(ctx->stream()->device_type(),
in->data_type(), ccl_reduce_type);
ccl_reduce_scatter->Launch(ctx->stream(), in->dptr(), out->mut_dptr(),
out->shape_view().elem_cnt(), ccl_comm);
} else if (nccl_type == "_nccl_logical_all_gather") {
CHECK(in->dptr() == pack_to_ptr);
CHECK(out->mut_dptr() == unpack_from_ptr);
CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt());
OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),
GetNcclDataType(in->data_type()), comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));

std::unique_ptr<ccl::AllGather> ccl_all_gather =
ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->stream()->device_type(),
in->data_type());
ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),
ccl_comm);
} else if (nccl_type == "_nccl_logical_all_gather_noncontinuous") {
CHECK(in->dptr() == pack_to_ptr);
CHECK(out->mut_dptr() != unpack_from_ptr); // do unpack from ptr -> out
CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt());
OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(),
GetNcclDataType(in->data_type()), comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
std::unique_ptr<ccl::AllGather> ccl_all_gather =
ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->stream()->device_type(),
in->data_type());
ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(),
ccl_comm);
} else if (nccl_type == "_nccl_logical_reduce_scatter_noncontinuous") {
CHECK(in->dptr() != pack_to_ptr); // do in -> pack to ptr
CHECK(out->mut_dptr() == unpack_from_ptr);
ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum;
if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; }
OF_NCCL_CHECK(ncclReduceScatter(pack_to_ptr, out->mut_dptr(), out->shape_view().elem_cnt(),
GetNcclDataType(in->data_type()), reduce_type, comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;
if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }
std::unique_ptr<ccl::ReduceScatter> ccl_reduce_scatter =
ccl::NewCollectiveCommunication<ccl::ReduceScatter>(ctx->stream()->device_type(),
in->data_type(), ccl_reduce_type);
ccl_reduce_scatter->Launch(ctx->stream(), pack_to_ptr, out->mut_dptr(),
out->shape_view().elem_cnt(), ccl_comm);
} else if (nccl_type == "_nccl_logical_s2s") {
const int64_t elem_cnt = in->shape_view().elem_cnt();
const int64_t dtype_size = GetSizeOfDataType(in->data_type());
const int64_t elem_per_chunk = elem_cnt / num_ranks;
const int64_t chunk_size = elem_per_chunk * dtype_size;
for (int64_t j = 0; j < num_ranks; ++j) {
OF_NCCL_CHECK(ncclSend(reinterpret_cast<const void*>(
reinterpret_cast<const char*>(pack_to_ptr) + j * chunk_size),
elem_per_chunk, GetNcclDataType(in->data_type()), j, comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
OF_NCCL_CHECK(ncclRecv(
ccl_send->Launch(ctx->stream(),
reinterpret_cast<const void*>(reinterpret_cast<const char*>(pack_to_ptr)
+ j * chunk_size),
elem_per_chunk, j, ccl_comm);
ccl_recv->Launch(
ctx->stream(),
reinterpret_cast<void*>(reinterpret_cast<char*>(unpack_from_ptr) + j * chunk_size),
elem_per_chunk, GetNcclDataType(in->data_type()), j, comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
elem_per_chunk, j, ccl_comm);
}
} else if (nccl_type == "_nccl_logical_2D_same_dim0_all_reduce") {
CHECK(in->dptr() == pack_to_ptr);
CHECK(out->mut_dptr() == unpack_from_ptr);
ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum;
if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; }
OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),
GetNcclDataType(in->data_type()), reduce_type, comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;
if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }
std::unique_ptr<ccl::AllReduce> ccl_all_reduce =
ccl::NewCollectiveCommunication<ccl::AllReduce>(ctx->stream()->device_type(),
in->data_type(), ccl_reduce_type);
ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),
ccl_comm);
} else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather") {
CHECK(in->dptr() == pack_to_ptr);
CHECK(out->mut_dptr() == unpack_from_ptr);
CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt());
OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),
GetNcclDataType(in->data_type()), comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
std::unique_ptr<ccl::AllGather> ccl_all_gather =
ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->stream()->device_type(),
in->data_type());
ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),
ccl_comm);
} else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather_noncontinuous") {
CHECK(in->dptr() == pack_to_ptr);
CHECK(out->mut_dptr() != unpack_from_ptr); // do unpack from ptr -> out
CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt());
OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(),
GetNcclDataType(in->data_type()), comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
std::unique_ptr<ccl::AllGather> ccl_all_gather =
ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->stream()->device_type(),
in->data_type());
ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(),
ccl_comm);
} else if (nccl_type == "_nccl_logical_2D_same_dim0_all2all") {
const int64_t elem_cnt = in->shape_view().elem_cnt();
const int64_t dtype_size = GetSizeOfDataType(in->data_type());
const int64_t elem_per_chunk = elem_cnt / num_ranks;
const int64_t chunk_size = elem_per_chunk * dtype_size;
for (int64_t j = 0; j < num_ranks; ++j) {
OF_NCCL_CHECK(ncclSend(reinterpret_cast<const void*>(
reinterpret_cast<const char*>(pack_to_ptr) + j * chunk_size),
elem_per_chunk, GetNcclDataType(in->data_type()), j, comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
OF_NCCL_CHECK(ncclRecv(
ccl_send->Launch(ctx->stream(),
reinterpret_cast<const void*>(reinterpret_cast<const char*>(pack_to_ptr)
+ j * chunk_size),
elem_per_chunk, j, ccl_comm);
ccl_recv->Launch(
ctx->stream(),
reinterpret_cast<void*>(reinterpret_cast<char*>(unpack_from_ptr) + j * chunk_size),
elem_per_chunk, GetNcclDataType(in->data_type()), j, comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
elem_per_chunk, j, ccl_comm);
}
} else if (nccl_type == "_nccl_logical_2D_same_dim1_all_reduce") {
CHECK(in->dptr() == pack_to_ptr);
CHECK(out->mut_dptr() == unpack_from_ptr);
ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum;
if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; }
OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),
GetNcclDataType(in->data_type()), reduce_type, comm,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;
if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }
std::unique_ptr<ccl::AllReduce> ccl_all_reduce =
ccl::NewCollectiveCommunication<ccl::AllReduce>(ctx->stream()->device_type(),
in->data_type(), ccl_reduce_type);
ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),
ccl_comm);

} else {
UNIMPLEMENTED();
}
Expand Down Expand Up @@ -663,15 +694,15 @@ void NcclLogicalFusionKernel::Compute(user_op::KernelComputeContext* ctx,
}

// NOTE(chengcheng): init nccl comm need before ncclGroupStart.
ncclComm_t comm = kernel_state->comm();
ccl::CclComm ccl_comm = kernel_state->ccl_comm();

// do nccl compute in group
OF_NCCL_CHECK(ncclGroupStart());
for (int32_t i = 0; i < nccl_num; ++i) {
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", i);
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", i);
DoNcclComputeByNcclTypeInGroup(pack_to_ptr_list.at(i), unpack_from_ptr_list.at(i),
nccl_type_list.at(i), in, out, ctx, kernel_state, i, comm);
nccl_type_list.at(i), in, out, ctx, kernel_state, i, ccl_comm);
}
OF_NCCL_CHECK(ncclGroupEnd());

Expand Down

0 comments on commit e5777b9

Please sign in to comment.