diff --git a/BUILD.bazel b/BUILD.bazel index 6057c87dd7..3a7f4375f1 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1058,6 +1058,7 @@ cc_library( "stablehlo/transforms/StablehloConvertToSignless.cpp", "stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp", "stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp", + "stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp", "stablehlo/transforms/StablehloLegalizeQuantToMath.cpp", "stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp", "stablehlo/transforms/StablehloLegalizeToVhlo.cpp", diff --git a/docs/generated/stablehlo_passes.md b/docs/generated/stablehlo_passes.md index 6613140556..bf92504799 100755 --- a/docs/generated/stablehlo_passes.md +++ b/docs/generated/stablehlo_passes.md @@ -85,6 +85,34 @@ long-term supported counterparts. ``` -fail-on-unused : Fail on (mostly) unused ops that are deprecated without any fallback. ``` +### `-stablehlo-legalize-qdq-to-quantized-op` + +_Fuse (de-quantize, floating-point operation and quantize) pattern into StableHLO quantized operation_ + +Fuse (de-quantize, floating-point operation and quantize) pattern into StableHLO quantized operation +Note: The pass does not delete any preexisting op. +For example, the following program + +```mlir +func.func @add(%arg0: tensor<16x16x!quant.uniform>) -> tensor<16x16x!quant.uniform> { + %0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> + %1 = stablehlo.abs %0 : tensor<16x16xf32> + %2 = stablehlo.uniform_quantize %1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + func.return %2 : tensor<16x16x!quant.uniform> +} +``` + +Will become: + +```mlir +func.func @add(%arg0: tensor<16x16x!quant.uniform>) -> tensor<16x16x!quant.uniform> { + %0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> + %1 = stablehlo.abs %0 : tensor<16x16xf32> + %2 = stablehlo.abs %arg0 : tensor<16x16x!quant.uniform> + %3 = stablehlo.uniform_quantize %1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + return %2 : tensor<16x16x!quant.uniform> +} +``` ### `-stablehlo-legalize-quant-to-math` _Convert from StableHLO quantized ops to StableHLO primitive math ops._ @@ -129,7 +157,7 @@ func.func @add(%arg0: tensor, %arg1: tensor) -> tensor { ``` ### `-stablehlo-legalize-quantized-op-to-qdq` -_Decompose StableHLO quantized ops using uniform quantize/dequantize ops._ +_Decompose quantized StableHLO operation to (de-quantize, floating-point operation and quantize) pattern._ Decompose StableHLO quantized programs using uniform quantize/dequantize operations. For example, the following program diff --git a/stablehlo/tests/transforms/stablehlo_legalize_qdq_to_quantized_op.mlir b/stablehlo/tests/transforms/stablehlo_legalize_qdq_to_quantized_op.mlir new file mode 100644 index 0000000000..9bf2e6f72f --- /dev/null +++ b/stablehlo/tests/transforms/stablehlo_legalize_qdq_to_quantized_op.mlir @@ -0,0 +1,129 @@ +// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-legalize-qdq-to-quantized-op | FileCheck %s --check-prefixes=CHECK + +// ----- + +// CHECK-LABEL @compose_quantized_abs_op +// CHECK: %[[abs0:.*]] = stablehlo.abs %arg0 : tensor<16x16x!quant.uniform> +// CHECK-NEXT: return %[[abs0]] : tensor<16x16x!quant.uniform> +func.func @compose_quantized_abs_op(%arg0: tensor<16x16x!quant.uniform>) -> tensor<16x16x!quant.uniform> { + %0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> + %1 = stablehlo.abs %0 : tensor<16x16xf32> + %2 = stablehlo.uniform_quantize %1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + func.return %2 : tensor<16x16x!quant.uniform> +} + +// ----- + +// CHECK-LABEL @failed_to_match_uniform_quant_op_operand_not_defined_by_op +// CHECK: %0 = stablehlo.uniform_quantize %arg0 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> +// CHECK-NEXT: return %0 : tensor<16x16x!quant.uniform> +func.func @failed_to_match_uniform_quant_op_operand_not_defined_by_op(%arg0: tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> { + %0 = stablehlo.uniform_quantize %arg0 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + func.return %0 : tensor<16x16x!quant.uniform> +} + +// ----- + +// CHECK-LABEL @failed_to_match_op_with_region +// CHECK: %0 = "stablehlo.all_reduce"(%arg0){{.*}}: tensor<1x2xi64>}> ({ +// CHECK-NEXT: ^bb0(%arg1: tensor, %arg2: tensor): +// CHECK-NEXT: %2 = stablehlo.add %arg1, %arg2 : tensor +// CHECK-NEXT: stablehlo.return %2 : tensor +// CHECK-NEXT: }) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %1 = stablehlo.uniform_quantize %0 : (tensor<4xf32>) -> tensor<4x!quant.uniform> +// CHECK-NEXT: return %1 : tensor<4x!quant.uniform> + +func.func @failed_to_match_op_with_region(%operand0 : tensor<4xf32>) -> (tensor<4x!quant.uniform>) { + %0 = stablehlo.uniform_quantize %operand0 : (tensor<4xf32>) -> tensor<4x!quant.uniform> + %1 = stablehlo.uniform_dequantize %0 : (tensor<4x!quant.uniform>) -> tensor<4xf32> + %2 = "stablehlo.all_reduce"(%operand0) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %3 = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %3 : tensor + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<4xf32>) -> tensor<4xf32> + %4 = stablehlo.uniform_quantize %2 : (tensor<4xf32>) -> tensor<4x!quant.uniform> + return %4 : tensor<4x!quant.uniform> +} + +// ----- + +// CHECK-LABEL failed_to_match_varidic_op +// CHECK: %0 = stablehlo.uniform_quantize %arg0 : (tensor<8x2xf32>) -> tensor<8x2x!quant.uniform> +// CHECK-NEXT: %1 = stablehlo.uniform_dequantize %0 : (tensor<8x2x!quant.uniform>) -> tensor<8x2xf32> +// CHECK-NEXT: %2 = stablehlo.uniform_quantize %arg1 : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK-NEXT: %3 = stablehlo.uniform_dequantize %2 : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> +// CHECK-NEXT: %4:2 = "stablehlo.all_gather"(%1, %3) {{.*}} : (tensor<8x2xf32>, tensor<2x2xf32>) -> (tensor<8x8xf32>, tensor<2x4xf32>) +// CHECK-NEXT: %5 = stablehlo.uniform_quantize %4#0 : (tensor<8x8xf32>) -> tensor<8x8x!quant.uniform> +// CHECK-NEXT: return %5, %4#1 : tensor<8x8x!quant.uniform>, tensor<2x4xf32> +func.func @failed_to_match_varidic_op(%arg0: tensor<8x2xf32>, %arg1: tensor<2x2xf32>) -> (tensor<8x8x!quant.uniform>, tensor<2x4xf32>) { + %0 = stablehlo.uniform_quantize %arg0 : (tensor<8x2xf32>) -> tensor<8x2x!quant.uniform> + %1 = stablehlo.uniform_dequantize %0 : (tensor<8x2x!quant.uniform>) -> tensor<8x2xf32> + %2 = stablehlo.uniform_quantize %arg1 : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + %4:2 = "stablehlo.all_gather"(%1, %3) { + all_gather_dim = 1 : i64, + channel_handle = #stablehlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<8x2xf32>, tensor<2x2xf32>) -> (tensor<8x8xf32>, tensor<2x4xf32>) + %5 = stablehlo.uniform_quantize %4#0 : (tensor<8x8xf32>) -> tensor<8x8x!quant.uniform> + func.return %5, %4#1 : tensor<8x8x!quant.uniform>, tensor<2x4xf32> +} + +// ----- + +// CHECK-LABEL @failed_to_match_operand_of_compute_op_already_quantized +// CHECK: %0 = stablehlo.uniform_quantize %arg0 : (tensor<1x8x8x207xf32>) -> tensor<1x8x8x207x!quant.uniform> +// CHECK-NEXT: %1 = stablehlo.uniform_dequantize %0 : (tensor<1x8x8x207x!quant.uniform>) -> tensor<1x8x8x207xf32> +// CHECK-NEXT: %2 = stablehlo.abs %arg1 : tensor<3x3x207x16x!quant.uniform> +// CHECK-NEXT: %3 = stablehlo.convolution(%1, %2) {{.*}} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16xf32> +// CHECK-NEXT: %4 = stablehlo.uniform_quantize %3 : (tensor<1x8x8x16xf32>) -> tensor<1x8x8x16x!quant.uniform> +// CHECK-NEXT: return %4 : tensor<1x8x8x16x!quant.uniform> +func.func @failed_to_match_operand_of_compute_op_already_quantized(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> { + %0 = stablehlo.uniform_quantize %arg0 : (tensor<1x8x8x207xf32>) -> tensor<1x8x8x207x!quant.uniform> + %1 = stablehlo.uniform_dequantize %0 : (tensor<1x8x8x207x!quant.uniform>) -> tensor<1x8x8x207xf32> + %2 = stablehlo.abs %arg1 : tensor<3x3x207x16x!quant.uniform> + %3 = stablehlo.convolution(%1, %2) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16xf32> + %4 = stablehlo.uniform_quantize %3 : (tensor<1x8x8x16xf32>) -> tensor<1x8x8x16x!quant.uniform> + func.return %4 : tensor<1x8x8x16x!quant.uniform> +} + +// ----- + +// CHECK-LABEL @failed_to_match_operand_not_defined_by_op +// CHECK: %0 = stablehlo.uniform_quantize %arg1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> +// CHECK-NEXT: %1 = stablehlo.uniform_dequantize %0 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> +// CHECK-NEXT: %2 = stablehlo.add %arg0, %1 : tensor<16x16xf32> +// CHECK-NEXT: %3 = stablehlo.uniform_quantize %2 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> +// CHECK-NEXT: return %3 : tensor<16x16x!quant.uniform> +func.func @failed_to_match_operand_not_defined_by_op(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> { + %1 = stablehlo.uniform_quantize %arg1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + %2 = stablehlo.uniform_dequantize %1 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> + %3 = stablehlo.add %arg0, %2 : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + %4 = stablehlo.uniform_quantize %3 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + func.return %4: tensor<16x16x!quant.uniform> +} + +// ----- + +// CHECK-LABEL @failed_to_match_defining_op_is_not_a_uniform_dequantized_op +// CHECK: %0 = stablehlo.abs %arg0 : tensor<16x16xf32> +// CHECK-NEXT: %1 = stablehlo.uniform_quantize %arg1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> +// CHECK-NEXT: %2 = stablehlo.uniform_dequantize %1 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> +// CHECK-NEXT: %3 = stablehlo.add %0, %2 : tensor<16x16xf32> +// CHECK-NEXT: %4 = stablehlo.uniform_quantize %3 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> +// CHECK-NEXT: return %4 : tensor<16x16x!quant.uniform> +func.func @failed_to_match_defining_op_is_not_a_uniform_dequantized_op(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> { + %0 = stablehlo.abs %arg0 : tensor<16x16xf32> + %1 = stablehlo.uniform_quantize %arg1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + %2 = stablehlo.uniform_dequantize %1 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> + %3 = stablehlo.add %0, %2 : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + %4 = stablehlo.uniform_quantize %3 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + func.return %4: tensor<16x16x!quant.uniform> +} diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index c2a008e941..72044d730f 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -41,6 +41,7 @@ add_mlir_dialect_library(StablehloPasses StablehloLegalizeDeprecatedOps.cpp StablehloLegalizeQuantToMath.cpp StablehloLegalizeQuantizedOpToQDQ.cpp + StablehloLegalizeQDQToQuantizedOp.cpp StablehloLegalizeToVhlo.cpp StablehloRefineArguments.cpp StablehloRefineShapes.cpp diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h index ca9d00812e..0945b5a285 100644 --- a/stablehlo/transforms/Passes.h +++ b/stablehlo/transforms/Passes.h @@ -72,6 +72,11 @@ void populateStablehloLegalizeQuantizedOpToQDQPatterns( RewritePatternSet *patterns, MLIRContext *context, PatternBenefit benefit = 1); +/// Collection of rewrite patterns for composing quantized StableHLO operations +/// using unform dequantize/quantize operations. +void populateStablehloLegalizeQDQToQuantizedOpPatterns( + RewritePatternSet *patterns, MLIRContext *context); + /// A subset of folding patterns for StableHLO that is necessary for shape /// refinement. void populateStablehloShapeFolderPatterns(RewritePatternSet *patterns, diff --git a/stablehlo/transforms/Passes.td b/stablehlo/transforms/Passes.td index 2c817b65ce..e186c0903c 100644 --- a/stablehlo/transforms/Passes.td +++ b/stablehlo/transforms/Passes.td @@ -229,7 +229,7 @@ def StablehloLegalizeQuantToMathPass : Pass<"stablehlo-legalize-quant-to-math", } def StablehloLegalizeQuantizedOpToQDQPass : Pass<"stablehlo-legalize-quantized-op-to-qdq", "mlir::func::FuncOp"> { - let summary = "Decompose StableHLO quantized ops using uniform quantize/dequantize ops."; + let summary = "Decompose quantized StableHLO operation to (de-quantize, floating-point operation and quantize) pattern."; let description = [{ Decompose StableHLO quantized programs using uniform quantize/dequantize @@ -258,3 +258,37 @@ def StablehloLegalizeQuantizedOpToQDQPass : Pass<"stablehlo-legalize-quantized-o "mlir::stablehlo::StablehloDialect", ]; } + +def StablehloLegalizeQDQToQuantizedOpPass : Pass<"stablehlo-legalize-qdq-to-quantized-op", "mlir::func::FuncOp"> { + let summary = "Fuse (de-quantize, floating-point operation and quantize) pattern into StableHLO quantized operation"; + + let description = [{ + Fuse (de-quantize, floating-point operation and quantize) pattern into StableHLO quantized operation + Note: The pass does not delete any preexisting op. + For example, the following program + + ```mlir + func.func @add(%arg0: tensor<16x16x!quant.uniform>) -> tensor<16x16x!quant.uniform> { + %0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> + %1 = stablehlo.abs %0 : tensor<16x16xf32> + %2 = stablehlo.uniform_quantize %1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + func.return %2 : tensor<16x16x!quant.uniform> + } + ``` + + Will become: + + ```mlir + func.func @add(%arg0: tensor<16x16x!quant.uniform>) -> tensor<16x16x!quant.uniform> { + %0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> + %1 = stablehlo.abs %0 : tensor<16x16xf32> + %2 = stablehlo.abs %arg0 : tensor<16x16x!quant.uniform> + %3 = stablehlo.uniform_quantize %1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + return %2 : tensor<16x16x!quant.uniform> + } + ``` + }]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + ]; +} diff --git a/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp b/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp new file mode 100644 index 0000000000..fba540988b --- /dev/null +++ b/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp @@ -0,0 +1,137 @@ +/* Copyright 2024 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" // Include for TypeConverter +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { + +#define GEN_PASS_DEF_STABLEHLOLEGALIZEQDQTOQUANTIZEDOPPASS +#include "stablehlo/transforms/Passes.h.inc" + +namespace { + +bool isAnyQuantizedTypes(TypeRange types) { + return llvm::any_of(types, [](Type type) { + return isa(getElementTypeOrSelf(type)); + }); +} + +struct QuantizedStablehloQDQToQuantizedOpConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(stablehlo::UniformQuantizeOp quantOp, + PatternRewriter& rewriter) const override { + // Matching sequence of ops: + // UniformDequantizeOp -> non-quantized Op -> UniformQuantizeOp. + // Start matching from a UniformQuantizeOp (`quantOp`). + // Get the Op (`computeOp`) which defines the inputs of the `quantOp`. + // Verify all inputs of the `computeOp` are produced by + // UniformDequantizeOp (`dequantOp`). + // Note: The pass does not delete any prexisting op. + auto* computeOp = quantOp->getOperand(0).getDefiningOp(); + if (!computeOp) + return rewriter.notifyMatchFailure( + quantOp, "requires operand to be defined by an op"); + + if (computeOp->getNumRegions() != 0) + return rewriter.notifyMatchFailure(computeOp, + "ops with regions are not supported"); + + if (computeOp->getNumResults() > 1) + return rewriter.notifyMatchFailure( + computeOp, "ops with variadic results are not supported"); + + if (isAnyQuantizedTypes(computeOp->getOperandTypes())) + return rewriter.notifyMatchFailure(computeOp, + "requires non quantized operands"); + + // Collect quantized operands and result types to rewrite. + // All operands and results must be quantized + llvm::SmallVector quantizedComputeOpOperands; + for (const Value& operand : computeOp->getOperands()) { + auto* definingOp = operand.getDefiningOp(); + if (!definingOp) + return rewriter.notifyMatchFailure( + computeOp, "requires operand to be defined by an op"); + + auto dequantOp = dyn_cast(definingOp); + if (!dequantOp) + return rewriter.notifyMatchFailure( + definingOp, + "requires operand to be defined by an stablehlo.uniform_dequantize " + "op"); + + quantizedComputeOpOperands.push_back(dequantOp->getOperand(0)); + } + + rewriter.setInsertionPointAfter(computeOp); + OperationState newState(computeOp->getLoc(), + computeOp->getName().getStringRef(), + quantizedComputeOpOperands, + quantOp->getResultTypes(), computeOp->getAttrs()); + Operation* quantizedComputeOp = rewriter.create(newState); + + // Now that `computeOp` is quantized, replace all uses of the `quantOp` + // with the `quantizedComputeOp`'s result. + quantOp.getResult().replaceAllUsesWith(quantizedComputeOp->getResult(0)); + + return success(); + } +}; + +class StablehloLegalizeQDQToQuantizedOpPass + : public impl::StablehloLegalizeQDQToQuantizedOpPassBase< + StablehloLegalizeQDQToQuantizedOpPass> { + public: + LogicalResult initialize(MLIRContext* context) override { + RewritePatternSet patterns_(context); + populateStablehloLegalizeQDQToQuantizedOpPatterns(&patterns_, context); + patterns = std::move(patterns_); + return success(); + } + + void runOnOperation() override { + auto func = getOperation(); + if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) { + func.emitError( + "Failed to converge StablehloLegalizeQDQToQuantizedOpPass in ") + << config.maxIterations << " iterations"; + signalPassFailure(); + } + } + + private: + FrozenRewritePatternSet patterns; + GreedyRewriteConfig config; +}; + +} // namespace + +void populateStablehloLegalizeQDQToQuantizedOpPatterns( + RewritePatternSet* patterns, MLIRContext* context) { + patterns->add(context); +} + +} // namespace stablehlo +} // namespace mlir