@@ -2404,6 +2404,80 @@ static Status TranslateMaxPoolGradOp(
2404
2404
return Status::OK ();
2405
2405
}
2406
2406
2407
+ static Status TranslateNonMaxSuppressionV4Op (
2408
+ const Node* op, const std::vector<const Tensor*>& static_input_map,
2409
+ Builder::OpMap& ng_op_map) {
2410
+ shared_ptr<ng::Node> ng_boxes, ng_scores;
2411
+ TF_RETURN_IF_ERROR (GetInputNodes (ng_op_map, op, &ng_boxes, &ng_scores,
2412
+ nullptr , nullptr , nullptr ));
2413
+
2414
+ std::vector<int > max_output_size;
2415
+ TF_RETURN_IF_ERROR (
2416
+ GetStaticInputVector (op, 2 , static_input_map, &max_output_size));
2417
+ std::vector<float > iou_threshold;
2418
+ TF_RETURN_IF_ERROR (
2419
+ GetStaticInputVector (op, 3 , static_input_map, &iou_threshold));
2420
+
2421
+ std::vector<float > score_threshold;
2422
+ TF_RETURN_IF_ERROR (
2423
+ GetStaticInputVector (op, 4 , static_input_map, &score_threshold));
2424
+
2425
+ bool pad_to_max_output_size;
2426
+ if (GetNodeAttr (op->attrs (), " pad_to_max_output_size" ,
2427
+ &pad_to_max_output_size) != Status::OK ()) {
2428
+ pad_to_max_output_size = false ;
2429
+ }
2430
+ // max_output_size must be scalar
2431
+ if (max_output_size.size () != 1 ) {
2432
+ return errors::InvalidArgument (
2433
+ " NonMaxSuppressionV4 Op: max_output_size of nms must be scalar " ,
2434
+ max_output_size.size ());
2435
+ }
2436
+ // iou_threshold must be scalar
2437
+ if (iou_threshold.size () != 1 ) {
2438
+ return errors::InvalidArgument (
2439
+ " NonMaxSuppressionV4 Op: iou_threshold of nms must be scalar " ,
2440
+ iou_threshold.size ());
2441
+ }
2442
+
2443
+ // score_threshold must be scalar
2444
+ if (score_threshold.size () != 1 ) {
2445
+ return errors::InvalidArgument (
2446
+ " NonMaxSuppressionV4 Op: score_threshold of nms must be scalar " ,
2447
+ score_threshold.size ());
2448
+ }
2449
+
2450
+ std::string backend_name;
2451
+ TF_RETURN_IF_ERROR (ngraph_bridge::GetNodeBackend (op, &backend_name));
2452
+
2453
+ if (backend_name != " NNPI" ) {
2454
+ return errors::Internal (" In translating NonMaxSuppressionV4 op " ,
2455
+ op->name (), " found requested backend " ,
2456
+ backend_name, " which is unsupported" );
2457
+ }
2458
+
2459
+ ng::runtime::Backend* backend = BackendManager::GetBackend (backend_name);
2460
+
2461
+ shared_ptr<ng::Node> ng_nmsv4 = backend->get_backend_op (
2462
+ " NonMaxSuppressionV4" , &ng_boxes, &ng_scores,
2463
+ (size_t )(max_output_size[0 ]), (float )(iou_threshold[0 ]),
2464
+ (float )score_threshold[0 ], (bool )pad_to_max_output_size);
2465
+ if (ng_nmsv4 == nullptr ) {
2466
+ return errors::Internal (" In translating NonMaxSuppressionV4 op " ,
2467
+ op->name (),
2468
+ " backend could not return valid ngraph node" );
2469
+ }
2470
+ shared_ptr<ngraph::Node> ng_selected_indices =
2471
+ ConstructNgNode<ngraph::op::GetOutputElement>(op->name (), ng_nmsv4, 0 );
2472
+ shared_ptr<ngraph::Node> ng_valid_output =
2473
+ ConstructNgNode<ngraph::op::GetOutputElement>(op->name (), ng_nmsv4, 1 );
2474
+
2475
+ SaveNgOp (ng_op_map, op->name (), ng_selected_indices);
2476
+ SaveNgOp (ng_op_map, op->name (), ng_valid_output);
2477
+
2478
+ return Status::OK ();
2479
+ }
2480
+
2407
2481
static Status TranslateReduceOp (
2408
2482
const Node* op, const std::vector<const Tensor*>& static_input_map,
2409
2483
Builder::OpMap& ng_op_map,
@@ -4357,6 +4431,7 @@ const static std::map<
4357
4431
{" MaxPool" , TranslateMaxPoolOp},
4358
4432
{" MaxPool3D" , TranslateMaxPool3DOp},
4359
4433
{" MaxPoolGrad" , TranslateMaxPoolGradOp},
4434
+ {" NonMaxSuppressionV4" , TranslateNonMaxSuppressionV4Op},
4360
4435
{" Mean" , TranslateMeanOp},
4361
4436
{" Min" , TranslateDirectReduceOp<ng::op::Min>},
4362
4437
{" Minimum" , TranslateBinaryOp<ngraph::op::Minimum>},
0 commit comments