Skip to content

Commit

Permalink
Attempt to use rFactor for allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Feb 10, 2025
1 parent 5611117 commit 678b0dd
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
3 changes: 2 additions & 1 deletion csrc/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,8 @@ void scheduleReduction(Fusion* fusion, const ReductionParams* rparams) {
}

NVF_ERROR(
!(rparams->schedule_3D && isSharded(reduction_tv)),
!(rparams->schedule_3D &&
getShardedLoopAxis(reduction_tv, ParallelType::DIDx) >= 0),
"Multidevice nvFuser does not support 3D reduction schedules");

auto dim_analysis = scheduler_utils::canonicalDimReduction(
Expand Down
2 changes: 2 additions & 0 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -779,12 +779,14 @@ TensorView* TensorView::rFactor(const std::vector<int64_t>& axes) {
"Error rfactoring ",
this,
" its definition is either a nullptr or not a reduction.");
#if 0
// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
// using commitLeafToLogical. Thus, the original logical domain is moved to
// the root domain.
NVF_CHECK(
definition()->isA<MmaOp>() || !domain()->hasRoot(),
"Cannot call rfactor on the same view twice.");
#endif
NVF_CHECK(
!definition()->isA<GroupedReductionOp>(),
"For GroupedReductionOp, use TensorView::rFactor(const std::vector<int64_t>& axes, const std::vector<TensorView*>& tvs)");
Expand Down
39 changes: 39 additions & 0 deletions tests/python/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,45 @@ def multidevice_schedule(self):
torch.testing.assert_close(output.local.cpu(), unsharded.sum(0))


@pytest.mark.mpi
def test_allreduce_rfactor(multidevice_test):
d = multidevice_test.size
mesh = nvfuser.DeviceMesh(range(d))
n = 3
k = d * 5

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor(
(k, n), contiguity=True, dtype=DataType.Float
)
self.out = self.ops.sum(self.inp, [0])
self.add_output(self.out)

def multidevice_schedule(self):
self.sched.split(self.inp, 0, d, False)
self.sched.split(self.out, 0, d, False)
out_local = self.sched.rfactor(self.out, [1])

self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)
self.sched._set_device_mesh(out_local, mesh)

self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)
self.sched.parallelize(out_local, 0, nvfuser.ParallelType.mesh_x)

self.sched.set_allocation_as_loop(self.inp)
self.sched.set_allocation_as_loop(out_local)
self.sched.set_allocation_as_loop(self.out)

unsharded = torch.randn(k, n)
sharded = multidevice_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0].local.cpu(), unsharded.sum(0))


@pytest.mark.mpi
def test_reduce_scatter(multidevice_test):
d = multidevice_test.size
Expand Down

0 comments on commit 678b0dd

Please sign in to comment.