Skip to content

Commit

Permalink
[AMD][CanonicalizePtr] Add a series of fixes for the new pipeliner (#…
Browse files Browse the repository at this point in the history
…4743)

This PR is fixing some issues with the `CanonicalizePointer` pass and
the new pipeliner:
- Don't traverse twice the same nodes
- Don't assume the operation to delete are in the correct order, but
  force dropping the reference of the ops before we delete them
- Add support for select operation (+test), which is used when dealing
with multiple buffer (this part has been coauthored with @sjw36)
  • Loading branch information
giuseros authored Sep 19, 2024
1 parent ad0cdfb commit 3ae95a8
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 16 deletions.
64 changes: 62 additions & 2 deletions test/TritonGPU/amd/amd-canonicalize-pointers.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL tt.func @whileOp
// CHECK-LABEL: tt.func @whileOp
tt.func @whileOp(%arg0: !tt.ptr<f32>, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{
%c1024_i32 = arith.constant 1024 : i32
%c0 = arith.constant 0: index
Expand Down Expand Up @@ -307,7 +307,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL tt.func @condBranch
// CHECK-LABEL: tt.func @condBranch
tt.func @condBranch(%arg0 : !tt.ptr<f32>, %i1 : i1) -> tensor<1024xf32, #blocked>{
%c1024_i32 = arith.constant 1024 : i32
%c0 = arith.constant 0: index
Expand Down Expand Up @@ -445,6 +445,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// The scalar offset we want is (%1*%arg1), while the variable offset should be (%2*%arg1 + %8)
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: tt.func public @matmul_kernel
tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) {
%c128_i32 = arith.constant 128 : i32
%0 = tt.get_program_id x : i32
Expand Down Expand Up @@ -486,3 +487,62 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}


// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: tt.func @select
tt.func @select(%arg0 : !tt.ptr<f32>, %i1 : i1) -> tensor<1024xf32, #blocked>{
%c1024_i32 = arith.constant 1024 : i32
%c0 = arith.constant 0: index
%c128 = arith.constant 128: index
%c1 = arith.constant 1 : index
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
// CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32
// CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked>
// CHECK: %[[baseOffset:.*]] = tt.splat %{{.*}} : i64
// CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]]
// CHECK: %[[extVariableOffset:.*]] = arith.extsi %[[variableOffset]]
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
// CHECK: %[[offset2:.*]] = arith.addi %[[extVariableOffset]], %[[baseOffset]]
// CHECK: %[[scalarPtr1:.*]] = arith.select %arg1, %arg0, %[[scalarPtr]]
// CHECK: %[[offset0:.*]] = arith.select %arg1, {{.*}}, %[[offset2]]
// CHECK: %[[offset1:.*]] = arith.trunci %[[offset0]]
// CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr1]]
// CHECK: tt.addptr %[[ptr]], %[[offset1]]
%7 = arith.select %i1, %5 , %6 : tensor<1024x!tt.ptr<f32>, #blocked>
%out = tt.load %7: tensor<1024x!tt.ptr<f32>, #blocked>
tt.return %out : tensor<1024xf32, #blocked>
}
}

// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1100", "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: tt.func @where_kernel
tt.func @where_kernel(%arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}){
%c0_i8 = arith.constant 0 : i8
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%9 = arith.cmpi ne, %c0_i8, %c0_i8 : i8
%10 = arith.select %9, %arg1, %arg2 : !tt.ptr<i64>
// CHECK: %[[selectPtr:.*]] = arith.select {{.*}} : !tt.ptr<i64>
%11 = tt.splat %10: !tt.ptr<i64> -> tensor<1024x!tt.ptr<i64>, #blocked>
%13 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<i64>, #blocked>, tensor<1024xi32, #blocked>
// CHECK: %[[selectPtr0:.*]] = tt.addptr %[[selectPtr]]
// CHECK: %[[tensorPtr:.*]] = tt.splat %[[selectPtr0]]
// CHECK: tt.addptr %[[tensorPtr]]
%14 = tt.load %13 : tensor<1024x!tt.ptr<i64>, #blocked>
tt.return
}
}
80 changes: 66 additions & 14 deletions third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include <utility>

#include "TritonAMDGPUTransforms/Passes.h"
Expand Down Expand Up @@ -127,6 +128,8 @@ class PointerCanonicalizer {
Value &nextPtr);
LogicalResult rewriteCondBranchOp(cf::CondBranchOp condBrOp, Location curLoc,
OpOperand *operand, Value &nextPtr);
LogicalResult rewriteSelectOp(arith::SelectOp selectOp, Location curLoc,
OpOperand *operand, Value &nextPtr);
LogicalResult rewriteBranchOp(cf::BranchOp branchOp, Location curLoc,
OpOperand *operand, Value &nextPtr);

Expand Down Expand Up @@ -370,7 +373,8 @@ PointerCanonicalizer::decomposeOffsetFromExpr(Location loc, Value expr,
}

