From a21cfca320bddeef120618ceff9563778b5cbd94 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 7 Mar 2025 08:43:01 +0100 Subject: [PATCH] [mlir][IR] Deprecate `match` and `rewrite` functions (#130031) Deprecate the `match` and `rewrite` functions. They mainly exist for historic reasons. This PR also updates all remaining uses of in the MLIR codebase. This is addressing a [comment](https://github.com/llvm/llvm-project/pull/129861#pullrequestreview-2662696084) on an earlier PR. Note for LLVM integration: `SplitMatchAndRewrite` will be deleted soon, update your patterns to use `matchAndRewrite` instead of separate `match` / `rewrite`. --------- Co-authored-by: Jakub Kuderski --- mlir/docs/DialectConversion.md | 14 +- mlir/docs/PatternRewriter.md | 33 +---- mlir/docs/Tutorials/QuickstartRewrites.md | 19 --- .../mlir/Conversion/LLVMCommon/Pattern.h | 5 + mlir/include/mlir/IR/PatternMatch.h | 12 ++ .../mlir/Transforms/DialectConversion.h | 6 + .../ArithToAMDGPU/ArithToAMDGPU.cpp | 132 +++++++++--------- .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 37 +++-- .../Transforms/EmulateUnsupportedFloats.cpp | 22 ++- .../Transforms/IntRangeOptimizations.cpp | 22 +-- .../Transforms/VectorTransferOpTransforms.cpp | 27 ++-- 11 files changed, 158 insertions(+), 171 deletions(-) diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index abacd5a82c61e..f67d1411b3065 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -179,13 +179,13 @@ updated/remapped operands of an operation, such as when the types of results defined by an operation have changed. The general Rewrite Patterns can no longer be used in these situations, as the types of the operands of the operation being matched will not correspond with those expected by the user. This pattern -provides, as an additional argument to the `matchAndRewrite` and `rewrite` -methods, the list of operands that the operation should use after conversion. If -an operand was the result of a non-converted operation, for example if it was -already legal, the original operand is used. This means that the operands -provided always have a 1-1 non-null correspondence with the operands on the -operation. The original operands of the operation are still intact and may be -inspected as normal. These patterns also utilize a special `PatternRewriter`, +provides, as an additional argument to the `matchAndRewrite` method, the list +of operands that the operation should use after conversion. If an operand was +the result of a non-converted operation, for example if it was already legal, +the original operand is used. This means that the operands provided always have +a 1-1 non-null correspondence with the operands on the operation. The original +operands of the operation are still intact and may be inspected as normal. +These patterns also utilize a special `PatternRewriter`, `ConversionPatternRewriter`, that provides special hooks for use with the conversion infrastructure. diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md index af0f56466e0cb..105a554b95851 100644 --- a/mlir/docs/PatternRewriter.md +++ b/mlir/docs/PatternRewriter.md @@ -48,13 +48,9 @@ operation type, a special tag must be provided to make the intent explicit: ### `matchAndRewrite` implementation This is the chunk of code that matches a given root `Operation` and performs a -rewrite of the IR. A `RewritePattern` can specify this implementation either via -the `matchAndRewrite` method or via separate `match` and `rewrite` methods when -deriving from `RewritePattern::SplitMatchAndRewrite`. When using the combined -`matchAndRewrite` method, no IR mutation should take place before the match is -deemed successful. The combined `matchAndRewrite` is useful when non-trivially -recomputable information is required by the matching and rewriting phase. See -below for examples: +rewrite of the IR. A `RewritePattern` can specify this implementation via the +`matchAndRewrite` method. No IR mutation should take place before the match is +deemed successful. See below for examples: ```c++ class MyPattern : public RewritePattern { @@ -67,21 +63,6 @@ public: MyPattern(PatternBenefit benefit) : RewritePattern(benefit, MatchAnyOpTypeTag()) {} - /// In this section, the `match` and `rewrite` implementation is specified - /// using the separate hooks. - LogicalResult match(Operation *op) const override { - // The `match` method returns `success()` if the pattern is a match, failure - // otherwise. - // ... - } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { - // The `rewrite` method performs mutations on the IR rooted at `op` using - // the provided rewriter. All mutations must go through the provided - // rewriter. - } - - /// In this section, the `match` and `rewrite` implementation is specified - /// using a single hook. LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { // The `matchAndRewrite` method performs both the matching and the mutation. // Note that the match must reach a successful point before IR mutation may @@ -92,12 +73,6 @@ public: #### Restrictions -Within the `match` section of a pattern, the following constraints apply: - -* No mutation of the IR is allowed. - -Within the `rewrite` section of a pattern, the following constraints apply: - * All IR mutations, including creation, *must* be performed by the given `PatternRewriter`. This class provides hooks for performing all of the possible mutations that may take place within a pattern. For example, this @@ -107,8 +82,6 @@ Within the `rewrite` section of a pattern, the following constraints apply: * The root operation is required to either be: updated in-place, replaced, or erased. * `matchAndRewrite` must return "success" if and only if the IR was modified. - `match` must return "success" if and only if the IR is going to be modified - during `rewrite`. ### Application Recursion diff --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md index 604148bd9c600..493f9d5687374 100644 --- a/mlir/docs/Tutorials/QuickstartRewrites.md +++ b/mlir/docs/Tutorials/QuickstartRewrites.md @@ -216,25 +216,6 @@ In case ODS patterns and `matchAndRewrite`-style functions are not sufficient you can also specify rewrites as a general set of `RewritePattern`s: ```c++ -/// Multi-step rewrite using "match" and "rewrite". This allows for separating -/// the concerns of matching and rewriting. -struct ConvertTFLeakyRelu : public RewritePattern { - ConvertTFLeakyRelu(MLIRContext *context) - : RewritePattern("tf.LeakyRelu", 1, context) {} - - LogicalResult match(Operation *op) const override { - return success(); - } - - void rewrite(Operation *op, PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType(), op->getOperand(0), - /*alpha=*/op->getAttrOfType("alpha")); - } -}; - -/// Single-step rewrite with "matchAndRewrite". This allows for performing the -/// rewrite immediately upon a successful match. struct ConvertTFLeakyRelu : public RewritePattern { ConvertTFLeakyRelu(MLIRContext *context) : RewritePattern("tf.LeakyRelu", 1, context) {} diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 8f82176f3b75f..e78f174ff8586 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -40,6 +40,8 @@ LogicalResult oneToOneRewrite( /// during the entire pattern lifetime. class ConvertToLLVMPattern : public ConversionPattern { public: + /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of + /// separate `match` and `rewrite`. using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl; @@ -149,6 +151,9 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { using OpAdaptor = typename SourceOp::Adaptor; using OneToNOpAdaptor = typename SourceOp::template GenericAdaptor>; + + /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of + /// separate `match` and `rewrite`. using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl< ConvertOpToLLVMPattern>; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 792b50d38817e..d1f00c34f87b4 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -237,6 +237,9 @@ class Pattern { namespace detail { /// Helper class that derives from a RewritePattern class and provides separate /// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`. +/// +/// This class is deprecated. Use `matchAndRewrite` instead of separate `match` +/// and `rewrite`. template class SplitMatchAndRewriteImpl : public PatternT { using PatternT::PatternT; @@ -268,6 +271,9 @@ class SplitMatchAndRewriteImpl : public PatternT { class RewritePattern : public Pattern { public: using OperationT = Operation *; + + /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of + /// separate `match` and `rewrite`. using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl; virtual ~RewritePattern() = default; @@ -350,6 +356,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern { template struct OpRewritePattern : public detail::OpOrInterfaceRewritePatternBase { + + /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of + /// separate `match` and `rewrite`. using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl>; @@ -368,6 +377,9 @@ struct OpRewritePattern template struct OpInterfaceRewritePattern : public detail::OpOrInterfaceRewritePatternBase { + + /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of + /// separate `match` and `rewrite`. using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl>; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 120709bbe5b67..f54397e942ae0 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -598,6 +598,9 @@ class ConversionPattern : public RewritePattern { using OperationT = Operation *; using OpAdaptor = ArrayRef; using OneToNOpAdaptor = ArrayRef; + + /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of + /// separate `match` and `rewrite`. using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl; @@ -669,6 +672,9 @@ class OpConversionPattern : public ConversionPattern { using OpAdaptor = typename SourceOp::Adaptor; using OneToNOpAdaptor = typename SourceOp::template GenericAdaptor>; + + /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of + /// separate `match` and `rewrite`. using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl>; diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 734c4839f9a10..27be54728c1a1 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -41,48 +41,46 @@ struct ArithToAMDGPUConversionPass final void runOnOperation() override; }; -struct ExtFOnFloat8RewritePattern final - : OpRewritePattern::SplitMatchAndRewrite { - using SplitMatchAndRewrite::SplitMatchAndRewrite; +struct ExtFOnFloat8RewritePattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; Chipset chipset; ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset) - : SplitMatchAndRewrite::SplitMatchAndRewrite(ctx), chipset(chipset) {} + : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {} - LogicalResult match(arith::ExtFOp op) const override; - void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const override; }; -struct TruncFToFloat8RewritePattern final - : OpRewritePattern::SplitMatchAndRewrite { +struct TruncFToFloat8RewritePattern final : OpRewritePattern { bool saturateFP8 = false; TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8, Chipset chipset) - : SplitMatchAndRewrite::SplitMatchAndRewrite(ctx), - saturateFP8(saturateFP8), chipset(chipset) {} + : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8), + chipset(chipset) {} Chipset chipset; - LogicalResult match(arith::TruncFOp op) const override; - void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const override; }; struct TruncfToFloat16RewritePattern final - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(arith::TruncFOp op) const override; - void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const override; }; } // end namespace -static LogicalResult isSupportedF8(Type elementType, Chipset chipset) { +static bool isSupportedF8(Type elementType, Chipset chipset) { if (chipset == kGfx942) - return success(isa(elementType)); + return isa(elementType); if (hasOcpFp8(chipset)) - return success(isa(elementType)); - return failure(); + return isa(elementType); + return false; } static Value castF32To(Type elementType, Value f32, Location loc, @@ -96,35 +94,36 @@ static Value castF32To(Type elementType, Value f32, Location loc, llvm_unreachable("The only 32-bit float type is f32"); } -LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const { +LogicalResult +ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const { Type inType = op.getIn().getType(); - if (auto inVecType = dyn_cast(inType)) { + auto inVecType = dyn_cast(inType); + if (inVecType) { if (inVecType.isScalable()) return failure(); inType = inVecType.getElementType(); } - return isSupportedF8(inType, chipset); -} + if (!isSupportedF8(inType, chipset)) + return failure(); -void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, - PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value in = op.getIn(); Type outElemType = getElementTypeOrSelf(op.getOut().getType()); - auto inType = dyn_cast(in.getType()); - if (!inType) { + if (!inVecType) { Value asFloat = rewriter.create( loc, rewriter.getF32Type(), in, 0); Value result = castF32To(outElemType, asFloat, loc, rewriter); - return rewriter.replaceOp(op, result); + rewriter.replaceOp(op, result); + return success(); } - int64_t numElements = inType.getNumElements(); + int64_t numElements = inVecType.getNumElements(); Value zero = rewriter.create( loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); VectorType outType = cast(op.getOut().getType()); - if (inType.getShape().empty()) { + if (inVecType.getShape().empty()) { Value zerodSplat = rewriter.createOrFold(loc, outType, zero); Value scalarIn = @@ -133,17 +132,18 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, rewriter.create(loc, outElemType, scalarIn); Value result = rewriter.create(loc, scalarExt, zerodSplat, ArrayRef{}); - return rewriter.replaceOp(op, result); + rewriter.replaceOp(op, result); + return success(); } VectorType flatTy = VectorType::get(SmallVector{numElements}, outType.getElementType()); Value result = rewriter.createOrFold(loc, flatTy, zero); - if (inType.getRank() > 1) { - inType = VectorType::get(SmallVector{numElements}, - inType.getElementType()); - in = rewriter.create(loc, inType, in); + if (inVecType.getRank() > 1) { + inVecType = VectorType::get(SmallVector{numElements}, + inVecType.getElementType()); + in = rewriter.create(loc, inVecType, in); } for (int64_t i = 0; i < numElements; i += 4) { @@ -158,11 +158,12 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, } } - if (inType.getRank() != outType.getRank()) { + if (inVecType.getRank() != outType.getRank()) { result = rewriter.create(loc, outType, result); } rewriter.replaceOp(op, result); + return success(); } static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { @@ -222,12 +223,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc, return res; } -LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const { +LogicalResult +TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const { // Only supporting default rounding mode as of now. if (op.getRoundingmodeAttr()) return failure(); Type outType = op.getOut().getType(); - if (auto outVecType = dyn_cast(outType)) { + auto outVecType = dyn_cast(outType); + if (outVecType) { if (outVecType.isScalable()) return failure(); outType = outVecType.getElementType(); @@ -237,11 +241,9 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const { // Conversion between 8-bit floats is not supported with truncation enabled. return failure(); - return isSupportedF8(outType, chipset); -} + if (!isSupportedF8(outType, chipset)) + return failure(); -void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, - PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value in = op.getIn(); Type outElemType = getElementTypeOrSelf(op.getOut().getType()); @@ -255,13 +257,14 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, /*existing=*/nullptr); Value result = rewriter.create(loc, asF8s, 0); - return rewriter.replaceOp(op, result); + rewriter.replaceOp(op, result); + return success(); } - VectorType outType = cast(op.getOut().getType()); - int64_t numElements = outType.getNumElements(); + + int64_t numElements = outVecType.getNumElements(); Value zero = rewriter.create( loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); - if (outType.getShape().empty()) { + if (outVecType.getShape().empty()) { Value scalarIn = rewriter.create(loc, in, ArrayRef{}); // Recurse to send the 0-D vector case to the 1-D vector case @@ -269,11 +272,12 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, rewriter.create(loc, outElemType, scalarIn); Value result = rewriter.create(loc, scalarTrunc, zero, ArrayRef{}); - return rewriter.replaceOp(op, result); + rewriter.replaceOp(op, result); + return success(); } VectorType flatTy = VectorType::get(SmallVector{numElements}, - outType.getElementType()); + outVecType.getElementType()); Value result = rewriter.createOrFold(loc, flatTy, zero); if (inVectorTy.getRank() > 1) { @@ -303,26 +307,27 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, result, i, 1); } - if (inVectorTy.getRank() != outType.getRank()) { - result = rewriter.create(loc, outType, result); + if (inVectorTy.getRank() != outVecType.getRank()) { + result = rewriter.create(loc, outVecType, result); } rewriter.replaceOp(op, result); + return success(); } -LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const { +LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( + arith::TruncFOp op, PatternRewriter &rewriter) const { Type outType = op.getOut().getType(); Type inputType = getElementTypeOrSelf(op.getIn()); - if (auto outVecType = dyn_cast(outType)) { + auto outVecType = dyn_cast(outType); + if (outVecType) { if (outVecType.isScalable()) return failure(); outType = outVecType.getElementType(); } - return success(outType.isF16() && inputType.isF32()); -} + if (!(outType.isF16() && inputType.isF32())) + return failure(); -void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op, - PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value in = op.getIn(); Type outElemType = getElementTypeOrSelf(op.getOut().getType()); @@ -335,13 +340,13 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op, Value asF16s = rewriter.create(loc, truncResType, in, sourceB); Value result = rewriter.create(loc, asF16s, 0); - return rewriter.replaceOp(op, result); + rewriter.replaceOp(op, result); + return success(); } - VectorType outType = cast(op.getOut().getType()); - int64_t numElements = outType.getNumElements(); + int64_t numElements = outVecType.getNumElements(); Value zero = rewriter.createOrFold( loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); - Value result = rewriter.createOrFold(loc, outType, zero); + Value result = rewriter.createOrFold(loc, outVecType, zero); if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector{numElements}, @@ -371,11 +376,12 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op, result, i, 1); } - if (inVectorTy.getRank() != outType.getRank()) { - result = rewriter.create(loc, outType, result); + if (inVectorTy.getRank() != outVecType.getRank()) { + result = rewriter.create(loc, outVecType, result); } rewriter.replaceOp(op, result); + return success(); } void mlir::arith::populateArithToAMDGPUConversionPatterns( diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 80310ce56a51b..fe0ee11d84adb 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -657,11 +657,12 @@ struct RankOpLowering : public ConvertOpToLLVMPattern { } }; -struct MemRefCastOpLowering - : public ConvertOpToLLVMPattern::SplitMatchAndRewrite { - using SplitMatchAndRewrite::SplitMatchAndRewrite; +struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult match(memref::CastOp memRefCastOp) const override { + LogicalResult + matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Type srcType = memRefCastOp.getOperand().getType(); Type dstType = memRefCastOp.getType(); @@ -671,30 +672,22 @@ struct MemRefCastOpLowering // perform a sanity check that the underlying structs are the same. Once op // semantics are relaxed we can revisit. if (isa(srcType) && isa(dstType)) - return success(typeConverter->convertType(srcType) == - typeConverter->convertType(dstType)); - - // At least one of the operands is unranked type - assert(isa(srcType) || - isa(dstType)); + if (typeConverter->convertType(srcType) != + typeConverter->convertType(dstType)) + return failure(); // Unranked to unranked cast is disallowed - return !(isa(srcType) && - isa(dstType)) - ? success() - : failure(); - } + if (isa(srcType) && isa(dstType)) + return failure(); - void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcType = memRefCastOp.getOperand().getType(); - auto dstType = memRefCastOp.getType(); auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); auto loc = memRefCastOp.getLoc(); // For ranked/ranked case, just keep the original descriptor. - if (isa(srcType) && isa(dstType)) - return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); + if (isa(srcType) && isa(dstType)) { + rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); + return success(); + } if (isa(srcType) && isa(dstType)) { // Casting ranked to unranked memref type @@ -733,6 +726,8 @@ struct MemRefCastOpLowering } else { llvm_unreachable("Unsupported unranked memref to unranked memref cast"); } + + return success(); } }; diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index f105534626082..62022bfb7df1e 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -40,36 +40,33 @@ struct EmulateUnsupportedFloatsPass void runOnOperation() override; }; -struct EmulateFloatPattern final : ConversionPattern::SplitMatchAndRewrite { +struct EmulateFloatPattern final : ConversionPattern { EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx) - : ConversionPattern::SplitMatchAndRewrite( + : ConversionPattern::ConversionPattern( converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {} - LogicalResult match(Operation *op) const override; - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; }; } // end namespace -LogicalResult EmulateFloatPattern::match(Operation *op) const { +LogicalResult EmulateFloatPattern::matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { if (getTypeConverter()->isLegal(op)) return failure(); // The rewrite doesn't handle cloning regions. if (op->getNumRegions() != 0) return failure(); - return success(); -} -void EmulateFloatPattern::rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); const TypeConverter *converter = getTypeConverter(); SmallVector resultTypes; if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) { // Note to anyone looking for this error message: this is a "can't happen". // If you're seeing it, there's a bug. - op->emitOpError("type conversion failed in float emulation"); - return; + return op->emitOpError("type conversion failed in float emulation"); } Operation *expandedOp = rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes, @@ -84,6 +81,7 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef operands, } } rewriter.replaceOp(op, newResults); + return success(); } void mlir::arith::populateEmulateUnsupportedFloatsConversions( diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 6da28ddeede3c..f866c91ef6e39 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -115,14 +115,14 @@ class DataFlowListener : public RewriterBase::Listener { /// and replace their uses with that constant. Return success() if all results /// where thus replaced and the operation is erased. Also replace any block /// arguments with their constant values. -struct MaterializeKnownConstantValues - : public RewritePattern::SplitMatchAndRewrite { +struct MaterializeKnownConstantValues : public RewritePattern { MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s) - : RewritePattern::SplitMatchAndRewrite(Pattern::MatchAnyOpTypeTag(), - /*benefit=*/1, context), + : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(), + /*benefit=*/1, context), solver(s) {} - LogicalResult match(Operation *op) const override { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { if (matchPattern(op, m_Constant())) return failure(); @@ -131,7 +131,8 @@ struct MaterializeKnownConstantValues }; bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing); if (op->getNumRegions() == 0) - return success(hasConstantResults); + if (!hasConstantResults) + return failure(); bool hasConstantRegionArgs = false; for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { @@ -139,10 +140,9 @@ struct MaterializeKnownConstantValues llvm::any_of(block.getArguments(), needsReplacing); } } - return success(hasConstantResults || hasConstantRegionArgs); - } + if (!hasConstantResults && !hasConstantRegionArgs) + return failure(); - void rewrite(Operation *op, PatternRewriter &rewriter) const override { bool replacedAll = (op->getNumResults() != 0); for (Value v : op->getResults()) replacedAll &= @@ -150,7 +150,7 @@ struct MaterializeKnownConstantValues v.use_empty()); if (replacedAll && isOpTriviallyDead(op)) { rewriter.eraseOp(op); - return; + return success(); } PatternRewriter::InsertionGuard guard(rewriter); @@ -162,6 +162,8 @@ struct MaterializeKnownConstantValues } } } + + return success(); } private: diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 2413a4126f3f7..8c9e2d889808a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -772,17 +772,16 @@ class FlattenContiguousRowMajorTransferWritePattern /// `vector.extract` and `vector.extract_element`. template class RewriteScalarExtractOfTransferReadBase - : public OpRewritePattern::SplitMatchAndRewrite { - using Base = typename OpRewritePattern::SplitMatchAndRewrite; + : public OpRewritePattern { + using Base = OpRewritePattern; public: RewriteScalarExtractOfTransferReadBase(MLIRContext *context, PatternBenefit benefit, bool allowMultipleUses) - : Base::SplitMatchAndRewrite(context, benefit), - allowMultipleUses(allowMultipleUses) {} + : Base(context, benefit), allowMultipleUses(allowMultipleUses) {} - LogicalResult match(VectorExtractOp extractOp) const override { + LogicalResult match(VectorExtractOp extractOp) const { auto xferOp = extractOp.getVector().template getDefiningOp(); if (!xferOp) @@ -828,8 +827,11 @@ class RewriteScalarExtractElementOfTransferRead using RewriteScalarExtractOfTransferReadBase:: RewriteScalarExtractOfTransferReadBase; - void rewrite(vector::ExtractElementOp extractOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp, + PatternRewriter &rewriter) const override { + if (failed(match(extractOp))) + return failure(); + // Construct scalar load. auto loc = extractOp.getLoc(); auto xferOp = extractOp.getVector().getDefiningOp(); @@ -856,6 +858,8 @@ class RewriteScalarExtractElementOfTransferRead rewriter.replaceOpWithNewOp( extractOp, xferOp.getSource(), newIndices); } + + return success(); } }; @@ -872,8 +876,11 @@ class RewriteScalarExtractOfTransferRead using RewriteScalarExtractOfTransferReadBase:: RewriteScalarExtractOfTransferReadBase; - void rewrite(vector::ExtractOp extractOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + if (failed(match(extractOp))) + return failure(); + // Construct scalar load. auto xferOp = extractOp.getVector().getDefiningOp(); SmallVector newIndices(xferOp.getIndices().begin(), @@ -899,6 +906,8 @@ class RewriteScalarExtractOfTransferRead rewriter.replaceOpWithNewOp( extractOp, xferOp.getSource(), newIndices); } + + return success(); } };