Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit a39cd8f

Browse files
sachinmuradiavijit-nervana
authored andcommitted
Sarkars/nmsv4 (#489)
1 parent f401b54 commit a39cd8f

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

src/ngraph_builder.cc

+75
Original file line numberDiff line numberDiff line change
@@ -2404,6 +2404,80 @@ static Status TranslateMaxPoolGradOp(
24042404
return Status::OK();
24052405
}
24062406

2407+
static Status TranslateNonMaxSuppressionV4Op(
2408+
const Node* op, const std::vector<const Tensor*>& static_input_map,
2409+
Builder::OpMap& ng_op_map) {
2410+
shared_ptr<ng::Node> ng_boxes, ng_scores;
2411+
TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, &ng_boxes, &ng_scores,
2412+
nullptr, nullptr, nullptr));
2413+
2414+
std::vector<int> max_output_size;
2415+
TF_RETURN_IF_ERROR(
2416+
GetStaticInputVector(op, 2, static_input_map, &max_output_size));
2417+
std::vector<float> iou_threshold;
2418+
TF_RETURN_IF_ERROR(
2419+
GetStaticInputVector(op, 3, static_input_map, &iou_threshold));
2420+
2421+
std::vector<float> score_threshold;
2422+
TF_RETURN_IF_ERROR(
2423+
GetStaticInputVector(op, 4, static_input_map, &score_threshold));
2424+
2425+
bool pad_to_max_output_size;
2426+
if (GetNodeAttr(op->attrs(), "pad_to_max_output_size",
2427+
&pad_to_max_output_size) != Status::OK()) {
2428+
pad_to_max_output_size = false;
2429+
}
2430+
// max_output_size must be scalar
2431+
if (max_output_size.size() != 1) {
2432+
return errors::InvalidArgument(
2433+
"NonMaxSuppressionV4 Op: max_output_size of nms must be scalar ",
2434+
max_output_size.size());
2435+
}
2436+
// iou_threshold must be scalar
2437+
if (iou_threshold.size() != 1) {
2438+
return errors::InvalidArgument(
2439+
"NonMaxSuppressionV4 Op: iou_threshold of nms must be scalar ",
2440+
iou_threshold.size());
2441+
}
2442+
2443+
// score_threshold must be scalar
2444+
if (score_threshold.size() != 1) {
2445+
return errors::InvalidArgument(
2446+
"NonMaxSuppressionV4 Op: score_threshold of nms must be scalar ",
2447+
score_threshold.size());
2448+
}
2449+
2450+
std::string backend_name;
2451+
TF_RETURN_IF_ERROR(ngraph_bridge::GetNodeBackend(op, &backend_name));
2452+
2453+
if (backend_name != "NNPI") {
2454+
return errors::Internal("In translating NonMaxSuppressionV4 op ",
2455+
op->name(), " found requested backend ",
2456+
backend_name, " which is unsupported");
2457+
}
2458+
2459+
ng::runtime::Backend* backend = BackendManager::GetBackend(backend_name);
2460+
2461+
shared_ptr<ng::Node> ng_nmsv4 = backend->get_backend_op(
2462+
"NonMaxSuppressionV4", &ng_boxes, &ng_scores,
2463+
(size_t)(max_output_size[0]), (float)(iou_threshold[0]),
2464+
(float)score_threshold[0], (bool)pad_to_max_output_size);
2465+
if (ng_nmsv4 == nullptr) {
2466+
return errors::Internal("In translating NonMaxSuppressionV4 op ",
2467+
op->name(),
2468+
" backend could not return valid ngraph node");
2469+
}
2470+
shared_ptr<ngraph::Node> ng_selected_indices =
2471+
ConstructNgNode<ngraph::op::GetOutputElement>(op->name(), ng_nmsv4, 0);
2472+
shared_ptr<ngraph::Node> ng_valid_output =
2473+
ConstructNgNode<ngraph::op::GetOutputElement>(op->name(), ng_nmsv4, 1);
2474+
2475+
SaveNgOp(ng_op_map, op->name(), ng_selected_indices);
2476+
SaveNgOp(ng_op_map, op->name(), ng_valid_output);
2477+
2478+
return Status::OK();
2479+
}
2480+
24072481
static Status TranslateReduceOp(
24082482
const Node* op, const std::vector<const Tensor*>& static_input_map,
24092483
Builder::OpMap& ng_op_map,
@@ -4357,6 +4431,7 @@ const static std::map<
43574431
{"MaxPool", TranslateMaxPoolOp},
43584432
{"MaxPool3D", TranslateMaxPool3DOp},
43594433
{"MaxPoolGrad", TranslateMaxPoolGradOp},
4434+
{"NonMaxSuppressionV4", TranslateNonMaxSuppressionV4Op},
43604435
{"Mean", TranslateMeanOp},
43614436
{"Min", TranslateDirectReduceOp<ng::op::Min>},
43624437
{"Minimum", TranslateBinaryOp<ngraph::op::Minimum>},

src/ngraph_mark_for_clustering.cc

+9-1
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,8 @@ Status MarkForClustering(Graph* graph,
460460
type_constraint_map["Minimum"]["T"] = NGraphNumericDTypes();
461461
type_constraint_map["Mul"]["T"] = NGraphNumericDTypes();
462462
type_constraint_map["Neg"]["T"] = NGraphNumericDTypes();
463+
type_constraint_map["NonMaxSuppressionV4"]["T"] = {
464+
DT_FLOAT}; // TF allows half too
463465
type_constraint_map["OneHot"]["T"] = NGraphDTypes();
464466
type_constraint_map["Pack"]["T"] = NGraphDTypes();
465467
type_constraint_map["Pad"]["T"] = NGraphDTypes();
@@ -574,6 +576,7 @@ Status MarkForClustering(Graph* graph,
574576
set_attributes_map["Max"] = SetStaticInputs({1});
575577
set_attributes_map["Mean"] = SetStaticInputs({1});
576578
set_attributes_map["Min"] = SetStaticInputs({1});
579+
set_attributes_map["NonMaxSuppressionV4"] = SetStaticInputs({2, 3, 4});
577580
set_attributes_map["OneHot"] = SetStaticInputs({1});
578581
set_attributes_map["Pad"] = SetStaticInputs({1});
579582
set_attributes_map["Prod"] = SetStaticInputs({1});
@@ -626,7 +629,6 @@ Status MarkForClustering(Graph* graph,
626629
" is not supported");
627630
}
628631
current_backend = backend_env;
629-
// TODO: set backend. Then don't use current_backend
630632
}
631633

632634
// Right now it cannot be inside the if(!initialized) block, because it is
@@ -639,6 +641,12 @@ Status MarkForClustering(Graph* graph,
639641
return Status::OK();
640642
};
641643

644+
confirmation_function_map["NonMaxSuppressionV4"] = [&current_backend](
645+
Node* n, bool* result) {
646+
*result = (current_backend == "NNPI");
647+
return Status::OK();
648+
};
649+
642650
std::unordered_map<string, int> no_support_histogram;
643651
std::unordered_map<string, int> fail_confirmation_histogram;
644652
std::unordered_map<string, int> fail_constraint_histogram;

0 commit comments

Comments
 (0)