// Create a tensor pointer from a fat pointer `fatPtr`. The tensor pointer is
// obtained by splatting the scalar pointer using the `fatPtr.offset` shape.
// obtained by splatting the `fatPtr.basePtr` using the `fatPtr.offset` shape
// and adding the offset to it.
Value PointerCanonicalizer::createTensorPointer(FatPtr fatPtr, Location loc) {
Value basePtr = fatPtr.basePtr;
Value offset = fatPtr.offset;
Expand All @@ -380,8 +384,14 @@ Value PointerCanonicalizer::createTensorPointer(FatPtr fatPtr, Location loc) {
// Splat the scalar pointer
auto tensorPtrType = RankedTensorType::get(offsetShape, basePtr.getType(),
offsetType.getEncoding());
if (fatPtr.canNarrow)
offset = narrow64bitOffsetTo32bits(rewriter, loc, offset);

Value tensorPtr =
rewriter.create<triton::SplatOp>(loc, tensorPtrType, basePtr);

tensorPtr =
rewriter.create<triton::AddPtrOp>(loc, tensorPtrType, tensorPtr, offset);
return tensorPtr;
}

Expand All @@ -392,17 +402,14 @@ LogicalResult PointerCanonicalizer::materializeFatPointer(Operation *op,
auto fatPtr = pointers[ptr];
Value basePtr = fatPtr.basePtr;
Value offset = fatPtr.offset;
if (fatPtr.canNarrow)
offset = narrow64bitOffsetTo32bits(rewriter, loc, offset);

// Create the tensor pointer (i.e., splat the base && add the offset)
Value newPtr = basePtr;
if (isa<RankedTensorType>(ptr.getType())) {
// Splat the base pointer
Value tensorPtr = createTensorPointer(fatPtr, loc);
// Add the tensor offset to the base pointer
newPtr = rewriter.create<triton::AddPtrOp>(loc, tensorPtr.getType(),
tensorPtr, offset);
}
if (isa<RankedTensorType>(ptr.getType()))
newPtr = createTensorPointer(fatPtr, loc);

// Save the fat pointer in the table
pointers[newPtr] = fatPtr;

// Map and replace the load
IRMapping mapper;
Expand Down Expand Up @@ -737,6 +744,43 @@ LogicalResult PointerCanonicalizer::rewriteCondBranchOp(
return success();
}

LogicalResult PointerCanonicalizer::rewriteSelectOp(arith::SelectOp selectOp,
Location curLoc,
OpOperand *curOperand,
Value &nextPtr) {
Value trueVal = selectOp.getTrueValue();
Value falseVal = selectOp.getFalseValue();
Value cond = selectOp.getCondition();
// If we didn't traverse both operands, simply materialize the pointer
if (!pointers.contains(trueVal) || !pointers.contains(falseVal))
return materializeFatPointer(selectOp, curLoc, curOperand->get());

// If both have been traversed, then we can rewrite select of pointers as a
// select of base and offset
FatPtr fatPtrT = pointers[trueVal];
FatPtr fatPtrF = pointers[falseVal];
nextPtr = selectOp.getResult();

// Simple case of a scalar select: update the base pointer
if (!isa<RankedTensorType>(selectOp.getType())) {
FatPtr fatPtr = pointers[trueVal];
pointers[nextPtr] = fatPtr.copyWithOffset(nextPtr);
nextPtr = selectOp.getResult();
return success();
}

// Rewrite `select` for base and offset
Value newBase = rewriter.create<arith::SelectOp>(
curLoc, cond, fatPtrT.basePtr, fatPtrF.basePtr);
Value newOffset = rewriter.create<arith::SelectOp>(
curLoc, cond, fatPtrT.offset, fatPtrF.offset);
assert(fatPtrT.canNarrow == fatPtrF.canNarrow);

pointers[nextPtr] = fatPtrT.copy(newBase, newOffset);
opToDelete.insert(selectOp);
return success();
}

LogicalResult PointerCanonicalizer::rewriteBranchOp(cf::BranchOp branchOp,
Location curLoc,
OpOperand *curOperand,
Expand Down Expand Up @@ -803,6 +847,9 @@ LogicalResult PointerCanonicalizer::rewritePointer(Value argPtr) {
.Case<cf::CondBranchOp>([&](auto condBrOp) {
res = rewriteCondBranchOp(condBrOp, curLoc, curOperand, nextPtr);
})
.Case<arith::SelectOp>([&](auto selectOp) {
res = rewriteSelectOp(selectOp, curLoc, curOperand, nextPtr);
})
.Case<cf::BranchOp>([&](auto branchOp) {
res = rewriteBranchOp(branchOp, curLoc, curOperand, nextPtr);
})
Expand All @@ -820,7 +867,8 @@ LogicalResult PointerCanonicalizer::rewritePointer(Value argPtr) {
// Keep propagating the fat pointer down the IR
if (nextPtr)
for (OpOperand &use : nextPtr.getUses())
queue.push_back(&use);
if (!opToDelete.contains(use.getOwner()))
queue.push_back(&use);
}
return success();
}
Expand All @@ -842,9 +890,13 @@ LogicalResult PointerCanonicalizer::rewriteFunction(triton::FuncOp funcOp) {
if (failed(rewritePointer(arg)))
return failure();

// Clean-up
for (Operation *op : llvm::reverse(opToDelete))
op->erase();
// Clean-up: don't assume the operation to delete are in the correct order,
// but force dropping the reference of the ops before we delete them
for (Operation *op : opToDelete) {
op->dropAllReferences();
op->dropAllDefinedValueUses();
rewriter.eraseOp(op);
}
}
return success();
}
Expand Down

0 comments on commit 3ae95a8

Please sign in to comment.