Skip to content

Commit

Permalink
Support tensor of indices as loop iter-arg in sequences of pointer ar…
Browse files Browse the repository at this point in the history
…ithmetic (#180)

This PR adds support for tensor of indices that are updated in each loop
iteration while also being used in pointer arithmetic sequences.

## Approach

Similarly to the pointer types, in the PtrAnalysis pre-pass, we
prematurely generate the `tts.get_structured_state` ops for tensor of
integers. The important note here is we do not need to know whether
these ops will eventually be used in a pointer arithmetic sequence. Any
values that are not used in a pointer arithmetic sequence will be
removed later in the process.

This approach can easily be extended to other kinds of values that might
be used in pointer arithmetic sequences. At a high level,
`tts.get_structured_state` can always be used to "wrap" a triton value.
This op returns two kinds of values: the first value is always of the
same type as the wrapped value, while the remaining values expose the
important fields in `PtrState` that are necessary for codegen in
scf.for.

The first return value of `tts.get_structured_state` is always an SSA
value of the same type as the original value; users of the original
triton value will then use this first return value from
`tts.get_structured_state` instead. With this approach, even if the
original triton value ends up not being used in a pointer arithmetic
sequence, it is very easy to revert the IR to the original form by
simply deleting the `tts.get_structured_state` op and forwarding the
original triton value to its users again.

The other return values then expose the important fields in PtrState
that are necessary to generate the code in loops (offsets and strides).
Within a loop, for every wrapped triton value returned by a
`tts.get_structured_state` op at index `i`, we can always get the
corresponding offsets in each loop iteration at index `i + 1` and
strides at index `i + 2`.

## Changes

+ Updated the pre-pass to insert `tts.get_structured_state` ops that
wrap tensor of indices
+ With the introduction of tensor of indices in loops, we now have to
manually visit the `tts.get_structured_state` ops to generate the ops
for updating PtrState. We previously did not have to do this because
triton pointers always have a `tt.addptr` at the end of each loop, right
before yielding the values, which always triggers the process for
generating the state-update ops
+ Logic for determining whether a loop iter-arg should have its PtrState
updated is improved. We do a BFS-like scan starting from the return
values of `tts.get_structured_state` ops to determine if an iter-arg
originates from a value that may need its PtrState populated
+ Preliminary support for mask sequences being updated in a loop; this
is a bit of a hack and will need more robust implementation if these use
cases appear more frequently.
+ Add tests for various scenarios
  • Loading branch information
nhat-nguyen authored Oct 15, 2024
1 parent 5bd61a0 commit 177a624
Show file tree
Hide file tree
Showing 16 changed files with 880 additions and 91 deletions.
4 changes: 4 additions & 0 deletions include/triton-shared/Analysis/MaskAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "mlir/Support/LogicalResult.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/Support/LogicalResult.h"

#include <utility>

Expand Down Expand Up @@ -137,6 +138,9 @@ struct MaskState {
// dimension that contains the range.
LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp,
const Location loc, OpBuilder &builder);

LogicalResult parseLoopIterArg(Value v, const Location loc,
OpBuilder &builder);
};

} // namespace triton
Expand Down
9 changes: 5 additions & 4 deletions include/triton-shared/AnalysisStructured/PtrAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
Expand Down Expand Up @@ -90,11 +92,10 @@ class PtrAnalysis {
scf::ForOp forOp, size_t ptrArgIndex, const PtrState &state,
llvm::function_ref<Value(scf::ForOp op, size_t)> getReplacementVal);

public:
using IndexMapSet = std::map<int, std::set<int>>;
DenseSet<Value> maybeStructuredArgs;

IndexMapSet levelToBlockArgIndex;
int level = 0;
public:
void initializeMaybeStructuredArgs(Operation *op);

