-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix rfactor replay for DID loop split #3543
base: main
Are you sure you want to change the base?
Conversation
What does the actual fusion that is passed to the reduction scheduler look like? |
6d03163
to
66a3363
Compare
95eb150
to
5ca48ee
Compare
Review updated until commit f45bf68 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
678b0dd
to
237980f
Compare
!test |
237980f
to
808a295
Compare
@@ -779,12 +779,6 @@ TensorView* TensorView::rFactor(const std::vector<int64_t>& axes) { | |||
"Error rfactoring ", | |||
this, | |||
" its definition is either a nullptr or not a reduction."); | |||
// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this check because we now expect rFactor to be called by both inter- and intra-GPU schedulers.
@@ -121,12 +121,12 @@ class ReplayRFactor : public ReplayTransformations { | |||
// rfactored domains. If it isn't involved in the rfactor, it's no | |||
// longer a redunction domain | |||
std::optional<IterType> outer_iter_type; | |||
if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this, I ran into an error with the following local reduction:
in: root/logical=[i{n}], loop=[iDIDx{d}, i{n/d}]
out = reduction(in): root=[r{n}], logical/loop=[iDIDx{d}, r{n/d}]
The reduction scheduler tries to schedule out
on TIDx
out: root=[r{n}], logical=[iDIDx{d}, r{n/d}], loop=[iDIDx{d}, r{n/d/blockDim.x}, rTIDx{blockDim.x}]
and then rFactor axis 1, i.e., r{n/d/blockDim.x}
.
rFactor tries to replay all transforms using ReplayRFactor on a new, identical root domain [r{n}]
. Without this change, the outer-split by d
produced rDIDx{d}
instead of iDIDx{d}
.
!test |
I finally debugged this through. PTAL! |
For #2563