From 603690b3ea5c656dcf228543d91e3f594f0807d0 Mon Sep 17 00:00:00 2001 From: Doyeon Kim Date: Tue, 16 Apr 2024 21:10:19 -0700 Subject: [PATCH] Refer to `Method` for weight-only quantization PiperOrigin-RevId: 625548761 --- .../compiler/mlir/quantization/common/BUILD | 2 + .../common/attrs_and_constraints.cc | 29 ++++ .../common/attrs_and_constraints.h | 14 ++ .../common/attrs_and_constraints_test.cc | 152 ++++++++++++++++++ .../common/lift_as_function_call.h | 4 - .../stablehlo/cc/pass_pipeline.cc | 2 - .../stablehlo/passes/insert_weight_param.cc | 14 +- .../quantization/stablehlo/passes/passes.td | 8 - .../stablehlo/passes/quantization_patterns.cc | 126 +++++++-------- .../stablehlo/passes/quantization_patterns.h | 28 +--- .../quantization/stablehlo/passes/quantize.cc | 27 +--- .../passes/quantize_composite_functions.cc | 13 +- .../tests/passes/insert_weight_param.mlir | 41 +++-- .../tests/passes/quantize/quantize.mlir | 4 +- .../passes/quantize/quantize_weight_only.mlir | 4 +- ...ntize_composite_functions_weight_only.mlir | 4 +- 16 files changed, 313 insertions(+), 159 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/common/BUILD b/tensorflow/compiler/mlir/quantization/common/BUILD index da122b67993af7..5faa358598811c 100644 --- a/tensorflow/compiler/mlir/quantization/common/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/BUILD @@ -145,6 +145,7 @@ cc_library( deps = [ ":uniform_quantized_types", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", @@ -155,6 +156,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:protobuf", "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc index 540eff26685968..e116341eb79f71 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc @@ -34,12 +34,16 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "tsl/platform/protobuf.h" namespace mlir::quant { using ::mlir::stablehlo::DotGeneralOp; +using ::stablehlo::quantization::Method; +using ::tsl::protobuf::TextFormat; bool HasStaticShape(Value value) { auto shaped_type = value.getType().dyn_cast(); @@ -174,4 +178,29 @@ std::optional GetDotGeneralQuantizationDim( return filter_rank - 1; } +bool HasWeightOnlyPtqMethod(Operation& op) { + if (auto quantization_method_txtpb = + op.getAttrOfType(kQuantizationMethodAttr)) { + Method method; + if (TextFormat::ParseFromString(quantization_method_txtpb.getValue().str(), + &method)) { + return method.has_weight_only_ptq(); + } + } + return false; +} + +bool IsWeightOnlyQuantizableOp(const Operation& op) { + if (auto call_op = dyn_cast(op)) { + StringRef entry_function_name = GetEntryFunctionName(call_op); + return ContainsConvOrDot(entry_function_name) && + HasWeightOnlyPtqMethod(*call_op); + } + return false; +} + +bool ContainsConvOrDot(StringRef str) { + return str.contains("conv") || str.contains("dot_general"); +} + } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h index 490a77a3b73ffa..dfbe3c2d45e267 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h @@ -42,6 +42,10 @@ namespace mlir::quant { constexpr char kAttrMapAttribute[] = "attr_map"; +// Name of the string attribute attached to `XlaCallModuleOp`, which is the +// textproto representation of `Method`. +inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method"; + // Permutation from the NHWC tensor format to NCHW. This is an inverse // permutation of `kNchwToNhwcPermutation`. inline constexpr std::array kNhwcToNchwPermutation = {0, 3, 1, 2}; @@ -248,6 +252,16 @@ absl::StatusOr IsDotGeneralFullyConnected( std::optional GetDotGeneralQuantizationDim( ::mlir::stablehlo::DotGeneralOp dot_general_op); +// Checks if the `Method` attatched to the given op has `WeightOnlyPtq`. +bool HasWeightOnlyPtqMethod(Operation& op); + +// Checks if an op is a `tf.XlaCallModule` op, contains 'conv' or 'dot_general' +// in its name and has `Method` with `WeightOnlyPtq`. +bool IsWeightOnlyQuantizableOp(const Operation& op); + +// Checks if a `StringRef` contains 'conv' or 'dot_general'. +bool ContainsConvOrDot(StringRef str); + } // namespace mlir::quant #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc index ca0df77f81b51c..041ce43eba20ac 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc @@ -98,6 +98,21 @@ constexpr absl::string_view kModuleXlaCallModule = R"mlir( } )mlir"; +constexpr absl::string_view kModuleDotWeightOnlyPtq = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "weight_only_ptq { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_dot_general_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x2xf32>) -> tensor + return %0 : tensor + } + } +)mlir"; + constexpr absl::string_view kModuleXlaCallModuleNoEntryNoQuantTrait = R"mlir( module { func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { @@ -526,5 +541,142 @@ TEST_F(AttrsAndConstraintsTest, DotGeneralBatchMatmulReturnsNullQuantDim) { EXPECT_THAT(GetDotGeneralQuantizationDim(dot_general_op), Eq(std::nullopt)); } +TEST_F(AttrsAndConstraintsTest, HasWeightOnlyPtqMethodExists) { + OwningOpRef module_op = + ParseModuleOpString(kModuleDotWeightOnlyPtq); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_TRUE(HasWeightOnlyPtqMethod(*call_op)); +} + +TEST_F(AttrsAndConstraintsTest, HasWeightOnlyPtqMethodDifferentMethod) { + const absl::string_view kModuleDotNoQuantization = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "no_quantization { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_dot_general_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x2xf32>) -> tensor + return %0 : tensor + } + } + )mlir"; + OwningOpRef module_op = + ParseModuleOpString(kModuleDotNoQuantization); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_FALSE(HasWeightOnlyPtqMethod(*call_op)); +} + +TEST_F(AttrsAndConstraintsTest, HasWeightOnlyPtqMethodNoMethod) { + OwningOpRef module_op = ParseModuleOpString(kModuleXlaCallModule); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_FALSE(HasWeightOnlyPtqMethod(*call_op)); +} + +TEST_F(AttrsAndConstraintsTest, IsWeightOnlyQuantizableOpDot) { + OwningOpRef module_op = + ParseModuleOpString(kModuleDotWeightOnlyPtq); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_TRUE(IsWeightOnlyQuantizableOp(*call_op)); +} + +TEST_F(AttrsAndConstraintsTest, IsWeightOnlyQuantizableOpNotTfXlaCallModuleOp) { + const absl::string_view kModulePartitionedCallDot = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.PartitionedCall"(%arg0, %1, %0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_dot_general_fn_1, _quantization_method = "weight_only_ptq { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_dot_general_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x2xf32>) -> tensor + return %0 : tensor + } + } + )mlir"; + OwningOpRef module_op = + ParseModuleOpString(kModulePartitionedCallDot); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_FALSE(IsWeightOnlyQuantizableOp(*call_op)); +} + +TEST_F(AttrsAndConstraintsTest, IsWeightOnlyQuantizableOpNoConvNoDot) { + constexpr absl::string_view kModuleXlaCallModule = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_fn_1, _original_entry_function = "composite_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "weight_only_ptq { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + return %arg0 : tensor + } + } + )mlir"; + OwningOpRef module_op = ParseModuleOpString(kModuleXlaCallModule); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_FALSE(IsWeightOnlyQuantizableOp(*call_op)); +} + +TEST_F(AttrsAndConstraintsTest, ContainsConvOrDotTrue) { + OwningOpRef module_op = + ParseModuleOpString(kModuleDotWeightOnlyPtq); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + const StringRef function_name = GetEntryFunctionName(call_op); + EXPECT_TRUE(ContainsConvOrDot(function_name)); +} + +TEST_F(AttrsAndConstraintsTest, ContainsConvOrDotFalse) { + OwningOpRef module_op = + ParseModuleOpString(kModuleXlaCallModuleNoEntryNoQuantTrait); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + const StringRef function_name = GetEntryFunctionName(call_op); + EXPECT_FALSE(ContainsConvOrDot(function_name)); +} + } // namespace } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h index bfef9a13df1a01..655eb3103eab54 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h @@ -43,10 +43,6 @@ constexpr StringRef kCompositeFuncPrefix = "composite_"; inline constexpr StringRef kOriginalStablehloEntryFunctionAttrName = "_original_entry_function"; -// Name of the string attribute attached to `XlaCallModuleOp`, which is the -// textproto representation of `Method`. -inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method"; - // FunctionCallOpType to be generated as the function call operator when // function lifting will happen. enum FunctionCallOpType { TFPartitionedCallOp = 0, TFXlaCallModuleOp = 1 }; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc index 622ff502c01ed9..ea20a8875ded5a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc @@ -64,7 +64,6 @@ void AddPostCalibrationPasses(OpPassManager& pm, options.enable_per_channel_quantized_weight_ = true; // For debugging purposes. options.mlir_dump_file_name_ = "quantize_composite_functions"; - options.enable_weight_only_ = false; options.merge_fusion_with_dequantize_ = pipeline_config.merge_fusion_with_dequantize(); @@ -101,7 +100,6 @@ void AddWeightOnlyQuantizationPasses( QuantizeCompositeFunctionsPassOptions options; // For debugging purposes. options.mlir_dump_file_name_ = "quantize_composite_functions"; - options.enable_weight_only_ = true; pm.addPass(createQuantizeCompositeFunctionsPass(options)); // Add an inliner pass to inline quantized StableHLO functions. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc index 9fb1e9e985d15e..d785cd5bf4d970 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc @@ -97,23 +97,13 @@ class InsertWeightParamPattern return false; } Operation* user = operand.getOwner(); - if (isa(user)) { - auto call_op = cast(user); - const StringRef function_name = GetEntryFunctionName(call_op); - const bool is_conv_or_dot = function_name.contains("conv") || - function_name.contains("dot_general"); - const bool has_quant_trait = HasQuantizableTrait(call_op); - return is_conv_or_dot && has_quant_trait; - } - return false; + return IsWeightOnlyQuantizableOp(*user); } void rewrite(Operation* op, PatternRewriter& rewriter) const override { Operation* quantizable_op = *op->getUsers().begin(); DenseFPElementsAttr attr; - if (!matchPattern(op->getResult(0), m_Constant(&attr))) { - return; - } + matchPattern(op->getResult(0), m_Constant(&attr)); auto quant_type = quant::GetUniformQuantizedTypeForWeight( attr, /*symmetric=*/false, /*num_bits=*/8, /*is_signed=*/true, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index fdb7fa7941f025..1ca1738566948a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -63,10 +63,6 @@ def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-function Option<"mlir_dump_file_name_", "mlir-dump-file-name", "std::optional", /*default=*/"std::nullopt", "MLIR dump file name.">, - Option<"enable_weight_only_", - "enable-weight-only", - "bool", /*default=*/"false", - "Whether to produce weight-only quantized op for convolution and dot_general op.">, Option<"merge_fusion_with_dequantize_", "merge-fusion-with-dequantize", "bool", /*default=*/"false", @@ -106,10 +102,6 @@ def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { "enable-per-channel-quantized-weight", "bool", /*default=*/"true", "Whether to enable per-channel quantized weights.">, - Option<"enable_weight_only_", - "enable-weight-only", - "bool", /*default=*/"false", - "Whether to produce weight-only quantized op for convolution and dot_general op.">, ]; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index 5578aede7dee3c..dd4ae2ca9ba7b2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -269,7 +269,8 @@ class EntryFuncBodyQuantizationPattern { // Returns `success()` if `entry_func_op`'s body is eligible for rewriting. At // this point `entry_func_op`'s signature has not been reset with quantized // types. - virtual LogicalResult match(func::FuncOp entry_func_op) const = 0; + virtual LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const = 0; // Rewrites the `entry_func_op`'s body. virtual void rewrite(func::FuncOp entry_func_op, @@ -408,19 +409,20 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeDotGeneralOpPattern( - const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) + const bool enable_per_channel_quantized_weight) : enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight), - enable_weight_only_(enable_weight_only) {} + enable_per_channel_quantized_weight) {} - LogicalResult match(func::FuncOp entry_func_op) const override { + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } return MatchGemmStyleOp(entry_func_op); } void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override { - if (enable_weight_only_) return; DotGeneralOp dot_general_op = *entry_func_op.getOps().begin(); const bool should_quantize_per_channel = enable_per_channel_quantized_weight_ && @@ -433,28 +435,26 @@ class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { [[deprecated( "Do not rely on this field for per-channel quantization. Use `Method` " "instead.")]] const bool enable_per_channel_quantized_weight_; - // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform - // weight-only quantization. - const bool enable_weight_only_; }; // Quantizes the entry function's body containing a `ConvolutionOp`. class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeConvolutionOpPattern( - const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) + const bool enable_per_channel_quantized_weight) : enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight), - enable_weight_only_(enable_weight_only) {} + enable_per_channel_quantized_weight) {} - LogicalResult match(func::FuncOp entry_func_op) const override { + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } return MatchGemmStyleOp(entry_func_op); } void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override { - if (enable_weight_only_) return; RewriteGemmStyleOp( entry_func_op, rewriter, enable_per_channel_quantized_weight_ && @@ -482,19 +482,42 @@ class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { [[deprecated( "Do not rely on this field for per-channel quantization. Use `Method` " "instead.")]] const bool enable_per_channel_quantized_weight_; - // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform - // weight-only quantization. - const bool enable_weight_only_; +}; + +// Quantizes the entry function's body for weight-only quantized op. +template +class QuantizeWeightOnlyOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeWeightOnlyOpPattern( + const bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_weight_only_ptq()) { + return failure(); + } + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, + PatternRewriter& rewriter) const override {} + + private: + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; }; template class QuantizeSingularOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeSingularOpPattern( - const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) {} + const bool enable_per_channel_quantized_weight) {} - LogicalResult match(func::FuncOp entry_func_op) const override { + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { const auto op_iterator_range = entry_func_op.getOps(); if (op_iterator_range.empty()) { LLVM_DEBUG(llvm::dbgs() << "Function does not have " @@ -606,12 +629,10 @@ template { public: explicit XlaCallModuleOpToCallOp( - MLIRContext& ctx, const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) + MLIRContext& ctx, const bool enable_per_channel_quantized_weight) : OpRewritePattern(&ctx), enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight), - enable_weight_only_(enable_weight_only) {} + enable_per_channel_quantized_weight) {} LogicalResult match(TF::XlaCallModuleOp op) const override { ModuleOp module_op = op->getParentOfType(); @@ -625,7 +646,7 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { if (!IsQuantizedXlaCallModuleOp(op)) return failure(); // For weight-only quantization, op should be hybrid quantized. - if (enable_weight_only_ && !IsHybridQuantizedOp(op)) { + if (HasWeightOnlyPtqMethod(*op) && !IsHybridQuantizedOp(op)) { return failure(); } @@ -634,10 +655,9 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { op->emitError("Failed to find a valid entry function."); return failure(); } - - return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_, - enable_weight_only_) - .match(entry_func_op); + Method quantization_method = GetQuantizationMethodOrDefault(op); + return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) + .match(entry_func_op, quantization_method); } void rewrite(TF::XlaCallModuleOp xla_call_module_op, @@ -650,8 +670,7 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT(enable_per_channel_quantized_weight_, - enable_weight_only_), + FuncBodyRewritePatternT(enable_per_channel_quantized_weight_), quantization_method); } @@ -659,9 +678,6 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { [[deprecated( "Do not rely on this field for per-channel quantization. Use `Method` " "instead.")]] const bool enable_per_channel_quantized_weight_; - // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform - // weight-only quantization. - const bool enable_weight_only_; }; // Quantizes op with regions such as stablehlo.reduce_window op. @@ -937,20 +953,6 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { return false; } -template -class QuantizeWeightOnlyOpPattern : public EntryFuncBodyQuantizationPattern { - public: - explicit QuantizeWeightOnlyOpPattern( - const bool enable_per_channel_quantized_weight) {} - - LogicalResult match(func::FuncOp entry_func_op) const override { - return MatchGemmStyleOp(entry_func_op); - } - - void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, - PatternRewriter& rewriter) const override {} -}; - // Compute heavy patterns should be quantized for both server and ODML targets. // Most patterns here are useful when quantized since they are compute heavy // or memory bound. @@ -958,13 +960,18 @@ void PopulateCommonQuantizationPatterns( MLIRContext& ctx, RewritePatternSet& patterns, const bool enable_per_channel_quantized_weight) { patterns.add>( - ctx, enable_per_channel_quantized_weight, /*enable_weight_only=*/false); + ctx, enable_per_channel_quantized_weight); patterns.add>( - ctx, enable_per_channel_quantized_weight, /*enable_weight_only=*/false); + ctx, enable_per_channel_quantized_weight); + patterns + .add>>( + ctx, enable_per_channel_quantized_weight); + patterns + .add>>( + ctx, enable_per_channel_quantized_weight); // TODO: b/307620772 - Per-channel quantization for gather. patterns.add>>( - ctx, /*enable_per_channel_quantized_weight=*/false, - /*enable_weight_only=*/false); + ctx, /*enable_per_channel_quantized_weight=*/false); // Populate pattern for quantization of ops with regions such as // `stablehlo.reduce_window` op. patterns.add(ctx); @@ -973,16 +980,7 @@ void PopulateCommonQuantizationPatterns( void PopulateAllQuantizablePatterns(MLIRContext& ctx, RewritePatternSet& patterns) { patterns.add>>( - ctx, /*enable_per_channel_quantized_weight=*/false, - /*enable_weight_only=*/false); -} - -void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx, - RewritePatternSet& patterns) { - patterns.add, - XlaCallModuleOpToCallOp>( - ctx, /*enable_per_channel_quantized_weight*/ false, - /*enable_weight_only=*/true); + ctx, /*enable_per_channel_quantized_weight=*/false); } } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h index 67eb267c1d9037..b8ebe592c41f21 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" #include "tensorflow/core/framework/types.pb.h" @@ -59,18 +60,8 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op); // quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. // Each matched pattern are rewritten by its quantized alternatives. // -// The concrete pattern, extends from this base pattern, can specify whether it -// allows weight-only quantization. If it is allowed, for operand/result that is -// not adjacent to dequantize/quantize op, it remains as float. For -// operand/result that is adjacent to dequantize/quantize, it is quantized. -// Weight-only quantization can be used to generate both weight-only -// quantization and dynamic range quantization. The condition for allowing -// weight-only quantization or not for an op can be specified in the below -// function: -// -// static bool AllowWeightOnlyQuantization(Operation& op) -// -// This is a templatized `OpRewritePattern`. +// Quantization method is determined by the `_quantization_method` attributes +// attached to each quantizable units. // // Template constraints are imposed as follows: // @@ -159,6 +150,9 @@ class StableHloQuantizationPattern : public OpRewritePattern { return failure(); } + const bool weight_only_quantizable = + IsWeightOnlyQuantizableOp(*candidate_op); + // Collect all the quantized inputs and "clone" the matched op by these // inputs. SmallVector inputs; @@ -178,8 +172,7 @@ class StableHloQuantizationPattern : public OpRewritePattern { // If the operand is an integer tensor, then it doesn't require the // DequantizeOp in the pattern. inputs.push_back(operand); - } else if (static_cast(this) - ->AllowWeightOnlyQuantization(*candidate_op)) { + } else if (weight_only_quantizable) { inputs.push_back(operand); } else { return failure(); @@ -215,8 +208,7 @@ class StableHloQuantizationPattern : public OpRewritePattern { // D op in the pattern. outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result.getType()); - } else if (static_cast(this) - ->AllowWeightOnlyQuantization(*candidate_op)) { + } else if (weight_only_quantizable) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result.getType()); } else { @@ -260,10 +252,6 @@ void PopulateCommonQuantizationPatterns( void PopulateAllQuantizablePatterns(MLIRContext& ctx, RewritePatternSet& patterns); -// Populates pattern weight-only quantization. -void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx, - RewritePatternSet& patterns); - } // namespace mlir::quant::stablehlo #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index 0000057402886f..86dbae8e4181f9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -77,35 +77,14 @@ struct StableHloQuantizationReverse quantfork::QuantizeCastOp>(ctx) {} }; -bool IsHybridQuantizableOp(Operation& op) { - auto call_op = cast(op); - if (call_op == nullptr) return false; - StringRef entry_function_name = GetEntryFunctionName(call_op); - return entry_function_name.contains("conv") || - entry_function_name.contains("dot_general"); -} - -// Quantization rewrite pattern using DQ as the root op. -struct StableHloQuantizationWeightOnly - : public StableHloQuantizationBase { - explicit StableHloQuantizationWeightOnly(MLIRContext* ctx) - : StableHloQuantizationBase(ctx) {} - - static bool AllowWeightOnlyQuantization(Operation& op) { - return IsHybridQuantizableOp(op); - } -}; - class QuantizePass : public impl::QuantizePassBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizePass) using impl::QuantizePassBase::QuantizePassBase; - explicit QuantizePass(const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) { + explicit QuantizePass(const bool enable_per_channel_quantized_weight) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; - enable_weight_only_ = enable_weight_only; } private: @@ -118,10 +97,6 @@ void QuantizePass::runOnOperation() { RewritePatternSet patterns(&ctx); patterns.add(&ctx); - if (enable_weight_only_) { - patterns.add(&ctx); - PopulateQuantizeWeightOnlyPatterns(ctx, patterns); - } PopulateCommonQuantizationPatterns(ctx, patterns, enable_per_channel_quantized_weight_); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index 1efc5d40c7ce20..0328b02c68c609 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -55,10 +55,8 @@ class QuantizeCompositeFunctionsPass QuantizeCompositeFunctionsPass>::QuantizeCompositeFunctionsPassBase; explicit QuantizeCompositeFunctionsPass( - const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) { + const bool enable_per_channel_quantized_weight) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; - enable_weight_only_ = enable_weight_only; } private: @@ -80,9 +78,10 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { // Change this to user-given bit width once we have custom configuration. options.bit_width_ = 8; - if (enable_weight_only_) { - pm.addNestedPass(createInsertWeightParamPass()); - } + // Insert quantization parameters for weights for ops with `weight_only_ptq` + // attribute. + pm.addNestedPass(createInsertWeightParamPass()); + // PrepareQuantizePass uses SymbolTable to fetch relevant GEMM ops for // determining quantization attributes. This requires module-level context. pm.addPass(createPrepareQuantizePass(options)); @@ -90,7 +89,7 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { QuantizePassOptions quantize_options; quantize_options.enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight_; - quantize_options.enable_weight_only_ = enable_weight_only_; + // QuantizePass modifies FuncOps referenced outside of its given scope // and therefore requires a module-level context. pm.addPass(createQuantizePass(quantize_options)); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir index 89ff96efecf471..a8b694b41b8cc5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir @@ -1,14 +1,14 @@ // RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-insert-weight-param | FileCheck %s // Test that q/dq pair is inserted between constant and XlaCallModule op -// with quantizable trait and function name containing conv. +// with `weight_only_ptq` method and function name containing conv. func.func @qdq_for_conv_weight(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %0 = "tf.XlaCallModule"(%arg0, %cst) { Sout = [#tf_type.shape<1x2x2x2>], _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", - _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 @@ -21,20 +21,20 @@ func.func @qdq_for_conv_weight(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32 // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> // CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> // CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> -// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) <{Sout = [#tf_type.shape<1x2x2x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) <{Sout = [#tf_type.shape<1x2x2x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, device = ""} : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> // CHECK: return %[[CALL]] : tensor<1x2x2x2xf32> // ----- // Test that q/dq pair is inserted between constant and XlaCallModule op -// with quantizable trait and function name containing dot_general. +// with `weight_only_ptq` method and function name containing dot_general. func.func @qdq_for_dot_general_weight(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> %0 = "tf.XlaCallModule"(%arg0, %cst) { Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", - _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 @@ -47,7 +47,7 @@ func.func @qdq_for_dot_general_weight(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3xf32>}> : () -> tensor<2x3xf32> // CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> // CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<2x3x!quant.uniform>) -> tensor<2x3xf32> -// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> // CHECK: return %[[CALL]] : tensor<1x3xf32> // ----- @@ -59,7 +59,7 @@ func.func @no_qdq_except_conv_and_dot_general(%arg0: tensor<2x3x2xi64>) -> tenso %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<3x4x2xf32>} : () -> tensor<3x4x2xf32> %0 = "tf.XlaCallModule"(%cst, %arg0) { Sout = [#tf_type.shape<1x3>], _entry_function = @composite_gather_fn, - _original_entry_function = "composite_gather_fn", + _original_entry_function = "composite_gather_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 @@ -81,7 +81,7 @@ func.func @no_qdq_for_non_weight_constant(%arg0: tensor<1x2xf32>, %arg1: tensor< %0 = "tf.XlaCallModule"(%arg0, %arg1, %cst) { Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_fn, _original_entry_function = "composite_dot_general_with_bias_fn", - _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 @@ -96,7 +96,7 @@ func.func @no_qdq_for_non_weight_constant(%arg0: tensor<1x2xf32>, %arg1: tensor< // ----- // Test that q/dq pair is not inserted between constant and XlaCallModule op -// without quantizable trait. +// without `weight_only_ptq` method. func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> @@ -116,6 +116,27 @@ func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3 // ----- +// Test that q/dq pair is not inserted between constant and XlaCallModule op +// with different method. + +func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], + disabled_checks = [], has_token_input_output = false, module = "", + platforms = [], _quantization_method = "static_range_ptq { }", version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @no_qdq_for_not_quantizable_call +// CHECK-NOT: quantfork.qcast +// CHECK-NOT: quantfork.dcast + +// ----- + // Test that q/dq pair is not inserted when constant has multiple users. func.func @no_qdq_for_multiple_users(%arg0: tensor<2x2xf32>) -> tensor<2x3xf32> attributes {tf._original_func_name = "main_0"} { @@ -123,7 +144,7 @@ func.func @no_qdq_for_multiple_users(%arg0: tensor<2x2xf32>) -> tensor<2x3xf32> %0 = "tf.XlaCallModule"(%arg0, %cst) { Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", - _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir index 9f2e371fb4d1d6..79f44e10b03e46 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir @@ -40,7 +40,7 @@ module attributes {tf_saved_model.semantics} { // CHECK-LABEL: quantize_simple_xla_call_module_no_operand func.func private @quantize_simple_xla_call_module_no_operand() -> tensor<1x3xf32> { - %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> @@ -63,7 +63,7 @@ module attributes {tf_saved_model.semantics} { %4 = "quantfork.dcast"(%3) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // expected-error @+2 {{Failed to find a valid entry function}} // expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} - %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> %6 = "quantfork.qcast"(%5) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> %7 = "quantfork.dcast"(%6) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> return %7 : tensor<1x3xf32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir index 15330b0b79b800..a6f0111d2c8293 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize=enable-weight-only=true | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize | FileCheck %s // Test that hybrid quantized dot_general is produced when q/dq pair only exists // for weight. @@ -41,7 +41,7 @@ module attributes {tf_saved_model.semantics} { %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> %0 = "quantfork.qcast"(%cst) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> %1 = "quantfork.dcast"(%0) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> - %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> return %2 : tensor<1x3x4x2xf32> } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir index dbe192bbb55cde..a614ee0af36adc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir @@ -1,5 +1,5 @@ // RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ -// RUN: -stablehlo-quantize-composite-functions=enable-weight-only=true | FileCheck --check-prefix=CHECK %s +// RUN: -stablehlo-quantize-composite-functions | FileCheck --check-prefix=CHECK %s // Test that weight-only quantized dot_general op is produced when // enable-weight-only is set to true. @@ -37,7 +37,7 @@ module attributes {tf_saved_model.semantics} { module attributes {tf_saved_model.semantics} { func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> - %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> return %1 : tensor<1x3x4x2xf32> }