Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rfactor replay for DID loop split #3543

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Fix rfactor replay for DID loop split #3543

wants to merge 6 commits into from

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Dec 7, 2024

For #2563

@naoyam
Copy link
Collaborator

naoyam commented Dec 10, 2024

What does the actual fusion that is passed to the reduction scheduler look like?

@wujingyue wujingyue force-pushed the wjy/rs branch 2 times, most recently from 6d03163 to 66a3363 Compare December 10, 2024 15:56
Base automatically changed from wjy/rs to main December 11, 2024 02:26
Copy link

github-actions bot commented Feb 10, 2025

Review updated until commit f45bf68

Description

  • Update rfactor replay logic for DID loop split

  • Enhance test coverage with dynamic sizes

  • Correct tutorial comments and remove redundant checks

  • Improve allreduce test with dynamic tensor sizes


Changes walkthrough 📝

Relevant files
Bug fix
reduction.cpp
Update 3D reduction error check                                                   

csrc/scheduler/reduction.cpp

  • Update error check for 3D reduction schedules with DIDx sharding
+2/-1     
tensor_view.cpp
Remove redundant rfactor check                                                     

csrc/tensor_view.cpp

  • Remove redundant check for rfactor on the same view
+0/-6     
Enhancement
transform_rfactor.cpp
Simplify IterType assignment                                                         

csrc/transform_rfactor.cpp

  • Simplify IterType assignment in ReplayRFactor
+2/-2     
test_communication.py
Enhance allreduce test                                                                     

tests/python/test_communication.py

  • Enhance allreduce test with dynamic tensor sizes and rfactor
  • Add device mesh and parallelization for rfactor output
  • +18/-4   
    Documentation
    test_tutorial.cpp
    Update tutorial comments                                                                 

    tests/cpp/test_tutorial.cpp

    • Update comments and print method in reduction rfactor test
    +4/-4     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The new error check might be too strict. The original check was isSharded(reduction_tv), which checks if the tensor view is sharded. The new check getShardedLoopAxis(reduction_tv, ParallelType::DIDx) >= 0 checks for a specific sharded loop axis. This might not cover all cases where the tensor view is sharded.

    !(rparams->schedule_3D &&
      getShardedLoopAxis(reduction_tv, ParallelType::DIDx) >= 0),
    Removed Check

    The check for domain()->hasRoot() was removed. This check ensures that rfactor is not called on the same view twice, which might be important for maintaining the correctness of the tensor view transformations.

        definition() != nullptr &&
            (definition()
                 ->isStrictlyOneOf<ReductionOp, MmaOp, MatmulOp, LinearOp>()),
        "Error rfactoring ",
        this,
        " its definition is either a nullptr or not a reduction.");
    NVF_CHECK(
        !definition()->isA<GroupedReductionOp>(),
        "For GroupedReductionOp, use TensorView::rFactor(const std::vector<int64_t>& axes, const std::vector<TensorView*>& tvs)");
    Hardcoded Values

    The test test_allreduce uses hardcoded values for the tensor dimensions (m = d * 2, n = 3). This might limit the test's ability to catch issues with different input sizes. Consider using more dynamic or parameterized test cases.

    def test_allreduce(multidevice_test):
        d = multidevice_test.size
        mesh = nvfuser.DeviceMesh(range(d))
    
        class Model(FusionDefinition):
            def definition(self):
                self.inp = self.define_tensor(
                    (-1, -1), contiguity=True, dtype=DataType.Float
                )
                self.out = self.ops.sum(self.inp, [0])
                self.add_output(self.out)
    
            def multidevice_schedule(self):
                self.sched.split(self.inp, 0, d, False)
                self.sched.split(self.out, 0, d, False)
                out_local = self.sched.rfactor(self.out, [1])
    
                self.sched._set_device_mesh(self.inp, mesh)
                self.sched._set_device_mesh(self.out, mesh)
                self.sched._set_device_mesh(out_local, mesh)
    
                self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)
                self.sched.parallelize(out_local, 0, nvfuser.ParallelType.mesh_x)
    
                self.sched.set_allocation_as_loop(self.inp)
                self.sched.set_allocation_as_loop(out_local)
                self.sched.set_allocation_as_loop(self.out)
    
        m = d * 2
        n = 3
        unsharded = torch.randn(m, n)
        sharded = multidevice_test.shard_tensor(unsharded, 0, mesh)
    
        fd = Model()
        outputs = fd.execute([sharded])
        torch.testing.assert_close(outputs[0].local.cpu(), unsharded.sum(0))

    @wujingyue wujingyue force-pushed the wjy/rfactor branch 2 times, most recently from 678b0dd to 237980f Compare February 14, 2025 00:19
    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue changed the title Attempt to use rFactor for allreduce Fix rfactor replay for DID loop split Feb 14, 2025
    @wujingyue wujingyue marked this pull request as ready for review February 14, 2025 00:33
    @@ -779,12 +779,6 @@ TensorView* TensorView::rFactor(const std::vector<int64_t>& axes) {
    "Error rfactoring ",
    this,
    " its definition is either a nullptr or not a reduction.");
    // For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I removed this check because we now expect rFactor to be called by both inter- and intra-GPU schedulers.

    @@ -121,12 +121,12 @@ class ReplayRFactor : public ReplayTransformations {
    // rfactored domains. If it isn't involved in the rfactor, it's no
    // longer a redunction domain
    std::optional<IterType> outer_iter_type;
    if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) {
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Without this, I ran into an error with the following local reduction:

    in: root/logical=[i{n}], loop=[iDIDx{d}, i{n/d}]
    out = reduction(in): root=[r{n}], logical/loop=[iDIDx{d}, r{n/d}]
    

    The reduction scheduler tries to schedule out on TIDx

    out: root=[r{n}], logical=[iDIDx{d}, r{n/d}], loop=[iDIDx{d}, r{n/d/blockDim.x}, rTIDx{blockDim.x}]
    

    and then rFactor axis 1, i.e., r{n/d/blockDim.x}.

    rFactor tries to replay all transforms using ReplayRFactor on a new, identical root domain [r{n}]. Without this change, the outer-split by d produced rDIDx{d} instead of iDIDx{d}.

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue
    Copy link
    Collaborator Author

    What does the actual fusion that is passed to the reduction scheduler look like?

    I finally debugged this through. PTAL!

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants