Skip to content

Commit 44c14e8

Browse files
authored
Adding test case for conv per channel with QDQ format (microsoft#13041)
**Description**: Adding test case for conv per channel with QDQ format
1 parent 2ae33b3 commit 44c14e8

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

onnxruntime/test/python/quantization/test_qdq.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,15 @@ def test_qdq_extra_options_2(self):
236236

237237

238238
class TestQDQFormatConv(TestQDQFormat):
239+
def check_per_channel_counts(self, model_path, channel_count: int, axis: int = 0):
240+
model = onnx.load(Path(model_path))
241+
for initializer in model.graph.initializer:
242+
dims = initializer.dims
243+
# skip if initializer is not a weight
244+
if len(dims) > 0:
245+
self.assertGreater(len(dims), axis)
246+
self.assertEqual(channel_count, dims[axis])
247+
239248
def construct_model_conv(self, output_model_path, input_shape, weight_shape, output_shape, has_bias):
240249
# (input)
241250
# |
@@ -281,8 +290,9 @@ def verify_quantize_conv(self, has_bias, per_channel, is_quant_type_int8=False):
281290
model_fp32_path = "conv_fp32.{}.{}.onnx".format(has_bias, per_channel)
282291
model_int8_qdq_path = "conv_quant_qdq.{}.{}.onnx".format(has_bias, per_channel)
283292
model_int8_qop_path = "conv_quant_qop.{}.{}.onnx".format(has_bias, per_channel)
293+
channel_count = 16
284294
data_reader = self.input_feeds(1, {"input": [1, 8, 33, 33]})
285-
self.construct_model_conv(model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 31, 31], has_bias)
295+
self.construct_model_conv(model_fp32_path, [1, 8, 33, 33], [channel_count, 8, 3, 3], [1, 16, 31, 31], has_bias)
286296
quantize_static(
287297
model_fp32_path,
288298
model_int8_qdq_path,
@@ -301,6 +311,8 @@ def verify_quantize_conv(self, has_bias, per_channel, is_quant_type_int8=False):
301311
"DequantizeLinear": 4 if has_bias else 3,
302312
}
303313
check_op_type_count(self, model_int8_qdq_path, **qdq_nodes)
314+
if per_channel:
315+
self.check_per_channel_counts(model_int8_qdq_path, channel_count)
304316
check_model_correctness(self, model_fp32_path, model_int8_qdq_path, data_reader.get_next())
305317

306318
data_reader.rewind()

0 commit comments

Comments
 (0)