Skip to content

Commit

Permalink
revert DecomposeUnsupportedConversions.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
binarman committed Jul 5, 2024
1 parent a489d90 commit 67eaac0
Showing 1 changed file with 103 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include "OptimizeLDSUtility.h"
#include "TargetInfo.h"
#include "TritonAMDGPUToLLVM/Passes.h"
#include "mlir/Pass/Pass.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/Patterns.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <numeric>
Expand All @@ -20,6 +20,13 @@ namespace triton {

namespace {

constexpr int kPtrBitWidth = 64;

static void addAttrs(Operation *op, ArrayRef<mlir::NamedAttribute> attrs) {
for (const NamedAttribute attr : attrs)
op->setAttr(attr.getName(), attr.getValue());
}

static void promoteReduceOpResult(OpBuilder &builder, triton::ReduceOp op,
Value result, Type promotedType) {
// save original type
Expand Down Expand Up @@ -51,6 +58,74 @@ static void promoteReduceOpResult(OpBuilder &builder, triton::ReduceOp op,
}
}

static int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp &cvtOp) {
unsigned inVec = 0;
unsigned outVec = 0;
auto smemShape = triton::getScratchConfigForCvtLayout(cvtOp, inVec, outVec);
unsigned elems = getNumElements<unsigned>(smemShape);
auto srcType = cvtOp.getSrc().getType();
auto bytes =
isa<triton::PointerType>(srcType.getElementType())
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcType.getElementTypeBitWidth()) / 8;

return bytes;
}

bool isPowerOfTwo(unsigned x) { return x && (x & (x - 1)) == 0; }

static std::vector<std::pair<int, int>> factorizePowerOf2(int n) {
assert(isPowerOfTwo(n));
int x = log2(n);
std::vector<std::pair<int, int>> pairs;

for (int i = 0; i <= x / 2; ++i) {
int j = x - i;
pairs.push_back({pow(2, i), pow(2, j)});
pairs.push_back({pow(2, j), pow(2, i)});
}

return pairs;
}

static std::pair<triton::gpu::ConvertLayoutOp, triton::gpu::ConvertLayoutOp>
createNewConvertOps(ModuleOp &mod, OpBuilder &builder,
triton::gpu::ConvertLayoutOp &cvtOp,
std::pair<unsigned, unsigned> warpsPerCta) {
unsigned warpsPerCtaX = warpsPerCta.first;
unsigned warpsPerCtaY = warpsPerCta.second;
auto srcType = cvtOp.getSrc().getType();
auto dstType = cvtOp.getType();

auto newDstType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), dstType.getEncoding());
RankedTensorType newSrcType;
if (auto srcMfma =
dyn_cast<triton::gpu::AMDMfmaEncodingAttr>(srcType.getEncoding())) {
auto newMfmaEnc = triton::gpu::AMDMfmaEncodingAttr::get(
mod.getContext(), srcMfma.getVersionMajor(), srcMfma.getVersionMinor(),
{warpsPerCtaX, warpsPerCtaY}, srcMfma.getMDim(), srcMfma.getNDim(),
srcMfma.getIsTransposed(), srcMfma.getCTALayout());

newSrcType = RankedTensorType::get(srcType.getShape(),
srcType.getElementType(), newMfmaEnc);
} else if (auto srcWmma = dyn_cast<triton::gpu::AMDWmmaEncodingAttr>(
srcType.getEncoding())) {
auto newWmmaEnc = triton::gpu::AMDWmmaEncodingAttr::get(
mod.getContext(), {warpsPerCtaX, warpsPerCtaY}, srcWmma.getCTALayout());

newSrcType = RankedTensorType::get(srcType.getShape(),
srcType.getElementType(), newWmmaEnc);
}

auto tmpCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), newSrcType, cvtOp.getSrc());
auto newEpilogueCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), newDstType, tmpCvt);

return std::make_pair(tmpCvt, newEpilogueCvt);
}

struct DecomposeUnsupportedAMDConversions
: public mlir::triton::impl::DecomposeUnsupportedAMDConversionsBase<
DecomposeUnsupportedAMDConversions> {
Expand Down Expand Up @@ -130,48 +205,52 @@ struct DecomposeUnsupportedAMDConversions
return;
}

auto currLDSUsage = mlir::triton::AMD::getCvtOpLDSUsage(cvtOp);
auto currLDSUsage = getCvtOpLDSUsage(cvtOp);
if (currLDSUsage <= sharedMemoryLimit) {
return;
}

unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc);

triton::gpu::ConvertLayoutOp tmpCvt;
triton::gpu::ConvertLayoutOp newEpilogueCvt;

// Find all possible shapes of WarpsPerCTA by finding all possible
// factorizations of numWarps. Pick shape for which both conversions in
// decomposition use LDS less than sharedMemoryLimit and for which sum of
// LDS usage is minimal. If no such shape exists, do not decompose.
// decomposition use LDS less than limit and for which sum of LDS usage
// is minimal. If no such shape exists, do not decompose.
unsigned minLDSUsage = 2 * sharedMemoryLimit;
int minIdx = -1;
int rank = dstBlocked.getWarpsPerCTA().size();
auto factorizedNumWarps =
mlir::triton::AMD::factorizePowerOf2(numWarps, rank);
auto factorizedNumWarps = factorizePowerOf2(numWarps);

SmallVector<Attribute> tmpLayouts;
for (int i = 0; i < factorizedNumWarps.size(); i++) {
auto warpsPerCTA = factorizedNumWarps[i];
tmpLayouts.push_back(
mlir::triton::AMD::createTmpLayout(srcEnc, warpsPerCTA));
}

for (int i = 0; i < tmpLayouts.size(); i++) {
auto resources = mlir::triton::AMD::estimateResourcesForReplacement(
builder, cvtOp, tmpLayouts[i]);
if (resources.LDS <= sharedMemoryLimit && resources.LDS < minLDSUsage) {
minLDSUsage = resources.LDS;
minIdx = i;
auto warpsPerCTAPair = factorizedNumWarps[i];
std::tie(tmpCvt, newEpilogueCvt) =
createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair);

int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt);
int newCvtLDS = getCvtOpLDSUsage(newEpilogueCvt);
if (tmpCvtLDS <= sharedMemoryLimit && newCvtLDS <= sharedMemoryLimit) {
int LDSUsage = tmpCvtLDS + newCvtLDS;
if (LDSUsage < minLDSUsage) {
minLDSUsage = LDSUsage;
minIdx = i;
}
}
newEpilogueCvt.erase();
tmpCvt.erase();
}

if (minIdx == -1 || minLDSUsage > sharedMemoryLimit) {
if (minIdx == -1) {
return;
}

assert(minIdx >= 0 && minIdx < tmpLayouts.size());
auto replacementCvts = mlir::triton::AMD::createNewConvertOps(
builder, cvtOp, tmpLayouts[minIdx]);
assert(minIdx >= 0 && minIdx < factorizedNumWarps.size());
auto warpsPerCTAPair = factorizedNumWarps[minIdx];
std::tie(tmpCvt, newEpilogueCvt) =
createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair);

cvtOp.replaceAllUsesWith(replacementCvts.second.getResult());
cvtOp.replaceAllUsesWith(newEpilogueCvt.getResult());
cvtOp.erase();
});

Expand Down

0 comments on commit 67eaac0

Please sign in to comment.