From a6151f4e237075919c12a120c391a8b6c6a5000c Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 6 Mar 2025 08:48:51 +0100 Subject: [PATCH] [mlir][IR] Move `match` and `rewrite` functions into separate class (#129861) The vast majority of rewrite / conversion patterns uses a combined `matchAndRewrite` instead of separate `match` and `rewrite` functions. This PR optimizes the code base for the most common case where users implement a combined `matchAndRewrite`. There are no longer any `match` and `rewrite` functions in `RewritePattern`, `ConversionPattern` and their derived classes. Instead, there is a `SplitMatchAndRewriteImpl` class that implements `matchAndRewrite` in terms of `match` and `rewrite`. Details: * The `RewritePattern` and `ConversionPattern` classes are simpler (fewer functions). Especially the `ConversionPattern` class, which now has 5 fewer functions. (There were various `rewrite` overloads to account for 1:1 / 1:N patterns.) * There is a new class `SplitMatchAndRewriteImpl` that derives from `RewritePattern` / `OpRewritePatern` / ..., along with a type alias `RewritePattern::SplitMatchAndRewrite` for convenience. * Fewer `llvm_unreachable` are needed throughout the code base. Instead, we can use pure virtual functions. (In cases where users previously had to implement `rewrite` or `matchAndRewrite`, etc.) * This PR may also improve the number of [`-Woverload-virtual` warnings](https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933) that are produced by GCC. (To be confirmed...) Note for LLVM integration: Patterns with separate `match` / `rewrite` implementations, must derive from `X::SplitMatchAndRewrite` instead of `X`. --------- Co-authored-by: River Riddle --- .../flang/Optimizer/CodeGen/FIROpPatterns.h | 36 +---- mlir/docs/PatternRewriter.md | 25 +-- .../mlir/Conversion/LLVMCommon/Pattern.h | 40 +---- mlir/include/mlir/IR/PatternMatch.h | 94 +++++------ .../mlir/Transforms/DialectConversion.h | 147 +++++++++--------- .../ArithToAMDGPU/ArithToAMDGPU.cpp | 18 ++- .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 10 +- .../Transforms/EmulateUnsupportedFloats.cpp | 5 +- .../Transforms/IntRangeOptimizations.cpp | 6 +- .../Transforms/VectorTransferOpTransforms.cpp | 6 +- mlir/lib/IR/PatternMatch.cpp | 9 -- mlir/unittests/IR/PatternMatchTest.cpp | 5 + 12 files changed, 175 insertions(+), 226 deletions(-) diff --git a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h index 35749dae5d7e9..53d16323beddf 100644 --- a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h +++ b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h @@ -187,7 +187,6 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern { const fir::FIRToLLVMPassOptions &options; - using ConvertToLLVMPattern::match; using ConvertToLLVMPattern::matchAndRewrite; }; @@ -206,20 +205,6 @@ class FIROpConversion : public ConvertFIRToLLVMPattern { options, benefit) {} /// Wrappers around the RewritePattern methods that pass the derived op type. - void rewrite(mlir::Operation *op, mlir::ArrayRef operands, - mlir::ConversionPatternRewriter &rewriter) const final { - rewrite(mlir::cast(op), - OpAdaptor(operands, mlir::cast(op)), rewriter); - } - void rewrite(mlir::Operation *op, mlir::ArrayRef operands, - mlir::ConversionPatternRewriter &rewriter) const final { - auto sourceOp = llvm::cast(op); - rewrite(llvm::cast(op), OneToNOpAdaptor(operands, sourceOp), - rewriter); - } - llvm::LogicalResult match(mlir::Operation *op) const final { - return match(mlir::cast(op)); - } llvm::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::ArrayRef operands, mlir::ConversionPatternRewriter &rewriter) const final { @@ -235,28 +220,12 @@ class FIROpConversion : public ConvertFIRToLLVMPattern { return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); } - /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// Methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. - virtual llvm::LogicalResult match(SourceOp op) const { - llvm_unreachable("must override match or matchAndRewrite"); - } - virtual void rewrite(SourceOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const { - llvm_unreachable("must override rewrite or matchAndRewrite"); - } - virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const { - llvm::SmallVector oneToOneOperands = - getOneToOneAdaptorOperands(adaptor.getOperands()); - rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); - } virtual llvm::LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - if (mlir::failed(match(op))) - return mlir::failure(); - rewrite(op, adaptor, rewriter); - return mlir::success(); + llvm_unreachable("matchAndRewrite is not implemented"); } virtual llvm::LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, @@ -268,7 +237,6 @@ class FIROpConversion : public ConvertFIRToLLVMPattern { private: using ConvertFIRToLLVMPattern::matchAndRewrite; - using ConvertToLLVMPattern::match; }; /// FIR conversion pattern template diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md index 9df4647299010..af0f56466e0cb 100644 --- a/mlir/docs/PatternRewriter.md +++ b/mlir/docs/PatternRewriter.md @@ -38,22 +38,23 @@ possible cost and use the predicate to guard the match. ### Root Operation Name (Optional) The name of the root operation that this pattern matches against. If specified, -only operations with the given root name will be provided to the `match` and -`rewrite` implementation. If not specified, any operation type may be provided. -The root operation name should be provided whenever possible, because it -simplifies the analysis of patterns when applying a cost model. To match any +only operations with the given root name will be provided to the +`matchAndRewrite` implementation. If not specified, any operation type may be +provided. The root operation name should be provided whenever possible, because +it simplifies the analysis of patterns when applying a cost model. To match any operation type, a special tag must be provided to make the intent explicit: `MatchAnyOpTypeTag`. -### `match` and `rewrite` implementation +### `matchAndRewrite` implementation This is the chunk of code that matches a given root `Operation` and performs a rewrite of the IR. A `RewritePattern` can specify this implementation either via -separate `match` and `rewrite` methods, or via a combined `matchAndRewrite` -method. When using the combined `matchAndRewrite` method, no IR mutation should -take place before the match is deemed successful. The combined `matchAndRewrite` -is useful when non-trivially recomputable information is required by the -matching and rewriting phase. See below for examples: +the `matchAndRewrite` method or via separate `match` and `rewrite` methods when +deriving from `RewritePattern::SplitMatchAndRewrite`. When using the combined +`matchAndRewrite` method, no IR mutation should take place before the match is +deemed successful. The combined `matchAndRewrite` is useful when non-trivially +recomputable information is required by the matching and rewriting phase. See +below for examples: ```c++ class MyPattern : public RewritePattern { @@ -105,6 +106,10 @@ Within the `rewrite` section of a pattern, the following constraints apply: `eraseOp`) should be used instead. * The root operation is required to either be: updated in-place, replaced, or erased. +* `matchAndRewrite` must return "success" if and only if the IR was modified. + `match` must return "success" if and only if the IR is going to be modified + during `rewrite`. + ### Application Recursion diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 86ea87b55af1c..8f82176f3b75f 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite( /// during the entire pattern lifetime. class ConvertToLLVMPattern : public ConversionPattern { public: + using SplitMatchAndRewrite = + detail::ConversionSplitMatchAndRewriteImpl; + ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1); @@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern { template class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: + using OperationT = SourceOp; using OpAdaptor = typename SourceOp::Adaptor; using OneToNOpAdaptor = typename SourceOp::template GenericAdaptor>; + using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl< + ConvertOpToLLVMPattern>; explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) @@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { benefit) {} /// Wrappers around the RewritePattern methods that pass the derived op type. - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto sourceOp = cast(op); - rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); - } - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto sourceOp = cast(op); - rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); - } - LogicalResult match(Operation *op) const final { - return match(cast(op)); - } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { @@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { rewriter); } - /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// Methods that operate on the SourceOp type. One of these must be /// overridden by the derived pattern class. - virtual LogicalResult match(SourceOp op) const { - llvm_unreachable("must override match or matchAndRewrite"); - } - virtual void rewrite(SourceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - llvm_unreachable("must override rewrite or matchAndRewrite"); - } - virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - SmallVector oneToOneOperands = - getOneToOneAdaptorOperands(adaptor.getOperands()); - rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); - } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (failed(match(op))) - return failure(); - rewrite(op, adaptor, rewriter); - return success(); + llvm_unreachable("matchAndRewrite is not implemented"); } virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, @@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { } private: - using ConvertToLLVMPattern::match; using ConvertToLLVMPattern::matchAndRewrite; }; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 2ab0405043a54..792b50d38817e 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -234,41 +234,52 @@ class Pattern { // RewritePattern //===----------------------------------------------------------------------===// -/// RewritePattern is the common base class for all DAG to DAG replacements. -/// There are two possible usages of this class: -/// * Multi-step RewritePattern with "match" and "rewrite" -/// - By overloading the "match" and "rewrite" functions, the user can -/// separate the concerns of matching and rewriting. -/// * Single-step RewritePattern with "matchAndRewrite" -/// - By overloading the "matchAndRewrite" function, the user can perform -/// the rewrite in the same call as the match. -/// -class RewritePattern : public Pattern { -public: - virtual ~RewritePattern() = default; +namespace detail { +/// Helper class that derives from a RewritePattern class and provides separate +/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`. +template +class SplitMatchAndRewriteImpl : public PatternT { + using PatternT::PatternT; + + /// Attempt to match against IR rooted at the specified operation, which is + /// the same operation kind as getRootKind(). + /// + /// Note: This function must not modify the IR. + virtual LogicalResult match(typename PatternT::OperationT op) const = 0; /// Rewrite the IR rooted at the specified operation with the result of /// this pattern, generating any new operations with the specified - /// builder. If an unexpected error is encountered (an internal - /// compiler error), it is emitted through the normal MLIR diagnostic - /// hooks and the IR is left in a valid state. - virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; - - /// Attempt to match against code rooted at the specified operation, - /// which is the same operation code as getRootKind(). - virtual LogicalResult match(Operation *op) const; + /// rewriter. + virtual void rewrite(typename PatternT::OperationT op, + PatternRewriter &rewriter) const = 0; - /// Attempt to match against code rooted at the specified operation, - /// which is the same operation code as getRootKind(). If successful, this - /// function will automatically perform the rewrite. - virtual LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewrite(typename PatternT::OperationT op, + PatternRewriter &rewriter) const final { if (succeeded(match(op))) { rewrite(op, rewriter); return success(); } return failure(); } +}; +} // namespace detail + +/// RewritePattern is the common base class for all DAG to DAG replacements. +class RewritePattern : public Pattern { +public: + using OperationT = Operation *; + using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl; + + virtual ~RewritePattern() = default; + + /// Attempt to match against code rooted at the specified operation, + /// which is the same operation code as getRootKind(). If successful, perform + /// the rewrite. + /// + /// Note: Implementations must modify the IR if and only if the function + /// returns "success". + virtual LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const = 0; /// This method provides a convenient interface for creating and initializing /// derived rewrite patterns of the given type `T`. @@ -317,36 +328,19 @@ namespace detail { /// class or Interface. template struct OpOrInterfaceRewritePatternBase : public RewritePattern { + using OperationT = SourceOp; using RewritePattern::RewritePattern; - /// Wrappers around the RewritePattern methods that pass the derived op type. - void rewrite(Operation *op, PatternRewriter &rewriter) const final { - rewrite(cast(op), rewriter); - } - LogicalResult match(Operation *op) const final { - return match(cast(op)); - } + /// Wrapper around the RewritePattern method that passes the derived op type. LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final { return matchAndRewrite(cast(op), rewriter); } - /// Rewrite and Match methods that operate on the SourceOp type. These must be - /// overridden by the derived pattern class. - virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const { - llvm_unreachable("must override rewrite or matchAndRewrite"); - } - virtual LogicalResult match(SourceOp op) const { - llvm_unreachable("must override match or matchAndRewrite"); - } + /// Method that operates on the SourceOp type. Must be overridden by the + /// derived pattern class. virtual LogicalResult matchAndRewrite(SourceOp op, - PatternRewriter &rewriter) const { - if (succeeded(match(op))) { - rewrite(op, rewriter); - return success(); - } - return failure(); - } + PatternRewriter &rewriter) const = 0; }; } // namespace detail @@ -356,6 +350,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern { template struct OpRewritePattern : public detail::OpOrInterfaceRewritePatternBase { + using SplitMatchAndRewrite = + detail::SplitMatchAndRewriteImpl>; + /// Patterns must specify the root operation name they match against, and can /// also specify the benefit of the pattern matching and a list of generated /// ops. @@ -371,6 +368,9 @@ struct OpRewritePattern template struct OpInterfaceRewritePattern : public detail::OpOrInterfaceRewritePatternBase { + using SplitMatchAndRewrite = + detail::SplitMatchAndRewriteImpl>; + OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) : detail::OpOrInterfaceRewritePatternBase( Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(), diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 9a6975dcf8dfa..120709bbe5b67 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -528,24 +528,78 @@ class TypeConverter { // Conversion Patterns //===----------------------------------------------------------------------===// +namespace detail { +/// Helper class that derives from a ConversionRewritePattern class and +/// provides separate `match` and `rewrite` entry points instead of a combined +/// `matchAndRewrite`. +template +class ConversionSplitMatchAndRewriteImpl : public PatternT { + using PatternT::PatternT; + + /// Attempt to match against IR rooted at the specified operation, which is + /// the same operation kind as getRootKind(). + /// + /// Note: This function must not modify the IR. + virtual LogicalResult match(typename PatternT::OperationT op) const = 0; + + /// Rewrite the IR rooted at the specified operation with the result of + /// this pattern, generating any new operations with the specified + /// rewriter. + virtual void rewrite(typename PatternT::OperationT op, + typename PatternT::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // One of the two `rewrite` functions must be implemented. + llvm_unreachable("rewrite is not implemented"); + } + + virtual void rewrite(typename PatternT::OperationT op, + typename PatternT::OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if constexpr (std::is_same>::value) { + rewrite(op, PatternT::getOneToOneAdaptorOperands(adaptor), rewriter); + } else { + SmallVector oneToOneOperands = + PatternT::getOneToOneAdaptorOperands(adaptor.getOperands()); + rewrite(op, typename PatternT::OpAdaptor(oneToOneOperands, adaptor), + rewriter); + } + } + + LogicalResult + matchAndRewrite(typename PatternT::OperationT op, + typename PatternT::OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + if (succeeded(match(op))) { + rewrite(op, adaptor, rewriter); + return success(); + } + return failure(); + } + + LogicalResult + matchAndRewrite(typename PatternT::OperationT op, + typename PatternT::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // Users would normally override this function in conversion patterns to + // implement a 1:1 pattern. Patterns that are derived from this class have + // separate `match` and `rewrite` functions, so this `matchAndRewrite` + // overload is obsolete. + llvm_unreachable("this function is unreachable"); + } +}; +} // namespace detail + /// Base class for the conversion patterns. This pattern class enables type /// conversions, and other uses specific to the conversion framework. As such, /// patterns of this type can only be used with the 'apply*' methods below. class ConversionPattern : public RewritePattern { public: - /// Hook for derived classes to implement rewriting. `op` is the (first) - /// operation matched by the pattern, `operands` is a list of the rewritten - /// operand values that are passed to `op`, `rewriter` can be used to emit the - /// new operations. This function should not fail. If some specific cases of - /// the operation are not supported, these cases should not be matched. - virtual void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - llvm_unreachable("unimplemented rewrite"); - } - virtual void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - rewrite(op, getOneToOneAdaptorOperands(operands), rewriter); - } + using OperationT = Operation *; + using OpAdaptor = ArrayRef; + using OneToNOpAdaptor = ArrayRef; + using SplitMatchAndRewrite = + detail::ConversionSplitMatchAndRewriteImpl; /// Hook for derived classes to implement combined matching and rewriting. /// This overload supports only 1:1 replacements. The 1:N overload is called @@ -554,10 +608,7 @@ class ConversionPattern : public RewritePattern { virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (failed(match(op))) - return failure(); - rewrite(op, operands, rewriter); - return success(); + llvm_unreachable("matchAndRewrite is not implemented"); } /// Hook for derived classes to implement combined matching and rewriting. @@ -606,9 +657,6 @@ class ConversionPattern : public RewritePattern { protected: /// An optional type converter for use by this pattern. const TypeConverter *typeConverter = nullptr; - -private: - using RewritePattern::rewrite; }; /// OpConversionPattern is a wrapper around ConversionPattern that allows for @@ -617,9 +665,12 @@ class ConversionPattern : public RewritePattern { template class OpConversionPattern : public ConversionPattern { public: + using OperationT = SourceOp; using OpAdaptor = typename SourceOp::Adaptor; using OneToNOpAdaptor = typename SourceOp::template GenericAdaptor>; + using SplitMatchAndRewrite = + detail::ConversionSplitMatchAndRewriteImpl>; OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} @@ -630,19 +681,6 @@ class OpConversionPattern : public ConversionPattern { /// Wrappers around the ConversionPattern methods that pass the derived op /// type. - LogicalResult match(Operation *op) const final { - return match(cast(op)); - } - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto sourceOp = cast(op); - rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); - } - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto sourceOp = cast(op); - rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); - } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { @@ -657,28 +695,12 @@ class OpConversionPattern : public ConversionPattern { rewriter); } - /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// Methods that operate on the SourceOp type. One of these must be /// overridden by the derived pattern class. - virtual LogicalResult match(SourceOp op) const { - llvm_unreachable("must override match or matchAndRewrite"); - } - virtual void rewrite(SourceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - llvm_unreachable("must override matchAndRewrite or a rewrite method"); - } - virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - SmallVector oneToOneOperands = - getOneToOneAdaptorOperands(adaptor.getOperands()); - rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); - } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (failed(match(op))) - return failure(); - rewrite(op, adaptor, rewriter); - return success(); + llvm_unreachable("matchAndRewrite is not implemented"); } virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, @@ -708,14 +730,6 @@ class OpInterfaceConversionPattern : public ConversionPattern { /// Wrappers around the ConversionPattern methods that pass the derived op /// type. - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), operands, rewriter); - } - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), operands, rewriter); - } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { @@ -727,23 +741,12 @@ class OpInterfaceConversionPattern : public ConversionPattern { return matchAndRewrite(cast(op), operands, rewriter); } - /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// Methods that operate on the SourceOp type. One of these must be /// overridden by the derived pattern class. - virtual void rewrite(SourceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - llvm_unreachable("must override matchAndRewrite or a rewrite method"); - } - virtual void rewrite(SourceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - rewrite(op, getOneToOneAdaptorOperands(operands), rewriter); - } virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (failed(match(op))) - return failure(); - rewrite(op, operands, rewriter); - return success(); + llvm_unreachable("matchAndRewrite is not implemented"); } virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index cba71740f9380..734c4839f9a10 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -41,23 +41,25 @@ struct ArithToAMDGPUConversionPass final void runOnOperation() override; }; -struct ExtFOnFloat8RewritePattern final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ExtFOnFloat8RewritePattern final + : OpRewritePattern::SplitMatchAndRewrite { + using SplitMatchAndRewrite::SplitMatchAndRewrite; Chipset chipset; ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset) - : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {} + : SplitMatchAndRewrite::SplitMatchAndRewrite(ctx), chipset(chipset) {} LogicalResult match(arith::ExtFOp op) const override; void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; }; -struct TruncFToFloat8RewritePattern final : OpRewritePattern { +struct TruncFToFloat8RewritePattern final + : OpRewritePattern::SplitMatchAndRewrite { bool saturateFP8 = false; TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8, Chipset chipset) - : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8), - chipset(chipset) {} + : SplitMatchAndRewrite::SplitMatchAndRewrite(ctx), + saturateFP8(saturateFP8), chipset(chipset) {} Chipset chipset; LogicalResult match(arith::TruncFOp op) const override; @@ -65,9 +67,9 @@ struct TruncFToFloat8RewritePattern final : OpRewritePattern { }; struct TruncfToFloat16RewritePattern final - : public OpRewritePattern { + : public OpRewritePattern::SplitMatchAndRewrite { - using OpRewritePattern::OpRewritePattern; + using SplitMatchAndRewrite::SplitMatchAndRewrite; LogicalResult match(arith::TruncFOp op) const override; void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override; diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 3646416def810..80310ce56a51b 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -363,11 +363,6 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using ConvertOpToLLVMPattern::isConvertibleAndHasIdentityMaps; using Base = LoadStoreOpLowering; - - LogicalResult match(Derived op) const override { - MemRefType type = op.getMemRefType(); - return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); - } }; /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be @@ -662,8 +657,9 @@ struct RankOpLowering : public ConvertOpToLLVMPattern { } }; -struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +struct MemRefCastOpLowering + : public ConvertOpToLLVMPattern::SplitMatchAndRewrite { + using SplitMatchAndRewrite::SplitMatchAndRewrite; LogicalResult match(memref::CastOp memRefCastOp) const override { Type srcType = memRefCastOp.getOperand().getType(); diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 836ebb65e7d17..f105534626082 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -40,9 +40,10 @@ struct EmulateUnsupportedFloatsPass void runOnOperation() override; }; -struct EmulateFloatPattern final : ConversionPattern { +struct EmulateFloatPattern final : ConversionPattern::SplitMatchAndRewrite { EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx) - : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {} + : ConversionPattern::SplitMatchAndRewrite( + converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {} LogicalResult match(Operation *op) const override; void rewrite(Operation *op, ArrayRef operands, diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 5982f5f55549e..6da28ddeede3c 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -115,9 +115,11 @@ class DataFlowListener : public RewriterBase::Listener { /// and replace their uses with that constant. Return success() if all results /// where thus replaced and the operation is erased. Also replace any block /// arguments with their constant values. -struct MaterializeKnownConstantValues : public RewritePattern { +struct MaterializeKnownConstantValues + : public RewritePattern::SplitMatchAndRewrite { MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s) - : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context), + : RewritePattern::SplitMatchAndRewrite(Pattern::MatchAnyOpTypeTag(), + /*benefit=*/1, context), solver(s) {} LogicalResult match(Operation *op) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index f13e54901f690..2413a4126f3f7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -772,14 +772,14 @@ class FlattenContiguousRowMajorTransferWritePattern /// `vector.extract` and `vector.extract_element`. template class RewriteScalarExtractOfTransferReadBase - : public OpRewritePattern { - using Base = OpRewritePattern; + : public OpRewritePattern::SplitMatchAndRewrite { + using Base = typename OpRewritePattern::SplitMatchAndRewrite; public: RewriteScalarExtractOfTransferReadBase(MLIRContext *context, PatternBenefit benefit, bool allowMultipleUses) - : Base::OpRewritePattern(context, benefit), + : Base::SplitMatchAndRewrite(context, benefit), allowMultipleUses(allowMultipleUses) {} LogicalResult match(VectorExtractOp extractOp) const override { diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 286f47ce69136..3e3c06bebf142 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -87,15 +87,6 @@ Pattern::Pattern(const void *rootValue, RootKind rootKind, // RewritePattern //===----------------------------------------------------------------------===// -void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { - llvm_unreachable("need to implement either matchAndRewrite or one of the " - "rewrite functions!"); -} - -LogicalResult RewritePattern::match(Operation *op) const { - llvm_unreachable("need to implement either match or matchAndRewrite!"); -} - /// Out-of-line vtable anchor. void RewritePattern::anchor() {} diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp index 75d5228c82d99..1c67bfc284d32 100644 --- a/mlir/unittests/IR/PatternMatchTest.cpp +++ b/mlir/unittests/IR/PatternMatchTest.cpp @@ -19,6 +19,11 @@ struct AnOpRewritePattern : OpRewritePattern { AnOpRewritePattern(MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1, /*generatedNames=*/{test::OpB::getOperationName()}) {} + + LogicalResult matchAndRewrite(test::OpA op, + PatternRewriter &rewriter) const override { + return failure(); + } }; TEST(OpRewritePatternTest, GetGeneratedNames) { MLIRContext context;