Skip to content

Commit

Permalink
[mlir][ArmSME] Migrate arm-sme-vector-legalization to dialect conve…
Browse files Browse the repository at this point in the history
…rsion (llvm#121101)

Use the regular dialect conversion driver instead of the 1:N dialect
conversion driver. The 1:N dialect conversion driver will be removed
soon.
  • Loading branch information
matthias-springer authored Dec 31, 2024
1 parent f0d6017 commit 31613de
Showing 1 changed file with 56 additions and 38 deletions.
94 changes: 56 additions & 38 deletions mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "arm-sme-vector-legalization"

Expand Down Expand Up @@ -172,12 +173,12 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
/// Legalize `arith.constant dense<value>` splat operations to fit within SME
/// tiles by decomposing them into tile-sized operations.
struct LegalizeArithConstantOpsByDecomposition
: public OneToNOpConversionPattern<arith::ConstantOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
: public OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
ConversionPatternRewriter &rewriter) const override {
auto vectorType = dyn_cast<VectorType>(constantOp.getType());
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
if (!vectorType || !denseAttr || !denseAttr.isSplat())
Expand All @@ -191,8 +192,8 @@ struct LegalizeArithConstantOpsByDecomposition
auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
auto tileSplat = rewriter.create<arith::ConstantOp>(
constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
adaptor.getResultMapping());
SmallVector<Value> repl(tileCount, tileSplat);
rewriter.replaceOpWithMultiple(constantOp, {repl});

return success();
}
Expand All @@ -201,12 +202,13 @@ struct LegalizeArithConstantOpsByDecomposition
/// Legalize `vector.outerproduct` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeVectorOuterProductOpsByDecomposition
: public OneToNOpConversionPattern<vector::OuterProductOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
: public OpConversionPattern<vector::OuterProductOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
matchAndRewrite(vector::OuterProductOp outerProductOp,
OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto vectorType = outerProductOp.getResultVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(outerProductOp,
Expand All @@ -219,6 +221,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
auto maskOp = outerProductOp.getMaskingOp();
mask = maskOp.getMask();
rootOp = maskOp;
rewriter.setInsertionPoint(rootOp);
}

if (!isSupportedMaskOp(mask))
Expand Down Expand Up @@ -248,7 +251,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
resultSMETiles.push_back(maskedOuterProduct->getResult(0));
}

rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
return success();
}
};
Expand All @@ -259,12 +262,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition
// (invalid). This pattern matches on `vector.mask` then calls into the
// `vector.outerproduct` pattern to work around this issue.
struct LegalizeMaskedVectorOuterProductOpsByDecomposition
: public OneToNOpConversionPattern<vector::MaskOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
: public OpConversionPattern<vector::MaskOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
maskOp.getMaskableOp())) {
LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
Expand All @@ -279,12 +282,12 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
/// Legalize `vector.transfer_read` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeTransferReadOpsByDecomposition
: public OneToNOpConversionPattern<vector::TransferReadOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
: public OpConversionPattern<vector::TransferReadOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto vectorType = readOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(readOp,
Expand Down Expand Up @@ -319,20 +322,20 @@ struct LegalizeTransferReadOpsByDecomposition
resultSMETiles.push_back(smeRead);
}

rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
return success();
}
};

/// Legalize `vector.transfer_write` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeTransferWriteOpsByDecomposition
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
: public OpConversionPattern<vector::TransferWriteOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto vectorType = writeOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(writeOp,
Expand Down Expand Up @@ -409,12 +412,12 @@ struct LegalizeTransferWriteOpsByDecomposition
/// }
/// ```
struct LegalizeMultiTileTransferWriteAsStoreLoop
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
: public OpConversionPattern<vector::TransferWriteOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (writeOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
writeOp, "TODO: tensor semantics are unsupported");
Expand Down Expand Up @@ -936,10 +939,16 @@ struct VectorLegalizationPass
return success();
});

patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
ConvertIllegalShapeCastOpsToTransposes,
LowerIllegalTransposeStoreViaZA>(context);
// Apply preprocessing patterns.
RewritePatternSet rewritePatterns(context);
rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
ConvertIllegalShapeCastOpsToTransposes,
LowerIllegalTransposeStoreViaZA>(context);
if (failed(
applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
return signalPassFailure();

// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
Expand All @@ -950,11 +959,20 @@ struct VectorLegalizationPass
LegalizeVectorOuterProductOpsByDecomposition,
LegalizeTransferReadOpsByDecomposition,
LegalizeTransferWriteOpsByDecomposition>(converter, context);
populateFuncTypeConversionPatterns(converter, patterns);
scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);

if (failed(applyPartialOneToNConversion(getOperation(), converter,
std::move(patterns))))
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
scf::populateSCFStructuralTypeConversions(converter, patterns);

ConversionTarget target(getContext());
target.markUnknownOpDynamicallyLegal(
[&](Operation *op) { return converter.isLegal(op); });
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
Expand Down

0 comments on commit 31613de

Please sign in to comment.