diff --git a/lib/Dialect/TritonGPU/Transforms/AMDReorderInstructions.cpp b/lib/Dialect/TritonGPU/Transforms/AMDReorderInstructions.cpp index 8e672cba8861..e802e44c7679 100644 --- a/lib/Dialect/TritonGPU/Transforms/AMDReorderInstructions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AMDReorderInstructions.cpp @@ -1,3 +1,26 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Verifier.h" @@ -23,7 +46,11 @@ class TritonAMDGPUReorderInstructionsPass mlir::DominanceInfo dom(m); for (auto operand : operands) { - operandsSorted.push_back(operand); + // Sort only operands for which defining op can be fetched. This will + // exclude, for example, block arguments. + if (operand.getDefiningOp()) { + operandsSorted.push_back(operand); + } } if (operandsSorted.size() == 1) { @@ -34,10 +61,8 @@ class TritonAMDGPUReorderInstructionsPass [&](const Value &a, const Value &b) { Operation *operandA = a.getDefiningOp(); Operation *operandB = b.getDefiningOp(); - if (operandA && operandB) { - return dom.dominates(operandA, operandB); - } - return false; + assert(operandA && operandB); + return dom.dominates(operandA, operandB); }); } @@ -97,18 +122,22 @@ class TritonAMDGPUReorderInstructionsPass SmallVector operandsSorted; sortOperandsByDominance(operands, operandsSorted); - if (!operandsSorted.empty() && - operandsSorted[operandsSorted.size() - 1].getDefiningOp()) { - - moveAfter(op, operandsSorted[operandsSorted.size() - 1].getDefiningOp()); - if (failed(mlir::verify(m))) { - assert(false); + if (!operandsSorted.empty()) { + auto dominantOperandOp = + operandsSorted[operandsSorted.size() - 1].getDefiningOp(); + if (dominantOperandOp) { + moveAfter(op, dominantOperandOp); + assert(succeeded(mlir::verify(m))); } } movedOperations.push_back(op); } + // Moves Q tensor in Flash Attention algorithm out of the + // "main" flash attention loop. Since Q tensor is the loop invariant, this way + // we ensure that loading of Q tensor, Q tensor transformations and related + // layout conversions happen only once. void moveQTensorOutOfTheLoop(ModuleOp m) { m.walk([&](tt::DotOp dotOp) { if (isFAChainDot(dotOp)) { diff --git a/python/perf-kernels/06-fused-attention-transV.py b/python/perf-kernels/06-fused-attention-transV.py index cf27d874f687..a3bd99074de5 100644 --- a/python/perf-kernels/06-fused-attention-transV.py +++ b/python/perf-kernels/06-fused-attention-transV.py @@ -564,7 +564,6 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): ## restore the grid for bwd kernel best_config = _attn_fwd.get_best_config() - # print(best_config) block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)