Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Always swap operands of mfma and use mfma.transposed layout #4767

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ SmallVector<unsigned> getWarpOrder(Attribute layout);

SmallVector<unsigned> getOrder(Attribute layout);

SmallVector<unsigned> getThreadOrder(Attribute layout);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some doc to both this one and the above getOrder so that it's easier to know the difference?


CTALayoutAttr getCTALayout(Attribute layout);

SmallVector<unsigned> getCTAsPerCGA(Attribute layout);
Expand Down
4 changes: 2 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentOrder(sliceEncoding.getParent());
}
return getOrder(layout);
return getThreadOrder(layout);
}

} // namespace
Expand Down Expand Up @@ -75,7 +75,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
threadOffset = threadsPerWarp[sliceLayout.getDim()];
} else {
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
for (unsigned i = 0; i < order.size(); i++) {
if (order[i] == axis)
break;
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getThreadOrder;
using ::mlir::triton::gpu::getTotalElemsPerThread;

namespace {
Expand Down Expand Up @@ -282,7 +283,7 @@ struct ReduceOpConversion

auto threadsPerWarp =
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
Value laneIdAxis = multiDimLaneId[axis];
Expand Down
21 changes: 12 additions & 9 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,6 @@ SmallVector<unsigned> getOrder(Attribute layout) {
auto rank = distributedLayout.getWarpsPerCTA().size();
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(layout);
if (!mfmaLayout)
return order;
// For transposed MFMA layouts, we swap M and N dimensions, which is
// always the first two in order; as we can have an optional batch
// dimension following them.
if (mfmaLayout.getIsTransposed())
std::swap(order[0], order[1]);
return order;
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
Expand All @@ -290,6 +282,14 @@ SmallVector<unsigned> getOrder(Attribute layout) {
return {};
};

SmallVector<unsigned> getThreadOrder(Attribute layout) {
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getThreadOrder();
else
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
return {};
};

CTALayoutAttr getCTALayout(Attribute layout) {
if (auto distributedLayout =
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
Expand Down Expand Up @@ -1536,7 +1536,10 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
auto order = ::getOrder(*this);
if (getIsTransposed())
std::swap(order[0], order[1]);
return order;
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadsPerWarp() const {
unsigned rows, cols;
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
{{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}},
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
// For mfma.transposed layout, the element ownership among threads are
// "transposed" within each warp.
if (getIsTransposed())
tileLayout = LinearLayout(
{{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}},
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
} else {
assert(getMDim() == 16);
// For mfma with 16x16 output, each of the 64 threads holds 4 elements.
Expand All @@ -521,6 +528,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
{{kRegister, {{0, 1}, {0, 2}}},
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
// For mfma.transposed layout, the element ownership among threads are
// "transposed" within each warp.
if (getIsTransposed())
tileLayout = LinearLayout(
{{kRegister, {{1, 0}, {2, 0}}},
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
}
if (hasBatchDim) {
assert(order[2] == 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,23 +269,6 @@ class BlockedToMFMA : public RewritePattern {
: RewritePattern(tt::DotOp::getOperationName(), 2, context),
mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {}

bool isChainDot(tt::DotOp &dotOp) const {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
ForwardSliceOptions fwdOpt;
fwdOpt.filter = filter;
BackwardSliceOptions bwdOpt;
bwdOpt.omitBlockArguments = true;
bwdOpt.filter = filter;
auto slices = getSlice(dotOp, bwdOpt, fwdOpt);
for (Operation *op : slices) {
if (isa<tt::DotOp>(op) && (op != dotOp))
return true;
}
return false;
}

bool isSecondDot(tt::DotOp &dotOp) const {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
Expand Down Expand Up @@ -400,11 +383,10 @@ class BlockedToMFMA : public RewritePattern {
auto warpsPerTile =
warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim});

bool isTransposed = isChainDot(dotOp);
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
oldRetType.getContext(),
/*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile,
/*instrShape*/ mDim, nDim, isTransposed, CTALayout);
/*instrShape*/ mDim, nDim, /*isTransposed*/ true, CTALayout);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments in the above of why we want to always set it as transposed? It makes it easier to follow for others reading the code.


Type mfmaAccType;
if (oldRetType.getElementType().isIntOrIndex())
Expand Down
8 changes: 4 additions & 4 deletions unittest/Dialect/TritonGPU/DialectTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,15 @@ TEST_F(AMDMfmaLayoutTest, mfma32) {

auto tmfma2d = createTransposedMFMA(32, 32, {2, 4});
ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u));

auto mfma3d = createMFMA(32, 32, {2, 4, 1});
ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u));
ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));

auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1});
ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
}

TEST_F(AMDMfmaLayoutTest, mfma16) {
Expand All @@ -577,15 +577,15 @@ TEST_F(AMDMfmaLayoutTest, mfma16) {

auto tmfma2d = createTransposedMFMA(16, 16, {2, 4});
ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u));

auto mfma3d = createMFMA(16, 16, {2, 4, 1});
ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u));
ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));

auto tmfma3d = createTransposedMFMA(16, 16, {2, 4, 1});
ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
}

} // anonymous namespace
Expand Down
4 changes: 2 additions & 2 deletions unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,14 +529,14 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
LinearLayout(
{{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}},
{S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}},
{S("warp"), {{32, 0}, {0, 0}, {0, 0}}},
{S("warp"), {{0, 0}, {0, 0}, {32, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
EXPECT_EQ(toLinearLayout({128, 128}, mfmaT),
LinearLayout(
{{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {64, 0}}},
{S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}},
{S("warp"), {{32, 0}, {0, 32}, {0, 64}}},
{S("warp"), {{0, 32}, {0, 64}, {32, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
}
Expand Down
Loading