Skip to content

Commit

Permalink
Update triton-shared 10/2023 (#9)
Browse files Browse the repository at this point in the history
This update includes various fixes and the following features:
- preliminary support for block-pointers
  - block-pointers are lowered to memref buffers similarly to traditional triton pointer loads
- support lowering for `triton.get_num_programs` by introducing extra arguments when launching triton kernels similarly to `triton.get_program_id`
- improved lowering `triton.reduce` which previously only supports lowering float values
- preliminary support for pointer arithmetic involving the modulo operator
  - the modulo operator use case is seen in the tutorial matmul example, where pointer offsets are being modded to prevent loading out-of-bound values; in such case, these values are wrapped around to the beginning of the buffer
  • Loading branch information
nhat-nguyen authored Oct 12, 2023
1 parent 40584de commit 959825b
Show file tree
Hide file tree
Showing 40 changed files with 2,171 additions and 368 deletions.
12 changes: 12 additions & 0 deletions include/triton-shared/Analysis/MaskAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "triton/Dialect/Triton/IR/Dialect.h"

#include <utility>

namespace mlir {

class ConversionPatternRewriter;
Expand Down Expand Up @@ -64,6 +66,16 @@ struct MaskState {
memref::SubViewOp getSubview(Value source, const Location loc,
ConversionPatternRewriter &rewriter) const;

std::pair<memref::SubViewOp, memref::SubViewOp>
getSideBySideSubviews(memref::ReinterpretCastOp chunk1,
memref::ReinterpretCastOp chunk2, const Location loc,
ConversionPatternRewriter &rewriter) const;

std::pair<memref::SubViewOp, memref::SubViewOp>
getStackedSubviews(memref::ReinterpretCastOp chunk1,
memref::ReinterpretCastOp chunk2, const Location loc,
ConversionPatternRewriter &rewriter) const;

private:
// -------
// Utility functions to operate on MaskState
Expand Down
19 changes: 13 additions & 6 deletions include/triton-shared/Analysis/OpFoldResultUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,41 @@

namespace mlir {

class ConversionPatternRewriter;
class OpBuilder;

// Return integer if ofr is an IntegerAttr. Note that this function differs
// from getConstantIntValue, which returns an integer if ofr is the constant
// result of an operation too.
std::optional<int64_t> getIntAttr(const OpFoldResult ofr);

// Create a value of index type if necessary from an OpFoldResult.
Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, OpBuilder &b);

// Create a vector of values of index type if necessary from an array of
// OpFoldResults.
SmallVector<Value> ofrsToIndexValues(ArrayRef<OpFoldResult> ofrs,
const Location loc, OpBuilder &b);

// Process addition of two OFRs. If both OFRs are Integer Attributes, result
// is an Integer Attribute. Otherwise, insert the arith.addi instruction if
// needed and use its result Value.
OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
const Location loc, ConversionPatternRewriter &rewriter);
const Location loc, OpBuilder &b);

// Produce result = lhs - rhs. If both OFRs are Integer Attributes, result
// is an Integer Attribute. Otherwise, insert the arith.addi instruction if
// needed and use its result Value.
OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
const Location loc, ConversionPatternRewriter &rewriter);
const Location loc, OpBuilder &b);

// Process multiplication of two OFRs. If both OFRs are Integer Attributes,
// result is an Integer Attribtue. Otherwise, insert the arith.muli
// instruction if needed and use its result Value.
OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs,
const Location loc,
ConversionPatternRewriter &rewriter);
const Location loc, OpBuilder &b);

OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
const Location loc, ConversionPatternRewriter &rewriter);
const Location loc, OpBuilder &b);

} // namespace mlir

Expand Down
66 changes: 62 additions & 4 deletions include/triton-shared/Analysis/PtrAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,47 @@ class ConversionPatternRewriter;

