diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 70fc7012993c..2836b779bbb9 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -489,6 +489,68 @@ LinearLayout mfmaToLinearLayout(ArrayRef shape, return combineCtaCgaWithShape(ctaLayout, mfma.getCTALayout(), shape); } +LinearLayout wmmaToLinearLayout(ArrayRef shape, + AMDWmmaEncodingAttr wmma) { + int rank = shape.size(); + assert(rank == wmma.getWarpsPerCTA().size()); + + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + int nIndex = 1 + hasBatchDim; + (void)mIndex, (void)nIndex; + + SmallVector mnkDim = wmma.getMNKDimPerWMMAInstr(); + unsigned mDim = mnkDim[0], nDim = mnkDim[1]; + (void)mDim, (void)nDim; + + assert(((shape[mIndex] == 1 || shape[mIndex] >= mDim) && + (shape[nIndex] == 1 || shape[nIndex] >= nDim)) && + "Unsupported tensor shape for given wmma layout"); + + MLIRContext *ctx = wmma.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + + // https://github.com/ROCm/amd_matrix_instruction_calculator can print the + // register and lane layout for mfma instructions. + + // We use the order from fastest varying to slowest varying. So each base + // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices. + SmallVector order = triton::gpu::getOrder(wmma); + + // For wmma with 16x16 output, each of the 32 threads holds 8 elements. + // + // For the register (i.e., element) dimension, these 8 elements are along + // the matrix C's M dimension, with 1 consecutive elements spanning 1 row + // and then the next 1 row being a gap. + // + // For the lane (i.e., thread) dimension, these threads are along the + // matrix C's N dimension, with 16 consecutive threads covering a whole + // row and the next 16 threads start at the next row. + LinearLayout tileLayout( + {{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accomodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + } + + // And each warp takes the same register and lane sub-layout. So mulitply with + // an identity layout for the warp. + LinearLayout warpLayout = + identityND(S("warp"), wmma.getWarpsPerCTA(), order, outDimNames); + LinearLayout ctaLayout = tileLayout * warpLayout; + + return combineCtaCgaWithShape(ctaLayout, wmma.getCTALayout(), shape); +} + std::optional sliceToLinearLayout(ArrayRef shape, SliceEncodingAttr slice) { MLIRContext *ctx = slice.getContext(); @@ -686,6 +748,9 @@ toLinearLayout(ArrayRef shape, Attribute layout, if (auto mfma = dyn_cast(layout)) { return mfmaToLinearLayout(shape, mfma); } + if (auto wmma = dyn_cast(layout)) { + return wmmaToLinearLayout(shape, wmma); + } if (auto mma = dyn_cast(layout)) { if (mma.isAmpere()) { return ampereMmaToLinearLayout(shape, mma); diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 4fb62c5485d0..ecd16095acae 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -50,6 +50,15 @@ class LinearLayoutConversionsTest : public ::testing::Test { isTransposed, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); } + AMDWmmaEncodingAttr wmma(ArrayRef warps) { + SmallVector cpg(warps.size(), 1u); + SmallVector cSplit(warps.size(), 1u); + SmallVector cOrd(warps.size()); + std::iota(cOrd.begin(), cOrd.end(), 0); + return AMDWmmaEncodingAttr::get( + &ctx, warps, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + } + SliceEncodingAttr slice(Attribute parent, int dim) { return SliceEncodingAttr::get(&ctx, dim, parent); } @@ -635,6 +644,79 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4x1Warps) { {S("dim0"), S("dim1"), S("dim2")})); } +TEST_F(LinearLayoutConversionsTest, WMMA_2x4Warps) { + auto legacy = wmma(/*warps=*/{2, 4}); + + EXPECT_EQ(toLinearLayout({16, 16}, legacy), + LinearLayout({{S("register"), {{2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + // For 32x16, we need 2x1 WMMA instances. We have 2x4 warps, so we are + // broadcasted along the warp N dimension, distributed along the warp M + // dimension. + EXPECT_EQ(toLinearLayout({32, 16}, legacy), + LinearLayout({{S("register"), {{2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + // For 16x32, we need 1x2 WMMA instances. We have 2x4 warps, so along the warp + // N dimension, warp 0/2 gets the first distributed instance, warp 1/3 gets + // the second distributed instance. Along the warp M dimension, all are + // broadcasted. + EXPECT_EQ(toLinearLayout({16, 32}, legacy), + LinearLayout({{S("register"), {{2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 16}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + // For 128x128, we need 8x8 WMMA instances. Given that we have 2x4 warps, each + // warp handles 4x2 instances. So for both the warp M and N dimension, we + // distribute. The register dimension will handle (8 x 4x2 =) 64 values--those + // additonal base vectors after the intrinsic shape are next power of two + // values following the warp dimension, given that we are tiling cyclically + // among warps. + EXPECT_EQ(toLinearLayout({128, 128}, legacy), + LinearLayout({{S("register"), + {{2, 0}, {4, 0}, {8, 0}, {0, 64}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_2x4x1Warps) { + auto legacy = wmma(/*warps=*/{2, 4, 1}); + + EXPECT_EQ( + toLinearLayout({1, 16, 16}, legacy), + LinearLayout( + {{S("register"), {{0, 2, 0}, {0, 4, 0}, {0, 8, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 1, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 16, 16}, legacy), + LinearLayout( + {{S("register"), {{0, 2, 0}, {0, 4, 0}, {0, 8, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 1, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({8, 16, 16}, legacy), + LinearLayout( + {{S("register"), + {{0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {2, 0, 0}, {4, 0, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 1, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + TEST_F(LinearLayoutConversionsTest, SliceOfBlocked) { auto parent = blocked({2, 4}, {4, 2}, {2, 2}, {2, 2}, {2, 2}, {1, 0}, {1, 0}); EXPECT_EQ(toLinearLayout({128}, slice(parent, 0)),