@@ -2117,41 +2117,73 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
2117
2117
binder.tensorResultType (resultType))
2118
2118
return failure ();
2119
2119
2120
+ auto loc = binder.getLoc ();
2120
2121
Value operand = operands[0 ];
2121
2122
Value scale = operands[1 ];
2122
2123
Value zeropoint = operands[2 ];
2123
2124
2124
2125
auto operandTy = cast<Torch::ValueTensorType>(operand.getType ());
2125
2126
2127
+ auto operandETy = operandTy.getDtype ();
2126
2128
auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType ());
2127
2129
if (!scaleTy || !scaleTy.hasSizes ())
2128
2130
return rewriter.notifyMatchFailure (binder.op , " requires known rank" );
2129
2131
if (!resultType.hasDtype ())
2130
2132
return rewriter.notifyMatchFailure (binder.op ,
2131
2133
" requires known result dtype" );
2132
- if (scaleTy.getSizes ().size () == 0 ||
2133
- (scaleTy.getSizes ().size () == 1 && scaleTy.getSizes ()[0 ] == 1 )) {
2134
- auto qTensorTy = getQTorchTypeFromTorchIntType (operandTy);
2135
- if (!qTensorTy) {
2136
- return rewriter.notifyMatchFailure (binder.op ,
2137
- " unsupported result dtype" );
2138
- }
2139
2134
2140
- scale = rewriter.create <Torch::AtenItemOp>(
2141
- binder.getLoc (), rewriter.getType <Torch::FloatType>(), scale);
2142
- zeropoint = rewriter.create <Torch::AtenItemOp>(
2143
- binder.getLoc (), rewriter.getType <Torch::IntType>(), zeropoint);
2135
+ bool rank0 = scaleTy.getSizes ().size () == 0 ;
2136
+ bool length1 =
2137
+ scaleTy.getSizes ().size () == 1 && scaleTy.getSizes ()[0 ] == 1 ;
2138
+
2139
+ if (!rank0 && !length1)
2140
+ return rewriter.notifyMatchFailure (binder.op ,
2141
+ " unimplemented: non-scalar scale" );
2142
+ auto qTensorTy = getQTorchTypeFromTorchIntType (operandTy);
2143
+ if (!qTensorTy) {
2144
+ return rewriter.notifyMatchFailure (binder.op ,
2145
+ " unsupported result dtype" );
2146
+ }
2147
+
2148
+ scale = rewriter.create <Torch::AtenItemOp>(
2149
+ loc, rewriter.getType <Torch::FloatType>(), scale);
2150
+
2151
+ bool fpOperand = isa<mlir::FloatType>(operandETy);
2152
+ Type zeropointTy = rewriter.getType <Torch::IntType>();
2153
+ if (fpOperand)
2154
+ zeropointTy = rewriter.getType <Torch::FloatType>();
2155
+
2156
+ zeropoint =
2157
+ rewriter.create <Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
2144
2158
2145
- auto quantize =
2146
- rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
2147
- binder.getLoc (), qTensorTy, operand, scale, zeropoint);
2148
- rewriter.replaceOpWithNewOp <Torch::AtenDequantizeSelfOp>(
2149
- binder.op , resultType, quantize);
2159
+ if (fpOperand) {
2160
+ Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
2161
+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
2162
+ auto tyVal = Torch::getScalarTypeForType (resultType.getDtype ());
2163
+ Value tyConst = rewriter.create <Torch::ConstantIntOp>(
2164
+ loc, rewriter.getType <Torch::IntType>(),
2165
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
2166
+ static_cast <int64_t >(tyVal)));
2167
+ Value toDtype = rewriter.create <Torch::AtenToDtypeOp>(
2168
+ loc, resultType, operand, tyConst,
2169
+ /* non_blocking=*/ cstFalse, /* copy=*/ cstFalse,
2170
+ /* memory_format=*/ none);
2171
+
2172
+ Value one = rewriter.create <Torch::ConstantFloatOp>(
2173
+ loc, rewriter.getF64FloatAttr (1.0 ));
2174
+ Value sub = rewriter.create <Torch::AtenSubScalarOp>(
2175
+ loc, resultType, toDtype, zeropoint, one);
2176
+ rewriter.replaceOpWithNewOp <Torch::AtenMulScalarOp>(
2177
+ binder.op , resultType, sub, scale);
2150
2178
return success ();
2151
2179
}
2152
2180
2153
- return rewriter.notifyMatchFailure (binder.op ,
2154
- " unimplemented: non-scalar scale" );
2181
+ auto quantize =
2182
+ rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
2183
+ loc, qTensorTy, operand, scale, zeropoint);
2184
+ rewriter.replaceOpWithNewOp <Torch::AtenDequantizeSelfOp>(
2185
+ binder.op , resultType, quantize);
2186
+ return success ();
2155
2187
});
2156
2188
patterns.onOp (" Div" , 7 ,
2157
2189
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
0 commit comments