diff --git a/ai_edge_quantizer/algorithm_manager.py b/ai_edge_quantizer/algorithm_manager.py index 046c69e..050c2c0 100644 --- a/ai_edge_quantizer/algorithm_manager.py +++ b/ai_edge_quantizer/algorithm_manager.py @@ -16,10 +16,12 @@ """Quantizer Algorithm Manager Interface.""" import enum +import functools from ai_edge_quantizer import algorithm_manager_api from ai_edge_quantizer import default_policy from ai_edge_quantizer import qtyping from ai_edge_quantizer.algorithms.nonlinear_quantize import float_casting +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize _TFLOpName = qtyping.TFLOperationName @@ -54,7 +56,7 @@ class AlgorithmName(str, enum.Enum): # Register MIN_MAX_UNIFORM_QUANT algorithm. register_op_quant_config_validation_func( AlgorithmName.MIN_MAX_UNIFORM_QUANT, - naive_min_max_quantize.check_op_quantization_config, + common_quantize.check_op_quantization_config, ) # Register a config check policy for MIN_MAX_UNIFORM_QUANT algorithm. @@ -63,71 +65,44 @@ class AlgorithmName(str, enum.Enum): default_policy.DEFAULT_CONFIG_CHECK_POLICY, ) - -for op_name, materialize_func in zip( - ( - _TFLOpName.INPUT, - _TFLOpName.OUTPUT, - _TFLOpName.FULLY_CONNECTED, - _TFLOpName.BATCH_MATMUL, - _TFLOpName.CONV_2D, - _TFLOpName.DEPTHWISE_CONV_2D, - _TFLOpName.CONV_2D_TRANSPOSE, - _TFLOpName.RESHAPE, - _TFLOpName.AVERAGE_POOL_2D, - _TFLOpName.EMBEDDING_LOOKUP, - _TFLOpName.SOFTMAX, - _TFLOpName.TANH, - _TFLOpName.TRANSPOSE, - _TFLOpName.GELU, - _TFLOpName.ADD, - _TFLOpName.SUB, - _TFLOpName.MUL, - _TFLOpName.MEAN, - _TFLOpName.RSQRT, - _TFLOpName.CONCATENATION, - _TFLOpName.STRIDED_SLICE, - _TFLOpName.SPLIT, - _TFLOpName.LOGISTIC, # Sigmoid - _TFLOpName.SLICE, - _TFLOpName.SUM, - _TFLOpName.SELECT_V2, - ), - ( - naive_min_max_quantize.materialize_input, - naive_min_max_quantize.materialize_output, - naive_min_max_quantize.materialize_fc_conv, - naive_min_max_quantize.materialize_batch_matmul, - naive_min_max_quantize.materialize_fc_conv, - naive_min_max_quantize.materialize_fc_conv, - naive_min_max_quantize.materialize_conv2d_transpose, - naive_min_max_quantize.materialize_reshape, - naive_min_max_quantize.materialize_average_pool_2d, - naive_min_max_quantize.materialize_embedding_lookup, - naive_min_max_quantize.materialize_softmax_and_logistic, - naive_min_max_quantize.materialize_tanh, - naive_min_max_quantize.materialize_transpose, - naive_min_max_quantize.materialize_gelu, - naive_min_max_quantize.materialize_add, - naive_min_max_quantize.materialize_sub, - naive_min_max_quantize.materialize_mul, - naive_min_max_quantize.materialize_mean, - naive_min_max_quantize.materialize_rsqrt, - naive_min_max_quantize.materialize_concatenation, - naive_min_max_quantize.materialize_strided_slice, - naive_min_max_quantize.materialize_split, - naive_min_max_quantize.materialize_softmax_and_logistic, - naive_min_max_quantize.materialize_slice, - naive_min_max_quantize.materialize_sum, - naive_min_max_quantize.materialize_select_v2, - ), -): +MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = { + _TFLOpName.INPUT: common_quantize.materialize_input, + _TFLOpName.OUTPUT: common_quantize.materialize_output, + _TFLOpName.FULLY_CONNECTED: common_quantize.materialize_fc_conv, + _TFLOpName.BATCH_MATMUL: common_quantize.materialize_batch_matmul, + _TFLOpName.CONV_2D: common_quantize.materialize_fc_conv, + _TFLOpName.DEPTHWISE_CONV_2D: common_quantize.materialize_fc_conv, + _TFLOpName.CONV_2D_TRANSPOSE: common_quantize.materialize_conv2d_transpose, + _TFLOpName.RESHAPE: common_quantize.materialize_reshape, + _TFLOpName.AVERAGE_POOL_2D: common_quantize.materialize_average_pool_2d, + _TFLOpName.EMBEDDING_LOOKUP: common_quantize.materialize_embedding_lookup, + _TFLOpName.SOFTMAX: common_quantize.materialize_softmax_and_logistic, + _TFLOpName.TANH: common_quantize.materialize_tanh, + _TFLOpName.TRANSPOSE: common_quantize.materialize_transpose, + _TFLOpName.GELU: common_quantize.materialize_gelu, + _TFLOpName.ADD: common_quantize.materialize_add, + _TFLOpName.SUB: common_quantize.materialize_sub, + _TFLOpName.MUL: common_quantize.materialize_mul, + _TFLOpName.MEAN: common_quantize.materialize_mean, + _TFLOpName.RSQRT: common_quantize.materialize_rsqrt, + _TFLOpName.CONCATENATION: common_quantize.materialize_concatenation, + _TFLOpName.STRIDED_SLICE: common_quantize.materialize_strided_slice, + _TFLOpName.SPLIT: common_quantize.materialize_split, + _TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic, + _TFLOpName.SLICE: common_quantize.materialize_slice, + _TFLOpName.SUM: common_quantize.materialize_sum, + _TFLOpName.SELECT_V2: common_quantize.materialize_select_v2, +} +for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items(): register_quantized_op( AlgorithmName.MIN_MAX_UNIFORM_QUANT, op_name, naive_min_max_quantize.init_qsvs, calibration_func=naive_min_max_quantize.min_max_calibrate, - materialize_func=materialize_func, + materialize_func=functools.partial( + materialize_func, + naive_min_max_quantize.get_tensor_quant_params, + ), ) # Register FLOAT_CASTING algorithm. diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py b/ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py new file mode 100644 index 0000000..d058193 --- /dev/null +++ b/ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py @@ -0,0 +1,629 @@ +# Copyright 2024 The AI Edge Quantizer Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Quantization helpers common to all uniform quantization algorithms.""" + +from typing import Any, Callable +import numpy as np +from ai_edge_quantizer import qtyping +from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor +from ai_edge_quantizer.algorithms.utils import common_utils +from ai_edge_quantizer.utils import tfl_flatbuffer_utils + +_TFLOpName = qtyping.TFLOperationName +_QuantTransformation = qtyping.QuantTransformation +_OpQuantConstraint = common_utils.OpQuantConstraint +_ComputePrecision = qtyping.ComputePrecision + + +def check_op_quantization_config( + op_name: _TFLOpName, + op_quant_config: qtyping.OpQuantizationConfig, + config_check_policy: qtyping.ConfigCheckPolicyDict, +) -> None: + """Checks the op quantization config. + + Args: + op_name: The name of the op. + op_quant_config: The quantization config for the op. + config_check_policy: The policy to check the op quantization config. + + Raises: + ValueError: If the op quantization config is invalid. + """ + if op_quant_config.weight_tensor_config is None: + raise ValueError( + "Weight tensor quantization is required for min/max uniform" + " quantization." + ) + if op_quant_config.weight_tensor_config.dtype != qtyping.TensorDataType.INT: + raise ValueError( + "Weights need to have integer type for min/max uniform quantization. If" + " you wish to perform float casting quantization (e.g., fp16 weight" + " only), please set algorithm key as 'float_casting'." + ) + + if op_quant_config.min_weight_elements < 0: + raise ValueError( + f"min_weight_elements must be non-negative for op: {op_name} with" + f" config: {op_quant_config}." + ) + + if op_quant_config.compute_precision in [ + _ComputePrecision.INTEGER, + _ComputePrecision.FLOAT, + ]: + # Use policy-based mechanism to validate op. + common_utils.check_if_valid_op_config( + op_name, op_quant_config, config_check_policy + ) + common_utils.check_subchannel_config(op_name, op_quant_config) + + +def materialize_input( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in the virtual input op.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + ) + + +def materialize_output( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in the virtual output op.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + ) + + +def materialize_add( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.add.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + ) + + +def materialize_sub( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.sub.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + ) + + +def materialize_mul( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.mul.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + ) + + +def materialize_softmax_and_logistic( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.softmax and tfl.logistic.""" + # Hard code scales and zp values as they are hard coded in TFL kernels. + # Softmax: + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/activations.cc#L548 + # Logistic: + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/activations.cc#L421 + output_activation_constraints = { + 8: qtyping.UniformQuantParams( + num_bits=8, + quantized_dimension=None, + scale=np.array(1.0 / 256), + zero_point=np.array(-128), + symmetric=False, + ), + 16: qtyping.UniformQuantParams( + num_bits=16, + quantized_dimension=None, + scale=np.array(1.0 / 32768), + zero_point=np.array(0), + ), + } + + return common_utils.materialize_op_with_output_activation_constraint( + op_info, + graph_info, + tensor_name_to_qsv, + output_activation_constraints, + get_tensor_quant_params_fn, + ) + + +def materialize_batch_matmul( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.batch_matmul.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + ) + + +def materialize_embedding_lookup( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.embedding_lookup.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + inputs_to_ignore=[0], # Lookup index does not need to be quantized. + ) + + +def materialize_reshape( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.reshape.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, + inputs_to_ignore=[1], # Shape tensor does not need to be quantized. + ) + + +def materialize_average_pool_2d( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.average_pool_2d.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, + ) + + +def _materialize_bias_for_conv_ops( + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + op_tensor_params: list[qtyping.TensorTransformationParams], + op_input_index: int = 0, + op_weight_index: int = 1, + op_bias_index: int = 2, +): + """Materializes bias tensors in conv ops by updating `op_tensor_params`. + + Args: + op_info: Aggregated information about the op (e.g., quantization config). + graph_info: Graph information needed to perform quantization for the op. + op_tensor_params: Partially populated quantization configuration for the + tensors associated with the op in the order of input, weight, output. + op_input_index: Index for the input tensor in the op. + op_weight_index: Index for the weight tensor in the op. + op_bias_index: Index for the bias tensor in the op. + """ + _, _, bias_tensor, _ = tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors( + op_info.op, + graph_info.subgraph_tensors, + op_input_index, + op_weight_index, + op_bias_index, + ) + if bias_tensor is not None: + bias_quant_params = None + # Fused bias needs to be quantized for SRQ. + # Check if SRQ. + if ( + op_info.op_quant_config.compute_precision == _ComputePrecision.INTEGER + and op_info.op_quant_config.activation_tensor_config is not None + ): + bias_content = tfl_flatbuffer_utils.get_tensor_data( + bias_tensor, + graph_info.buffers, + ) + bias_quant_params = ( + uniform_quantize_tensor.symmetric_quantize_bias_tensor( + bias_content, + op_tensor_params[op_input_index].consumers[0].parameters, + op_tensor_params[op_weight_index].consumers[0].parameters, + ) + ) + # We only quantize bias under SRQ. Setting is_constant=True for SRQ only + # to avoid quantize bias for DRQ and weight-only cases. + is_constant = ( + # Check if SRQ. + op_info.op_quant_config.compute_precision == _ComputePrecision.INTEGER + and op_info.op_quant_config.activation_tensor_config is not None + ) + op_tensor_params[op_bias_index] = ( + common_utils.get_tensor_transformation_params( + tfl_flatbuffer_utils.get_tensor_name(bias_tensor), + op_info, + is_inbounding_tensor=True, + quant_params=bias_quant_params, + is_constant=is_constant, + ) + ) + + +def _are_weights_too_small( + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + weight_index: int, +) -> bool: + """Checks if weights are too small to be quantized.""" + tensor = graph_info.subgraph_tensors[op_info.op.inputs[weight_index]] + tensor_data = tfl_flatbuffer_utils.get_tensor_data( + tensor, + graph_info.buffers, + ) + return ( + tensor_data is not None + and np.size(tensor_data) < op_info.op_quant_config.min_weight_elements + ) + + +def materialize_slice( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.slice.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, + inputs_to_ignore=[ + 1, + 2, + ], # Begin and size indices do not need to be quantized. + ) + + +def materialize_select_v2( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.select_v2.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE, + inputs_to_ignore=[ + 0, + ], # Condition tensor does not need to be quantized. + ) + + +def materialize_sum( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.sum.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, + inputs_to_ignore=[1], # Axis index does not need to be quantized. + ) + + +def materialize_fc_conv( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], + input_index: int = 0, + weight_index: int = 1, + bias_index: int = 2, +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in fully_connected, conv_2d and depthwise_conv_2d. + + Args: + get_tensor_quant_params_fn: A function to get the quantization parameters + for a tensor. + op_info: Aggregated information about the op (e.g., quantization config). + graph_info: Graph information needed to perform quantization for the op. + tensor_name_to_qsv: A map of tensor name to quantization parameters. + input_index: Index for the input tensor in the op. + weight_index: Index for the weight tensor in the op. + bias_index: Index for the bias tensor in the op. + + Returns: + Quantization configuration for the tensors associated with the op (e.g., + weights, bias). + """ + ignored_inputs = [bias_index] # Bias tensor is quantized separately. + if _are_weights_too_small(op_info, graph_info, weight_index): + ignored_inputs.append(weight_index) + + op_tensor_params = common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + inputs_to_ignore=ignored_inputs, + ) + + _materialize_bias_for_conv_ops( + op_info, + graph_info, + op_tensor_params, + op_input_index=input_index, + op_weight_index=weight_index, + op_bias_index=bias_index, + ) + + return op_tensor_params + + +def materialize_conv2d_transpose( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.conv2d_transpose. + + Args: + get_tensor_quant_params_fn: A function to get the quantization parameters + for a tensor. + op_info: Aggregated information about the op (e.g., quantization config). + graph_info: Graph information needed to perform quantization for the op. + tensor_name_to_qsv: A map of tensor name to quantization parameters. + + Returns: + Quantization configuration for the tensors associated with the op (e.g., + weights, bias). + """ + ignored_shape_index = 0 + weight_index = 1 + input_index = 2 + bias_index = 3 + + ignored_inputs = [ + ignored_shape_index, + bias_index, # Bias tensor is quantized separately. + ] + if _are_weights_too_small(op_info, graph_info, weight_index): + ignored_inputs.append(weight_index) + + op_tensor_params = common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + inputs_to_ignore=ignored_inputs, + ) + if len(op_tensor_params) < 2: + raise ValueError( + "Materialize standard op should return at least two tensors for" + " conv2d_transpose." + ) + _materialize_bias_for_conv_ops( + op_info, + graph_info, + op_tensor_params, + op_input_index=input_index, + op_weight_index=weight_index, + op_bias_index=bias_index, + ) + + return op_tensor_params + + +def materialize_tanh( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.tanh.""" + # Hard code scales and zero point values as they are hard coded in: + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/ir/tfl_ops.td#L3430 + output_activation_constraints = {} + for num_bits in [8, 16]: + output_activation_constraints[num_bits] = qtyping.UniformQuantParams( + num_bits=num_bits, + quantized_dimension=None, + scale=np.array(1.0 / (1 << (num_bits - 1))), + zero_point=np.array(0), + # Activation is always asymmetric for 8 bit and symmetric for 16 bits. + symmetric=num_bits == 16, + ) + return common_utils.materialize_op_with_output_activation_constraint( + op_info, + graph_info, + tensor_name_to_qsv, + output_activation_constraints, + get_tensor_quant_params_fn, + ) + + +def materialize_transpose( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.transpose.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, + inputs_to_ignore=[1], # Permutation tensor does not need to be quantized. + ) + + +def materialize_gelu( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.gelu.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + ) + + +def materialize_strided_slice( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.strided_slice.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, + inputs_to_ignore=[1, 2, 3], # Ignore the begin, end, and strides tensors. + ) + + +def materialize_mean( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.mean.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + inputs_to_ignore=[1], # Axis tensor does not need to be quantized. + ) + + +def materialize_rsqrt( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.rsqrt.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + ) + + +def materialize_concatenation( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.concatenation.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE, + ) + + +def materialize_split( + get_tensor_quant_params_fn: Callable[..., Any], + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.split.""" + return common_utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, + constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, + inputs_to_ignore=[0], # Split dimension does not need to be quantized. + ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py new file mode 100644 index 0000000..610f31a --- /dev/null +++ b/ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py @@ -0,0 +1,74 @@ +# Copyright 2024 The AI Edge Quantizer Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.platform import googletest +from ai_edge_quantizer import default_policy +from ai_edge_quantizer import qtyping +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize +from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_flatbuffer_utils + +_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../../tests/models") +_TFLOpName = qtyping.TFLOperationName +_TensorQuantConfig = qtyping.TensorQuantizationConfig + + +class CommonQuantizeTest(parameterized.TestCase): + """Tests for general quantize functions. + """ + + def setUp(self): + super().setUp() + np.random.seed(666) + self._test_model_path = os.path.join( + _TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite" + ) + self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path) + # The test model has one subgraph for now. + self._graph_info = qtyping.GraphInfo( + subgraph_tensors=self._test_model.subgraphs[0].tensors, + buffers=self._test_model.buffers, + ) + self._tensor_name_to_qsv = {} + + def test_check_op_quantization_config_with_negative_min_weight_elements_raises_error( + self, + ): + op_quant_config = qtyping.OpQuantizationConfig( + weight_tensor_config=_TensorQuantConfig( + num_bits=8, + granularity=qtyping.QuantGranularity.CHANNELWISE, + ), + compute_precision=qtyping.ComputePrecision.INTEGER, # DRQ. + min_weight_elements=-1, + ) + with self.assertRaisesWithPredicateMatch( + ValueError, + lambda err: "min_weight_elements must be non-negative" in str(err), + ): + common_quantize.check_op_quantization_config( + _TFLOpName.FULLY_CONNECTED, + op_quant_config, + default_policy.DEFAULT_CONFIG_CHECK_POLICY, + ) + + +if __name__ == "__main__": + googletest.main() diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py index cfe6b4c..dec7ad0 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py @@ -19,559 +19,213 @@ import numpy as np from ai_edge_quantizer import qtyping from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor -from ai_edge_quantizer.algorithms.utils import min_max_quantize_utils as utils +from ai_edge_quantizer.algorithms.utils import common_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils ALGORITHM_KEY = "min_max_uniform_quantize" _TFLOpName = qtyping.TFLOperationName _QuantTransformation = qtyping.QuantTransformation -_OpQuantConstraint = utils.OpQuantConstraint -_ComputePrecision = qtyping.ComputePrecision +_IntType = uniform_quantize_tensor.IntType -def check_op_quantization_config( - op_name: _TFLOpName, - op_quant_config: qtyping.OpQuantizationConfig, - config_check_policy: qtyping.ConfigCheckPolicyDict, -) -> None: - """Checks the op quantization config. - - Args: - op_name: The name of the op. - op_quant_config: The quantization config for the op. - config_check_policy: The policy to check the op quantization config. - - Raises: - ValueError: If the op quantization config is invalid. - """ - if op_quant_config.weight_tensor_config is None: - raise ValueError( - "Weight tensor quantization is required for min/max uniform" - " quantization." - ) - if op_quant_config.weight_tensor_config.dtype != qtyping.TensorDataType.INT: - raise ValueError( - "Weights need to have integer type for min/max uniform quantization. If" - " you wish to perform float casting quantization (e.g., fp16 weight" - " only), please set algorithm key as 'float_casting'." - ) - - if op_quant_config.min_weight_elements < 0: - raise ValueError( - f"min_weight_elements must be non-negative for op: {op_name} with" - f" config: {op_quant_config}." - ) - - if op_quant_config.compute_precision in [ - _ComputePrecision.INTEGER, - _ComputePrecision.FLOAT, - ]: - # Use policy-based mechanism to validate op. - utils.check_if_valid_op_config( - op_name, op_quant_config, config_check_policy - ) - utils.check_subchannel_config(op_name, op_quant_config) - - -def materialize_input( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in the virtual input op.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - ) - - -def materialize_output( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in the virtual output op.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - ) - - -def materialize_add( +def _init_tensor_min_max( + tensor_data: Optional[np.ndarray], op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.add.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - ) - - -def materialize_sub( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.sub.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - ) - - -def materialize_mul( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.mul.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - ) - - -def materialize_softmax_and_logistic( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.softmax and tfl.logistic.""" - # Hard code scales and zp values as they are hard coded in TFL kernels. - # Softmax: - # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/activations.cc#L548 - # Logistic: - # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/activations.cc#L421 - output_activation_constraints = { - 8: qtyping.UniformQuantParams( - num_bits=8, - quantized_dimension=None, - scale=np.array(1.0 / 256), - zero_point=np.array(-128), - symmetric=False, - ), - 16: qtyping.UniformQuantParams( - num_bits=16, - quantized_dimension=None, - scale=np.array(1.0 / 32768), - zero_point=np.array(0), - ), - } - - return utils.materialize_op_with_output_activation_constraint( - op_info, - graph_info, - tensor_name_to_qsv, - output_activation_constraints, - ) - - -def materialize_batch_matmul( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.batch_matmul.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - ) - - -def materialize_embedding_lookup( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.embedding_lookup.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - inputs_to_ignore=[0], # Lookup index does not need to be quantized. - ) - - -def materialize_reshape( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.reshape.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, - inputs_to_ignore=[1], # Shape tensor does not need to be quantized. - ) - - -def materialize_average_pool_2d( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.average_pool_2d.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, - ) - - -def _materialize_bias_for_conv_ops( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - op_tensor_params: list[qtyping.TensorTransformationParams], - op_input_index: int = 0, - op_weight_index: int = 1, - op_bias_index: int = 2, -): - """Materializes bias tensors in conv ops by updating `op_tensor_params`. - - Args: - op_info: Aggregated information about the op (e.g., quantization config). - graph_info: Graph information needed to perform quantization for the op. - op_tensor_params: Partially populated quantization configuration for the - tensors associated with the op in the order of input, weight, output. - op_input_index: Index for the input tensor in the op. - op_weight_index: Index for the weight tensor in the op. - op_bias_index: Index for the bias tensor in the op. - """ - _, _, bias_tensor, _ = tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors( - op_info.op, - graph_info.subgraph_tensors, - op_input_index, - op_weight_index, - op_bias_index, - ) - if bias_tensor is not None: - bias_quant_params = None - # Fused bias needs to be quantized for SRQ. - # Check if SRQ. +) -> qtyping.QSV: + """Initialize the min/max for a tensor.""" + if tensor_data is None: + return {} + else: + quantized_dim = None if ( - op_info.op_quant_config.compute_precision == _ComputePrecision.INTEGER - and op_info.op_quant_config.activation_tensor_config is not None + op_info.op_quant_config.weight_tensor_config is not None + and op_info.op_quant_config.weight_tensor_config.granularity + == qtyping.QuantGranularity.BLOCKWISE ): - bias_content = tfl_flatbuffer_utils.get_tensor_data( - bias_tensor, - graph_info.buffers, - ) - bias_quant_params = ( - uniform_quantize_tensor.symmetric_quantize_bias_tensor( - bias_content, - op_tensor_params[op_input_index].consumers[0].parameters, - op_tensor_params[op_weight_index].consumers[0].parameters, - ) + # TODO(b/346612503): emulate subchannel only supports fully connected, + # will skip special handling. Once we have a spec, we can change this. + block_size = op_info.op_quant_config.weight_tensor_config.block_size + # assuming tensor is 2D, which is correct for FULLY_CONNECTED + transposed_tensor_data = np.transpose(tensor_data, (1, 0)) + if transposed_tensor_data.shape[0] % block_size: + raise ValueError( + f"Block size {block_size} does not divide channel dimension" + f" {transposed_tensor_data.shape[0]}." + ) + reshaped_tensor_data = np.reshape( + transposed_tensor_data, + ( + 1, + int(transposed_tensor_data.shape[0] / block_size), + block_size, + transposed_tensor_data.shape[1], + ), ) - # We only quantize bias under SRQ. Setting is_constant=True for SRQ only - # to avoid quantize bias for DRQ and weight-only cases. - is_constant = ( - # Check if SRQ. - op_info.op_quant_config.compute_precision == _ComputePrecision.INTEGER - and op_info.op_quant_config.activation_tensor_config is not None - ) - op_tensor_params[op_bias_index] = utils.get_tensor_transformation_params( - tfl_flatbuffer_utils.get_tensor_name(bias_tensor), - op_info, - is_inbounding_tensor=True, - quant_params=bias_quant_params, - is_constant=is_constant, + return { + "min": np.min(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True), + "max": np.max(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True), + } + if ( + op_info.op_quant_config.weight_tensor_config is not None + and op_info.op_quant_config.weight_tensor_config.granularity + == qtyping.QuantGranularity.CHANNELWISE + ): + if op_info.op_name == _TFLOpName.BATCH_MATMUL: + quantized_dim = common_utils.get_bmm_weight_quantized_dim( + tensor_data, adj_y=op_info.op.builtinOptions.adjY + ) + else: + quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get( + op_info.op_name, None + ) + reduce_dims = common_utils.get_reduce_dims( + quantized_dim, list(tensor_data.shape) ) + return { + "min": np.min(tensor_data, axis=reduce_dims, keepdims=True), + "max": np.max(tensor_data, axis=reduce_dims, keepdims=True), + } -def _are_weights_too_small( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - weight_index: int, -) -> bool: - """Checks if weights are too small to be quantized.""" - tensor = graph_info.subgraph_tensors[op_info.op.inputs[weight_index]] - tensor_data = tfl_flatbuffer_utils.get_tensor_data( - tensor, - graph_info.buffers, - ) - return ( - tensor_data is not None - and np.size(tensor_data) < op_info.op_quant_config.min_weight_elements - ) - - -def materialize_slice( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.slice.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, - inputs_to_ignore=[ - 1, - 2, - ], # Begin and size indices do not need to be quantized. - ) - - -def materialize_select_v2( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.select_v2.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE, - inputs_to_ignore=[ - 0, - ], # Condition tensor does not need to be quantized. - ) - - -def materialize_sum( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.sum.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, - inputs_to_ignore=[1], # Axis index does not need to be quantized. - ) - - -def materialize_fc_conv( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], - input_index: int = 0, - weight_index: int = 1, - bias_index: int = 2, -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in fully_connected, conv_2d and depthwise_conv_2d. +def _tensor_zp_scale_from_min_max( + min_value, max_value, num_bits: int, symmetric: bool +): + """Get zero point and scale from min and max value. Args: - op_info: Aggregated information about the op (e.g., quantization config). - graph_info: Graph information needed to perform quantization for the op. - tensor_name_to_qsv: A map of tensor name to quantization parameters. - input_index: Index for the input tensor in the op. - weight_index: Index for the weight tensor in the op. - bias_index: Index for the bias tensor in the op. + min_value: The minimum value of the tensor (channel-wise supported). + max_value: The maximum value of the tensor (channel-wise supported). + num_bits: The number of bits of the tensor. + symmetric: Whether the tensor is symmetric. Returns: - Quantization configuration for the tensors associated with the op (e.g., - weights, bias). + The zero point and scale of the tensor. """ - ignored_inputs = [bias_index] # Bias tensor is quantized separately. - if _are_weights_too_small(op_info, graph_info, weight_index): - ignored_inputs.append(weight_index) - - op_tensor_params = utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - inputs_to_ignore=ignored_inputs, - ) - - _materialize_bias_for_conv_ops( - op_info, - graph_info, - op_tensor_params, - op_input_index=input_index, - op_weight_index=weight_index, - op_bias_index=bias_index, + # TODO: b/332574603 - support unsigned data type. + qtype = uniform_quantize_tensor.IntType( + num_bits, + signed=True, ) - - return op_tensor_params - - -def materialize_conv2d_transpose( + qmin, qmax = uniform_quantize_tensor.get_quantized_range(qtype) + min_bound = 1e-4 # 1e-6 precision for int8 and 1e-8 for int16. + + if symmetric: + bound = np.maximum(np.abs(min_value), np.abs(max_value)) + bound = np.maximum(bound, min_bound) + if not qtype.signed: + half_q = (qmax - 1) / 2 + scale = bound / half_q + zp = np.ones_like(scale) * (half_q + 1) + else: + scale = bound / qmax + zp = np.zeros_like(scale, dtype=np.int32) + + else: + # Include 0 to the range to support zero-padding. + # See: https://arxiv.org/pdf/1712.05877.pdf + # This ensures bound_min <= 0 <= bound_max. + bound_max = np.maximum(max_value, np.zeros_like(max_value)) + bound_min = np.minimum(min_value, np.zeros_like(min_value)) + bound = np.maximum(bound_max - bound_min, min_bound) + scale = bound / (qmax - qmin) + zp = qmin - bound_min / scale + zp = np.rint(zp) + + # It's safe to cast zp to qtype without clipping because we can infer + # qmin <= zp <= qmax from bound_min <= 0 <= bound_max. + zp = uniform_quantize_tensor.assign_quantized_type(zp, qtype) + return zp, scale + + +def get_tensor_quant_params( op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.conv2d_transpose. + tensor_quant_config: qtyping.TensorQuantizationConfig, + tensor_content: Optional[np.ndarray] = None, + tensor_qsv: Optional[dict[str, Any]] = None, +) -> qtyping.UniformQuantParams: + """Get the quantization parameters for a tensor. Args: - op_info: Aggregated information about the op (e.g., quantization config). - graph_info: Graph information needed to perform quantization for the op. - tensor_name_to_qsv: A map of tensor name to quantization parameters. + op_info: aggregated information about the op (e.g., quantization config). + tensor_quant_config: the quantization config for the tensor. + tensor_content: the content of the tensor. + tensor_qsv: a dictionary containingthe min/max of the tensor. Returns: - Quantization configuration for the tensors associated with the op (e.g., - weights, bias). + The quantization parameters for the tensor. """ - ignored_shape_index = 0 - weight_index = 1 - input_index = 2 - bias_index = 3 - - ignored_inputs = [ - ignored_shape_index, - bias_index, # Bias tensor is quantized separately. - ] - if _are_weights_too_small(op_info, graph_info, weight_index): - ignored_inputs.append(weight_index) + # Get quant params. + if tensor_qsv is None: + if tensor_content is not None: + # We need min/max to calculate quantization parameters, which + # should be collected during the calibration process. However, + # weight-only and DRQ do not require calibration, thus it is + # possible that this information is missing here. In that case we + # collect min/max on the spot. + tensor_min_max = _init_tensor_min_max( + tensor_content, + op_info, + ) + else: + raise ValueError( + f"{op_info.op_name}(index: {op_info.subgraph_op_index}) not found in" + " tensor_name_to_qsv. Check if the correct calibration results are" + " passed into the ParamsGenerator." + ) + else: + tensor_min_max = tensor_qsv - op_tensor_params = utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - inputs_to_ignore=ignored_inputs, - ) - if len(op_tensor_params) < 2: + if "min" not in tensor_min_max or "max" not in tensor_min_max: raise ValueError( - "Materialize standard op should return at least two tensors for" - " conv2d_transpose." - ) - _materialize_bias_for_conv_ops( - op_info, - graph_info, - op_tensor_params, - op_input_index=input_index, - op_weight_index=weight_index, - op_bias_index=bias_index, - ) - - return op_tensor_params - - -def materialize_tanh( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.tanh.""" - # Hard code scales and zero point values as they are hard coded in: - # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/ir/tfl_ops.td#L3430 - output_activation_constraints = {} - for num_bits in [8, 16]: - output_activation_constraints[num_bits] = qtyping.UniformQuantParams( - num_bits=num_bits, - quantized_dimension=None, - scale=np.array(1.0 / (1 << (num_bits - 1))), - zero_point=np.array(0), - # Activation is always asymmetric for 8 bit and symmetric for 16 bits. - symmetric=num_bits == 16, + "min and max must be provided to produce tensor quantization" + " parameters. Check if the correct calibration results are passed into" + " the ParamsGenerator." ) - return utils.materialize_op_with_output_activation_constraint( - op_info, graph_info, tensor_name_to_qsv, output_activation_constraints + zp, scale = _tensor_zp_scale_from_min_max( + tensor_min_max["min"], + tensor_min_max["max"], + tensor_quant_config.num_bits, + tensor_quant_config.symmetric, ) - - -def materialize_transpose( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.transpose.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, - inputs_to_ignore=[1], # Permutation tensor does not need to be quantized. - ) - - -def materialize_gelu( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.gelu.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - ) - - -def materialize_strided_slice( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.strided_slice.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, - inputs_to_ignore=[1, 2, 3], # Ignore the begin, end, and strides tensors. - ) - - -def materialize_mean( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.mean.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - inputs_to_ignore=[1], # Axis tensor does not need to be quantized. - ) - - -def materialize_rsqrt( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.rsqrt.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - ) - - -def materialize_concatenation( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.concatenation.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE, + quantized_dim = None + if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE: + print(op_info) + print(tensor_quant_config.granularity) + if op_info.op_name == _TFLOpName.BATCH_MATMUL: + quantized_dim = common_utils.get_bmm_weight_quantized_dim( + tensor_content, adj_y=op_info.op.builtinOptions.adjY + ) + else: + quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM[ + op_info.op_name + ] + quant_params = qtyping.UniformQuantParams( + scale=scale, + zero_point=zp, + num_bits=tensor_quant_config.num_bits, + symmetric=tensor_quant_config.symmetric, + quantized_dimension=quantized_dim, ) - - -def materialize_split( - op_info: qtyping.OpInfo, - graph_info: qtyping.GraphInfo, - tensor_name_to_qsv: dict[str, Any], -) -> list[qtyping.TensorTransformationParams]: - """Materialize tensors in tfl.split.""" - return utils.materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, - constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, - inputs_to_ignore=[0], # Split dimension does not need to be quantized. + if tensor_content is None: + return quant_params + if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE: + quantized_vars = ( + uniform_quantize_tensor.uniform_quantize_for_emulated_subchannel( + tensor_content, quant_params, tensor_quant_config.block_size + ) + ) + else: + quantized_vars = uniform_quantize_tensor.uniform_quantize( + tensor_content, quant_params + ) + # Update with quantized values. + return qtyping.UniformQuantParams( + scale=scale, + zero_point=zp, + num_bits=tensor_quant_config.num_bits, + symmetric=tensor_quant_config.symmetric, + quantized_dimension=quantized_dim, + quantized_data=quantized_vars, ) @@ -601,18 +255,22 @@ def init_qsvs( if tensor_idx != -1 and i not in inputs_to_ignore: tensor = graph_info.subgraph_tensors[tensor_idx] tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor) - op_qsvs[tensor_name] = utils.init_tensor_min_max( - tensor, - graph_info, + tensor_data = tfl_flatbuffer_utils.get_tensor_data( + tensor, graph_info.buffers + ) + op_qsvs[tensor_name] = _init_tensor_min_max( + tensor_data, op_info, ) for i, tensor_idx in enumerate(op_info.op.outputs): if tensor_idx != -1 and i not in outputs_to_ignore: tensor = graph_info.subgraph_tensors[tensor_idx] tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor) - op_qsvs[tensor_name] = utils.init_tensor_min_max( - tensor, - graph_info, + tensor_data = tfl_flatbuffer_utils.get_tensor_data( + tensor, graph_info.buffers + ) + op_qsvs[tensor_name] = _init_tensor_min_max( + tensor_data, op_info, ) return op_qsvs diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/add_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/add_test.py index f9b0251..f6b4974 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/add_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/add_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -94,7 +94,7 @@ def test_materialize_srq_add_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_add, + common_quantize.materialize_add, ) @parameterized.named_parameters( @@ -139,7 +139,7 @@ def test_materialize_srq_add1_constant_input_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_add, + common_quantize.materialize_add, ) if __name__ == "__main__": diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/average_pool_2d_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/average_pool_2d_test.py index 9888303..a5eaee5 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/average_pool_2d_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/average_pool_2d_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -101,7 +101,7 @@ def test_materialize_average_pool_2d_succeeds(self, activation_tensor_config): op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_average_pool_2d, + common_quantize.materialize_average_pool_2d, same_input_output_params=True, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/batch_matmul_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/batch_matmul_test.py index 2e8d6d6..c421311 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/batch_matmul_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/batch_matmul_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -96,7 +96,7 @@ def test_batch_matmul_adjy_false_srq_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_batch_matmul, + common_quantize.materialize_batch_matmul, ) @parameterized.named_parameters( @@ -137,7 +137,7 @@ def test_batch_matmul_adjy_true_srq_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_batch_matmul, + common_quantize.materialize_batch_matmul, ) @@ -217,7 +217,7 @@ def test_batch_matmul_adjy_false_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_batch_matmul, + common_quantize.materialize_batch_matmul, ) @parameterized.product( @@ -273,7 +273,7 @@ def test_batch_matmul_adjy_true_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_batch_matmul, + common_quantize.materialize_batch_matmul, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/concatenation_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/concatenation_test.py index e645cd9..ce86cec 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/concatenation_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/concatenation_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -99,7 +99,7 @@ def test_materialize_concatenation_succeeds(self, activation_tensor_config): op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_concatenation, + common_quantize.materialize_concatenation, same_input_output_params=True, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_test.py index ef4928f..1abd54e 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -117,7 +117,7 @@ def test_materialize_weight_only_drq_conv2d_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_fc_conv, + common_quantize.materialize_fc_conv, ) @parameterized.product( @@ -164,7 +164,7 @@ def test_materialize_srq_conv2d_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_fc_conv, + common_quantize.materialize_fc_conv, ) @parameterized.named_parameters( @@ -193,7 +193,7 @@ def test_materialize_conv2d_quantizes_weights_larger_than_min_weight_elements_fo min_weight_elements=min_weight_elements, graph_info=self._graph_info, op_test_info=self._op_test_info, - materialization_func=naive_min_max_quantize.materialize_fc_conv, + materialization_func=common_quantize.materialize_fc_conv, expect_weights_quantized=expect_weights_quantized, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_transpose_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_transpose_test.py index 885c8ff..ce46551 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_transpose_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_transpose_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -116,7 +116,7 @@ def test_materialize_weight_only_drq_conv2d_transpose_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_conv2d_transpose, + common_quantize.materialize_conv2d_transpose, bias_quantized_dim=None, input_index=2, weight_index=1, @@ -164,7 +164,7 @@ def test_materialize_srq_conv2d_transpose_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_conv2d_transpose, + common_quantize.materialize_conv2d_transpose, bias_quantized_dim=None, input_index=2, weight_index=1, @@ -198,7 +198,7 @@ def test_materialize_conv2d_transpose_quantizes_weights_larger_than_min_weight_e min_weight_elements=min_weight_elements, graph_info=self._graph_info, op_test_info=self._op_test_info, - materialization_func=naive_min_max_quantize.materialize_conv2d_transpose, + materialization_func=common_quantize.materialize_conv2d_transpose, expect_weights_quantized=expect_weights_quantized, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/depthwise_conv2d_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/depthwise_conv2d_test.py index e94898e..365a453 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/depthwise_conv2d_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/depthwise_conv2d_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -113,7 +113,7 @@ def test_materialize_weight_only_drq_depthwise_conv2d_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_fc_conv, + common_quantize.materialize_fc_conv, ) @parameterized.parameters(8, 16) @@ -156,7 +156,7 @@ def test_materialize_srq_depthwise_conv2d_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_fc_conv, + common_quantize.materialize_fc_conv, ) @parameterized.named_parameters( @@ -185,7 +185,7 @@ def _test_materialize_depthwise_conv2d_quantizes_weights_larger_than_min_weight_ min_weight_elements=min_weight_elements, graph_info=self._graph_info, op_test_info=self._op_test_info, - materialization_func=naive_min_max_quantize.materialize_fc_conv, + materialization_func=common_quantize.materialize_fc_conv, expect_weights_quantized=expect_weights_quantized, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/embedding_lookup_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/embedding_lookup_test.py index d02dd2d..7050efa 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/embedding_lookup_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/embedding_lookup_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -111,7 +111,7 @@ def test_embedding_lookup_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_embedding_lookup, + common_quantize.materialize_embedding_lookup, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/fully_connected_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/fully_connected_test.py index 921eb5b..4e65c6d 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/fully_connected_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/fully_connected_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -123,7 +123,7 @@ def test_materialize_fully_connected_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_fc_conv, + common_quantize.materialize_fc_conv, ) @parameterized.named_parameters( @@ -152,7 +152,7 @@ def test_materialize_fully_connected_quantizes_weights_larger_than_min_weight_el min_weight_elements=min_weight_elements, graph_info=self._graph_info, op_test_info=self._op_test_info, - materialization_func=naive_min_max_quantize.materialize_fc_conv, + materialization_func=common_quantize.materialize_fc_conv, expect_weights_quantized=expect_weights_quantized, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/gelu_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/gelu_test.py index d90d441..9b52bfa 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/gelu_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/gelu_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -98,7 +98,7 @@ def test_materialize_gelu_succeeds(self, activation_tensor_config): op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_gelu, + common_quantize.materialize_gelu, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/input_output_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/input_output_test.py index d773ae1..59ae221 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/input_output_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/input_output_test.py @@ -20,6 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils @@ -97,8 +98,11 @@ def test_materialize_input_float( subgraph_op_index=-1, # Virtual op, no real id. op_quant_config=op_quant_config, ) - quantization_params = naive_min_max_quantize.materialize_input( - op_info, self._graph_info, self._model_qsv + quantization_params = common_quantize.materialize_input( + naive_min_max_quantize.get_tensor_quant_params, + op_info, + self._graph_info, + self._model_qsv, ) # Only one input tensor for the test model. self.assertLen(quantization_params, 1) @@ -141,8 +145,11 @@ def test_materialize_output_float( subgraph_op_index=-1, # Virtual op, no real id. op_quant_config=op_quant_config, ) - quantization_params = naive_min_max_quantize.materialize_output( - op_info, self._graph_info, self._model_qsv + quantization_params = common_quantize.materialize_output( + naive_min_max_quantize.get_tensor_quant_params, + op_info, + self._graph_info, + self._model_qsv, ) # Only one output tensor for the test model. self.assertLen(quantization_params, 1) @@ -189,8 +196,11 @@ def test_materialize_input_integer( subgraph_op_index=-1, # Virtual op, no real id. op_quant_config=op_quant_config, ) - quantization_params = naive_min_max_quantize.materialize_input( - op_info, self._graph_info, self._model_qsv + quantization_params = common_quantize.materialize_input( + naive_min_max_quantize.get_tensor_quant_params, + op_info, + self._graph_info, + self._model_qsv, ) # Only one input tensor for the test model. self.assertLen(quantization_params, 1) @@ -240,8 +250,11 @@ def test_materialize_output_integer( subgraph_op_index=-1, # Virtual op, no real id. op_quant_config=op_quant_config, ) - quantization_params = naive_min_max_quantize.materialize_output( - op_info, self._graph_info, self._model_qsv + quantization_params = common_quantize.materialize_output( + naive_min_max_quantize.get_tensor_quant_params, + op_info, + self._graph_info, + self._model_qsv, ) # Only one output tensor for the test model. self.assertLen(quantization_params, 1) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/logistic_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/logistic_test.py index afa68e8..e63c1e8 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/logistic_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/logistic_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -98,7 +98,7 @@ def test_materialize_logistics_succeeds(self, activation_tensor_config): op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_softmax_and_logistic, + common_quantize.materialize_softmax_and_logistic, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/mean_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/mean_test.py index 9816f2e..e5e06ec 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/mean_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/mean_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -99,7 +99,7 @@ def test_materialize_mean_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_mean, + common_quantize.materialize_mean, inputs_to_ignore=[1], # Ignore axis tensor. ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/mul_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/mul_test.py index 2921cdd..f1c5288 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/mul_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/mul_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -94,7 +94,7 @@ def test_materialize_srq_mul_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_mul, + common_quantize.materialize_mul, ) @parameterized.named_parameters( @@ -139,7 +139,7 @@ def test_materialize_srq_mul2_constant_input_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_mul, + common_quantize.materialize_mul, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/reshape_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/reshape_test.py index bdf0901..f8cb72e 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/reshape_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/reshape_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -99,7 +99,7 @@ def test_materialize_reshape_succeeds(self, activation_tensor_config): op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_reshape, + common_quantize.materialize_reshape, same_input_output_params=True, inputs_to_ignore=[1], # Shape tensor does not need to be quantized. ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/rsqrt_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/rsqrt_test.py index 419a771..2fb76d0 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/rsqrt_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/rsqrt_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -91,7 +91,7 @@ def test_materialize_softmax_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_rsqrt, + common_quantize.materialize_rsqrt, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/select_v2_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/select_v2_test.py index f6c020d..4c2ab94 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/select_v2_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/select_v2_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -97,7 +97,7 @@ def test_materialize_select_v2_succeeds(self, num_bits): op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_select_v2, + common_quantize.materialize_select_v2, same_input_output_params=True, inputs_to_ignore=[0], ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/slice_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/slice_test.py index 33374de..8a39b78 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/slice_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/slice_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -97,7 +97,7 @@ def test_materialize_slice_succeeds(self, num_bits): op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_slice, + common_quantize.materialize_slice, same_input_output_params=True, inputs_to_ignore=[1, 2], ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/softmax_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/softmax_test.py index c55c042..92614d1 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/softmax_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/softmax_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -98,7 +98,7 @@ def test_materialize_softmax_succeeds(self, activation_tensor_config): op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_softmax_and_logistic, + common_quantize.materialize_softmax_and_logistic, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/split_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/split_test.py index 1449e48..e1fa8a7 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/split_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/split_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -100,7 +100,7 @@ def test_materialize_split_succeeds(self, activation_tensor_config): op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_split, + common_quantize.materialize_split, same_input_output_params=True, inputs_to_ignore=[0], # Ignore split dimension tensor. ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/strided_slice_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/strided_slice_test.py index 7492f1d..0eafeb9 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/strided_slice_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/strided_slice_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -97,7 +97,7 @@ def test_materialize_strided_slice_srq_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_strided_slice, + common_quantize.materialize_strided_slice, same_input_output_params=True, inputs_to_ignore=[1, 2, 3], ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/sub_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/sub_test.py index 00f40dd..b251408 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/sub_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/sub_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -94,7 +94,7 @@ def test_materialize_srq_sub_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_sub, + common_quantize.materialize_sub, ) @parameterized.named_parameters( @@ -139,7 +139,7 @@ def test_materialize_srq_sub1_constant_input_succeeds( op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_sub, + common_quantize.materialize_sub, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/sum_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/sum_test.py index 7cff7df..3afce85 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/sum_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/sum_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -95,7 +95,7 @@ def test_materialize_sum_succeeds(self, num_bits): op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_sum, + common_quantize.materialize_sum, same_input_output_params=True, inputs_to_ignore=[1], # Ignore axis tensor. ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/tanh_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/tanh_test.py index 8061b00..9b69e3f 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/tanh_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/tanh_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -98,7 +98,7 @@ def test_materialize_softmax_succeeds(self, activation_tensor_config): op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_tanh, + common_quantize.materialize_tanh, ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/test_utils.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/test_utils.py index 80119ee..1c49072 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/test_utils.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/test_utils.py @@ -23,6 +23,7 @@ import numpy as np from ai_edge_quantizer import qtyping +from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -241,7 +242,10 @@ def _test_no_weights_op( num_outputs=num_outputs, ) tensor_quant_params = materialization_func( - op_info, graph_info, self._tensor_name_to_qsv + naive_min_max_quantize.get_tensor_quant_params, + op_info, + graph_info, + self._tensor_name_to_qsv, ) self.assertLen(tensor_quant_params, num_inputs + num_outputs) @@ -327,7 +331,10 @@ def _test_fc_bmm_conv( op_test_info=op_test_info, ) tensor_quant_params = materialization_func( - op_info, graph_info, self._tensor_name_to_qsv + naive_min_max_quantize.get_tensor_quant_params, + op_info, + graph_info, + self._tensor_name_to_qsv, ) _, weight_tensor, bias_tensor, _ = ( @@ -635,7 +642,10 @@ def _test_materialize_fn_quantizes_weights_larger_than_min_weight_elements_for_w ), ) _, weight_quant_params, *_ = materialization_func( - op_info, graph_info, self._tensor_name_to_qsv + naive_min_max_quantize.get_tensor_quant_params, + op_info, + graph_info, + self._tensor_name_to_qsv, ) self.assertEqual( weight_quant_params.tensor_name, op_test_info.op_tensor_names["weight"] diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/transpose_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/transpose_test.py index 5c2d07d..904e921 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/transpose_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/transpose_test.py @@ -20,7 +20,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils @@ -99,7 +99,7 @@ def test_materialize_transpose_succeeds(self, activation_tensor_config): op_info, self._graph_info, self._op_test_info, - naive_min_max_quantize.materialize_transpose, + common_quantize.materialize_transpose, same_input_output_params=True, inputs_to_ignore=[1], # Ignore permutation tensor. ) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py index 20bd2ce..227a7e6 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py @@ -19,7 +19,6 @@ import numpy as np from tensorflow.python.platform import googletest -from ai_edge_quantizer import default_policy from ai_edge_quantizer import qtyping from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize from ai_edge_quantizer.utils import test_utils @@ -158,27 +157,6 @@ def test_min_max_calibrate(self): self.assertNotIn("arith.constant1", op_qsvs) self.assertNotIn("arith.constant2", op_qsvs) - def test_check_op_quantization_config_with_negative_min_weight_elements_raises_error( - self, - ): - op_quant_config = qtyping.OpQuantizationConfig( - weight_tensor_config=_TensorQuantConfig( - num_bits=8, - granularity=qtyping.QuantGranularity.CHANNELWISE, - ), - compute_precision=qtyping.ComputePrecision.INTEGER, # DRQ. - min_weight_elements=-1, - ) - with self.assertRaisesWithPredicateMatch( - ValueError, - lambda err: "min_weight_elements must be non-negative" in str(err), - ): - naive_min_max_quantize.check_op_quantization_config( - _TFLOpName.FULLY_CONNECTED, - op_quant_config, - default_policy.DEFAULT_CONFIG_CHECK_POLICY, - ) - if __name__ == "__main__": googletest.main() diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py b/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py index 3536dcb..babab1b 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py @@ -285,56 +285,6 @@ def symmetric_quantize_bias_tensor( ) -def tensor_zp_scale_from_min_max( - min_value, max_value, num_bits: int, symmetric: bool -): - """Get zero point and scale from min and max value. - - Args: - min_value: The minimum value of the tensor (channel-wise supported). - max_value: The maximum value of the tensor (channel-wise supported). - num_bits: The number of bits of the tensor. - symmetric: Whether the tensor is symmetric. - - Returns: - The zero point and scale of the tensor. - """ - # TODO: b/332574603 - support unsigned data type. - qtype = IntType( - num_bits, - signed=True, - ) - qmin, qmax = get_quantized_range(qtype) - min_bound = 1e-4 # 1e-6 precision for int8 and 1e-8 for int16. - - if symmetric: - bound = np.maximum(np.abs(min_value), np.abs(max_value)) - bound = np.maximum(bound, min_bound) - if not qtype.signed: - half_q = (qmax - 1) / 2 - scale = bound / half_q - zp = np.ones_like(scale) * (half_q + 1) - else: - scale = bound / qmax - zp = np.zeros_like(scale, dtype=np.int32) - - else: - # Include 0 to the range to support zero-padding. - # See: https://arxiv.org/pdf/1712.05877.pdf - # This ensures bound_min <= 0 <= bound_max. - bound_max = np.maximum(max_value, np.zeros_like(max_value)) - bound_min = np.minimum(min_value, np.zeros_like(min_value)) - bound = np.maximum(bound_max - bound_min, min_bound) - scale = bound / (qmax - qmin) - zp = qmin - bound_min / scale - zp = np.rint(zp) - - # It's safe to cast zp to qtype without clipping because we can infer - # qmin <= zp <= qmax from bound_min <= 0 <= bound_max. - zp = assign_quantized_type(zp, qtype) - return zp, scale - - def _is_valid_quantization_params( tensor_data: np.ndarray, quantization_params: qtyping.UniformQuantParams, diff --git a/ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py b/ai_edge_quantizer/algorithms/utils/common_utils.py similarity index 82% rename from ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py rename to ai_edge_quantizer/algorithms/utils/common_utils.py index 4345f28..66375d7 100644 --- a/ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +++ b/ai_edge_quantizer/algorithms/utils/common_utils.py @@ -13,12 +13,12 @@ # limitations under the License. # ============================================================================== -"""Utils for min/max based quantization.""" +"""Common utils for uniform quantization algorithms.""" from collections.abc import Sequence import dataclasses import enum -from typing import Any, Optional +from typing import Any, Callable, Optional import numpy as np from ai_edge_quantizer import qtyping from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor @@ -29,16 +29,8 @@ _QuantTransformation = qtyping.QuantTransformation _IntType = uniform_quantize_tensor.IntType -_SUPPORTED_WEIGHT_ONLY_OPS = frozenset([ - _TFLOpName.FULLY_CONNECTED, - _TFLOpName.CONV_2D, - _TFLOpName.BATCH_MATMUL, - _TFLOpName.EMBEDDING_LOOKUP, - _TFLOpName.DEPTHWISE_CONV_2D, - _TFLOpName.CONV_2D_TRANSPOSE, -]) -_SUPPORTED_DRQ_OPS = frozenset([ +_DRQ_OR_WEIGHT_ONLY_OPS = frozenset([ _TFLOpName.FULLY_CONNECTED, _TFLOpName.CONV_2D, _TFLOpName.BATCH_MATMUL, @@ -46,6 +38,7 @@ _TFLOpName.DEPTHWISE_CONV_2D, _TFLOpName.CONV_2D_TRANSPOSE, ]) + _SUPPORTED_SUBCHANNEL_OPS = frozenset([ _TFLOpName.FULLY_CONNECTED, ]) @@ -139,73 +132,21 @@ class OpQuantConstraint(enum.Enum): SAME_AS_OUTPUT_SCALE = 2 -def init_tensor_min_max( - tensor: Any, - graph_info: qtyping.GraphInfo, - op_info: qtyping.OpInfo, -): - """Initialize the min/max for a tensor.""" - tensor_data = tfl_flatbuffer_utils.get_tensor_data(tensor, graph_info.buffers) - # Initial values for non-constant tensors. - if tensor_data is None: - return {} - # Real min/max for constant tensors. - else: - quantized_dim = None - if ( - op_info.op_quant_config.weight_tensor_config is not None - and op_info.op_quant_config.weight_tensor_config.granularity - == qtyping.QuantGranularity.BLOCKWISE - ): - # TODO(b/346612503): emulate subchannel only supports fully connected, - # will skip special handling. Once we have a spec, we can change this. - block_size = op_info.op_quant_config.weight_tensor_config.block_size - # assuming tensor is 2D, which is correct for FULLY_CONNECTED - transposed_tensor_data = np.transpose(tensor_data, (1, 0)) - if transposed_tensor_data.shape[0] % block_size: - raise ValueError( - f"Block size {block_size} does not divide channel dimension" - f" {transposed_tensor_data.shape[0]}." - ) - reshaped_tensor_data = np.reshape( - transposed_tensor_data, - ( - 1, - int(transposed_tensor_data.shape[0] / block_size), - block_size, - transposed_tensor_data.shape[1], - ), - ) - return { - "min": np.min(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True), - "max": np.max(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True), - } - if ( - op_info.op_quant_config.weight_tensor_config is not None - and op_info.op_quant_config.weight_tensor_config.granularity - == qtyping.QuantGranularity.CHANNELWISE - ): - if op_info.op_name == _TFLOpName.BATCH_MATMUL: - quantized_dim = _get_bmm_weight_quantized_dim( - tensor_data, adj_y=op_info.op.builtinOptions.adjY - ) - else: - quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get( - op_info.op_name, None - ) - reduce_dims = _get_reduce_dims(quantized_dim, tensor.shape) - return { - "min": np.min(tensor_data, axis=reduce_dims, keepdims=True), - "max": np.max(tensor_data, axis=reduce_dims, keepdims=True), - } - - def _get_tensor_transformation_params_wrapper( tensor: Any, is_inbounding_tensor: bool, op_info: qtyping.OpInfo, graph_info: qtyping.GraphInfo, tensor_name_to_qsv: dict[str, Any], + get_tensor_quant_params_fn: Callable[ + [ + qtyping.OpInfo, # op_info + qtyping.TensorQuantizationConfig, # tensor_quant_config + Optional[np.ndarray], # tensor_data + Optional[dict[str, Any]], # tensor qsv + ], + qtyping.UniformQuantParams, + ], quant_params=None, ) -> qtyping.TensorTransformationParams: """Util to get tensor transformation params. @@ -216,6 +157,8 @@ def _get_tensor_transformation_params_wrapper( op_info: Aggregated information about the op (e.g., quantization config). graph_info: Graph information needed to perform quantization for the op. tensor_name_to_qsv: A map of tensor name to quantization parameters. + get_tensor_quant_params_fn: Function to get quantization parameters for the + tensor. quant_params: Quantization parameters for the tensor. Returns: @@ -229,37 +172,15 @@ def _get_tensor_transformation_params_wrapper( tensor_quant_config = op_info.op_quant_config.activation_tensor_config is_constant = tensor_data is not None # Use weight configuration if it is supported. - if is_constant and op_info.op_name in frozenset.union( - _SUPPORTED_WEIGHT_ONLY_OPS, _SUPPORTED_DRQ_OPS - ): + if is_constant and op_info.op_name in _DRQ_OR_WEIGHT_ONLY_OPS: tensor_quant_config = op_info.op_quant_config.weight_tensor_config # Get quant params. if quant_params is None and tensor_quant_config is not None: - if tensor_name not in tensor_name_to_qsv: - if is_constant: - # We need min/max to calculate quantization parameters, which - # should be collected during the calibration process. However, - # weight-only and DRQ do not require calibration, thus it is - # possible that this information is missing here. In that case we - # collect min/max on the spot. - tensor_min_max = init_tensor_min_max( - tensor, - graph_info, - op_info, - ) - else: - raise ValueError( - f"Tensor {tensor_name} not found in tensor_name_to_qsv. Check" - " if the correct calibration results are passed into the" - " ParamsGenerator." - ) - else: - tensor_min_max = tensor_name_to_qsv[tensor_name] - quant_params = _get_tensor_quant_params( + quant_params = get_tensor_quant_params_fn( op_info, - tensor_min_max, tensor_quant_config, - tensor_content=tensor_data, + tensor_data, + tensor_name_to_qsv.get(tensor_name), ) return get_tensor_transformation_params( tensor_name, @@ -277,6 +198,7 @@ def _materialize_op_tensors( op_info: qtyping.OpInfo, graph_info: qtyping.GraphInfo, tensor_name_to_qsv: dict[str, Any], + get_tensor_quant_params_fn: Callable[..., Any], quant_params=None, ) -> None: """Util to materialize op tensors. Appends the results to op_tensor_params. @@ -289,6 +211,8 @@ def _materialize_op_tensors( op_info: Aggregated information about the op (e.g., quantization config). graph_info: Graph information needed to perform quantization for the op. tensor_name_to_qsv: A map of tensor name to quantization parameters. + get_tensor_quant_params_fn: Function to get quantization parameters for the + tensor. quant_params: Quantization parameters for the tensor. """ for tensor in op_tensors: @@ -298,7 +222,8 @@ def _materialize_op_tensors( op_info, graph_info, tensor_name_to_qsv, - quant_params, + get_tensor_quant_params_fn=get_tensor_quant_params_fn, + quant_params=quant_params, ) op_tensor_params.append(tensor_params) @@ -309,6 +234,7 @@ def _get_single_tensor_params( op_info: qtyping.OpInfo, graph_info: qtyping.GraphInfo, tensor_name_to_qsv: dict[str, Any], + get_tensor_quant_params_fn: Callable[..., Any], ) -> qtyping.TensorTransformationParams: """Util to get single tensor params. @@ -318,6 +244,8 @@ def _get_single_tensor_params( op_info: Aggregated information about the op (e.g., quantization config). graph_info: Graph information needed to perform quantization for the op. tensor_name_to_qsv: A map of tensor name to quantization parameters. + get_tensor_quant_params_fn: Function to get quantization parameters for the + tensor. Returns: Transformation parameters for the tensor. @@ -336,6 +264,7 @@ def _get_single_tensor_params( op_info, graph_info, tensor_name_to_qsv, + get_tensor_quant_params_fn=get_tensor_quant_params_fn, ) @@ -345,6 +274,7 @@ def _materialize_standard_op_with_same_as_input_scale( op_info: qtyping.OpInfo, graph_info: qtyping.GraphInfo, tensor_name_to_qsv: dict[str, Any], + get_tensor_quant_params_fn: Callable[..., Any], ) -> list[qtyping.TensorTransformationParams]: """Materialize tensors in an op with same as input scale constraint. @@ -354,6 +284,8 @@ def _materialize_standard_op_with_same_as_input_scale( op_info: Aggregated information about the op (e.g., quantization config). graph_info: Graph information needed to perform quantization for the op. tensor_name_to_qsv: A map of tensor name to quantization parameters. + get_tensor_quant_params_fn: Function to get quantization parameters for the + tensor. Returns: Quantization configuration for the tensors associated with the op (e.g., @@ -367,6 +299,7 @@ def _materialize_standard_op_with_same_as_input_scale( op_info=op_info, graph_info=graph_info, tensor_name_to_qsv=tensor_name_to_qsv, + get_tensor_quant_params_fn=get_tensor_quant_params_fn, ) op_tensor_params.append(input_tensor_params) # Use input quantization params for all output tensors. @@ -377,6 +310,7 @@ def _materialize_standard_op_with_same_as_input_scale( op_info=op_info, graph_info=graph_info, tensor_name_to_qsv=tensor_name_to_qsv, + get_tensor_quant_params_fn=get_tensor_quant_params_fn, quant_params=input_tensor_params.consumers[0].parameters, ) # Change output qsv to be the same as input qsv. This is safe since TFL @@ -396,6 +330,7 @@ def _materialize_standard_op_with_same_as_output_scale( op_info: qtyping.OpInfo, graph_info: qtyping.GraphInfo, tensor_name_to_qsv: dict[str, Any], + get_tensor_quant_params_fn: Callable[..., Any], ) -> list[qtyping.TensorTransformationParams]: """Materialize tensors in an op with same as output scale constraint. @@ -405,6 +340,8 @@ def _materialize_standard_op_with_same_as_output_scale( op_info: Aggregated information about the op (e.g., quantization config). graph_info: Graph information needed to perform quantization for the op. tensor_name_to_qsv: A map of tensor name to quantization parameters. + get_tensor_quant_params_fn: Function to get quantization parameters for the + tensor. Returns: Quantization configuration for the tensors associated with the op (e.g., @@ -418,6 +355,7 @@ def _materialize_standard_op_with_same_as_output_scale( op_info=op_info, graph_info=graph_info, tensor_name_to_qsv=tensor_name_to_qsv, + get_tensor_quant_params_fn=get_tensor_quant_params_fn, ) # Use output quantization params for all input tensors. if output_tensor_params.producer is None: @@ -431,6 +369,7 @@ def _materialize_standard_op_with_same_as_output_scale( op_info=op_info, graph_info=graph_info, tensor_name_to_qsv=tensor_name_to_qsv, + get_tensor_quant_params_fn=get_tensor_quant_params_fn, quant_params=quant_params, ) op_tensor_params.append(output_tensor_params) @@ -444,6 +383,7 @@ def _materialize_standard_op_no_constraint( op_info: qtyping.OpInfo, graph_info: qtyping.GraphInfo, tensor_name_to_qsv: dict[str, Any], + get_tensor_quant_params_fn: Callable[..., Any], ) -> list[qtyping.TensorTransformationParams]: """Materialize tensors in an op with no constraint. @@ -453,6 +393,8 @@ def _materialize_standard_op_no_constraint( op_info: Aggregated information about the op (e.g., quantization config). graph_info: Graph information needed to perform quantization for the op. tensor_name_to_qsv: A map of tensor name to quantization parameters. + get_tensor_quant_params_fn: Function to get quantization parameters for the + tensor. Returns: Quantization configuration for the tensors associated with the op (e.g., @@ -466,6 +408,7 @@ def _materialize_standard_op_no_constraint( op_info=op_info, graph_info=graph_info, tensor_name_to_qsv=tensor_name_to_qsv, + get_tensor_quant_params_fn=get_tensor_quant_params_fn, ) _materialize_op_tensors( op_tensor_params, @@ -474,6 +417,7 @@ def _materialize_standard_op_no_constraint( op_info=op_info, graph_info=graph_info, tensor_name_to_qsv=tensor_name_to_qsv, + get_tensor_quant_params_fn=get_tensor_quant_params_fn, ) return op_tensor_params @@ -696,6 +640,7 @@ def materialize_standard_op( op_info: qtyping.OpInfo, graph_info: qtyping.GraphInfo, tensor_name_to_qsv: dict[str, Any], + get_tensor_quant_params_fn: Callable[..., Any], constraint: OpQuantConstraint = OpQuantConstraint.NO_CONSTRAIN, inputs_to_ignore: Optional[Sequence[int]] = None, outputs_to_ignore: Optional[Sequence[int]] = None, @@ -709,6 +654,8 @@ def materialize_standard_op( op_info: Aggregated information about the op (e.g., quantization config). graph_info: Graph information needed to perform quantization for the op. tensor_name_to_qsv: A map of tensor name to quantization parameters. + get_tensor_quant_params_fn: Function to get quantization parameters for the + tensor. constraint: The constraint for materializing the op. inputs_to_ignore: Input tensor indices to ignore. outputs_to_ignore: Output tensor indices to ignore. @@ -747,15 +694,30 @@ def materialize_standard_op( tensor_params = [] # Every tensor is ignored. elif constraint == OpQuantConstraint.SAME_AS_INPUT_SCALE: tensor_params = _materialize_standard_op_with_same_as_input_scale( - input_tensors, output_tensors, op_info, graph_info, tensor_name_to_qsv + input_tensors, + output_tensors, + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, ) elif constraint == OpQuantConstraint.SAME_AS_OUTPUT_SCALE: tensor_params = _materialize_standard_op_with_same_as_output_scale( - input_tensors, output_tensors, op_info, graph_info, tensor_name_to_qsv + input_tensors, + output_tensors, + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, ) else: tensor_params = _materialize_standard_op_no_constraint( - input_tensors, output_tensors, op_info, graph_info, tensor_name_to_qsv + input_tensors, + output_tensors, + op_info, + graph_info, + tensor_name_to_qsv, + get_tensor_quant_params_fn, ) # Materialize ignored tensors. @@ -781,6 +743,7 @@ def materialize_op_with_output_activation_constraint( graph_info: qtyping.GraphInfo, tensor_name_to_qsv: dict[str, Any], output_activation_constraints: dict[int, qtyping.UniformQuantParams], + get_tensor_quant_params_fn: Callable[..., Any], ) -> list[qtyping.TensorTransformationParams]: """Materialize tensors in an op with output activation constraint. @@ -795,6 +758,8 @@ def materialize_op_with_output_activation_constraint( tensor_name_to_qsv: A map of tensor name to quantization parameters. output_activation_constraints: A map of output activation num bits to quantization parameters. + get_tensor_quant_params_fn: Function to get quantization parameters for the + tensor. Returns: Quantization configuration for the tensors associated with the op (e.g., @@ -812,9 +777,7 @@ def materialize_op_with_output_activation_constraint( ) tensor_params = materialize_standard_op( - op_info, - graph_info, - tensor_name_to_qsv, + op_info, graph_info, tensor_name_to_qsv, get_tensor_quant_params_fn ) output_tensor_params = tensor_params[-1] @@ -951,76 +914,7 @@ def get_tensor_transformation_params( ) -def _get_tensor_quant_params( - op_info: qtyping.OpInfo, - tensor_min_max: dict[str, Any], - tensor_quant_config: qtyping.TensorQuantizationConfig, - tensor_content: Optional[np.ndarray] = None, -) -> qtyping.UniformQuantParams: - """Get the quantization parameters for a tensor. - - Args: - op_info: aggregated information about the op (e.g., quantization config). - tensor_min_max: the min/max of the tensor. - tensor_quant_config: the quantization config for the tensor. - tensor_content: the content of the tensor. - - Returns: - The quantization parameters for the tensor. - """ - if "min" not in tensor_min_max or "max" not in tensor_min_max: - raise ValueError( - "min and max must be provided to produce tensor quantization" - " parameters. Check if the correct calibration results are passed into" - " the ParamsGenerator." - ) - zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max( - tensor_min_max["min"], - tensor_min_max["max"], - tensor_quant_config.num_bits, - tensor_quant_config.symmetric, - ) - quantized_dim = None - if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE: - if op_info.op_name == _TFLOpName.BATCH_MATMUL: - quantized_dim = _get_bmm_weight_quantized_dim( - tensor_content, adj_y=op_info.op.builtinOptions.adjY - ) - else: - quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM[ - op_info.op_name - ] - quant_params = qtyping.UniformQuantParams( - scale=scale, - zero_point=zp, - num_bits=tensor_quant_config.num_bits, - symmetric=tensor_quant_config.symmetric, - quantized_dimension=quantized_dim, - ) - if tensor_content is None: - return quant_params - if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE: - quantized_vars = ( - uniform_quantize_tensor.uniform_quantize_for_emulated_subchannel( - tensor_content, quant_params, tensor_quant_config.block_size - ) - ) - else: - quantized_vars = uniform_quantize_tensor.uniform_quantize( - tensor_content, quant_params - ) - # Update with quantized values. - return qtyping.UniformQuantParams( - scale=scale, - zero_point=zp, - num_bits=tensor_quant_config.num_bits, - symmetric=tensor_quant_config.symmetric, - quantized_dimension=quantized_dim, - quantized_data=quantized_vars, - ) - - -def _get_reduce_dims( +def get_reduce_dims( quantized_dim: Optional[int], tensor_shape: list[int], ) -> Optional[tuple[int, ...]]: @@ -1034,7 +928,7 @@ def _get_reduce_dims( return tuple(reduce_dims) -def _get_bmm_weight_quantized_dim( +def get_bmm_weight_quantized_dim( weight_tensor_data: np.ndarray, adj_y: bool ) -> int: """Get the quantized dimension for batch matmul.""" diff --git a/ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py b/ai_edge_quantizer/algorithms/utils/common_utils_test.py similarity index 93% rename from ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py rename to ai_edge_quantizer/algorithms/utils/common_utils_test.py index a88f65f..80231ef 100644 --- a/ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +++ b/ai_edge_quantizer/algorithms/utils/common_utils_test.py @@ -18,7 +18,7 @@ from tensorflow.python.platform import googletest from ai_edge_quantizer import default_policy from ai_edge_quantizer import qtyping -from ai_edge_quantizer.algorithms.utils import min_max_quantize_utils +from ai_edge_quantizer.algorithms.utils import common_utils _ComputePrecision = qtyping.ComputePrecision _QuantTransformation = qtyping.QuantTransformation @@ -52,7 +52,7 @@ def test_get_tensor_transformations( compute_precision=compute_precision, explicit_dequantize=explicit_dequantize, ) - transformations = min_max_quantize_utils.get_tensor_transformations( + transformations = common_utils.get_tensor_transformations( op_quant_config, is_inbounding_tensor, is_constant ) # Check if WEIGHT_ONLY. @@ -120,7 +120,7 @@ def test_check_weight_only_config_raises_when_invalid_config(self, op_name): with self.assertRaisesWithPredicateMatch( ValueError, lambda err: error_message in str(err) ): - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) @@ -146,7 +146,7 @@ def test_check_drq_config_succeeds( ), compute_precision=_ComputePrecision.INTEGER, # DRQ. ) - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) @@ -166,7 +166,7 @@ def test_check_drq_config_unsupported_op_raise_error(self, op_name): with self.assertRaisesWithPredicateMatch( ValueError, lambda err: error_message in str(err) ): - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) @@ -186,7 +186,7 @@ def test_check_drq_config_wrong_bits_raise_error(self, op_name): with self.assertRaisesWithPredicateMatch( ValueError, lambda err: error_message in str(err) ): - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) @@ -207,7 +207,7 @@ def test_check_drq_config_asymmetric_weights_raise_error(self, op_name): with self.assertRaisesWithPredicateMatch( ValueError, lambda err: error_message in str(err) ): - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) @@ -220,7 +220,7 @@ def test_check_drq_config_with_non_default_min_weight_elements_succeeds(self): compute_precision=_ComputePrecision.INTEGER, # DRQ. min_weight_elements=100, ) - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( _TFLOpName.CONV_2D, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) @@ -255,7 +255,7 @@ def test_check_srq_config_succeeds( ), compute_precision=_ComputePrecision.INTEGER, # SRQ. ) - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) @@ -275,7 +275,7 @@ def test_check_srq_config_unsupported_op_raise_error(self): with self.assertRaisesWithPredicateMatch( ValueError, lambda err: error_message in str(err) ): - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( _TFLOpName.CUSTOM_OP, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) @@ -297,7 +297,7 @@ def test_check_srq_config_wrong_act_bits_config_raise_error(self): with self.assertRaisesWithPredicateMatch( ValueError, lambda err: error_message in str(err) ): - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( _TFLOpName.FULLY_CONNECTED, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY, @@ -321,7 +321,7 @@ def test_check_srq_config_asym_int16_act_raise_error(self): with self.assertRaisesWithPredicateMatch( ValueError, lambda err: error_message in str(err) ): - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( _TFLOpName.FULLY_CONNECTED, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY, @@ -345,7 +345,7 @@ def test_check_srq_config_wrong_weight_bits_raise_error(self): with self.assertRaisesWithPredicateMatch( ValueError, lambda err: error_message in str(err) ): - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( _TFLOpName.FULLY_CONNECTED, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY, @@ -368,7 +368,7 @@ def test_check_srq_config_asym_weight_raise_error(self): with self.assertRaisesWithPredicateMatch( ValueError, lambda err: error_message in str(err) ): - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( _TFLOpName.FULLY_CONNECTED, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY, @@ -425,7 +425,7 @@ def test_check_supported_int4_config_succeeds( compute_precision == _ComputePrecision.INTEGER and op_quant_config.activation_tensor_config is None ): - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) # Check if WEIGHT_ONLY. @@ -439,7 +439,7 @@ def test_check_supported_int4_config_succeeds( compute_precision == _ComputePrecision.INTEGER and op_quant_config.activation_tensor_config is not None ): - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) @@ -477,11 +477,11 @@ def test_check_unsupported_int4_config_raise_error( with self.assertRaises(ValueError): if is_drq: - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) elif not is_drq: - min_max_quantize_utils.check_if_valid_op_config( + common_utils.check_if_valid_op_config( op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) @@ -500,11 +500,12 @@ def test_materialize_op_with_output_activation_constraint_fails_for_multiple_out with self.assertRaisesRegex( ValueError, "only supports ops with a single output tensor" ): - min_max_quantize_utils.materialize_op_with_output_activation_constraint( + common_utils.materialize_op_with_output_activation_constraint( op_info=mock_op_info, graph_info=qtyping.GraphInfo([], []), tensor_name_to_qsv={}, output_activation_constraints={}, + get_tensor_quant_params_fn=lambda *args: [], )