forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BACKEND] Dedup elementwise in LLVM IR based on constancy (triton-lan…
…g#2512) ### Summary When Triton GPU IR is lowered into LLVM IR, we can make use of the constancy information about the result of the elementwise ops to deduplicate otherwise redundant computation. That is the contribution of this PR: the constancy is checked and, if possible, some of the values in LLVM IR are reused multiple times instead of computing equal values separately. The change is beneficial for the PyTorch 2 / TorchInductor-generated Triton code, as the leftmost sub-indices extracted from the flat index by div / mod operations can be equal, given sufficiently large 2^n factor in the rightmost rightmost dimension(s). This makes the computation resulting in those sub-indices redundant. Consequently, under the necessary constancy conditions, the redundant indexing arithmetics can be deduplicated. We observe up to 29% decrease in the latency of some of our jagged tensor kernels
- Loading branch information
Showing
2 changed files
with
220 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" --llvm-optimize-for-nvvm-target | FileCheck %s | ||
|
||
// CHECK-LABEL: dedup_by_constancy_full | ||
// CHECK-COUNT-5: llvm.add | ||
// CHECK-NOT: llvm.add | ||
// CHECK: llvm.icmp "slt" | ||
// CHECK-NOT: llvm.icmp "slt" | ||
// CHECK: llvm.sdiv | ||
// CHECK-NOT: llvm.sdiv | ||
// CHECK: llvm.getelementptr %arg0[[[REGISTER:%[0-9]+]]] | ||
// CHECK-COUNT-7: llvm.getelementptr %arg0[[[REGISTER]]] | ||
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER]]] | ||
#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> | ||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { | ||
tt.func public @dedup_by_constancy_full(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { | ||
%cst = arith.constant dense<256> : tensor<1024xi32, #blocked> | ||
%c1024_i32 = arith.constant 1024 : i32 | ||
%0 = tt.get_program_id x : i32 | ||
%1 = arith.muli %0, %c1024_i32 : i32 | ||
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> | ||
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> | ||
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> | ||
%5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked> | ||
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> | ||
%7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked> | ||
%8 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked> | ||
%9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked> | ||
%10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked> | ||
%11 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked> | ||
%12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked> | ||
tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked> | ||
tt.return | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: dedup_by_constancy_partial | ||
// CHECK-COUNT-8: llvm.add | ||
// CHECK-NOT: llvm.add | ||
// CHECK: llvm.icmp "slt" | ||
// CHECK-NOT: llvm.icmp "slt" | ||
// CHECK-COUNT-2: llvm.sdiv | ||
// CHECK-NOT: llvm.sdiv | ||
// CHECK: llvm.getelementptr %arg0[[[REGISTER1:%[0-9]+]]] | ||
// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER1]]] | ||
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER1]]] | ||
// CHECK: llvm.getelementptr %arg0[[[REGISTER2:%[0-9]+]]] | ||
// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER2]]] | ||
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER2]]] | ||
#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> | ||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { | ||
tt.func public @dedup_by_constancy_partial(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { | ||
%cst = arith.constant dense<4> : tensor<1024xi32, #blocked> | ||
%c1024_i32 = arith.constant 1024 : i32 | ||
%0 = tt.get_program_id x : i32 | ||
%1 = arith.muli %0, %c1024_i32 : i32 | ||
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> | ||
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> | ||
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> | ||
%5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked> | ||
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> | ||
%7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked> | ||
%8 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked> | ||
%9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked> | ||
%10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked> | ||
%11 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked> | ||
%12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked> | ||
tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked> | ||
tt.return | ||
} | ||
} |