Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rfactor replay for DID loop split #3543

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,8 @@ void scheduleReduction(Fusion* fusion, const ReductionParams* rparams) {
}

NVF_ERROR(
!(rparams->schedule_3D && isSharded(reduction_tv)),
!(rparams->schedule_3D &&
getShardedLoopAxis(reduction_tv, ParallelType::DIDx) >= 0),
"Multidevice nvFuser does not support 3D reduction schedules");

auto dim_analysis = scheduler_utils::canonicalDimReduction(
Expand Down
6 changes: 0 additions & 6 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -779,12 +779,6 @@ TensorView* TensorView::rFactor(const std::vector<int64_t>& axes) {
"Error rfactoring ",
this,
" its definition is either a nullptr or not a reduction.");
// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this check because we now expect rFactor to be called by both inter- and intra-GPU schedulers.

// using commitLeafToLogical. Thus, the original logical domain is moved to
// the root domain.
NVF_CHECK(
definition()->isA<MmaOp>() || !domain()->hasRoot(),
"Cannot call rfactor on the same view twice.");
NVF_CHECK(
!definition()->isA<GroupedReductionOp>(),
"For GroupedReductionOp, use TensorView::rFactor(const std::vector<int64_t>& axes, const std::vector<TensorView*>& tvs)");
Expand Down
4 changes: 2 additions & 2 deletions csrc/transform_rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ class ReplayRFactor : public ReplayTransformations {
// rfactored domains. If it isn't involved in the rfactor, it's no
// longer a redunction domain
std::optional<IterType> outer_iter_type;
if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this, I ran into an error with the following local reduction:

in: root/logical=[i{n}], loop=[iDIDx{d}, i{n/d}]
out = reduction(in): root=[r{n}], logical/loop=[iDIDx{d}, r{n/d}]

The reduction scheduler tries to schedule out on TIDx

out: root=[r{n}], logical=[iDIDx{d}, r{n/d}], loop=[iDIDx{d}, r{n/d/blockDim.x}, rTIDx{blockDim.x}]

and then rFactor axis 1, i.e., r{n/d/blockDim.x}.

rFactor tries to replay all transforms using ReplayRFactor on a new, identical root domain [r{n}]. Without this change, the outer-split by d produced rDIDx{d} instead of iDIDx{d}.

if (!rfactor_dep_ids_.count(s->outer())) {
outer_iter_type = IterType::Iteration;
}

std::optional<IterType> inner_iter_type;
if (s->inner()->isReduction() && !rfactor_dep_ids_.count(s->inner())) {
if (!rfactor_dep_ids_.count(s->inner())) {
inner_iter_type = IterType::Iteration;
}

Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/test_tutorial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,11 @@ TEST_F(Tutorial, ReductionRFactor) {

// The fusion math should now look like:
//
// tv0: root = logical = [i0]
// tv2 = reduction(tv0): root = [i0], logical = [r1/1024, i1024]
// tv1 = reduction(tv2): root = logical = [r1024]
// tv0: root = logical = [i{i0}]
// tv2 = reduction(tv0): root = [r{i0}], logical = [r{i0/1024}, i{1024}]
// tv1 = reduction(tv2): root = logical = [r{1024}]
if (verbose_) {
fusion_copy.printMath();
fusion_copy.print();
}
// Notice that the reduction operation is now split into two
// operations, where the first one takes care of the first domain, and the
Expand Down
22 changes: 18 additions & 4 deletions tests/python/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,36 @@ def test_allreduce(multidevice_test):

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor((d, 4), contiguity=True, dtype=DataType.Float)
self.inp = self.define_tensor(
(-1, -1), contiguity=True, dtype=DataType.Float
)
self.out = self.ops.sum(self.inp, [0])
self.add_output(self.out)

def multidevice_schedule(self):
self.sched.split(self.inp, 0, d, False)
self.sched.split(self.out, 0, d, False)
out_local = self.sched.rfactor(self.out, [1])

self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)
self.sched._set_device_mesh(out_local, mesh)

self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)
self.sched.parallelize(out_local, 0, nvfuser.ParallelType.mesh_x)

unsharded = torch.randn(d, 4)
self.sched.set_allocation_as_loop(self.inp)
self.sched.set_allocation_as_loop(out_local)
self.sched.set_allocation_as_loop(self.out)

m = d * 2
n = 3
unsharded = torch.randn(m, n)
sharded = multidevice_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
(output,) = fd.execute([sharded])
torch.testing.assert_close(output.local.cpu(), unsharded.sum(0))
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0].local.cpu(), unsharded.sum(0))


@pytest.mark.mpi
Expand Down