Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][IR] Move match and rewrite functions into separate class #129861

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 2 additions & 34 deletions flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {

const fir::FIRToLLVMPassOptions &options;

using ConvertToLLVMPattern::match;
using ConvertToLLVMPattern::matchAndRewrite;
};

Expand All @@ -206,20 +205,6 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
options, benefit) {}

/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
rewrite(mlir::cast<SourceOp>(op),
OpAdaptor(operands, mlir::cast<SourceOp>(op)), rewriter);
}
void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::ValueRange> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
auto sourceOp = llvm::cast<SourceOp>(op);
rewrite(llvm::cast<SourceOp>(op), OneToNOpAdaptor(operands, sourceOp),
rewriter);
}
llvm::LogicalResult match(mlir::Operation *op) const final {
return match(mlir::cast<SourceOp>(op));
}
llvm::LogicalResult
matchAndRewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
Expand All @@ -235,28 +220,12 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// Methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual llvm::LogicalResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
llvm::SmallVector<mlir::Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual llvm::LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
if (mlir::failed(match(op)))
return mlir::failure();
rewrite(op, adaptor, rewriter);
return mlir::success();
llvm_unreachable("matchAndRewrite is not implemented");
}
virtual llvm::LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
Expand All @@ -268,7 +237,6 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {

private:
using ConvertFIRToLLVMPattern::matchAndRewrite;
using ConvertToLLVMPattern::match;
};

/// FIR conversion pattern template
Expand Down
40 changes: 8 additions & 32 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite(
/// during the entire pattern lifetime.
class ConvertToLLVMPattern : public ConversionPattern {
public:
using SplitMatchAndRewrite =
detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;

ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1);
Expand Down Expand Up @@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OperationT = SourceOp;
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
ConvertOpToLLVMPattern<SourceOp>>;

explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
Expand All @@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
benefit) {}

/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Expand All @@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
virtual LogicalResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(match(op)))
return failure();
rewrite(op, adaptor, rewriter);
return success();
llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
Expand All @@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
}

private:
using ConvertToLLVMPattern::match;
using ConvertToLLVMPattern::matchAndRewrite;
};

Expand Down
86 changes: 42 additions & 44 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,41 +234,50 @@ class Pattern {
// RewritePattern
//===----------------------------------------------------------------------===//

/// RewritePattern is the common base class for all DAG to DAG replacements.
/// There are two possible usages of this class:
/// * Multi-step RewritePattern with "match" and "rewrite"
/// - By overloading the "match" and "rewrite" functions, the user can
/// separate the concerns of matching and rewriting.
/// * Single-step RewritePattern with "matchAndRewrite"
/// - By overloading the "matchAndRewrite" function, the user can perform
/// the rewrite in the same call as the match.
///
class RewritePattern : public Pattern {
public:
virtual ~RewritePattern() = default;
namespace detail {
/// Helper class that derives from a RewritePattern class and provides separate
/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
template <typename PatternT>
class SplitMatchAndRewriteImpl : public PatternT {
using PatternT::PatternT;

/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
/// compiler error), it is emitted through the normal MLIR diagnostic
/// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
/// rewriter.
virtual void rewrite(typename PatternT::OperationT op,
PatternRewriter &rewriter) const = 0;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
virtual LogicalResult match(Operation *op) const;
virtual LogicalResult match(typename PatternT::OperationT op) const = 0;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). If successful, this
/// function will automatically perform the rewrite.
virtual LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
LogicalResult matchAndRewrite(typename PatternT::OperationT op,
PatternRewriter &rewriter) const final {
if (succeeded(match(op))) {
rewrite(op, rewriter);
return success();
}
return failure();
}
};
} // namespace detail

/// RewritePattern is the common base class for all DAG to DAG replacements.
/// By overloading the "matchAndRewrite" function, the user can perform the
/// rewrite in the same call as the match.
///
class RewritePattern : public Pattern {
public:
using OperationT = Operation *;
using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;

virtual ~RewritePattern() = default;

/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). If successful, this
/// function will automatically perform the rewrite.
virtual LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const = 0;

/// This method provides a convenient interface for creating and initializing
/// derived rewrite patterns of the given type `T`.
Expand Down Expand Up @@ -317,36 +326,19 @@ namespace detail {
/// class or Interface.
template <typename SourceOp>
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
using OperationT = SourceOp;
using RewritePattern::RewritePattern;

/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
/// Wrapper around the RewritePattern method that passes the derived op type.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual LogicalResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
/// Method that operates on the SourceOp type. Must be overridden by the
/// derived pattern class.
virtual LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const {
if (succeeded(match(op))) {
rewrite(op, rewriter);
return success();
}
return failure();
}
PatternRewriter &rewriter) const = 0;
};
} // namespace detail

Expand All @@ -356,6 +348,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
template <typename SourceOp>
struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
using SplitMatchAndRewrite =
detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;

/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching and a list of generated
/// ops.
Expand All @@ -371,6 +366,9 @@ struct OpRewritePattern
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
using SplitMatchAndRewrite =
detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;

OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
Expand Down
Loading