Skip to content

Commit

Permalink
WA AVG reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Nov 14, 2024
1 parent 0aedc00 commit 65e0d9d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
10 changes: 10 additions & 0 deletions test/distributed/test_c10d_ops_xccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ def allreduce(tensors, op):
tensors[0],
)

# Avg
tensors = [torch.tensor([self.rank + 1.0]).xpu(local_device_id)]

allreduce(tensors, c10d.ReduceOp.AVG)
ndev = self.world_size
self.assertEqual(
torch.tensor([ndev * (ndev + 1.0) / (2.0 * ndev)]),
tensors[0],
)

# Product
tensors = [torch.tensor([self.rank + 1]).xpu(local_device_id)]

Expand Down
44 changes: 44 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
// Map sum to max for bool tensors to avoid overflow issues with sum.
return ccl::reduction::max;
}
// WA due to oneCCL not support AVG
if (reduceOp == ReduceOp::AVG) {
return ccl::reduction::sum;
}
return xcclOps.at(reduceOp);
} catch (const std::out_of_range&) {
C10_THROW_ERROR(
Expand Down Expand Up @@ -894,6 +898,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
xcclReduceOp,
comm,
ccl::create_stream(stream.queue()));
// WA due to oneCCL not support AVG
if (opts.reduceOp == ReduceOp::AVG) {
auto divisor = getSize();
output.div_(divisor);
}
return;
},
OpType::ALLREDUCE,
Expand Down Expand Up @@ -942,6 +951,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
xcclReduceOp,
comm,
ccl::create_stream(stream.queue()));
// WA due to oneCCL not support AVG
if (opts.reduceOp == ReduceOp::AVG) {
auto divisor = getSize();
output.div_(divisor);
}
return;
},
OpType::ALLREDUCE,
Expand Down Expand Up @@ -988,6 +1002,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
xcclReduceOp,
comm,
ccl::create_stream(stream.queue()));
// WA due to oneCCL not support AVG
if (opts.reduceOp == ReduceOp::AVG) {
auto divisor = getSize();
output.div_(divisor);
}
return;
},
OpType::COALESCED,
Expand Down Expand Up @@ -1117,6 +1136,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce(
root,
comm,
ccl::create_stream(stream.queue()));
// WA due to oneCCL not support AVG
if (opts.reduceOp == ReduceOp::AVG && getRank() == root) {
auto divisor = getSize();
output.div_(divisor);
}
return;
},
OpType::REDUCE,
Expand Down Expand Up @@ -1150,6 +1174,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_oop(
root,
comm,
ccl::create_stream(stream.queue()));
// WA due to oneCCL not support AVG
if (opts.reduceOp == ReduceOp::AVG && getRank() == root) {
auto divisor = getSize();
output.div_(divisor);
}
return;
},
OpType::REDUCE,
Expand Down Expand Up @@ -1370,6 +1399,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter(
xcclReduceOp,
comm,
ccl::create_stream(stream.queue()));
// WA due to oneCCL not support AVG
if (opts.reduceOp == ReduceOp::AVG) {
auto divisor = getSize();
output.div_(divisor);
}
return;
},
[&](at::xpu::XPUStream& Stream,
Expand Down Expand Up @@ -1453,6 +1487,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_scatter_base(
xcclReduceOp,
comm,
ccl::create_stream(stream.queue()));
// WA due to oneCCL not support AVG
if (opts.reduceOp == ReduceOp::AVG) {
auto divisor = getSize();
output.div_(divisor);
}
return;
},
OpType::_REDUCE_SCATTER_BASE,
Expand Down Expand Up @@ -1482,6 +1521,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
xcclReduceOp,
comm,
ccl::create_stream(stream.queue()));
// WA due to oneCCL not support AVG
if (opts.reduceOp == ReduceOp::AVG) {
auto divisor = getSize();
output.div_(divisor);
}
return;
},
OpType::COALESCED,
Expand Down

0 comments on commit 65e0d9d

Please sign in to comment.