diff --git a/CMakeLists.txt b/CMakeLists.txt index a0ba1fd99..9874c48c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,12 @@ list(APPEND CMAKE_MODULE_PATH ${TORCH_XPU_OPS_ROOT}/cmake/Modules) include(${TORCH_XPU_OPS_ROOT}/cmake/SYCL.cmake) include(${TORCH_XPU_OPS_ROOT}/cmake/BuildFlags.cmake) +option(USE_XCCL "Build with XCCL support" ON) + +if(NOT WIN32 AND USE_XCCL) + include(${TORCH_XPU_OPS_ROOT}/cmake/XCCL.cmake) +endif() + if(BUILD_TEST) add_subdirectory(${TORCH_XPU_OPS_ROOT}/test/sycl ${CMAKE_BINARY_DIR}/test_sycl) endif() diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake new file mode 100644 index 000000000..b211af9a9 --- /dev/null +++ b/cmake/Modules/FindXCCL.cmake @@ -0,0 +1,62 @@ +# This will define the following variables: +# XCCL_FOUND : True if the system has the XCCL library. +# XCCL_INCLUDE_DIR : Include directories needed to use XCCL. +# XCCL_LIBRARY_DIR :The path to the XCCL library. +# XCCL_LIBRARY : XCCL library fullname. + +include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) + +# we need source OneCCL environment before building. +set(XCCL_ROOT $ENV{CCL_ROOT}) + +# Find include path from binary. +find_file( + XCCL_INCLUDE_DIR + NAMES include + HINTS ${XCCL_ROOT} + NO_DEFAULT_PATH +) + +# Find include/oneapi path from include path. +find_file( + XCCL_INCLUDE_ONEAPI_DIR + NAMES oneapi + HINTS ${XCCL_ROOT}/include/ + NO_DEFAULT_PATH +) + +list(APPEND XCCL_INCLUDE_DIR ${XCCL_INCLUDE_ONEAPI_DIR}) + +# Find library directory from binary. +find_file( + XCCL_LIBRARY_DIR + NAMES lib + HINTS ${XCCL_ROOT} + NO_DEFAULT_PATH +) + +# Find XCCL library fullname. +find_library( + XCCL_LIBRARY + NAMES ccl + HINTS ${XCCL_LIBRARY_DIR} + NO_DEFAULT_PATH +) + +if((NOT XCCL_INCLUDE_DIR) OR (NOT XCCL_LIBRARY_DIR) OR (NOT XCCL_LIBRARY)) + set(XCCL_FOUND False) + set(XCCL_NOT_FOUND_MESSAGE "OneCCL library not found!!") + return() +endif() + +SET(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH} + "${XCCL_INCLUDE_DIR}") +SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} + "${XCCL_LIBRARY_DIR}") + +find_package_handle_standard_args( + XCCL + FOUND_VAR XCCL_FOUND + REQUIRED_VARS XCCL_INCLUDE_DIR XCCL_LIBRARY_DIR XCCL_LIBRARY + REASON_FAILURE_MESSAGE "${XCCL_NOT_FOUND_MESSAGE}" +) diff --git a/cmake/XCCL.cmake b/cmake/XCCL.cmake new file mode 100644 index 000000000..ffe040291 --- /dev/null +++ b/cmake/XCCL.cmake @@ -0,0 +1,21 @@ +if(NOT __XCCL_INCLUDED) + set(__XCCL_INCLUDED TRUE) + + # XCCL_ROOT, XCCL_LIBRARY_DIR, XCCL_INCLUDE_DIR are handled by FindXCCL.cmake. + find_package(XCCL REQUIRED) + if(NOT XCCL_FOUND) + message("${XCCL_NOT_FOUND_MESSAGE}") + return() + endif() + if(XCCL_FOUND) + add_library(torch::xccl INTERFACE IMPORTED) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${XCCL_INCLUDE_DIR}) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_LINK_LIBRARIES + ${XCCL_LIBRARY}) + set(USE_C10D_XCCL ON) + set(USE_C10D_XCCL ${USE_C10D_XCCL} PARENT_SCOPE) + endif() +endif() diff --git a/src/BuildOnLinux.cmake b/src/BuildOnLinux.cmake index 1590919c0..d0f28ad29 100644 --- a/src/BuildOnLinux.cmake +++ b/src/BuildOnLinux.cmake @@ -8,7 +8,13 @@ add_library( STATIC ${ATen_XPU_CPP_SRCS} ${ATen_XPU_NATIVE_CPP_SRCS} - ${ATen_XPU_GEN_SRCS}) + ${ATen_XPU_GEN_SRCS} + ${ATen_XPU_XCCL_SRCS}) + +if(USE_C10D_XCCL) + target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL) + target_link_libraries(torch_xpu_ops PUBLIC torch::xccl) +endif() if(BUILD_SEPARATE_OPS) foreach(sycl_src ${ATen_XPU_SYCL_SRCS}) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0716ca5af..7a427e294 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,11 +4,14 @@ include(${TORCH_XPU_OPS_ROOT}/cmake/Codegen.cmake) set(ATen_XPU_CPP_SRCS) set(ATen_XPU_NATIVE_CPP_SRCS) set(ATen_XPU_SYCL_SRCS) +set(ATen_XPU_XCCL_SRCS) set(ATen_XPU_INCLUDE_DIRS ${TORCH_XPU_OPS_ROOT}/src CACHE STRING "ATen XPU Include directory") add_subdirectory(ATen) - +if(USE_C10D_XCCL) + add_subdirectory(xccl) +endif() # With the increasement of bin size, we have to split libtorch_xpu.so into # multiple libraries. Because of strict linkage requirements on Windows, # we add extra logics to resolve, 1) Cyclic dependence, 2) Make symbols visible. diff --git a/src/xccl/CMakeLists.txt b/src/xccl/CMakeLists.txt new file mode 100644 index 000000000..f147b55ca --- /dev/null +++ b/src/xccl/CMakeLists.txt @@ -0,0 +1,16 @@ +# XCCL sources + +file(GLOB xccl_h "*.hpp") +file(GLOB xccl_cpp "*.cpp") + +list(APPEND ATen_XPU_XCCL_SRCS ${xccl_cpp}) + +set(ATen_XPU_XCCL_SRCS ${ATen_XPU_XCCL_SRCS} PARENT_SCOPE) + +# Why copy the header file to the build directory? +# We want register XCCL backend to PyTorch c10d in torch/csrc/distributed/c10d/init.cpp#L27-L29. +# To align with other backends, we need to copy the header file to the build torch/csrc/distributed/c10d directory. +# Further solution is add find path for torch/csrc/distributed/c10d/init.cpp#L27-L29. +foreach(HEADER ${xccl_h}) + file(COPY ${HEADER} DESTINATION "${CMAKE_BINARY_DIR}/torch/csrc/distributed/c10d") +endforeach() diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp new file mode 100644 index 000000000..54db563c1 --- /dev/null +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -0,0 +1,290 @@ +#ifdef USE_C10D_XCCL + +#include +#include +#include + +namespace c10d { + +namespace { +const std::map xcclOps = { + {ReduceOp::MIN, ccl::reduction::min}, + {ReduceOp::MAX, ccl::reduction::max}, + {ReduceOp::SUM, ccl::reduction::sum}, + {ReduceOp::PRODUCT, ccl::reduction::prod}, +}; + +const std::map xcclDatatypes = { + {at::kByte, ccl::datatype::uint8}, + {at::kChar, ccl::datatype::int8}, + {at::kInt, ccl::datatype::int32}, + {at::kLong, ccl::datatype::int64}, + {at::kHalf, ccl::datatype::float16}, + {at::kFloat, ccl::datatype::float32}, + {at::kDouble, ccl::datatype::float64}, + {at::kBFloat16, ccl::datatype::bfloat16}, + {at::kBool, ccl::datatype::uint8}, + // use for non-reducetion op like allgather + {at::kFloat8_e5m2, ccl::datatype::uint8}, + {at::kFloat8_e4m3fn, ccl::datatype::uint8}, + {at::kFloat8_e4m3fnuz, ccl::datatype::uint8}, + {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, +}; + +void checkXPUTensor(at::Tensor& tensor) { + if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR( + ValueError, "Tensors must be XPU and dense and non-complex"); + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + } + } +} + +ccl::datatype getXcclDataType( + at::ScalarType type, + bool is_reduction_op = false) { + TORCH_CHECK( + !isFloat8Type(type) && is_reduction_op, + "Float8 dtypes are not currenlty supported for XCCL reductions"); + auto it = xcclDatatypes.find(type); + TORCH_CHECK_WITH( + TypeError, + it != xcclDatatypes.end(), + "Input tensor data type is not supported for XCCL process group: ", + type); + return it->second; +} + +ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { + try { + if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) { + // Map sum to max for bool tensors to avoid overflow issues with sum. + return ccl::reduction::max; + } + return xcclOps.at(reduceOp); + } catch (const std::out_of_range&) { + C10_THROW_ERROR( + ValueError, + "Cannot use ReduceOp." + reduceOpToString(reduceOp) + " with XCCL"); + } +} + +void syncStream( + at::Device& device, + at::xpu::XPUEvent& xcclEvent, + at::xpu::XPUStream& xcclStream) { + xcclEvent.record(at::xpu::getCurrentXPUStream(device.index())); + xcclEvent.block(xcclStream); +} +} // namespace + +constexpr int64_t kSynchronizeBusyWaitMillis = 10; + +ProcessGroupXCCL::WorkXCCL::WorkXCCL( + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + const char* profilingTitle, + const std::optional>& inputs) + : Work(rank, opType, profilingTitle, inputs), + device_(device), + workStartTime_(std::chrono::steady_clock::now()), + seq_(seq) { + xcclEndEvent_ = std::make_shared(); +} + +ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) + : Work(w.rank_, w.opType_), + device_(w.device_), + xcclEndEvent_(w.xcclEndEvent_), + blockingWait_(w.blockingWait_), + workStartTime_(w.workStartTime_), + seq_(w.seq_) {} + +ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; + +bool ProcessGroupXCCL::WorkXCCL::isCompleted() { + if (xcclEndEvent_ && xcclEndEvent_->query()) { + return true; + } + return false; +} + +void ProcessGroupXCCL::WorkXCCL::synchronize() { + synchronizeInternal(kNoTimeout); +} + +void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( + std::chrono::milliseconds timeout) { + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + xcclEndEvent_->block(currentStream); + if (blockingWait_) { + while (!isCompleted()) { + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + if (timeElapsed >= timeout) { + std::string exceptionMsg = c10::str( + "Work ran time out after ", timeElapsed.count(), " milliseconds."); + TORCH_CHECK(false, exceptionMsg) + } + std::this_thread::sleep_for( + std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); + } + } +} + +bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { + synchronizeInternal(timeout); + return true; +} + +ProcessGroupXCCL::ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size) + : Backend(rank, size), store_(store), xcclCommCounter_(0) { + blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); + init(); +} + +ProcessGroupXCCL::~ProcessGroupXCCL() = default; + +c10::intrusive_ptr ProcessGroupXCCL::initWork( + at::Device& device, + int rank, + OpType opType, + const char* profilingTitle, + const std::vector& inputs, + const std::vector& outputs) { + auto r = c10::make_intrusive( + device, + rank, + opType, + seqCollective_, + profilingTitle, + std::optional>(inputs)); + return r; +} + +std::shared_ptr ProcessGroupXCCL::getXCCLComm( + const std::string& deviceKey, + at::Device& device) { + TORCH_CHECK_WITH( + DistBackendError, + !deviceKey.empty(), + "Not able to create/get " + "XCCL Communicator since the devices are empty "); + { + std::lock_guard lock(mutex_); + if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { + return devXCCLCommMap_[deviceKey]; + } + } + + int numRanks, rank; + numRanks = getSize(); + rank = getRank(); + + c10::impl::VirtualGuardImpl impl(device.type()); + c10::Stream stream = + impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); + sycl::queue& q = c10::xpu::XPUStream(stream).queue(); + + auto ctx = ccl::create_context(q.get_context()); + ccl::vector_class> devs_rank; + devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); + + auto xccl_kvs = get_kvs(rank_, *store_); + auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); + std::shared_ptr XCCLComm = + std::make_shared(std::move(comms[0])); + + std::lock_guard lock(mutex_); + devXCCLCommMap_.emplace(deviceKey, XCCLComm); + xcclStreamsMap_.emplace(deviceKey, std::move(stream)); + xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); + + return XCCLComm; +} + +template +c10::intrusive_ptr ProcessGroupXCCL::collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle) { + seqCollective_++; + + auto device = inputs[0].device(); + const auto key = std::to_string(device.index()); + auto comm = getXCCLComm(key, device); + + auto stream = xcclStreamsMap_.at(key); + syncStream(device, xcclEventsMap_[key], stream); + + c10::intrusive_ptr work; + work = initWork(device, rank_, opType, profilingTitle); + work->outputs_ = std::make_shared>(outputs); + + at::xpu::OptionalXPUGuard gpuGuard(device); + pre(stream, work); + for (const auto i : c10::irange(inputs.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputs[i].storage().data_ptr(), stream); + fn(inputs[i], outputs[i], *comm, stream); + } + post(stream, work); + + work->xcclEndEvent_->record(stream); + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + work->blockingWait_ = blockingWait_; + + return work; +} + +c10::intrusive_ptr ProcessGroupXCCL::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK( + tensors.size() == 1, "Expecting one tensor only but got multiple"); + auto tensor = tensors.back(); + checkXPUTensor(tensor); + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + auto ccl_stream = ccl::create_stream(stream.queue()); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl_stream); + return; + }, + OpType::ALLREDUCE, + "xccl:all_reduce"); +} + +} // namespace c10d + +#endif // USE_C10D_XCCL diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp new file mode 100644 index 000000000..21269bd6f --- /dev/null +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -0,0 +1,222 @@ +#pragma once + +#ifdef USE_C10D_XCCL +// We will define those flags in XCCL backend file instead of passing to gcc +// compiler. +#define CCL_ENABLE_ZE +#define CCL_ENABLE_SYCL + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +namespace c10d { + +static std::vector TORCH_XCCL_BLOCKING_WAIT = { + "TORCH_XCCL_BLOCKING_WAIT", + "XCCL_BLOCKING_WAIT"}; + +using xcclComm_t = ccl::communicator; +constexpr const char* XCCL_BACKEND_NAME = "xccl"; + +class TORCH_API ProcessGroupXCCL : public Backend { + public: + class WorkXCCL : public Work { + public: + WorkXCCL( + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + const char* profilingTitle = nullptr, + const std::optional>& inputs = std::nullopt); + WorkXCCL(const WorkXCCL& w); + ~WorkXCCL() override; + + bool isCompleted() override; + + void abort() override { + TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::abort not implemented"); + } + + void synchronize() override; + + bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; + + c10::intrusive_ptr getFuture() override { + return future_; + } + + uint64_t getSequencenumber() const override { + return seq_; + } + + std::vector result() override { + return *outputs_; + } + + protected: + at::Device device_; + std::shared_ptr xcclEndEvent_; + bool blockingWait_ = false; + std::chrono::time_point workStartTime_; + uint64_t seq_; + + private: + void synchronizeInternal(std::chrono::milliseconds timeout); + std::shared_ptr> outputs_; + c10::intrusive_ptr future_; + friend class ProcessGroupXCCL; + }; + + ProcessGroupXCCL(const c10::intrusive_ptr& store, int rank, int size); + + C10_DEPRECATED ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size, + const std::string& groupName) + : ProcessGroupXCCL(store, rank, size) {} + + ~ProcessGroupXCCL() override; + + const std::string getBackendName() const override { + return std::string(XCCL_BACKEND_NAME); + } + + std::shared_ptr getXCCLComm( + const std::string& deviceKey, + at::Device& device); + + virtual c10::intrusive_ptr initWork( + at::Device& device, + int rank, + OpType opType, + const char* profilingTitle = nullptr, + const std::vector& inputs = {}, + const std::vector& outputs = {}); + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective( + inputs, + outputs, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + opType); + } + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr); + + c10::intrusive_ptr allreduce( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; + + void setSequenceNumberForGroup() override {} + uint64_t getSequenceNumberForGroup() override { + return seqCollective_; + } + + protected: + std::unordered_map xcclStreamsMap_; + std::unordered_map xcclEventsMap_; + std::unordered_map> devXCCLCommMap_; + c10::intrusive_ptr store_; + uint64_t xcclCommCounter_{0}; + std::mutex mutex_; + bool blockingWait_ = false; + uint64_t seqCollective_{0}; + + private: + std::mutex kvs_mutex; + + ccl::shared_ptr_class get_kvs( + int rank, + c10d::Store& store, + bool singleP2POp = false, + const std::string& p2pKey = "", + int p2pRank = 0) { + std::lock_guard lock(kvs_mutex); + ccl::shared_ptr_class kvs; + std::string storeKey; + if (!singleP2POp) { + storeKey = std::to_string(xcclCommCounter_++); + } else { + storeKey = p2pKey; + } + // Rank 0 broadcast the bootstrap network information to other ranks + if (rank == 0 || (singleP2POp && p2pRank == 0)) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = + std::vector(main_addr.begin(), main_addr.end()); + store.set(storeKey, ccl_kvs_addr); + } else { + auto ccl_kvs_addr = store.get(storeKey); + if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); + } + ccl::kvs::address_type main_addr; + std::copy_n( + ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); + kvs = ccl::create_kvs(main_addr); + } + return kvs; + } +}; +} // namespace c10d + +namespace { +inline std::string reduceOpToString(c10d::ReduceOp op) { + switch (op) { + case c10d::ReduceOp::SUM: + return "SUM"; + case c10d::ReduceOp::PRODUCT: + return "PRODUCT"; + case c10d::ReduceOp::MIN: + return "MIN"; + case c10d::ReduceOp::MAX: + return "MAX"; + case c10d::ReduceOp::BAND: + return "BAND"; + case c10d::ReduceOp::BOR: + return "BOR"; + case c10d::ReduceOp::BXOR: + return "BXOR"; + case c10d::ReduceOp::AVG: + return "AVG"; + case c10d::ReduceOp::PREMUL_SUM: + return "PREMUL_SUM"; + default: + return "UNKNOWN"; + } +} +} // namespace +#endif // USE_C10D_XCCL diff --git a/src/xccl/Register.cpp b/src/xccl/Register.cpp new file mode 100644 index 000000000..3716c7a90 --- /dev/null +++ b/src/xccl/Register.cpp @@ -0,0 +1,313 @@ +#include +#include +#include +#include +#include + +namespace c10d { +namespace ops { +namespace { +c10::intrusive_ptr send_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t dstRank, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->send(tensor_vec, static_cast(dstRank), static_cast(tag)); +} + +c10::intrusive_ptr recv_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t srcRank, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->recv(tensor_vec, static_cast(srcRank), static_cast(tag)); +} + +c10::intrusive_ptr recv_any_source_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->recvAnysource(tensor_vec, static_cast(tag)); +} + +c10::intrusive_ptr reduce_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t root_rank, + int64_t root_tensor, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->reduce( + tensor_vec, + ReduceOptions{ + *reduce_op.get(), + root_rank, + root_tensor, + std::chrono::milliseconds(timeout)}); +} + +std::tuple, c10::intrusive_ptr> broadcast_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t root_tensor, + bool asyncOp, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + auto work = process_group->getBackend(c10::DeviceType::XPU) + ->broadcast( + tensor_vec, + BroadcastOptions{ + root_rank, + root_tensor, + std::chrono::milliseconds(timeout), + asyncOp}); + return std::tuple, c10::intrusive_ptr>( + std::move(tensor_vec), work); +} + +std::tuple, c10::intrusive_ptr> allreduce_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + const std::optional& sparse_indices, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->allreduce( + tensor_vec, + AllreduceOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); + return std::tuple, c10::intrusive_ptr>( + std::move(tensor_vec), work); +} + +c10::intrusive_ptr allreduce_coalesced_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + return process_group->getBackend(c10::DeviceType::XPU) + ->allreduce_coalesced(tensor_vec, opts); +} + +std::tuple>, c10::intrusive_ptr> +allgather_XPU( + const std::vector>& output_tensors, + at::TensorList input_tensors, + const c10::intrusive_ptr& process_group, + int64_t timeout) { + auto input_tensors_vec = input_tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->allgather( + const_cast>&>(output_tensors), + input_tensors_vec, + AllgatherOptions{std::chrono::milliseconds(timeout)}); + return std:: + tuple>, c10::intrusive_ptr>( + output_tensors, work); +} + +std::tuple> _allgather_base_XPU( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group, + bool asyncOp, + int64_t timeout) { + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->_allgather_base( + output_tensor, + input_tensor, + AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); + return std::tuple>(output_tensor, work); +} + +c10::intrusive_ptr allgather_coalesced_XPU( + const std::vector>& output_lists, + const at::TensorList& input_list, + const c10::intrusive_ptr& process_group) { + auto input_list_vec = input_list.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->allgather_coalesced( + const_cast>&>(output_lists), + input_list_vec); +} + +c10::intrusive_ptr allgather_into_tensor_coalesced_XPU( + at::TensorList outputs, + at::TensorList inputs, + const c10::intrusive_ptr& process_group) { + auto output_vec = outputs.vec(); + auto input_vec = inputs.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->allgather_into_tensor_coalesced(output_vec, input_vec); +} + +std::tuple, c10::intrusive_ptr> reduce_scatter_XPU( + const at::TensorList& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->reduce_scatter( + output_tensors_vec, + const_cast>&>(input_tensors), + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); + return std::tuple, c10::intrusive_ptr>( + output_tensors_vec, work); +} + +std::tuple> _reduce_scatter_base_XPU( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + bool asyncOp, + int64_t timeout) { + auto work = process_group->getBackend(c10::DeviceType::XPU) + ->_reduce_scatter_base( + output_tensor, + input_tensor, + ReduceScatterOptions{ + *reduce_op.get(), + std::chrono::milliseconds(timeout), + asyncOp}); + return std::tuple>(output_tensor, work); +} + +c10::intrusive_ptr reduce_scatter_tensor_coalesced_XPU( + at::TensorList outputs, + at::TensorList inputs, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto output_vec = outputs.vec(); + auto input_vec = inputs.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->reduce_scatter_tensor_coalesced( + output_vec, + input_vec, + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr gather_XPU( + const std::vector>& output_tensors, + const at::TensorList& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t timeout) { + auto input_tensors_vec = input_tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->gather( + const_cast>&>(output_tensors), + input_tensors_vec, + GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); +} + +std::tuple, c10::intrusive_ptr> scatter_XPU( + const at::TensorList& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + bool asyncOp, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->scatter( + output_tensors_vec, + const_cast>&>(input_tensors), + ScatterOptions{ + root_rank, std::chrono::milliseconds(timeout), asyncOp}); + return std::tuple, c10::intrusive_ptr>( + std::move(output_tensors_vec), work); +} + +std::tuple, c10::intrusive_ptr> alltoall_XPU( + const at::TensorList& output_tensors, + const at::TensorList& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto input_tensors_vec = input_tensors.vec(); + auto work = process_group->getBackend(c10::DeviceType::XPU) + ->alltoall( + output_tensors_vec, + input_tensors_vec, + AllToAllOptions{std::chrono::milliseconds(timeout)}); + return std::tuple, c10::intrusive_ptr>( + std::move(output_tensors_vec), work); +} + +c10::intrusive_ptr alltoall_base_XPU( + at::Tensor& output, + at::Tensor& input, + const c10::intrusive_ptr& process_group, + std::vector output_split_sizes, + std::vector input_split_sizes, + int64_t timeout) { + return process_group->getBackend(c10::DeviceType::XPU) + ->alltoall_base( + output, + input, + output_split_sizes, + input_split_sizes, + AllToAllOptions{std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr barrier_XPU( + at::Tensor /* unused */, + const c10::intrusive_ptr& process_group, + const std::vector& device_ids, + int64_t timeout) { + return process_group->getBackend(c10::DeviceType::XPU) + ->barrier(BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("send", send_XPU); + m.impl("recv_", recv_XPU); + m.impl("recv_any_source_", recv_any_source_XPU); + m.impl("reduce_", reduce_XPU); + m.impl("broadcast_", broadcast_XPU); + m.impl("allreduce_", allreduce_XPU); + m.impl("allreduce_coalesced_", allreduce_coalesced_XPU); + m.impl("allgather_", allgather_XPU); + m.impl("_allgather_base_", _allgather_base_XPU); + m.impl("allgather_coalesced_", allgather_coalesced_XPU); + m.impl( + "allgather_into_tensor_coalesced_", allgather_into_tensor_coalesced_XPU); + m.impl("reduce_scatter_", reduce_scatter_XPU); + m.impl("_reduce_scatter_base_", _reduce_scatter_base_XPU); + m.impl( + "reduce_scatter_tensor_coalesced_", reduce_scatter_tensor_coalesced_XPU); + m.impl("gather_", gather_XPU); + m.impl("scatter_", scatter_XPU); + m.impl("alltoall_", alltoall_XPU); + m.impl("alltoall_base_", alltoall_base_XPU); + m.impl("barrier", barrier_XPU); +} +} // namespace + +} // namespace ops +} // namespace c10d \ No newline at end of file