-
Notifications
You must be signed in to change notification settings - Fork 12.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Deprecate match
and rewrite
functions
#130031
[mlir][IR] Deprecate match
and rewrite
functions
#130031
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesDeprecate the This is addressing a comment on an earlier PR. Note for LLVM integration: Patch is 34.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130031.diff 11 Files Affected:
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index abacd5a82c61e..f67d1411b3065 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -179,13 +179,13 @@ updated/remapped operands of an operation, such as when the types of results
defined by an operation have changed. The general Rewrite Patterns can no longer
be used in these situations, as the types of the operands of the operation being
matched will not correspond with those expected by the user. This pattern
-provides, as an additional argument to the `matchAndRewrite` and `rewrite`
-methods, the list of operands that the operation should use after conversion. If
-an operand was the result of a non-converted operation, for example if it was
-already legal, the original operand is used. This means that the operands
-provided always have a 1-1 non-null correspondence with the operands on the
-operation. The original operands of the operation are still intact and may be
-inspected as normal. These patterns also utilize a special `PatternRewriter`,
+provides, as an additional argument to the `matchAndRewrite` method, the list
+of operands that the operation should use after conversion. If an operand was
+the result of a non-converted operation, for example if it was already legal,
+the original operand is used. This means that the operands provided always have
+a 1-1 non-null correspondence with the operands on the operation. The original
+operands of the operation are still intact and may be inspected as normal.
+These patterns also utilize a special `PatternRewriter`,
`ConversionPatternRewriter`, that provides special hooks for use with the
conversion infrastructure.
diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index af0f56466e0cb..105a554b95851 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -48,13 +48,9 @@ operation type, a special tag must be provided to make the intent explicit:
### `matchAndRewrite` implementation
This is the chunk of code that matches a given root `Operation` and performs a
-rewrite of the IR. A `RewritePattern` can specify this implementation either via
-the `matchAndRewrite` method or via separate `match` and `rewrite` methods when
-deriving from `RewritePattern::SplitMatchAndRewrite`. When using the combined
-`matchAndRewrite` method, no IR mutation should take place before the match is
-deemed successful. The combined `matchAndRewrite` is useful when non-trivially
-recomputable information is required by the matching and rewriting phase. See
-below for examples:
+rewrite of the IR. A `RewritePattern` can specify this implementation via the
+`matchAndRewrite` method. No IR mutation should take place before the match is
+deemed successful. See below for examples:
```c++
class MyPattern : public RewritePattern {
@@ -67,21 +63,6 @@ public:
MyPattern(PatternBenefit benefit)
: RewritePattern(benefit, MatchAnyOpTypeTag()) {}
- /// In this section, the `match` and `rewrite` implementation is specified
- /// using the separate hooks.
- LogicalResult match(Operation *op) const override {
- // The `match` method returns `success()` if the pattern is a match, failure
- // otherwise.
- // ...
- }
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- // The `rewrite` method performs mutations on the IR rooted at `op` using
- // the provided rewriter. All mutations must go through the provided
- // rewriter.
- }
-
- /// In this section, the `match` and `rewrite` implementation is specified
- /// using a single hook.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
// The `matchAndRewrite` method performs both the matching and the mutation.
// Note that the match must reach a successful point before IR mutation may
@@ -92,12 +73,6 @@ public:
#### Restrictions
-Within the `match` section of a pattern, the following constraints apply:
-
-* No mutation of the IR is allowed.
-
-Within the `rewrite` section of a pattern, the following constraints apply:
-
* All IR mutations, including creation, *must* be performed by the given
`PatternRewriter`. This class provides hooks for performing all of the
possible mutations that may take place within a pattern. For example, this
@@ -107,8 +82,6 @@ Within the `rewrite` section of a pattern, the following constraints apply:
* The root operation is required to either be: updated in-place, replaced, or
erased.
* `matchAndRewrite` must return "success" if and only if the IR was modified.
- `match` must return "success" if and only if the IR is going to be modified
- during `rewrite`.
### Application Recursion
diff --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md
index 604148bd9c600..493f9d5687374 100644
--- a/mlir/docs/Tutorials/QuickstartRewrites.md
+++ b/mlir/docs/Tutorials/QuickstartRewrites.md
@@ -216,25 +216,6 @@ In case ODS patterns and `matchAndRewrite`-style functions are not sufficient
you can also specify rewrites as a general set of `RewritePattern`s:
```c++
-/// Multi-step rewrite using "match" and "rewrite". This allows for separating
-/// the concerns of matching and rewriting.
-struct ConvertTFLeakyRelu : public RewritePattern {
- ConvertTFLeakyRelu(MLIRContext *context)
- : RewritePattern("tf.LeakyRelu", 1, context) {}
-
- LogicalResult match(Operation *op) const override {
- return success();
- }
-
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
- op, op->getResult(0).getType(), op->getOperand(0),
- /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
- }
-};
-
-/// Single-step rewrite with "matchAndRewrite". This allows for performing the
-/// rewrite immediately upon a successful match.
struct ConvertTFLeakyRelu : public RewritePattern {
ConvertTFLeakyRelu(MLIRContext *context)
: RewritePattern("tf.LeakyRelu", 1, context) {}
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 8f82176f3b75f..e78f174ff8586 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -40,6 +40,8 @@ LogicalResult oneToOneRewrite(
/// during the entire pattern lifetime.
class ConvertToLLVMPattern : public ConversionPattern {
public:
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
@@ -149,6 +151,9 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
ConvertOpToLLVMPattern<SourceOp>>;
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 792b50d38817e..45d1471ecc1ef 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -237,6 +237,9 @@ class Pattern {
namespace detail {
/// Helper class that derives from a RewritePattern class and provides separate
/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
+///
+/// This class is deprecated. Use `matchAndRewrite` instead of separate `match`
+/// and `rewrite`.
template <typename PatternT>
class SplitMatchAndRewriteImpl : public PatternT {
using PatternT::PatternT;
@@ -245,13 +248,15 @@ class SplitMatchAndRewriteImpl : public PatternT {
/// the same operation kind as getRootKind().
///
/// Note: This function must not modify the IR.
- virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
+ virtual LogicalResult
+ match(typename PatternT::OperationT op) const = 0;
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// rewriter.
- virtual void rewrite(typename PatternT::OperationT op,
- PatternRewriter &rewriter) const = 0;
+ virtual void
+ rewrite(typename PatternT::OperationT op,
+ PatternRewriter &rewriter) const = 0;
LogicalResult matchAndRewrite(typename PatternT::OperationT op,
PatternRewriter &rewriter) const final {
@@ -268,6 +273,9 @@ class SplitMatchAndRewriteImpl : public PatternT {
class RewritePattern : public Pattern {
public:
using OperationT = Operation *;
+
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
virtual ~RewritePattern() = default;
@@ -350,6 +358,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
template <typename SourceOp>
struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
@@ -368,6 +379,9 @@ struct OpRewritePattern
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 120709bbe5b67..f54397e942ae0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -598,6 +598,9 @@ class ConversionPattern : public RewritePattern {
using OperationT = Operation *;
using OpAdaptor = ArrayRef<Value>;
using OneToNOpAdaptor = ArrayRef<ValueRange>;
+
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::ConversionSplitMatchAndRewriteImpl<ConversionPattern>;
@@ -669,6 +672,9 @@ class OpConversionPattern : public ConversionPattern {
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::ConversionSplitMatchAndRewriteImpl<OpConversionPattern<SourceOp>>;
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 734c4839f9a10..27be54728c1a1 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,48 +41,46 @@ struct ArithToAMDGPUConversionPass final
void runOnOperation() override;
};
-struct ExtFOnFloat8RewritePattern final
- : OpRewritePattern<arith::ExtFOp>::SplitMatchAndRewrite {
- using SplitMatchAndRewrite::SplitMatchAndRewrite;
+struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
Chipset chipset;
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
- : SplitMatchAndRewrite::SplitMatchAndRewrite(ctx), chipset(chipset) {}
+ : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
- LogicalResult match(arith::ExtFOp op) const override;
- void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const override;
};
-struct TruncFToFloat8RewritePattern final
- : OpRewritePattern<arith::TruncFOp>::SplitMatchAndRewrite {
+struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
bool saturateFP8 = false;
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
Chipset chipset)
- : SplitMatchAndRewrite::SplitMatchAndRewrite(ctx),
- saturateFP8(saturateFP8), chipset(chipset) {}
+ : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
+ chipset(chipset) {}
Chipset chipset;
- LogicalResult match(arith::TruncFOp op) const override;
- void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const override;
};
struct TruncfToFloat16RewritePattern final
- : public OpRewritePattern<arith::TruncFOp>::SplitMatchAndRewrite {
+ : public OpRewritePattern<arith::TruncFOp> {
- using SplitMatchAndRewrite::SplitMatchAndRewrite;
+ using OpRewritePattern::OpRewritePattern;
- LogicalResult match(arith::TruncFOp op) const override;
- void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const override;
};
} // end namespace
-static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
+static bool isSupportedF8(Type elementType, Chipset chipset) {
if (chipset == kGfx942)
- return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
+ return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);
if (hasOcpFp8(chipset))
- return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
- return failure();
+ return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);
+ return false;
}
static Value castF32To(Type elementType, Value f32, Location loc,
@@ -96,35 +94,36 @@ static Value castF32To(Type elementType, Value f32, Location loc,
llvm_unreachable("The only 32-bit float type is f32");
}
-LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
+LogicalResult
+ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const {
Type inType = op.getIn().getType();
- if (auto inVecType = dyn_cast<VectorType>(inType)) {
+ auto inVecType = dyn_cast<VectorType>(inType);
+ if (inVecType) {
if (inVecType.isScalable())
return failure();
inType = inVecType.getElementType();
}
- return isSupportedF8(inType, chipset);
-}
+ if (!isSupportedF8(inType, chipset))
+ return failure();
-void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
- PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- auto inType = dyn_cast<VectorType>(in.getType());
- if (!inType) {
+ if (!inVecType) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
- return rewriter.replaceOp(op, result);
+ rewriter.replaceOp(op, result);
+ return success();
}
- int64_t numElements = inType.getNumElements();
+ int64_t numElements = inVecType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
VectorType outType = cast<VectorType>(op.getOut().getType());
- if (inType.getShape().empty()) {
+ if (inVecType.getShape().empty()) {
Value zerodSplat =
rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
Value scalarIn =
@@ -133,17 +132,18 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
ArrayRef<int64_t>{});
- return rewriter.replaceOp(op, result);
+ rewriter.replaceOp(op, result);
+ return success();
}
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
- if (inType.getRank() > 1) {
- inType = VectorType::get(SmallVector<int64_t>{numElements},
- inType.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
+ if (inVecType.getRank() > 1) {
+ inVecType = VectorType::get(SmallVector<int64_t>{numElements},
+ inVecType.getElementType());
+ in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
}
for (int64_t i = 0; i < numElements; i += 4) {
@@ -158,11 +158,12 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
}
}
- if (inType.getRank() != outType.getRank()) {
+ if (inVecType.getRank() != outType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
}
rewriter.replaceOp(op, result);
+ return success();
}
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
@@ -222,12 +223,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
return res;
}
-LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
+LogicalResult
+TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const {
// Only supporting default rounding mode as of now.
if (op.getRoundingmodeAttr())
return failure();
Type outType = op.getOut().getType();
- if (auto outVecType = dyn_cast<VectorType>(outType)) {
+ auto outVecType = dyn_cast<VectorType>(outType);
+ if (outVecType) {
if (outVecType.isScalable())
return failure();
outType = outVecType.getElementType();
@@ -237,11 +241,9 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
- return isSupportedF8(outType, chipset);
-}
+ if (!isSupportedF8(outType, chipset))
+ return failure();
-void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
- PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
@@ -255,13 +257,14 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
/*existing=*/nullptr);
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
- return rewriter.replaceOp(op, result);
+ rewriter.replaceOp(op, result);
+ return success();
}
- VectorType outType = cast<VectorType>(op.getOut().getType());
- int64_t numElements = outType.getNumElements();
+
+ int64_t numElements = outVecType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
- if (outType.getShape().empty()) {
+ if (outVecType.getShape().empty()) {
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
@@ -269,11 +272,12 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
ArrayRef<int64_t>{});
- return rewriter.replaceOp(op, result);
+ rewriter.replaceOp(op, result);
+ return success();
}
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
- outType.getElementType());
+ ...
[truncated]
|
@llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) ChangesDeprecate the This is addressing a comment on an earlier PR. Note for LLVM integration: Patch is 34.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130031.diff 11 Files Affected:
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index abacd5a82c61e..f67d1411b3065 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -179,13 +179,13 @@ updated/remapped operands of an operation, such as when the types of results
defined by an operation have changed. The general Rewrite Patterns can no longer
be used in these situations, as the types of the operands of the operation being
matched will not correspond with those expected by the user. This pattern
-provides, as an additional argument to the `matchAndRewrite` and `rewrite`
-methods, the list of operands that the operation should use after conversion. If
-an operand was the result of a non-converted operation, for example if it was
-already legal, the original operand is used. This means that the operands
-provided always have a 1-1 non-null correspondence with the operands on the
-operation. The original operands of the operation are still intact and may be
-inspected as normal. These patterns also utilize a special `PatternRewriter`,
+provides, as an additional argument to the `matchAndRewrite` method, the list
+of operands that the operation should use after conversion. If an operand was
+the result of a non-converted operation, for example if it was already legal,
+the original operand is used. This means that the operands provided always have
+a 1-1 non-null correspondence with the operands on the operation. The original
+operands of the operation are still intact and may be inspected as normal.
+These patterns also utilize a special `PatternRewriter`,
`ConversionPatternRewriter`, that provides special hooks for use with the
conversion infrastructure.
diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index af0f56466e0cb..105a554b95851 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -48,13 +48,9 @@ operation type, a special tag must be provided to make the intent explicit:
### `matchAndRewrite` implementation
This is the chunk of code that matches a given root `Operation` and performs a
-rewrite of the IR. A `RewritePattern` can specify this implementation either via
-the `matchAndRewrite` method or via separate `match` and `rewrite` methods when
-deriving from `RewritePattern::SplitMatchAndRewrite`. When using the combined
-`matchAndRewrite` method, no IR mutation should take place before the match is
-deemed successful. The combined `matchAndRewrite` is useful when non-trivially
-recomputable information is required by the matching and rewriting phase. See
-below for examples:
+rewrite of the IR. A `RewritePattern` can specify this implementation via the
+`matchAndRewrite` method. No IR mutation should take place before the match is
+deemed successful. See below for examples:
```c++
class MyPattern : public RewritePattern {
@@ -67,21 +63,6 @@ public:
MyPattern(PatternBenefit benefit)
: RewritePattern(benefit, MatchAnyOpTypeTag()) {}
- /// In this section, the `match` and `rewrite` implementation is specified
- /// using the separate hooks.
- LogicalResult match(Operation *op) const override {
- // The `match` method returns `success()` if the pattern is a match, failure
- // otherwise.
- // ...
- }
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- // The `rewrite` method performs mutations on the IR rooted at `op` using
- // the provided rewriter. All mutations must go through the provided
- // rewriter.
- }
-
- /// In this section, the `match` and `rewrite` implementation is specified
- /// using a single hook.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
// The `matchAndRewrite` method performs both the matching and the mutation.
// Note that the match must reach a successful point before IR mutation may
@@ -92,12 +73,6 @@ public:
#### Restrictions
-Within the `match` section of a pattern, the following constraints apply:
-
-* No mutation of the IR is allowed.
-
-Within the `rewrite` section of a pattern, the following constraints apply:
-
* All IR mutations, including creation, *must* be performed by the given
`PatternRewriter`. This class provides hooks for performing all of the
possible mutations that may take place within a pattern. For example, this
@@ -107,8 +82,6 @@ Within the `rewrite` section of a pattern, the following constraints apply:
* The root operation is required to either be: updated in-place, replaced, or
erased.
* `matchAndRewrite` must return "success" if and only if the IR was modified.
- `match` must return "success" if and only if the IR is going to be modified
- during `rewrite`.
### Application Recursion
diff --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md
index 604148bd9c600..493f9d5687374 100644
--- a/mlir/docs/Tutorials/QuickstartRewrites.md
+++ b/mlir/docs/Tutorials/QuickstartRewrites.md
@@ -216,25 +216,6 @@ In case ODS patterns and `matchAndRewrite`-style functions are not sufficient
you can also specify rewrites as a general set of `RewritePattern`s:
```c++
-/// Multi-step rewrite using "match" and "rewrite". This allows for separating
-/// the concerns of matching and rewriting.
-struct ConvertTFLeakyRelu : public RewritePattern {
- ConvertTFLeakyRelu(MLIRContext *context)
- : RewritePattern("tf.LeakyRelu", 1, context) {}
-
- LogicalResult match(Operation *op) const override {
- return success();
- }
-
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
- op, op->getResult(0).getType(), op->getOperand(0),
- /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
- }
-};
-
-/// Single-step rewrite with "matchAndRewrite". This allows for performing the
-/// rewrite immediately upon a successful match.
struct ConvertTFLeakyRelu : public RewritePattern {
ConvertTFLeakyRelu(MLIRContext *context)
: RewritePattern("tf.LeakyRelu", 1, context) {}
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 8f82176f3b75f..e78f174ff8586 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -40,6 +40,8 @@ LogicalResult oneToOneRewrite(
/// during the entire pattern lifetime.
class ConvertToLLVMPattern : public ConversionPattern {
public:
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
@@ -149,6 +151,9 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
ConvertOpToLLVMPattern<SourceOp>>;
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 792b50d38817e..45d1471ecc1ef 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -237,6 +237,9 @@ class Pattern {
namespace detail {
/// Helper class that derives from a RewritePattern class and provides separate
/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
+///
+/// This class is deprecated. Use `matchAndRewrite` instead of separate `match`
+/// and `rewrite`.
template <typename PatternT>
class SplitMatchAndRewriteImpl : public PatternT {
using PatternT::PatternT;
@@ -245,13 +248,15 @@ class SplitMatchAndRewriteImpl : public PatternT {
/// the same operation kind as getRootKind().
///
/// Note: This function must not modify the IR.
- virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
+ virtual LogicalResult
+ match(typename PatternT::OperationT op) const = 0;
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// rewriter.
- virtual void rewrite(typename PatternT::OperationT op,
- PatternRewriter &rewriter) const = 0;
+ virtual void
+ rewrite(typename PatternT::OperationT op,
+ PatternRewriter &rewriter) const = 0;
LogicalResult matchAndRewrite(typename PatternT::OperationT op,
PatternRewriter &rewriter) const final {
@@ -268,6 +273,9 @@ class SplitMatchAndRewriteImpl : public PatternT {
class RewritePattern : public Pattern {
public:
using OperationT = Operation *;
+
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
virtual ~RewritePattern() = default;
@@ -350,6 +358,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
template <typename SourceOp>
struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
@@ -368,6 +379,9 @@ struct OpRewritePattern
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 120709bbe5b67..f54397e942ae0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -598,6 +598,9 @@ class ConversionPattern : public RewritePattern {
using OperationT = Operation *;
using OpAdaptor = ArrayRef<Value>;
using OneToNOpAdaptor = ArrayRef<ValueRange>;
+
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::ConversionSplitMatchAndRewriteImpl<ConversionPattern>;
@@ -669,6 +672,9 @@ class OpConversionPattern : public ConversionPattern {
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+
+ /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
+ /// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::ConversionSplitMatchAndRewriteImpl<OpConversionPattern<SourceOp>>;
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 734c4839f9a10..27be54728c1a1 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,48 +41,46 @@ struct ArithToAMDGPUConversionPass final
void runOnOperation() override;
};
-struct ExtFOnFloat8RewritePattern final
- : OpRewritePattern<arith::ExtFOp>::SplitMatchAndRewrite {
- using SplitMatchAndRewrite::SplitMatchAndRewrite;
+struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
Chipset chipset;
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
- : SplitMatchAndRewrite::SplitMatchAndRewrite(ctx), chipset(chipset) {}
+ : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
- LogicalResult match(arith::ExtFOp op) const override;
- void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const override;
};
-struct TruncFToFloat8RewritePattern final
- : OpRewritePattern<arith::TruncFOp>::SplitMatchAndRewrite {
+struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
bool saturateFP8 = false;
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
Chipset chipset)
- : SplitMatchAndRewrite::SplitMatchAndRewrite(ctx),
- saturateFP8(saturateFP8), chipset(chipset) {}
+ : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
+ chipset(chipset) {}
Chipset chipset;
- LogicalResult match(arith::TruncFOp op) const override;
- void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const override;
};
struct TruncfToFloat16RewritePattern final
- : public OpRewritePattern<arith::TruncFOp>::SplitMatchAndRewrite {
+ : public OpRewritePattern<arith::TruncFOp> {
- using SplitMatchAndRewrite::SplitMatchAndRewrite;
+ using OpRewritePattern::OpRewritePattern;
- LogicalResult match(arith::TruncFOp op) const override;
- void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const override;
};
} // end namespace
-static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
+static bool isSupportedF8(Type elementType, Chipset chipset) {
if (chipset == kGfx942)
- return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
+ return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);
if (hasOcpFp8(chipset))
- return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
- return failure();
+ return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);
+ return false;
}
static Value castF32To(Type elementType, Value f32, Location loc,
@@ -96,35 +94,36 @@ static Value castF32To(Type elementType, Value f32, Location loc,
llvm_unreachable("The only 32-bit float type is f32");
}
-LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
+LogicalResult
+ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const {
Type inType = op.getIn().getType();
- if (auto inVecType = dyn_cast<VectorType>(inType)) {
+ auto inVecType = dyn_cast<VectorType>(inType);
+ if (inVecType) {
if (inVecType.isScalable())
return failure();
inType = inVecType.getElementType();
}
- return isSupportedF8(inType, chipset);
-}
+ if (!isSupportedF8(inType, chipset))
+ return failure();
-void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
- PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- auto inType = dyn_cast<VectorType>(in.getType());
- if (!inType) {
+ if (!inVecType) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
- return rewriter.replaceOp(op, result);
+ rewriter.replaceOp(op, result);
+ return success();
}
- int64_t numElements = inType.getNumElements();
+ int64_t numElements = inVecType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
VectorType outType = cast<VectorType>(op.getOut().getType());
- if (inType.getShape().empty()) {
+ if (inVecType.getShape().empty()) {
Value zerodSplat =
rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
Value scalarIn =
@@ -133,17 +132,18 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
ArrayRef<int64_t>{});
- return rewriter.replaceOp(op, result);
+ rewriter.replaceOp(op, result);
+ return success();
}
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
- if (inType.getRank() > 1) {
- inType = VectorType::get(SmallVector<int64_t>{numElements},
- inType.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
+ if (inVecType.getRank() > 1) {
+ inVecType = VectorType::get(SmallVector<int64_t>{numElements},
+ inVecType.getElementType());
+ in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
}
for (int64_t i = 0; i < numElements; i += 4) {
@@ -158,11 +158,12 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
}
}
- if (inType.getRank() != outType.getRank()) {
+ if (inVecType.getRank() != outType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
}
rewriter.replaceOp(op, result);
+ return success();
}
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
@@ -222,12 +223,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
return res;
}
-LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
+LogicalResult
+TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const {
// Only supporting default rounding mode as of now.
if (op.getRoundingmodeAttr())
return failure();
Type outType = op.getOut().getType();
- if (auto outVecType = dyn_cast<VectorType>(outType)) {
+ auto outVecType = dyn_cast<VectorType>(outType);
+ if (outVecType) {
if (outVecType.isScalable())
return failure();
outType = outVecType.getElementType();
@@ -237,11 +241,9 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
- return isSupportedF8(outType, chipset);
-}
+ if (!isSupportedF8(outType, chipset))
+ return failure();
-void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
- PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
@@ -255,13 +257,14 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
/*existing=*/nullptr);
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
- return rewriter.replaceOp(op, result);
+ rewriter.replaceOp(op, result);
+ return success();
}
- VectorType outType = cast<VectorType>(op.getOut().getType());
- int64_t numElements = outType.getNumElements();
+
+ int64_t numElements = outVecType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
- if (outType.getShape().empty()) {
+ if (outVecType.getShape().empty()) {
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
@@ -269,11 +272,12 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
ArrayRef<int64_t>{});
- return rewriter.replaceOp(op, result);
+ rewriter.replaceOp(op, result);
+ return success();
}
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
- outType.getElementType());
+ ...
[truncated]
|
90dcd17
to
eabdc70
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me
/// This class is deprecated. Use `matchAndRewrite` instead of separate `match` | ||
/// and `rewrite`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you plan to add a deprecated attribute around this in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried that, but it gives warnings because matchAndRewrite
in this file calls match
and rewrite
.
Co-authored-by: Jakub Kuderski <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleanup!
Deprecate the
match
andrewrite
functions. They mainly exist for historic reasons. This PR also updates all remaining uses of in the MLIR codebase.This is addressing a comment on an earlier PR.
Note for LLVM integration:
SplitMatchAndRewrite
will be deleted soon, update your patterns to usematchAndRewrite
instead of separatematch
/rewrite
.