Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Linear Layout with stmatrix part 2: support stmatrix for local_alloc ops #4763

Merged
merged 22 commits into from
Oct 1, 2024
Merged
9 changes: 0 additions & 9 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,6 @@ class TargetInfoBase {
unsigned numLaneToReduce,
unsigned interleave) const = 0;

// TODO (Keren): Remove this function once layout conversion using stmatrix is
// handled by Linear Layout.
virtual bool processReplicaUsingStMatrix(
RewriterBase &rewriter, Location loc, Value smemBase,
SmallVector<Value> &vals, RankedTensorType srcTy, Type elemTy,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> origRepShape,
ArrayRef<unsigned> outOrd, unsigned accumNumReplicates,
int swizzleByteWidth = 0) const = 0;

virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
// Emits LLVM code with |rewriter| to print a message following the given
// format from the device. |formatStrStart| is the pointer to the start of
Expand Down
128 changes: 125 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,134 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
// row0 reg[0-1] reg[4-5]
// row8 reg[2-3] reg[6-7]
//
// When `swizzleByteSize` is non-zero, the layout is constructed
// differently due to leading dimension offset and swizzling.
// There are two key concepts to understand:
//
// 1. Chunks: The leading dimension (i.e., the column dimension) is divided
// into chunks, where each chunk's size is determined by `swizzleByteSize`.
// 2. Swizzling within tiles: Each tile applies a swizzling pattern to its
// rows to optimize memory access.
//
// - Concept 1: Chunks
//
// In the swizzled layout, the leading dimension is strided by
// `swizzleByteSize`. This introduces the concept of a "chunk", where each chunk
// spans a certain number of columns.
//
// For a tile size of `stmatrix.x4` (16x16 elements), with each element being 16
// bits (2 bytes), each tile occupies 16 rows and 32 bytes per row (since 16
// elements * 2 bytes per element = 32 bytes per row).
//
// Given a `swizzleByteSize` of 128 bytes, the number of tiles per chunk can be
// calculated as:
//
// Number of tiles per chunk = swizzleByteSize / (bytes per row) = 128 bytes /
// 32 bytes = 4 tiles
//
// Therefore, each chunk contains 4 tiles horizontally, spanning 64 columns
// (since each tile is 16 columns):
//
// col0-15 col16-31 col32-47 col48-63
// row0-15 tile0 tile1 tile2 tile3
//
// For a tensor of size 128x128 elements (#rows x #columns), and each element
// being 16 bits, the tensor can be divided into multiple chunks both
// horizontally and vertically. Chunks are stored in memory in a "column-major"
// order based on chunks, meaning chunk1's address follows chunk0's.
//
// Assuming we have 8 warps, and we assign each warp to process a chunk of 16
// rows (rows per tile) and 128 columns (the width of two chunks). This results
// in each warp handling one horizontal slice of the tensor.
//
// The overall layout can be visualized as:
//
// |<- 128 * 128 bytes ->|<- 128 * 128 bytes ->|
// columns 0-63 columns 64-127
// warp0 | rows 0-15 chunk0 chunk8
// warp1 | rows 16-31 chunk1 chunk9
// warp2 | rows 32-47 chunk2 chunk10
// warp3 | rows 48-63 chunk3 chunk11
// warp4 | rows 64-79 chunk4 chunk12
// warp5 | rows 80-95 chunk5 chunk13
// warp6 | rows 96-111 chunk6 chunk14
// warp7 | rows 112-127 chunk7 chunk15
//
// - Concept 2: Swizzling within tiles
//
// Within each 16x16 tile, rows are swizzled to optimize memory access patterns.
// This swizzling is similar to what's defined in `TritonGPUAttrDefs.td`. at the
// level of each 16x16 tile rather than the entire tensor.
//
// Key parameters for swizzling:
//
// - `perPhase`: The number of rows over which to apply a XOR operation at
// each phase.
// - `maxPhase`: The total number of phases.
// - `vectorWidth`: The number of elements per vector, which is 8 in this case
// because `stmatrix` stores 8 contiguous elements per thread.
//
// The offset of each element within a tile is calculated using the formula:
//
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
// maxPhase)) * elementSize
//
// where `elementSize` is the size of each element in bytes (2 bytes for 16-bit
// elements).
//
// For example, consider the element at index `(row=1, col=0)` in chunk0:
//
// Without swizzling:
//
// offset = row * swizzleByteSize + col * elementSize
// = 1 * 128 bytes + 0 * 2 bytes
// = 128 bytes
//
// With swizzling (assuming `perPhase=1`, `maxPhase=8`, `vectorWidth=8`):
//
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
// maxPhase)) * elementSize
// = 1 * 128 bytes + (8 * ((1 / 1) % 8)) * 2 bytes
// = 128 bytes + (8 * (1 % 8)) * 2 bytes
// = 128 bytes + 8 * 2 bytes
// = 128 bytes + 16 bytes
// = 144 bytes
//
// This swizzling ensures that elements are stored in a way that optimizes for
// memory bandwidth and reduces bank conflicts.
//
// - Verification through Linear Layout
//
// We can verify the offsets with the following outputs of the corresponding
// linear layout, where each element is 16 bits (2 bytes):
//
// - register=1 -> offset=1
// register=2 -> offset=2
// register=4 -> offset=4
// register=8 -> offset=16
// register=16 -> offset=32
// register=32 -> offset=8192
// - lane=1 -> offset=72
// lane=2 -> offset=144
// lane=4 -> offset=288
// lane=8 -> offset=512
// lane=16 -> offset=8
// - warp=1 -> offset=1024
// warp=2 -> offset=2048
// warp=4 -> offset=4096
//
// For index `(row=1, col=0)`, which corresponds to `reg=0` and `lane=1` in
// `warp=0`, the offset is calculated as 72 * 2 bytes = 144 bytes. The result
// matches our earlier calculation.
//
// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
// bit width of the tensor in the future to support more flexible tensor
// encodings
std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order);
std::optional<LinearLayout>
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize);
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
18 changes: 6 additions & 12 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,9 @@ struct ConvertLayoutOpConversion
if (repId != 0) {
barrier();
}
auto successful = targetInfo.processReplicaUsingStMatrix(
rewriter, loc, smemBase, vals, srcTy,
getTypeConverter()->convertType(srcTy.getElementType()),
paddedRepShape, origRepShape, outOrd, accumNumReplicates);
if (!successful) {
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
multiDimRepId, inVec, paddedRepShape, origRepShape,
outOrd, vals, smemBase);
}
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd,
vals, smemBase);
barrier();
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
multiDimRepId, outVec, paddedRepShape, origRepShape,
Expand Down Expand Up @@ -483,9 +477,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// Input dims: [reg, lane, warp]
// Output dims: [offset, iteration]
std::optional<LinearLayout> shmemStoreLayout =
chooseStMatrixLayoutForRegToRegConversion(
ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order);
chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0);
bool isStMatrix = shmemStoreLayout.has_value();
if (!isStMatrix) {
shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout);
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
RankedTensorType dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
// TODO: do we need to check if src is shared ?
if (isa<SharedEncodingAttr>(srcLayout) &&
isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
dstLayout)) {
Expand Down
103 changes: 96 additions & 7 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,8 +806,8 @@ namespace {
// stmatrix. These restrictions are retained from legacy code, and we could
// relax some of them in the future.
bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order) {
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
int swizzleByteSize) {
auto mmaLayout =
mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
if (!mmaLayout || !mmaLayout.isHopper())
Expand All @@ -826,17 +826,87 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
return false;
if (paddedRepShape[1] % 8 != 0)
return false;
if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 &&
swizzleByteSize != 128)
return false;
return true;
}

} // anonymous namespace
std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably this function should have some unit tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I will add a test when investigating peter's issue. Seems like there're still some problems.
#4727

MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
int swizzleByteSize) {
StringAttr kReg = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");
StringAttr kCol = S("dim1");
StringAttr kRow = S("dim0");
StringAttr kOffset = S("offset");

int perPhase;
int maxPhase;
if (swizzleByteSize == 32) {
perPhase = 4;
maxPhase = 2;
} else if (swizzleByteSize == 64) {
perPhase = 2;
maxPhase = 4;
} else if (swizzleByteSize == 128) {
perPhase = 1;
maxPhase = 8;
} else {
llvm::errs() << "Illegal swizzleByteSize: " << swizzleByteSize << "\n";
llvm::report_fatal_error("Illegal swizzleByteSize");
}
jlebar marked this conversation as resolved.
Show resolved Hide resolved

// stmatrix only supports 16-bit elements, and each vector has 8 elements
int elemBitWidth = 16;
int vecSize = 8;
int numRows = 16;
int numCols = 8 * swizzleByteSize / elemBitWidth;

// Construct a single stmatrix.x4 (16x16) tile
std::vector<std::vector<int>> basesReg = {{1, 0}, {2, 0}, {4, 0}};
std::vector<std::vector<int>> basesLane;
for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) {
int row = 1 << logRow;
basesLane.push_back({vecSize * ((row / perPhase) % maxPhase), row});
}
basesLane.push_back({8, 0});

