From 959825b30aa4a9d53483e1d3acf99120b5a039d8 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 12 Oct 2023 17:18:12 -0400 Subject: [PATCH] Update triton-shared 10/2023 (#9) This update includes various fixes and the following features: - preliminary support for block-pointers - block-pointers are lowered to memref buffers similarly to traditional triton pointer loads - support lowering for `triton.get_num_programs` by introducing extra arguments when launching triton kernels similarly to `triton.get_program_id` - improved lowering `triton.reduce` which previously only supports lowering float values - preliminary support for pointer arithmetic involving the modulo operator - the modulo operator use case is seen in the tutorial matmul example, where pointer offsets are being modded to prevent loading out-of-bound values; in such case, these values are wrapped around to the beginning of the buffer --- include/triton-shared/Analysis/MaskAnalysis.h | 12 + .../Analysis/OpFoldResultUtils.h | 19 +- include/triton-shared/Analysis/PtrAnalysis.h | 66 +- include/triton-shared/Analysis/UseAnalysis.h | 3 +- .../TritonToLinalg/TritonToLinalg.h | 3 +- lib/Analysis/MaskAnalysis.cpp | 115 ++- lib/Analysis/OpFoldResultUtils.cpp | 92 +- lib/Analysis/PtrAnalysis.cpp | 790 +++++++++++++++--- .../TritonToLinalg/TritonToLinalg.cpp | 542 +++++++++--- .../TritonToLinalg/TritonToLinalgPass.cpp | 37 +- .../TritonToLinalg/addptr_2d_example.mlir | 2 +- .../TritonToLinalg/addptr_add_value.mlir | 2 +- .../TritonToLinalg/addptr_dim1.mlir | 107 +++ .../addptr_for_accumulation.mlir | 2 +- .../TritonToLinalg/addptr_loopback.mlir | 2 +- .../addptr_mul_value_const.mlir | 6 +- .../TritonToLinalg/addptr_nested.mlir | 2 +- .../addptr_scalar_broadcast.mlir | 6 +- .../TritonToLinalg/addptr_scalar_for.mlir | 6 +- .../TritonToLinalg/addptr_scalar_for_2d.mlir | 6 +- .../addptr_scalar_loopback.mlir | 23 +- .../TritonToLinalg/addptr_scalar_nested.mlir | 10 +- .../TritonToLinalg/addptr_scalar_splat.mlir | 6 +- .../addptr_scalar_splat_2d.mlir | 6 +- test/Conversion/TritonToLinalg/bitcast.mlir | 3 +- .../TritonToLinalg/block_ptr_advance.mlir | 93 +++ .../TritonToLinalg/convert_minmax_reduce.mlir | 126 +++ .../TritonToLinalg/get_num_programs.mlir | 44 + .../TritonToLinalg/kernel-01-vector-add.mlir | 4 +- .../kernel-02-fused-softmax.mlir | 10 +- .../kernel-03-matrix-multiplication.mlir | 11 +- .../kernel-05-layer-norm-dwdb.mlir | 12 +- .../kernel-05-layer-norm-fwd.mlir | 20 +- .../TritonToLinalg/masked_ldst_1d.mlir | 6 +- .../TritonToLinalg/masked_ldst_2d.mlir | 6 +- .../masked_ldst_sitofp_other.mlir | 6 +- .../TritonToLinalg/triton_assert.mlir | 15 + .../wraparound_side_by_side.mlir | 132 +++ .../TritonToLinalg/wraparound_stacked.mlir | 129 +++ .../wraparound_unsupported_add_offset.mlir | 57 ++ 40 files changed, 2171 insertions(+), 368 deletions(-) create mode 100644 test/Conversion/TritonToLinalg/addptr_dim1.mlir create mode 100644 test/Conversion/TritonToLinalg/block_ptr_advance.mlir create mode 100644 test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir create mode 100644 test/Conversion/TritonToLinalg/get_num_programs.mlir create mode 100644 test/Conversion/TritonToLinalg/triton_assert.mlir create mode 100644 test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir create mode 100644 test/Conversion/TritonToLinalg/wraparound_stacked.mlir create mode 100644 test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir diff --git a/include/triton-shared/Analysis/MaskAnalysis.h b/include/triton-shared/Analysis/MaskAnalysis.h index 53ad3d5c..65170073 100644 --- a/include/triton-shared/Analysis/MaskAnalysis.h +++ b/include/triton-shared/Analysis/MaskAnalysis.h @@ -14,6 +14,8 @@ #include "triton/Dialect/Triton/IR/Dialect.h" +#include + namespace mlir { class ConversionPatternRewriter; @@ -64,6 +66,16 @@ struct MaskState { memref::SubViewOp getSubview(Value source, const Location loc, ConversionPatternRewriter &rewriter) const; + std::pair + getSideBySideSubviews(memref::ReinterpretCastOp chunk1, + memref::ReinterpretCastOp chunk2, const Location loc, + ConversionPatternRewriter &rewriter) const; + + std::pair + getStackedSubviews(memref::ReinterpretCastOp chunk1, + memref::ReinterpretCastOp chunk2, const Location loc, + ConversionPatternRewriter &rewriter) const; + private: // ------- // Utility functions to operate on MaskState diff --git a/include/triton-shared/Analysis/OpFoldResultUtils.h b/include/triton-shared/Analysis/OpFoldResultUtils.h index 23512278..3cd82ddf 100644 --- a/include/triton-shared/Analysis/OpFoldResultUtils.h +++ b/include/triton-shared/Analysis/OpFoldResultUtils.h @@ -15,34 +15,41 @@ namespace mlir { -class ConversionPatternRewriter; +class OpBuilder; // Return integer if ofr is an IntegerAttr. Note that this function differs // from getConstantIntValue, which returns an integer if ofr is the constant // result of an operation too. std::optional getIntAttr(const OpFoldResult ofr); +// Create a value of index type if necessary from an OpFoldResult. +Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, OpBuilder &b); + +// Create a vector of values of index type if necessary from an array of +// OpFoldResults. +SmallVector ofrsToIndexValues(ArrayRef ofrs, + const Location loc, OpBuilder &b); + // Process addition of two OFRs. If both OFRs are Integer Attributes, result // is an Integer Attribute. Otherwise, insert the arith.addi instruction if // needed and use its result Value. OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs, - const Location loc, ConversionPatternRewriter &rewriter); + const Location loc, OpBuilder &b); // Produce result = lhs - rhs. If both OFRs are Integer Attributes, result // is an Integer Attribute. Otherwise, insert the arith.addi instruction if // needed and use its result Value. OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs, - const Location loc, ConversionPatternRewriter &rewriter); + const Location loc, OpBuilder &b); // Process multiplication of two OFRs. If both OFRs are Integer Attributes, // result is an Integer Attribtue. Otherwise, insert the arith.muli // instruction if needed and use its result Value. OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs, - const Location loc, - ConversionPatternRewriter &rewriter); + const Location loc, OpBuilder &b); OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs, - const Location loc, ConversionPatternRewriter &rewriter); + const Location loc, OpBuilder &b); } // namespace mlir diff --git a/include/triton-shared/Analysis/PtrAnalysis.h b/include/triton-shared/Analysis/PtrAnalysis.h index 30b6f50f..a96ee8cb 100644 --- a/include/triton-shared/Analysis/PtrAnalysis.h +++ b/include/triton-shared/Analysis/PtrAnalysis.h @@ -22,15 +22,35 @@ class ConversionPatternRewriter; namespace triton { +struct ModuloState { + Value size; + OpFoldResult offset; + ModuloState() {} + ModuloState(Value size, OpFoldResult offset) : size{size}, offset{offset} {} + + static constexpr char const *WraparoundAttr = "ptr.wraparound_type"; + static constexpr char const *WraparoundStacked = "stacked"; + static constexpr char const *WraparoundSideBySide = "side_by_side"; +}; + // Data structure used to decode pointer arithmetics and potentially to be // translate it into memref. offsets, sizes, and strides are in unit of elements // in a linearly laid-out memory, which is the same as pointer arithmetic // operations in Triton language. scalar is a shortcut used when the entire // state describes a single scalar value. source is the base pointer. -struct PtrState { +class PtrState { + + OpFoldResult + accumulateTargetOffset(Location loc, + ConversionPatternRewriter &rewriter) const; + +public: SmallVector offsets; SmallVector sizes; SmallVector strides; + + SmallVector> modulos; + Value source; Value scalar; @@ -38,6 +58,11 @@ struct PtrState { bool isEmpty() const; + bool hasModulo() const; + + MemRefType getResultMemrefType(MLIRContext *context, int64_t offset, + ArrayRef resultShape) const; + // Process addition of two PtrStates. void addState(const PtrState &lhsState, const PtrState &rhsState, Location loc, ConversionPatternRewriter &rewriter); @@ -48,9 +73,17 @@ struct PtrState { // Produce a reinterpret cast based on the current PtrState. Additional // instructions may be inserted in calculating the final offset. - memref::ReinterpretCastOp createCastOp(ArrayRef resultShape, - const Location loc, - ConversionPatternRewriter &rewriter); + memref::ReinterpretCastOp + createCastOp(ArrayRef resultShape, const Location loc, + ConversionPatternRewriter &rewriter) const; + + SmallVector + createSideBySideCastOps(ArrayRef resultShape, const Location loc, + ConversionPatternRewriter &rewriter) const; + + SmallVector + createStackedCastOps(ArrayRef resultShape, const Location loc, + ConversionPatternRewriter &rewriter) const; }; class PtrAnalysis { @@ -95,6 +128,16 @@ class PtrAnalysis { ConversionPatternRewriter &rewriter, const llvm::SmallDenseMap &knownPtrs); + static void + visitOperandRem(arith::RemSIOp mulOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + + static void visitOperandUnrealizedCast( + UnrealizedConversionCastOp op, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + // Operand is the result of make_range. // Main assumptions: // start, end, and shape are all statically known @@ -156,6 +199,11 @@ class PtrAnalysis { ConversionPatternRewriter &rewriter, const llvm::SmallDenseMap &knownPtrs); + static void visitOperandMakeTensorPtr( + triton::MakeTensorPtrOp makeTensorPtrOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs); + // Operand is the result of addptr. // Main assumptions: // The ptr field should populate the source field @@ -177,6 +225,16 @@ class PtrAnalysis { const Location loc, ConversionPatternRewriter &rewriter, const llvm::SmallDenseMap &knownPtrs); + // Operand is the result of tt.advance. + // Main assumptions: + // The source of the tt.advance has been mapped to a reinterpret_cast + // Expected result: + // Directly grab all corresponding fields from reinterpret_cast. + // Add the offsets multiplied by the strides to the final offsets. + static void rewriteAdvanceOp(triton::AdvanceOp op, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &knownPtrs); + // Parse the state of AddPtrOp, insert any instruction needed to // calculate strides and offsets, build PtrState for this operand, and record // PtrState for knownPtrs. diff --git a/include/triton-shared/Analysis/UseAnalysis.h b/include/triton-shared/Analysis/UseAnalysis.h index e18088c8..d0c34da6 100644 --- a/include/triton-shared/Analysis/UseAnalysis.h +++ b/include/triton-shared/Analysis/UseAnalysis.h @@ -48,8 +48,9 @@ struct UseInfo : public dataflow::AbstractSparseLattice { } case UseType::MixUse: return ChangeResult::NoChange; + default: + llvm_unreachable("bad type"); } - llvm_unreachable("bad type"); } ChangeResult meet(const AbstractSparseLattice &other) override { diff --git a/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h b/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h index ffd2c16c..4c58e992 100644 --- a/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h +++ b/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h @@ -24,7 +24,8 @@ void populateTritonToLinalgCanonicalizationPatterns( RewritePatternSet &patterns); void populateTritonToLinalgConversionPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + unsigned int launchGridRank); } // namespace triton } // namespace mlir diff --git a/lib/Analysis/MaskAnalysis.cpp b/lib/Analysis/MaskAnalysis.cpp index 80ac29e5..2e54a151 100644 --- a/lib/Analysis/MaskAnalysis.cpp +++ b/lib/Analysis/MaskAnalysis.cpp @@ -68,6 +68,120 @@ MaskState::getSubview(Value source, const Location loc, source, offsets, dims, strides); } +static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + auto srcType = cast(src.getType()); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return b.create(loc, cast(dstType), src, + offsets, sizes, strides); +} + +// Assume block1 wraps around and the remainder is block2. +// +// |----------------------| +// | | | +// | block2 | block1 | +// | | | +// |----------------------| +// +// Once we copy the chunks in order, the end result is block1 followed by +// block2. +// +// buffer_tmp: +// +// |----------------------| +// | | | +// | block1 | block2 | +// | | | +// |----------------------| +// +// Assume we have the following subview: +// +// +++++++++++++++++------- +// + + | +// + subview + | +// + + | +// +++++++++++++++++------- +// +// If we simply take the subview of `buffer_tmp`, this requires an extra buffer +// to just hold the temporary result. +// +// So we can subview into block1 and block2 directly. There are 2 cases: +// + subview only spans block1 +// + subview spans both block1 and block2, creating sv1 and sv2 (illustrated +// below for case when we wrap around side-by-side) +// +// |----------------------------------------| +// | | +// | col2 col1 | +// |++++++--------| |+++++++++++++++ +// | sv2 + block2 | | block1 & sv1 + +// |++++++--------| |+++++++++++++++ +// | | +// |----------------------------------------| +// +// For simplicity, assume we only wrap around side-by-side. +// +// Let (row, col1) and (row, col2) be the dimensions of block1 and block2, +// respectively. +// +// Let (rowFull, colFull), (rowView1, colView1) and (rowView2, colView2) be the +// dimensions of the full subview, sv1, and sv2, respectively. +// +// + colView1 = min(colFull, col1) +// + colView2 = colFull - colView1 +// + rowView1 = rowView2 = row = rowFull +std::pair +MaskState::getSideBySideSubviews(memref::ReinterpretCastOp block1, + memref::ReinterpretCastOp block2, + const Location loc, + ConversionPatternRewriter &rewriter) const { + + assert(block1.getResultRank() == 2 && block2.getResultRank() == 2 && + getRank() == 2); + + OpFoldResult subviewRowFull = dims[0]; + OpFoldResult subviewColFull = dims[1]; + OpFoldResult col1 = block1.getMixedSizes()[1]; + OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, rewriter); + OpFoldResult subviewCol2 = + subOFRs(subviewColFull, subviewCol1, loc, rewriter); + + SmallVector offsets(getRank(), rewriter.getIndexAttr(0)); + SmallVector strides = block1.getMixedStrides(); + auto sv1 = createSubview(block1.getResult(), loc, rewriter, offsets, + {subviewRowFull, subviewCol1}, strides); + auto sv2 = createSubview(block2.getResult(), loc, rewriter, offsets, + {subviewRowFull, subviewCol2}, strides); + + return {sv1, sv2}; +} + +std::pair MaskState::getStackedSubviews( + memref::ReinterpretCastOp block1, memref::ReinterpretCastOp block2, + const Location loc, ConversionPatternRewriter &rewriter) const { + assert(block1.getResultRank() == 2 && block2.getResultRank() == 2 && + getRank() == 2); + + OpFoldResult subviewRowFull = dims[0]; + OpFoldResult subviewColFull = dims[1]; + OpFoldResult row1 = block1.getMixedSizes()[0]; + OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, rewriter); + OpFoldResult subviewRow2 = + subOFRs(subviewRowFull, subviewRow1, loc, rewriter); + + SmallVector offsets(getRank(), rewriter.getIndexAttr(0)); + SmallVector strides = block1.getMixedStrides(); + auto sv1 = createSubview(block1.getResult(), loc, rewriter, offsets, + {subviewRow1, subviewColFull}, strides); + auto sv2 = createSubview(block2.getResult(), loc, rewriter, offsets, + {subviewRow2, subviewColFull}, strides); + return {sv1, sv2}; +} + LogicalResult MaskState::addStateScalar(const MaskState &state, const OpFoldResult scalar, Location loc, ConversionPatternRewriter &rewriter) { @@ -132,7 +246,6 @@ LogicalResult MaskState::parseConstant(arith::ConstantOp constOp, auto constAttr = rewriter.getIndexAttr(value.getSExtValue()); auto op = arith::ConstantOp::materialize(rewriter, constAttr, rewriter.getIndexType(), loc); - this->scalar = op.getValue(); } else { auto value = constOp.getValue().cast().getInt(); diff --git a/lib/Analysis/OpFoldResultUtils.cpp b/lib/Analysis/OpFoldResultUtils.cpp index 89acd781..c98657d3 100644 --- a/lib/Analysis/OpFoldResultUtils.cpp +++ b/lib/Analysis/OpFoldResultUtils.cpp @@ -12,9 +12,6 @@ namespace mlir { -// Return integer if ofr is an IntegerAttr. Note that this function differs -// from getConstantIntValue, which returns an integer if ofr is the constant -// result of an operation too. std::optional getIntAttr(const OpFoldResult ofr) { if (ofr.is() && ofr.get().isa()) return ofr.get().dyn_cast().getInt(); @@ -22,11 +19,31 @@ std::optional getIntAttr(const OpFoldResult ofr) { return std::nullopt; } -// Process addition of two OFRs. If both OFRs are Integer Attributes, result -// is an Integer Attribute. Otherwise, insert the arith.addi instruction if -// needed and use its result Value. +Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, + OpBuilder &b) { + if (Value val = ofr.dyn_cast()) { + assert(val.getType().isIndex() && "Provided ofr is of type index"); + return val; + } + + auto intVal = getIntAttr(ofr); + if (intVal.has_value()) { + return b.create(loc, b.getIndexAttr(intVal.value())); + } + llvm_unreachable("Unexpected OpFoldResult state"); + return nullptr; +} + +SmallVector ofrsToIndexValues(ArrayRef ofrs, + const Location loc, OpBuilder &b) { + return llvm::to_vector<4>( + llvm::map_range(ofrs, [&](OpFoldResult ofr) -> Value { + return ofrToIndexValue(ofr, loc, b); + })); +} + OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs, - const Location loc, ConversionPatternRewriter &rewriter) { + const Location loc, OpBuilder &b) { auto lhsIntAttr = getIntAttr(lhs); auto rhsIntAttr = getIntAttr(rhs); @@ -38,13 +55,13 @@ OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs, // both lhs and rhs are constants, return result directly if (lhsIntAttr && rhsIntAttr) - return rewriter.getIndexAttr(lhsIntAttr.value() + rhsIntAttr.value()); + return b.getIndexAttr(lhsIntAttr.value() + rhsIntAttr.value()); // otherwise, need to create instructions to calculate new attribute value auto lhsValue = lhs.dyn_cast(); if (lhsIntAttr) { - auto lhsOp = rewriter.create( - loc, rewriter.getIndexAttr(lhsIntAttr.value())); + auto lhsOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } else { assert(lhsValue.getType().isa()); @@ -52,21 +69,18 @@ OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs, auto rhsValue = rhs.dyn_cast(); if (rhsIntAttr) { - auto rhsOp = rewriter.create( - loc, rewriter.getIndexAttr(rhsIntAttr.value())); + auto rhsOp = + b.create(loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } else { assert(lhsValue.getType().isa()); } - return rewriter.create(loc, lhsValue, rhsValue).getResult(); + return b.create(loc, lhsValue, rhsValue).getResult(); } -// Produce result = lhs - rhs. If both OFRs are Integer Attributes, result -// is an Integer Attribute. Otherwise, insert the arith.addi instruction if -// needed and use its result Value. OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs, - const Location loc, ConversionPatternRewriter &rewriter) { + const Location loc, OpBuilder &b) { auto lhsIntAttr = getIntAttr(lhs); auto rhsIntAttr = getIntAttr(rhs); @@ -76,33 +90,29 @@ OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs, // both lhs and rhs are constants, return result directly if (lhsIntAttr && rhsIntAttr) - return rewriter.getIndexAttr(lhsIntAttr.value() - rhsIntAttr.value()); + return b.getIndexAttr(lhsIntAttr.value() - rhsIntAttr.value()); // otherwise, need to create instructions to calculate new attribute value auto lhsValue = lhs.dyn_cast(); if (lhsIntAttr) { - auto lhsOp = rewriter.create( - loc, rewriter.getIndexAttr(lhsIntAttr.value())); + auto lhsOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } auto rhsValue = rhs.dyn_cast(); if (rhsIntAttr) { - auto rhsOp = rewriter.create( - loc, rewriter.getIndexAttr(rhsIntAttr.value())); + auto rhsOp = + b.create(loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } - auto sumOp = rewriter.create(loc, lhsValue, rhsValue); + auto sumOp = b.create(loc, lhsValue, rhsValue); return sumOp.getResult(); } -// Process multiplication of two OFRs. If both OFRs are Integer Attributes, -// result is an Integer Attribtue. Otherwise, insert the arith.muli -// instruction if needed and use its result Value. OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs, - const Location loc, - ConversionPatternRewriter &rewriter) { + const Location loc, OpBuilder &b) { auto lhsIntAttr = getIntAttr(lhs); auto rhsIsConst = false; @@ -131,49 +141,47 @@ OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs, // 0. both lhs and rhs are constants if (lhsIntAttr && rhsIsConst) - return rewriter.getIndexAttr(lhsIntAttr.value() * rhsConstValue); + return b.getIndexAttr(lhsIntAttr.value() * rhsConstValue); // 1. if lhs is constant but rhs is not if (lhsIntAttr && !rhsIsConst) { - auto lhsConstOp = rewriter.create( - loc, rewriter.getIndexAttr(lhsIntAttr.value())); - auto mulOp = - rewriter.create(loc, lhsConstOp.getResult(), rhs); + auto lhsConstOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + auto mulOp = b.create(loc, lhsConstOp.getResult(), rhs); return mulOp.getResult(); } // 2. if lhs is not constant assert(!lhsIntAttr); - auto mulOp = rewriter.create(loc, lhs.get(), rhs); + auto mulOp = b.create(loc, lhs.get(), rhs); return mulOp.getResult(); } OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs, - const Location loc, ConversionPatternRewriter &rewriter) { + const Location loc, OpBuilder &b) { auto lhsIntAttr = getIntAttr(lhs); auto rhsIntAttr = getIntAttr(rhs); // both lhs and rhs are constants, return result directly if (lhsIntAttr && rhsIntAttr) - return rewriter.getIndexAttr( - std::min(lhsIntAttr.value(), rhsIntAttr.value())); + return b.getIndexAttr(std::min(lhsIntAttr.value(), rhsIntAttr.value())); // otherwise, need to create instructions to calculate new attribute value auto lhsValue = lhs.dyn_cast(); if (lhsIntAttr) { - auto lhsOp = rewriter.create( - loc, rewriter.getIndexAttr(lhsIntAttr.value())); + auto lhsOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } auto rhsValue = rhs.dyn_cast(); if (rhsIntAttr) { - auto rhsOp = rewriter.create( - loc, rewriter.getIndexAttr(rhsIntAttr.value())); + auto rhsOp = + b.create(loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } - auto minOp = rewriter.create(loc, lhsValue, rhsValue); + auto minOp = b.create(loc, lhsValue, rhsValue); return minOp.getResult(); } diff --git a/lib/Analysis/PtrAnalysis.cpp b/lib/Analysis/PtrAnalysis.cpp index bd0ddb86..eb468e81 100644 --- a/lib/Analysis/PtrAnalysis.cpp +++ b/lib/Analysis/PtrAnalysis.cpp @@ -20,8 +20,33 @@ namespace mlir { namespace triton { +MemRefType PtrState::getResultMemrefType(MLIRContext *context, int64_t offset, + ArrayRef resultShape) const { + + SmallVector staticStrides; + SmallVector dynamicStrides; + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + + auto elementType = source.getType().cast().getElementType(); + auto layout = + StridedLayoutAttr::get(source.getContext(), offset, staticStrides); + + return MemRefType::get(resultShape, elementType, layout); +} + +OpFoldResult +PtrState::accumulateTargetOffset(Location loc, + ConversionPatternRewriter &rewriter) const { + OpFoldResult targetOffset = rewriter.getIndexAttr(0); + for (auto o : offsets) { + targetOffset = addOFRs(targetOffset, o, loc, rewriter); + } + return targetOffset; +} + int64_t PtrState::getRank() const { - assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); + assert(offsets.size() == sizes.size() && offsets.size() == strides.size() && + modulos.size() == offsets.size()); return offsets.size(); } @@ -29,6 +54,10 @@ bool PtrState::isEmpty() const { return (getRank() == 0 && !source && !scalar); } +bool PtrState::hasModulo() const { + return llvm::any_of(modulos, [](auto mod) { return mod.has_value(); }); +} + void PtrState::addState(const PtrState &lhsState, const PtrState &rhsState, Location loc, ConversionPatternRewriter &rewriter) { assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); @@ -56,68 +85,243 @@ void PtrState::addState(const PtrState &lhsState, const PtrState &rhsState, strides.push_back(newStride); sizes.push_back(lhsState.sizes[i]); + + assert(!lhsState.hasModulo() || + !rhsState.hasModulo() && "AddPtr where both lhs and rhs containing " + "modulo operators not supported"); + + modulos.push_back(lhsState.modulos[i].has_value() ? lhsState.modulos[i] + : rhsState.modulos[i]); } } void PtrState::mulState(const PtrState &lhsState, const PtrState &rhsState, const Location loc, ConversionPatternRewriter &rewriter) { - bool rhsScalar = true; assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); // neither lhs nor rhs should have source, since multiplying base pointer // does not make sense assert(!(lhsState.source && rhsState.source)); - source = lhsState.source ? lhsState.source : rhsState.source; - assert((lhsState.scalar || rhsState.scalar) && !(lhsState.scalar && rhsState.scalar) && "currently does not support both tensors are effectively non-scalar"); - if (!rhsState.scalar && lhsState.scalar) - rhsScalar = false; + PtrState const *lhs = &lhsState; + PtrState const *rhs = &rhsState; - for (uint64_t i = 0; i < lhsState.sizes.size(); i++) { - OpFoldResult newOffset; - OpFoldResult newStride; - if (rhsScalar) { - newOffset = - mulOFRValue(lhsState.offsets[i], rhsState.scalar, loc, rewriter); - newStride = - mulOFRValue(lhsState.strides[i], rhsState.scalar, loc, rewriter); - } else { - newOffset = - mulOFRValue(rhsState.offsets[i], lhsState.scalar, loc, rewriter); - newStride = - mulOFRValue(rhsState.strides[i], lhsState.scalar, loc, rewriter); - } + if (!rhs->scalar && lhs->scalar) { + std::swap(lhs, rhs); + } + + for (uint64_t i = 0; i < lhs->sizes.size(); i++) { + OpFoldResult newOffset = + mulOFRValue(lhs->offsets[i], rhs->scalar, loc, rewriter); + OpFoldResult newStride = + mulOFRValue(lhs->strides[i], rhs->scalar, loc, rewriter); offsets.push_back(newOffset); strides.push_back(newStride); - sizes.push_back(lhsState.sizes[i]); + sizes.push_back(lhs->sizes[i]); } + + assert(llvm::all_of(rhsState.modulos, + [](auto rhs) { return !rhs.has_value(); })); + + modulos = lhs->modulos; +} + +SmallVector +PtrState::createStackedCastOps(ArrayRef resultShape, + const Location loc, + ConversionPatternRewriter &rewriter) const { + + assert(resultShape.size() == 2); + assert(getRank() == 2); + assert(modulos[0].has_value() && !modulos[1].has_value()); + + Value targetOffset = + ofrToIndexValue(accumulateTargetOffset(loc, rewriter), loc, rewriter); + + ////////////////////////////////////////////////////////////////////////////// + // + // Handling stacked wraparound + // + // We do not support cases where the target offset has already overflown the + // number of rows. See side-by-side wraparound for details. + // + ////////////////////////////////////////////////////////////////////////////// + // We're loading a tensor of dim (rowSize, colSize) + // d1 + d2 = rowSize + // d2 is the number of rows that overflow + // + // cols + // + // wrappedAroundOff + // --------------*------------*-------- + // | d2 | | | + // | |------------| | + // rows| | + // | | + // | targetOffset | + // | *------------| | + // | | | | + // | d1 | | | + // | | clampedOff | | + // --------------*--------------------- + // | overflow | + // *------------- + // nextOff + // + // wrappedAroundOff = targetOffset % cols + // clampedOff = (rows * strideRows) + wrappedAroundOff + // + // clampedOff - targetOffset + // d1 = -------------------- + // strideRows + + auto resultType = getResultMemrefType( + rewriter.getContext(), /* offset */ ShapedType::kDynamic, + /* result shape */ + SmallVector{ + ShapedType::kDynamic, // Row is dynamic, in most cases, this should be + // the same as the original row. The last chunk + // may be smaller due to wrapping around. + resultShape[1], // Col stays the same. + }); + + Value rowSize = ofrToIndexValue(sizes[0], loc, rewriter); + Value colSize = ofrToIndexValue(sizes[1], loc, rewriter); + + Value strideRow = ofrToIndexValue(strides[0], loc, rewriter); + Value strideCol = ofrToIndexValue(strides[1], loc, rewriter); + + Value modRow = rewriter.create( + loc, rewriter.getIndexType(), modulos[0]->size); + + // First chunk + Value wrappedAroundOff = + rewriter.create(loc, targetOffset, strideRow); + Value clampedOff = rewriter.create(loc, modRow, strideRow); + clampedOff = + rewriter.create(loc, clampedOff, wrappedAroundOff); + Value d1 = rewriter.create(loc, clampedOff, targetOffset); + d1 = rewriter.create(loc, d1, strideRow); + + SmallVector sizes1{d1, colSize}; + memref::ReinterpretCastOp cast1 = rewriter.create( + loc, resultType, source, targetOffset, sizes1, + ValueRange{strideRow, strideCol}); + + // Second chunk + Value d2 = rewriter.create(loc, rowSize, d1); + SmallVector sizes2{d2, colSize}; + memref::ReinterpretCastOp cast2 = rewriter.create( + loc, resultType, source, wrappedAroundOff, sizes2, + ValueRange{strideRow, strideCol}); + + return {cast1, cast2}; +} + +SmallVector +PtrState::createSideBySideCastOps(ArrayRef resultShape, + const Location loc, + ConversionPatternRewriter &rewriter) const { + + assert(resultShape.size() == 2); + assert(getRank() == 2 && !modulos[0].has_value() && modulos[1].has_value()); + + // Accumulate final offset + Value targetOffset = + ofrToIndexValue(accumulateTargetOffset(loc, rewriter), loc, rewriter); + + ////////////////////////////////////////////////////////////////////////////// + // + // Handling side-by-side wraparound + // + // Note: We do not support cases where the target has already overflown the + // number of columns! This is because in PtrAnalysis, the offset has already + // been collapsed into a single dimension, so it is ambiguous to determine + // whether the offset actually overflows or just refers to an element on the + // subsequent rows. + // + // Same limitations apply to the stacked wraparound case. + // + ////////////////////////////////////////////////////////////////////////////// + // + // nextOffset - targetOffset = colSize + // d1 + d2 = colSize + // N + // x clampedOffset + // --------------------------*----------------*-----* + // | | nextOffset (might + // | targetOffset | overflow) + // y *----- *----------------| + // | | | | + // M |----- -----------------| + // | d2 d1 | + // -------------------------------------------- + // + // x = targetOffset % N + // nextOffset = x + colSize + // clampedOffset = min(nextOffset, N) + // d1 = clampedOffset - x + // + ////////////////////////////////////////////////////////////////////////////// + + SmallVector casts; + + auto resultType = getResultMemrefType( + rewriter.getContext(), /* offset */ ShapedType::kDynamic, + /* result shape */ + SmallVector{ + resultShape[0], // Row stays the same + ShapedType::kDynamic // Column is dynamic, in most cases, this should + // be the same as the original column. The last + // chunk may be smaller due to wrapping around. + }); + + Value rowSize = ofrToIndexValue(sizes[0], loc, rewriter); + Value colSize = ofrToIndexValue(sizes[1], loc, rewriter); + + Value modN = rewriter.create(loc, rewriter.getIndexType(), + modulos[1]->size); + + SmallVector strideVals = ofrsToIndexValues(strides, loc, rewriter); + + Value x = rewriter.create(loc, targetOffset, modN); + Value y = rewriter.create(loc, targetOffset, x); + + // First chunk + Value nextOffset = rewriter.create(loc, x, colSize); + Value clampedOffset = rewriter.create(loc, nextOffset, modN); + Value d1 = rewriter.create(loc, clampedOffset, x); + SmallVector sizes1{rowSize, d1}; + auto cast1 = rewriter.create( + loc, resultType, source, targetOffset, sizes1, strideVals); + + // Second chunk + Value d2 = rewriter.create(loc, colSize, d1); + SmallVector sizes2{rowSize, d2}; + auto cast2 = rewriter.create( + loc, resultType, source, y, sizes2, strideVals); + + return {cast1, cast2}; } memref::ReinterpretCastOp PtrState::createCastOp(ArrayRef resultShape, const Location loc, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter &rewriter) const { // Accumulate final offset - OpFoldResult targetOffset = rewriter.getIndexAttr(0); - for (auto o : offsets) - targetOffset = addOFRs(targetOffset, o, loc, rewriter); + OpFoldResult targetOffset = accumulateTargetOffset(loc, rewriter); // Create result MemRefType SmallVector staticOffset; SmallVector dynamicOffset; - SmallVector staticStrides; - SmallVector dynamicStrides; dispatchIndexOpFoldResult(targetOffset, dynamicOffset, staticOffset); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - auto elementType = source.getType().cast().getElementType(); - auto layout = StridedLayoutAttr::get(source.getContext(), staticOffset[0], - staticStrides); - auto resultType = MemRefType::get(resultShape, elementType, layout); + auto resultType = + getResultMemrefType(rewriter.getContext(), staticOffset[0], resultShape); // Create reinterpret cast return rewriter.create( @@ -134,6 +338,11 @@ void PtrAnalysis::visitOperandAdd( PtrState rhsState; visitOperand(addOp.getRhs(), rhsState, loc, rewriter, knownPtrs); + if ((lhsState.getRank() == 1 && lhsState.hasModulo()) || + (rhsState.getRank() == 1 && rhsState.hasModulo())) { + assert(0 && "Current do not support this pattern: a + arange(0, K) % M"); + } + state.addState(lhsState, rhsState, loc, rewriter); } @@ -150,6 +359,23 @@ void PtrAnalysis::visitOperandMul( state.mulState(lhsState, rhsState, loc, rewriter); } +void PtrAnalysis::visitOperandRem( + arith::RemSIOp remOp, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + visitOperand(remOp.getLhs(), state, loc, rewriter, knownPtrs); + assert(state.getRank() == 1 && !state.modulos.back().has_value() && + "No support for multiple modulos within an expression"); + + PtrState rhsState; + visitOperand(remOp.getRhs(), rhsState, loc, rewriter, knownPtrs); + assert(rhsState.scalar); + rhsState.scalar.dump(); + + state.modulos.back() = ModuloState(rhsState.scalar, rewriter.getIndexAttr(0)); +} + void PtrAnalysis::visitOperandMakeRange( triton::MakeRangeOp rangeOp, PtrState &state, Location loc, ConversionPatternRewriter &rewriter, @@ -165,6 +391,7 @@ void PtrAnalysis::visitOperandMakeRange( state.offsets.push_back(rewriter.getIndexAttr(start)); state.sizes.push_back(rewriter.getIndexAttr(shape[0])); state.strides.push_back(rewriter.getIndexAttr(stride)); + state.modulos.push_back(std::nullopt); } void PtrAnalysis::visitOperandExpandDims( @@ -186,6 +413,7 @@ void PtrAnalysis::visitOperandExpandDims( state.offsets.insert(state.offsets.begin() + axis, rewriter.getIndexAttr(0)); state.sizes.insert(state.sizes.begin() + axis, rewriter.getIndexAttr(1)); state.strides.insert(state.strides.begin() + axis, rewriter.getIndexAttr(0)); + state.modulos.insert(state.modulos.begin() + axis, std::nullopt); } void PtrAnalysis::visitOperandBroadcast( @@ -214,8 +442,6 @@ void PtrAnalysis::visitOperandBroadcast( else llvm_unreachable("unexpected dimensions used in broadcast"); } - - return; } void PtrAnalysis::visitOperandSplat( @@ -236,12 +462,13 @@ void PtrAnalysis::visitOperandSplat( state.offsets.push_back(rewriter.getIndexAttr(0)); state.sizes.push_back(rewriter.getIndexAttr(s)); state.strides.push_back(rewriter.getIndexAttr(0)); + state.modulos.push_back(std::nullopt); } } else { - // src is a memref that represent a scalar pointer; it should have one - // dimension of size 1. This happens inside a for loop that originally has - // an init arg that is a tensor of pointers; this arg would have been - // replaced by rewriteForOp. + // src is a memref that represent a scalar pointer; it should have + // one dimension of size 1. This happens inside a for loop that + // originally has an init arg that is a tensor of pointers; this arg + // would have been replaced by rewriteForOp. auto srcType = src.getType().cast(); assert(srcType.getRank() == 1 && state.getRank() == 1 && "splat MemRef source should have rank 1"); @@ -249,8 +476,8 @@ void PtrAnalysis::visitOperandSplat( getIntAttr(state.sizes[0]).value() == 1 && "splat MemRef source should have size 1"); - // Stride[0] will have value of 1 set in visitOperandAddPtr. This value will - // be represented by a constOp. Clear this value. + // Stride[0] will have value of 1 set in visitOperandAddPtr. This + // value will be represented by a constOp. Clear this value. state.strides[0] = rewriter.getIndexAttr(0); for (auto [i, s] : llvm::enumerate(dstShape)) { @@ -261,6 +488,7 @@ void PtrAnalysis::visitOperandSplat( state.offsets.push_back(rewriter.getIndexAttr(0)); state.sizes.push_back(rewriter.getIndexAttr(s)); state.strides.push_back(rewriter.getIndexAttr(0)); + state.modulos.push_back(std::nullopt); } } @@ -268,8 +496,19 @@ void PtrAnalysis::visitOperandSplat( // most dimension if (state.scalar) state.offsets[0] = state.scalar; +} - return; +void PtrAnalysis::visitOperandMakeTensorPtr( + triton::MakeTensorPtrOp makeTensorPtrOp, PtrState &state, + const Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + auto remappedValue = rewriter.getRemappedValue(makeTensorPtrOp); + if (auto castOp = remappedValue.getDefiningOp()) { + visitOperandReintCast(castOp, state, loc, rewriter, knownPtrs); + } else { + llvm_unreachable("Expect value to me mapped to a memref.reinterpret_cast"); + } } void PtrAnalysis::visitOperandAddptr( @@ -297,6 +536,7 @@ void PtrAnalysis::visitOperandAddptr( offsetState.sizes.push_back(rewriter.getIndexAttr(1)); offsetState.offsets.push_back(offsetState.scalar); offsetState.strides.push_back(rewriter.getIndexAttr(0)); + offsetState.modulos.push_back(std::nullopt); } assert(ptrState.getRank() == offsetState.getRank() && @@ -315,6 +555,7 @@ void PtrAnalysis::visitOperandReintCast( state.sizes = reintCastOp.getMixedSizes(); state.strides = reintCastOp.getMixedStrides(); state.source = reintCastOp.getSource(); + state.modulos.append(state.sizes.size(), std::nullopt); // getMixedOffsets produces staticOffsets (which is the result of collapsing // multiple dimensions). Populate the rest of the dimensions with zeroes. @@ -322,6 +563,19 @@ void PtrAnalysis::visitOperandReintCast( for (size_t i = 1; i < state.sizes.size(); i++) { state.offsets.push_back(rewriter.getIndexAttr(0)); } + + // Regular Triton programs cannot express patterns of size 1 and non-zero + // stride; we only set it that way to make memrefs work. Set stride back to + // zero if this scenario detected. + for (size_t i = 0; i < state.strides.size(); i++) { + auto strideIntAttr = getIntAttr(state.strides[i]); + auto sizeIntAttr = getIntAttr(state.sizes[i]); + + assert(sizeIntAttr); + if (sizeIntAttr.value() == 1 && strideIntAttr) { + state.strides[i] = rewriter.getIndexAttr(0); + } + } } void PtrAnalysis::visitOperand( @@ -345,12 +599,18 @@ void PtrAnalysis::visitOperand( auto remappedPtr = rewriter.getRemappedValue(operand); assert(remappedPtr); - // A scalar pointer can either be produced by AddPtrOp or a block argument + // A scalar pointer can either be produced by AddPtrOp or a block + // argument if (auto op = operand.getDefiningOp()) { - assert(operand.getDefiningOp() && - "Assume only addptr can produce a scalar pointer"); - visitOperandAddptr(cast(op), state, loc, rewriter, - knownPtrs); + if (auto addPtrOp = dyn_cast(op)) { + visitOperandAddptr(cast(op), state, loc, rewriter, + knownPtrs); + } else if (auto makeTensorOp = dyn_cast(op)) { + visitOperandMakeTensorPtr(makeTensorOp, state, loc, rewriter, + knownPtrs); + } else { + llvm_unreachable("Unexpected operand defining operation"); + } } else { state.source = remappedPtr; } @@ -373,8 +633,10 @@ void PtrAnalysis::visitOperand( visitOperandAddptr(op, state, loc, rewriter, knownPtrs); } else if (auto op = operand.getDefiningOp()) { visitOperandConstSplat(op, state, loc, rewriter, knownPtrs); + } else if (auto op = operand.getDefiningOp()) { + visitOperandRem(op, state, loc, rewriter, knownPtrs); } else { - operand.getDefiningOp()->dump(); + operand.dump(); llvm_unreachable("encountered addptr operand produced by an " "unsupported operation"); } @@ -399,7 +661,7 @@ void PtrAnalysis::visitOperandConstSplat( state.scalar = constOp; auto resultType = cast(op.getResult().getType()); - for (auto i = 0; i < resultType.getShape().size(); i++) { + for (size_t i = 0; i < resultType.getShape().size(); i++) { if (i == 0) { state.offsets.push_back(constOp.getResult()); } else { @@ -408,6 +670,7 @@ void PtrAnalysis::visitOperandConstSplat( state.sizes.push_back(rewriter.getIndexAttr(resultType.getShape()[i])); state.strides.push_back(rewriter.getIndexAttr(0)); + state.modulos.push_back(std::nullopt); } } @@ -427,6 +690,7 @@ void PtrAnalysis::rewriteAddptrOp( state.sizes.push_back(rewriter.getIndexAttr(1)); state.strides.push_back(rewriter.getIndexAttr(0)); state.offsets.push_back(state.scalar); + state.modulos.push_back(std::nullopt); } SmallVector scalarShape(1, 1); @@ -439,35 +703,141 @@ void PtrAnalysis::rewriteAddptrOp( assert(state.getRank() == 1); } - // If there are dimensions with size 1 and stride 0, replace stride 0 with 1 - // so inferResultType below works as expected. - for (size_t i = 0; i < state.sizes.size(); i++) { + knownPtrs[op.getResult()] = state; + + // If there are dimensions with size 1 and stride 0, replace 0 stride with the + // product of sizes of all lower dimensions. This avoids creating memref with + // zero stride. Note that we store the unmodified state into knownPtrs, since + // any following pointer arithmetic operations should use the original 0 + // stride. + auto accum_size = 1; + for (int i = state.sizes.size() - 1; i >= 0; i--) { auto strideIntAttr = getIntAttr(state.strides[i]); auto sizeIntAttr = getIntAttr(state.sizes[i]); - if (!strideIntAttr || strideIntAttr != 0) - continue; + assert(sizeIntAttr); + if (sizeIntAttr.value() == 1 && strideIntAttr && strideIntAttr.value() == 0) + state.strides[i] = rewriter.getIndexAttr(accum_size); - if (sizeIntAttr && sizeIntAttr.value() == 1) - state.strides[i] = rewriter.getIndexAttr(1); + accum_size *= sizeIntAttr.value(); } - auto castOp = state.createCastOp(resultShape, op.getLoc(), rewriter); - LLVM_DEBUG({ - llvm::dbgs() << "cast MemRefType:\n"; - castOp.getOperation()->print(llvm::dbgs(), - OpPrintingFlags().printGenericOpForm()); - llvm::dbgs() << "\n"; - }); + Value src; - state.source = castOp.getResult(); - rewriter.replaceOp(op, castOp.getResult()); + if (llvm::any_of(state.modulos, [](auto mod) { return mod.has_value(); })) { + assert(state.modulos.size() == 2); + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(op); - knownPtrs[op.getResult()] = state; + SmallVector casts; + StringRef type; + + if (!state.modulos[0].has_value() && state.modulos[1].has_value()) { + casts = state.createSideBySideCastOps(resultShape, op.getLoc(), rewriter); + type = ModuloState::WraparoundSideBySide; + } else if (state.modulos[0].has_value() && !state.modulos[1].has_value()) { + casts = state.createStackedCastOps(resultShape, op.getLoc(), rewriter); + type = ModuloState::WraparoundStacked; + } else { + assert(false && "not supported"); + } + auto resultType = state.getResultMemrefType( + rewriter.getContext(), ShapedType::kDynamic, resultShape); + + UnrealizedConversionCastOp combinedCast = + rewriter.create( + op.getLoc(), resultType, + ValueRange{casts[0].getResult(), casts[1].getResult(), + op.getResult()}); + + combinedCast->setAttr(ModuloState::WraparoundAttr, + rewriter.getStringAttr(type)); + + src = combinedCast.getResult(0); + + LLVM_DEBUG({ + llvm::dbgs() << "combine cast for split pointers:\n"; + combinedCast.getOperation()->print( + llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + } else { + memref::ReinterpretCastOp castOp = + state.createCastOp(resultShape, op.getLoc(), rewriter); + + src = castOp.getResult(); + + LLVM_DEBUG({ + llvm::dbgs() << "cast MemRefType:\n"; + castOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + } + + state.source = src; + rewriter.replaceOp(op, src); rewriter.restoreInsertionPoint(origIp); } +void PtrAnalysis::rewriteAdvanceOp( + triton::AdvanceOp op, ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &knownPtrs) { + OpBuilder::InsertionGuard insertionGuard{rewriter}; + rewriter.setInsertionPoint(op); + auto loc = op.getLoc(); + + PtrState ptrState; + visitOperand(op.getOperand(0), ptrState, loc, rewriter, knownPtrs); + + auto incrementOffsets = op.getOffsets(); + + SmallVector newOffsets; + for (auto [increment, offset, stride] : + llvm::zip(incrementOffsets, ptrState.offsets, ptrState.strides)) { + Value offsetValue; + if (auto offsetIntAttr = getIntAttr(offset)) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(0)); + offsetValue = constOp.getResult(); + } else { + offsetValue = offset.get(); + } + auto castOp = rewriter.create( + loc, rewriter.getIndexType(), increment); + auto mulOp = rewriter.create(loc, castOp.getResult(), + stride.get()); + auto addOp = + rewriter.create(loc, mulOp.getResult(), offsetValue); + newOffsets.push_back(addOp.getResult()); + } + + ptrState.offsets.clear(); + + for (auto offset : newOffsets) { + ptrState.offsets.push_back(offset); + } + + SmallVector scalarShape(1, 1); + ArrayRef resultShape; + auto pointerType = op.getResult().getType().cast(); + if (auto shapedType = pointerType.getPointeeType().dyn_cast()) { + resultShape = shapedType.getShape(); + } else { + // scalar pointer, should produce a one dimensional memref + resultShape = scalarShape; + assert(ptrState.getRank() == 1); + } + + auto newOp = ptrState.createCastOp(resultShape, loc, rewriter); + + rewriter.replaceOp(op, newOp.getResult()); + + knownPtrs[newOp.getResult()] = ptrState; +} + void PtrAnalysis::rewriteYieldOp( scf::YieldOp op, ConversionPatternRewriter &rewriter, const IndexMapSet &levelToBlockArgIndex, const int level, @@ -478,8 +848,11 @@ void PtrAnalysis::rewriteYieldOp( auto adaptor = scf::YieldOp::Adaptor(op); - SmallVector initArgState; + SmallVector initArgState; SmallVector operands(adaptor.getOperands()); + // Track the second chunks of modulo pointers so that we can append them to + // the yield results + SmallVector moduloSecondChunks; // For each of the init arg that we added additional Values in for loop, we // need to add corresponding Values as yield operands. The loop below gathers @@ -487,24 +860,37 @@ void PtrAnalysis::rewriteYieldOp( for (auto [i, v] : llvm::enumerate(adaptor.getOperands())) { if (auto mappedV = rewriter.getRemappedValue(v)) { // If this value is a tensor of pointers produced by AddPtrOp, - // TritonTypeConverter should have converted to MemRefType without layout - // information. Since it doesn't match with the MemRefType that we - // produced in rewriteAddptrOp (which is in canonical form with layout - // information), an unrealized_conversion_cast should have been added. We - // need to trace it back through this unrealized_conversion_cast to get - // the original reinterpret_cast. Also see comments in - // TritonTypeConverter::addConversion. + // TritonTypeConverter should have converted to MemRefType without + // layout information. Since it doesn't match with the MemRefType + // that we produced in rewriteAddptrOp (which is in canonical form + // with layout information), an unrealized_conversion_cast should + // have been added. We need to trace it back through this + // unrealized_conversion_cast to get the original reinterpret_cast. + // Also see comments in TritonTypeConverter::addConversion. // - // For TritonToLinalg, we do not use any TypeConverters, hence we can - // access the reinterpret_cast directly. - if (v.getDefiningOp()) { + // For TritonToLinalg, we do not use any TypeConverters, hence we + // can access the reinterpret_cast directly. + if (v.getDefiningOp() || + v.getDefiningOp() || + v.getDefiningOp()) { if (auto castOp = mappedV.getDefiningOp()) { auto castInputs = castOp.getInputs(); - assert(castInputs.size() == 1 && - "only expect 1:1 mapping for unrealized_conversion_cast that " + + assert((castInputs.size() == 1 || + castOp->hasAttr(ModuloState::WraparoundAttr)) && + "only expect 1:1 mapping for " + "unrealized_conversion_cast that " "were " "automatically inserted during legalizing"); - v = castInputs[0]; + + if (castOp->hasAttr(ModuloState::WraparoundAttr)) { + v = castOp.getResult(0); + operands[i] = castInputs[0]; + moduloSecondChunks.push_back(castInputs[1]); + } else { + v = castInputs[0]; + } + } else if (auto castOp = mappedV.getDefiningOp()) { v = castOp; @@ -512,13 +898,13 @@ void PtrAnalysis::rewriteYieldOp( llvm_unreachable("mapped value defined by an unexpected op"); } } else { - // If this value is not a tensor of pointers, we will use the mapped - // value, and rely on the conversion will happen later automatically - // when we legalize loop body. + // If this value is not a tensor of pointers, we will use the + // mapped value, and rely on the conversion will happen later + // automatically when we legalize loop body. // TODO: - // The scenario where a value is a tensor of pointers but not produced - // by AddPtrOp is not supported + // The scenario where a value is a tensor of pointers but not + // produced by AddPtrOp is not supported if (mappedV.getType().isa() && mappedV.getType() .dyn_cast() @@ -537,8 +923,12 @@ void PtrAnalysis::rewriteYieldOp( continue; auto reintCastOp = v.getDefiningOp(); + auto unrealizedCastOp = v.getDefiningOp(); + assert( reintCastOp || + (unrealizedCastOp && + unrealizedCastOp->hasAttr(ModuloState::WraparoundAttr)) || (v.getType().isa() && v.getType().dyn_cast().getElementType().isa())); @@ -546,6 +936,9 @@ void PtrAnalysis::rewriteYieldOp( if (reintCastOp) { visitOperandReintCast(reintCastOp, state, op.getLoc(), rewriter, knownPtrs); + } else if (unrealizedCastOp) { + visitOperandUnrealizedCast(unrealizedCastOp, state, op.getLoc(), rewriter, + knownPtrs); } else { visitOperand(v, state, op.getLoc(), rewriter, knownPtrs); } @@ -557,9 +950,10 @@ void PtrAnalysis::rewriteYieldOp( // them to yield operands. for (auto state : initArgState) { for (auto s : state.offsets) { - // offsets can be IntAttr zeroes, since reinterpret_cast collapses them - // for the input memref, and the for loop may not update offsets other - // than offsets[0]. Create constants Values for those zeroes. + // offsets can be IntAttr zeroes, since reinterpret_cast collapses + // them for the input memref, and the for loop may not update + // offsets other than offsets[0]. Create constants Values for those + // zeroes. if (auto sIntAttr = getIntAttr(s)) { assert(sIntAttr.value() == 0 && "attribute offsets should be zeroes"); auto constOp = rewriter.create( @@ -571,13 +965,17 @@ void PtrAnalysis::rewriteYieldOp( } for (auto s : state.strides) { - assert(!getIntAttr(s) && - "PtrState strides for yield within for loop not expected to be " - "attribute."); + assert(!getIntAttr(s) && "PtrState strides for yield within for " + "loop not expected to be " + "attribute."); operands.push_back(s.get()); } } + for (auto chunk : moduloSecondChunks) { + operands.push_back(chunk); + } + // Yield is a terminator op that must be at the end of the function rewriter.setInsertionPointAfter(op); auto newOp = rewriter.replaceOpWithNewOp(op, operands); @@ -591,14 +989,64 @@ void PtrAnalysis::rewriteYieldOp( }); } +// From an unrealized_conversion_cast which takes in two reinterpret_casts +// representing two chunks, we need to get back the full pointer state. We +// cannot rebuild the original state from the two reinterpret_casts similarly to +// the normal case. To solve this, we attach the original addptr as the third +// operand to the unrealized_cast so that we can manually rebuild the state. +void PtrAnalysis::visitOperandUnrealizedCast( + UnrealizedConversionCastOp op, PtrState &state, const Location loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &knownPtrs) { + assert(op->hasAttr(ModuloState::WraparoundAttr) && + op.getInputs().size() == 3 && + op.getInputs()[0].getDefiningOp() && + op.getInputs()[1].getDefiningOp() && + op.getInputs()[2].getDefiningOp()); + + auto origPtr = op.getInputs()[2]; + if (knownPtrs.contains(origPtr)) { + state = knownPtrs.at(origPtr); + } else { + visitOperandAddptr(origPtr.getDefiningOp(), state, loc, + rewriter, knownPtrs); + } +} + +struct ModuloChunkInitArg { + Value reinterpretCast = nullptr; + // where in the init args is the first chunk placed + size_t initArgIndex = -1; +}; + void PtrAnalysis::rewriteForOp( scf::ForOp op, ConversionPatternRewriter &rewriter, IndexMapSet &levelToBlockArgIndex, const int level, llvm::SmallDenseMap &knownPtrs) { SmallVector newInitArgs; - SmallVector> initArgIndexState; - SmallVector> knownPtrsTmp; + SmallVector, 5> initArgIndexState; + SmallVector, 5> knownPtrsTmp; + + // If we have a load op that uses a modulo pointer, we need to insert both of + // the memref chunks to the init args. We reuse the sizes from the original + // memrefs. This data structure keeps track of where these additional init + // args should be inserted. + // + // As an example, if we have a 2D memrefs being split, we first put the first + // chunk in the order as it appears. Then, once all of the original init args + // are processed, we insert their offsets and strides, and finally the second + // chunk. + SmallVector, PtrState>, + 6> + moduloStates; + + // Amongst the init args, track the indices that map to the first chunk of a + // modulo pair. This is used to distinguish between the normal + // reinterpret_casts whose return types need to be rewritten to match what the + // for loop is yielding. + DenseSet moduloInitArgIndices; // Create a new list of init args for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { @@ -609,28 +1057,55 @@ void PtrAnalysis::rewriteForOp( // TypeConverters. if (mappedV && mappedV.getDefiningOp()) { auto castOp = mappedV.getDefiningOp(); - assert(castOp && "expected unrealized_conversion_cast"); - auto castInputs = castOp.getInputs(); - assert(castInputs.size() == 1 && - "only expect 1:1 mapping for unrealized_conversion_cast that were " - "automatically inserted during legalizing"); - mappedV = castInputs[0]; + if (!castOp->hasAttr(ModuloState::WraparoundAttr)) { + assert(castOp && "expected unrealized_conversion_cast"); + auto castInputs = castOp.getInputs(); + assert(castInputs.size() == 1 && + "only expect 1:1 mapping for unrealized_conversion_cast " + "that " + "were automatically inserted during legalizing"); + mappedV = castInputs[0]; + } } memref::ReinterpretCastOp reintCastOp; + UnrealizedConversionCastOp unrealizedCastOp; - // If this init arg is supposed to be remapped, use the remapped value - // instead. In addition, if this init arg is a memref created by a - // reinterpret_cast or a tensor of index, there is a chance that it will be - // used in addptr. Create PtrState for each such init arg. + // If this init arg is supposed to be remapped, use the remapped + // value instead. In addition, if this init arg is a memref created + // by a reinterpret_cast or a tensor of index, there is a chance that + // it will be used in addptr. Create PtrState for each such init arg. if (mappedV) { // TODO: // Passing a block argument pointer directly into a for loop not + // supported. assert(!(mappedV.dyn_cast() && mappedV.getType().isa()) && "cannot take pointer block argument as init arg for for loop"); - reintCastOp = mappedV.getDefiningOp(); - newInitArgs.push_back(mappedV); + if (auto op = mappedV.getDefiningOp()) { + reintCastOp = op; + newInitArgs.push_back(mappedV); + } else if (auto op = + mappedV.getDefiningOp()) { + assert(op->hasAttr(ModuloState::WraparoundAttr)); + unrealizedCastOp = op; + auto inputs = unrealizedCastOp.getInputs(); + assert(inputs.size() == 3); + + SmallVector initArgData{ + ModuloChunkInitArg{inputs[0], i}, + ModuloChunkInitArg{inputs[1]}, + }; + + moduloInitArgIndices.insert(i); + moduloStates.push_back( + std::make_tuple(unrealizedCastOp, initArgData, PtrState{})); + + newInitArgs.push_back(inputs[0]); + } else { + newInitArgs.push_back(mappedV); + } + } else { newInitArgs.push_back(arg); } @@ -639,15 +1114,18 @@ void PtrAnalysis::rewriteForOp( arg.getType().isa() && arg.getType().dyn_cast().getElementType().isa(); - if (!reintCastOp && !indexTensor) + if (!unrealizedCastOp && !reintCastOp && !indexTensor) continue; PtrState state; if (reintCastOp) { visitOperandReintCast(reintCastOp, state, op.getLoc(), rewriter, llvm::SmallDenseMap(0)); + } else if (unrealizedCastOp) { + visitOperandUnrealizedCast(unrealizedCastOp, state, op.getLoc(), rewriter, + llvm::SmallDenseMap(0)); + std::get<2>(moduloStates.back()) = state; } else { - // TODO: visitOperand(arg, state, op.getLoc(), rewriter, llvm::SmallDenseMap(0)); } @@ -666,8 +1144,8 @@ void PtrAnalysis::rewriteForOp( // them to init args for (auto [i, state] : initArgIndexState) { // For each dimension, if the corresponding offset and stride is an - // integer attribute, create a constant value and append them at the end - // of init arg list. + // integer attribute, create a constant value and append them at the + // end of init arg list. for (auto [j, s] : llvm::enumerate(state.offsets)) { auto sIntAttr = getIntAttr(s); if (sIntAttr) { @@ -692,25 +1170,31 @@ void PtrAnalysis::rewriteForOp( } } - // Note that we want the knownPtrs to be indexed by block arg, but we only - // have index for now. Also, the state we record is the init arg, but want - // to to use newly created block arg. These block args are not created yet. - // We will translate this mapping later. + // Note that we want the knownPtrs to be indexed by block arg, but we + // only have index for now. Also, the state we record is the init + // arg, but want to to use newly created block arg. These block args + // are not created yet. We will translate this mapping later. knownPtrsTmp.push_back(std::make_pair(i, state)); levelToBlockArgIndex[level].insert(i); - // If the original init arg is a memref produced by reinterpret_cast, create - // a new memref using new strides and offsets created above. This produces a - // canonicalized memref, which will match what the for loop generates if it - // modifies the memref. E.g., original reinterpret_cast can produce a memref - // with const stride: - // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + s0 + d1 + // If the original init arg is a memref produced by reinterpret_cast, + // create a new memref using new strides and offsets created above. + // This produces a canonicalized memref, which will match what the + // for loop generates if it modifies the memref. E.g., original + // reinterpret_cast can produce a memref with const stride: + // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + + // s0 + d1 // * s1)>> - // The new reinterpret_cast will always have dynamic stride and offset: - // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + - // d1 * s2)>> - if (auto reintCastOp = - newInitArgs[i].getDefiningOp()) { + // The new reinterpret_cast will always have dynamic stride and + // offset: + // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + // + s0 + d1 * s2)>> + // + // For init args that are the first chunk of a modulo pair, there is + // no need for the type to be rewritten because the strides and + // offsets are already dynamic. + if (!moduloInitArgIndices.contains(i) && + newInitArgs[i].getDefiningOp()) { SmallVector resultShape; for (auto s : state.sizes) { auto sIntAttr = getIntAttr(s); @@ -730,9 +1214,15 @@ void PtrAnalysis::rewriteForOp( } } + // Pass in the second chunk of each modulo pair + for (auto &[unrealizedCastOp, chunkData, state] : moduloStates) { + chunkData[1].initArgIndex = newInitArgs.size(); + newInitArgs.push_back(chunkData[1].reinterpretCast); + } + rewriter.restoreInsertionPoint(origIp); - // create a new scf::ForOp that uses updated init args and same loop body + // Create a new scf::ForOp that uses updated init args and same loop body auto newOp = rewriter.create( op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), newInitArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { @@ -740,9 +1230,53 @@ void PtrAnalysis::rewriteForOp( mapping.map(op.getInductionVar(), iv); mapping.map(op.getInitArgs(), newInitArgs); mapping.map(op.getRegionIterArgs(), args); + for (auto &bodyOp : op.getLoopBody().getOps()) { b.clone(bodyOp, mapping); } + + // Load op is lowered independent of the pointer, if we have a split + // pointer due to modulo, we need to "logically combine" these two + // memrefs into a single one using unrealized_cast_op. This way, when + // lowering the load, the pattern can detect if additional copies are + // inserted. When we are in a loop, it is more complicated because we + // have to insert a new unrealized_cast_op that combines the two memrefs + // in the init arg list. In addition, because init args hold no offset + // and size information, we have to manually insert two additional + // reinterpret_cast ops as input to this unrealized_cast_op so that the + // load have enough information to generate the corresponding copy. + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(b.getBlock()); + + Value zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + + for (auto &[unrealizedCastOp, chunkData, state] : moduloStates) { + SmallVector newReinterpretCasts; + for (auto &chunk : chunkData) { + auto initReintCast = + chunk.reinterpretCast + .getDefiningOp(); + + auto newReintCast = b.create( + loc, initReintCast.getResult().getType(), + args[chunk.initArgIndex], zero, initReintCast.getSizes(), + initReintCast.getStrides()); + + newReinterpretCasts.push_back(newReintCast); + } + + auto combinedCast = b.create( + loc, unrealizedCastOp.getResult(0).getType(), newReinterpretCasts, + unrealizedCastOp->getAttrs()); + + args[chunkData[0].initArgIndex].replaceUsesWithIf( + combinedCast.getResult(0), [](OpOperand &operand) { + assert(!isa(operand.getOwner()) && + "Storing to split pointers not supported"); + return isa(operand.getOwner()); + }); + } }); // Convert the book-keeping data structure to use the correct key and value. @@ -764,10 +1298,11 @@ void PtrAnalysis::rewriteForOp( auto key = newOp.getRegionIterArgs()[i]; knownPtrs.insert(std::make_pair(key, state)); } - assert(static_cast(cnt) == newOp.getRegionIterArgs().size() && + assert(static_cast(cnt + moduloStates.size()) == + newOp.getRegionIterArgs().size() && "expect to remap all new block args"); - // replace only the results that correspond to the original scf.for + // Replace only the results that correspond to the original scf.for auto resultsToReplaceWith = ResultRange( newOp.result_begin(), newOp.result_begin() + op.getNumResults()); rewriter.replaceOp(op, resultsToReplaceWith); @@ -777,6 +1312,8 @@ void PtrAnalysis::rewriteForOp( for (auto &bodyOp : newOp.getLoopBody().getOps()) { if (auto addptrOp = dyn_cast(bodyOp)) { rewriteAddptrOp(addptrOp, rewriter, knownPtrs); + } else if (auto advanceOp = dyn_cast(bodyOp)) { + rewriteAdvanceOp(advanceOp, rewriter, knownPtrs); } else if (auto forOp = dyn_cast(bodyOp)) { // TODO: // Nested for loops are not supported at the moment @@ -830,6 +1367,7 @@ Value PtrAnalysis::getScalarMemRef(Value ptr, Value memRef, const Location loc, state.offsets.push_back(rewriter.getIndexAttr(0)); state.sizes.push_back(rewriter.getIndexAttr(1)); state.strides.push_back(rewriter.getIndexAttr(1)); + state.modulos.push_back(std::nullopt); auto castOp = state.createCastOp(SmallVector(1, 1), loc, rewriter); return castOp.getResult(); } diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp index d9d5678d..3cdcfb6c 100644 --- a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp +++ b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp @@ -7,6 +7,7 @@ #include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" #include "triton-shared/Analysis/MaskAnalysis.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" #include "triton-shared/Analysis/PtrAnalysis.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -17,10 +18,16 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" #include +#include + +#define DEBUG_TYPE "triton-to-linalg" using namespace mlir; using namespace triton; @@ -34,18 +41,43 @@ using namespace triton; // Extract a scalar value from v. // If v is a scalar, return that directly. Otherwise, parse through operations -// (currently only support splat and sitofp) that produce it and to extract they -// underlying scalar value . If no scalar value can be extracted, a nullptr is -// returned. -static std::optional -getScalarValue(Value v, Location loc, ConversionPatternRewriter &rewriter) { - // Record if an sitofp op was in the chain of ops that produce the scalar - Operation *siToFp = nullptr; +// (currently only support splat, sitofp, and truncf) that produce it to +// extract the underlying scalar value. We then reconstruct the chain of +// operations that can produce this constant with the original type. If no +// scalar value can be extracted, a nullptr is returned. +static Value getScalarValue(Value operand, Location loc, + ConversionPatternRewriter &rewriter) { + SmallVector ops; + + auto reconstructScalarValue = [&](Value src) { + for (auto op = ops.rbegin(); op != ops.rend(); ++op) { + src = TypeSwitch(*op) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Default([](Operation *op) { + llvm_unreachable("unsupported op in generating "); + return nullptr; + }); + } + return src; + }; while (true) { - if (!v.getType().dyn_cast()) { - break; - } else if (auto op = v.getDefiningOp()) { + if (!operand.getType().dyn_cast()) { + return reconstructScalarValue(operand); + } else if (auto op = operand.getDefiningOp()) { if (auto attr = op.getValue().dyn_cast()) { if (!attr.isSplat()) { InFlightDiagnostic diag = emitError(loc) @@ -56,13 +88,16 @@ getScalarValue(Value v, Location loc, ConversionPatternRewriter &rewriter) { auto elemValue = attr.getSplatValue(); auto constOp = arith::ConstantOp::materialize( rewriter, elemValue, attr.getElementType(), op.getLoc()); - v = constOp.getResult(); + return reconstructScalarValue(constOp.getResult()); } - } else if (auto op = v.getDefiningOp()) { - v = op.getSrc(); - } else if (auto op = v.getDefiningOp()) { - siToFp = op; - v = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + operand = op.getSrc(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); } else { InFlightDiagnostic diag = emitError(loc) << "other value used in masked load produced " @@ -70,15 +105,7 @@ getScalarValue(Value v, Location loc, ConversionPatternRewriter &rewriter) { return nullptr; } } - - if (siToFp) { - auto resType = siToFp->getResult(0).getType(); - if (auto shapedType = dyn_cast(resType)) { - resType = shapedType.getElementType(); - } - return rewriter.create(loc, resType, v); - } - return v; + return nullptr; } static SmallVector getNParallelLoopsAttrs(unsigned n) { @@ -309,14 +336,102 @@ struct MakeRangeConverter : public OpConversionPattern { } }; +struct AdvanceConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallDenseMap knownPtrs; + PtrState pointerState; + PtrAnalysis::rewriteAdvanceOp(op, rewriter, knownPtrs); + return success(); + } +}; + +struct MakeTensorPtrConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + void populateVectorAsIndex(SmallVector &vec, + Operation::operand_range ops, + ConversionPatternRewriter &rewriter, + Location loc) const { + for (auto opnd : ops) { + if (opnd.getType().isa()) { + auto castOp = rewriter.create( + loc, rewriter.getIndexType(), opnd); + vec.push_back(castOp.getResult()); + } else { + assert(opnd.getType().isa()); + vec.push_back(opnd); + } + } + } + + LogicalResult + matchAndRewrite(triton::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + PtrState pointerState; + + auto orderSize = op.getOrder().size(); + if (orderSize > 1) { + for (auto [first, second] : + llvm::zip(op.getOrder().slice(0, orderSize - 2), + op.getOrder().slice(1, orderSize - 1))) { + assert(first == second + 1 && + "Currently only support default order on block pointers"); + } + } + + pointerState.source = rewriter.getRemappedValue(op.getBase()); + populateVectorAsIndex(pointerState.offsets, op.getOffsets(), rewriter, loc); + populateVectorAsIndex(pointerState.strides, op.getStrides(), rewriter, loc); + + SmallVector newOffsets; + for (auto [offset, stride] : + llvm::zip(pointerState.offsets, pointerState.strides)) { + auto mulOp = rewriter.create(loc, offset.get(), + stride.get()); + newOffsets.push_back(mulOp.getResult()); + } + + pointerState.offsets.clear(); + + for (auto offset : newOffsets) { + pointerState.offsets.push_back(offset); + } + + ArrayRef resultShape; + auto pointerType = + op.getResult().getType().cast(); + if (auto shapedType = pointerType.getPointeeType().dyn_cast()) { + resultShape = shapedType.getShape(); + for (auto dim_size : resultShape) { + pointerState.sizes.push_back( + IntegerAttr::get(IntegerType::get(op.getContext(), 64), dim_size)); + } + } else { + // scalar pointer, should produce a one dimensional memref + SmallVector scalarShape(1, 1); + resultShape = scalarShape; + assert(pointerState.getRank() == 1); + } + + auto castOp = pointerState.createCastOp(resultShape, loc, rewriter); + rewriter.replaceOp(op, castOp.getResult()); + return success(); + } +}; + struct AddPtrConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - llvm::SmallDenseMap knwonPtrs; - PtrAnalysis::rewriteAddptrOp(op, rewriter, knwonPtrs); + llvm::SmallDenseMap knownPtrs; + PtrAnalysis::rewriteAddptrOp(op, rewriter, knownPtrs); return success(); } }; @@ -331,7 +446,7 @@ struct AssertConverter : public OpConversionPattern { if (condVal.getType().isa()) { auto scalarVal = getScalarValue(op.getCondition(), op.getLoc(), rewriter); - condVal = scalarVal.value_or(condVal); + condVal = scalarVal ? scalarVal : condVal; } assert(condVal && condVal.getType().isa() && "Only asserts on scalars are currently supported"); @@ -347,8 +462,8 @@ struct AssertConverter : public OpConversionPattern { auto assertMessage = llvm::formatv("{0}.py:{1}: {2} Assertion `{3}` failed", op.getFile(), op.getLine(), op.getFunc(), op.getMessage()); - auto assertOp = rewriter.create(op.getLoc(), condVal, - assertMessage.str()); + rewriter.create(op.getLoc(), condVal, + assertMessage.str()); rewriter.eraseOp(op); return success(); @@ -373,6 +488,65 @@ struct LoadConverter : public OpConversionPattern { private: using OpConversionPattern::OpConversionPattern; + template + void createSideBySideCopies(SourceOpTy block1, SourceOpTy block2, Value dst, + Location loc, + ConversionPatternRewriter &rewriter) const { + static_assert((std::is_same() || + std::is_same()) && + "Expect source of split pointers to come from either " + "reinterpret_cast or subview ops"); + + auto zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + + auto block1Dst = rewriter.create( + loc, dst, /* offsets */ ValueRange{zero, zero}, + ofrsToIndexValues(block1.getMixedSizes(), loc, rewriter), + ofrsToIndexValues(block1.getMixedStrides(), loc, rewriter)); + + auto block2Dst = rewriter.create( + loc, dst, + /* offsets */ + ValueRange{zero, + ofrToIndexValue(block1.getMixedSizes()[1], loc, rewriter)}, + ofrsToIndexValues(block2.getMixedSizes(), loc, rewriter), + ofrsToIndexValues(block2.getMixedStrides(), loc, rewriter)); + + rewriter.create(loc, block1.getResult(), block1Dst); + rewriter.create(loc, block2.getResult(), block2Dst); + } + + template + void createStackedCopies(SourceOpTy block1, SourceOpTy block2, Value dst, + Location loc, + ConversionPatternRewriter &rewriter) const { + static_assert((std::is_same() || + std::is_same()) && + "Expect source of split pointers to come from either " + "reinterpret_cast or subview ops"); + + auto zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + + auto block1Dst = rewriter.create( + loc, dst, /* offsets */ ValueRange{zero, zero}, + ofrsToIndexValues(block1.getMixedSizes(), loc, rewriter), + ofrsToIndexValues(block1.getMixedStrides(), loc, rewriter)); + + auto block2Dst = rewriter.create( + loc, dst, + /* offsets */ + ValueRange{ofrToIndexValue(block1.getMixedSizes()[0], loc, rewriter), + zero}, + ofrsToIndexValues(block2.getMixedSizes(), loc, rewriter), + ofrsToIndexValues(block2.getMixedStrides(), loc, rewriter)); + + rewriter.create(loc, block1.getResult(), block1Dst); + rewriter.create(loc, block2.getResult(), block2Dst); + } + +public: LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -383,10 +557,6 @@ struct LoadConverter : public OpConversionPattern { // 0. Shortcut for scalar loads if (!op.getResult().getType().isa()) { - // Temporarily disbale scalar load until later passes support it - op.emitError("Scalar load is currently not supported"); - return failure(); - auto sMemRef = PtrAnalysis::getScalarMemRef(op.getPtr(), adaptor.getPtr(), loc, rewriter); auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); @@ -397,18 +567,54 @@ struct LoadConverter : public OpConversionPattern { } // 1. Simple case where no mask is used. - auto type = ptr.getType().cast(); + auto type = ptr.getType().dyn_cast(); + if (!type) { + // Seen when implicit broadcasting is done late in a chain of operations. + // The workaround is to broadcast the pointers early in the address + // calculation. A proper fix is complicated, but at least we can provide a + // better error message. + op.emitError("LoadOp expects a memref, not a memref of pointers"); + return failure(); + } + + DictionaryAttr attrs; auto tensorType = - RankedTensorType::get(type.getShape(), type.getElementType()); + RankedTensorType::get(type.getShape(), type.getElementType(), attrs); auto alloc = rewriter.create( - loc, MemRefType::get(type.getShape(), type.getElementType())); + loc, MemRefType::get(type.getShape(), type.getElementType(), + AffineMap(), attrs)); if (!mask) { assert(!other && "other value used in non-masked load"); - rewriter.create(loc, ptr, alloc); + if (auto unrealizedCast = + ptr.getDefiningOp()) { + if (auto wrapType = unrealizedCast->getAttrOfType( + ModuloState::WraparoundAttr)) { + + auto memrefs = unrealizedCast.getOperands(); + auto block1 = memrefs[0].getDefiningOp(); + auto block2 = memrefs[1].getDefiningOp(); + + if (wrapType.getValue().equals(ModuloState::WraparoundSideBySide)) { + createSideBySideCopies(block1, block2, alloc, loc, rewriter); + } else if (wrapType.getValue().equals( + ModuloState::WraparoundStacked)) { + createStackedCopies(block1, block2, alloc, loc, rewriter); + } else { + llvm_unreachable("unexpected wraparound type"); + } + } else { + llvm_unreachable("unexpected unrealized cast op"); + } + + } else { + rewriter.create(loc, ptr, alloc); + } + Value tensor = rewriter.create( loc, tensorType, alloc, true /* restrict */, true /* writable */); rewriter.replaceOp(op, tensor); + return success(); } @@ -418,22 +624,15 @@ struct LoadConverter : public OpConversionPattern { MaskState mstate; auto isContMask = mstate.parse(mask, loc, rewriter); - if (isContMask.failed()) - return failure(); - - auto castOp = ptr.getDefiningOp(); - assert(castOp); - ptr = castOp.getResult(); - - auto srcSubview = mstate.getSubview(ptr, loc, rewriter); - auto dstSubview = mstate.getSubview(alloc, loc, rewriter); + if (isContMask.failed()) { + return op.emitError("Cannot lower continuous masked loads"); + } // fill load destination with other value if (other) { auto scalarOther = getScalarValue(other, loc, rewriter); - assert(scalarOther.has_value() && - "other value used in masked load produced by " - "unsupported instruction"); + assert(scalarOther && "other value used in masked load produced by " + "unsupported instruction"); // For each dimension check if mstate.dims[i] < shape[i], or-accumulate // the result @@ -461,13 +660,44 @@ struct LoadConverter : public OpConversionPattern { // initialize with padding prior to CopyOp rewriter.create( loc, accBase, [&](OpBuilder &builder, Location loc) { - builder.create(loc, ValueRange{scalarOther.value()}, + builder.create(loc, ValueRange{scalarOther}, ValueRange{alloc}); builder.create(loc); }); } - rewriter.create(loc, srcSubview, dstSubview); + if (auto unrealizedCast = ptr.getDefiningOp()) { + if (auto wrapType = unrealizedCast->getAttrOfType( + ModuloState::WraparoundAttr)) { + + auto memrefs = unrealizedCast.getOperands(); + auto block1 = memrefs[0].getDefiningOp(); + auto block2 = memrefs[1].getDefiningOp(); + + if (wrapType.getValue().equals(ModuloState::WraparoundSideBySide)) { + auto [subview1, subview2] = + mstate.getSideBySideSubviews(block1, block2, loc, rewriter); + + createSideBySideCopies(subview1, subview2, alloc, loc, rewriter); + } else if (wrapType.getValue().equals(ModuloState::WraparoundStacked)) { + auto [subview1, subview2] = + mstate.getStackedSubviews(block1, block2, loc, rewriter); + + createStackedCopies(subview1, subview2, alloc, loc, rewriter); + } else { + llvm_unreachable("unexpected wraparound type"); + } + + } else { + llvm_unreachable("unexpected unrealized cast op"); + } + + } else { + memref::SubViewOp srcSubview = mstate.getSubview(ptr, loc, rewriter); + memref::SubViewOp dstSubview = mstate.getSubview(alloc, loc, rewriter); + rewriter.create(loc, srcSubview, dstSubview); + } + Value tensor = rewriter.create( loc, tensorType, alloc, true /* restrict */, true /* writable */); rewriter.replaceOp(op, tensor); @@ -603,28 +833,52 @@ struct ReduceConverter : public OpConversionPattern { private: llvm::SmallVector getRedOps(triton::ReduceOp redOp) const { auto reduceBlock = redOp.getBody(); - llvm::SmallVector ops; - for (auto &op : reduceBlock->without_terminator()) { - ops.push_back(&op); - } - return ops; + return llvm::map_to_vector(reduceBlock->without_terminator(), + [](Operation &op) { return &op; }); } bool isReductionOpSupported(Operation *redOp) const { - return isa(redOp); + return isa(redOp); } - float getRedBaseVal(Operation *redOp) const { - return llvm::TypeSwitch(redOp) - .Case([](arith::AddFOp) { return 0; }) - .Case([](arith::MaxFOp) { - return -std::numeric_limits::infinity(); - }) - .Default([](Operation *op) { - op->dump(); - llvm_unreachable("Reduction op not yet supported"); - return -1; - }); + arith::ConstantOp getRedBaseConstOp(ConversionPatternRewriter &rewriter, + Operation *redOp, + Type constantType) const { + const int64_t bitWidth = constantType.getIntOrFloatBitWidth(); + + auto attr = + llvm::TypeSwitch(redOp) + .Case([&](arith::AddFOp) { + return rewriter.getFloatAttr(constantType, 0.f); + }) + .Case([&](arith::MaxFOp) { + return rewriter.getFloatAttr( + constantType, -std::numeric_limits::infinity()); + }) + .Case([&](arith::MinSIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::maxIntN(bitWidth)); + }) + .Case([&](arith::MinUIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::maxUIntN(bitWidth)); + }) + .Case([&](arith::MaxSIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::minIntN(bitWidth)); + }) + .Case([&](arith::MaxUIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not yet supported"); + return nullptr; + }); + + return rewriter.create(redOp->getLoc(), constantType, + attr); } bool requiresF32Conversion(const Type elemType, Operation *redOp) const { @@ -645,8 +899,9 @@ struct ReduceConverter : public OpConversionPattern { } return b.create(loc, lhs, rhs); }) - .Case([&](arith::MaxFOp) { - return b.create(loc, lhs, rhs); + .Case([&](auto redOp) { + return b.create(loc, lhs, rhs); }) .Default([](Operation *op) { op->dump(); @@ -672,7 +927,7 @@ struct ReduceConverter : public OpConversionPattern { if (reductionOps.size() != 1 || !isReductionOpSupported(reductionOps.front())) { return op.emitError("Only support lowering reduction with body " - "containing 1 maxf or addf."); + "containing 1 max(i/f) or addf."); } auto rop = reductionOps.front(); @@ -689,9 +944,8 @@ struct ReduceConverter : public OpConversionPattern { auto constantType = convertToF32Precision ? Float32Type::get(rewriter.getContext()) : elemType; - float accBaseVal = getRedBaseVal(rop); - auto accBase = rewriter.create( - loc, constantType, rewriter.getFloatAttr(constantType, accBaseVal)); + + auto accBaseConstOp = getRedBaseConstOp(rewriter, rop, constantType); Value initTensor; if (isVectorReduce) { @@ -703,13 +957,13 @@ struct ReduceConverter : public OpConversionPattern { // harder to match the patterns correctly). initTensor = rewriter.create( loc, RankedTensorType::get({}, constantType), ValueRange{}); - initTensor = rewriter.create(loc, accBase, initTensor, - ValueRange{}); + initTensor = rewriter.create(loc, accBaseConstOp, + initTensor, ValueRange{}); } else { Value init = rewriter.create( loc, cast(resType).getShape(), constantType); initTensor = rewriter - .create(loc, ValueRange{accBase}, + .create(loc, ValueRange{accBaseConstOp}, ValueRange{init}) .result(); } @@ -762,10 +1016,28 @@ struct ReduceConverter : public OpConversionPattern { } }; +// get_program_id and get_num_programs: +// When launching triton kernels, we pass 6 additional arguments to indicate +// num_programs and program_id. Amongst those six, we have 3 arguments +// correspond to each axis for num_programs followed by 3 additional arguments +// for program_id. +// +// For instance, with triton kernel example_kernel(a, b, c), we have: +// example_kernel( +// a, b, c, +// num_programs_axis_0, +// num_programs_axis_1, +// num_programs_axis_2, +// program_id_axis_0, +// program_id_axis_1, +// program_id_axis_2, +// ) +// struct GetProgramIDConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - static auto constexpr LAUNCH_GRID_RANK = getMaxEnumValForProgramIDDim() + 1; + static uint32_t constexpr LAUNCH_GRID_RANK = + getMaxEnumValForProgramIDDim() + 1; public: GetProgramIDConverter(MLIRContext *context) : OpConversionPattern(context) {} @@ -774,7 +1046,9 @@ struct GetProgramIDConverter matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto axis = (uint32_t)op.getAxis(); - assert(axis < LAUNCH_GRID_RANK && "invalid program-id axis"); + assert(axis < LAUNCH_GRID_RANK && "program_id expects " + "axis to be either 0, " + "1, or 2"); auto func = op->getParentOfType(); auto numArgs = func.getNumArguments(); @@ -785,6 +1059,35 @@ struct GetProgramIDConverter } }; +struct GetNumProgramsConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + static uint32_t constexpr LAUNCH_GRID_RANK = + getMaxEnumValForProgramIDDim() + 1; + +public: + GetNumProgramsConverter(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto axis = (uint32_t)op.getAxis(); + assert(axis < LAUNCH_GRID_RANK && "program_id expects " + "axis to be either 0, " + "1, or 2"); + + auto func = op->getParentOfType(); + auto numArgs = func.getNumArguments(); + auto id = func.getArgument(numArgs - LAUNCH_GRID_RANK * 2 + axis); + + rewriter.replaceOp(op, id); + return success(); + } +}; + // Remove all Meta ops except for AddPtr which is handled by AddPtrConverter. // Use benefit == 10 to ensure that this pattern always takes precedence over // other patterns. @@ -818,13 +1121,14 @@ struct MetaOpConverter : public RewritePattern { // Convert a pair of cmpf and select to either min or max. // Leave the pattern as simple as possible because triton has plans to emit // min and max directly. -struct MinMaxConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +struct MinMaxConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; MinMaxConverter(MLIRContext *context) - : OpRewritePattern(context, /*benefit=*/10) {} + : OpRewritePattern(context, /*benefit=*/10) {} - LogicalResult matchAndRewrite(arith::CmpFOp cmpOp, + LogicalResult matchAndRewrite(CmpOp cmpOp, PatternRewriter &rewriter) const final { if (!cmpOp.getResult().hasOneUse()) { return failure(); @@ -841,21 +1145,52 @@ struct MinMaxConverter : public OpRewritePattern { return failure(); } - auto pred = cmpOp.getPredicate(); - auto loc = cmpOp.getLoc(); - if (pred == arith::CmpFPredicate::OGT) { + rewriteOpWithMinMax(rewriter, cmpOp, selectOp, cmpOp.getPredicate()); + rewriter.eraseOp(cmpOp); + + return success(); + } + + void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpFOp cmpOp, + arith::SelectOp selectOp, + arith::CmpFPredicate pred) const { + switch (pred) { + case arith::CmpFPredicate::OGT: rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), cmpOp.getRhs()); - } else if (pred == arith::CmpFPredicate::OLT) { + break; + case arith::CmpFPredicate::OLT: rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), cmpOp.getRhs()); - } else { + break; + default: llvm_unreachable("Unhandled predicate"); } + } - rewriter.eraseOp(cmpOp); - - return success(); + void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpIOp cmpOp, + arith::SelectOp selectOp, + arith::CmpIPredicate pred) const { + switch (pred) { + case arith::CmpIPredicate::sgt: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + case arith::CmpIPredicate::ugt: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + case arith::CmpIPredicate::slt: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + case arith::CmpIPredicate::ult: + rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), + cmpOp.getRhs()); + break; + default: + llvm_unreachable("Unhandled predicate"); + } } }; @@ -881,21 +1216,37 @@ struct DenseConstantConverter : public OpConversionPattern { } }; +struct UnrealizedCastConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace void mlir::triton::populateTritonToLinalgCanonicalizationPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add, MinMaxConverter>( + patterns.getContext()); } void mlir::triton::populateTritonToLinalgConversionPatterns( - TypeConverter &typeConverter, RewritePatternSet &patterns) { + TypeConverter &typeConverter, RewritePatternSet &patterns, + unsigned int launchGridRank) { populateFunctionOpInterfaceTypeConversionPattern( patterns, typeConverter); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); @@ -909,6 +1260,7 @@ void mlir::triton::populateTritonToLinalgConversionPatterns( patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); // Note: the ordering here matters! // MetaOpConverter has PatternBenefit == 10 which should take precedence over diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp index cfa6ff87..0b2b1411 100644 --- a/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp +++ b/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp @@ -50,18 +50,20 @@ class TritonTypeConverter : public TypeConverter { struct TritonToLinalgPass : public TritonToLinalgBase { - static uint32_t constexpr LAUNCH_GRID_RANK = - getMaxEnumValForProgramIDDim() + 1; - - // Add additional I32 arguments to represent program - // ID, one for each dimension of the launch grid - static void addProgramId(triton::FuncOp func) { + static auto constexpr LAUNCH_GRID_RANK = getMaxEnumValForProgramIDDim() + 1; + static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT = + LAUNCH_GRID_RANK * 2; + + // Add additional I32 arguments to represent: + // - num_programs, 3 in total, one for each axis of the launch grid + // - program_id, 3 in total, one for each axis of the launch grid + static void addProgramInfo(triton::FuncOp func) { OpBuilder b(func); auto origFuncType = func.getFunctionType(); auto origInputTypes = origFuncType.getInputs(); SmallVector newInputTypes(origInputTypes); - newInputTypes.append(LAUNCH_GRID_RANK, b.getI32Type()); + newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, b.getI32Type()); auto newFuncType = b.getFunctionType(newInputTypes, origFuncType.getResults()); @@ -72,12 +74,12 @@ struct TritonToLinalgPass : public TritonToLinalgBase { if (func.getAllArgAttrs()) { SmallVector newArgAttrs; func.getAllArgAttrs(newArgAttrs); - newArgAttrs.append(LAUNCH_GRID_RANK, DictionaryAttr()); + newArgAttrs.append(TRITON_PROGRAM_INFO_ARG_COUNT, DictionaryAttr()); func.setAllArgAttrs(newArgAttrs); } // Add the corresponding arguments to function body - for (unsigned int i = 0; i < LAUNCH_GRID_RANK; i++) { + for (unsigned int i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) { func.getBody().front().addArgument(b.getI32Type(), func.getLoc()); } } @@ -112,11 +114,12 @@ struct TritonToLinalgPass : public TritonToLinalgBase { ConversionTarget target(getContext()); TritonTypeConverter tritonTypeConverter; - target.addLegalDialect< - func::FuncDialect, arith::ArithDialect, math::MathDialect, - linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, - cf::ControlFlowDialect, tensor::TensorDialect, - bufferization::BufferizationDialect, memref::MemRefDialect>(); + target.addLegalDialect(); target.addLegalOp(); @@ -182,11 +185,11 @@ struct TritonToLinalgPass : public TritonToLinalgBase { return !operateOnTensors; }); - triton::populateTritonToLinalgConversionPatterns(tritonTypeConverter, - patterns); + triton::populateTritonToLinalgConversionPatterns( + tritonTypeConverter, patterns, LAUNCH_GRID_RANK); for (auto func : getOperation().getOps()) - addProgramId(func); + addProgramInfo(func); if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) signalPassFailure(); diff --git a/test/Conversion/TritonToLinalg/addptr_2d_example.mlir b/test/Conversion/TritonToLinalg/addptr_2d_example.mlir index c7fc6322..b2e57dd8 100644 --- a/test/Conversion/TritonToLinalg/addptr_2d_example.mlir +++ b/test/Conversion/TritonToLinalg/addptr_2d_example.mlir @@ -45,7 +45,7 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xbf16>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xbf16>, %[[VAL_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) { // CHECK: %[[VAL_7:.*]] = arith.constant 5 : index // CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_3]] : i32 to index // CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_8]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> diff --git a/test/Conversion/TritonToLinalg/addptr_add_value.mlir b/test/Conversion/TritonToLinalg/addptr_add_value.mlir index 1a6121a2..918c6295 100644 --- a/test/Conversion/TritonToLinalg/addptr_add_value.mlir +++ b/test/Conversion/TritonToLinalg/addptr_add_value.mlir @@ -47,7 +47,7 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) { // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : index // CHECK-DAG: %[[VAL_8:.*]] = arith.constant 10 : index // CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_2]] : i32 to index diff --git a/test/Conversion/TritonToLinalg/addptr_dim1.mlir b/test/Conversion/TritonToLinalg/addptr_dim1.mlir new file mode 100644 index 00000000..f818d657 --- /dev/null +++ b/test/Conversion/TritonToLinalg/addptr_dim1.mlir @@ -0,0 +1,107 @@ +// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : i32 + ) + { + %0 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + %1 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + + %splat_arg0 = tt.splat %arg0 : (!tt.ptr) -> tensor<1x256x!tt.ptr> + %2 = tt.addptr %splat_arg0, %1 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + + // 1x256 pointer should have meaningful stride in outer dimension + %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<1x256xbf16> + + %4 = tt.splat %arg1 : (i32) -> tensor<1x256xi32> + // 1x256 pointer should have meaningful stride in outer dimension + %5 = tt.addptr %2, %4 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + tt.store %5, %3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x256x!tt.ptr>, tensor<1x256xbf16> + + %10 = arith.constant 0.0 : bf16 + %11 = tt.splat %10 : (bf16) -> tensor<4x256xbf16> + + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %c256 = arith.constant 256 : i32 + %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %11, %ptr = %2) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr>) { + %bptr = tt.broadcast %ptr : (tensor<1x256x!tt.ptr>) -> tensor<4x256x!tt.ptr> + + %20 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + %i_i32 = arith.index_cast %i : index to i32 + %21 = arith.muli %c256, %i_i32 : i32 + %22 = tt.splat %21 : (i32) -> tensor<4xi32> + %23 = arith.muli %20, %22 : tensor<4xi32> + %24 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %25 = tt.broadcast %24 : (tensor<4x1xi32>) -> tensor<4x256xi32> + + // %bptr should have zero stride and %30 should have correct stride + %30 = tt.addptr %bptr, %25 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + %31 = tt.load %30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256xbf16> + %32 = arith.addf %sum_iter, %31 : tensor<4x256xbf16> + + %40 = tt.splat %c256 : (i32) -> tensor<1x256xi32> + %41 = tt.addptr %ptr, %40 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + + scf.yield %32, %41 : tensor<4x256xbf16>, tensor<1x256x!tt.ptr> + } + + %31 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + %splat_c256 = tt.splat %c256 : (i32) -> tensor<4xi32> + %32 = arith.muli %31, %splat_c256 : tensor<4xi32> + %33 = tt.expand_dims %32 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %34 = tt.broadcast %33 : (tensor<4x1xi32>) -> tensor<4x256xi32> + %35 = tt.broadcast %2 : (tensor<1x256x!tt.ptr>) -> tensor<4x256x!tt.ptr> + %36 = tt.addptr %35, %34 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + tt.store %36, %sum_out {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4x256x!tt.ptr>, tensor<4x256xbf16> + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @kernel +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4x256xbf16> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_0_]] : tensor<4x256xbf16>) -> tensor<4x256xbf16> +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1, 256], strides: [256, 1] : memref<*xbf16> to memref<1x256xbf16, strided<[256, 1]>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1x256xbf16> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<1x256xbf16, strided<[256, 1]>> to memref<1x256xbf16> +// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<1x256xbf16> +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_1_]] : i32 to index +// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [1, 256], strides: [256, 1] : memref<*xbf16> to memref<1x256xbf16, strided<[256, 1], offset: ?>> +// CHECK: memref.tensor_store [[VAR_2_]], [[VAR_reinterpret_cast_0_]] : memref<1x256xbf16, strided<[256, 1], offset: ?>> +// CHECK-DAG: [[VAR_4_:%.+]]:3 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_1_]], [[VAR_arg7_:%.+]] = [[CST_0_]], [[VAR_arg8_:%.+]] = [[CST_0_]]) -> (tensor<4x256xbf16>, index, index) { +// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_arg5_]] : index to i32 +// CHECK: [[VAR_6_:%.+]] = arith.muli [[VAR_5_]], [[CST_256_1_]] : i32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : i32 to index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_arg8_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [4, 256], strides: {{.}}[[VAR_7_]], [[CST_1_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_1_]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16> +// CHECK: [[VAR_9_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4x256xbf16> +// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg6_]], [[VAR_9_]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs([[VAR_arg6_]] : tensor<4x256xbf16>) { +// CHECK: ^bb0([[in1:%.+]]: bf16, [[in2:%.+]]: bf16, [[out:%.+]]: bf16): +// CHECK: [[VAR_13_:%.+]] = arith.addf [[in1]], [[in2]] : bf16 +// CHECK: linalg.yield [[VAR_13_]] : bf16 +// CHECK: } -> tensor<4x256xbf16> +// CHECK: [[VAR_11_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_256_]] : index +// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_11_]], [[VAR_arg8_]] : index +// CHECK: scf.yield [[VAR_10_]], [[VAR_12_]], [[CST_0_]] : tensor<4x256xbf16>, index, index +// CHECK: } +// CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<4x256xbf16, strided<[?, 1]>> +// CHECK: memref.tensor_store [[VAR_4_]]#0, [[VAR_reinterpret_cast_1_]] : memref<4x256xbf16, strided<[?, 1]>> +// CHECK: return +// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir b/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir index 95b2daae..abaf233b 100644 --- a/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir +++ b/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir @@ -59,7 +59,7 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xbf16>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xbf16>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_1O:.*]]: i32) { // CHECK-DAG: %[[VAL_8:.*]] = arith.constant 5 : index // CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_10:.*]] = arith.constant 3 : index diff --git a/test/Conversion/TritonToLinalg/addptr_loopback.mlir b/test/Conversion/TritonToLinalg/addptr_loopback.mlir index 0231fc19..e3d748a4 100644 --- a/test/Conversion/TritonToLinalg/addptr_loopback.mlir +++ b/test/Conversion/TritonToLinalg/addptr_loopback.mlir @@ -39,7 +39,7 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) { // CHECK: %[[VAL_6:.*]] = arith.constant 6 : index // CHECK: %[[VAL_7:.*]] = arith.index_cast %[[VAL_2]] : i32 to index // CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_7]]], sizes: [4, 256], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> diff --git a/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir b/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir index 9eb02f69..99812799 100644 --- a/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir +++ b/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir @@ -32,18 +32,18 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) { // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 2 : index // CHECK-DAG: %[[VAL_8:.*]] = arith.constant 2048 : index -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_3]] : i32 to index +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[ARG_6]] : i32 to index // CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_2]] : i32 to index // CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_8]] : index // CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : index // CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_9]], %[[VAL_11]] : index // CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_6]] : index // CHECK: %[[VAL_15:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [1024], strides: {{\[}}%[[VAL_14]]] : memref<*xbf16> to memref<1024xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[VAL_3]] : i32 to index +// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[ARG_6]] : i32 to index // CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_16]]], sizes: [1024], strides: [1] : memref<*xbf16> to memref<1024xbf16, strided<[1], offset: ?>> // CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<1024xbf16> // CHECK: memref.copy %[[VAL_15]], %[[VAL_18]] : memref<1024xbf16, strided<[?], offset: ?>> to memref<1024xbf16> diff --git a/test/Conversion/TritonToLinalg/addptr_nested.mlir b/test/Conversion/TritonToLinalg/addptr_nested.mlir index f2bb18a5..c4c2ec9d 100644 --- a/test/Conversion/TritonToLinalg/addptr_nested.mlir +++ b/test/Conversion/TritonToLinalg/addptr_nested.mlir @@ -41,7 +41,7 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[ARG_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32) { // CHECK-DAG: %[[VAL_5:.*]] = arith.constant 15 : index // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 5 : index // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 10 : index diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir index fb257294..19d030ef 100644 --- a/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir +++ b/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir @@ -45,8 +45,8 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32) { -// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_5]], %[[VAL_2]] : i32 +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 // CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index // CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024, 1024], strides: [2, 1] : memref<*xf32> to memref<1024x1024xf32, strided<[2, 1], offset: ?>> // CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1024x1024xf32> @@ -57,7 +57,7 @@ module { // CHECK: %[[VAL_16:.*]] = math.exp %[[VAL_14]] : f32 // CHECK: linalg.yield %[[VAL_16]] : f32 // CHECK: } -> tensor<1024x1024xf32> -// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_5]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_17:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 // CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_17]] : i32 to index // CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_18]]], sizes: [1024, 1024], strides: [2, 1] : memref<*xf32> to memref<1024x1024xf32, strided<[2, 1], offset: ?>> // CHECK: memref.tensor_store %[[VAL_20:.*]], %[[VAL_19]] : memref<1024x1024xf32, strided<[2, 1], offset: ?>> diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_for.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_for.mlir index 5a89470f..3c888240 100644 --- a/test/Conversion/TritonToLinalg/addptr_scalar_for.mlir +++ b/test/Conversion/TritonToLinalg/addptr_scalar_for.mlir @@ -35,14 +35,14 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { // CHECK-DAG: %[[VAL_8:.*]] = arith.constant 3 : index // CHECK-DAG: %[[VAL_9:.*]] = arith.constant 12 : index // CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<1024xf32> // CHECK: %[[VAL_15:.*]] = linalg.fill ins(%[[VAL_11]] : f32) outs(%[[VAL_14]] : tensor<1024xf32>) -> tensor<1024xf32> -// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_5]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_12:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 // CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : i32 to index // CHECK: %[[VAL_16:.*]]:2 = scf.for %[[VAL_17:.*]] = %[[VAL_10]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]], %[[VAL_19:.*]] = %[[VAL_13]]) -> (tensor<1024xf32>, index) { // CHECK: %[[VAL_20:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_19]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> @@ -62,7 +62,7 @@ module { // CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_19]], %[[VAL_17]] : index // CHECK: scf.yield %[[VAL_34:.*]], %[[VAL_33]] : tensor<1024xf32>, index // CHECK: } -// CHECK: %[[VAL_35:.*]] = arith.muli %[[VAL_5]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_35:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 // CHECK: %[[VAL_36:.*]] = arith.index_cast %[[VAL_35]] : i32 to index // CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_36]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> // CHECK: memref.tensor_store %[[VAL_38:.*]]#0, %[[VAL_37]] : memref<1024xf32, strided<[1], offset: ?>> diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir index f95bc782..af159e08 100644 --- a/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir +++ b/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir @@ -54,7 +54,7 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { // CHECK-DAG: %[[VAL_8:.*]] = arith.constant 128 : index // CHECK-DAG: %[[VAL_9:.*]] = arith.constant 3 : index // CHECK-DAG: %[[VAL_10:.*]] = arith.constant 12 : index @@ -62,7 +62,7 @@ module { // CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[VAL_15:.*]] = tensor.empty() : tensor<128x128xf32> // CHECK: %[[VAL_16:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_15]] : tensor<128x128xf32>) -> tensor<128x128xf32> -// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_5]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_13:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 // CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_13]] : i32 to index // CHECK: %[[VAL_17:.*]]:2 = scf.for %[[VAL_18:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_9]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]], %[[VAL_20:.*]] = %[[VAL_14]]) -> (tensor<128x128xf32>, index) { // CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_8]] : index @@ -83,7 +83,7 @@ module { // CHECK: %[[VAL_35:.*]] = arith.addi %[[VAL_20]], %[[VAL_18]] : index // CHECK: scf.yield %[[VAL_36:.*]], %[[VAL_35]] : tensor<128x128xf32>, index // CHECK: } -// CHECK: %[[VAL_37:.*]] = arith.muli %[[VAL_5]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_37:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 // CHECK: %[[VAL_38:.*]] = arith.index_cast %[[VAL_37]] : i32 to index // CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_38]], %[[VAL_8]] : index // CHECK: %[[VAL_40:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_39]]], sizes: [128, 128], strides: [2, 1] : memref<*xf32> to memref<128x128xf32, strided<[2, 1], offset: ?>> diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir index 1d041d73..e5a5803f 100644 --- a/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir +++ b/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir @@ -1,20 +1,27 @@ -// RUN: triton-shared-opt --triton-to-linalg %s -// XFAIL: * -// Disable this test since we do not support scalar loads at the moment. +// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { tt.func @kernel( %arg0 : !tt.ptr, %arg1 : !tt.ptr, %arg2 : i32 - ) - { + ) { %0 = tt.addptr %arg0, %arg2 : !tt.ptr, i32 %1 = tt.addptr %arg1, %arg2 : !tt.ptr, i32 - - // expected-error @below {{Scalar load is currently not supported}} %10 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: bf16 tt.store %1, %10 : bf16 - tt.return + tt.return } } + +// CHECK: module { +// CHECK: func.func @kernel(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { +// CHECK: %0 = arith.index_cast %arg2 : i32 to index +// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%0], sizes: [1], strides: [1] : memref<*xbf16> to memref<1xbf16, strided<[1], offset: ?>> +// CHECK: %1 = arith.index_cast %arg2 : i32 to index +// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%1], sizes: [1], strides: [1] : memref<*xbf16> to memref<1xbf16, strided<[1], offset: ?>> +// CHECK: %2 = affine.load %reinterpret_cast[0] : memref<1xbf16, strided<[1], offset: ?>> +// CHECK: affine.store %2, %reinterpret_cast_0[0] : memref<1xbf16, strided<[1], offset: ?>> +// CHECK: return +// CHECK: } +// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir index ffd4fba2..b2434398 100644 --- a/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir +++ b/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir @@ -31,10 +31,10 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32) { -// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_5]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_5]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_5]], %[[VAL_4]] : i32 +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_9:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_10:.*]] = arith.muli %[[ARG_8]], %[[VAL_4]] : i32 // CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_8]] : i32 to index // CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_9]] : i32 to index // CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : index @@ -49,7 +49,7 @@ module { // CHECK: %[[VAL_22:.*]] = math.exp %[[VAL_20]] : f32 // CHECK: linalg.yield %[[VAL_22]] : f32 // CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_23:.*]] = arith.muli %[[VAL_5]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_23:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 // CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : i32 to index // CHECK: %[[VAL_25:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_24]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> // CHECK: memref.tensor_store %[[VAL_26:.*]], %[[VAL_25]] : memref<1024xf32, strided<[1], offset: ?>> diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir index def681e2..d90c4e65 100644 --- a/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir +++ b/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir @@ -25,8 +25,8 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32) { -// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_5]], %[[VAL_2]] : i32 +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 // CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index // CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> // CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1024xf32> @@ -37,7 +37,7 @@ module { // CHECK: %[[VAL_16:.*]] = math.exp %[[VAL_14]] : f32 // CHECK: linalg.yield %[[VAL_16]] : f32 // CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_5]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_17:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 // CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_17]] : i32 to index // CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_18]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> // CHECK: memref.tensor_store %[[VAL_20:.*]], %[[VAL_19]] : memref<1024xf32, strided<[1], offset: ?>> diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir index 88cde351..4275e28b 100644 --- a/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir +++ b/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir @@ -33,9 +33,9 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { // CHECK: %[[VAL_8:.*]] = arith.constant 128 : index -// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_5]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_9:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 // CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : i32 to index // CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_8]] : index // CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_11]]], sizes: [128, 128], strides: [2, 1] : memref<*xf32> to memref<128x128xf32, strided<[2, 1], offset: ?>> @@ -47,7 +47,7 @@ module { // CHECK: %[[VAL_18:.*]] = math.exp %[[VAL_16]] : f32 // CHECK: linalg.yield %[[VAL_18]] : f32 // CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_5]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_19:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 // CHECK: %[[VAL_20:.*]] = arith.index_cast %[[VAL_19]] : i32 to index // CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_8]] : index // CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_21]]], sizes: [128, 128], strides: [2, 1] : memref<*xf32> to memref<128x128xf32, strided<[2, 1], offset: ?>> diff --git a/test/Conversion/TritonToLinalg/bitcast.mlir b/test/Conversion/TritonToLinalg/bitcast.mlir index c50dccf1..ea951b8c 100644 --- a/test/Conversion/TritonToLinalg/bitcast.mlir +++ b/test/Conversion/TritonToLinalg/bitcast.mlir @@ -25,7 +25,7 @@ module { } // CHECK: module { -// CHECK: func.func @kernel(%arg0: memref<*xi32>, %arg1: memref<*xf32>, %arg2: i32, %arg3: i32, %arg4: i32) { +// CHECK: func.func @kernel(%arg0: memref<*xi32>, %arg1: memref<*xf32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { // CHECK: [[RC_:%.+]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1024], strides: [1]{{.*}} : memref<*xi32> to memref<1024xi32, strided<[1]>> // CHECK: [[RC_0_:%.+]] = memref.reinterpret_cast %arg1 to offset: [0], sizes: [1024], strides: [1]{{.*}} : memref<*xf32> to memref<1024xf32, strided<[1]>> // CHECK: [[ALLOC_:%.+]] = memref.alloc() : memref<1024xi32> @@ -41,3 +41,4 @@ module { // CHECK: return // CHECK: } // CHECK: } + diff --git a/test/Conversion/TritonToLinalg/block_ptr_advance.mlir b/test/Conversion/TritonToLinalg/block_ptr_advance.mlir new file mode 100644 index 00000000..a37046f6 --- /dev/null +++ b/test/Conversion/TritonToLinalg/block_ptr_advance.mlir @@ -0,0 +1,93 @@ +// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @matmul_kernel_with_block_pointers_01234567891011(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32) { + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 0.000000e+00 : bf16 + %c256_i32 = arith.constant 256 : i32 + %0 = arith.extsi %arg3 : i32 to i64 + %1 = arith.extsi %arg5 : i32 to i64 + %2 = arith.extsi %arg6 : i32 to i64 + %3 = arith.extsi %arg7 : i32 to i64 + %4 = tt.make_tensor_ptr %arg0, [%0, %1], [%2, %3], [%arg12, %c0_i32] {order = array} : > + %5 = tt.advance %4, [%c0_i32, %c64_i32] : > + %6 = tt.splat %cst : (bf16) -> tensor<128x64xbf16> + %7:3 = scf.for %arg14 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg15 = %6, %arg16 = %5, %arg17 = %4) -> (tensor<128x64xbf16>, !tt.ptr>, !tt.ptr>) : i32 { + %13 = tt.load %arg16 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> -> tensor<128x64xbf16> + %14 = tt.load %arg17 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> -> tensor<128x64xbf16> + %15 = arith.addf %13, %14 : tensor<128x64xbf16> + %16 = arith.addf %arg15, %15 : tensor<128x64xbf16> + %17 = tt.advance %arg16, [%c0_i32, %c64_i32] : > + %18 = tt.advance %arg17, [%c64_i32, %c0_i32] : > + scf.yield %16, %17, %18 : tensor<128x64xbf16>, !tt.ptr>, !tt.ptr> + } + %8 = arith.extsi %arg10 : i32 to i64 + %9 = arith.extsi %arg11 : i32 to i64 + %10 = arith.extsi %arg4 : i32 to i64 + %11 = arith.muli %arg13, %c256_i32 : i32 + %12 = tt.make_tensor_ptr %arg2, [%0, %10], [%8, %9], [%arg12, %11] {order = array} : > + tt.store %12, %7#0 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr>, tensor<128x64xbf16> + tt.return + } +} + +// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: module { +// CHECK: func.func @matmul_kernel_with_block_pointers_01234567891011(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: memref<*xbf16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32, %arg18: i32, %arg19: i32) { +// CHECK: %c64 = arith.constant 64 : index +// CHECK: %c0 = arith.constant 0 : index +// CHECK: %c256_i32 = arith.constant 256 : i32 +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %c64_i32 = arith.constant 64 : i32 +// CHECK: %cst = arith.constant 0.000000e+00 : bf16 +// CHECK: %0 = tensor.empty() : tensor<128x64xbf16> +// CHECK: %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> +// CHECK: %2 = arith.index_cast %arg12 : i32 to index +// CHECK: %3 = arith.index_cast %arg6 : i32 to index +// CHECK: %4 = arith.index_cast %arg7 : i32 to index +// CHECK: %5 = arith.muli %2, %3 : index +// CHECK: %6 = arith.muli %4, %c64 : index +// CHECK: %7 = arith.addi %5, %6 : index +// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%7], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg0 to offset: [%5], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK: %8:7 = scf.for %arg20 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg21 = %1, %arg22 = %reinterpret_cast, %arg23 = %reinterpret_cast_0, %arg24 = %7, %arg25 = %c0, %arg26 = %5, %arg27 = %c0) -> (tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index, index, index) : i32 { +// CHECK: %alloc = memref.alloc() : memref<128x64xbf16> +// CHECK: memref.copy %arg22, %alloc : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref<128x64xbf16> +// CHECK: %17 = bufferization.to_tensor %alloc restrict writable : memref<128x64xbf16> +// CHECK: %alloc_2 = memref.alloc() : memref<128x64xbf16> +// CHECK: memref.copy %arg23, %alloc_2 : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref<128x64xbf16> +// CHECK: %18 = bufferization.to_tensor %alloc_2 restrict writable : memref<128x64xbf16> +// CHECK: %19 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%17, %18 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%17 : tensor<128x64xbf16>) { +// CHECK: ^bb0(%in: bf16, %in_5: bf16, %out: bf16): +// CHECK: %27 = arith.addf %in, %in_5 : bf16 +// CHECK: linalg.yield %27 : bf16 +// CHECK: } -> tensor<128x64xbf16> +// CHECK: %20 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg21, %19 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%arg21 : tensor<128x64xbf16>) { +// CHECK: ^bb0(%in: bf16, %in_5: bf16, %out: bf16): +// CHECK: %27 = arith.addf %in, %in_5 : bf16 +// CHECK: linalg.yield %27 : bf16 +// CHECK: } -> tensor<128x64xbf16> +// CHECK: %21 = arith.muli %4, %c64 : index +// CHECK: %22 = arith.addi %21, %arg25 : index +// CHECK: %23 = arith.addi %arg24, %22 : index +// CHECK: %reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%23], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK: %24 = arith.muli %3, %c64 : index +// CHECK: %25 = arith.addi %24, %arg26 : index +// CHECK: %26 = arith.addi %25, %arg27 : index +// CHECK: %reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%26], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK: scf.yield %20, %reinterpret_cast_3, %reinterpret_cast_4, %23, %c0, %26, %c0 : tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index, index, index +// CHECK: } +// CHECK: %9 = arith.muli %arg13, %c256_i32 : i32 +// CHECK: %10 = arith.index_cast %arg12 : i32 to index +// CHECK: %11 = arith.index_cast %9 : i32 to index +// CHECK: %12 = arith.index_cast %arg10 : i32 to index +// CHECK: %13 = arith.index_cast %arg11 : i32 to index +// CHECK: %14 = arith.muli %10, %12 : index +// CHECK: %15 = arith.muli %11, %13 : index +// CHECK: %16 = arith.addi %14, %15 : index +// CHECK: %reinterpret_cast_1 = memref.reinterpret_cast %arg2 to offset: [%16], sizes: [128, 64], strides: [%12, %13] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK: memref.tensor_store %8#0, %reinterpret_cast_1 : memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK: return +// CHECK: } +// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir b/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir new file mode 100644 index 00000000..066ef866 --- /dev/null +++ b/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir @@ -0,0 +1,126 @@ +// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s +module { + tt.func public @minmax_sgt(%arg0: !tt.ptr) { + %cst_0 = arith.constant dense<0> : tensor<4096xi32> + %63 = "tt.reduce"(%cst_0) ({ + ^bb0(%arg14: i32, %arg15: i32): + %69 = arith.cmpi sgt, %arg14, %arg15 : i32 + %70 = arith.select %69, %arg14, %arg15 : i32 + tt.reduce.return %70 : i32 + }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 + tt.store %arg0, %63 {cache = 1 : i32, evict = 1 : i32} : i32 + tt.return + } +} + +// CHECK: func.func @minmax_sgt(%[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> +// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> +// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor +// CHECK: %[[VAL_10:.*]] = tensor.insert %c-2147483648{{.*}} into %[[VAL_9]][] : tensor +// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_10]] : tensor) dimensions = [0] +// CHECK: (%in: i32, %init: i32) { +// CHECK: %[[VAL_12:.*]] = arith.maxsi %in, %init : i32 +// CHECK: linalg.yield %[[VAL_12]] : i32 +// CHECK: } +// CHECK: %[[VAL_12:.*]] = tensor.extract %[[VAL_11]][] : tensor +// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> +// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> +// CHECK: return +// CHECK: } + +// ----- + +module { + tt.func public @minmax_ugt(%arg0: !tt.ptr) { + %cst_0 = arith.constant dense<0> : tensor<4096xi32> + %63 = "tt.reduce"(%cst_0) ({ + ^bb0(%arg14: i32, %arg15: i32): + %69 = arith.cmpi ugt, %arg14, %arg15 : i32 + %70 = arith.select %69, %arg14, %arg15 : i32 + tt.reduce.return %70 : i32 + }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 + tt.store %arg0, %63 {cache = 1 : i32, evict = 1 : i32} : i32 + tt.return + } +} + +// CHECK: func.func @minmax_ugt(%[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> +// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> +// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor +// CHECK: %[[VAL_10:.*]] = tensor.insert %c0{{.*}} into %[[VAL_9]][] : tensor +// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_10]] : tensor) dimensions = [0] +// CHECK: (%in: i32, %init: i32) { +// CHECK: %[[VAL_12:.*]] = arith.maxui %in, %init : i32 +// CHECK: linalg.yield %[[VAL_12]] : i32 +// CHECK: } +// CHECK: %[[VAL_12:.*]] = tensor.extract %[[VAL_11]][] : tensor +// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> +// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> +// CHECK: return +// CHECK: } + +// ----- + +module { + tt.func public @minmax_slt(%arg0: !tt.ptr) { + %cst_0 = arith.constant dense<0> : tensor<4096xi32> + %63 = "tt.reduce"(%cst_0) ({ + ^bb0(%arg14: i32, %arg15: i32): + %69 = arith.cmpi slt, %arg14, %arg15 : i32 + %70 = arith.select %69, %arg14, %arg15 : i32 + tt.reduce.return %70 : i32 + }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 + tt.store %arg0, %63 {cache = 1 : i32, evict = 1 : i32} : i32 + tt.return + } +} + +// CHECK: func.func @minmax_slt(%[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> +// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> +// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor +// CHECK: %[[VAL_10:.*]] = tensor.insert %c2147483647{{.*}} into %[[VAL_9]][] : tensor +// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_10]] : tensor) dimensions = [0] +// CHECK: (%in: i32, %init: i32) { +// CHECK: %[[VAL_12:.*]] = arith.minsi %in, %init : i32 +// CHECK: linalg.yield %[[VAL_12]] : i32 +// CHECK: } +// CHECK: %[[VAL_12:.*]] = tensor.extract %[[VAL_11]][] : tensor +// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> +// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> +// CHECK: return +// CHECK: } + +// ----- + +module { + tt.func public @minmax_ult(%arg0: !tt.ptr) { + %cst_0 = arith.constant dense<0> : tensor<4096xi32> + %63 = "tt.reduce"(%cst_0) ({ + ^bb0(%arg14: i32, %arg15: i32): + %69 = arith.cmpi ult, %arg14, %arg15 : i32 + %70 = arith.select %69, %arg14, %arg15 : i32 + tt.reduce.return %70 : i32 + }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 + tt.store %arg0, %63 {cache = 1 : i32, evict = 1 : i32} : i32 + tt.return + } +} + +// CHECK: func.func @minmax_ult(%[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> +// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> +// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor +// CHECK: %[[VAL_10:.*]] = tensor.insert %c-1{{.*}} into %[[VAL_9]][] : tensor +// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_10]] : tensor) dimensions = [0] +// CHECK: (%in: i32, %init: i32) { +// CHECK: %[[VAL_12:.*]] = arith.minui %in, %init : i32 +// CHECK: linalg.yield %[[VAL_12]] : i32 +// CHECK: } +// CHECK: %[[VAL_12:.*]] = tensor.extract %[[VAL_11]][] : tensor +// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> +// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> +// CHECK: return +// CHECK: } \ No newline at end of file diff --git a/test/Conversion/TritonToLinalg/get_num_programs.mlir b/test/Conversion/TritonToLinalg/get_num_programs.mlir new file mode 100644 index 00000000..15de818e --- /dev/null +++ b/test/Conversion/TritonToLinalg/get_num_programs.mlir @@ -0,0 +1,44 @@ +// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @num_programs(%arg0: !tt.ptr) { + %0 = tt.get_num_programs {axis = 0 : i32} : i32 + %1 = tt.get_num_programs {axis = 1 : i32} : i32 + %2 = tt.get_num_programs {axis = 2 : i32} : i32 + %3 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> + %4 = tt.make_range {end = 2 : i32, start = 1 : i32} : tensor<1xi32> + %5 = tt.make_range {end = 3 : i32, start = 2 : i32} : tensor<1xi32> + %6 = tt.splat %arg0 : (!tt.ptr) -> tensor<1x!tt.ptr> + %7 = tt.addptr %6, %3 : tensor<1x!tt.ptr>, tensor<1xi32> + %8 = tt.splat %0 : (i32) -> tensor<1xi32> + tt.store %7, %8 {cache = 1 : i32, evict = 1 : i32} : tensor<1xi32> + %9 = tt.addptr %6, %4 : tensor<1x!tt.ptr>, tensor<1xi32> + %10 = tt.splat %1 : (i32) -> tensor<1xi32> + tt.store %9, %10 {cache = 1 : i32, evict = 1 : i32} : tensor<1xi32> + %11 = tt.addptr %6, %5 : tensor<1x!tt.ptr>, tensor<1xi32> + %12 = tt.splat %2 : (i32) -> tensor<1xi32> + tt.store %11, %12 {cache = 1 : i32, evict = 1 : i32} : tensor<1xi32> + tt.return + } +} + +// CHECK: module { +// CHECK: func.func @num_programs(%arg0: memref<*xi32>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { +// CHECK: %c2 = arith.constant 2 : index +// CHECK: %c1 = arith.constant 1 : index +// CHECK: %c0 = arith.constant 0 : index +// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> +// CHECK: %0 = tensor.empty() : tensor<1xi32> +// CHECK: %1 = linalg.fill ins(%arg1 : i32) outs(%0 : tensor<1xi32>) -> tensor<1xi32> +// CHECK: memref.tensor_store %1, %reinterpret_cast : memref<1xi32, strided<[1], offset: ?>> +// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg0 to offset: [%c1], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> +// CHECK: %2 = tensor.empty() : tensor<1xi32> +// CHECK: %3 = linalg.fill ins(%arg2 : i32) outs(%2 : tensor<1xi32>) -> tensor<1xi32> +// CHECK: memref.tensor_store %3, %reinterpret_cast_0 : memref<1xi32, strided<[1], offset: ?>> +// CHECK: %reinterpret_cast_1 = memref.reinterpret_cast %arg0 to offset: [%c2], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> +// CHECK: %4 = tensor.empty() : tensor<1xi32> +// CHECK: %5 = linalg.fill ins(%arg3 : i32) outs(%4 : tensor<1xi32>) -> tensor<1xi32> +// CHECK: memref.tensor_store %5, %reinterpret_cast_1 : memref<1xi32, strided<[1], offset: ?>> +// CHECK: return +// CHECK: } +// CHECK: } diff --git a/test/Conversion/TritonToLinalg/kernel-01-vector-add.mlir b/test/Conversion/TritonToLinalg/kernel-01-vector-add.mlir index 776fc563..b53ab2fb 100644 --- a/test/Conversion/TritonToLinalg/kernel-01-vector-add.mlir +++ b/test/Conversion/TritonToLinalg/kernel-01-vector-add.mlir @@ -26,10 +26,10 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @add_kernel_01234 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { // CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index // CHECK-DAG: [[CST_1024_1_:%.+]] = arith.constant 1024 : i32 -// CHECK: [[VAR_0_:%.+]] = arith.muli [[PARAM_4_]], [[CST_1024_1_]] : i32 +// CHECK: [[VAR_0_:%.+]] = arith.muli [[PARAM_7_]], [[CST_1024_1_]] : i32 // CHECK: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [1024], strides: [1]{{.*}} : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1024xf32> diff --git a/test/Conversion/TritonToLinalg/kernel-02-fused-softmax.mlir b/test/Conversion/TritonToLinalg/kernel-02-fused-softmax.mlir index e0689c65..7b95588d 100644 --- a/test/Conversion/TritonToLinalg/kernel-02-fused-softmax.mlir +++ b/test/Conversion/TritonToLinalg/kernel-02-fused-softmax.mlir @@ -40,22 +40,22 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @softmax_kernel_012345 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32) { // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_5_]], [[PARAM_2_]] : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_2_]] : i32 // CHECK: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [128], strides: [1]{{.*}} : memref<*xf32> to memref<128xf32, strided<[1], offset: ?>> // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128xf32> // CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index // CHECK: [[VAR_3_:%.+]] = arith.minsi [[VAR_2_]], [[CST_128_]] : index -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_3_]]{{.}} [1]{{.*}} : memref<128xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_3_]]{{.}} [1] : memref<128xf32> to memref> // CHECK-DAG: [[VAR_4_:%.+]] = arith.cmpi slt, [[VAR_3_]], [[CST_128_]] : index // CHECK: scf.if [[VAR_4_]] { // CHECK: linalg.fill ins([[CST_0_]] : f32) outs([[RES_]] : memref<128xf32>) // CHECK: } +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_3_]]{{.}} [1]{{.*}} : memref<128xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_3_]]{{.}} [1] : memref<128xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_1 : memref> to memref> // CHECK-DAG: [[VAR_5_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128xf32> // CHECK-DAG: [[VAR_6_:%.+]] = bufferization.alloc_tensor() : tensor @@ -93,7 +93,7 @@ module { // CHECK: [[VAR_19_4_:%.+]] = arith.divf [[in_1]], [[in_2]] : f32 // CHECK: linalg.yield [[VAR_19_4_]] : f32 // CHECK: } -> tensor<128xf32> -// CHECK: [[VAR_15_:%.+]] = arith.muli [[PARAM_5_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_15_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_3_]] : i32 // CHECK: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_]] : i32 to index // CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_16_]]{{.}}, sizes: [128], strides: [1]{{.*}} : memref<*xf32> to memref<128xf32, strided<[1], offset: ?>> // CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index diff --git a/test/Conversion/TritonToLinalg/kernel-03-matrix-multiplication.mlir b/test/Conversion/TritonToLinalg/kernel-03-matrix-multiplication.mlir index 92426a2e..62964933 100644 --- a/test/Conversion/TritonToLinalg/kernel-03-matrix-multiplication.mlir +++ b/test/Conversion/TritonToLinalg/kernel-03-matrix-multiplication.mlir @@ -98,7 +98,7 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func.func @matmul_kernel_0123456789101112131415 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32, [[PARAM_14_:%.+]]: i32) { +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32, [[PARAM_14_:%.+]]: i32, [[PARAM_15_:%.+]]: i32, [[PARAM_16_:%.+]]: i32, [[PARAM_17_:%.+]]: i32) { // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index // CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index @@ -125,14 +125,13 @@ module { // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_7_:%.+]] = arith.divsi [[VAR_6_]], [[CST_64_]] : i32 // CHECK-DAG: [[VAR_8_:%.+]] = arith.muli [[VAR_5_]], [[CST_8_]] : i32 -// CHECK: [[VAR_9_:%.+]] = arith.divsi [[PARAM_12_]], [[VAR_8_]] : i32 +// CHECK: [[VAR_9_:%.+]] = arith.divsi [[PARAM_15_]], [[VAR_8_]] : i32 // CHECK: [[VAR_10_:%.+]] = arith.muli [[VAR_9_]], [[CST_8_]] : i32 // CHECK: [[VAR_11_:%.+]] = arith.subi [[VAR_3_]], [[VAR_10_]] : i32 -// CHECK: [[VAR_12_:%.+]] = arith.cmpi slt, [[VAR_11_]], [[CST_8_]] : i32 -// CHECK: [[VAR_13_:%.+]] = arith.select [[VAR_12_]], [[VAR_11_]], [[CST_8_]] : i32 -// CHECK: [[VAR_14_:%.+]] = arith.remsi [[PARAM_12_]], [[VAR_13_]] : i32 +// CHECK: [[VAR_13_:%.+]] = arith.minsi [[VAR_11_]], [[CST_8_]] : i32 +// CHECK: [[VAR_14_:%.+]] = arith.remsi [[PARAM_15_]], [[VAR_13_]] : i32 // CHECK-DAG: [[VAR_15_:%.+]] = arith.addi [[VAR_10_]], [[VAR_14_]] : i32 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.remsi [[PARAM_12_]], [[VAR_8_]] : i32 +// CHECK-DAG: [[VAR_16_:%.+]] = arith.remsi [[PARAM_15_]], [[VAR_8_]] : i32 // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_17_:%.+]] = arith.divsi [[VAR_16_]], [[VAR_13_]] : i32 // CHECK-DAG: [[VAR_18_:%.+]] = arith.muli [[VAR_15_]], [[CST_128_1_]] : i32 diff --git a/test/Conversion/TritonToLinalg/kernel-05-layer-norm-dwdb.mlir b/test/Conversion/TritonToLinalg/kernel-05-layer-norm-dwdb.mlir index b8d1d94e..67852650 100644 --- a/test/Conversion/TritonToLinalg/kernel-05-layer-norm-dwdb.mlir +++ b/test/Conversion/TritonToLinalg/kernel-05-layer-norm-dwdb.mlir @@ -62,7 +62,7 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func.func @_layer_norm_bwd_dwdb_0123456 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xf32>, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xf32>, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32) { // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index // CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 @@ -70,7 +70,7 @@ module { // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<256x256xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<256x256xf32>) -> tensor<256x256xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_6_]], [[CST_256_1_]] : i32 +// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_9_]], [[CST_256_1_]] : i32 // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_3_:%.+]]:2 = scf.for [[VAR_arg9_:%.+]] = [[CST_0_]] to [[PARAM_4_]] step [[CST_256_1_]] iter_args([[VAR_arg10_:%.+]] = [[VAR_1_]], [[VAR_arg11_:%.+]] = [[VAR_1_]]) -> (tensor<256x256xf32>, tensor<256x256xf32>) : i32 { // CHECK-DAG: [[VAR_20_:%.+]] = arith.index_cast [[VAR_arg9_]] : i32 to index @@ -95,14 +95,14 @@ module { // CHECK-DAG: [[VAR_34_:%.+]] = arith.subi [[VAR_33_]], [[VAR_30_]] : index // CHECK-DAG: [[VAR_35_:%.+]] = arith.minsi [[VAR_29_]], [[CST_256_]] : index // CHECK: [[VAR_36_:%.+]] = arith.minsi [[VAR_34_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_subview_5_:%.+]] = memref.subview [[VAR_reinterpret_cast_4_]][0, 0] {{.}}[[VAR_35_]], [[VAR_36_]]{{.}} [1, 1] : memref<256x256xf32, strided<[?, 1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_35_]], [[VAR_36_]]{{.}} [1, 1] : memref<256x256xf32> to memref> // CHECK-DAG: [[VAR_37_:%.+]] = arith.cmpi slt, [[VAR_35_]], [[CST_256_]] : index // CHECK-DAG: [[VAR_38_:%.+]] = arith.cmpi slt, [[VAR_36_]], [[CST_256_]] : index // CHECK: [[VAR_39_:%.+]] = arith.ori [[VAR_37_]], [[VAR_38_]] : i1 // CHECK: scf.if [[VAR_39_]] { // CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_]] : memref<256x256xf32>) // CHECK: } +// CHECK-DAG: [[VAR_subview_5_:%.+]] = memref.subview [[VAR_reinterpret_cast_4_]][0, 0] {{.}}[[VAR_35_]], [[VAR_36_]]{{.}} [1, 1] : memref<256x256xf32, strided<[?, 1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_35_]], [[VAR_36_]]{{.}} [1, 1] : memref<256x256xf32> to memref> // CHECK: memref.copy [[VAR_subview_5_]], [[VAR_subview_6_]] : memref> to memref> // CHECK: [[VAR_40_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256x256xf32> // CHECK: [[VAR_41_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg10_]], [[VAR_40_]] : tensor<256x256xf32>, tensor<256x256xf32>) outs([[VAR_arg10_]] : tensor<256x256xf32>) { @@ -132,14 +132,14 @@ module { // CHECK-DAG: [[VAR_56_:%.+]] = arith.subi [[VAR_55_]], [[VAR_52_]] : index // CHECK-DAG: [[VAR_57_:%.+]] = arith.minsi [[VAR_51_]], [[CST_256_]] : index // CHECK: [[VAR_58_:%.+]] = arith.minsi [[VAR_56_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_subview_9_:%.+]] = memref.subview [[VAR_reinterpret_cast_7_]][0, 0] {{.}}[[VAR_57_]], [[VAR_58_]]{{.}} [1, 1] : memref<256x256xf32, strided<[?, 1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_10_:%.+]] = memref.subview [[RES_1_]][0, 0] {{.}}[[VAR_57_]], [[VAR_58_]]{{.}} [1, 1] : memref<256x256xf32> to memref> // CHECK-DAG: [[VAR_59_:%.+]] = arith.cmpi slt, [[VAR_57_]], [[CST_256_]] : index // CHECK-DAG: [[VAR_60_:%.+]] = arith.cmpi slt, [[VAR_58_]], [[CST_256_]] : index // CHECK: [[VAR_61_:%.+]] = arith.ori [[VAR_59_]], [[VAR_60_]] : i1 // CHECK: scf.if [[VAR_61_]] { // CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_1_]] : memref<256x256xf32>) // CHECK: } +// CHECK-DAG: [[VAR_subview_9_:%.+]] = memref.subview [[VAR_reinterpret_cast_7_]][0, 0] {{.}}[[VAR_57_]], [[VAR_58_]]{{.}} [1, 1] : memref<256x256xf32, strided<[?, 1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_10_:%.+]] = memref.subview [[RES_1_]][0, 0] {{.}}[[VAR_57_]], [[VAR_58_]]{{.}} [1, 1] : memref<256x256xf32> to memref> // CHECK: memref.copy [[VAR_subview_9_]], [[VAR_subview_10_]] : memref> to memref> // CHECK: [[VAR_62_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256x256xf32> // CHECK: [[VAR_63_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg11_]], [[VAR_62_]] : tensor<256x256xf32>, tensor<256x256xf32>) outs([[VAR_arg11_]] : tensor<256x256xf32>) { diff --git a/test/Conversion/TritonToLinalg/kernel-05-layer-norm-fwd.mlir b/test/Conversion/TritonToLinalg/kernel-05-layer-norm-fwd.mlir index 5d58729f..5c703c1e 100644 --- a/test/Conversion/TritonToLinalg/kernel-05-layer-norm-fwd.mlir +++ b/test/Conversion/TritonToLinalg/kernel-05-layer-norm-fwd.mlir @@ -90,7 +90,7 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @_layer_norm_fwd_fused_0123456789 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xf32>, [[PARAM_4_:%.+]]: memref<*xf32>, [[PARAM_5_:%.+]]: memref<*xf32>, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: f32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32) { +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xf32>, [[PARAM_4_:%.+]]: memref<*xf32>, [[PARAM_5_:%.+]]: memref<*xf32>, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: f32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32, [[PARAM_14_:%.+]]: i32) { // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index // CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 @@ -99,7 +99,7 @@ module { // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<256xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<256xf32>) -> tensor<256xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_9_]], [[PARAM_6_]] : i32 +// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_12_]], [[PARAM_6_]] : i32 // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_3_:%.+]] = scf.for [[VAR_arg12_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] iter_args([[VAR_arg13_:%.+]] = [[VAR_1_]]) -> (tensor<256xf32>) : i32 { // CHECK-DAG: [[VAR_25_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index @@ -113,12 +113,12 @@ module { // CHECK-DAG: [[VAR_30_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index // CHECK: [[VAR_31_:%.+]] = arith.minsi [[VAR_29_]], [[VAR_30_]] : index // CHECK: [[VAR_32_:%.+]] = arith.subi [[VAR_31_]], [[VAR_28_]] : index -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] {{.}}[[VAR_32_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_32_]]{{.}} [1] : memref<256xf32> to memref> // CHECK-DAG: [[VAR_33_:%.+]] = arith.cmpi slt, [[VAR_32_]], [[CST_256_]] : index // CHECK: scf.if [[VAR_33_]] { // CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_]] : memref<256xf32>) // CHECK: } +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] {{.}}[[VAR_32_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_32_]]{{.}} [1] : memref<256xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_6 : memref> to memref> // CHECK: [[VAR_34_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256xf32> // CHECK: [[VAR_35_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg13_]], [[VAR_34_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_arg13_]] : tensor<256xf32>) { @@ -176,12 +176,12 @@ module { // CHECK-DAG: [[VAR_35_1_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index // CHECK: [[VAR_36_1_:%.+]] = arith.minsi [[VAR_34_1_]], [[VAR_35_1_]] : index // CHECK: [[VAR_37_:%.+]] = arith.subi [[VAR_36_1_]], [[VAR_33_1_]] : index -// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_1_]][0] {{.}}[[VAR_37_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_6_1_:%.+]] = memref.subview [[RES_1_]][0] {{.}}[[VAR_37_]]{{.}} [1] : memref<256xf32> to memref> // CHECK-DAG: [[VAR_38_:%.+]] = arith.cmpi slt, [[VAR_37_]], [[CST_256_]] : index // CHECK: scf.if [[VAR_38_]] { // CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_1_]] : memref<256xf32>) // CHECK: } +// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_1_]][0] {{.}}[[VAR_37_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_6_1_:%.+]] = memref.subview [[RES_1_]][0] {{.}}[[VAR_37_]]{{.}} [1] : memref<256xf32> to memref> // CHECK: memref.copy [[VAR_subview_1_]], [[VAR_subview_1_]]_6 : memref> to memref> // CHECK: [[VAR_39_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256xf32> // CHECK: [[VAR_40_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_39_]], [[VAR_12_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_39_]] : tensor<256xf32>) { @@ -218,10 +218,10 @@ module { // CHECK: [[VAR_16_:%.+]] = arith.addf [[VAR_15_]], [[PARAM_8_]] : f32 // CHECK: [[VAR_17_:%.+]] = math.sqrt [[VAR_16_]] : f32 // CHECK-DAG: [[VAR_18_:%.+]] = arith.divf [[CST_1_dot_000000_]], [[VAR_17_]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[PARAM_9_]] : i32 to index +// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[PARAM_12_]] : i32 to index // CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_4_]] to offset: {{.}}[[VAR_19_]]{{.}}, sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset: ?>> // CHECK: affine.store [[VAR_6_]], [[VAR_reinterpret_cast_]][0] : memref<1xf32, strided<[1], offset: ?>> -// CHECK: [[VAR_20_:%.+]] = arith.index_cast [[PARAM_9_]] : i32 to index +// CHECK: [[VAR_20_:%.+]] = arith.index_cast [[PARAM_12_]] : i32 to index // CHECK: [[VAR_reinterpret_cast_4_:%.+]] = memref.reinterpret_cast [[PARAM_5_]] to offset: {{.}}[[VAR_20_]]{{.}}, sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset: ?>> // CHECK: affine.store [[VAR_18_]], [[VAR_reinterpret_cast_4_]][0] : memref<1xf32, strided<[1], offset: ?>> // CHECK: [[VAR_21_:%.+]] = tensor.empty() : tensor<256xf32> @@ -267,12 +267,12 @@ module { // CHECK-DAG: [[VAR_44_6_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index // CHECK: [[VAR_45_:%.+]] = arith.minsi [[VAR_43_1_]], [[VAR_44_6_]] : index // CHECK: [[VAR_46_:%.+]] = arith.subi [[VAR_45_]], [[VAR_42_1_]] : index -// CHECK-DAG: [[VAR_subview_13_:%.+]] = memref.subview [[VAR_reinterpret_cast_11_]][0] {{.}}[[VAR_46_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_14_:%.+]] = memref.subview [[RES_4_]][0] {{.}}[[VAR_46_]]{{.}} [1] : memref<256xf32> to memref> // CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpi slt, [[VAR_46_]], [[CST_256_]] : index // CHECK: scf.if [[VAR_47_]] { // CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_4_]] : memref<256xf32>) // CHECK: } +// CHECK-DAG: [[VAR_subview_13_:%.+]] = memref.subview [[VAR_reinterpret_cast_11_]][0] {{.}}[[VAR_46_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_14_:%.+]] = memref.subview [[RES_4_]][0] {{.}}[[VAR_46_]]{{.}} [1] : memref<256xf32> to memref> // CHECK: memref.copy [[VAR_subview_13_]], [[VAR_subview_14_]] : memref> to memref> // CHECK: [[VAR_48_:%.+]] = bufferization.to_tensor [[RES_4_]] restrict writable : memref<256xf32> // CHECK: [[VAR_49_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_48_]], [[VAR_22_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_48_]] : tensor<256xf32>) { diff --git a/test/Conversion/TritonToLinalg/masked_ldst_1d.mlir b/test/Conversion/TritonToLinalg/masked_ldst_1d.mlir index 912dcc7f..81ff6822 100644 --- a/test/Conversion/TritonToLinalg/masked_ldst_1d.mlir +++ b/test/Conversion/TritonToLinalg/masked_ldst_1d.mlir @@ -20,7 +20,7 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) { // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0xFF80 : bf16 // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 128 : index // CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> @@ -28,12 +28,12 @@ module { // CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<128xbf16> // CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_2]] : i32 to index // CHECK: %[[VAL_12:.*]] = arith.minsi %[[VAL_11]], %[[VAL_7]] : index -// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_8]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16, strided<[1]>> to memref> -// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_10]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16> to memref> // CHECK: %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_12]], %[[VAL_7]] : index // CHECK: scf.if %[[VAL_15]] { // CHECK: linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_10]] : memref<128xbf16>) // CHECK: } +// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_8]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16, strided<[1]>> to memref> +// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_10]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16> to memref> // CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref> to memref> // CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<128xbf16> // CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_2]] : i32 to index diff --git a/test/Conversion/TritonToLinalg/masked_ldst_2d.mlir b/test/Conversion/TritonToLinalg/masked_ldst_2d.mlir index 7492d03e..aaa6294f 100644 --- a/test/Conversion/TritonToLinalg/masked_ldst_2d.mlir +++ b/test/Conversion/TritonToLinalg/masked_ldst_2d.mlir @@ -62,7 +62,7 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) { // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3074 : index // CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1024 : index // CHECK-DAG: %[[VAL_9:.*]] = arith.constant 3 : index @@ -83,14 +83,14 @@ module { // CHECK: %[[VAL_24:.*]] = arith.subi %[[VAL_23]], %[[VAL_9]] : index // CHECK: %[[VAL_25:.*]] = arith.minsi %[[VAL_21]], %[[VAL_12]] : index // CHECK: %[[VAL_26:.*]] = arith.minsi %[[VAL_24]], %[[VAL_11]] : index -// CHECK: %[[VAL_27:.*]] = memref.subview %[[VAL_16]][0, 0] {{\[}}%[[VAL_25]], %[[VAL_26]]] [1, 1] : memref<128x256xbf16, strided<[1, ?], offset: ?>> to memref> -// CHECK: %[[VAL_28:.*]] = memref.subview %[[VAL_18]][0, 0] {{\[}}%[[VAL_25]], %[[VAL_26]]] [1, 1] : memref<128x256xbf16> to memref> // CHECK: %[[VAL_29:.*]] = arith.cmpi slt, %[[VAL_25]], %[[VAL_12]] : index // CHECK: %[[VAL_30:.*]] = arith.cmpi slt, %[[VAL_26]], %[[VAL_11]] : index // CHECK: %[[VAL_31:.*]] = arith.ori %[[VAL_29]], %[[VAL_30]] : i1 // CHECK: scf.if %[[VAL_31]] { // CHECK: linalg.fill ins(%[[VAL_15]] : bf16) outs(%[[VAL_18]] : memref<128x256xbf16>) // CHECK: } +// CHECK: %[[VAL_27:.*]] = memref.subview %[[VAL_16]][0, 0] {{\[}}%[[VAL_25]], %[[VAL_26]]] [1, 1] : memref<128x256xbf16, strided<[1, ?], offset: ?>> to memref> +// CHECK: %[[VAL_28:.*]] = memref.subview %[[VAL_18]][0, 0] {{\[}}%[[VAL_25]], %[[VAL_26]]] [1, 1] : memref<128x256xbf16> to memref> // CHECK: memref.copy %[[VAL_27]], %[[VAL_28]] : memref> to memref> // CHECK: %[[VAL_32:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<128x256xbf16> // CHECK: %[[VAL_33:.*]] = arith.index_cast %[[VAL_2]] : i32 to index diff --git a/test/Conversion/TritonToLinalg/masked_ldst_sitofp_other.mlir b/test/Conversion/TritonToLinalg/masked_ldst_sitofp_other.mlir index 12c6bd92..09008d6a 100644 --- a/test/Conversion/TritonToLinalg/masked_ldst_sitofp_other.mlir +++ b/test/Conversion/TritonToLinalg/masked_ldst_sitofp_other.mlir @@ -22,7 +22,7 @@ module { } } // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) { // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 7.000000e+00 : bf16 // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 128 : index // CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> @@ -30,12 +30,12 @@ module { // CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<128xbf16> // CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_2]] : i32 to index // CHECK: %[[VAL_12:.*]] = arith.minsi %[[VAL_11]], %[[VAL_7]] : index -// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_8]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16, strided<[1]>> to memref> -// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_10]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16> to memref> // CHECK: %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_12]], %[[VAL_7]] : index // CHECK: scf.if %[[VAL_15]] { // CHECK: linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_10]] : memref<128xbf16>) // CHECK: } +// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_8]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16, strided<[1]>> to memref> +// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_10]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16> to memref> // CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref> to memref> // CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<128xbf16> // CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_2]] : i32 to index diff --git a/test/Conversion/TritonToLinalg/triton_assert.mlir b/test/Conversion/TritonToLinalg/triton_assert.mlir new file mode 100644 index 00000000..5a6824c2 --- /dev/null +++ b/test/Conversion/TritonToLinalg/triton_assert.mlir @@ -0,0 +1,15 @@ +// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s +tt.func public @assert_lol(%arg0: i32) { + %c0_i32 = arith.constant 0 : i32 + %0 = arith.cmpi sgt, %arg0, %c0_i32 : i32 + %1 = tt.splat %0 : (i1) -> tensor<1xi1> + tt.assert %1, "lol", "", "", 0 : tensor<1xi1> + tt.return +} + +// CHECK: func.func @assert_lol(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %0 = arith.cmpi sgt, %arg0, %c0_i32 : i32 +// CHECK: cf.assert %0, ".py:0: Assertion `lol` failed" +// CHECK: return +// CHECK: } diff --git a/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir b/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir new file mode 100644 index 00000000..ffe59823 --- /dev/null +++ b/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir @@ -0,0 +1,132 @@ +// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @wrap_side_by_side_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { + %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %cst_0 = arith.constant dense<2> : tensor<4x1xi32> + %cst_1 = arith.constant dense<6> : tensor<4xi32> + %cst_2 = arith.constant dense<2> : tensor<4xi32> + %c4_i32 = arith.constant 4 : i32 + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = arith.addi %0, %cst_2 : tensor<4xi32> + %2 = arith.addi %0, %cst_1 : tensor<4xi32> + %3 = tt.splat %arg3 : (i32) -> tensor<4xi32> + %4 = arith.remsi %2, %3 : tensor<4xi32> + %5 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %6 = tt.splat %arg4 : (i32) -> tensor<4x1xi32> + %7 = arith.muli %5, %6 : tensor<4x1xi32> + %8 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> + %9 = tt.splat %arg5 : (i32) -> tensor<1x4xi32> + %10 = arith.muli %8, %9 : tensor<1x4xi32> + %11 = tt.broadcast %7 : (tensor<4x1xi32>) -> tensor<4x4xi32> + %12 = tt.broadcast %10 : (tensor<1x4xi32>) -> tensor<4x4xi32> + %13 = arith.addi %11, %12 : tensor<4x4xi32> + %14 = tt.splat %arg0 : (!tt.ptr) -> tensor<4x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %16 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %17 = tt.splat %arg6 : (i32) -> tensor<4x1xi32> + %18 = arith.muli %17, %16 : tensor<4x1xi32> + %19 = tt.splat %arg1 : (!tt.ptr) -> tensor<4x1x!tt.ptr> + %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> + %21 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> + %22 = tt.splat %arg7 : (i32) -> tensor<1x4xi32> + %23 = arith.muli %22, %21 : tensor<1x4xi32> + %24 = tt.broadcast %20 : (tensor<4x1x!tt.ptr>) -> tensor<4x4x!tt.ptr> + %25 = tt.broadcast %23 : (tensor<1x4xi32>) -> tensor<4x4xi32> + %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32> + %28 = tt.broadcast %27 : (tensor<4x1xi1>) -> tensor<4x4xi1> + %29 = arith.muli %arg4, %c4_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<4x4xi32> + %31 = arith.muli %arg5, %c4_i32 : i32 + %32 = tt.splat %31 : (i32) -> tensor<4x4xi32> + %33:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { + %34 = tt.load %arg9, %28, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4x4xf32> + tt.store %arg10, %34 {cache = 1 : i32, evict = 1 : i32} : tensor<4x4xf32> + %35 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %36 = tt.addptr %arg10, %32 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + scf.yield %35, %36 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> + } + tt.return + } +} + +// CHECK-LABEL: func.func @wrap_side_by_side_masked_loop_01234567 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { +// CHECK-DAG: [[CST_minus_9_dot_900000_:%.+]] = arith.constant -9.900000e+01 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 +// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[CST_2_1_]] : index +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_4_:%.+]] = arith.muli [[VAR_3_]], [[CST_6_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_1_]], [[VAR_4_]] : index +// CHECK: [[VAR_6_:%.+]] = arith.remsi [[VAR_5_]], [[VAR_2_]] : index +// CHECK-DAG: [[VAR_7_:%.+]] = arith.subi [[VAR_5_]], [[VAR_6_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_6_]], [[CST_4_]] : index +// CHECK: [[VAR_9_:%.+]] = arith.minsi [[VAR_8_]], [[VAR_2_]] : index +// CHECK: [[VAR_10_:%.+]] = arith.subi [[VAR_9_]], [[VAR_6_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_5_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_10_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_11_:%.+]] = arith.subi [[CST_4_]], [[VAR_10_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_7_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_11_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index +// CHECK-DAG: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK-DAG: [[VAR_14_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_1_]] : i32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_1_]] : i32 +// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_17_:%.+]] = arith.muli [[VAR_16_]], [[CST_2_1_]] : index +// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_20_:%.+]] = arith.muli [[VAR_19_]], [[CST_6_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_21_:%.+]]:6 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg16_:%.+]] = [[VAR_reinterpret_cast_]]_1, [[VAR_arg17_:%.+]] = [[VAR_17_]], [[VAR_arg18_:%.+]] = [[CST_0_]], [[VAR_arg19_:%.+]] = [[CST_0_]], [[VAR_arg20_:%.+]] = [[VAR_reinterpret_cast_]]_0) -> (memref<4x?xf32, strided<[?, ?], offset: ?>>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref<4x?xf32, strided<[?, ?], offset: ?>>) : i32 { +// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[VAR_arg15_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_10_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_reinterpret_cast_3_:%.+]] = memref.reinterpret_cast [[VAR_arg20_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_11_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> +// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) +// CHECK: [[VAR_22_:%.+]] = arith.minsi [[VAR_10_]], [[CST_4_]] : index +// CHECK-DAG: [[VAR_23_:%.+]] = arith.subi [[CST_4_]], [[VAR_22_]] : index +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_2_]][0, 0] [2, [[VAR_22_]]{{.}} {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[VAR_reinterpret_cast_3_]][0, 0] [2, [[VAR_23_]]{{.}} {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_subview_5_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_22_]]{{.}} {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<4x4xf32> to memref<2x?xf32, strided<[?, ?]>> +// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0, [[VAR_22_]]{{.}} [2, [[VAR_23_]]{{.}} {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<4x4xf32> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_5 : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?]>> +// CHECK: memref.copy [[VAR_subview_4_]], [[VAR_subview_6_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK: [[VAR_24_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> +// CHECK: memref.tensor_store [[VAR_24_]], [[VAR_arg16_]] : memref<4x4xf32, strided<[?, ?], offset: ?>> +// CHECK: [[VAR_25_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index +// CHECK: [[VAR_26_:%.+]] = arith.addi [[VAR_arg17_]], [[VAR_25_]] : index +// CHECK: [[VAR_27_:%.+]] = arith.addi [[VAR_26_]], [[VAR_20_]] : index +// CHECK: [[VAR_28_:%.+]] = arith.remsi [[VAR_27_]], [[VAR_18_]] : index +// CHECK-DAG: [[VAR_29_:%.+]] = arith.subi [[VAR_27_]], [[VAR_28_]] : index +// CHECK-DAG: [[VAR_30_:%.+]] = arith.addi [[VAR_28_]], [[CST_4_]] : index +// CHECK: [[VAR_31_:%.+]] = arith.minsi [[VAR_30_]], [[VAR_18_]] : index +// CHECK: [[VAR_32_:%.+]] = arith.subi [[VAR_31_]], [[VAR_28_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_27_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_32_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_19_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_33_:%.+]] = arith.subi [[CST_4_]], [[VAR_32_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_reinterpret_cast_8_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_29_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_33_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_19_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_34_:%.+]] = arith.index_cast [[VAR_15_]] : i32 to index +// CHECK: [[VAR_35_:%.+]] = arith.addi [[VAR_arg18_]], [[VAR_34_]] : index +// CHECK: [[VAR_36_:%.+]] = arith.addi [[VAR_35_]], [[VAR_arg19_]] : index +// CHECK: [[VAR_reinterpret_cast_9_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_36_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> +// CHECK: scf.yield [[VAR_reinterpret_cast_7_]], [[VAR_reinterpret_cast_9_]], [[VAR_26_]], [[VAR_36_]], [[CST_0_]], [[VAR_reinterpret_cast_8_]] : memref<4x?xf32, strided<[?, ?], offset: ?>>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/test/Conversion/TritonToLinalg/wraparound_stacked.mlir b/test/Conversion/TritonToLinalg/wraparound_stacked.mlir new file mode 100644 index 00000000..37676ab6 --- /dev/null +++ b/test/Conversion/TritonToLinalg/wraparound_stacked.mlir @@ -0,0 +1,129 @@ +// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @wrap_stacked_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { + %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %cst_0 = arith.constant dense<3> : tensor<1x4xi32> + %cst_1 = arith.constant dense<3> : tensor<4xi32> + %cst_2 = arith.constant dense<2> : tensor<4xi32> + %c4_i32 = arith.constant 4 : i32 + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = arith.addi %0, %cst_2 : tensor<4xi32> + %2 = tt.splat %arg2 : (i32) -> tensor<4xi32> + %3 = arith.remsi %1, %2 : tensor<4xi32> + %4 = arith.addi %0, %cst_1 : tensor<4xi32> + %5 = tt.expand_dims %3 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %6 = tt.splat %arg4 : (i32) -> tensor<4x1xi32> + %7 = arith.muli %5, %6 : tensor<4x1xi32> + %8 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> + %9 = tt.splat %arg5 : (i32) -> tensor<1x4xi32> + %10 = arith.muli %8, %9 : tensor<1x4xi32> + %11 = tt.broadcast %7 : (tensor<4x1xi32>) -> tensor<4x4xi32> + %12 = tt.broadcast %10 : (tensor<1x4xi32>) -> tensor<4x4xi32> + %13 = arith.addi %11, %12 : tensor<4x4xi32> + %14 = tt.splat %arg0 : (!tt.ptr) -> tensor<4x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %16 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %17 = tt.splat %arg6 : (i32) -> tensor<4x1xi32> + %18 = arith.muli %17, %16 : tensor<4x1xi32> + %19 = tt.splat %arg1 : (!tt.ptr) -> tensor<4x1x!tt.ptr> + %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> + %21 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> + %22 = tt.splat %arg7 : (i32) -> tensor<1x4xi32> + %23 = arith.muli %22, %21 : tensor<1x4xi32> + %24 = tt.broadcast %20 : (tensor<4x1x!tt.ptr>) -> tensor<4x4x!tt.ptr> + %25 = tt.broadcast %23 : (tensor<1x4xi32>) -> tensor<4x4xi32> + %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %27 = arith.cmpi slt, %21, %cst_0 : tensor<1x4xi32> + %28 = tt.broadcast %27 : (tensor<1x4xi1>) -> tensor<4x4xi1> + %29 = arith.muli %arg5, %c4_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<4x4xi32> + %31:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { + %32 = tt.load %arg9, %28, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4x4xf32> + tt.store %arg10, %32 {cache = 1 : i32, evict = 1 : i32} : tensor<4x4xf32> + %33 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %34 = tt.addptr %arg10, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + scf.yield %33, %34 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> + } + tt.return + } +} + +// CHECK-LABEL: func.func @wrap_stacked_masked_loop_01234567 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { +// CHECK-DAG: [[CST_minus_9_dot_900000_:%.+]] = arith.constant -9.900000e+01 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 +// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[VAR_1_]], [[CST_2_1_]] : index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_4_:%.+]] = arith.muli [[VAR_3_]], [[CST_3_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_2_]], [[VAR_4_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.remsi [[VAR_5_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_7_:%.+]] = arith.muli [[VAR_0_]], [[VAR_1_]] : index +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_7_]], [[VAR_6_]] : index +// CHECK: [[VAR_9_:%.+]] = arith.subi [[VAR_8_]], [[VAR_5_]] : index +// CHECK: [[VAR_10_:%.+]] = arith.divsi [[VAR_9_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_5_]]{{.}}, sizes: {{.}}[[VAR_10_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref> +// CHECK-DAG: [[VAR_11_:%.+]] = arith.subi [[CST_4_]], [[VAR_10_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_6_]]{{.}}, sizes: {{.}}[[VAR_11_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref> +// CHECK-DAG: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index +// CHECK-DAG: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK-DAG: [[VAR_14_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_1_]] : i32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_17_:%.+]] = arith.muli [[VAR_16_]], [[CST_2_1_]] : index +// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_19_:%.+]] = arith.muli [[VAR_18_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_20_:%.+]]:6 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg16_:%.+]] = [[VAR_reinterpret_cast_]]_1, [[VAR_arg17_:%.+]] = [[VAR_17_]], [[VAR_arg18_:%.+]] = [[CST_0_]], [[VAR_arg19_:%.+]] = [[CST_0_]], [[VAR_arg20_:%.+]] = [[VAR_reinterpret_cast_]]_0) -> (memref>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref>) : i32 { +// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[VAR_arg15_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: {{.}}[[VAR_10_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref> to memref> +// CHECK-DAG: [[VAR_reinterpret_cast_3_:%.+]] = memref.reinterpret_cast [[VAR_arg20_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: {{.}}[[VAR_11_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref> to memref> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> +// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) +// CHECK: [[VAR_21_:%.+]] = arith.minsi [[VAR_10_]], [[CST_4_]] : index +// CHECK-DAG: [[VAR_22_:%.+]] = arith.subi [[CST_4_]], [[VAR_21_]] : index +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_2_]][0, 0] {{.}}[[VAR_21_]], 3] {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref> to memref> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[VAR_reinterpret_cast_3_]][0, 0] {{.}}[[VAR_22_]], 3] {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref> to memref> +// CHECK-DAG: [[VAR_subview_5_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_21_]], 3] {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref<4x4xf32> to memref> +// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_21_]], 0] {{.}}[[VAR_22_]], 3] {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref<4x4xf32> to memref> +// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_5 : memref> to memref> +// CHECK: memref.copy [[VAR_subview_4_]], [[VAR_subview_6_]] : memref> to memref> +// CHECK: [[VAR_23_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> +// CHECK: memref.tensor_store [[VAR_23_]], [[VAR_arg16_]] : memref<4x4xf32, strided<[?, ?], offset: ?>> +// CHECK: [[VAR_24_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index +// CHECK: [[VAR_25_:%.+]] = arith.addi [[VAR_arg17_]], [[VAR_24_]] : index +// CHECK: [[VAR_26_:%.+]] = arith.addi [[VAR_25_]], [[VAR_19_]] : index +// CHECK-DAG: [[VAR_27_:%.+]] = arith.remsi [[VAR_26_]], [[VAR_16_]] : index +// CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[VAR_15_]], [[VAR_16_]] : index +// CHECK: [[VAR_29_:%.+]] = arith.addi [[VAR_28_]], [[VAR_27_]] : index +// CHECK: [[VAR_30_:%.+]] = arith.subi [[VAR_29_]], [[VAR_26_]] : index +// CHECK: [[VAR_31_:%.+]] = arith.divsi [[VAR_30_]], [[VAR_16_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_26_]]{{.}}, sizes: {{.}}[[VAR_31_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_18_]]{{.}} : memref<*xf32> to memref> +// CHECK-DAG: [[VAR_32_:%.+]] = arith.subi [[CST_4_]], [[VAR_31_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_reinterpret_cast_8_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_27_]]{{.}}, sizes: {{.}}[[VAR_32_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_18_]]{{.}} : memref<*xf32> to memref> +// CHECK-DAG: [[VAR_33_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index +// CHECK: [[VAR_34_:%.+]] = arith.addi [[VAR_arg18_]], [[VAR_33_]] : index +// CHECK: [[VAR_35_:%.+]] = arith.addi [[VAR_34_]], [[VAR_arg19_]] : index +// CHECK: [[VAR_reinterpret_cast_9_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_35_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> +// CHECK: scf.yield [[VAR_reinterpret_cast_7_]], [[VAR_reinterpret_cast_9_]], [[VAR_25_]], [[VAR_35_]], [[CST_0_]], [[VAR_reinterpret_cast_8_]] : memref>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref> +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir b/test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir new file mode 100644 index 00000000..b0bcfd74 --- /dev/null +++ b/test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir @@ -0,0 +1,57 @@ +// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s +// XFAIL: * +// We currently do not support this kind of modulo pattern: +// (a + arrange(0, K)) % M +module { + tt.func public @wrap_side_by_side_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { + %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %cst_0 = arith.constant dense<2> : tensor<4x1xi32> + %cst_1 = arith.constant dense<6> : tensor<4xi32> + %cst_2 = arith.constant dense<2> : tensor<4xi32> + %c4_i32 = arith.constant 4 : i32 + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = arith.addi %0, %cst_2 : tensor<4xi32> + %2 = tt.splat %arg3 : (i32) -> tensor<4xi32> + %3 = arith.remsi %0, %2 : tensor<4xi32> + %4 = arith.addi %3, %cst_1 : tensor<4xi32> + %5 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %6 = tt.splat %arg4 : (i32) -> tensor<4x1xi32> + %7 = arith.muli %5, %6 : tensor<4x1xi32> + %8 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> + %9 = tt.splat %arg5 : (i32) -> tensor<1x4xi32> + %10 = arith.muli %8, %9 : tensor<1x4xi32> + %11 = tt.broadcast %7 : (tensor<4x1xi32>) -> tensor<4x4xi32> + %12 = tt.broadcast %10 : (tensor<1x4xi32>) -> tensor<4x4xi32> + %13 = arith.addi %11, %12 : tensor<4x4xi32> + %14 = tt.splat %arg0 : (!tt.ptr) -> tensor<4x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %16 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %17 = tt.splat %arg6 : (i32) -> tensor<4x1xi32> + %18 = arith.muli %17, %16 : tensor<4x1xi32> + %19 = tt.splat %arg1 : (!tt.ptr) -> tensor<4x1x!tt.ptr> + %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> + %21 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> + %22 = tt.splat %arg7 : (i32) -> tensor<1x4xi32> + %23 = arith.muli %22, %21 : tensor<1x4xi32> + %24 = tt.broadcast %20 : (tensor<4x1x!tt.ptr>) -> tensor<4x4x!tt.ptr> + %25 = tt.broadcast %23 : (tensor<1x4xi32>) -> tensor<4x4xi32> + %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32> + %28 = tt.broadcast %27 : (tensor<4x1xi1>) -> tensor<4x4xi1> + %29 = arith.muli %arg4, %c4_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<4x4xi32> + %31 = arith.muli %arg5, %c4_i32 : i32 + %32 = tt.splat %31 : (i32) -> tensor<4x4xi32> + %33:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { + %34 = tt.load %arg9, %28, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4x4xf32> + tt.store %arg10, %34 {cache = 1 : i32, evict = 1 : i32} : tensor<4x4xf32> + %35 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %36 = tt.addptr %arg10, %32 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + scf.yield %35, %36 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> + } + tt.return + } +}