From 808a295aca1a1657fd3951cf876f33a2c9425318 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 6 Dec 2024 19:40:37 -0800 Subject: [PATCH 1/7] Attempt to use rFactor for allreduce --- csrc/scheduler/reduction.cpp | 3 ++- csrc/tensor_view.cpp | 2 ++ csrc/transform_rfactor.cpp | 4 ++-- tests/python/test_communication.py | 37 ++++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index bcbb859e3b2..e09b6253fbc 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -1549,7 +1549,8 @@ void scheduleReduction(Fusion* fusion, const ReductionParams* rparams) { } NVF_ERROR( - !(rparams->schedule_3D && isSharded(reduction_tv)), + !(rparams->schedule_3D && + getShardedLoopAxis(reduction_tv, ParallelType::DIDx) >= 0), "Multidevice nvFuser does not support 3D reduction schedules"); auto dim_analysis = scheduler_utils::canonicalDimReduction( diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 0cb7afc4734..3038a372312 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -779,12 +779,14 @@ TensorView* TensorView::rFactor(const std::vector& axes) { "Error rfactoring ", this, " its definition is either a nullptr or not a reduction."); +#if 0 // For hopper matmuls, the mma_result logical domain is reordered as [M, N, K] // using commitLeafToLogical. Thus, the original logical domain is moved to // the root domain. NVF_CHECK( definition()->isA() || !domain()->hasRoot(), "Cannot call rfactor on the same view twice."); +#endif NVF_CHECK( !definition()->isA(), "For GroupedReductionOp, use TensorView::rFactor(const std::vector& axes, const std::vector& tvs)"); diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index fdbbef8ae5f..3f0149b7ad4 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -121,12 +121,12 @@ class ReplayRFactor : public ReplayTransformations { // rfactored domains. If it isn't involved in the rfactor, it's no // longer a redunction domain std::optional outer_iter_type; - if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) { + if (!rfactor_dep_ids_.count(s->outer())) { outer_iter_type = IterType::Iteration; } std::optional inner_iter_type; - if (s->inner()->isReduction() && !rfactor_dep_ids_.count(s->inner())) { + if (!rfactor_dep_ids_.count(s->inner())) { inner_iter_type = IterType::Iteration; } diff --git a/tests/python/test_communication.py b/tests/python/test_communication.py index 55369f0f366..ee8e83ad807 100644 --- a/tests/python/test_communication.py +++ b/tests/python/test_communication.py @@ -70,6 +70,43 @@ def multidevice_schedule(self): torch.testing.assert_close(output.local.cpu(), unsharded.sum(0)) +@pytest.mark.mpi +def test_allreduce_rfactor(multidevice_test): + d = multidevice_test.size + mesh = nvfuser.DeviceMesh(range(d)) + n = 3 + k = d * 5 + + class Model(FusionDefinition): + def definition(self): + self.inp = self.define_tensor((k, n), contiguity=True, dtype=DataType.Float) + self.out = self.ops.sum(self.inp, [0]) + self.add_output(self.out) + + def multidevice_schedule(self): + self.sched.split(self.inp, 0, d, False) + self.sched.split(self.out, 0, d, False) + out_local = self.sched.rfactor(self.out, [1]) + + self.sched._set_device_mesh(self.inp, mesh) + self.sched._set_device_mesh(self.out, mesh) + self.sched._set_device_mesh(out_local, mesh) + + self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x) + self.sched.parallelize(out_local, 0, nvfuser.ParallelType.mesh_x) + + self.sched.set_allocation_as_loop(self.inp) + self.sched.set_allocation_as_loop(out_local) + self.sched.set_allocation_as_loop(self.out) + + unsharded = torch.randn(k, n) + sharded = multidevice_test.shard_tensor(unsharded, 0, mesh) + + fd = Model() + outputs = fd.execute([sharded]) + torch.testing.assert_close(outputs[0].local.cpu(), unsharded.sum(0)) + + @pytest.mark.mpi def test_reduce_scatter(multidevice_test): d = multidevice_test.size From 6f88e5fe81160f6fe8a2bef8261090274ae6f891 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 14 Feb 2025 13:51:29 -0800 Subject: [PATCH 2/7] Remove the double rfactor check --- csrc/tensor_view.cpp | 8 -------- tests/python/test_communication.py | 6 +++--- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 3038a372312..ecfb322b022 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -779,14 +779,6 @@ TensorView* TensorView::rFactor(const std::vector& axes) { "Error rfactoring ", this, " its definition is either a nullptr or not a reduction."); -#if 0 - // For hopper matmuls, the mma_result logical domain is reordered as [M, N, K] - // using commitLeafToLogical. Thus, the original logical domain is moved to - // the root domain. - NVF_CHECK( - definition()->isA() || !domain()->hasRoot(), - "Cannot call rfactor on the same view twice."); -#endif NVF_CHECK( !definition()->isA(), "For GroupedReductionOp, use TensorView::rFactor(const std::vector& axes, const std::vector& tvs)"); diff --git a/tests/python/test_communication.py b/tests/python/test_communication.py index ee8e83ad807..be969d030c2 100644 --- a/tests/python/test_communication.py +++ b/tests/python/test_communication.py @@ -74,12 +74,12 @@ def multidevice_schedule(self): def test_allreduce_rfactor(multidevice_test): d = multidevice_test.size mesh = nvfuser.DeviceMesh(range(d)) + m = d * 2 n = 3 - k = d * 5 class Model(FusionDefinition): def definition(self): - self.inp = self.define_tensor((k, n), contiguity=True, dtype=DataType.Float) + self.inp = self.define_tensor((m, n), contiguity=True, dtype=DataType.Float) self.out = self.ops.sum(self.inp, [0]) self.add_output(self.out) @@ -99,7 +99,7 @@ def multidevice_schedule(self): self.sched.set_allocation_as_loop(out_local) self.sched.set_allocation_as_loop(self.out) - unsharded = torch.randn(k, n) + unsharded = torch.randn(m, n) sharded = multidevice_test.shard_tensor(unsharded, 0, mesh) fd = Model() From 28b5aee769f60bc052321290194c7ec059cd8e49 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 14 Feb 2025 13:53:16 -0800 Subject: [PATCH 3/7] Harden the test with dynamic sizes --- tests/python/test_communication.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/test_communication.py b/tests/python/test_communication.py index be969d030c2..8cd99ff9bc0 100644 --- a/tests/python/test_communication.py +++ b/tests/python/test_communication.py @@ -74,12 +74,10 @@ def multidevice_schedule(self): def test_allreduce_rfactor(multidevice_test): d = multidevice_test.size mesh = nvfuser.DeviceMesh(range(d)) - m = d * 2 - n = 3 class Model(FusionDefinition): def definition(self): - self.inp = self.define_tensor((m, n), contiguity=True, dtype=DataType.Float) + self.inp = self.define_tensor((-1, -1), contiguity=True, dtype=DataType.Float) self.out = self.ops.sum(self.inp, [0]) self.add_output(self.out) @@ -99,6 +97,8 @@ def multidevice_schedule(self): self.sched.set_allocation_as_loop(out_local) self.sched.set_allocation_as_loop(self.out) + m = d * 2 + n = 3 unsharded = torch.randn(m, n) sharded = multidevice_test.shard_tensor(unsharded, 0, mesh) From f3cd3b28bda75de825b46504a8767f9bb66a10ec Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 14 Feb 2025 14:01:00 -0800 Subject: [PATCH 4/7] Fix tutorial comments --- tests/cpp/test_tutorial.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/cpp/test_tutorial.cpp b/tests/cpp/test_tutorial.cpp index f2d0744825b..a45e3e2038b 100644 --- a/tests/cpp/test_tutorial.cpp +++ b/tests/cpp/test_tutorial.cpp @@ -345,11 +345,11 @@ TEST_F(Tutorial, ReductionRFactor) { // The fusion math should now look like: // - // tv0: root = logical = [i0] - // tv2 = reduction(tv0): root = [i0], logical = [r1/1024, i1024] - // tv1 = reduction(tv2): root = logical = [r1024] + // tv0: root = logical = [i{i0}] + // tv2 = reduction(tv0): root = [r{i0}], logical = [r{i0/1024}, i{1024}] + // tv1 = reduction(tv2): root = logical = [r{1024}] if (verbose_) { - fusion_copy.printMath(); + fusion_copy.print(); } // Notice that the reduction operation is now split into two // operations, where the first one takes care of the first domain, and the From cb6f4421aa68c60a4a6bdd4f48034bf7ca802bd3 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 14 Feb 2025 14:02:50 -0800 Subject: [PATCH 5/7] Remove the other test that's less realistic --- tests/python/test_communication.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/tests/python/test_communication.py b/tests/python/test_communication.py index 8cd99ff9bc0..f6dd6d5928d 100644 --- a/tests/python/test_communication.py +++ b/tests/python/test_communication.py @@ -50,31 +50,6 @@ def test_allreduce(multidevice_test): d = multidevice_test.size mesh = nvfuser.DeviceMesh(range(d)) - class Model(FusionDefinition): - def definition(self): - self.inp = self.define_tensor((d, 4), contiguity=True, dtype=DataType.Float) - self.out = self.ops.sum(self.inp, [0]) - self.add_output(self.out) - - def multidevice_schedule(self): - self.sched._set_device_mesh(self.inp, mesh) - self.sched._set_device_mesh(self.out, mesh) - - self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x) - - unsharded = torch.randn(d, 4) - sharded = multidevice_test.shard_tensor(unsharded, 0, mesh) - - fd = Model() - (output,) = fd.execute([sharded]) - torch.testing.assert_close(output.local.cpu(), unsharded.sum(0)) - - -@pytest.mark.mpi -def test_allreduce_rfactor(multidevice_test): - d = multidevice_test.size - mesh = nvfuser.DeviceMesh(range(d)) - class Model(FusionDefinition): def definition(self): self.inp = self.define_tensor((-1, -1), contiguity=True, dtype=DataType.Float) From f45bf689c5a72412dc0e2e2349cf5591464dc3e2 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 14 Feb 2025 14:23:07 -0800 Subject: [PATCH 6/7] lint --- tests/python/test_communication.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/test_communication.py b/tests/python/test_communication.py index f6dd6d5928d..955c556b7a8 100644 --- a/tests/python/test_communication.py +++ b/tests/python/test_communication.py @@ -52,7 +52,9 @@ def test_allreduce(multidevice_test): class Model(FusionDefinition): def definition(self): - self.inp = self.define_tensor((-1, -1), contiguity=True, dtype=DataType.Float) + self.inp = self.define_tensor( + (-1, -1), contiguity=True, dtype=DataType.Float + ) self.out = self.ops.sum(self.inp, [0]) self.add_output(self.out) From a912e69bc9f871bc427c67757a13a74bea9d048c Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 19 Feb 2025 19:57:34 -0800 Subject: [PATCH 7/7] Harden the test by DID-spliting a non-outermost dimension. This exercises sortAndRFactor for DID. --- tests/python/test_communication.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/python/test_communication.py b/tests/python/test_communication.py index 955c556b7a8..26eb703ad58 100644 --- a/tests/python/test_communication.py +++ b/tests/python/test_communication.py @@ -53,35 +53,36 @@ def test_allreduce(multidevice_test): class Model(FusionDefinition): def definition(self): self.inp = self.define_tensor( - (-1, -1), contiguity=True, dtype=DataType.Float + (-1, -1, -1), contiguity=True, dtype=DataType.Float ) - self.out = self.ops.sum(self.inp, [0]) + self.out = self.ops.sum(self.inp, [1]) self.add_output(self.out) def multidevice_schedule(self): - self.sched.split(self.inp, 0, d, False) - self.sched.split(self.out, 0, d, False) - out_local = self.sched.rfactor(self.out, [1]) + self.sched.split(self.inp, 1, d, False) + self.sched.split(self.out, 1, d, False) + out_local = self.sched.rfactor(self.out, [2]) self.sched._set_device_mesh(self.inp, mesh) self.sched._set_device_mesh(self.out, mesh) self.sched._set_device_mesh(out_local, mesh) - self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x) - self.sched.parallelize(out_local, 0, nvfuser.ParallelType.mesh_x) + self.sched.parallelize(self.inp, 1, nvfuser.ParallelType.mesh_x) + self.sched.parallelize(out_local, 1, nvfuser.ParallelType.mesh_x) self.sched.set_allocation_as_loop(self.inp) self.sched.set_allocation_as_loop(out_local) self.sched.set_allocation_as_loop(self.out) - m = d * 2 - n = 3 - unsharded = torch.randn(m, n) - sharded = multidevice_test.shard_tensor(unsharded, 0, mesh) + m = 2 + k = d * 3 + n = 5 + unsharded = torch.randn(m, k, n) + sharded = multidevice_test.shard_tensor(unsharded, 1, mesh) fd = Model() outputs = fd.execute([sharded]) - torch.testing.assert_close(outputs[0].local.cpu(), unsharded.sum(0)) + torch.testing.assert_close(outputs[0].local.cpu(), unsharded.sum(1)) @pytest.mark.mpi