Skip to content

Commit

Permalink
Address review comments and add some minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ognjen Plavsic committed Mar 23, 2024
1 parent d0d5e0d commit c6d04b2
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
51 changes: 40 additions & 11 deletions lib/Dialect/TritonGPU/Transforms/AMDReorderInstructions.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
/*
* Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Verifier.h"
Expand All @@ -23,7 +46,11 @@ class TritonAMDGPUReorderInstructionsPass
mlir::DominanceInfo dom(m);

for (auto operand : operands) {
operandsSorted.push_back(operand);
// Sort only operands for which defining op can be fetched. This will
// exclude, for example, block arguments.
if (operand.getDefiningOp()) {
operandsSorted.push_back(operand);
}
}

if (operandsSorted.size() == 1) {
Expand All @@ -34,10 +61,8 @@ class TritonAMDGPUReorderInstructionsPass
[&](const Value &a, const Value &b) {
Operation *operandA = a.getDefiningOp();
Operation *operandB = b.getDefiningOp();
if (operandA && operandB) {
return dom.dominates(operandA, operandB);
}
return false;
assert(operandA && operandB);
return dom.dominates(operandA, operandB);
});
}

Expand Down Expand Up @@ -97,18 +122,22 @@ class TritonAMDGPUReorderInstructionsPass
SmallVector<Value> operandsSorted;
sortOperandsByDominance(operands, operandsSorted);

if (!operandsSorted.empty() &&
operandsSorted[operandsSorted.size() - 1].getDefiningOp()) {

moveAfter(op, operandsSorted[operandsSorted.size() - 1].getDefiningOp());
if (failed(mlir::verify(m))) {
assert(false);
if (!operandsSorted.empty()) {
auto dominantOperandOp =
operandsSorted[operandsSorted.size() - 1].getDefiningOp();
if (dominantOperandOp) {
moveAfter(op, dominantOperandOp);
assert(succeeded(mlir::verify(m)));
}
}

movedOperations.push_back(op);
}

// Moves Q tensor in Flash Attention algorithm out of the
// "main" flash attention loop. Since Q tensor is the loop invariant, this way
// we ensure that loading of Q tensor, Q tensor transformations and related
// layout conversions happen only once.
void moveQTensorOutOfTheLoop(ModuleOp m) {
m.walk([&](tt::DotOp dotOp) {
if (isFAChainDot(dotOp)) {
Expand Down
1 change: 0 additions & 1 deletion python/perf-kernels/06-fused-attention-transV.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,6 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):

## restore the grid for bwd kernel
best_config = _attn_fwd.get_best_config()
# print(best_config)
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)

Expand Down

0 comments on commit c6d04b2

Please sign in to comment.