Skip to content

Commit

Permalink
refactor ccl::AllGather AllReduce ReduceScatter primitive using ccl::…
Browse files Browse the repository at this point in the history
…CclComm
  • Loading branch information
Flowingsun007 committed Jan 2, 2025
1 parent aac19b4 commit 3cb872a
Show file tree
Hide file tree
Showing 13 changed files with 60 additions and 0 deletions.
5 changes: 5 additions & 0 deletions oneflow/core/job/eager_ccl_comm_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class EagerCclCommMgr {
virtual void CreateCommFromPlan(const Plan& plan) = 0;
virtual bool IsAsyncLaunchCclLogicalKernel() const = 0;
virtual void SetAsyncLaunchCclLogicalKernel(bool val) = 0;
virtual ccl::CclComm GetCclCommForDevice(
const std::set<std::pair<int64_t, int64_t>>& device_set) {
ccl::CclComm ccl_comm{};
return ccl_comm;
}
virtual ccl::CclComm GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set, const std::string& stream_name) {
ccl::CclComm ccl_comm{};
Expand Down
8 changes: 8 additions & 0 deletions oneflow/core/job/eager_nccl_comm_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ ncclComm_t EagerNcclCommMgr::GetCommForDeviceAndStreamName(
return comm;
}

ccl::CclComm EagerNcclCommMgr::GetCclCommForDevice(
const std::set<std::pair<int64_t, int64_t>>& device_set) {
ncclComm_t comm = GetCommForDevice(device_set);
std::shared_ptr<ccl::CommBase> ncclCommAdapter = std::make_shared<ccl::NcclCommAdapter>(comm);
ccl::CclComm ccl_comm(ncclCommAdapter);
return ccl_comm;
}

ccl::CclComm EagerNcclCommMgr::GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set, const std::string& stream_name) {
ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name);
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job/eager_nccl_comm_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class EagerNcclCommMgr final : public EagerCclCommMgr {
ncclComm_t GetCommForDevice(const std::set<std::pair<int64_t, int64_t>>& device_set);
ncclComm_t GetCommForDeviceAndStreamName(const std::set<std::pair<int64_t, int64_t>>& device_set,
const std::string& stream_name);
ccl::CclComm GetCclCommForDevice(const std::set<std::pair<int64_t, int64_t>>& device_set);
ccl::CclComm GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set, const std::string& stream_name);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ class CpuAllGather final : public AllGather {
CHECK_JUST(AllGatherImpl(in, out, elem_cnt, datatype_, cpu_communication_ctx->parallel_desc()));
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const override {
UNIMPLEMENTED();
}

private:
DataType datatype_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class CpuAllReduce final : public AllReduce {
cpu_communication_ctx->parallel_desc()));
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const override {
UNIMPLEMENTED();
}

private:
DataType datatype_;
ReduceType reduce_type_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ class CpuReduceScatter final : public ReduceScatter {
cpu_communication_ctx->parallel_desc()));
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const override {
UNIMPLEMENTED();
}

private:
DataType datatype_;
ReduceType reduce_type_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ class CudaAllGather final : public AllGather {
stream->As<ep::CudaStream>()->cuda_stream()));
}

virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const override {
ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());
OF_NCCL_CHECK(ncclAllGather(in, out, elem_cnt, nccl_datatype_, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}

private:
ncclDataType_t nccl_datatype_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ class CudaAllReduce final : public AllReduce {
stream->As<ep::CudaStream>()->cuda_stream()));
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const override {
ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());
OF_NCCL_CHECK(ncclAllReduce(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}

private:
ncclDataType_t nccl_datatype_;
ncclRedOp_t nccl_reduce_op_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ class CudaReduceScatter final : public ReduceScatter {
stream->As<ep::CudaStream>()->cuda_stream()));
}

virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const override {
ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());
OF_NCCL_CHECK(ncclReduceScatter(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}

private:
ncclDataType_t nccl_datatype_;
ncclRedOp_t nccl_reduce_op_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class AllGather : public CollectiveCommunication {

virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
const std::shared_ptr<CommunicationContext>& communicator) const = 0;

virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const = 0;
};

inline bool IsAllGatherRegistered(DeviceType device_type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class AllReduce : public CollectiveCommunication {

virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
const std::shared_ptr<CommunicationContext>& communicator) const = 0;

virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const = 0;
};

inline bool IsAllReduceRegistered(DeviceType device_type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COMMUNICATION_CONTEXT_H_
#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COMMUNICATION_CONTEXT_H_

#include "collective_communication.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/common/auto_registration_factory.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class ReduceScatter : public CollectiveCommunication {

virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
const std::shared_ptr<CommunicationContext>& communicator) const = 0;

virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const = 0;
};

inline bool IsReduceScatterRegistered(DeviceType device_type) {
Expand Down

0 comments on commit 3cb872a

Please sign in to comment.