llvm::SmallDenseMap<Value, PtrState> knownPtrs;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ def TritonToStructured : Pass<"triton-to-structured", "mlir::ModuleOp"> {
let constructor = "triton::createTritonToStructuredPass()";
let options = [
Option<"runPrepassOnly", "run-prepass-only", "bool", /*default*/"false",
"Only run the pre-processing pass which inserts tts.get_structured_state ops used in scf.for">
"Only run the pre-processing pass which inserts tts.get_structured_state ops used in scf.for">,
Option<"skipPrepass", "skip-prepass", "bool", /*default*/"false",
"Skip the prepass">
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSe
let summary = "Placeholder for the structured pointer states computed during PtrAnalysis.";
let description = "Used to pass the offsets and strides to scf.for op to simplify IR rewrites.";

let arguments = (ins TT_PtrLike:$ptr);
let results = (outs TT_PtrLike:$structuredPtr, Variadic<Index>:$offsets, Variadic<Index>:$strides);
let arguments = (ins AnyTypeOf<[TT_PtrLike, I32Tensor]>:$input);
let results = (outs AnyTypeOf<[TT_PtrLike, I32Tensor]>:$structured, Variadic<Index>:$offsets, Variadic<Index>:$strides);

let builders = [
OpBuilder<(ins "Value":$ptr)>,
OpBuilder<(ins "Value":$input)>,
];

let extraClassDeclaration = [{
Expand Down
72 changes: 68 additions & 4 deletions lib/Analysis/MaskAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,22 @@
#include "triton-shared/Analysis/MaskAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Support/LogicalResult.h"

#include "triton-shared/Analysis/OpFoldResultUtils.h"

#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/Transforms/DialectConversion.h"

#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include <cassert>

namespace mlir {

namespace triton {
Expand All @@ -38,6 +48,8 @@ LogicalResult MaskState::parse(Value operand, const Location loc,
return this->parseSplat(op, loc, builder);
} else if (auto op = operand.getDefiningOp<triton::ExpandDimsOp>()) {
return this->parseExpandDims(op, loc, builder);
} else if (!operand.getDefiningOp()) {
return this->parseLoopIterArg(operand, loc, builder);
} else if (auto op = operand.getDefiningOp<arith::ExtSIOp>()) {
return this->parseExtSI(op, loc, builder);
} else {
Expand Down Expand Up @@ -109,8 +121,8 @@ static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b,
// + + |
// +++++++++++++++++-------
//
// If we simply take the subview of `buffer_tmp`, this requires an extra buffer
// to just hold the temporary result.
// If we simply take the subview of `buffer_tmp`, this requires an extra
// buffer to just hold the temporary result.
//
// So we can subview into block1 and block2 directly. There are 2 cases:
// + subview only spans block1
Expand All @@ -131,8 +143,8 @@ static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b,
// Let (row, col1) and (row, col2) be the dimensions of block1 and block2,
// respectively.
//
// Let (rowFull, colFull), (rowView1, colView1) and (rowView2, colView2) be the
// dimensions of the full subview, sv1, and sv2, respectively.
// Let (rowFull, colFull), (rowView1, colView1) and (rowView2, colView2) be
// the dimensions of the full subview, sv1, and sv2, respectively.
//
// + colView1 = min(colFull, col1)
// + colView2 = colFull - colView1
Expand Down Expand Up @@ -342,6 +354,58 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
return success();
}

LogicalResult MaskState::parseLoopIterArg(Value v, const Location loc,
OpBuilder &builder) {
assert(!v.getDefiningOp());

auto forOp = llvm::dyn_cast<scf::ForOp>(v.getParentRegion()->getParentOp());

if (!forOp) {
return failure();
}

// TODO: This implementation does not work with nested loops
if (forOp->getParentOfType<scf::ForOp>()) {
return failure();
}

auto it = llvm::find(forOp.getRegionIterArgs(), v);
if (it == forOp.getRegionIterArgs().end()) {
return failure();
}

auto argIndex = std::distance(forOp.getRegionIterArgs().begin(), it);
auto initArg = forOp.getInitArgs()[argIndex];
if (auto getStateOp = initArg.getDefiningOp<tts::GetStructuredStateOp>()) {
auto tritonValue = getStateOp->getOperand(0);
MaskState lhsState;
if (failed(lhsState.parse(tritonValue, loc, builder))) {
return failure();
}

// This is a bit of a hack!!
//
// The offsets and dimensions of a MaskState can now depend on a loop's
// iter-arg.
//
// Because the PtrAnalysis's pre-pass already sets up the offsets,
// we can create a new MaskState for each loop iteration by adding the
// original MaskState with the current iter-arg, which is at `argIndex +
// 1`.
//
// This will not work for nested loop scenarios, which would need a
// more robust implementation.
if (failed(this->addStateScalar(
lhsState, forOp.getRegionIterArgs()[argIndex + 1], loc, builder))) {
return failure();
}

return success();
}

return failure();
}

LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp,
const Location loc,
OpBuilder &builder) {
Expand Down
Loading

0 comments on commit 177a624

Please sign in to comment.