Skip to content

Commit 44266ab

Browse files
authored
[onnx] Support fp8 for onnx.QuantizeLinear (llvm#3619)
We need to directly decompose quantize linear for `fp8` types as the equivalent torch operations do not support the operation.
1 parent 8358e8c commit 44266ab

File tree

2 files changed

+68
-22
lines changed

2 files changed

+68
-22
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

+51-22
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
214214
binder.tensorResultType(resultType))
215215
return failure();
216216

217+
auto loc = binder.getLoc();
217218
Value operand = operands[0];
218219
Value scale = operands[1];
219220
Value zeropoint = operands[2];
@@ -225,33 +226,61 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
225226
return rewriter.notifyMatchFailure(binder.op,
226227
"requires known result dtype");
227228

228-
if (scaleTy.getSizes().size() == 0) {
229-
auto qTensorTy = getQTorchTypeFromTorchIntType(resultType);
230-
if (!qTensorTy) {
231-
return rewriter.notifyMatchFailure(binder.op,
232-
"unsupported result dtype");
233-
}
229+
auto resultETy = resultType.getDtype();
234230

235-
auto torchqTy = Torch::getScalarTypeForType(qTensorTy.getDtype());
231+
bool rank0 = scaleTy.getSizes().size() == 0;
232+
bool length1 =
233+
scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1;
236234

237-
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
238-
binder.getLoc(), rewriter.getType<Torch::IntType>(),
239-
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
240-
static_cast<int64_t>(torchqTy)));
241-
242-
scale = rewriter.create<Torch::AtenItemOp>(
243-
binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
244-
zeropoint = rewriter.create<Torch::AtenItemOp>(
245-
binder.getLoc(), rewriter.getType<Torch::IntType>(), zeropoint);
246-
247-
auto quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(
248-
binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst);
249-
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(
250-
binder.op, resultType, quantize);
235+
if (!rank0 && !length1)
236+
return rewriter.notifyMatchFailure(binder.op,
237+
"unimplemented: non-scalar scale");
238+
239+
auto qTensorTy = getQTorchTypeFromTorchIntType(resultType);
240+
if (!qTensorTy) {
241+
return rewriter.notifyMatchFailure(binder.op,
242+
"unsupported result dtype");
243+
}
244+
245+
auto torchqTy = Torch::getScalarTypeForType(qTensorTy.getDtype());
246+
247+
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
248+
loc, rewriter.getType<Torch::IntType>(),
249+
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
250+
static_cast<int64_t>(torchqTy)));
251+
252+
scale = rewriter.create<Torch::AtenItemOp>(
253+
loc, rewriter.getType<Torch::FloatType>(), scale);
254+
255+
bool fpResult = isa<mlir::FloatType>(resultETy);
256+
Type zeropointTy = rewriter.getType<Torch::IntType>();
257+
if (fpResult)
258+
zeropointTy = rewriter.getType<Torch::FloatType>();
259+
zeropoint =
260+
rewriter.create<Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
261+
262+
if (fpResult) {
263+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
264+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
265+
Value one = rewriter.create<Torch::ConstantFloatOp>(
266+
loc, rewriter.getF64FloatAttr(1.0));
267+
Value div = rewriter.create<Torch::AtenDivScalarOp>(
268+
loc, operand.getType(), operand, scale);
269+
Value add = rewriter.create<Torch::AtenAddScalarOp>(
270+
loc, operand.getType(), div, zeropoint, one);
271+
272+
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
273+
binder.op, resultType, add, tyConst,
274+
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
275+
/*memory_format=*/none);
251276
return success();
252277
}
253278

254-
return failure();
279+
auto quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(
280+
loc, qTensorTy, operand, scale, zeropoint, tyConst);
281+
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
282+
quantize);
283+
return success();
255284
});
256285
patterns.onOp(
257286
"QLinearConv", 1,

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

+17
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,23 @@ func.func @test_quantizelinear_i32(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch
4747

4848
// -----
4949

50+
// CHECK-LABEL: @test_quantizelinear_f8
51+
func.func @test_quantizelinear_f8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[6],f8E4M3FN> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
52+
// CHECK: %[[DTYPE:.+]] = torch.constant.int 24
53+
// CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
54+
// CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
55+
// CHECK: %[[NONE:.+]] = torch.constant.none
56+
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
57+
// CHECK: %[[ONE:.+]] = torch.constant.float 1.000000e+00
58+
// CHECK: %[[DIV:.+]] = torch.aten.div.Scalar %arg0, %[[SCALE]]
59+
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[DIV]], %[[ZP]], %[[ONE]]
60+
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[ADD]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]]
61+
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[6],f8E4M3FN>
62+
return %0 : !torch.vtensor<[6],f8E4M3FN>
63+
}
64+
65+
// -----
66+
5067
// CHECK-LABEL: @test_qlinearconv_nobias
5168
func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
5269
%0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8>

0 commit comments

Comments
 (0)