Skip to content

Commit

Permalink
[BACKEND] Dedup elementwise in LLVM IR based on constancy (triton-lan…
Browse files Browse the repository at this point in the history
…g#2512)

### Summary

When Triton GPU IR is lowered into LLVM IR, we can make use of the
constancy information about the result of the elementwise ops to
deduplicate otherwise redundant computation. That is the contribution of
this PR: the constancy is checked and, if possible, some of the values
in LLVM IR are reused multiple times instead of computing equal values
separately.

The change is beneficial for the PyTorch 2 / TorchInductor-generated
Triton code, as the leftmost sub-indices extracted from the flat index
by div / mod operations can be equal, given sufficiently large 2^n
factor in the rightmost rightmost dimension(s). This makes the
computation resulting in those sub-indices redundant. Consequently,
under the necessary constancy conditions, the redundant indexing
arithmetics can be deduplicated. We observe up to 29% decrease in the
latency of some of our jagged tensor kernels
  • Loading branch information
aakhundov authored and zhanglx13 committed Nov 9, 2023
1 parent ac4ee36 commit 6d45d6c
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 27 deletions.
175 changes: 148 additions & 27 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1186,8 +1186,118 @@ class ElementwiseOpConversionBase
using OpAdaptor = typename SourceOp::Adaptor;

explicit ElementwiseOpConversionBase(
TritonGPUToLLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
TritonGPUToLLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit),
axisAnalysisPass(axisAnalysisPass) {}

// Try to deduplicate the resultVals based on the
// constancy properties of the result discovered by
// the axis analysis pass. If possible, redundant
// computation is eliminated.
SmallVector<Value> maybeDeduplicate(SourceOp op,
SmallVector<Value> resultVals) const {
if (!isMemoryEffectFree(op))
// the op has side effects: can't dedup
return resultVals;
SmallVector<Value> results = op->getResults();
if (results.size() == 0 || results.size() > 1)
// there must be exactly 1 result
return resultVals;
Value result = results[0];
Type type = result.getType();
if (!type)
return resultVals;
RankedTensorType rtType = type.dyn_cast<RankedTensorType>();
if (!rtType)
// the result must be a tensor
return resultVals;
Attribute encoding = rtType.getEncoding();
if (!encoding)
// encoding not available
return resultVals;
if (!encoding.dyn_cast<triton::gpu::BlockedEncodingAttr>() &&
!encoding.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
// TODO: constraining the ecndoing type here is necessary
// for avoiding crashes in the triton::gpu::getElemsPerThread
// call below happening in the test_core::test_fp8_dot_acc
return resultVals;
}

SmallVector<unsigned> elemsPerThread =
triton::gpu::getElemsPerThread(rtType);
int rank = elemsPerThread.size();
if (product<unsigned>(elemsPerThread) != resultVals.size())
return resultVals;
AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(result);
if (!axisInfo)
// axis info (e.g., constancy) not available
return resultVals;
SmallVector<unsigned> sizePerThread =
triton::gpu::getSizePerThread(encoding);
if (rank != sizePerThread.size())
return resultVals;

SmallVector<int64_t> constancy = axisInfo->getConstancy();
if (rank != constancy.size())
return resultVals;
bool hasConstancy = false;
for (int i = 0; i < rank; ++i) {
if (constancy[i] > sizePerThread[i]) {
if (constancy[i] % sizePerThread[i] != 0)
// constancy is not evenly covered by sizePerThread
return resultVals;
// can't move the values across different
// "sizePerThread"-sized blocks
constancy[i] = sizePerThread[i];
}
if (elemsPerThread[i] < 1 || constancy[i] < 1)
return resultVals;
if (!(elemsPerThread[i] % constancy[i] == 0 ||
constancy[i] % elemsPerThread[i] == 0))
// either the constancy along each dimension must fit
// into the elemsPerThread or the other way around
return resultVals;
if (constancy[i] > 1)
hasConstancy = true;
}
if (!hasConstancy)
// nothing to deduplicate
return resultVals;

if (rank > 1) {
// reorder the shape and constancy vectors by the axis order:
// from the fastest-changing to the smallest-changing axis
SmallVector<unsigned> order = triton::gpu::getOrder(encoding);
if (rank != order.size())
return resultVals;
ArrayRef<unsigned> orderRef(order);
elemsPerThread = reorder(ArrayRef<unsigned>(elemsPerThread), orderRef);
constancy = reorder(ArrayRef<int64_t>(constancy), orderRef);
}

SmallVector<unsigned> strides(rank, 1);
for (int i = 1; i < rank; ++i) {
strides[i] = strides[i - 1] * elemsPerThread[i - 1];
}
SmallVector<Value> dedupResultVals;
dedupResultVals.reserve(resultVals.size());
for (int i = 0; i < resultVals.size(); ++i) {
// each coordinate of the orig_idx is "coarsened" using the
// constancy along this dimension: the resulting dedup_idx
// points to the reused value in the original resultsVal
int orig_idx = i;
int dedup_idx = 0;
for (int j = 0; j < rank; ++j) {
int coord_j = orig_idx % elemsPerThread[j];
dedup_idx += (coord_j / constancy[j] * constancy[j]) * strides[j];
orig_idx /= elemsPerThread[j];
}
dedupResultVals.push_back(resultVals[dedup_idx]);
}

return dedupResultVals;
}

LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -1230,6 +1340,7 @@ class ElementwiseOpConversionBase
auto argTy = op->getOperand(0).getType();
resultVals = reorderValues(resultVals, argTy, resultTy);
}
resultVals = maybeDeduplicate(op, resultVals);
resultVals =
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
resultVals = this->getTypeConverter()->packMfmaOperand(resultVals, resultTy, rewriter, loc);
Expand All @@ -1241,6 +1352,9 @@ class ElementwiseOpConversionBase
return success();
}

