Skip to content

Commit 8358e8c

Browse files
authored
[onnx] Add support for fp8 onnx.DequantizeLinear (llvm#3617)
Fp8 needs a slightly different path for dequantization as the `torch` dequantize operation does not support `fp8` types.
1 parent 880e64b commit 8358e8c

File tree

2 files changed

+66
-18
lines changed

2 files changed

+66
-18
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

+50-18
Original file line numberDiff line numberDiff line change
@@ -2117,41 +2117,73 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
21172117
binder.tensorResultType(resultType))
21182118
return failure();
21192119

2120+
auto loc = binder.getLoc();
21202121
Value operand = operands[0];
21212122
Value scale = operands[1];
21222123
Value zeropoint = operands[2];
21232124

21242125
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
21252126

2127+
auto operandETy = operandTy.getDtype();
21262128
auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType());
21272129
if (!scaleTy || !scaleTy.hasSizes())
21282130
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
21292131
if (!resultType.hasDtype())
21302132
return rewriter.notifyMatchFailure(binder.op,
21312133
"requires known result dtype");
2132-
if (scaleTy.getSizes().size() == 0 ||
2133-
(scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) {
2134-
auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy);
2135-
if (!qTensorTy) {
2136-
return rewriter.notifyMatchFailure(binder.op,
2137-
"unsupported result dtype");
2138-
}
21392134

2140-
scale = rewriter.create<Torch::AtenItemOp>(
2141-
binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
2142-
zeropoint = rewriter.create<Torch::AtenItemOp>(
2143-
binder.getLoc(), rewriter.getType<Torch::IntType>(), zeropoint);
2135+
bool rank0 = scaleTy.getSizes().size() == 0;
2136+
bool length1 =
2137+
scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1;
2138+
2139+
if (!rank0 && !length1)
2140+
return rewriter.notifyMatchFailure(binder.op,
2141+
"unimplemented: non-scalar scale");
2142+
auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy);
2143+
if (!qTensorTy) {
2144+
return rewriter.notifyMatchFailure(binder.op,
2145+
"unsupported result dtype");
2146+
}
2147+
2148+
scale = rewriter.create<Torch::AtenItemOp>(
2149+
loc, rewriter.getType<Torch::FloatType>(), scale);
2150+
2151+
bool fpOperand = isa<mlir::FloatType>(operandETy);
2152+
Type zeropointTy = rewriter.getType<Torch::IntType>();
2153+
if (fpOperand)
2154+
zeropointTy = rewriter.getType<Torch::FloatType>();
2155+
2156+
zeropoint =
2157+
rewriter.create<Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
21442158

2145-
auto quantize =
2146-
rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
2147-
binder.getLoc(), qTensorTy, operand, scale, zeropoint);
2148-
rewriter.replaceOpWithNewOp<Torch::AtenDequantizeSelfOp>(
2149-
binder.op, resultType, quantize);
2159+
if (fpOperand) {
2160+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
2161+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
2162+
auto tyVal = Torch::getScalarTypeForType(resultType.getDtype());
2163+
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
2164+
loc, rewriter.getType<Torch::IntType>(),
2165+
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
2166+
static_cast<int64_t>(tyVal)));
2167+
Value toDtype = rewriter.create<Torch::AtenToDtypeOp>(
2168+
loc, resultType, operand, tyConst,
2169+
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
2170+
/*memory_format=*/none);
2171+
2172+
Value one = rewriter.create<Torch::ConstantFloatOp>(
2173+
loc, rewriter.getF64FloatAttr(1.0));
2174+
Value sub = rewriter.create<Torch::AtenSubScalarOp>(
2175+
loc, resultType, toDtype, zeropoint, one);
2176+
rewriter.replaceOpWithNewOp<Torch::AtenMulScalarOp>(
2177+
binder.op, resultType, sub, scale);
21502178
return success();
21512179
}
21522180

2153-
return rewriter.notifyMatchFailure(binder.op,
2154-
"unimplemented: non-scalar scale");
2181+
auto quantize =
2182+
rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
2183+
loc, qTensorTy, operand, scale, zeropoint);
2184+
rewriter.replaceOpWithNewOp<Torch::AtenDequantizeSelfOp>(
2185+
binder.op, resultType, quantize);
2186+
return success();
21552187
});
21562188
patterns.onOp("Div", 7,
21572189
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir

+16
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,22 @@ func.func @test_dequantizelinear_i32(%arg0: !torch.vtensor<[6],si32>, %arg1: !to
800800

801801
// -----
802802

803+
// CHECK-LABEL: @test_dequantizelinear_fp8
804+
func.func @test_dequantizelinear_fp8(%arg0: !torch.vtensor<[6],f8E4M3FN>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f8E4M3FN>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
805+
// CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
806+
// CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f8E4M3FN> -> !torch.float
807+
// CHECK: %[[NONE:.+]] = torch.constant.none
808+
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
809+
// CHECK: %[[DTY:.+]] = torch.constant.int 6
810+
// CHECK: %[[TO:.+]] = torch.aten.to.dtype %arg0, %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]]
811+
// CHECK: %[[ONE:.+]] = torch.constant.float 1.000000e+00
812+
// CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[TO]], %[[ZP]], %[[ONE]]
813+
// CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[SCALE]]
814+
%0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f8E4M3FN>, !torch.vtensor<[],f32>, !torch.vtensor<[],f8E4M3FN>) -> !torch.vtensor<[6],f32>
815+
return %0 : !torch.vtensor<[6],f32>
816+
}
817+
818+
// -----
803819

804820
// CHECK-LABEL: @test_div_bcast
805821
func.func @test_div_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {

0 commit comments

Comments
 (0)