From 882cf683267486cb2c0c9f7576f60cccd19729b0 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Sat, 21 Dec 2024 12:16:13 -0800 Subject: [PATCH] [Cherry-pick] NVLS support for NCCL API (#410) (#425) Co-authored-by: Qinghua Zhou Co-authored-by: Changho Hwang --- apps/nccl/include/nccl.h | 9 +++++ apps/nccl/src/nccl.cu | 62 +++++++++++++++++++++++++++++++++- src/executor/execution_plan.cc | 14 ++++++-- src/executor/executor.cc | 15 ++++---- src/include/execution_plan.hpp | 2 +- 5 files changed, 91 insertions(+), 11 deletions(-) diff --git a/apps/nccl/include/nccl.h b/apps/nccl/include/nccl.h index 7f507927b..bfdb22697 100644 --- a/apps/nccl/include/nccl.h +++ b/apps/nccl/include/nccl.h @@ -69,6 +69,15 @@ typedef struct ncclConfig_v21700 { NCCL_CONFIG_UNDEF_INT /* splitShare */ \ } +/* NCCL malloc and free function for all types of NCCL optimizations + * (e.g. user buffer registration). The actual allocated size might + * be larger than requested due to granularity requirement. */ +ncclResult_t ncclMemAlloc(void** ptr, size_t size); +ncclResult_t pncclMemAlloc(void** ptr, size_t size); + +ncclResult_t ncclMemFree(void* ptr); +ncclResult_t pncclMemFree(void* ptr); + /* Return the NCCL_VERSION_CODE of the NCCL library in the supplied integer. * This integer is coded with the MAJOR, MINOR and PATCH level of the * NCCL library diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index cd75edfea..fe240de77 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,9 @@ // mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5, // mscclpp::Transport::IB6, mscclpp::Transport::IB7}; +// Declare the global map to store associations between raw pointer and shared pointer +static std::unordered_map> ptrMap; + struct channelKey { const void* buff; size_t bytes; @@ -113,7 +117,7 @@ static size_t ncclTypeSize(ncclDataType_t type) { return 0; } -double parseSize(const char* value) { +static double parseSize(const char* value) { std::string valueStr(value); std::istringstream iss(valueStr); long long int units; @@ -644,3 +648,59 @@ NCCL_API ncclResult_t ncclGroupEnd() { // Do nothing return ncclSuccess; } + +NCCL_API ncclResult_t ncclCommRegister(const ncclComm_t, void*, size_t, void**) { + // TODO: Implementation + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclCommDeregister(const ncclComm_t, void*) { + // TODO: Implementation + return ncclSuccess; +} + +ncclResult_t ncclMemAlloc(void** ptr, size_t size) { + // Allocate memory using mscclpp::allocSharedPhysicalCuda + if (ptr == nullptr || size == 0) { + return ncclInvalidArgument; + } + std::shared_ptr sharedPtr; + try { + if (mscclpp::isNvlsSupported()) { + sharedPtr = mscclpp::allocSharedPhysicalCuda(size); + } else { + sharedPtr = mscclpp::allocExtSharedCuda(size); + } + if (sharedPtr == nullptr) { + return ncclSystemError; + } + } catch (const mscclpp::Error& e) { + if (e.getErrorCode() == mscclpp::ErrorCode::InvalidUsage) { + return ncclInvalidUsage; + } else { + return ncclInternalError; + } + } catch (const mscclpp::CudaError& e) { + return ncclUnhandledCudaError; + } catch (const mscclpp::CuError& e) { + return ncclUnhandledCudaError; + } catch (const mscclpp::BaseError& e) { + return ncclInternalError; + } + ptrMap[sharedPtr.get()] = sharedPtr; + + // Return the pointer + *ptr = sharedPtr.get(); + return ncclSuccess; +} + +ncclResult_t ncclMemFree(void* ptr) { + auto ptrIt = ptrMap.find(ptr); + if (ptrIt != ptrMap.end()) { + ptrMap.erase(ptrIt); + return ncclSuccess; + } + + // Pointer not found + return ncclInvalidUsage; +} diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index 162c3ef1c..37b7cfda8 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -141,7 +141,17 @@ std::vector ExecutionPlan::Impl::getUnpairedChannelInfos(int rank, return unpaired; } -std::vector ExecutionPlan::Impl::getNvlsInfos(int rank) const { return this->nvlsInfos.at(rank); } +std::vector ExecutionPlan::Impl::getNvlsInfos(int rank, size_t sendBuffserSize, size_t recvBufferSize) const { + if (sendBuffserSize == 0 && recvBufferSize == 0) { + return this->nvlsInfos.at(rank); + } + size_t chunkSize = this->getUpperBoundChunkSize(rank, sendBuffserSize, recvBufferSize); + std::vector infos = this->nvlsInfos.at(rank); + for (auto& info : infos) { + info.bufferSize = info.bufferSize * chunkSize; + } + return infos; +} std::vector ExecutionPlan::Impl::getConnectedPeers(int rank) const { std::set peers; @@ -272,7 +282,7 @@ void ExecutionPlan::Impl::parseChannels( NvlsInfo info; info.bufferType = convertToBufferType(channel["buff"]); for (const auto& group : channel["rankGroups"]) { - info.bufferSize = (int)group["size"] * this->getUpperBoundChunkSize(rank, this->inputSize, this->outputSize); + info.bufferSize = (int)group["size"]; info.ranks.clear(); for (int rank : group["ranks"]) { info.ranks.push_back(rank); diff --git a/src/executor/executor.cc b/src/executor/executor.cc index b8e8c6af3..d2e5ac7e2 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -180,10 +180,10 @@ struct Executor::Impl { context.scratchBufferSize = scratchBufferSize; context.proxyService = std::make_shared(); context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock(); - this->setupConnections(context, rank, plan); + this->setupConnections(context, rank, plan, sendMemRange, recvMemRange); this->setupRegisteredMemories(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan); this->setupChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan); - this->setupNvlsChannels(context, sendbuff, recvbuff, rank, plan); + this->setupNvlsChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan); this->setupDeviceExecutionPlan(context, devicePlanKey, rank, plan); context.deviceExecutionPlansBuffers[devicePlanKey] = allocExtSharedCuda(context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)); @@ -214,7 +214,8 @@ struct Executor::Impl { return flags; }; - void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan) { + void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan, size_t sendBufferSize, + size_t recvBufferSize) { std::vector connectedPeers = plan.impl_->getConnectedPeers(rank); std::vector>> connectionFutures; for (int peer : connectedPeers) { @@ -227,7 +228,7 @@ struct Executor::Impl { context.connections[connectedPeers[i]] = connectionFutures[i].get(); } - std::vector nvlsInfos = plan.impl_->getNvlsInfos(rank); + std::vector nvlsInfos = plan.impl_->getNvlsInfos(rank, sendBufferSize, recvBufferSize); for (const NvlsInfo& info : nvlsInfos) { std::shared_ptr nvlsConnection = mscclpp::connectNvlsCollective(this->comm, info.ranks, info.bufferSize); @@ -351,9 +352,9 @@ struct Executor::Impl { } } - void setupNvlsChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, int rank, - const ExecutionPlan& plan) { - std::vector nvlsInfos = plan.impl_->getNvlsInfos(rank); + void setupNvlsChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize, + size_t recvBufferSize, int rank, const ExecutionPlan& plan) { + std::vector nvlsInfos = plan.impl_->getNvlsInfos(rank, sendBufferSize, recvBufferSize); for (size_t i = 0; i < nvlsInfos.size(); i++) { std::shared_ptr nvlsConnection = context.nvlsConnections[i]; NvlsInfo info = nvlsInfos[i]; diff --git a/src/include/execution_plan.hpp b/src/include/execution_plan.hpp index 3af585508..95c3aadd1 100644 --- a/src/include/execution_plan.hpp +++ b/src/include/execution_plan.hpp @@ -69,7 +69,7 @@ struct ExecutionPlan::Impl { std::vector getChannelInfos(int rank, BufferType bufferType) const; std::vector getChannelInfosByDstRank(int rank, BufferType bufferType) const; std::vector getUnpairedChannelInfos(int rank, int worldSize, ChannelType channelType); - std::vector getNvlsInfos(int rank) const; + std::vector getNvlsInfos(int rank, size_t sendBuffserSize = 0, size_t recvBufferSize = 0) const; std::vector getConnectedPeers(int rank) const; std::vector getConnectedBufferTypes(int rank) const; size_t getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const;