Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][IR] Move match and rewrite functions into separate class #129861

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Mar 5, 2025

The vast majority of rewrite / conversion patterns uses a combined matchAndRewrite instead of separate match and rewrite functions.

This PR optimizes the code base for the most common case where users implement a combined matchAndRewrite. There are no longer any match and rewrite functions in RewritePattern, ConversionPattern and their derived classes. Instead, there is a SplitMatchAndRewriteImpl class that implements matchAndRewrite in terms of match and rewrite.

Details:

  • The RewritePattern and ConversionPattern classes are simpler (fewer functions). Especially the ConversionPattern class, which now has 5 fewer functions. (There were various rewrite overloads to account for 1:1 / 1:N patterns.)
  • There is a new class SplitMatchAndRewriteImpl that derives from RewritePattern / OpRewritePatern / ..., along with a type alias RewritePattern::SplitMatchAndRewrite for convenience.
  • Fewer llvm_unreachable are needed throughout the code base. Instead, we can use pure virtual functions. (In cases where users previously had to implement rewrite or matchAndRewrite, etc.)
  • This PR may also improve the number of -Woverload-virtual warnings that are produced by GCC. (To be confirmed...)

Note for LLVM integration: Patterns with separate match / rewrite implementations, must derive from X::SplitMatchAndRewrite instead of X.

@llvmbot
Copy link
Member

llvmbot commented Mar 5, 2025

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-flang-codegen
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-backend-amdgpu

Author: Matthias Springer (matthias-springer)

Changes

The vast majority of rewrite / conversion patterns uses a combined matchAndRewrite instead of separate match and rewrite functions.

This PR optimizes the code base for the most common case where users implement a combined matchAndRewrite. There are no longer any match and rewrite functions in RewritePattern, ConversionPattern and their derived classes. Instead, there is a SplitMatchAndRewriteImpl class that implements matchAndRewrite in terms of match and rewrite.

Details:

  • The RewritePattern and ConversionPattern classes are simpler (fewer functions). Especially the ConversionPattern class, which now has 5 fewer functions. (There were various rewrite overloads to account for 1:1 / 1:N patterns.)
  • There is a new class SplitMatchAndRewriteImpl that derives from RewritePattern / OpRewritePatern / ..., along with a type alias RewritePattern::SplitMatchAndRewrite for convenience.
  • Fewer llvm_unreachable are needed throughout the code base. Instead, we can use pure virtual functions. (In cases where users previously had to implement rewrite or matchAndRewrite, etc.)
  • This PR may also improve the number of -Woverload-virtual warnings that are produced by GCC. (To be confirmed...)

Patch is 27.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129861.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+8-32)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+42-44)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+69-72)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+10-8)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+3-7)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+3-3)
  • (modified) mlir/lib/IR/PatternMatch.cpp (-9)
  • (modified) mlir/unittests/IR/PatternMatchTest.cpp (+5)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 86ea87b55af1c..8f82176f3b75f 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite(
 /// during the entire pattern lifetime.
 class ConvertToLLVMPattern : public ConversionPattern {
 public:
+  using SplitMatchAndRewrite =
+      detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
+
   ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
                        const LLVMTypeConverter &typeConverter,
                        PatternBenefit benefit = 1);
@@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
 template <typename SourceOp>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
+  using OperationT = SourceOp;
   using OpAdaptor = typename SourceOp::Adaptor;
   using OneToNOpAdaptor =
       typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+  using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
+      ConvertOpToLLVMPattern<SourceOp>>;
 
   explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
                                   PatternBenefit benefit = 1)
@@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                              benefit) {}
 
   /// Wrappers around the RewritePattern methods that pass the derived op type.
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
-  }
-  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
-  }
-  LogicalResult match(Operation *op) const final {
-    return match(cast<SourceOp>(op));
-  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                            rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. One of these must be
   /// overridden by the derived pattern class.
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override rewrite or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    SmallVector<Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, adaptor, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   }
 
 private:
