Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Sep 4, 2024
1 parent 2e21d4f commit c9ef78f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 51 deletions.
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,7 @@ libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [
]

libtorch_python_xpu_sources = [
"torch/csrc/xpu/xccl.cpp",
"torch/csrc/xpu/Event.cpp",
"torch/csrc/xpu/Module.cpp",
"torch/csrc/xpu/Stream.cpp",
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/distributed/c10d/ProcessGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
UCC = 3,
MPI = 4,
CUSTOM = 5,
XCCL = 6,
};

// Not used, set for backwards compatibility and only used for TypeDef in
Expand Down
50 changes: 11 additions & 39 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <mutex>
#include <sstream>

#ifdef USE_C10D_XCCL
// #ifdef USE_C10D_XCCL
#include <exception>
#include <map>
#include <stdexcept>
Expand Down Expand Up @@ -68,38 +68,6 @@ ccl::datatype getXcclDataType(at::ScalarType type) {
static std::mutex xcclCommDevIdxMapMutex;
static std::unordered_map<std::shared_ptr<xcclComm_t>, int> xcclCommDevIdxMap;

// template <
// template <typename, typename, typename, typename, typename>
// class WorkXCCL,
// typename RunF,
// typename CommType,
// typename InputType,
// typename OutputType,
// typename attr_t>
// c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> make_work_ccl(
// const std::vector<InputType>& inputs,
// const std::vector<OutputType>& outputs,
// RunF f,
// CommType& comms,
// attr_t& attr,
// int rank,
// c10d::OpType op_type) {
// c10::intrusive_ptr<WorkCCL<RunF, CommType, InputType, OutputType, attr_t>>
// ret_ptr = c10::make_intrusive<
// WorkCCL<RunF, CommType, InputType, OutputType, attr_t>>(
// inputs, outputs, f, comms, attr, rank, op_type);
// return ret_ptr;
// }

// ProcessGroupXCCL::WorkXCCL::WorkXCCL(
// std::vector<std::vector<at::Tensor>> outputTensors,
// int rank,
// c10d::OpType opType,
// const c10::optional<std::vector<at::Tensor>>& inputTensors)
// : Work(rank, opType, nullptr, inputTensors),
// outputTensors_(std::move(outputTensors)),
// future_(createFutureAsOutput(outputTensors)) {}

ProcessGroupXCCL::WorkXCCL::WorkXCCL(
at::Device& device,
int rank,
Expand All @@ -116,6 +84,11 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w)

ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default;

bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
synchronize();
return true;
}

c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupXCCL::WorkXCCL::
getFuture() {
return future_;
Expand Down Expand Up @@ -267,8 +240,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
input,
output,
fn,
[](std::vector<ccl::stream>&) {},
[](std::vector<ccl::stream>&) {},
[](at::xpu::XPUStream&,
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL>& work) {},
[](at::xpu::XPUStream&,
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL>& work) {},
opType);
}

Expand Down Expand Up @@ -306,9 +281,6 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
OpType::ALLREDUCE);
}

// c10::intrusive_ptr<Work> barrier(
// const BarrierOptions& opts = BarrierOptions()) override;

} // namespace c10d

#endif // USE_C10D_XCCL
// #endif // USE_C10D_XCCL
15 changes: 3 additions & 12 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <unistd.h>
#endif

#ifdef USE_C10D_XCCL
// #ifdef USE_C10D_XCCL

#include <oneapi/ccl.hpp>
#include <torch/csrc/xpu/xccl.h>
Expand Down Expand Up @@ -35,7 +35,7 @@ namespace c10d {
constexpr const char* XCCL_BACKEND_NAME = "xccl";
using namespace torch::xpu::xccl;

class ProcessGroupXCCL : public Backend {
class TORCH_XPU_API ProcessGroupXCCL : public Backend {
public:
class WorkXCCL : public Work {
public:
Expand Down Expand Up @@ -82,22 +82,13 @@ class ProcessGroupXCCL : public Backend {
void synchronize() override;

bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
// void wait() {
// std::unique_lock<std::timed_mutex> lock(mutex_);
// for (auto& event : events_) {
// event.wait();
// }
// events_.clear();
// }

c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
return future_;
}

std::vector<at::Tensor> result() override {
TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::result not implemented");
// return outputTensors_.empty() ? std::vector<at::Tensor>()
// : outputTensors_[0];
}

protected:
Expand Down Expand Up @@ -147,4 +138,4 @@ class ProcessGroupXCCL : public Backend {

} // namespace c10d

#endif // USE_C10D_XCCL
// #endif // USE_C10D_XCCL

0 comments on commit c9ef78f

Please sign in to comment.