// Expand the tile's register dimension to fit swizzleByteSize, which is a
// "chunk"
for (int logChunk = 0; logChunk < llvm::Log2_32(numCols / 16); logChunk++) {
int chunk = 1 << logChunk;
basesReg.push_back({16 * chunk, 0});
}

// Construct the layout for a single chunk
LinearLayout layout =
LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow});

std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
// Expand the `warp` dimension according to warpsPerCTA.
auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
layout *=
identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));

// Expand the `register` dimension so the size of columns matches `n`.
int n = mma.getInstrShape()[1];
int numWarpRows = layout.getOutDimSize(kRow);
layout = (layout.reshapeOuts({{kOffset, layout.getTotalOutDimSize()}}) *
LinearLayout::identity1D(n / numCols, kReg, kOffset))
.reshapeOuts({{kCol, n}, {kRow, numWarpRows}});

auto ret =
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames()))
.reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}});
}

std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order) {
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order))
return std::nullopt;

StringAttr kReg = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");
Expand Down Expand Up @@ -866,4 +936,23 @@ std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
}

} // anonymous namespace

std::optional<LinearLayout>
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize) {
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order,
swizzleByteSize))
return std::nullopt;

if (swizzleByteSize == 0)
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape,
paddedRepShape, order);
else
return chooseStMatrixLayoutLeadingOffset(
ctx, tensorTy, repShape, paddedRepShape, order, swizzleByteSize);
}

} // namespace mlir::triton::gpu
9 changes: 0 additions & 9 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,6 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
return false;
}

bool TargetInfo::processReplicaUsingStMatrix(
RewriterBase &rewriter, Location loc, Value smemBase,
SmallVector<Value> &vals, RankedTensorType srcTy, Type elemTy,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> origRepShape,
ArrayRef<unsigned> outOrd, unsigned accumNumReplicates,
int swizzleByteWidth) const {
return false;
}

void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount,
ValueRange args, RewriterBase &rewriter,
bool useStdErr) const {
Expand Down
9 changes: 0 additions & 9 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
triton::ReduceOp op, unsigned numLaneToReduce,
unsigned interleave) const override;

bool processReplicaUsingStMatrix(RewriterBase &rewriter, Location loc,
Value smemBase, SmallVector<Value> &vals,
RankedTensorType srcTy, Type elemTy,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> origRepShape,
ArrayRef<unsigned> outOrd,
unsigned accumNumReplicates,
int swizzleByteWidth) const override;

std::string getMulhiFuncName(Type resultElementTy) const override;

void printf(RewriterBase &rewriter, Value formatStrStart,
Expand Down
Loading
Loading