Skip to content

Commit

Permalink
[mlir][IR] Move match and rewrite functions into separate class (#…
Browse files Browse the repository at this point in the history
…129861)

The vast majority of rewrite / conversion patterns uses a combined
`matchAndRewrite` instead of separate `match` and `rewrite` functions.

This PR optimizes the code base for the most common case where users
implement a combined `matchAndRewrite`. There are no longer any `match`
and `rewrite` functions in `RewritePattern`, `ConversionPattern` and
their derived classes. Instead, there is a `SplitMatchAndRewriteImpl`
class that implements `matchAndRewrite` in terms of `match` and
`rewrite`.

Details:
* The `RewritePattern` and `ConversionPattern` classes are simpler
(fewer functions). Especially the `ConversionPattern` class, which now
has 5 fewer functions. (There were various `rewrite` overloads to
account for 1:1 / 1:N patterns.)
* There is a new class `SplitMatchAndRewriteImpl` that derives from
`RewritePattern` / `OpRewritePatern` / ..., along with a type alias
`RewritePattern::SplitMatchAndRewrite` for convenience.
* Fewer `llvm_unreachable` are needed throughout the code base. Instead,
we can use pure virtual functions. (In cases where users previously had
to implement `rewrite` or `matchAndRewrite`, etc.)
* This PR may also improve the number of [`-Woverload-virtual`
warnings](https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933)
that are produced by GCC. (To be confirmed...)

Note for LLVM integration: Patterns with separate `match` / `rewrite`
implementations, must derive from `X::SplitMatchAndRewrite` instead of
`X`.

---------

Co-authored-by: River Riddle <[email protected]>
  • Loading branch information
matthias-springer and River707 authored Mar 6, 2025
1 parent 87976ca commit a6151f4
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 226 deletions.
36 changes: 2 additions & 34 deletions flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {

const fir::FIRToLLVMPassOptions &options;

using ConvertToLLVMPattern::match;
using ConvertToLLVMPattern::matchAndRewrite;
};

Expand All @@ -206,20 +205,6 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
options, benefit) {}

/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
rewrite(mlir::cast<SourceOp>(op),
OpAdaptor(operands, mlir::cast<SourceOp>(op)), rewriter);
}
void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::ValueRange> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
auto sourceOp = llvm::cast<SourceOp>(op);
rewrite(llvm::cast<SourceOp>(op), OneToNOpAdaptor(operands, sourceOp),
rewriter);
}
llvm::LogicalResult match(mlir::Operation *op) const final {
return match(mlir::cast<SourceOp>(op));
}
llvm::LogicalResult
matchAndRewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
Expand All @@ -235,28 +220,12 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// Methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual llvm::LogicalResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
llvm::SmallVector<mlir::Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual llvm::LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
if (mlir::failed(match(op)))
return mlir::failure();
rewrite(op, adaptor, rewriter);
return mlir::success();
llvm_unreachable("matchAndRewrite is not implemented");
}
virtual llvm::LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
Expand All @@ -268,7 +237,6 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {

private:
using ConvertFIRToLLVMPattern::matchAndRewrite;
using ConvertToLLVMPattern::match;
};

/// FIR conversion pattern template
Expand Down
25 changes: 15 additions & 10 deletions mlir/docs/PatternRewriter.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,23 @@ possible cost and use the predicate to guard the match.
### Root Operation Name (Optional)

The name of the root operation that this pattern matches against. If specified,
only operations with the given root name will be provided to the `match` and
`rewrite` implementation. If not specified, any operation type may be provided.
The root operation name should be provided whenever possible, because it
simplifies the analysis of patterns when applying a cost model. To match any
only operations with the given root name will be provided to the
`matchAndRewrite` implementation. If not specified, any operation type may be
provided. The root operation name should be provided whenever possible, because
it simplifies the analysis of patterns when applying a cost model. To match any
operation type, a special tag must be provided to make the intent explicit:
`MatchAnyOpTypeTag`.

### `match` and `rewrite` implementation
### `matchAndRewrite` implementation

This is the chunk of code that matches a given root `Operation` and performs a
rewrite of the IR. A `RewritePattern` can specify this implementation either via
separate `match` and `rewrite` methods, or via a combined `matchAndRewrite`
method. When using the combined `matchAndRewrite` method, no IR mutation should
take place before the match is deemed successful. The combined `matchAndRewrite`
is useful when non-trivially recomputable information is required by the
matching and rewriting phase. See below for examples:
the `matchAndRewrite` method or via separate `match` and `rewrite` methods when
deriving from `RewritePattern::SplitMatchAndRewrite`. When using the combined
`matchAndRewrite` method, no IR mutation should take place before the match is
deemed successful. The combined `matchAndRewrite` is useful when non-trivially
recomputable information is required by the matching and rewriting phase. See
below for examples:

