From c9ef78fdc5c8872246c74e5a1949d5a7c94726c5 Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 4 Sep 2024 04:25:38 +0000 Subject: [PATCH] update --- build_variables.bzl | 1 + torch/csrc/distributed/c10d/ProcessGroup.hpp | 1 + .../distributed/c10d/ProcessGroupXCCL.cpp | 50 ++++--------------- .../distributed/c10d/ProcessGroupXCCL.hpp | 15 ++---- 4 files changed, 16 insertions(+), 51 deletions(-) diff --git a/build_variables.bzl b/build_variables.bzl index 80a575324aa8b3..55a3f0023b571f 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -786,6 +786,7 @@ libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [ ] libtorch_python_xpu_sources = [ + "torch/csrc/xpu/xccl.cpp", "torch/csrc/xpu/Event.cpp", "torch/csrc/xpu/Module.cpp", "torch/csrc/xpu/Stream.cpp", diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index acf8c9c354a76b..85142caf0ac7c7 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -70,6 +70,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { UCC = 3, MPI = 4, CUSTOM = 5, + XCCL = 6, }; // Not used, set for backwards compatibility and only used for TypeDef in diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 5e2e179d32af37..8be7c6451fcdd0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -3,7 +3,7 @@ #include #include -#ifdef USE_C10D_XCCL +// #ifdef USE_C10D_XCCL #include #include #include @@ -68,38 +68,6 @@ ccl::datatype getXcclDataType(at::ScalarType type) { static std::mutex xcclCommDevIdxMapMutex; static std::unordered_map, int> xcclCommDevIdxMap; -// template < -// template -// class WorkXCCL, -// typename RunF, -// typename CommType, -// typename InputType, -// typename OutputType, -// typename attr_t> -// c10::intrusive_ptr make_work_ccl( -// const std::vector& inputs, -// const std::vector& outputs, -// RunF f, -// CommType& comms, -// attr_t& attr, -// int rank, -// c10d::OpType op_type) { -// c10::intrusive_ptr> -// ret_ptr = c10::make_intrusive< -// WorkCCL>( -// inputs, outputs, f, comms, attr, rank, op_type); -// return ret_ptr; -// } - -// ProcessGroupXCCL::WorkXCCL::WorkXCCL( -// std::vector> outputTensors, -// int rank, -// c10d::OpType opType, -// const c10::optional>& inputTensors) -// : Work(rank, opType, nullptr, inputTensors), -// outputTensors_(std::move(outputTensors)), -// future_(createFutureAsOutput(outputTensors)) {} - ProcessGroupXCCL::WorkXCCL::WorkXCCL( at::Device& device, int rank, @@ -116,6 +84,11 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; +bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { + synchronize(); + return true; +} + c10::intrusive_ptr ProcessGroupXCCL::WorkXCCL:: getFuture() { return future_; @@ -267,8 +240,10 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( input, output, fn, - [](std::vector&) {}, - [](std::vector&) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr& work) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr& work) {}, opType); } @@ -306,9 +281,6 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( OpType::ALLREDUCE); } -// c10::intrusive_ptr barrier( -// const BarrierOptions& opts = BarrierOptions()) override; - } // namespace c10d -#endif // USE_C10D_XCCL +// #endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 02eddb7acb8ec0..d14d677205ecbb 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -7,7 +7,7 @@ #include #endif -#ifdef USE_C10D_XCCL +// #ifdef USE_C10D_XCCL #include #include @@ -35,7 +35,7 @@ namespace c10d { constexpr const char* XCCL_BACKEND_NAME = "xccl"; using namespace torch::xpu::xccl; -class ProcessGroupXCCL : public Backend { +class TORCH_XPU_API ProcessGroupXCCL : public Backend { public: class WorkXCCL : public Work { public: @@ -82,13 +82,6 @@ class ProcessGroupXCCL : public Backend { void synchronize() override; bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; - // void wait() { - // std::unique_lock lock(mutex_); - // for (auto& event : events_) { - // event.wait(); - // } - // events_.clear(); - // } c10::intrusive_ptr getFuture() override { return future_; @@ -96,8 +89,6 @@ class ProcessGroupXCCL : public Backend { std::vector result() override { TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::result not implemented"); - // return outputTensors_.empty() ? std::vector() - // : outputTensors_[0]; } protected: @@ -147,4 +138,4 @@ class ProcessGroupXCCL : public Backend { } // namespace c10d -#endif // USE_C10D_XCCL +// #endif // USE_C10D_XCCL