Skip to content

Commit

Permalink
Merge pull request #1605 from ROCm/fix_dequantizelinear_bp
Browse files Browse the repository at this point in the history
[BACKPORT] Fix dequantizelinear definition

Backport a fix to the definition of dequantizelinear so that it can be used in MIGraphX 6.2.1
  • Loading branch information
krzysz00 authored Aug 15, 2024
2 parents 2d2ac48 + 1324731 commit 2d7c4ec
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 26 deletions.
34 changes: 14 additions & 20 deletions mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,11 @@ struct SoftmaxConverter final
};
} // namespace

// MIGraphX pseudo code:
// output[i] = static_cast<T>(input[i] - zero_pts[i]) * scales[i];
// MIGraphX implements:
// Let T = scale element type
// output[i] = (convert<T>(input[i]) - convert<T>(zero_pts[i])) * scales[i];
// For f32, this matches ONNX reference, dequantizing to f16, if it's ever done
// will be less precise than the reference but that's probably fine.
LogicalResult DeQuantizeLinearConverter::matchAndRewrite(
migraphx::DeQuantizeLinearOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -847,34 +850,25 @@ LogicalResult DeQuantizeLinearConverter::matchAndRewrite(
Value output = op.getOutput();
Location loc = op->getLoc();

Value shifted = input;
Type outputType = getShapedElementTy(output);
Value upcastInput = createCastOp(rewriter, loc, outputType, input);

Value shifted = upcastInput;
if (auto bias = adaptor.getBias()) {
Type inElemTy = getShapedElementTy(input);
Type biasElemTy = getShapedElementTy(bias);
Type elementType =
inElemTy.getIntOrFloatBitWidth() <= biasElemTy.getIntOrFloatBitWidth()
? biasElemTy
: inElemTy;
if (inElemTy != elementType)
input = createCastOp(rewriter, loc, elementType, shifted);
if (biasElemTy != elementType)
bias = createCastOp(rewriter, loc, elementType, bias);
shifted =
createOpAndInfer<tosa::SubOp>(rewriter, loc, elementType, input, bias);
Value upcastBias = createCastOp(rewriter, loc, outputType, bias);
shifted = createOpAndInfer<tosa::SubOp>(rewriter, loc, outputType,
upcastInput, upcastBias);
}

Type outputType = getShapedElementTy(output);
Value upCast = createCastOp(rewriter, loc, outputType, shifted);

Value scaled = createOpAndInfer<tosa::MulOp>(rewriter, loc, outputType,
upCast, scale, /*shift=*/0);
shifted, scale, /*shift=*/0);

rewriter.replaceOp(op, scaled);
return success();
}

// MIGraphX pseudo code:
// int64_t quantized = static_cast<int32>(
// int32_t quantized = static_cast<int32>(
// std::round(input[i] / scales[i])) + zero_pts[i];
// output[i] = std::max(-128, std::min(127, quantized));
LogicalResult QuantizeLinearConverter::matchAndRewrite(
Expand Down
14 changes: 8 additions & 6 deletions mlir/test/Conversion/MIGraphXToTosa/mixr-to-tosa-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,29 @@ module {
}

// CHECK-LABEL: func @dequantize_scale_bias
// CHECK: tosa.sub
// CHECK: tosa.cast{{.*}}f32
// CHECK: tosa.cast{{.*}}f32
// CHECK: tosa.sub
// CHECK: tosa.mul
func.func @dequantize_scale_bias(%arg: !migraphx.shaped<1x112x112x64xi32, 802816x7168x64x1>, %scale: !migraphx.shaped<64xf32, 1>, %bias: !migraphx.shaped<64xi32, 1>) -> !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1> attributes {kernel = "mixr"} {
%1 = migraphx.dequantizelinear %arg, %scale, %bias : <1x112x112x64xi32, 802816x7168x64x1>, <64xf32, 1>, !migraphx.shaped<64xi32, 1> -> <1x112x112x64xf32, 802816x7168x64x1>
return %1 : !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1>
}

// CHECK-LABEL: func @dequantize_wide_bias
// CHECK: tosa.cast{{.*}}i32
// CHECK: tosa.sub{{.*}}i32
// CHECK: tosa.cast{{.*}}f32
// CHECK: tosa.cast{{.*}}f32
// CHECK: tosa.sub{{.*}}f32
// CHECK: tosa.mul
func.func @dequantize_wide_bias(%arg: !migraphx.shaped<1x112x112x64xi8, 802816x7168x64x1>, %scale: !migraphx.shaped<64xf32, 1>, %bias: !migraphx.shaped<64xi32, 1>) -> !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1> attributes {kernel = "mixr"} {
%1 = migraphx.dequantizelinear %arg, %scale, %bias : <1x112x112x64xi8, 802816x7168x64x1>, <64xf32, 1>, !migraphx.shaped<64xi32, 1> -> <1x112x112x64xf32, 802816x7168x64x1>
return %1 : !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1>
}

// CHECK-LABEL: func @dequantize_wide_input
// CHECK: tosa.cast{{.*}}i32
// CHECK: tosa.sub{{.*}}i32
// CHECK: tosa.cast{{.*}}f32
// CHECK: tosa.cast{{.*}}f32
// CHECK: tosa.sub{{.*}}f32
// CHECK: tosa.mul
func.func @dequantize_wide_input(%arg: !migraphx.shaped<1x112x112x64xi32, 802816x7168x64x1>, %scale: !migraphx.shaped<64xf32, 1>, %bias: !migraphx.shaped<64xi8, 1>) -> !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1> attributes {kernel = "mixr"} {
%1 = migraphx.dequantizelinear %arg, %scale, %bias : <1x112x112x64xi32, 802816x7168x64x1>, <64xf32, 1>, !migraphx.shaped<64xi8, 1> -> <1x112x112x64xf32, 802816x7168x64x1>
Expand Down Expand Up @@ -142,8 +143,9 @@ module {

// CHECK-LABEL: func @conv_with_quant
// CHECK: tosa.conv2d{{.*}} quantization_info
// CHECK: tosa.sub
// CHECK: tosa.cast
// CHECK: tosa.cast
// CHECK: tosa.sub
// CHECK: tosa.mul
// CHECK: tosa.reciprocal
// CHECK: tosa.mul
Expand Down

0 comments on commit 2d7c4ec

Please sign in to comment.