Skip to content

Commit

Permalink
fix order related issue with reduceOp
Browse files Browse the repository at this point in the history
In general, we should use getThreadOrder in most places where getOrder
is called. Note that order and threadOrder can be different, and this
is the case for mfma.transposed layout.
  • Loading branch information
zhanglx13 committed Sep 20, 2024
1 parent 9d60f31 commit c525ee0
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 8 deletions.
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ SmallVector<unsigned> getWarpOrder(Attribute layout);

SmallVector<unsigned> getOrder(Attribute layout);

SmallVector<unsigned> getThreadOrder(Attribute layout);

CTALayoutAttr getCTALayout(Attribute layout);

SmallVector<unsigned> getCTAsPerCGA(Attribute layout);
Expand Down
4 changes: 2 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentOrder(sliceEncoding.getParent());
}
return getOrder(layout);
return getThreadOrder(layout);
}

} // namespace
Expand Down Expand Up @@ -75,7 +75,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
threadOffset = threadsPerWarp[sliceLayout.getDim()];
} else {
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
for (unsigned i = 0; i < order.size(); i++) {
if (order[i] == axis)
break;
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getThreadOrder;
using ::mlir::triton::gpu::getTotalElemsPerThread;

namespace {
Expand Down Expand Up @@ -282,7 +283,7 @@ struct ReduceOpConversion

auto threadsPerWarp =
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
Value laneIdAxis = multiDimLaneId[axis];
Expand Down
13 changes: 12 additions & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,14 @@ SmallVector<unsigned> getOrder(Attribute layout) {
return {};
};

SmallVector<unsigned> getThreadOrder(Attribute layout) {
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getThreadOrder();
else
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
return {};
};

CTALayoutAttr getCTALayout(Attribute layout) {
if (auto distributedLayout =
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
Expand Down Expand Up @@ -1528,7 +1536,10 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
auto order = ::getOrder(*this);
if (getIsTransposed())
std::swap(order[0], order[1]);
return order;
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadsPerWarp() const {
unsigned rows, cols;
Expand Down
8 changes: 4 additions & 4 deletions unittest/Dialect/TritonGPU/DialectTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,15 +558,15 @@ TEST_F(AMDMfmaLayoutTest, mfma32) {
ASSERT_THAT(mfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u));

auto tmfma2d = createTransposedMFMA(32, 32, {2, 4});
ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(1u, 0u));
ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u));

auto mfma3d = createMFMA(32, 32, {2, 4, 1});
ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u));
ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));

auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1});
ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u));
ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
}

Expand All @@ -576,15 +576,15 @@ TEST_F(AMDMfmaLayoutTest, mfma16) {
ASSERT_THAT(mfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u));

auto tmfma2d = createTransposedMFMA(16, 16, {2, 4});
ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(1u, 0u));
ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u));

auto mfma3d = createMFMA(16, 16, {2, 4, 1});
ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u));
ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));

auto tmfma3d = createTransposedMFMA(16, 16, {2, 4, 1});
ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u));
ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
}

Expand Down

0 comments on commit c525ee0

Please sign in to comment.