diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 3b012a630541..74ea99b58891 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -75,9 +75,32 @@ getThreadsPerWarpWithUniqueData(Attribute layout, SmallVector getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); +// Returns the dimensions of the tensor from minor (fast-varying) to +// major (slow-varying). For blocked, mma, and dotOperand layouts, +// though the elements are in registers, the order refers to memory +// layout of the original tensor in global memory. +// For shared Layout, the order refers to which dimension of the original tensor +// is contiguous in shared memory. +SmallVector getOrder(Attribute layout); + +// Returns the dimensions along which warpId's are distributed. +// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4] +// tells there are 2 warps along dim0 and 4 warps along dim1. +// warpOrder tells the specific order when distributing warp IDs. +// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows +// [warp0 warp2 warp4 warp6] +// [warp1 warp3 warp5 warp7] +// Note that in most cases, getWarpOrder and getOrder return the same results. +// But this is not guaranteed. SmallVector getWarpOrder(Attribute layout); -SmallVector getOrder(Attribute layout); +// Returns the dimensions along which threadId's are distributed. +// Similar to warpOrder, threadOrder is necessary to tell the specific thread +// distribution in the warp. +// Note that, in most cases, getThreadOrder and getOrder return the same +// results. But this is not guaranteed. One exception is mfma.transposed layout, +// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1]. +SmallVector getThreadOrder(Attribute layout); CTALayoutAttr getCTALayout(Attribute layout); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 56630c731858..b9468aa3e380 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -36,7 +36,7 @@ SmallVector getParentOrder(Attribute layout) { if (auto sliceEncoding = mlir::dyn_cast(layout)) { return getParentOrder(sliceEncoding.getParent()); } - return getOrder(layout); + return getThreadOrder(layout); } } // namespace @@ -75,7 +75,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { threadOffset = threadsPerWarp[sliceLayout.getDim()]; } else { auto threadsPerWarp = getThreadsPerWarp(srcLayout); - auto order = getOrder(srcLayout); + auto order = getThreadOrder(srcLayout); for (unsigned i = 0; i < order.size(); i++) { if (order[i] == axis) break; diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 16c9991a17b0..414328be50cf 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -9,6 +9,7 @@ using namespace mlir::triton; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getThreadOrder; using ::mlir::triton::gpu::getTotalElemsPerThread; namespace { @@ -271,7 +272,7 @@ struct ReduceOpConversion auto threadsPerWarp = triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape); - auto order = getOrder(srcLayout); + auto order = getThreadOrder(srcLayout); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); Value laneIdAxis = multiDimLaneId[axis]; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a454fef56674..48f31bdf2a9d 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -256,14 +256,6 @@ SmallVector getOrder(Attribute layout) { auto rank = distributedLayout.getWarpsPerCTA().size(); SmallVector order(rank); std::iota(order.rbegin(), order.rend(), 0); - auto mfmaLayout = dyn_cast(layout); - if (!mfmaLayout) - return order; - // For transposed MFMA layouts, we swap M and N dimensions, which is - // always the first two in order; as we can have an optional batch - // dimension following them. - if (mfmaLayout.getIsTransposed()) - std::swap(order[0], order[1]); return order; } if (auto dotLayout = dyn_cast(layout)) { @@ -290,6 +282,14 @@ SmallVector getOrder(Attribute layout) { return {}; }; +SmallVector getThreadOrder(Attribute layout) { + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getThreadOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); + return {}; +}; + CTALayoutAttr getCTALayout(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { @@ -1536,7 +1536,10 @@ SmallVector AMDMfmaEncodingAttr::getWarpOrder() const { return ::getWarpOrder(*this); } SmallVector AMDMfmaEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + auto order = ::getOrder(*this); + if (getIsTransposed()) + std::swap(order[0], order[1]); + return order; } SmallVector AMDMfmaEncodingAttr::getThreadsPerWarp() const { unsigned rows, cols; diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 286b1eac519c..f576ce215417 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -507,6 +507,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { {{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}}, {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}}, {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); } else { assert(getMDim() == 16); // For mfma with 16x16 output, each of the 64 threads holds 4 elements. @@ -521,6 +528,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { {{kRegister, {{0, 1}, {0, 2}}}, {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}}, {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); } if (hasBatchDim) { assert(order[2] == 0); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index bf976a8138dc..21b74ecf99fa 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -269,23 +269,6 @@ class BlockedToMFMA : public RewritePattern { : RewritePattern(tt::DotOp::getOperationName(), 2, context), mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {} - bool isChainDot(tt::DotOp &dotOp) const { - auto filter = [&dotOp](Operation *op) { - return op->getParentRegion() == dotOp->getParentRegion(); - }; - ForwardSliceOptions fwdOpt; - fwdOpt.filter = filter; - BackwardSliceOptions bwdOpt; - bwdOpt.omitBlockArguments = true; - bwdOpt.filter = filter; - auto slices = getSlice(dotOp, bwdOpt, fwdOpt); - for (Operation *op : slices) { - if (isa(op) && (op != dotOp)) - return true; - } - return false; - } - bool isSecondDot(tt::DotOp &dotOp) const { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); @@ -400,11 +383,12 @@ class BlockedToMFMA : public RewritePattern { auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim}); - bool isTransposed = isChainDot(dotOp); + // Always use transposed mfma layout. This enables larger vectorization + // for global store instructions mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), /*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile, - /*instrShape*/ mDim, nDim, isTransposed, CTALayout); + /*instrShape*/ mDim, nDim, /*isTransposed*/ true, CTALayout); Type mfmaAccType; if (oldRetType.getElementType().isIntOrIndex()) diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index 67b1d9e9bce0..e3f521f1b3da 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -559,7 +559,7 @@ TEST_F(AMDMfmaLayoutTest, mfma32) { auto tmfma2d = createTransposedMFMA(32, 32, {2, 4}); ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u)); - ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto mfma3d = createMFMA(32, 32, {2, 4, 1}); ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); @@ -567,7 +567,7 @@ TEST_F(AMDMfmaLayoutTest, mfma32) { auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1}); ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u)); - ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); } TEST_F(AMDMfmaLayoutTest, mfma16) { @@ -577,7 +577,7 @@ TEST_F(AMDMfmaLayoutTest, mfma16) { auto tmfma2d = createTransposedMFMA(16, 16, {2, 4}); ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u)); - ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto mfma3d = createMFMA(16, 16, {2, 4, 1}); ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); @@ -585,7 +585,7 @@ TEST_F(AMDMfmaLayoutTest, mfma16) { auto tmfma3d = createTransposedMFMA(16, 16, {2, 4, 1}); ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u)); - ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); } } // anonymous namespace diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 0b7a0f78211d..7d918602a705 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -529,14 +529,14 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}}, {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, - {S("warp"), {{32, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, {S("block"), {}}}, {S("dim0"), S("dim1")})); EXPECT_EQ(toLinearLayout({128, 128}, mfmaT), LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {64, 0}}}, {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, - {S("warp"), {{32, 0}, {0, 32}, {0, 64}}}, + {S("warp"), {{0, 32}, {0, 64}, {32, 0}}}, {S("block"), {}}}, {S("dim0"), S("dim1")})); }