From 31613de9cf22b2915cb39bfb043d957d513bd1cd Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 31 Dec 2024 12:44:50 +0100 Subject: [PATCH] [mlir][ArmSME] Migrate `arm-sme-vector-legalization` to dialect conversion (#121101) Use the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon. --- .../ArmSME/Transforms/VectorLegalization.cpp | 94 +++++++++++-------- 1 file changed, 56 insertions(+), 38 deletions(-) diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 61767f3b21c9c3..12c65a72babcb8 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -17,7 +17,7 @@ #include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -25,7 +25,8 @@ #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" -#include "mlir/Transforms/OneToNTypeConversion.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "arm-sme-vector-legalization" @@ -172,12 +173,12 @@ int getNumberOfSMETilesForVectorType(VectorType type) { /// Legalize `arith.constant dense` splat operations to fit within SME /// tiles by decomposing them into tile-sized operations. struct LegalizeArithConstantOpsByDecomposition - : public OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + ConversionPatternRewriter &rewriter) const override { auto vectorType = dyn_cast(constantOp.getType()); auto denseAttr = dyn_cast(constantOp.getValueAttr()); if (!vectorType || !denseAttr || !denseAttr.isSplat()) @@ -191,8 +192,8 @@ struct LegalizeArithConstantOpsByDecomposition auto tileCount = getNumberOfSMETilesForVectorType(vectorType); auto tileSplat = rewriter.create( constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); - rewriter.replaceOp(constantOp, SmallVector(tileCount, tileSplat), - adaptor.getResultMapping()); + SmallVector repl(tileCount, tileSplat); + rewriter.replaceOpWithMultiple(constantOp, {repl}); return success(); } @@ -201,12 +202,13 @@ struct LegalizeArithConstantOpsByDecomposition /// Legalize `vector.outerproduct` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeVectorOuterProductOpsByDecomposition - : public OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::OuterProductOp outerProductOp, + OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto vectorType = outerProductOp.getResultVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(outerProductOp, @@ -219,6 +221,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition auto maskOp = outerProductOp.getMaskingOp(); mask = maskOp.getMask(); rootOp = maskOp; + rewriter.setInsertionPoint(rootOp); } if (!isSupportedMaskOp(mask)) @@ -248,7 +251,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition resultSMETiles.push_back(maskedOuterProduct->getResult(0)); } - rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping()); + rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles}); return success(); } }; @@ -259,12 +262,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition // (invalid). This pattern matches on `vector.mask` then calls into the // `vector.outerproduct` pattern to work around this issue. struct LegalizeMaskedVectorOuterProductOpsByDecomposition - : public OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { if (auto outerProductOp = llvm::dyn_cast_or_null( maskOp.getMaskableOp())) { LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(), @@ -279,12 +282,12 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition /// Legalize `vector.transfer_read` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeTransferReadOpsByDecomposition - : public OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto vectorType = readOp.getVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(readOp, @@ -319,7 +322,7 @@ struct LegalizeTransferReadOpsByDecomposition resultSMETiles.push_back(smeRead); } - rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping()); + rewriter.replaceOpWithMultiple(readOp, {resultSMETiles}); return success(); } }; @@ -327,12 +330,12 @@ struct LegalizeTransferReadOpsByDecomposition /// Legalize `vector.transfer_write` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeTransferWriteOpsByDecomposition - : public OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto vectorType = writeOp.getVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(writeOp, @@ -409,12 +412,12 @@ struct LegalizeTransferWriteOpsByDecomposition /// } /// ``` struct LegalizeMultiTileTransferWriteAsStoreLoop - : public OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { if (writeOp.hasPureTensorSemantics()) return rewriter.notifyMatchFailure( writeOp, "TODO: tensor semantics are unsupported"); @@ -936,10 +939,16 @@ struct VectorLegalizationPass return success(); }); - patterns.add(context); + // Apply preprocessing patterns. + RewritePatternSet rewritePatterns(context); + rewritePatterns.add(context); + if (failed( + applyPatternsGreedily(getOperation(), std::move(rewritePatterns)))) + return signalPassFailure(); + // Note: These two patterns are added with a high benefit to ensure: // - Masked outer products are handled before unmasked ones // - Multi-tile writes are lowered as a store loop (if possible) @@ -950,11 +959,20 @@ struct VectorLegalizationPass LegalizeVectorOuterProductOpsByDecomposition, LegalizeTransferReadOpsByDecomposition, LegalizeTransferWriteOpsByDecomposition>(converter, context); - populateFuncTypeConversionPatterns(converter, patterns); - scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); - - if (failed(applyPartialOneToNConversion(getOperation(), converter, - std::move(patterns)))) + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + populateCallOpTypeConversionPattern(patterns, converter); + populateReturnOpTypeConversionPattern(patterns, converter); + scf::populateSCFStructuralTypeConversions(converter, patterns); + + ConversionTarget target(getContext()); + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return converter.isLegal(op); }); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return converter.isSignatureLegal(op.getFunctionType()); + }); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) return signalPassFailure(); } };