From dde7751af3c3eb1e4835c501201f3a04a7bd6517 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Thu, 19 Sep 2024 23:04:15 -0500 Subject: [PATCH 1/3] Always swap operands of mfma and use mfma.transposed layout Also fixed the issue with getOrder for mfma layout --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 8 -------- .../TritonGPU/IR/LinearLayoutConversions.cpp | 14 +++++++++++++ .../AccelerateAMDMatmul.cpp | 20 +------------------ 3 files changed, 15 insertions(+), 27 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a454fef56674..3b2ea7815353 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -256,14 +256,6 @@ SmallVector getOrder(Attribute layout) { auto rank = distributedLayout.getWarpsPerCTA().size(); SmallVector order(rank); std::iota(order.rbegin(), order.rend(), 0); - auto mfmaLayout = dyn_cast(layout); - if (!mfmaLayout) - return order; - // For transposed MFMA layouts, we swap M and N dimensions, which is - // always the first two in order; as we can have an optional batch - // dimension following them. - if (mfmaLayout.getIsTransposed()) - std::swap(order[0], order[1]); return order; } if (auto dotLayout = dyn_cast(layout)) { diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 286b1eac519c..f576ce215417 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -507,6 +507,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { {{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}}, {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}}, {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); } else { assert(getMDim() == 16); // For mfma with 16x16 output, each of the 64 threads holds 4 elements. @@ -521,6 +528,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { {{kRegister, {{0, 1}, {0, 2}}}, {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}}, {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); } if (hasBatchDim) { assert(order[2] == 0); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index bf976a8138dc..a6b3099c88dc 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -269,23 +269,6 @@ class BlockedToMFMA : public RewritePattern { : RewritePattern(tt::DotOp::getOperationName(), 2, context), mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {} - bool isChainDot(tt::DotOp &dotOp) const { - auto filter = [&dotOp](Operation *op) { - return op->getParentRegion() == dotOp->getParentRegion(); - }; - ForwardSliceOptions fwdOpt; - fwdOpt.filter = filter; - BackwardSliceOptions bwdOpt; - bwdOpt.omitBlockArguments = true; - bwdOpt.filter = filter; - auto slices = getSlice(dotOp, bwdOpt, fwdOpt); - for (Operation *op : slices) { - if (isa(op) && (op != dotOp)) - return true; - } - return false; - } - bool isSecondDot(tt::DotOp &dotOp) const { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); @@ -400,11 +383,10 @@ class BlockedToMFMA : public RewritePattern { auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim}); - bool isTransposed = isChainDot(dotOp); mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), /*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile, - /*instrShape*/ mDim, nDim, isTransposed, CTALayout); + /*instrShape*/ mDim, nDim, /*isTransposed*/ true, CTALayout); Type mfmaAccType; if (oldRetType.getElementType().isIntOrIndex()) From 9d60f3133528e08c9d85138c522a36a249797064 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Fri, 20 Sep 2024 09:59:50 -0500 Subject: [PATCH 2/3] Fix mfmaT unit tests --- unittest/Dialect/TritonGPU/DialectTest.cpp | 16 ++++++++-------- .../TritonGPU/LinearLayoutConversionsTest.cpp | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index 67b1d9e9bce0..8e6230bfcc80 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -558,16 +558,16 @@ TEST_F(AMDMfmaLayoutTest, mfma32) { ASSERT_THAT(mfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto tmfma2d = createTransposedMFMA(32, 32, {2, 4}); - ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u)); - ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto mfma3d = createMFMA(32, 32, {2, 4, 1}); ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1}); - ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u)); - ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); } TEST_F(AMDMfmaLayoutTest, mfma16) { @@ -576,16 +576,16 @@ TEST_F(AMDMfmaLayoutTest, mfma16) { ASSERT_THAT(mfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto tmfma2d = createTransposedMFMA(16, 16, {2, 4}); - ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u)); - ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto mfma3d = createMFMA(16, 16, {2, 4, 1}); ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); auto tmfma3d = createTransposedMFMA(16, 16, {2, 4, 1}); - ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u)); - ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); } } // anonymous namespace diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 0b7a0f78211d..7d918602a705 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -529,14 +529,14 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}}, {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, - {S("warp"), {{32, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, {S("block"), {}}}, {S("dim0"), S("dim1")})); EXPECT_EQ(toLinearLayout({128, 128}, mfmaT), LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {64, 0}}}, {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, - {S("warp"), {{32, 0}, {0, 32}, {0, 64}}}, + {S("warp"), {{0, 32}, {0, 64}, {32, 0}}}, {S("block"), {}}}, {S("dim0"), S("dim1")})); } From c525ee0ca9dd21318df4b7442e2a849afba98c2e Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Fri, 20 Sep 2024 14:12:43 -0500 Subject: [PATCH 3/3] fix order related issue with reduceOp In general, we should use getThreadOrder in most places where getOrder is called. Note that order and threadOrder can be different, and this is the case for mfma.transposed layout. --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 2 ++ lib/Analysis/Utility.cpp | 4 ++-- lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp | 3 ++- lib/Dialect/TritonGPU/IR/Dialect.cpp | 13 ++++++++++++- unittest/Dialect/TritonGPU/DialectTest.cpp | 8 ++++---- 5 files changed, 22 insertions(+), 8 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 3b012a630541..bb43b80e5841 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -79,6 +79,8 @@ SmallVector getWarpOrder(Attribute layout); SmallVector getOrder(Attribute layout); +SmallVector getThreadOrder(Attribute layout); + CTALayoutAttr getCTALayout(Attribute layout); SmallVector getCTAsPerCGA(Attribute layout); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 679dcc88d788..a49673b36119 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -36,7 +36,7 @@ SmallVector getParentOrder(Attribute layout) { if (auto sliceEncoding = mlir::dyn_cast(layout)) { return getParentOrder(sliceEncoding.getParent()); } - return getOrder(layout); + return getThreadOrder(layout); } } // namespace @@ -75,7 +75,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { threadOffset = threadsPerWarp[sliceLayout.getDim()]; } else { auto threadsPerWarp = getThreadsPerWarp(srcLayout); - auto order = getOrder(srcLayout); + auto order = getThreadOrder(srcLayout); for (unsigned i = 0; i < order.size(); i++) { if (order[i] == axis) break; diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 30de13f6a88c..ac44ab2f9d04 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -12,6 +12,7 @@ using namespace mlir::triton; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getThreadOrder; using ::mlir::triton::gpu::getTotalElemsPerThread; namespace { @@ -282,7 +283,7 @@ struct ReduceOpConversion auto threadsPerWarp = triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape); - auto order = getOrder(srcLayout); + auto order = getThreadOrder(srcLayout); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); Value laneIdAxis = multiDimLaneId[axis]; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 3b2ea7815353..48f31bdf2a9d 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -282,6 +282,14 @@ SmallVector getOrder(Attribute layout) { return {}; }; +SmallVector getThreadOrder(Attribute layout) { + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getThreadOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); + return {}; +}; + CTALayoutAttr getCTALayout(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { @@ -1528,7 +1536,10 @@ SmallVector AMDMfmaEncodingAttr::getWarpOrder() const { return ::getWarpOrder(*this); } SmallVector AMDMfmaEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + auto order = ::getOrder(*this); + if (getIsTransposed()) + std::swap(order[0], order[1]); + return order; } SmallVector AMDMfmaEncodingAttr::getThreadsPerWarp() const { unsigned rows, cols; diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index 8e6230bfcc80..e3f521f1b3da 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -558,7 +558,7 @@ TEST_F(AMDMfmaLayoutTest, mfma32) { ASSERT_THAT(mfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto tmfma2d = createTransposedMFMA(32, 32, {2, 4}); - ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u)); ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto mfma3d = createMFMA(32, 32, {2, 4, 1}); @@ -566,7 +566,7 @@ TEST_F(AMDMfmaLayoutTest, mfma32) { ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1}); - ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u)); ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); } @@ -576,7 +576,7 @@ TEST_F(AMDMfmaLayoutTest, mfma16) { ASSERT_THAT(mfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto tmfma2d = createTransposedMFMA(16, 16, {2, 4}); - ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u)); ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto mfma3d = createMFMA(16, 16, {2, 4, 1}); @@ -584,7 +584,7 @@ TEST_F(AMDMfmaLayoutTest, mfma16) { ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); auto tmfma3d = createTransposedMFMA(16, 16, {2, 4, 1}); - ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u)); ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); }