Skip to content

Commit

Permalink
Attempt to use rFactor for allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Feb 14, 2025
1 parent 63482f5 commit 237980f
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 3 deletions.
3 changes: 2 additions & 1 deletion csrc/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,8 @@ void scheduleReduction(Fusion* fusion, const ReductionParams* rparams) {
}

NVF_ERROR(
!(rparams->schedule_3D && isSharded(reduction_tv)),
!(rparams->schedule_3D &&
getShardedLoopAxis(reduction_tv, ParallelType::DIDx) >= 0),
"Multidevice nvFuser does not support 3D reduction schedules");

auto dim_analysis = scheduler_utils::canonicalDimReduction(
Expand Down
2 changes: 2 additions & 0 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -779,12 +779,14 @@ TensorView* TensorView::rFactor(const std::vector<int64_t>& axes) {
"Error rfactoring ",
this,
" its definition is either a nullptr or not a reduction.");
#if 0
// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
// using commitLeafToLogical. Thus, the original logical domain is moved to
// the root domain.
NVF_CHECK(
definition()->isA<MmaOp>() || !domain()->hasRoot(),
"Cannot call rfactor on the same view twice.");
#endif
NVF_CHECK(
!definition()->isA<GroupedReductionOp>(),
"For GroupedReductionOp, use TensorView::rFactor(const std::vector<int64_t>& axes, const std::vector<TensorView*>& tvs)");
Expand Down
4 changes: 2 additions & 2 deletions csrc/transform_rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ class ReplayRFactor : public ReplayTransformations {
// rfactored domains. If it isn't involved in the rfactor, it's no
// longer a redunction domain
std::optional<IterType> outer_iter_type;
if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) {
if (!rfactor_dep_ids_.count(s->outer())) {
outer_iter_type = IterType::Iteration;
}

std::optional<IterType> inner_iter_type;
if (s->inner()->isReduction() && !rfactor_dep_ids_.count(s->inner())) {
if (!rfactor_dep_ids_.count(s->inner())) {
inner_iter_type = IterType::Iteration;
}

Expand Down
37 changes: 37 additions & 0 deletions tests/python/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,43 @@ def multidevice_schedule(self):
torch.testing.assert_close(output.local.cpu(), unsharded.sum(0))


@pytest.mark.mpi
def test_allreduce_rfactor(multidevice_test):
d = multidevice_test.size
mesh = nvfuser.DeviceMesh(range(d))
n = 3
k = d * 5

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor((k, n), contiguity=True, dtype=DataType.Float)
self.out = self.ops.sum(self.inp, [0])
self.add_output(self.out)

def multidevice_schedule(self):
self.sched.split(self.inp, 0, d, False)
self.sched.split(self.out, 0, d, False)
out_local = self.sched.rfactor(self.out, [1])

self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)
self.sched._set_device_mesh(out_local, mesh)

self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)
self.sched.parallelize(out_local, 0, nvfuser.ParallelType.mesh_x)

self.sched.set_allocation_as_loop(self.inp)
self.sched.set_allocation_as_loop(out_local)
self.sched.set_allocation_as_loop(self.out)

unsharded = torch.randn(k, n)
sharded = multidevice_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0].local.cpu(), unsharded.sum(0))


@pytest.mark.mpi
def test_reduce_scatter(multidevice_test):
d = multidevice_test.size
Expand Down

0 comments on commit 237980f

Please sign in to comment.