From 6d45d6c67d943833b25362ae450164e9dab64c86 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Wed, 25 Oct 2023 17:25:29 +0200 Subject: [PATCH] [BACKEND] Dedup elementwise in LLVM IR based on constancy (#2512) ### Summary When Triton GPU IR is lowered into LLVM IR, we can make use of the constancy information about the result of the elementwise ops to deduplicate otherwise redundant computation. That is the contribution of this PR: the constancy is checked and, if possible, some of the values in LLVM IR are reused multiple times instead of computing equal values separately. The change is beneficial for the PyTorch 2 / TorchInductor-generated Triton code, as the leftmost sub-indices extracted from the flat index by div / mod operations can be equal, given sufficiently large 2^n factor in the rightmost rightmost dimension(s). This makes the computation resulting in those sub-indices redundant. Consequently, under the necessary constancy conditions, the redundant indexing arithmetics can be deduplicated. We observe up to 29% decrease in the latency of some of our jagged tensor kernels --- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 175 +++++++++++++++--- test/Conversion/dedup-by-constancy.mlir | 72 +++++++ 2 files changed, 220 insertions(+), 27 deletions(-) create mode 100644 test/Conversion/dedup-by-constancy.mlir diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index da5696871f3a..3568518645dd 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1186,8 +1186,118 @@ class ElementwiseOpConversionBase using OpAdaptor = typename SourceOp::Adaptor; explicit ElementwiseOpConversionBase( - TritonGPUToLLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) - : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} + TritonGPUToLLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit = 1) + : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit), + axisAnalysisPass(axisAnalysisPass) {} + + // Try to deduplicate the resultVals based on the + // constancy properties of the result discovered by + // the axis analysis pass. If possible, redundant + // computation is eliminated. + SmallVector maybeDeduplicate(SourceOp op, + SmallVector resultVals) const { + if (!isMemoryEffectFree(op)) + // the op has side effects: can't dedup + return resultVals; + SmallVector results = op->getResults(); + if (results.size() == 0 || results.size() > 1) + // there must be exactly 1 result + return resultVals; + Value result = results[0]; + Type type = result.getType(); + if (!type) + return resultVals; + RankedTensorType rtType = type.dyn_cast(); + if (!rtType) + // the result must be a tensor + return resultVals; + Attribute encoding = rtType.getEncoding(); + if (!encoding) + // encoding not available + return resultVals; + if (!encoding.dyn_cast() && + !encoding.dyn_cast()) { + // TODO: constraining the ecndoing type here is necessary + // for avoiding crashes in the triton::gpu::getElemsPerThread + // call below happening in the test_core::test_fp8_dot_acc + return resultVals; + } + + SmallVector elemsPerThread = + triton::gpu::getElemsPerThread(rtType); + int rank = elemsPerThread.size(); + if (product(elemsPerThread) != resultVals.size()) + return resultVals; + AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(result); + if (!axisInfo) + // axis info (e.g., constancy) not available + return resultVals; + SmallVector sizePerThread = + triton::gpu::getSizePerThread(encoding); + if (rank != sizePerThread.size()) + return resultVals; + + SmallVector constancy = axisInfo->getConstancy(); + if (rank != constancy.size()) + return resultVals; + bool hasConstancy = false; + for (int i = 0; i < rank; ++i) { + if (constancy[i] > sizePerThread[i]) { + if (constancy[i] % sizePerThread[i] != 0) + // constancy is not evenly covered by sizePerThread + return resultVals; + // can't move the values across different + // "sizePerThread"-sized blocks + constancy[i] = sizePerThread[i]; + } + if (elemsPerThread[i] < 1 || constancy[i] < 1) + return resultVals; + if (!(elemsPerThread[i] % constancy[i] == 0 || + constancy[i] % elemsPerThread[i] == 0)) + // either the constancy along each dimension must fit + // into the elemsPerThread or the other way around + return resultVals; + if (constancy[i] > 1) + hasConstancy = true; + } + if (!hasConstancy) + // nothing to deduplicate + return resultVals; + + if (rank > 1) { + // reorder the shape and constancy vectors by the axis order: + // from the fastest-changing to the smallest-changing axis + SmallVector order = triton::gpu::getOrder(encoding); + if (rank != order.size()) + return resultVals; + ArrayRef orderRef(order); + elemsPerThread = reorder(ArrayRef(elemsPerThread), orderRef); + constancy = reorder(ArrayRef(constancy), orderRef); + } + + SmallVector strides(rank, 1); + for (int i = 1; i < rank; ++i) { + strides[i] = strides[i - 1] * elemsPerThread[i - 1]; + } + SmallVector dedupResultVals; + dedupResultVals.reserve(resultVals.size()); + for (int i = 0; i < resultVals.size(); ++i) { + // each coordinate of the orig_idx is "coarsened" using the + // constancy along this dimension: the resulting dedup_idx + // points to the reused value in the original resultsVal + int orig_idx = i; + int dedup_idx = 0; + for (int j = 0; j < rank; ++j) { + int coord_j = orig_idx % elemsPerThread[j]; + dedup_idx += (coord_j / constancy[j] * constancy[j]) * strides[j]; + orig_idx /= elemsPerThread[j]; + } + dedupResultVals.push_back(resultVals[dedup_idx]); + } + + return dedupResultVals; + } LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, @@ -1230,6 +1340,7 @@ class ElementwiseOpConversionBase auto argTy = op->getOperand(0).getType(); resultVals = reorderValues(resultVals, argTy, resultTy); } + resultVals = maybeDeduplicate(op, resultVals); resultVals = packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); resultVals = this->getTypeConverter()->packMfmaOperand(resultVals, resultTy, rewriter, loc); @@ -1241,6 +1352,9 @@ class ElementwiseOpConversionBase return success(); } +protected: + ModuleAxisInfoAnalysis &axisAnalysisPass; + private: int computeCapability; }; @@ -1272,8 +1386,9 @@ struct FpToFpOpConversion triton::FpToFpOp, FpToFpOpConversion>::ElementwiseOpConversionBase; explicit FpToFpOpConversion(TritonGPUToLLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, int computeCapability, PatternBenefit benefit = 1) - : ElementwiseOpConversionBase(typeConverter, benefit), + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), computeCapability(computeCapability) {} static Value convertBf16ToFp32(Location loc, @@ -2097,12 +2212,14 @@ void populateElementwiseOpToLLVMPatterns( ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, int computeCapability, PatternBenefit benefit) { #define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \ - patterns.add>(typeConverter, benefit); + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); POPULATE_TERNARY_OP(arith::SelectOp, LLVM::SelectOp) #undef POPULATE_TERNARY_OP #define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ - patterns.add>(typeConverter, benefit); + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // * @@ -2126,7 +2243,8 @@ void populateElementwiseOpToLLVMPatterns( #undef POPULATE_BINARY_OP #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ - patterns.add>(typeConverter, benefit); + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp) @@ -2142,29 +2260,32 @@ void populateElementwiseOpToLLVMPatterns( POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) #undef POPULATE_UNARY_OP - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - - patterns.add(typeConverter, computeCapability, benefit); - - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, axisInfoAnalysis, + computeCapability, benefit); + + patterns.add(typeConverter, axisInfoAnalysis, + benefit); + patterns.add(typeConverter, + axisInfoAnalysis, benefit); // ExpOpConversionApprox will try using ex2.approx if the input type is // FP32. For other input types, ExpOpConversionApprox will return failure and // ElementwiseOpConversion defined below will call // __nv_expf for higher-precision calculation - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); } diff --git a/test/Conversion/dedup-by-constancy.mlir b/test/Conversion/dedup-by-constancy.mlir new file mode 100644 index 000000000000..455a71548fef --- /dev/null +++ b/test/Conversion/dedup-by-constancy.mlir @@ -0,0 +1,72 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" --llvm-optimize-for-nvvm-target | FileCheck %s + +// CHECK-LABEL: dedup_by_constancy_full +// CHECK-COUNT-5: llvm.add +// CHECK-NOT: llvm.add +// CHECK: llvm.icmp "slt" +// CHECK-NOT: llvm.icmp "slt" +// CHECK: llvm.sdiv +// CHECK-NOT: llvm.sdiv +// CHECK: llvm.getelementptr %arg0[[[REGISTER:%[0-9]+]]] +// CHECK-COUNT-7: llvm.getelementptr %arg0[[[REGISTER]]] +// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER]]] +#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @dedup_by_constancy_full(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<256> : tensor<1024xi32, #blocked> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> + %7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked> + %11 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked> + tt.return + } +} + +// ----- + +// CHECK-LABEL: dedup_by_constancy_partial +// CHECK-COUNT-8: llvm.add +// CHECK-NOT: llvm.add +// CHECK: llvm.icmp "slt" +// CHECK-NOT: llvm.icmp "slt" +// CHECK-COUNT-2: llvm.sdiv +// CHECK-NOT: llvm.sdiv +// CHECK: llvm.getelementptr %arg0[[[REGISTER1:%[0-9]+]]] +// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER1]]] +// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER1]]] +// CHECK: llvm.getelementptr %arg0[[[REGISTER2:%[0-9]+]]] +// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER2]]] +// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER2]]] +#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @dedup_by_constancy_partial(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<4> : tensor<1024xi32, #blocked> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> + %7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked> + %11 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked> + tt.return + } +}