From 13b20fee3f116842107dd0b2e78c2e4542c99194 Mon Sep 17 00:00:00 2001 From: SJW Date: Thu, 27 Jun 2024 19:17:46 +0000 Subject: [PATCH] * invert order of loads and local_stores --- .../ReorderInstructions.cpp | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 70caa21f4020..f46b5a2d6460 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -147,14 +147,13 @@ class TritonAMDGPUReorderInstructionsPass moveAfter(op, argOp); }); SmallVector moveOps; - // Move local stores early if dependence distance greater than - // one iteration. - m.walk([&](triton::gpu::LocalStoreOp op) { moveOps.push_back(op); }); - // Move global loads early (prefetch). These should be first in - // the block since they have the longest latency. + // Move global loads early to prefetch. m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); + // Move local_stores early if dependence distance greater than + // one iteration. Best perf on GEMM when these precede global loads. + m.walk([&](triton::gpu::LocalStoreOp op) { moveOps.push_back(op); }); for (auto op : moveOps) { - // 0. gather DFG + // 0. Gather use-def chain in block. Block *block = op->getBlock(); SmallVector dfg{op}; bool leadsToLoad = gatherDFG(op, block, dfg); @@ -163,9 +162,12 @@ class TritonAMDGPUReorderInstructionsPass if (auto ld = dyn_cast(op)) src = ld.getPtr(); auto ip = findEarlyInsertionPoint(block, op, src); - // Remove ops that already precede the insertion point. - llvm::erase_if( - dfg, [&](Operation *op) { return !ip->isBeforeInBlock(op); }); + // Remove ops that already precede the insertion point. This + // is done before moves happen to avoid N^2 complexity in + // `Operation::isBeforeInBlock`. + llvm::erase_if(dfg, + [&](Operation *op) { return !ip->isBeforeInBlock(op); }); + // Move ops to insertion point. for (auto *op : dfg) op->moveAfter(block, ip); }