@@ -236,6 +236,15 @@ def test_qdq_extra_options_2(self):
236
236
237
237
238
238
class TestQDQFormatConv (TestQDQFormat ):
239
+ def check_per_channel_counts (self , model_path , channel_count : int , axis : int = 0 ):
240
+ model = onnx .load (Path (model_path ))
241
+ for initializer in model .graph .initializer :
242
+ dims = initializer .dims
243
+ # skip if initializer is not a weight
244
+ if len (dims ) > 0 :
245
+ self .assertGreater (len (dims ), axis )
246
+ self .assertEqual (channel_count , dims [axis ])
247
+
239
248
def construct_model_conv (self , output_model_path , input_shape , weight_shape , output_shape , has_bias ):
240
249
# (input)
241
250
# |
@@ -281,8 +290,9 @@ def verify_quantize_conv(self, has_bias, per_channel, is_quant_type_int8=False):
281
290
model_fp32_path = "conv_fp32.{}.{}.onnx" .format (has_bias , per_channel )
282
291
model_int8_qdq_path = "conv_quant_qdq.{}.{}.onnx" .format (has_bias , per_channel )
283
292
model_int8_qop_path = "conv_quant_qop.{}.{}.onnx" .format (has_bias , per_channel )
293
+ channel_count = 16
284
294
data_reader = self .input_feeds (1 , {"input" : [1 , 8 , 33 , 33 ]})
285
- self .construct_model_conv (model_fp32_path , [1 , 8 , 33 , 33 ], [16 , 8 , 3 , 3 ], [1 , 16 , 31 , 31 ], has_bias )
295
+ self .construct_model_conv (model_fp32_path , [1 , 8 , 33 , 33 ], [channel_count , 8 , 3 , 3 ], [1 , 16 , 31 , 31 ], has_bias )
286
296
quantize_static (
287
297
model_fp32_path ,
288
298
model_int8_qdq_path ,
@@ -301,6 +311,8 @@ def verify_quantize_conv(self, has_bias, per_channel, is_quant_type_int8=False):
301
311
"DequantizeLinear" : 4 if has_bias else 3 ,
302
312
}
303
313
check_op_type_count (self , model_int8_qdq_path , ** qdq_nodes )
314
+ if per_channel :
315
+ self .check_per_channel_counts (model_int8_qdq_path , channel_count )
304
316
check_model_correctness (self , model_fp32_path , model_int8_qdq_path , data_reader .get_next ())
305
317
306
318
data_reader .rewind ()
0 commit comments