Skip to content

Commit

Permalink
[AMD] Refactor decompose-unsupported-amd-conversions pass
Browse files Browse the repository at this point in the history
This PR:
- Simplifying pass code, reusing common code
- Introduces supports 3d tensors in mfma -> dot conversion(supported in common code from item above)
- Adds more lit tests for decompose-unsupported-amd-conversions pass
  • Loading branch information
binarman committed Jul 30, 2024
1 parent 2db5668 commit b23e154
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 101 deletions.
33 changes: 33 additions & 0 deletions test/Conversion/amd/decompose-unsupported-conversions-gfx9.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx942 | FileCheck %s

// CHECK: #[[DST_ENC:.+]] = #triton_gpu.blocked<{{.*}}>
// CHECK: #[[SRC_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}>
// CHECK: #[[TMP_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}>
// CHECK: large_tensor_conversion
#src = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = false}>
#dst = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @large_tensor_conversion(%arg0: tensor<128x128xf32, #src>) {
// CHECK: %0 = triton_gpu.convert_layout %arg0 : tensor<128x128xf32, #[[SRC_ENC]]> -> tensor<128x128xf32, #[[TMP_ENC]]>
// CHECK: %1 = triton_gpu.convert_layout %0 : tensor<128x128xf32, #[[TMP_ENC]]> -> tensor<128x128xf32, #[[DST_ENC]]>
%0 = triton_gpu.convert_layout %arg0 : tensor<128x128xf32, #src> -> tensor<128x128xf32, #dst>
tt.return
}
}

// -----

// CHECK: #[[DST_ENC:.+]] = #triton_gpu.blocked<{{.*}}>
// CHECK: #[[SRC_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}>
// CHECK: #[[TMP_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}>
// CHECK: large_tensor_3d_conversion
#src = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 1, 2], instrShape = [32, 32], isTransposed = false}>
#dst = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 64, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @large_tensor_3d_conversion(%arg0: tensor<2x128x64xf32, #src>) {
// CHECK: %0 = triton_gpu.convert_layout %arg0 : tensor<2x128x64xf32, #[[SRC_ENC]]> -> tensor<2x128x64xf32, #[[TMP_ENC]]>
// CHECK: %1 = triton_gpu.convert_layout %0 : tensor<2x128x64xf32, #[[TMP_ENC]]> -> tensor<2x128x64xf32, #[[DST_ENC]]>
%0 = triton_gpu.convert_layout %arg0 : tensor<2x128x64xf32, #src> -> tensor<2x128x64xf32, #dst>
tt.return
}
}
20 changes: 19 additions & 1 deletion test/Conversion/amd/decompose-unsupported-conversions.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx942 | FileCheck %s
// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx1130 | FileCheck %s

// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}>
// CHECK: wmma_to_wmma_dot_op
#mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) {
Expand All @@ -13,3 +14,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}>
// CHECK: wmma_to_wmma_dot3d_op
#mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 2, 2]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) {
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[WMMA]]> -> tensor<2x16x16xf16, #[[BLOCKED]]>
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[SHARED]], #triton_gpu.shared_memory>
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>>
%0 = triton_gpu.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
tt.return
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include "OptimizeLDSUtility.h"
#include "TargetInfo.h"
#include "TritonAMDGPUToLLVM/Passes.h"
#include "mlir/Pass/Pass.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/Patterns.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <numeric>
Expand All @@ -20,78 +20,6 @@ namespace triton {

namespace {

constexpr int kPtrBitWidth = 64;

static void addAttrs(Operation *op, ArrayRef<mlir::NamedAttribute> attrs) {
for (const NamedAttribute attr : attrs)
op->setAttr(attr.getName(), attr.getValue());
}

static int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp &cvtOp) {
auto scratchConfig = mlir::triton::getScratchConfigForCvt(
cvtOp.getSrc().getType(), cvtOp.getType());
unsigned elems = getNumScratchElements(scratchConfig.paddedRepShape);
auto srcType = cvtOp.getSrc().getType();
auto bytes =
isa<triton::PointerType>(srcType.getElementType())
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcType.getElementTypeBitWidth()) / 8;

return bytes;
}

static std::vector<std::pair<int, int>> factorizePowerOf2(int n) {
assert(llvm::isPowerOf2_32(n));
int x = log2(n);
std::vector<std::pair<int, int>> pairs;

for (int i = 0; i <= x / 2; ++i) {
int j = x - i;
pairs.push_back({pow(2, i), pow(2, j)});
pairs.push_back({pow(2, j), pow(2, i)});
}

return pairs;
}

