@@ -214,6 +214,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
214
214
binder.tensorResultType (resultType))
215
215
return failure ();
216
216
217
+ auto loc = binder.getLoc ();
217
218
Value operand = operands[0 ];
218
219
Value scale = operands[1 ];
219
220
Value zeropoint = operands[2 ];
@@ -225,33 +226,61 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
225
226
return rewriter.notifyMatchFailure (binder.op ,
226
227
" requires known result dtype" );
227
228
228
- if (scaleTy.getSizes ().size () == 0 ) {
229
- auto qTensorTy = getQTorchTypeFromTorchIntType (resultType);
230
- if (!qTensorTy) {
231
- return rewriter.notifyMatchFailure (binder.op ,
232
- " unsupported result dtype" );
233
- }
229
+ auto resultETy = resultType.getDtype ();
234
230
235
- auto torchqTy = Torch::getScalarTypeForType (qTensorTy.getDtype ());
231
+ bool rank0 = scaleTy.getSizes ().size () == 0 ;
232
+ bool length1 =
233
+ scaleTy.getSizes ().size () == 1 && scaleTy.getSizes ()[0 ] == 1 ;
236
234
237
- Value tyConst = rewriter.create <Torch::ConstantIntOp>(
238
- binder.getLoc (), rewriter.getType <Torch::IntType>(),
239
- rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
240
- static_cast <int64_t >(torchqTy)));
241
-
242
- scale = rewriter.create <Torch::AtenItemOp>(
243
- binder.getLoc (), rewriter.getType <Torch::FloatType>(), scale);
244
- zeropoint = rewriter.create <Torch::AtenItemOp>(
245
- binder.getLoc (), rewriter.getType <Torch::IntType>(), zeropoint);
246
-
247
- auto quantize = rewriter.create <Torch::AtenQuantizePerTensorOp>(
248
- binder.getLoc (), qTensorTy, operand, scale, zeropoint, tyConst);
249
- rewriter.replaceOpWithNewOp <Torch::AtenIntReprOp>(
250
- binder.op , resultType, quantize);
235
+ if (!rank0 && !length1)
236
+ return rewriter.notifyMatchFailure (binder.op ,
237
+ " unimplemented: non-scalar scale" );
238
+
239
+ auto qTensorTy = getQTorchTypeFromTorchIntType (resultType);
240
+ if (!qTensorTy) {
241
+ return rewriter.notifyMatchFailure (binder.op ,
242
+ " unsupported result dtype" );
243
+ }
244
+
245
+ auto torchqTy = Torch::getScalarTypeForType (qTensorTy.getDtype ());
246
+
247
+ Value tyConst = rewriter.create <Torch::ConstantIntOp>(
248
+ loc, rewriter.getType <Torch::IntType>(),
249
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
250
+ static_cast <int64_t >(torchqTy)));
251
+
252
+ scale = rewriter.create <Torch::AtenItemOp>(
253
+ loc, rewriter.getType <Torch::FloatType>(), scale);
254
+
255
+ bool fpResult = isa<mlir::FloatType>(resultETy);
256
+ Type zeropointTy = rewriter.getType <Torch::IntType>();
257
+ if (fpResult)
258
+ zeropointTy = rewriter.getType <Torch::FloatType>();
259
+ zeropoint =
260
+ rewriter.create <Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
261
+
262
+ if (fpResult) {
263
+ Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
264
+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
265
+ Value one = rewriter.create <Torch::ConstantFloatOp>(
266
+ loc, rewriter.getF64FloatAttr (1.0 ));
267
+ Value div = rewriter.create <Torch::AtenDivScalarOp>(
268
+ loc, operand.getType (), operand, scale);
269
+ Value add = rewriter.create <Torch::AtenAddScalarOp>(
270
+ loc, operand.getType (), div , zeropoint, one);
271
+
272
+ rewriter.replaceOpWithNewOp <Torch::AtenToDtypeOp>(
273
+ binder.op , resultType, add, tyConst,
274
+ /* non_blocking=*/ cstFalse, /* copy=*/ cstFalse,
275
+ /* memory_format=*/ none);
251
276
return success ();
252
277
}
253
278
254
- return failure ();
279
+ auto quantize = rewriter.create <Torch::AtenQuantizePerTensorOp>(
280
+ loc, qTensorTy, operand, scale, zeropoint, tyConst);
281
+ rewriter.replaceOpWithNewOp <Torch::AtenIntReprOp>(binder.op , resultType,
282
+ quantize);
283
+ return success ();
255
284
});
256
285
patterns.onOp (
257
286
" QLinearConv" , 1 ,
0 commit comments