Skip to content

Commit

Permalink
Refactor min/max quantize files
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723259068
  • Loading branch information
paulinesho authored and copybara-github committed Feb 6, 2025
1 parent 09eec5a commit 053d748
Show file tree
Hide file tree
Showing 34 changed files with 1,110 additions and 928 deletions.
97 changes: 36 additions & 61 deletions ai_edge_quantizer/algorithm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
"""Quantizer Algorithm Manager Interface."""

import enum
import functools
from ai_edge_quantizer import algorithm_manager_api
from ai_edge_quantizer import default_policy
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.algorithms.nonlinear_quantize import float_casting
from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize

_TFLOpName = qtyping.TFLOperationName
Expand Down Expand Up @@ -54,7 +56,7 @@ class AlgorithmName(str, enum.Enum):
# Register MIN_MAX_UNIFORM_QUANT algorithm.
register_op_quant_config_validation_func(
AlgorithmName.MIN_MAX_UNIFORM_QUANT,
naive_min_max_quantize.check_op_quantization_config,
common_quantize.check_op_quantization_config,
)

# Register a config check policy for MIN_MAX_UNIFORM_QUANT algorithm.
Expand All @@ -63,71 +65,44 @@ class AlgorithmName(str, enum.Enum):
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
)


for op_name, materialize_func in zip(
(
_TFLOpName.INPUT,
_TFLOpName.OUTPUT,
_TFLOpName.FULLY_CONNECTED,
_TFLOpName.BATCH_MATMUL,
_TFLOpName.CONV_2D,
_TFLOpName.DEPTHWISE_CONV_2D,
_TFLOpName.CONV_2D_TRANSPOSE,
_TFLOpName.RESHAPE,
_TFLOpName.AVERAGE_POOL_2D,
_TFLOpName.EMBEDDING_LOOKUP,
_TFLOpName.SOFTMAX,
_TFLOpName.TANH,
_TFLOpName.TRANSPOSE,
_TFLOpName.GELU,
_TFLOpName.ADD,
_TFLOpName.SUB,
_TFLOpName.MUL,
_TFLOpName.MEAN,
_TFLOpName.RSQRT,
_TFLOpName.CONCATENATION,
_TFLOpName.STRIDED_SLICE,
_TFLOpName.SPLIT,
_TFLOpName.LOGISTIC, # Sigmoid
_TFLOpName.SLICE,
_TFLOpName.SUM,
_TFLOpName.SELECT_V2,
),
(
naive_min_max_quantize.materialize_input,
naive_min_max_quantize.materialize_output,
naive_min_max_quantize.materialize_fc_conv,
naive_min_max_quantize.materialize_batch_matmul,
naive_min_max_quantize.materialize_fc_conv,
naive_min_max_quantize.materialize_fc_conv,
naive_min_max_quantize.materialize_conv2d_transpose,
naive_min_max_quantize.materialize_reshape,
naive_min_max_quantize.materialize_average_pool_2d,
naive_min_max_quantize.materialize_embedding_lookup,
naive_min_max_quantize.materialize_softmax_and_logistic,
naive_min_max_quantize.materialize_tanh,
naive_min_max_quantize.materialize_transpose,
naive_min_max_quantize.materialize_gelu,
naive_min_max_quantize.materialize_add,
naive_min_max_quantize.materialize_sub,
naive_min_max_quantize.materialize_mul,
naive_min_max_quantize.materialize_mean,
naive_min_max_quantize.materialize_rsqrt,
naive_min_max_quantize.materialize_concatenation,
naive_min_max_quantize.materialize_strided_slice,
naive_min_max_quantize.materialize_split,
naive_min_max_quantize.materialize_softmax_and_logistic,
naive_min_max_quantize.materialize_slice,
naive_min_max_quantize.materialize_sum,
naive_min_max_quantize.materialize_select_v2,
),
):
MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
_TFLOpName.INPUT: common_quantize.materialize_input,
_TFLOpName.OUTPUT: common_quantize.materialize_output,
_TFLOpName.FULLY_CONNECTED: common_quantize.materialize_fc_conv,
_TFLOpName.BATCH_MATMUL: common_quantize.materialize_batch_matmul,
_TFLOpName.CONV_2D: common_quantize.materialize_fc_conv,
_TFLOpName.DEPTHWISE_CONV_2D: common_quantize.materialize_fc_conv,
_TFLOpName.CONV_2D_TRANSPOSE: common_quantize.materialize_conv2d_transpose,
_TFLOpName.RESHAPE: common_quantize.materialize_reshape,
_TFLOpName.AVERAGE_POOL_2D: common_quantize.materialize_average_pool_2d,
_TFLOpName.EMBEDDING_LOOKUP: common_quantize.materialize_embedding_lookup,
_TFLOpName.SOFTMAX: common_quantize.materialize_softmax_and_logistic,
_TFLOpName.TANH: common_quantize.materialize_tanh,
_TFLOpName.TRANSPOSE: common_quantize.materialize_transpose,
_TFLOpName.GELU: common_quantize.materialize_gelu,
_TFLOpName.ADD: common_quantize.materialize_add,
_TFLOpName.SUB: common_quantize.materialize_sub,
_TFLOpName.MUL: common_quantize.materialize_mul,
_TFLOpName.MEAN: common_quantize.materialize_mean,
_TFLOpName.RSQRT: common_quantize.materialize_rsqrt,
_TFLOpName.CONCATENATION: common_quantize.materialize_concatenation,
_TFLOpName.STRIDED_SLICE: common_quantize.materialize_strided_slice,
_TFLOpName.SPLIT: common_quantize.materialize_split,
_TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic,
_TFLOpName.SLICE: common_quantize.materialize_slice,
_TFLOpName.SUM: common_quantize.materialize_sum,
_TFLOpName.SELECT_V2: common_quantize.materialize_select_v2,
}
for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
register_quantized_op(
AlgorithmName.MIN_MAX_UNIFORM_QUANT,
op_name,
naive_min_max_quantize.init_qsvs,
calibration_func=naive_min_max_quantize.min_max_calibrate,
materialize_func=materialize_func,
materialize_func=functools.partial(
materialize_func,
naive_min_max_quantize.get_tensor_quant_params,
),
)

# Register FLOAT_CASTING algorithm.
Expand Down
Loading

0 comments on commit 053d748

Please sign in to comment.