namespace triton {

struct ModuloState {
Value size;
OpFoldResult offset;
ModuloState() {}
ModuloState(Value size, OpFoldResult offset) : size{size}, offset{offset} {}

static constexpr char const *WraparoundAttr = "ptr.wraparound_type";
static constexpr char const *WraparoundStacked = "stacked";
static constexpr char const *WraparoundSideBySide = "side_by_side";
};

// Data structure used to decode pointer arithmetics and potentially to be
// translate it into memref. offsets, sizes, and strides are in unit of elements
// in a linearly laid-out memory, which is the same as pointer arithmetic
// operations in Triton language. scalar is a shortcut used when the entire
// state describes a single scalar value. source is the base pointer.
struct PtrState {
class PtrState {

OpFoldResult
accumulateTargetOffset(Location loc,
ConversionPatternRewriter &rewriter) const;

public:
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;

SmallVector<std::optional<ModuloState>> modulos;

Value source;
Value scalar;

int64_t getRank() const;

bool isEmpty() const;

bool hasModulo() const;

MemRefType getResultMemrefType(MLIRContext *context, int64_t offset,
ArrayRef<int64_t> resultShape) const;

// Process addition of two PtrStates.
void addState(const PtrState &lhsState, const PtrState &rhsState,
Location loc, ConversionPatternRewriter &rewriter);
Expand All @@ -48,9 +73,17 @@ struct PtrState {

// Produce a reinterpret cast based on the current PtrState. Additional
// instructions may be inserted in calculating the final offset.
memref::ReinterpretCastOp createCastOp(ArrayRef<int64_t> resultShape,
const Location loc,
ConversionPatternRewriter &rewriter);
memref::ReinterpretCastOp
createCastOp(ArrayRef<int64_t> resultShape, const Location loc,
ConversionPatternRewriter &rewriter) const;

SmallVector<memref::ReinterpretCastOp>
createSideBySideCastOps(ArrayRef<int64_t> resultShape, const Location loc,
ConversionPatternRewriter &rewriter) const;

SmallVector<memref::ReinterpretCastOp>
createStackedCastOps(ArrayRef<int64_t> resultShape, const Location loc,
ConversionPatternRewriter &rewriter) const;
};

class PtrAnalysis {
Expand Down Expand Up @@ -95,6 +128,16 @@ class PtrAnalysis {
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, PtrState> &knownPtrs);

static void
visitOperandRem(arith::RemSIOp mulOp, PtrState &state, const Location loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, PtrState> &knownPtrs);

static void visitOperandUnrealizedCast(
UnrealizedConversionCastOp op, PtrState &state, const Location loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, PtrState> &knownPtrs);

// Operand is the result of make_range.
// Main assumptions:
// start, end, and shape are all statically known
Expand Down Expand Up @@ -156,6 +199,11 @@ class PtrAnalysis {
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, PtrState> &knownPtrs);

static void visitOperandMakeTensorPtr(
triton::MakeTensorPtrOp makeTensorPtrOp, PtrState &state,
const Location loc, ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, PtrState> &knownPtrs);

// Operand is the result of addptr.
// Main assumptions:
// The ptr field should populate the source field
Expand All @@ -177,6 +225,16 @@ class PtrAnalysis {
const Location loc, ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, PtrState> &knownPtrs);

// Operand is the result of tt.advance.
// Main assumptions:
// The source of the tt.advance has been mapped to a reinterpret_cast
// Expected result:
// Directly grab all corresponding fields from reinterpret_cast.
// Add the offsets multiplied by the strides to the final offsets.
static void rewriteAdvanceOp(triton::AdvanceOp op,
ConversionPatternRewriter &rewriter,
llvm::SmallDenseMap<Value, PtrState> &knownPtrs);

// Parse the state of AddPtrOp, insert any instruction needed to
// calculate strides and offsets, build PtrState for this operand, and record
// PtrState for knownPtrs.
Expand Down
3 changes: 2 additions & 1 deletion include/triton-shared/Analysis/UseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ struct UseInfo : public dataflow::AbstractSparseLattice {
}
case UseType::MixUse:
return ChangeResult::NoChange;
default:
llvm_unreachable("bad type");
}
llvm_unreachable("bad type");
}

ChangeResult meet(const AbstractSparseLattice &other) override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ void populateTritonToLinalgCanonicalizationPatterns(
RewritePatternSet &patterns);

void populateTritonToLinalgConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns);
RewritePatternSet &patterns,
unsigned int launchGridRank);

} // namespace triton
} // namespace mlir
Expand Down
115 changes: 114 additions & 1 deletion lib/Analysis/MaskAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,120 @@ MaskState::getSubview(Value source, const Location loc,
source, offsets, dims, strides);
}

static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
auto srcType = cast<MemRefType>(src.getType());
auto dstType =
memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides);
return b.create<memref::SubViewOp>(loc, cast<MemRefType>(dstType), src,
offsets, sizes, strides);
}