-  using ConvertToLLVMPattern::match;
   using ConvertToLLVMPattern::matchAndRewrite;
 };
 
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ab0405043a54..9055dc6ed7fc1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -234,41 +234,50 @@ class Pattern {
 // RewritePattern
 //===----------------------------------------------------------------------===//
 
-/// RewritePattern is the common base class for all DAG to DAG replacements.
-/// There are two possible usages of this class:
-///   * Multi-step RewritePattern with "match" and "rewrite"
-///     - By overloading the "match" and "rewrite" functions, the user can
-///       separate the concerns of matching and rewriting.
-///   * Single-step RewritePattern with "matchAndRewrite"
-///     - By overloading the "matchAndRewrite" function, the user can perform
-///       the rewrite in the same call as the match.
-///
-class RewritePattern : public Pattern {
-public:
-  virtual ~RewritePattern() = default;
+namespace detail {
+/// Helper class that derives from a RewritePattern class and provides separate
+/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
+template <typename PatternT>
+class SplitMatchAndRewriteImpl : public PatternT {
+  using PatternT::PatternT;
 
   /// Rewrite the IR rooted at the specified operation with the result of
   /// this pattern, generating any new operations with the specified
-  /// builder.  If an unexpected error is encountered (an internal
-  /// compiler error), it is emitted through the normal MLIR diagnostic
-  /// hooks and the IR is left in a valid state.
-  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
+  /// rewriter.
+  virtual void rewrite(typename PatternT::OperationT op,
+                       PatternRewriter &rewriter) const = 0;
 
   /// Attempt to match against code rooted at the specified operation,
   /// which is the same operation code as getRootKind().
-  virtual LogicalResult match(Operation *op) const;
+  virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
 
-  /// Attempt to match against code rooted at the specified operation,
-  /// which is the same operation code as getRootKind(). If successful, this
-  /// function will automatically perform the rewrite.
-  virtual LogicalResult matchAndRewrite(Operation *op,
-                                        PatternRewriter &rewriter) const {
+  LogicalResult matchAndRewrite(typename PatternT::OperationT op,
+                                PatternRewriter &rewriter) const final {
     if (succeeded(match(op))) {
       rewrite(op, rewriter);
       return success();
     }
     return failure();
   }
+};
+} // namespace detail
+
+/// RewritePattern is the common base class for all DAG to DAG replacements.
+/// By overloading the "matchAndRewrite" function, the user can perform the
+/// rewrite in the same call as the match.
+///
+class RewritePattern : public Pattern {
+public:
+  using OperationT = Operation *;
+  using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
+
+  virtual ~RewritePattern() = default;
+
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind(). If successful, this
+  /// function will automatically perform the rewrite.
+  virtual LogicalResult matchAndRewrite(Operation *op,
+                                        PatternRewriter &rewriter) const = 0;
 
   /// This method provides a convenient interface for creating and initializing
   /// derived rewrite patterns of the given type `T`.
@@ -317,36 +326,19 @@ namespace detail {
 /// class or Interface.
 template <typename SourceOp>
 struct OpOrInterfaceRewritePatternBase : public RewritePattern {
+  using OperationT = SourceOp;
   using RewritePattern::RewritePattern;
 
-  /// Wrappers around the RewritePattern methods that pass the derived op type.
-  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), rewriter);
-  }
-  LogicalResult match(Operation *op) const final {
-    return match(cast<SourceOp>(op));
-  }
+  /// Wrapper around the RewritePattern method that passes the derived op type.
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const final {
     return matchAndRewrite(cast<SourceOp>(op), rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
-  /// overridden by the derived pattern class.
-  virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
-    llvm_unreachable("must override rewrite or matchAndRewrite");
-  }
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
+  /// Method that operates on the SourceOp type. Must be overridden by the
+  /// derived pattern class.
   virtual LogicalResult matchAndRewrite(SourceOp op,
-                                        PatternRewriter &rewriter) const {
-    if (succeeded(match(op))) {
-      rewrite(op, rewriter);
-      return success();
-    }
-    return failure();
-  }
+                                        PatternRewriter &rewriter) const = 0;
 };
 } // namespace detail
 
@@ -356,6 +348,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
 template <typename SourceOp>
 struct OpRewritePattern
     : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+  using SplitMatchAndRewrite =
+      detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
+
   /// Patterns must specify the root operation name they match against, and can
   /// also specify the benefit of the pattern matching and a list of generated
   /// ops.
@@ -371,6 +366,9 @@ struct OpRewritePattern
 template <typename SourceOp>
 struct OpInterfaceRewritePattern
     : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+  using SplitMatchAndRewrite =
+      detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
+
   OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
       : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
             Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9a6975dcf8dfa..d705480cb137e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -528,24 +528,72 @@ class TypeConverter {
 // Conversion Patterns
 //===----------------------------------------------------------------------===//
 
+namespace detail {
+/// Helper class that derives from a ConversionRewritePattern class and
+/// provides separate `match` and `rewrite` entry points instead of a combined
+/// `matchAndRewrite`.
+template <typename PatternT>
+class ConversionSplitMatchAndRewriteImpl : public PatternT {
+  using PatternT::PatternT;
+
+  /// Rewrite the IR rooted at the specified operation with the result of
+  /// this pattern, generating any new operations with the specified
+  /// rewriter.
+  virtual void rewrite(typename PatternT::OperationT op,
+                       typename PatternT::OpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    // One of the two `rewrite` functions must be implemented.
+    llvm_unreachable("rewrite is not implemented");
+  }
+
+  virtual void rewrite(typename PatternT::OperationT op,
+                       typename PatternT::OneToNOpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    if constexpr (std::is_same<typename PatternT::OpAdaptor,
+                               ArrayRef<Value>>::value) {
+      rewrite(op, PatternT::getOneToOneAdaptorOperands(adaptor), rewriter);
+    } else {
+      SmallVector<Value> oneToOneOperands =
+          PatternT::getOneToOneAdaptorOperands(adaptor.getOperands());
+      rewrite(op, typename PatternT::OpAdaptor(oneToOneOperands, adaptor),
+              rewriter);
+    }
+  }
+
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind().
+  virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
+
+  LogicalResult
+  matchAndRewrite(typename PatternT::OperationT op,
+                  typename PatternT::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    llvm_unreachable("1:1 matchAndRewrite entry point is never used");
+  }
+
+  LogicalResult
+  matchAndRewrite(typename PatternT::OperationT op,
+                  typename PatternT::OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (succeeded(match(op))) {
+      rewrite(op, adaptor, rewriter);
+      return success();
+    }
+    return failure();
+  }
+};
+} // namespace detail
+
 /// Base class for the conversion patterns. This pattern class enables type
 /// conversions, and other uses specific to the conversion framework. As such,
 /// patterns of this type can only be used with the 'apply*' methods below.
 class ConversionPattern : public RewritePattern {
 public:
-  /// Hook for derived classes to implement rewriting. `op` is the (first)
-  /// operation matched by the pattern, `operands` is a list of the rewritten
-  /// operand values that are passed to `op`, `rewriter` can be used to emit the
-  /// new operations. This function should not fail. If some specific cases of
-  /// the operation are not supported, these cases should not be matched.
-  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("unimplemented rewrite");
-  }
-  virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
-  }
+  using OperationT = Operation *;
+  using OpAdaptor = ArrayRef<Value>;
+  using OneToNOpAdaptor = ArrayRef<ValueRange>;
+  using SplitMatchAndRewrite =
+      detail::ConversionSplitMatchAndRewriteImpl<ConversionPattern>;
 
   /// Hook for derived classes to implement combined matching and rewriting.
   /// This overload supports only 1:1 replacements. The 1:N overload is called
@@ -554,10 +602,7 @@ class ConversionPattern : public RewritePattern {
   virtual LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, operands, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
 
   /// Hook for derived classes to implement combined matching and rewriting.
@@ -606,9 +651,6 @@ class ConversionPattern : public RewritePattern {
 protected:
   /// An optional type converter for use by this pattern.
   const TypeConverter *typeConverter = nullptr;
-
-private:
-  using RewritePattern::rewrite;
 };
 
 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
@@ -617,9 +659,12 @@ class ConversionPattern : public RewritePattern {
 template <typename SourceOp>
 class OpConversionPattern : public ConversionPattern {
 public:
+  using OperationT = SourceOp;
   using OpAdaptor = typename SourceOp::Adaptor;
   using OneToNOpAdaptor =
       typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+  using SplitMatchAndRewrite =
+      detail::ConversionSplitMatchAndRewriteImpl<OpConversionPattern<SourceOp>>;
 
   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -630,19 +675,6 @@ class OpConversionPattern : public ConversionPattern {
 
   /// Wrappers around the ConversionPattern methods that pass the derived op
   /// type.
-  LogicalResult match(Operation *op) const final {
-    return match(cast<SourceOp>(op));
-  }
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
-  }
-  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
-  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -657,28 +689,12 @@ class OpConversionPattern : public ConversionPattern {
                            rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. One of these must be
   /// overridden by the derived pattern class.
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override matchAndRewrite or a rewrite method");
-  }
-  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    SmallVector<Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, adaptor, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -708,14 +724,6 @@ class OpInterfaceConversionPattern : public ConversionPattern {
 
   /// Wrappers around the ConversionPattern methods that pass the derived op
   /// type.
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), operands, rewriter);
-  }
-  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), operands, rewriter);
-  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -727,23 +735,12 @@ class OpInterfaceConversionPattern : public ConversionPattern {
     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. One of these must be
   /// overridden by the derived pattern class.
-  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override matchAndRewrite or a rewrite method");
-  }
-  virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, operands, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cba71740f9380..734c4839f9a10 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,23 +41,25 @@ struct ArithToAMDGPUConversionPass final
   void runOnOperation() override;
 };
 
-struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
-  using OpRewritePattern::OpRewritePattern;
+struct ExtFOnFloat8RewritePattern final
+    : OpRewritePattern<arith::ExtFOp>::SplitMatchAndRewrite {
+  using SplitMatchAndRewrite::SplitMatchAndRewrite;
 
   Chipset chipset;
   ExtFOnFloat8...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 5, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

The vast majority of rewrite / conversion patterns uses a combined matchAndRewrite instead of separate match and rewrite functions.

This PR optimizes the code base for the most common case where users implement a combined matchAndRewrite. There are no longer any match and rewrite functions in RewritePattern, ConversionPattern and their derived classes. Instead, there is a SplitMatchAndRewriteImpl class that implements matchAndRewrite in terms of match and rewrite.

Details:

  • The RewritePattern and ConversionPattern classes are simpler (fewer functions). Especially the ConversionPattern class, which now has 5 fewer functions. (There were various rewrite overloads to account for 1:1 / 1:N patterns.)
  • There is a new class SplitMatchAndRewriteImpl that derives from RewritePattern / OpRewritePatern / ..., along with a type alias RewritePattern::SplitMatchAndRewrite for convenience.
  • Fewer llvm_unreachable are needed throughout the code base. Instead, we can use pure virtual functions. (In cases where users previously had to implement rewrite or matchAndRewrite, etc.)
  • This PR may also improve the number of -Woverload-virtual warnings that are produced by GCC. (To be confirmed...)

Patch is 27.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129861.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+8-32)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+42-44)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+69-72)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+10-8)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+3-7)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+3-3)
  • (modified) mlir/lib/IR/PatternMatch.cpp (-9)
  • (modified) mlir/unittests/IR/PatternMatchTest.cpp (+5)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 86ea87b55af1c..8f82176f3b75f 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite(
 /// during the entire pattern lifetime.
 class ConvertToLLVMPattern : public ConversionPattern {
 public:
+  using SplitMatchAndRewrite =
+      detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
+
   ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
                        const LLVMTypeConverter &typeConverter,
                        PatternBenefit benefit = 1);
@@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
 template <typename SourceOp>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
+  using OperationT = SourceOp;
   using OpAdaptor = typename SourceOp::Adaptor;
   using OneToNOpAdaptor =
       typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+  using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
+      ConvertOpToLLVMPattern<SourceOp>>;
 
   explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
                                   PatternBenefit benefit = 1)
