From a35e727dbcf2fdb0b341111812f4c8d150b6607e Mon Sep 17 00:00:00 2001 From: Pauline Sho Date: Tue, 15 Oct 2024 18:41:29 -0700 Subject: [PATCH] Fix incorrect im2col size allocation with INT4 filter PiperOrigin-RevId: 686312779 --- ai_edge_quantizer/quantizer.py | 24 +++++++++++++++- ai_edge_quantizer/utils/validation_utils.py | 32 +++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/ai_edge_quantizer/quantizer.py b/ai_edge_quantizer/quantizer.py index e95604f..4cf8a0f 100644 --- a/ai_edge_quantizer/quantizer.py +++ b/ai_edge_quantizer/quantizer.py @@ -16,6 +16,7 @@ """AI Edge Quantizer API.""" from collections.abc import Iterable +import copy import dataclasses import json import os @@ -239,7 +240,9 @@ def calibrate( return calib.get_model_qsvs() def quantize( - self, calibration_result: Optional[_CalibrationResult] = None + self, + calibration_result: Optional[_CalibrationResult] = None, + tfl_scales=None, ) -> QuantizationResult: """Quantizes the float model. @@ -257,6 +260,25 @@ def quantize( if not self.get_quantization_recipe(): raise RuntimeError('Can not quantize without a quantization recipe.') quant_params = self._get_quantization_params(calibration_result) + + # import pprint + # pp = pprint.PrettyPrinter(indent=4) + + # override scales + if tfl_scales is not None: + for s, params in quant_params.items(): + for ts in tfl_scales: + if ts[0] == s: + for c in params.consumers: + if c.parameters is not None and isinstance( + c.parameters, qtyping.UniformQuantParams + ): + print(c.parameters.scale) + print(ts[1]) + c.parameters.scale[:] = abs(float(ts[1])) + print(c.parameters.scale) + print('.........') + quantized_model = self._get_quantized_model(quant_params) self._result = QuantizationResult( self.get_quantization_recipe(), quantized_model diff --git a/ai_edge_quantizer/utils/validation_utils.py b/ai_edge_quantizer/utils/validation_utils.py index 19cbe1f..9102778 100644 --- a/ai_edge_quantizer/utils/validation_utils.py +++ b/ai_edge_quantizer/utils/validation_utils.py @@ -60,6 +60,22 @@ def mean_squared_difference( Raises: Value error if the two inputs don't have the same number of elements """ + if np.any(np.isnan(data1)): + print(data1) + print("data1 is nan") + if np.any(np.isnan(data2)): + print("data2 is nan") + if np.any(np.isinf(data1)): + print("data1 is inf") + if np.any(np.isinf(data2)): + print("data2 is inf") + if ( + np.any(np.isnan(data1)) + or np.any(np.isnan(data2)) + or np.any(np.isinf(data1)) + or np.any(np.isinf(data2)) + ): + return float(1234567890) data1, data2 = _preprocess_same_size_arrays(data1, data2) # special handling for tensor of size 0 if data1.size == 0: @@ -89,6 +105,22 @@ def median_diff_ratio( Raises: Value error if the two inputs don't have the same number of elements """ + if np.any(np.isnan(data1)): + print(data1) + print("data1 is nan") + if np.any(np.isnan(data2)): + print("data2 is nan") + if np.any(np.isinf(data1)): + print("data1 is inf") + if np.any(np.isinf(data2)): + print("data2 is inf") + if ( + np.any(np.isnan(data1)) + or np.any(np.isnan(data2)) + or np.any(np.isinf(data1)) + or np.any(np.isinf(data2)) + ): + return float(1234567890) data1, data2 = _preprocess_same_size_arrays(data1, data2) # special handling for tensor of size 0 if data1.size == 0: