diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index c7fafa6b5721..34db6536e429 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -230,6 +230,8 @@ compared to 1*64 when the hasLeadingOffset is false. // ---- begin GFX908/GFX90A ---- if (auto mfmaEnc = dotOpEnc.getParent().dyn_cast()) { int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0; + if (needTrans) + kDimNum = 1 - kDimNum; bool isKDimInner = (order[0] == kDimNum); if (isKDimInner) { const int numBanks = 32;