@@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                              benefit) {}
 
   /// Wrappers around the RewritePattern methods that pass the derived op type.
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
-  }
-  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
-  }
-  LogicalResult match(Operation *op) const final {
-    return match(cast<SourceOp>(op));
-  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                            rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. One of these must be
   /// overridden by the derived pattern class.
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override rewrite or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    SmallVector<Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, adaptor, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   }
 
 private:
-  using ConvertToLLVMPattern::match;
   using ConvertToLLVMPattern::matchAndRewrite;
 };
 
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ab0405043a54..9055dc6ed7fc1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -234,41 +234,50 @@ class Pattern {
 // RewritePattern
 //===----------------------------------------------------------------------===//
 
-/// RewritePattern is the common base class for all DAG to DAG replacements.
-/// There are two possible usages of this class:
-///   * Multi-step RewritePattern with "match" and "rewrite"
-///     - By overloading the "match" and "rewrite" functions, the user can
-///       separate the concerns of matching and rewriting.
-///   * Single-step RewritePattern with "matchAndRewrite"
-///     - By overloading the "matchAndRewrite" function, the user can perform
-///       the rewrite in the same call as the match.
-///
-class RewritePattern : public Pattern {
-public:
-  virtual ~RewritePattern() = default;
+namespace detail {
+/// Helper class that derives from a RewritePattern class and provides separate
+/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
+template <typename PatternT>
+class SplitMatchAndRewriteImpl : public PatternT {
+  using PatternT::PatternT;
 
   /// Rewrite the IR rooted at the specified operation with the result of
   /// this pattern, generating any new operations with the specified
-  /// builder.  If an unexpected error is encountered (an internal
-  /// compiler error), it is emitted through the normal MLIR diagnostic
-  /// hooks and the IR is left in a valid state.
-  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
+  /// rewriter.
+  virtual void rewrite(typename PatternT::OperationT op,
+                       PatternRewriter &rewriter) const = 0;
 
   /// Attempt to match against code rooted at the specified operation,
   /// which is the same operation code as getRootKind().
-  virtual LogicalResult match(Operation *op) const;
+  virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
 
-  /// Attempt to match against code rooted at the specified operation,
-  /// which is the same operation code as getRootKind(). If successful, this
-  /// function will automatically perform the rewrite.
-  virtual LogicalResult matchAndRewrite(Operation *op,
-                                        PatternRewriter &rewriter) const {
+  LogicalResult matchAndRewrite(typename PatternT::OperationT op,
+                                PatternRewriter &rewriter) const final {
     if (succeeded(match(op))) {
       rewrite(op, rewriter);
       return success();
     }
     return failure();
   }
+};
+} // namespace detail
+
+/// RewritePattern is the common base class for all DAG to DAG replacements.
+/// By overloading the "matchAndRewrite" function, the user can perform the
+/// rewrite in the same call as the match.
+///
+class RewritePattern : public Pattern {
+public:
+  using OperationT = Operation *;
+  using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
+
+  virtual ~RewritePattern() = default;
+
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind(). If successful, this
+  /// function will automatically perform the rewrite.
+  virtual LogicalResult matchAndRewrite(Operation *op,
+                                        PatternRewriter &rewriter) const = 0;
 
   /// This method provides a convenient interface for creating and initializing
   /// derived rewrite patterns of the given type `T`.
@@ -317,36 +326,19 @@ namespace detail {
 /// class or Interface.
 template <typename SourceOp>
 struct OpOrInterfaceRewritePatternBase : public RewritePattern {
+  using OperationT = SourceOp;
   using RewritePattern::RewritePattern;
 
-  /// Wrappers around the RewritePattern methods that pass the derived op type.
-  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), rewriter);
-  }
-  LogicalResult match(Operation *op) const final {
-    return match(cast<SourceOp>(op));
-  }
+  /// Wrapper around the RewritePattern method that passes the derived op type.
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const final {
     return matchAndRewrite(cast<SourceOp>(op), rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
-  /// overridden by the derived pattern class.
-  virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
-    llvm_unreachable("must override rewrite or matchAndRewrite");
-  }
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
+  /// Method that operates on the SourceOp type. Must be overridden by the
+  /// derived pattern class.
   virtual LogicalResult matchAndRewrite(SourceOp op,
-                                        PatternRewriter &rewriter) const {
-    if (succeeded(match(op))) {
-      rewrite(op, rewriter);
-      return success();
-    }
-    return failure();
-  }
+                                        PatternRewriter &rewriter) const = 0;
 };
 } // namespace detail
 
@@ -356,6 +348,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
 template <typename SourceOp>
 struct OpRewritePattern
     : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+  using SplitMatchAndRewrite =
+      detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
+
   /// Patterns must specify the root operation name they match against, and can
   /// also specify the benefit of the pattern matching and a list of generated
   /// ops.
@@ -371,6 +366,9 @@ struct OpRewritePattern
 template <typename SourceOp>
 struct OpInterfaceRewritePattern
     : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+  using SplitMatchAndRewrite =
+      detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
+
   OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
       : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
             Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9a6975dcf8dfa..d705480cb137e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -528,24 +528,72 @@ class TypeConverter {
 // Conversion Patterns
 //===----------------------------------------------------------------------===//
 
+namespace detail {
+/// Helper class that derives from a ConversionRewritePattern class and
+/// provides separate `match` and `rewrite` entry points instead of a combined
+/// `matchAndRewrite`.
+template <typename PatternT>
+class ConversionSplitMatchAndRewriteImpl : public PatternT {
+  using PatternT::PatternT;
+
+  /// Rewrite the IR rooted at the specified operation with the result of
+  /// this pattern, generating any new operations with the specified
+  /// rewriter.
+  virtual void rewrite(typename PatternT::OperationT op,
+                       typename PatternT::OpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    // One of the two `rewrite` functions must be implemented.
+    llvm_unreachable("rewrite is not implemented");
+  }
+
+  virtual void rewrite(typename PatternT::OperationT op,
+                       typename PatternT::OneToNOpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    if constexpr (std::is_same<typename PatternT::OpAdaptor,
+                               ArrayRef<Value>>::value) {
+      rewrite(op, PatternT::getOneToOneAdaptorOperands(adaptor), rewriter);
+    } else {
+      SmallVector<Value> oneToOneOperands =
+          PatternT::getOneToOneAdaptorOperands(adaptor.getOperands());
+      rewrite(op, typename PatternT::OpAdaptor(oneToOneOperands, adaptor),
+              rewriter);
+    }
+  }
+
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind().
+  virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
+
+  LogicalResult
+  matchAndRewrite(typename PatternT::OperationT op,
+                  typename PatternT::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    llvm_unreachable("1:1 matchAndRewrite entry point is never used");
+  }
+
+  LogicalResult
+  matchAndRewrite(typename PatternT::OperationT op,
+                  typename PatternT::OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (succeeded(match(op))) {
+      rewrite(op, adaptor, rewriter);
+      return success();
+    }
+    return failure();
+  }
+};
+} // namespace detail
+
 /// Base class for the conversion patterns. This pattern class enables type
 /// conversions, and other uses specific to the conversion framework. As such,
 /// patterns of this type can only be used with the 'apply*' methods below.
 class ConversionPattern : public RewritePattern {
 public:
-  /// Hook for derived classes to implement rewriting. `op` is the (first)
-  /// operation matched by the pattern, `operands` is a list of the rewritten
-  /// operand values that are passed to `op`, `rewriter` can be used to emit the
-  /// new operations. This function should not fail. If some specific cases of
-  /// the operation are not supported, these cases should not be matched.
-  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("unimplemented rewrite");
-  }
-  virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
-  }
+  using OperationT = Operation *;
+  using OpAdaptor = ArrayRef<Value>;
+  using OneToNOpAdaptor = ArrayRef<ValueRange>;
+  using SplitMatchAndRewrite =
+      detail::ConversionSplitMatchAndRewriteImpl<ConversionPattern>;
 
   /// Hook for derived classes to implement combined matching and rewriting.
   /// This overload supports only 1:1 replacements. The 1:N overload is called
@@ -554,10 +602,7 @@ class ConversionPattern : public RewritePattern {
   virtual LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, operands, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
 
   /// Hook for derived classes to implement combined matching and rewriting.
@@ -606,9 +651,6 @@ class ConversionPattern : public RewritePattern {
 protected:
   /// An optional type converter for use by this pattern.
   const TypeConverter *typeConverter = nullptr;
-
-private:
-  using RewritePattern::rewrite;
 };
 
 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
@@ -617,9 +659,12 @@ class ConversionPattern : public RewritePattern {
 template <typename SourceOp>
 class OpConversionPattern : public ConversionPattern {
 public:
+  using OperationT = SourceOp;
   using OpAdaptor = typename SourceOp::Adaptor;
   using OneToNOpAdaptor =
       typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+  using SplitMatchAndRewrite =
+      detail::ConversionSplitMatchAndRewriteImpl<OpConversionPattern<SourceOp>>;
 
   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -630,19 +675,6 @@ class OpConversionPattern : public ConversionPattern {
 
   /// Wrappers around the ConversionPattern methods that pass the derived op
   /// type.
-  LogicalResult match(Operation *op) const final {
-    return match(cast<SourceOp>(op));
-  }
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
-  }
-  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
-  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -657,28 +689,12 @@ class OpConversionPattern : public ConversionPattern {
                            rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. One of these must be
   /// overridden by the derived pattern class.
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override matchAndRewrite or a rewrite method");
-  }
-  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    SmallVector<Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, adaptor, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -708,14 +724,6 @@ class OpInterfaceConversionPattern : public ConversionPattern {
 
   /// Wrappers around the ConversionPattern methods that pass the derived op
   /// type.
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), operands, rewriter);
-  }
-  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), operands, rewriter);
-  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -727,23 +735,12 @@ class OpInterfaceConversionPattern : public ConversionPattern {
     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. One of these must be
   /// overridden by the derived pattern class.
-  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override matchAndRewrite or a rewrite method");
-  }
-  virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, operands, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cba71740f9380..734c4839f9a10 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,23 +41,25 @@ struct ArithToAMDGPUConversionPass final
   void runOnOperation() override;
 };
 
-struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
-  using OpRewritePattern::OpRewritePattern;
+struct ExtFOnFloat8RewritePattern final
+    : OpRewritePattern<arith::ExtFOp>::SplitMatchAndRewrite {
+  using SplitMatchAndRewrite::SplitMatchAndRewrite;
 
   Chipset chipset;
   ExtFOnFloat8...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 5, 2025

@llvm/pr-subscribers-mlir-arith

Author: Matthias Springer (matthias-springer)

Changes

The vast majority of rewrite / conversion patterns uses a combined matchAndRewrite instead of separate match and rewrite functions.

This PR optimizes the code base for the most common case where users implement a combined matchAndRewrite. There are no longer any match and rewrite functions in RewritePattern, ConversionPattern and their derived classes. Instead, there is a SplitMatchAndRewriteImpl class that implements matchAndRewrite in terms of match and rewrite.

Details:

  • The RewritePattern and ConversionPattern classes are simpler (fewer functions). Especially the ConversionPattern class, which now has 5 fewer functions. (There were various rewrite overloads to account for 1:1 / 1:N patterns.)
  • There is a new class SplitMatchAndRewriteImpl that derives from RewritePattern / OpRewritePatern / ..., along with a type alias RewritePattern::SplitMatchAndRewrite for convenience.
  • Fewer llvm_unreachable are needed throughout the code base. Instead, we can use pure virtual functions. (In cases where users previously had to implement rewrite or matchAndRewrite, etc.)
  • This PR may also improve the number of -Woverload-virtual warnings that are produced by GCC. (To be confirmed...)

Patch is 27.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129861.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+8-32)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+42-44)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+69-72)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+10-8)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+3-7)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+3-3)
  • (modified) mlir/lib/IR/PatternMatch.cpp (-9)
  • (modified) mlir/unittests/IR/PatternMatchTest.cpp (+5)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 86ea87b55af1c..8f82176f3b75f 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite(
 /// during the entire pattern lifetime.
 class ConvertToLLVMPattern : public ConversionPattern {
 public:
+  using SplitMatchAndRewrite =
+      detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
+
   ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
                        const LLVMTypeConverter &typeConverter,
                        PatternBenefit benefit = 1);
@@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
 template <typename SourceOp>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
+  using OperationT = SourceOp;
   using OpAdaptor = typename SourceOp::Adaptor;
   using OneToNOpAdaptor =
       typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+  using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
+      ConvertOpToLLVMPattern<SourceOp>>;
 
   explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
                                   PatternBenefit benefit = 1)