protected:
ModuleAxisInfoAnalysis &axisAnalysisPass;

private:
int computeCapability;
};
Expand Down Expand Up @@ -1272,8 +1386,9 @@ struct FpToFpOpConversion
triton::FpToFpOp, FpToFpOpConversion>::ElementwiseOpConversionBase;

explicit FpToFpOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisAnalysisPass,
int computeCapability, PatternBenefit benefit = 1)
: ElementwiseOpConversionBase(typeConverter, benefit),
: ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit),
computeCapability(computeCapability) {}

static Value convertBf16ToFp32(Location loc,
Expand Down Expand Up @@ -2097,12 +2212,14 @@ void populateElementwiseOpToLLVMPatterns(
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
int computeCapability, PatternBenefit benefit) {
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
typeConverter, axisInfoAnalysis, benefit);
POPULATE_TERNARY_OP(arith::SelectOp, LLVM::SelectOp)
#undef POPULATE_TERNARY_OP

#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
typeConverter, axisInfoAnalysis, benefit);
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // *
Expand All @@ -2126,7 +2243,8 @@ void populateElementwiseOpToLLVMPatterns(
#undef POPULATE_BINARY_OP

#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
typeConverter, axisInfoAnalysis, benefit);
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp)
Expand All @@ -2142,29 +2260,32 @@ void populateElementwiseOpToLLVMPatterns(
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP

patterns.add<AbsIOpConversion>(typeConverter, benefit);
patterns.add<AbsFOpConversion>(typeConverter, benefit);
patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit);

patterns.add<FDivOpConversion>(typeConverter, benefit);
patterns.add<FSubOpConversion>(typeConverter, benefit);
patterns.add<FAddOpConversion>(typeConverter, benefit);
patterns.add<FMulOpConversion>(typeConverter, benefit);

patterns.add<ExtFOpConversion>(typeConverter, benefit);
patterns.add<TruncFOpConversion>(typeConverter, benefit);
patterns.add<FPToSIOpConversion>(typeConverter, benefit);
patterns.add<SIToFPOpConversion>(typeConverter, benefit);
patterns.add<IndexCastOpLowering>(typeConverter, benefit);

patterns.add<FpToFpOpConversion>(typeConverter, computeCapability, benefit);

patterns.add<ExternElementwiseOpConversion>(typeConverter, benefit);
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter, benefit);
patterns.add<AbsIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<AbsFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<CmpIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<CmpFOpConversion>(typeConverter, axisInfoAnalysis, benefit);

patterns.add<FDivOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FSubOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FAddOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FMulOpConversion>(typeConverter, axisInfoAnalysis, benefit);

patterns.add<ExtFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<TruncFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FPToSIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<SIToFPOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<IndexCastOpLowering>(typeConverter, axisInfoAnalysis, benefit);

patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis,
computeCapability, benefit);

patterns.add<ExternElementwiseOpConversion>(typeConverter, axisInfoAnalysis,
benefit);
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter,
axisInfoAnalysis, benefit);
// ExpOpConversionApprox will try using ex2.approx if the input type is
// FP32. For other input types, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
patterns.add<ExpOpConversionApprox>(typeConverter, axisInfoAnalysis, benefit);
}
72 changes: 72 additions & 0 deletions test/Conversion/dedup-by-constancy.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" --llvm-optimize-for-nvvm-target | FileCheck %s

// CHECK-LABEL: dedup_by_constancy_full
// CHECK-COUNT-5: llvm.add
// CHECK-NOT: llvm.add
// CHECK: llvm.icmp "slt"
// CHECK-NOT: llvm.icmp "slt"
// CHECK: llvm.sdiv
// CHECK-NOT: llvm.sdiv
// CHECK: llvm.getelementptr %arg0[[[REGISTER:%[0-9]+]]]
// CHECK-COUNT-7: llvm.getelementptr %arg0[[[REGISTER]]]
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER]]]
#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @dedup_by_constancy_full(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<256> : tensor<1024xi32, #blocked>
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked>
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
%7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked>
%8 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
%9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
%10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked>
%11 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
%12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked>
tt.return
}
}

// -----

// CHECK-LABEL: dedup_by_constancy_partial
// CHECK-COUNT-8: llvm.add
// CHECK-NOT: llvm.add
// CHECK: llvm.icmp "slt"
// CHECK-NOT: llvm.icmp "slt"
// CHECK-COUNT-2: llvm.sdiv
// CHECK-NOT: llvm.sdiv
// CHECK: llvm.getelementptr %arg0[[[REGISTER1:%[0-9]+]]]
// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER1]]]
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER1]]]
// CHECK: llvm.getelementptr %arg0[[[REGISTER2:%[0-9]+]]]
// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER2]]]
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER2]]]
#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @dedup_by_constancy_partial(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<4> : tensor<1024xi32, #blocked>
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked>
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
%7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked>
%8 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
%9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
%10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked>
%11 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
%12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked>
tt.return
}
}

0 comments on commit 6d45d6c

Please sign in to comment.