Skip to content

Commit

Permalink
[AMD] Support convert WMMA to linear layout (#4134)
Browse files Browse the repository at this point in the history
This commit adds WMMA to linear layout conversion.
  • Loading branch information
antiagainst committed Jun 13, 2024
1 parent 2329531 commit e8445b1
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 0 deletions.
65 changes: 65 additions & 0 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,68 @@ LinearLayout mfmaToLinearLayout(ArrayRef<int64_t> shape,
return combineCtaCgaWithShape(ctaLayout, mfma.getCTALayout(), shape);
}

LinearLayout wmmaToLinearLayout(ArrayRef<int64_t> shape,
AMDWmmaEncodingAttr wmma) {
int rank = shape.size();
assert(rank == wmma.getWarpsPerCTA().size());

bool hasBatchDim = rank == 3;
int mIndex = 0 + hasBatchDim;
int nIndex = 1 + hasBatchDim;
(void)mIndex, (void)nIndex;

SmallVector<unsigned> mnkDim = wmma.getMNKDimPerWMMAInstr();
unsigned mDim = mnkDim[0], nDim = mnkDim[1];
(void)mDim, (void)nDim;

assert(((shape[mIndex] == 1 || shape[mIndex] >= mDim) &&
(shape[nIndex] == 1 || shape[nIndex] >= nDim)) &&
"Unsupported tensor shape for given wmma layout");

MLIRContext *ctx = wmma.getContext();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);

StringAttr kRegister = S("register");
StringAttr kLane = S("lane");

// https://github.com/ROCm/amd_matrix_instruction_calculator can print the
// register and lane layout for mfma instructions.

// We use the order from fastest varying to slowest varying. So each base
// vector is a tuple of values mapping to matrix C's (N, M[, B]) indices.
SmallVector<unsigned> order = triton::gpu::getOrder(wmma);

// For wmma with 16x16 output, each of the 32 threads holds 8 elements.
//
// For the register (i.e., element) dimension, these 8 elements are along
// the matrix C's M dimension, with 1 consecutive elements spanning 1 row
// and then the next 1 row being a gap.
//
// For the lane (i.e., thread) dimension, these threads are along the
// matrix C's N dimension, with 16 consecutive threads covering a whole
// row and the next 16 threads start at the next row.
LinearLayout tileLayout(
{{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}},
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}},
{outDimNames[order[0]], outDimNames[order[1]]});

if (hasBatchDim) {
assert(order[2] == 0);
// Extend the base vector with one value to accomodate for the batch
// dimension, which appears at the last.
tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]);
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]);
}

// And each warp takes the same register and lane sub-layout. So mulitply with
// an identity layout for the warp.
LinearLayout warpLayout =
identityND(S("warp"), wmma.getWarpsPerCTA(), order, outDimNames);
LinearLayout ctaLayout = tileLayout * warpLayout;

return combineCtaCgaWithShape(ctaLayout, wmma.getCTALayout(), shape);
}

std::optional<LinearLayout> sliceToLinearLayout(ArrayRef<int64_t> shape,
SliceEncodingAttr slice) {
MLIRContext *ctx = slice.getContext();
Expand Down Expand Up @@ -686,6 +748,9 @@ toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
if (auto mfma = dyn_cast<AMDMfmaEncodingAttr>(layout)) {
return mfmaToLinearLayout(shape, mfma);
}
if (auto wmma = dyn_cast<AMDWmmaEncodingAttr>(layout)) {
return wmmaToLinearLayout(shape, wmma);
}
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
if (mma.isAmpere()) {
return ampereMmaToLinearLayout(shape, mma);
Expand Down
82 changes: 82 additions & 0 deletions unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ class LinearLayoutConversionsTest : public ::testing::Test {
isTransposed, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd));
}

AMDWmmaEncodingAttr wmma(ArrayRef<unsigned> warps) {
SmallVector<unsigned> cpg(warps.size(), 1u);
SmallVector<unsigned> cSplit(warps.size(), 1u);
SmallVector<unsigned> cOrd(warps.size());
std::iota(cOrd.begin(), cOrd.end(), 0);
return AMDWmmaEncodingAttr::get(
&ctx, warps, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd));
}

SliceEncodingAttr slice(Attribute parent, int dim) {
return SliceEncodingAttr::get(&ctx, dim, parent);
}
Expand Down Expand Up @@ -635,6 +644,79 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4x1Warps) {
{S("dim0"), S("dim1"), S("dim2")}));
}

TEST_F(LinearLayoutConversionsTest, WMMA_2x4Warps) {
auto legacy = wmma(/*warps=*/{2, 4});

EXPECT_EQ(toLinearLayout({16, 16}, legacy),
LinearLayout({{S("register"), {{2, 0}, {4, 0}, {8, 0}}},
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}},
{S("warp"), {{0, 0}, {0, 0}, {0, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
// For 32x16, we need 2x1 WMMA instances. We have 2x4 warps, so we are
// broadcasted along the warp N dimension, distributed along the warp M
// dimension.
EXPECT_EQ(toLinearLayout({32, 16}, legacy),
LinearLayout({{S("register"), {{2, 0}, {4, 0}, {8, 0}}},
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}},
{S("warp"), {{0, 0}, {0, 0}, {16, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
// For 16x32, we need 1x2 WMMA instances. We have 2x4 warps, so along the warp
// N dimension, warp 0/2 gets the first distributed instance, warp 1/3 gets
// the second distributed instance. Along the warp M dimension, all are
// broadcasted.
EXPECT_EQ(toLinearLayout({16, 32}, legacy),
LinearLayout({{S("register"), {{2, 0}, {4, 0}, {8, 0}}},
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}},
{S("warp"), {{0, 16}, {0, 0}, {0, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
// For 128x128, we need 8x8 WMMA instances. Given that we have 2x4 warps, each
// warp handles 4x2 instances. So for both the warp M and N dimension, we
// distribute. The register dimension will handle (8 x 4x2 =) 64 values--those
// additonal base vectors after the intrinsic shape are next power of two
// values following the warp dimension, given that we are tiling cyclically
// among warps.
EXPECT_EQ(toLinearLayout({128, 128}, legacy),
LinearLayout({{S("register"),
{{2, 0}, {4, 0}, {8, 0}, {0, 64}, {32, 0}, {64, 0}}},
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}},
{S("warp"), {{0, 16}, {0, 32}, {16, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, WMMA_2x4x1Warps) {
auto legacy = wmma(/*warps=*/{2, 4, 1});

EXPECT_EQ(
toLinearLayout({1, 16, 16}, legacy),
LinearLayout(
{{S("register"), {{0, 2, 0}, {0, 4, 0}, {0, 8, 0}}},
{S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 1, 0}}},
{S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1"), S("dim2")}));
EXPECT_EQ(
toLinearLayout({2, 16, 16}, legacy),
LinearLayout(
{{S("register"), {{0, 2, 0}, {0, 4, 0}, {0, 8, 0}}},
{S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 1, 0}}},
{S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1"), S("dim2")}));
EXPECT_EQ(
toLinearLayout({8, 16, 16}, legacy),
LinearLayout(
{{S("register"),
{{0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {2, 0, 0}, {4, 0, 0}}},
{S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 1, 0}}},
{S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1"), S("dim2")}));
}

TEST_F(LinearLayoutConversionsTest, SliceOfBlocked) {
auto parent = blocked({2, 4}, {4, 2}, {2, 2}, {2, 2}, {2, 2}, {1, 0}, {1, 0});
EXPECT_EQ(toLinearLayout({128}, slice(parent, 0)),
Expand Down

0 comments on commit e8445b1

Please sign in to comment.