Skip to content

Commit

Permalink
Address review feedback - 2
Browse files Browse the repository at this point in the history
  • Loading branch information
giuseros committed Sep 19, 2024
1 parent 74c76f7 commit 034c64e
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions test/TritonGPU/amd/amd-canonicalize-pointers.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,28 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
tt.return %out : tensor<1024xf32, #blocked>
}
}

// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1100", "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: tt.func @where_kernel
tt.func @where_kernel(%arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}){
%c0_i8 = arith.constant 0 : i8
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%9 = arith.cmpi ne, %c0_i8, %c0_i8 : i8
%10 = arith.select %9, %arg1, %arg2 : !tt.ptr<i64>
// CHECK: %[[selectPtr:.*]] = arith.select {{.*}} : !tt.ptr<i64>
%11 = tt.splat %10: !tt.ptr<i64> -> tensor<1024x!tt.ptr<i64>, #blocked>
%13 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<i64>, #blocked>, tensor<1024xi32, #blocked>
// CHECK: %[[selectPtr0:.*]] = tt.addptr %[[selectPtr]]
// CHECK: %[[tensorPtr:.*]] = tt.splat %[[selectPtr0]]
// CHECK: tt.addptr %[[tensorPtr]]
%14 = tt.load %13 : tensor<1024x!tt.ptr<i64>, #blocked>
tt.return
}
}

0 comments on commit 034c64e

Please sign in to comment.