Skip to content

Commit

Permalink
Refer to Method for weight-only quantization
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625548761
  • Loading branch information
doyeonkim0 authored and tensorflower-gardener committed Apr 17, 2024
1 parent e0d6269 commit 603690b
Show file tree
Hide file tree
Showing 16 changed files with 313 additions and 159 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/quantization/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ cc_library(
deps = [
":uniform_quantized_types",
"//tensorflow/compiler/mlir/quantization/common/quantization_lib",
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
"//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs",
Expand All @@ -155,6 +156,7 @@ cc_library(
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@local_tsl//tsl/platform:protobuf",
"@stablehlo//:stablehlo_ops",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep
#include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h"
#include "tsl/platform/protobuf.h"

namespace mlir::quant {

using ::mlir::stablehlo::DotGeneralOp;
using ::stablehlo::quantization::Method;
using ::tsl::protobuf::TextFormat;

bool HasStaticShape(Value value) {
auto shaped_type = value.getType().dyn_cast<ShapedType>();
Expand Down Expand Up @@ -174,4 +178,29 @@ std::optional<int64_t> GetDotGeneralQuantizationDim(
return filter_rank - 1;
}

bool HasWeightOnlyPtqMethod(Operation& op) {
if (auto quantization_method_txtpb =
op.getAttrOfType<StringAttr>(kQuantizationMethodAttr)) {
Method method;
if (TextFormat::ParseFromString(quantization_method_txtpb.getValue().str(),
&method)) {
return method.has_weight_only_ptq();
}
}
return false;
}

bool IsWeightOnlyQuantizableOp(const Operation& op) {
if (auto call_op = dyn_cast<TF::XlaCallModuleOp>(op)) {
StringRef entry_function_name = GetEntryFunctionName(call_op);
return ContainsConvOrDot(entry_function_name) &&
HasWeightOnlyPtqMethod(*call_op);
}
return false;
}

bool ContainsConvOrDot(StringRef str) {
return str.contains("conv") || str.contains("dot_general");
}

} // namespace mlir::quant
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ namespace mlir::quant {

constexpr char kAttrMapAttribute[] = "attr_map";

// Name of the string attribute attached to `XlaCallModuleOp`, which is the
// textproto representation of `Method`.
inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method";

// Permutation from the NHWC tensor format to NCHW. This is an inverse
// permutation of `kNchwToNhwcPermutation`.
inline constexpr std::array<int64_t, 4> kNhwcToNchwPermutation = {0, 3, 1, 2};
Expand Down Expand Up @@ -248,6 +252,16 @@ absl::StatusOr<bool> IsDotGeneralFullyConnected(
std::optional<int64_t> GetDotGeneralQuantizationDim(
::mlir::stablehlo::DotGeneralOp dot_general_op);

// Checks if the `Method` attatched to the given op has `WeightOnlyPtq`.
bool HasWeightOnlyPtqMethod(Operation& op);

// Checks if an op is a `tf.XlaCallModule` op, contains 'conv' or 'dot_general'
// in its name and has `Method` with `WeightOnlyPtq`.
bool IsWeightOnlyQuantizableOp(const Operation& op);

// Checks if a `StringRef` contains 'conv' or 'dot_general'.
bool ContainsConvOrDot(StringRef str);

} // namespace mlir::quant

#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,21 @@ constexpr absl::string_view kModuleXlaCallModule = R"mlir(
}
)mlir";

constexpr absl::string_view kModuleDotWeightOnlyPtq = R"mlir(
module {
func.func @main(%arg0: tensor<?x2xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<?x2xf32>) {
%0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32>
%1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32>
%2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape<?x2>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "weight_only_ptq { }"} : (tensor<?x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<?x2xf32>
return %2 : tensor<?x2xf32>
}
func.func private @composite_dot_general_fn_1(%arg0: tensor<?x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor<?x2xf32> attributes {_from_xla_call_module, tf_quant.composite_function} {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<?x2xf32>, tensor<2x2xf32>) -> tensor<?x2xf32>
return %0 : tensor<?x2xf32>
}
}
)mlir";

constexpr absl::string_view kModuleXlaCallModuleNoEntryNoQuantTrait = R"mlir(
module {
func.func @main(%arg0: tensor<?x2xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<?x2xf32>) {
Expand Down Expand Up @@ -526,5 +541,142 @@ TEST_F(AttrsAndConstraintsTest, DotGeneralBatchMatmulReturnsNullQuantDim) {
EXPECT_THAT(GetDotGeneralQuantizationDim(dot_general_op), Eq(std::nullopt));
}

TEST_F(AttrsAndConstraintsTest, HasWeightOnlyPtqMethodExists) {
OwningOpRef<ModuleOp> module_op =
ParseModuleOpString(kModuleDotWeightOnlyPtq);
ASSERT_TRUE(module_op);

func::FuncOp main_fn = FindMainFuncOp(*module_op);
ASSERT_THAT(main_fn, NotNull());

auto call_op = *main_fn.getOps<TF::XlaCallModuleOp>().begin();
EXPECT_TRUE(HasWeightOnlyPtqMethod(*call_op));
}

TEST_F(AttrsAndConstraintsTest, HasWeightOnlyPtqMethodDifferentMethod) {
const absl::string_view kModuleDotNoQuantization = R"mlir(
module {
func.func @main(%arg0: tensor<?x2xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<?x2xf32>) {
%0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32>
%1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32>
%2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape<?x2>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "no_quantization { }"} : (tensor<?x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<?x2xf32>
return %2 : tensor<?x2xf32>
}
func.func private @composite_dot_general_fn_1(%arg0: tensor<?x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor<?x2xf32> attributes {_from_xla_call_module, tf_quant.composite_function} {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<?x2xf32>, tensor<2x2xf32>) -> tensor<?x2xf32>
return %0 : tensor<?x2xf32>
}
}
)mlir";
OwningOpRef<ModuleOp> module_op =
ParseModuleOpString(kModuleDotNoQuantization);
ASSERT_TRUE(module_op);

func::FuncOp main_fn = FindMainFuncOp(*module_op);
ASSERT_THAT(main_fn, NotNull());

auto call_op = *main_fn.getOps<TF::XlaCallModuleOp>().begin();
EXPECT_FALSE(HasWeightOnlyPtqMethod(*call_op));
}

TEST_F(AttrsAndConstraintsTest, HasWeightOnlyPtqMethodNoMethod) {
OwningOpRef<ModuleOp> module_op = ParseModuleOpString(kModuleXlaCallModule);
ASSERT_TRUE(module_op);

func::FuncOp main_fn = FindMainFuncOp(*module_op);
ASSERT_THAT(main_fn, NotNull());

auto call_op = *main_fn.getOps<TF::XlaCallModuleOp>().begin();
EXPECT_FALSE(HasWeightOnlyPtqMethod(*call_op));
}

TEST_F(AttrsAndConstraintsTest, IsWeightOnlyQuantizableOpDot) {
OwningOpRef<ModuleOp> module_op =
ParseModuleOpString(kModuleDotWeightOnlyPtq);
ASSERT_TRUE(module_op);

func::FuncOp main_fn = FindMainFuncOp(*module_op);
ASSERT_THAT(main_fn, NotNull());

auto call_op = *main_fn.getOps<TF::XlaCallModuleOp>().begin();
EXPECT_TRUE(IsWeightOnlyQuantizableOp(*call_op));
}

TEST_F(AttrsAndConstraintsTest, IsWeightOnlyQuantizableOpNotTfXlaCallModuleOp) {
const absl::string_view kModulePartitionedCallDot = R"mlir(
module {
func.func @main(%arg0: tensor<?x2xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<?x2xf32>) {
%0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32>
%1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32>
%2 = "tf.PartitionedCall"(%arg0, %1, %0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_dot_general_fn_1, _quantization_method = "weight_only_ptq { }"} : (tensor<?x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<?x2xf32>
return %2 : tensor<?x2xf32>
}
func.func private @composite_dot_general_fn_1(%arg0: tensor<?x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor<?x2xf32> attributes {_from_xla_call_module, tf_quant.composite_function} {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<?x2xf32>, tensor<2x2xf32>) -> tensor<?x2xf32>
return %0 : tensor<?x2xf32>
}
}
)mlir";
OwningOpRef<ModuleOp> module_op =
ParseModuleOpString(kModulePartitionedCallDot);
ASSERT_TRUE(module_op);

func::FuncOp main_fn = FindMainFuncOp(*module_op);
ASSERT_THAT(main_fn, NotNull());

auto call_op = *main_fn.getOps<TF::PartitionedCallOp>().begin();
EXPECT_FALSE(IsWeightOnlyQuantizableOp(*call_op));
}

TEST_F(AttrsAndConstraintsTest, IsWeightOnlyQuantizableOpNoConvNoDot) {
constexpr absl::string_view kModuleXlaCallModule = R"mlir(
module {
func.func @main(%arg0: tensor<?x2xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<?x2xf32>) {
%0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32>
%1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32>
%2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape<?x2>], module = "", version = 9 : i64}> {_entry_function = @composite_fn_1, _original_entry_function = "composite_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "weight_only_ptq { }"} : (tensor<?x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<?x2xf32>
return %2 : tensor<?x2xf32>
}
func.func private @composite_fn_1(%arg0: tensor<?x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor<?x2xf32> attributes {_from_xla_call_module, tf_quant.composite_function} {
return %arg0 : tensor<?x2xf32>
}
}
)mlir";
OwningOpRef<ModuleOp> module_op = ParseModuleOpString(kModuleXlaCallModule);
ASSERT_TRUE(module_op);

func::FuncOp main_fn = FindMainFuncOp(*module_op);
ASSERT_THAT(main_fn, NotNull());

auto call_op = *main_fn.getOps<TF::XlaCallModuleOp>().begin();
EXPECT_FALSE(IsWeightOnlyQuantizableOp(*call_op));
}

TEST_F(AttrsAndConstraintsTest, ContainsConvOrDotTrue) {
OwningOpRef<ModuleOp> module_op =
ParseModuleOpString(kModuleDotWeightOnlyPtq);
ASSERT_TRUE(module_op);

func::FuncOp main_fn = FindMainFuncOp(*module_op);
ASSERT_THAT(main_fn, NotNull());

auto call_op = *main_fn.getOps<TF::XlaCallModuleOp>().begin();
const StringRef function_name = GetEntryFunctionName(call_op);
EXPECT_TRUE(ContainsConvOrDot(function_name));
}

TEST_F(AttrsAndConstraintsTest, ContainsConvOrDotFalse) {
OwningOpRef<ModuleOp> module_op =
ParseModuleOpString(kModuleXlaCallModuleNoEntryNoQuantTrait);
ASSERT_TRUE(module_op);

func::FuncOp main_fn = FindMainFuncOp(*module_op);
ASSERT_THAT(main_fn, NotNull());

auto call_op = *main_fn.getOps<TF::XlaCallModuleOp>().begin();
const StringRef function_name = GetEntryFunctionName(call_op);
EXPECT_FALSE(ContainsConvOrDot(function_name));
}

} // namespace
} // namespace mlir::quant
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ constexpr StringRef kCompositeFuncPrefix = "composite_";
inline constexpr StringRef kOriginalStablehloEntryFunctionAttrName =
"_original_entry_function";

