From 65e0d9d7946716a829c04777954b7ab134bdf472 Mon Sep 17 00:00:00 2001 From: hanchao Date: Thu, 14 Nov 2024 08:48:03 +0000 Subject: [PATCH] WA AVG reduction --- test/distributed/test_c10d_ops_xccl.py | 10 +++++ .../distributed/c10d/ProcessGroupXCCL.cpp | 44 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/test/distributed/test_c10d_ops_xccl.py b/test/distributed/test_c10d_ops_xccl.py index 279ec0eb03ecf8..9784cf3a5c0bea 100644 --- a/test/distributed/test_c10d_ops_xccl.py +++ b/test/distributed/test_c10d_ops_xccl.py @@ -155,6 +155,16 @@ def allreduce(tensors, op): tensors[0], ) + # Avg + tensors = [torch.tensor([self.rank + 1.0]).xpu(local_device_id)] + + allreduce(tensors, c10d.ReduceOp.AVG) + ndev = self.world_size + self.assertEqual( + torch.tensor([ndev * (ndev + 1.0) / (2.0 * ndev)]), + tensors[0], + ) + # Product tensors = [torch.tensor([self.rank + 1]).xpu(local_device_id)] diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index f202f8916f89fd..b2a900c92b8c0b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -147,6 +147,10 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { // Map sum to max for bool tensors to avoid overflow issues with sum. return ccl::reduction::max; } + // WA due to oneCCL not support AVG + if (reduceOp == ReduceOp::AVG) { + return ccl::reduction::sum; + } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { C10_THROW_ERROR( @@ -894,6 +898,11 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( xcclReduceOp, comm, ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } return; }, OpType::ALLREDUCE, @@ -942,6 +951,11 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( xcclReduceOp, comm, ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } return; }, OpType::ALLREDUCE, @@ -988,6 +1002,11 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( xcclReduceOp, comm, ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } return; }, OpType::COALESCED, @@ -1117,6 +1136,11 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( root, comm, ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { + auto divisor = getSize(); + output.div_(divisor); + } return; }, OpType::REDUCE, @@ -1150,6 +1174,11 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( root, comm, ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { + auto divisor = getSize(); + output.div_(divisor); + } return; }, OpType::REDUCE, @@ -1370,6 +1399,11 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( xcclReduceOp, comm, ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } return; }, [&](at::xpu::XPUStream& Stream, @@ -1453,6 +1487,11 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( xcclReduceOp, comm, ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } return; }, OpType::_REDUCE_SCATTER_BASE, @@ -1482,6 +1521,11 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( xcclReduceOp, comm, ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } return; }, OpType::COALESCED,