Skip to content

Commit

Permalink
[BACKEND] Pipeliner refactoring (triton-lang#2565)
Browse files Browse the repository at this point in the history
Refactor the pipeliner pass in order to make it more generic. The main
change is that the pipeliner is now broken into 2 pieces one calculating
a modulo schedule and create async ops based on the IR and an expander
that will generate the pipelined IR based on the modulo schedule.
The advantage of separating the two pieces is that it will allow us to
create different schedule without having to change the expander and it
will allow for more complex schedules.
For now the schedule generated for matmul case matches rougly the
schedule picked by the previous pipeliner in order to avoid changes.

This also creates a different sequence of insert/extract slice for the
alloc. We should probably change shared alloc to use memory semantic.
  • Loading branch information
ThomasRaoux authored and zhanglx13 committed Nov 9, 2023
1 parent 6d45d6c commit d1de5aa
Show file tree
Hide file tree
Showing 14 changed files with 1,990 additions and 1,911 deletions.
3 changes: 2 additions & 1 deletion include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ bool isExpensiveLoadOrStore(Operation *op);

bool canFoldIntoConversion(Operation *op, Attribute targetEncoding);

// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
// updated and needs to be updated separatly for the loop to be correct.
scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop,
ValueRange newIterOperands);

Expand Down Expand Up @@ -143,7 +145,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,

Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape);

} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
32 changes: 26 additions & 6 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2205,18 +2205,37 @@ struct IndexCastOpLowering
}
};

struct SelectOpConversion
: ElementwiseOpConversionBase<mlir::arith::SelectOp, SelectOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::SelectOp, SelectOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

SmallVector<Value> createDestOps(mlir::arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
std::array<Value, 3> llvmOperands;
if (operands[0].size() == 2) {
// Case of scalar condition with tensor operands.
assert(op.getCondition().getType().isInteger(1));
llvmOperands = {adaptor.getCondition(), operands[0][0], operands[0][1]};
} else {
llvmOperands = {operands[0][0], operands[0][1], operands[0][2]};
}
return {rewriter.create<LLVM::SelectOp>(
loc, llvmOperands[1].getType(), llvmOperands,
adaptor.getAttributes().getValue())};
}
};

void populateElementwiseOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
int computeCapability, PatternBenefit benefit) {
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
typeConverter, axisInfoAnalysis, benefit);
POPULATE_TERNARY_OP(arith::SelectOp, LLVM::SelectOp)
#undef POPULATE_TERNARY_OP

#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
typeConverter, axisInfoAnalysis, benefit);
Expand Down Expand Up @@ -2270,6 +2289,7 @@ void populateElementwiseOpToLLVMPatterns(
patterns.add<FAddOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FMulOpConversion>(typeConverter, axisInfoAnalysis, benefit);

patterns.add<SelectOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<ExtFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<TruncFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FPToSIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ add_mlir_dialect_library(TritonGPUTransforms
OptimizeDotOperands.cpp
OptimizeEpilogue.cpp
OptimizeThreadLocality.cpp
Pipeline.cpp
Pipeliner/MatmulLoopPipeline.cpp
Pipeliner/PipelineExpander.cpp
Pipeliner/SoftwarePipeliner.cpp
Prefetch.cpp
RemoveLayoutConversions.cpp
ReorderInstructions.cpp
Expand Down
Loading

0 comments on commit d1de5aa

Please sign in to comment.