// Name of the string attribute attached to `XlaCallModuleOp`, which is the
// textproto representation of `Method`.
inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method";

// FunctionCallOpType to be generated as the function call operator when
// function lifting will happen.
enum FunctionCallOpType { TFPartitionedCallOp = 0, TFXlaCallModuleOp = 1 };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ void AddPostCalibrationPasses(OpPassManager& pm,
options.enable_per_channel_quantized_weight_ = true;
// For debugging purposes.
options.mlir_dump_file_name_ = "quantize_composite_functions";
options.enable_weight_only_ = false;
options.merge_fusion_with_dequantize_ =
pipeline_config.merge_fusion_with_dequantize();

Expand Down Expand Up @@ -101,7 +100,6 @@ void AddWeightOnlyQuantizationPasses(
QuantizeCompositeFunctionsPassOptions options;
// For debugging purposes.
options.mlir_dump_file_name_ = "quantize_composite_functions";
options.enable_weight_only_ = true;
pm.addPass(createQuantizeCompositeFunctionsPass(options));

// Add an inliner pass to inline quantized StableHLO functions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,13 @@ class InsertWeightParamPattern
return false;
}
Operation* user = operand.getOwner();
if (isa<TF::XlaCallModuleOp>(user)) {
auto call_op = cast<TF::XlaCallModuleOp>(user);
const StringRef function_name = GetEntryFunctionName(call_op);
const bool is_conv_or_dot = function_name.contains("conv") ||
function_name.contains("dot_general");
const bool has_quant_trait = HasQuantizableTrait(call_op);
return is_conv_or_dot && has_quant_trait;
}
return false;
return IsWeightOnlyQuantizableOp(*user);
}