// Assume block1 wraps around and the remainder is block2.
//
// |----------------------|
// | | |
// | block2 | block1 |
// | | |
// |----------------------|
//
// Once we copy the chunks in order, the end result is block1 followed by
// block2.
//
// buffer_tmp:
//
// |----------------------|
// | | |
// | block1 | block2 |
// | | |
// |----------------------|
//
// Assume we have the following subview:
//
// +++++++++++++++++-------
// + + |
// + subview + |
// + + |
// +++++++++++++++++-------
//
// If we simply take the subview of `buffer_tmp`, this requires an extra buffer
// to just hold the temporary result.
//
// So we can subview into block1 and block2 directly. There are 2 cases:
// + subview only spans block1
// + subview spans both block1 and block2, creating sv1 and sv2 (illustrated
// below for case when we wrap around side-by-side)
//
// |----------------------------------------|
// | |
// | col2 col1 |
// |++++++--------| |+++++++++++++++
// | sv2 + block2 | | block1 & sv1 +
// |++++++--------| |+++++++++++++++
// | |
// |----------------------------------------|
//
// For simplicity, assume we only wrap around side-by-side.
//
// Let (row, col1) and (row, col2) be the dimensions of block1 and block2,
// respectively.
//
// Let (rowFull, colFull), (rowView1, colView1) and (rowView2, colView2) be the
// dimensions of the full subview, sv1, and sv2, respectively.
//
// + colView1 = min(colFull, col1)
// + colView2 = colFull - colView1
// + rowView1 = rowView2 = row = rowFull
std::pair<memref::SubViewOp, memref::SubViewOp>
MaskState::getSideBySideSubviews(memref::ReinterpretCastOp block1,
memref::ReinterpretCastOp block2,
const Location loc,
ConversionPatternRewriter &rewriter) const {

assert(block1.getResultRank() == 2 && block2.getResultRank() == 2 &&
getRank() == 2);

OpFoldResult subviewRowFull = dims[0];
OpFoldResult subviewColFull = dims[1];
OpFoldResult col1 = block1.getMixedSizes()[1];
OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, rewriter);
OpFoldResult subviewCol2 =
subOFRs(subviewColFull, subviewCol1, loc, rewriter);

SmallVector<OpFoldResult> offsets(getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides = block1.getMixedStrides();
auto sv1 = createSubview(block1.getResult(), loc, rewriter, offsets,
{subviewRowFull, subviewCol1}, strides);
auto sv2 = createSubview(block2.getResult(), loc, rewriter, offsets,
{subviewRowFull, subviewCol2}, strides);

return {sv1, sv2};
}

std::pair<memref::SubViewOp, memref::SubViewOp> MaskState::getStackedSubviews(
memref::ReinterpretCastOp block1, memref::ReinterpretCastOp block2,
const Location loc, ConversionPatternRewriter &rewriter) const {
assert(block1.getResultRank() == 2 && block2.getResultRank() == 2 &&
getRank() == 2);

OpFoldResult subviewRowFull = dims[0];
OpFoldResult subviewColFull = dims[1];
OpFoldResult row1 = block1.getMixedSizes()[0];
OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, rewriter);
OpFoldResult subviewRow2 =
subOFRs(subviewRowFull, subviewRow1, loc, rewriter);

SmallVector<OpFoldResult> offsets(getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides = block1.getMixedStrides();
auto sv1 = createSubview(block1.getResult(), loc, rewriter, offsets,
{subviewRow1, subviewColFull}, strides);
auto sv2 = createSubview(block2.getResult(), loc, rewriter, offsets,
{subviewRow2, subviewColFull}, strides);
return {sv1, sv2};
}

LogicalResult MaskState::addStateScalar(const MaskState &state,
const OpFoldResult scalar, Location loc,
ConversionPatternRewriter &rewriter) {
Expand Down Expand Up @@ -132,7 +246,6 @@ LogicalResult MaskState::parseConstant(arith::ConstantOp constOp,
auto constAttr = rewriter.getIndexAttr(value.getSExtValue());
auto op = arith::ConstantOp::materialize(rewriter, constAttr,
rewriter.getIndexType(), loc);

this->scalar = op.getValue();
} else {
auto value = constOp.getValue().cast<IntegerAttr>().getInt();
Expand Down
Loading

0 comments on commit 959825b

Please sign in to comment.