From d7603ab15f64e982025de298d7d577660d5b9dce Mon Sep 17 00:00:00 2001 From: "Lin, Peiyong" Date: Thu, 9 Jan 2025 19:14:27 +0000 Subject: [PATCH] Add center_point_box=1 support in NonMaxSuppression. When center_point_box=1, the supplied boxes come with a format of [x_center, y_center, width, height], this patch converts the format into [x1, y1, x2, y2] format before they are consumed. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 55 +++++++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 73 +++++++++++++++++++ 2 files changed, 124 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 12d8683bc9d12..d5b9abd69f3f5 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -3697,11 +3697,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) return failure(); - // TODO: Add support for non-zero center_point_box value. - if (centerPointBox != 0) + if (centerPointBox != 0 && centerPointBox != 1) return rewriter.notifyMatchFailure( - binder.op, "unimplemented: expected center_point_box " - "attribute value to be 0"); + binder.op, "expected center_point_box attribute to be 0 or 1"); // TODO: Support multiple batches and classes // Squeeze the boxes and scores tensor. @@ -3727,6 +3725,55 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( "failed to squeeze scores tensor"); boxes = squeezedBoxes.value(); scores = squeezedScores.value(); + if (centerPointBox == 1) { + // When center_point_box is 1, the box data is supplied as + // [[x_center, y_center, width, height], ...]. Slice it to + // [[x_center, y_center], ...] and [[width, height], ...], + // calculate the [[x1, y1], ...] and [[x2, y2], ...], and concatnate + // to [[x1, y1, x2, y2], ...] + Type boxType = boxes.getType(); + Value const0 = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value const1 = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value const2 = rewriter.create( + loc, rewriter.getI64IntegerAttr(2)); + Value const4 = rewriter.create( + loc, rewriter.getI64IntegerAttr(4)); + Value const2F = rewriter.create( + loc, rewriter.getF64FloatAttr(2.0)); + + // extract scaled ranges for regions of interest + auto tmpShape = SmallVector{Torch::kUnknownSize, 2}; + auto tmpTensorType = rewriter.getType( + tmpShape, + cast(boxes.getType()).getDtype()); + Value centers = rewriter.create( + loc, tmpTensorType, boxes, const1, const0, const2, const1); + Value sizes = rewriter.create( + loc, tmpTensorType, boxes, const1, const2, const4, const1); + Value const2FTensor = Torch::createRank0Tensor( + rewriter, loc, + Torch::ValueTensorType::get(binder.op->getContext(), std::nullopt, + rewriter.getF64Type()), + const2F); + Value halfSizes = rewriter.create( + loc, sizes.getType(), sizes, const2FTensor); + Value x1y1s = rewriter.create( + loc, centers.getType(), centers, halfSizes, const1); + Value x2y2s = rewriter.create( + loc, centers.getType(), centers, halfSizes, const1); + + Type listElemType = + cast(resultType) + .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + Value tensorList = rewriter.create( + loc, listType, SmallVector{x1y1s, x2y2s}); + boxes = rewriter.create(loc, boxType, tensorList, + const1); + } // TODO: Support score_threshold input // Filter out the boxes if the score < score_threshold diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 30b85e63ab0fb..3d2b5e72fa10b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -2145,6 +2145,79 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, return %0 : !torch.vtensor<[1,3],si64> } +// CHECK-LABEL: func.func @test_nonmaxsuppression_center_point_box( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,1],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],si64>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>, +// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_nonmaxsuppression_center_point_box(%arg0: !torch.vtensor<[1,1,4],f32>, %arg1: !torch.vtensor<[1,1,1],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_5:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1." + // CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[VAL_10:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_11:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1." + // CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> + // CHECK: %[[VAL_15:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_16:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1." + // CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[VAL_20:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_21:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_22:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_23:.*]] = torch.constant.int 4 + // CHECK: %[[VAL_24:.*]] = torch.constant.float 2.000000e+00 + // CHECK: %[[VAL_25:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_20]], %[[VAL_22]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32> + // CHECK: %[[VAL_26:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32> + // CHECK: %[[VAL_27:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[VAL_28:.*]] = torch.constant.none + // CHECK: %[[VAL_29:.*]] = torch.constant.int 7 + // CHECK: %[[VAL_30:.*]] = torch.aten.full %[[VAL_27]], %[[VAL_24]], %[[VAL_29]], %[[VAL_28]], %[[VAL_28]], %[[VAL_28]] : !torch.list, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64> + // CHECK: %[[VAL_31:.*]] = torch.aten.div.Tensor %[[VAL_26]], %[[VAL_30]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,2],f32> + // CHECK: %[[VAL_32:.*]] = torch.aten.sub.Tensor %[[VAL_25]], %[[VAL_31]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32> + // CHECK: %[[VAL_33:.*]] = torch.aten.add.Tensor %[[VAL_25]], %[[VAL_31]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32> + // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_32]], %[[VAL_33]] : (!torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>) -> !torch.list + // CHECK: %[[VAL_35:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_21]] : !torch.list, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[VAL_36:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_37:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_38:.*]] = torch.aten.item %[[VAL_37]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[VAL_39:.*]] = torch.aten.ge.float %[[VAL_38]], %[[VAL_36]] : !torch.float, !torch.float -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_39]], "unimplemented: score_threshold should be <= min(scores)" + // CHECK: %[[VAL_40:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_41:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_42:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[VAL_43:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_44:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_45:.*]] = torch.torchvision.nms %[[VAL_35]], %[[VAL_19]], %[[VAL_43]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64> + // CHECK: %[[VAL_46:.*]] = torch.aten.size.int %[[VAL_45]], %[[VAL_40]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_47:.*]] = torch.aten.gt.int %[[VAL_46]], %[[VAL_44]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[VAL_48:.*]] = torch.prim.If %[[VAL_47]] -> (!torch.vtensor<[1],si64>) { + // CHECK: %[[VAL_49:.*]] = torch.aten.slice.Tensor %[[VAL_45]], %[[VAL_40]], %[[VAL_40]], %[[VAL_44]], %[[VAL_41]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[VAL_49]] : !torch.vtensor<[1],si64> + // CHECK: } else { + // CHECK: %[[VAL_50:.*]] = torch.tensor_static_info_cast %[[VAL_45]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[VAL_50]] : !torch.vtensor<[1],si64> + // CHECK: } + // CHECK: %[[VAL_51:.*]] = torch.aten.unsqueeze %[[VAL_48]], %[[VAL_41]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + // CHECK: %[[VAL_52:.*]] = torch.aten.size.int %[[VAL_51]], %[[VAL_40]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_53:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_54:.*]] = torch.prim.ListConstruct %[[VAL_52]], %[[VAL_53]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_55:.*]] = torch.constant.none + // CHECK: %[[VAL_56:.*]] = torch.aten.zeros %[[VAL_54]], %[[VAL_55]], %[[VAL_55]], %[[VAL_55]], %[[VAL_55]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> + // CHECK: %[[VAL_57:.*]] = torch.prim.ListConstruct %[[VAL_56]], %[[VAL_51]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list + // CHECK: %[[VAL_58:.*]] = torch.aten.cat %[[VAL_57]], %[[VAL_41]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> + // CHECK: return %[[VAL_58]] : !torch.vtensor<[1,3],si64> + %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.center_point_box = 1 : si64} : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> + return %0 : !torch.vtensor<[1,3],si64> +} // ----- // CHECK-LABEL: func.func @test_mwm