static std::pair<triton::gpu::ConvertLayoutOp, triton::gpu::ConvertLayoutOp>
createNewConvertOps(ModuleOp &mod, OpBuilder &builder,
triton::gpu::ConvertLayoutOp &cvtOp,
std::pair<unsigned, unsigned> warpsPerCta) {
unsigned warpsPerCtaX = warpsPerCta.first;
unsigned warpsPerCtaY = warpsPerCta.second;
auto srcType = cvtOp.getSrc().getType();
auto dstType = cvtOp.getType();

auto newDstType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), dstType.getEncoding());
RankedTensorType newSrcType;
if (auto srcMfma =
dyn_cast<triton::gpu::AMDMfmaEncodingAttr>(srcType.getEncoding())) {
auto newMfmaEnc = triton::gpu::AMDMfmaEncodingAttr::get(
mod.getContext(), srcMfma.getVersionMajor(), srcMfma.getVersionMinor(),
{warpsPerCtaX, warpsPerCtaY}, srcMfma.getMDim(), srcMfma.getNDim(),
srcMfma.getIsTransposed(), srcMfma.getCTALayout());

newSrcType = RankedTensorType::get(srcType.getShape(),
srcType.getElementType(), newMfmaEnc);
} else if (auto srcWmma = dyn_cast<triton::gpu::AMDWmmaEncodingAttr>(
srcType.getEncoding())) {
auto newWmmaEnc = triton::gpu::AMDWmmaEncodingAttr::get(
mod.getContext(), {warpsPerCtaX, warpsPerCtaY}, srcWmma.getCTALayout());

newSrcType = RankedTensorType::get(srcType.getShape(),
srcType.getElementType(), newWmmaEnc);
}

auto tmpCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), newSrcType, cvtOp.getSrc());
auto newEpilogueCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), newDstType, tmpCvt);

return std::make_pair(tmpCvt, newEpilogueCvt);
}

struct DecomposeUnsupportedAMDConversions
: public mlir::triton::impl::DecomposeUnsupportedAMDConversionsBase<
DecomposeUnsupportedAMDConversions> {
Expand Down Expand Up @@ -171,52 +99,48 @@ struct DecomposeUnsupportedAMDConversions
return;
}

auto currLDSUsage = getCvtOpLDSUsage(cvtOp);
auto currLDSUsage = mlir::triton::AMD::getCvtOpLDSUsage(cvtOp);
if (currLDSUsage <= sharedMemoryLimit) {
return;
}

unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc);

triton::gpu::ConvertLayoutOp tmpCvt;
triton::gpu::ConvertLayoutOp newEpilogueCvt;

// Find all possible shapes of WarpsPerCTA by finding all possible
// factorizations of numWarps. Pick shape for which both conversions in
// decomposition use LDS less than limit and for which sum of LDS usage
// is minimal. If no such shape exists, do not decompose.
// decomposition use LDS less than sharedMemoryLimit and for which sum of
// LDS usage is minimal. If no such shape exists, do not decompose.
unsigned minLDSUsage = 2 * sharedMemoryLimit;
int minIdx = -1;
auto factorizedNumWarps = factorizePowerOf2(numWarps);
int rank = dstBlocked.getWarpsPerCTA().size();
auto factorizedNumWarps =
mlir::triton::AMD::factorizePowerOf2(numWarps, rank);

SmallVector<Attribute> tmpLayouts;
for (int i = 0; i < factorizedNumWarps.size(); i++) {
auto warpsPerCTAPair = factorizedNumWarps[i];
std::tie(tmpCvt, newEpilogueCvt) =
createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair);

int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt);
int newCvtLDS = getCvtOpLDSUsage(newEpilogueCvt);
if (tmpCvtLDS <= sharedMemoryLimit && newCvtLDS <= sharedMemoryLimit) {
int LDSUsage = tmpCvtLDS + newCvtLDS;
if (LDSUsage < minLDSUsage) {
minLDSUsage = LDSUsage;
minIdx = i;
}
auto warpsPerCTA = factorizedNumWarps[i];
tmpLayouts.push_back(
mlir::triton::AMD::createTmpLayout(srcEnc, warpsPerCTA));
}

for (int i = 0; i < tmpLayouts.size(); i++) {
auto resources = mlir::triton::AMD::estimateResourcesForReplacement(
builder, cvtOp, tmpLayouts[i]);
if (resources.LDS <= sharedMemoryLimit && resources.LDS < minLDSUsage) {
minLDSUsage = resources.LDS;
minIdx = i;
}
newEpilogueCvt.erase();
tmpCvt.erase();
}

if (minIdx == -1) {
if (minIdx == -1 || minLDSUsage > sharedMemoryLimit) {
return;
}

assert(minIdx >= 0 && minIdx < factorizedNumWarps.size());
auto warpsPerCTAPair = factorizedNumWarps[minIdx];
std::tie(tmpCvt, newEpilogueCvt) =
createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair);
assert(minIdx >= 0 && minIdx < tmpLayouts.size());
auto replacementCvts = mlir::triton::AMD::createNewConvertOps(
builder, cvtOp, tmpLayouts[minIdx]);

cvtOp.replaceAllUsesWith(newEpilogueCvt.getResult());
cvtOp.replaceAllUsesWith(replacementCvts.second.getResult());
cvtOp.erase();
});

Expand Down

0 comments on commit b23e154

Please sign in to comment.