```c++
class MyPattern : public RewritePattern {
Expand Down Expand Up @@ -105,6 +106,10 @@ Within the `rewrite` section of a pattern, the following constraints apply:
`eraseOp`) should be used instead.
* The root operation is required to either be: updated in-place, replaced, or
erased.
* `matchAndRewrite` must return "success" if and only if the IR was modified.
`match` must return "success" if and only if the IR is going to be modified
during `rewrite`.
### Application Recursion
Expand Down
40 changes: 8 additions & 32 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite(
/// during the entire pattern lifetime.
class ConvertToLLVMPattern : public ConversionPattern {
public:
using SplitMatchAndRewrite =
detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;

ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1);
Expand Down Expand Up @@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OperationT = SourceOp;
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
ConvertOpToLLVMPattern<SourceOp>>;

explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
Expand All @@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
benefit) {}

/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Expand All @@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
virtual LogicalResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(match(op)))
return failure();
rewrite(op, adaptor, rewriter);
return success();
llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
Expand All @@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
}

private:
using ConvertToLLVMPattern::match;
using ConvertToLLVMPattern::matchAndRewrite;
};

Expand Down
94 changes: 47 additions & 47 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,41 +234,52 @@ class Pattern {
// RewritePattern
//===----------------------------------------------------------------------===//

/// RewritePattern is the common base class for all DAG to DAG replacements.
/// There are two possible usages of this class:
/// * Multi-step RewritePattern with "match" and "rewrite"
/// - By overloading the "match" and "rewrite" functions, the user can
/// separate the concerns of matching and rewriting.
/// * Single-step RewritePattern with "matchAndRewrite"
/// - By overloading the "matchAndRewrite" function, the user can perform
/// the rewrite in the same call as the match.
///
class RewritePattern : public Pattern {
public:
virtual ~RewritePattern() = default;
namespace detail {
/// Helper class that derives from a RewritePattern class and provides separate
/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
template <typename PatternT>
class SplitMatchAndRewriteImpl : public PatternT {
using PatternT::PatternT;

/// Attempt to match against IR rooted at the specified operation, which is
/// the same operation kind as getRootKind().
///
/// Note: This function must not modify the IR.
virtual LogicalResult match(typename PatternT::OperationT op) const = 0;

/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
virtual LogicalResult match(Operation *op) const;
/// rewriter.
virtual void rewrite(typename PatternT::OperationT op,
PatternRewriter &rewriter) const = 0;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). If successful, this
/// function will automatically perform the rewrite.
virtual LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
LogicalResult matchAndRewrite(typename PatternT::OperationT op,
PatternRewriter &rewriter) const final {
if (succeeded(match(op))) {
rewrite(op, rewriter);
return success();
}
return failure();
}
};
} // namespace detail

/// RewritePattern is the common base class for all DAG to DAG replacements.
class RewritePattern : public Pattern {
public:
using OperationT = Operation *;
using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;

virtual ~RewritePattern() = default;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). If successful, perform
/// the rewrite.
///
/// Note: Implementations must modify the IR if and only if the function
/// returns "success".
virtual LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const = 0;

/// This method provides a convenient interface for creating and initializing
/// derived rewrite patterns of the given type `T`.
Expand Down Expand Up @@ -317,36 +328,19 @@ namespace detail {
/// class or Interface.
template <typename SourceOp>
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
using OperationT = SourceOp;
using RewritePattern::RewritePattern;

/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
/// Wrapper around the RewritePattern method that passes the derived op type.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual LogicalResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
/// Method that operates on the SourceOp type. Must be overridden by the
/// derived pattern class.
virtual LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const {
if (succeeded(match(op))) {
rewrite(op, rewriter);
return success();
}
return failure();
}
PatternRewriter &rewriter) const = 0;
};
} // namespace detail

Expand All @@ -356,6 +350,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
template <typename SourceOp>
struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
using SplitMatchAndRewrite =
detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;

/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching and a list of generated
/// ops.
Expand All @@ -371,6 +368,9 @@ struct OpRewritePattern
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
using SplitMatchAndRewrite =
detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;

OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
Expand Down
Loading

0 comments on commit a6151f4

Please sign in to comment.