@@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                              benefit) {}
 
   /// Wrappers around the RewritePattern methods that pass the derived op type.
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
-  }
-  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
-  }
-  LogicalResult match(Operation *op) const final {
-    return match(cast<SourceOp>(op));
-  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                            rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. One of these must be
   /// overridden by the derived pattern class.
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override rewrite or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    SmallVector<Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, adaptor, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   }
 
 private:
-  using ConvertToLLVMPattern::match;
   using ConvertToLLVMPattern::matchAndRewrite;
 };
 
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ab0405043a54..9055dc6ed7fc1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -234,41 +234,50 @@ class Pattern {
 // RewritePattern
 //===----------------------------------------------------------------------===//
 
-/// RewritePattern is the common base class for all DAG to DAG replacements.
-/// There are two possible usages of this class:
-///   * Multi-step RewritePattern with "match" and "rewrite"
-///     - By overloading the "match" and "rewrite" functions, the user can
-///       separate the concerns of matching and rewriting.
-///   * Single-step RewritePattern with "matchAndRewrite"
-///     - By overloading the "matchAndRewrite" function, the user can perform
-///       the rewrite in the same call as the match.
-///
-class RewritePattern : public Pattern {
-public:
-  virtual ~RewritePattern() = default;
+namespace detail {
+/// Helper class that derives from a RewritePattern class and provides separate
+/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
+template <typename PatternT>
+class SplitMatchAndRewriteImpl : public PatternT {
+  using PatternT::PatternT;
 
   /// Rewrite the IR rooted at the specified operation with the result of
   /// this pattern, generating any new operations with the specified
-  /// builder.  If an unexpected error is encountered (an internal
-  /// compiler error), it is emitted through the normal MLIR diagnostic
-  /// hooks and the IR is left in a valid state.
-  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
+  /// rewriter.
+  virtual void rewrite(typename PatternT::OperationT op,
+                       PatternRewriter &rewriter) const = 0;
 
   /// Attempt to match against code rooted at the specified operation,
   /// which is the same operation code as getRootKind().
-  virtual LogicalResult match(Operation *op) const;
+  virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
 
-  /// Attempt to match against code rooted at the specified operation,
-  /// which is the same operation code as getRootKind(). If successful, this
-  /// function will automatically perform the rewrite.
-  virtual LogicalResult matchAndRewrite(Operation *op,
-                                        PatternRewriter &rewriter) const {
+  LogicalResult matchAndRewrite(typename PatternT::OperationT op,
+                                PatternRewriter &rewriter) const final {
     if (succeeded(match(op))) {
       rewrite(op, rewriter);
       return success();
     }
     return failure();
   }
+};
+} // namespace detail
+
+/// RewritePattern is the common base class for all DAG to DAG replacements.
+/// By overloading the "matchAndRewrite" function, the user can perform the
+/// rewrite in the same call as the match.
+///
+class RewritePattern : public Pattern {
+public:
+  using OperationT = Operation *;
+  using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
+
+  virtual ~RewritePattern() = default;
+
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind(). If successful, this
+  /// function will automatically perform the rewrite.
+  virtual LogicalResult matchAndRewrite(Operation *op,
+                                        PatternRewriter &rewriter) const = 0;
 
   /// This method provides a convenient interface for creating and initializing
   /// derived rewrite patterns of the given type `T`.
@@ -317,36 +326,19 @@ namespace detail {
 /// class or Interface.
 template <typename SourceOp>
 struct OpOrInterfaceRewritePatternBase : public RewritePattern {
+  using OperationT = SourceOp;
   using RewritePattern::RewritePattern;
 
-  /// Wrappers around the RewritePattern methods that pass the derived op type.
-  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), rewriter);
-  }
-  LogicalResult match(Operation *op) const final {
-    return match(cast<SourceOp>(op));
-  }
+  /// Wrapper around the RewritePattern method that passes the derived op type.
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const final {
     return matchAndRewrite(cast<SourceOp>(op), rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
-  /// overridden by the derived pattern class.
-  virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
-    llvm_unreachable("must override rewrite or matchAndRewrite");
-  }
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
+  /// Method that operates on the SourceOp type. Must be overridden by the
+  /// derived pattern class.
   virtual LogicalResult matchAndRewrite(SourceOp op,
-                                        PatternRewriter &rewriter) const {
-    if (succeeded(match(op))) {
-      rewrite(op, rewriter);
-      return success();
-    }
-    return failure();
-  }
+                                        PatternRewriter &rewriter) const = 0;
 };
 } // namespace detail
 
@@ -356,6 +348,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
 template <typename SourceOp>
 struct OpRewritePattern
     : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+  using SplitMatchAndRewrite =
+      detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
+
   /// Patterns must specify the root operation name they match against, and can
   /// also specify the benefit of the pattern matching and a list of generated
   /// ops.
@@ -371,6 +366,9 @@ struct OpRewritePattern
 template <typename SourceOp>
 struct OpInterfaceRewritePattern
     : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+  using SplitMatchAndRewrite =
+      detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
+
   OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
       : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
             Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9a6975dcf8dfa..d705480cb137e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -528,24 +528,72 @@ class TypeConverter {
 // Conversion Patterns
 //===----------------------------------------------------------------------===//
 
+namespace detail {
+/// Helper class that derives from a ConversionRewritePattern class and
+/// provides separate `match` and `rewrite` entry points instead of a combined
+/// `matchAndRewrite`.
+template <typename PatternT>
+class ConversionSplitMatchAndRewriteImpl : public PatternT {
+  using PatternT::PatternT;
+
+  /// Rewrite the IR rooted at the specified operation with the result of
+  /// this pattern, generating any new operations with the specified
+  /// rewriter.
+  virtual void rewrite(typename PatternT::OperationT op,
+                       typename PatternT::OpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    // One of the two `rewrite` functions must be implemented.
+    llvm_unreachable("rewrite is not implemented");
+  }
+
+  virtual void rewrite(typename PatternT::OperationT op,
+                       typename PatternT::OneToNOpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    if constexpr (std::is_same<typename PatternT::OpAdaptor,
+                               ArrayRef<Value>>::value) {
+      rewrite(op, PatternT::getOneToOneAdaptorOperands(adaptor), rewriter);
+    } else {
+      SmallVector<Value> oneToOneOperands =
+          PatternT::getOneToOneAdaptorOperands(adaptor.getOperands());
+      rewrite(op, typename PatternT::OpAdaptor(oneToOneOperands, adaptor),
+              rewriter);
+    }
+  }
+
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind().
+  virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
+
+  LogicalResult
+  matchAndRewrite(typename PatternT::OperationT op,
+                  typename PatternT::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    llvm_unreachable("1:1 matchAndRewrite entry point is never used");
+  }
+
+  LogicalResult
+  matchAndRewrite(typename PatternT::OperationT op,
+                  typename PatternT::OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (succeeded(match(op))) {
+      rewrite(op, adaptor, rewriter);
+      return success();
+    }
+    return failure();
+  }
+};
+} // namespace detail
+
 /// Base class for the conversion patterns. This pattern class enables type
 /// conversions, and other uses specific to the conversion framework. As such,
 /// patterns of this type can only be used with the 'apply*' methods below.
 class ConversionPattern : public RewritePattern {
 public:
-  /// Hook for derived classes to implement rewriting. `op` is the (first)
-  /// operation matched by the pattern, `operands` is a list of the rewritten
-  /// operand values that are passed to `op`, `rewriter` can be used to emit the
-  /// new operations. This function should not fail. If some specific cases of
-  /// the operation are not supported, these cases should not be matched.
-  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("unimplemented rewrite");
-  }
-  virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
-  }
+  using OperationT = Operation *;
+  using OpAdaptor = ArrayRef<Value>;
+  using OneToNOpAdaptor = ArrayRef<ValueRange>;
+  using SplitMatchAndRewrite =
+      detail::ConversionSplitMatchAndRewriteImpl<ConversionPattern>;
 
   /// Hook for derived classes to implement combined matching and rewriting.
   /// This overload supports only 1:1 replacements. The 1:N overload is called
@@ -554,10 +602,7 @@ class ConversionPattern : public RewritePattern {
   virtual LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, operands, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
 
   /// Hook for derived classes to implement combined matching and rewriting.
@@ -606,9 +651,6 @@ class ConversionPattern : public RewritePattern {
 protected:
   /// An optional type converter for use by this pattern.
   const TypeConverter *typeConverter = nullptr;
-
-private:
-  using RewritePattern::rewrite;
 };
 
 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
@@ -617,9 +659,12 @@ class ConversionPattern : public RewritePattern {
 template <typename SourceOp>
 class OpConversionPattern : public ConversionPattern {
 public:
+  using OperationT = SourceOp;
   using OpAdaptor = typename SourceOp::Adaptor;
   using OneToNOpAdaptor =
       typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+  using SplitMatchAndRewrite =
+      detail::ConversionSplitMatchAndRewriteImpl<OpConversionPattern<SourceOp>>;
 
   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -630,19 +675,6 @@ class OpConversionPattern : public ConversionPattern {
 
   /// Wrappers around the ConversionPattern methods that pass the derived op
   /// type.
-  LogicalResult match(Operation *op) const final {
-    return match(cast<SourceOp>(op));
-  }
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
-  }
-  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
-  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -657,28 +689,12 @@ class OpConversionPattern : public ConversionPattern {
                            rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. One of these must be
   /// overridden by the derived pattern class.
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override matchAndRewrite or a rewrite method");
-  }
-  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    SmallVector<Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, adaptor, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -708,14 +724,6 @@ class OpInterfaceConversionPattern : public ConversionPattern {
 
   /// Wrappers around the ConversionPattern methods that pass the derived op
   /// type.
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), operands, rewriter);
-  }
-  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), operands, rewriter);
-  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -727,23 +735,12 @@ class OpInterfaceConversionPattern : public ConversionPattern {
     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. One of these must be
   /// overridden by the derived pattern class.
-  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override matchAndRewrite or a rewrite method");
-  }
-  virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, operands, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cba71740f9380..734c4839f9a10 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,23 +41,25 @@ struct ArithToAMDGPUConversionPass final
   void runOnOperation() override;
 };
 
-struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
-  using OpRewritePattern::OpRewritePattern;
+struct ExtFOnFloat8RewritePattern final
+    : OpRewritePattern<arith::ExtFOp>::SplitMatchAndRewrite {
+  using SplitMatchAndRewrite::SplitMatchAndRewrite;
 
   Chipset chipset;
   ExtFOnFloat8...
[truncated]

@fschlimb
Copy link
Contributor

fschlimb commented Mar 5, 2025

I can confirm that the gcc (14) warning about hidden functions rewrite and matchAndRewrite are gone with this PR.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Mar 5, 2025
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the removal of those unreachable instances in the base, this looks like good cleanup.

So for folks with the existing split one, it's like a 3 line change to use the split one instead?

@matthias-springer matthias-springer force-pushed the users/matthias-springer/split-match-and-rewrite branch from 2fb9e11 to 5e025f8 Compare March 5, 2025 14:36
@matthias-springer
Copy link
Member Author

So for folks with the existing split one, it's like a 3 line change to use the split one instead?

That's right. Users should derive from RewritePattern::SplitMatchAndRewrite instead of RewritePattern, etc. That's it.

Copy link
Contributor

@River707 River707 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me. This does raise a question though, what actually needs separate match and rewrite these days? Do we even need to carry around this code? Seems like we could just deprecate this and remove it (it's mostly legacy from the early days anyways).

@matthias-springer matthias-springer merged commit a6151f4 into main Mar 6, 2025
7 of 11 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/split-match-and-rewrite branch March 6, 2025 07:48
matthias-springer added a commit that referenced this pull request Mar 7, 2025
Deprecate the `match` and `rewrite` functions. They mainly exist for
historic reasons. This PR also updates all remaining uses of in the MLIR
codebase.

This is addressing a
[comment](#129861 (review))
on an earlier PR.

Note for LLVM integration: `SplitMatchAndRewrite` will be deleted soon,
update your patterns to use `matchAndRewrite` instead of separate
`match` / `rewrite`.

---------

Co-authored-by: Jakub Kuderski <[email protected]>
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Mar 7, 2025
…031)

Deprecate the `match` and `rewrite` functions. They mainly exist for
historic reasons. This PR also updates all remaining uses of in the MLIR
codebase.

This is addressing a
[comment](llvm/llvm-project#129861 (review))
on an earlier PR.

Note for LLVM integration: `SplitMatchAndRewrite` will be deleted soon,
update your patterns to use `matchAndRewrite` instead of separate
`match` / `rewrite`.

---------

Co-authored-by: Jakub Kuderski <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants