Skip to content

Commit

Permalink
Add LDSWr, LDSRd sinking to scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
Ognjen committed Feb 21, 2024
1 parent f06329d commit 1687710
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/AMDReorderInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class TritonAMDGPUReorderInstructionsPass
return std::distance(value.user_begin(), value.user_end());
}

void scheduleSlicedDot(ModuleOp m, int stages) {
void scheduleSlicedDot(ModuleOp m, int stages, bool sinkLDSRd, bool sinkLDSWr) {
SmallVector<SmallVector<Operation *>> dotChains;

m.walk([&](tt::DotOp dotOp) {
Expand Down Expand Up @@ -270,6 +270,24 @@ class TritonAMDGPUReorderInstructionsPass
operations, i == 0, 1);
}
}

if (!sinkLDSRd) {
return;
}

for (auto chain : dotChains) {
for (int i = 0; i < chain.size(); i++) {
Operation *dotOp = chain[i];
Operation *ldsRd = dotOp->getOperand(1).getDefiningOp();
assert(isLDSRead(ldsRd));
moveBefore(ldsRd, dotOp);
if (sinkLDSWr) {
Operation *ldsWr = ldsRd->getOperand(0).getDefiningOp();
assert(isLDSWrite(ldsWr));
moveBefore(ldsWr, ldsRd);
}
}
}
}

void runOnOperation() override {
Expand All @@ -278,7 +296,9 @@ class TritonAMDGPUReorderInstructionsPass

moveQTensorOutOfTheLoop(m);
int stages = 4;
scheduleSlicedDot(m, stages);
bool sinkLDSRd = true;
bool sinkLDSWr = true;
scheduleSlicedDot(m, stages, sinkLDSRd, sinkLDSWr);
}
};

Expand Down

0 comments on commit 1687710

Please sign in to comment.