Skip to content

Commit

Permalink
improve chain dot checking (#480)
Browse files Browse the repository at this point in the history
* improve the chained dot check
  • Loading branch information
scxiao authored Jan 31, 2024
1 parent 8fa7cf3 commit 6aa0111
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 27 deletions.
62 changes: 35 additions & 27 deletions lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,51 @@ int getMfmaVersion(MatrixCoreVersion matrixCoreVer) {
return 0;
}

SmallVector<unsigned, 2>
warpsPerTile(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
SmallVector<int64_t, 2> shapePerWarp) {
// TODO: needs to be updated with appropriate shapePerWarp etc.
static bool isTransposeChainDotPattern(tt::DotOp &dotOp) {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
mlir::ForwardSliceOptions fwdOpt;
fwdOpt.filter = filter;
mlir::SetVector<mlir::Operation*> fwdSlices;
mlir::getForwardSlice(static_cast<mlir::Operation*>(dotOp), &fwdSlices, fwdOpt);
for (Operation *op : fwdSlices) {
// ensure output of the first dot is the operand 0 of the second dot
if (isa<tt::DotOp>(op) && (op != dotOp)) {
auto dOp = dyn_cast<tt::DotOp>(op);
auto oper0 = dOp.getOperand(0).getDefiningOp();
if(std::find(fwdSlices.begin(), fwdSlices.end(), oper0) != fwdSlices.end()) {
return true;
}
}
}

mlir::BackwardSliceOptions bwdOpt;
bwdOpt.omitBlockArguments = true;
bwdOpt.filter = filter;
auto slices = mlir::getSlice(dotOp, bwdOpt, fwdOpt);
for (Operation *op : slices)
if (isa<tt::DotOp>(op) && (op != dotOp))
return {(unsigned)numWarps, 1};
mlir::SetVector<mlir::Operation*> bwdSlices;
// search backward of the operand 0 of the dot
mlir::Operation* oper0 = dotOp.getOperand(0).getDefiningOp();
mlir::getBackwardSlice(oper0, &bwdSlices, bwdOpt);
for (Operation *op : bwdSlices) {
if (isa<tt::DotOp>(op) && (op != dotOp)) {
return true;
}
}

return false;
}

SmallVector<unsigned, 2>
warpsPerTile(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
SmallVector<int64_t, 2> shapePerWarp) {
// TODO: needs to be updated with appropriate shapePerWarp etc.
if (isTransposeChainDotPattern(dotOp)) {
return {(unsigned)numWarps, 1};
}

SmallVector<int64_t, 2> tensorShape = {shape[0], shape[1]};
SmallVector<unsigned, 2> ret = {1, 1};

do {
if (ret[0] * ret[1] >= numWarps)
break;
Expand Down Expand Up @@ -113,23 +138,6 @@ class BlockedToMFMA : public mlir::RewritePattern {
: mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context),
mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim) {}

bool isChainDot(tt::DotOp &dotOp) const {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
mlir::ForwardSliceOptions fwdOpt;
fwdOpt.filter = filter;
mlir::BackwardSliceOptions bwdOpt;
bwdOpt.omitBlockArguments = true;
bwdOpt.filter = filter;
auto slices = mlir::getSlice(dotOp, bwdOpt, fwdOpt);
for (Operation *op : slices) {
if (isa<tt::DotOp>(op) && (op != dotOp))
return true;
}
return false;
}

/// @brief Choose MFMA instruction parameters
/// @param dot target dot operation
/// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments
Expand Down Expand Up @@ -229,7 +237,7 @@ class BlockedToMFMA : public mlir::RewritePattern {
auto warpsPerTile =
warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim});

bool isTransposed = isChainDot(dotOp);
bool isTransposed = isTransposeChainDotPattern(dotOp);
mfmaEnc = ttg::MfmaEncodingAttr::get(
oldRetType.getContext(),
/*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile,
Expand Down
Loading

0 comments on commit 6aa0111

Please sign in to comment.