void rewrite(Operation* op, PatternRewriter& rewriter) const override {
Operation* quantizable_op = *op->getUsers().begin();
DenseFPElementsAttr attr;
if (!matchPattern(op->getResult(0), m_Constant(&attr))) {
return;
}
matchPattern(op->getResult(0), m_Constant(&attr));
auto quant_type =
quant::GetUniformQuantizedTypeForWeight(
attr, /*symmetric=*/false, /*num_bits=*/8, /*is_signed=*/true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-function
Option<"mlir_dump_file_name_", "mlir-dump-file-name",
"std::optional<std::string>", /*default=*/"std::nullopt",
"MLIR dump file name.">,
Option<"enable_weight_only_",
"enable-weight-only",
"bool", /*default=*/"false",
"Whether to produce weight-only quantized op for convolution and dot_general op.">,
Option<"merge_fusion_with_dequantize_",
"merge-fusion-with-dequantize",
"bool", /*default=*/"false",
Expand Down Expand Up @@ -106,10 +102,6 @@ def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> {
"enable-per-channel-quantized-weight",
"bool", /*default=*/"true",
"Whether to enable per-channel quantized weights.">,
Option<"enable_weight_only_",
"enable-weight-only",
"bool", /*default=*/"false",
"Whether to produce weight-only quantized op for convolution and dot_general op.">,
];
let dependentDialects = [
"mlir::stablehlo::StablehloDialect",
Expand Down
Loading

0 comments on commit 603690b

Please sign in to comment.