From 25ec4ce6f184ee3cd780b9503a57ad31f7680c74 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Sat, 4 Apr 2020 20:27:59 +0300 Subject: [PATCH] PReLU from Tensorflow --- .../src/tensorflow/tf_graph_simplifier.cpp | 60 +++++++++++++++++++ modules/dnn/src/tensorflow/tf_importer.cpp | 12 +++- modules/dnn/test/test_tf_importer.cpp | 16 ++++- 3 files changed, 86 insertions(+), 2 deletions(-) diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index b0978c2ace58..1afed2cf464b 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -223,6 +223,26 @@ class FlattenShapeSubgraph : public Subgraph } }; +class FlattenProdSubgraph : public Subgraph +{ +public: + FlattenProdSubgraph() + { + int input = addNodeToMatch(""); + int shape = addNodeToMatch("Shape", input); + int stack = addNodeToMatch("Const"); + int stack_1 = addNodeToMatch("Const"); + int stack_2 = addNodeToMatch("Const"); + int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); + int prod = addNodeToMatch("Prod", strided_slice, addNodeToMatch("Const")); + int shape_pack = addNodeToMatch("Const"); + int pack = addNodeToMatch("Pack", shape_pack, prod); + addNodeToMatch("Reshape", input, pack); + + setFusedNode("Flatten", input); + } +}; + // K.layers.Softmax class SoftMaxKerasSubgraph : public Subgraph { @@ -629,6 +649,36 @@ class KerasMVNSubgraph : public TFSubgraph } }; +class PReLUSubgraph : public TFSubgraph +{ +public: + PReLUSubgraph(bool negativeScales_) : negativeScales(negativeScales_) + { + int input = addNodeToMatch(""); + int scales = addNodeToMatch("Const"); + int neg = addNodeToMatch("Neg", input); + int relu_neg = addNodeToMatch("Relu", neg); + int finalScales = negativeScales ? addNodeToMatch("Neg", scales) : scales; + int mul = addNodeToMatch("Mul", finalScales, relu_neg); + int relu_pos = addNodeToMatch("Relu", input); + addNodeToMatch("Add", relu_pos, mul); + setFusedNode("PReLU", input, scales); + } + + virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode, + std::vector& inputNodes) CV_OVERRIDE + { + if (!negativeScales) + { + Mat scales = getTensorContent(inputNodes[1]->attr().at("value").tensor(), /*copy*/false); + scales *= -1; + } + } + +private: + bool negativeScales; +}; + void simplifySubgraphs(tensorflow::GraphDef& net) { std::vector > subgraphs; @@ -649,6 +699,16 @@ void simplifySubgraphs(tensorflow::GraphDef& net) subgraphs.push_back(Ptr(new SoftMaxSlimV2Subgraph())); subgraphs.push_back(Ptr(new ReshapeAsShapeSubgraph())); subgraphs.push_back(Ptr(new KerasMVNSubgraph())); + subgraphs.push_back(Ptr(new PReLUSubgraph(true))); + subgraphs.push_back(Ptr(new PReLUSubgraph(false))); + subgraphs.push_back(Ptr(new FlattenProdSubgraph())); + + for (int i = 0; i < net.node_size(); ++i) + { + tensorflow::NodeDef* layer = net.mutable_node(i); + if (layer->op() == "AddV2") + layer->set_op("Add"); + } simplifySubgraphs(Ptr(new TFGraphWrapper(net)), subgraphs); } diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 0dd21770a4b2..534ceff3df45 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -1231,6 +1231,7 @@ void TFImporter::populateNet(Net dstNet) // Only NHWC <-> NCHW permutations are allowed. OpenCV is always // keep NCHW layout this way. int inpLayout = getDataLayout(layer.input(0), data_layouts); + std::string type = "Identity"; if (inpLayout == DATA_LAYOUT_NHWC) { if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2) @@ -1245,6 +1246,15 @@ void TFImporter::populateNet(Net dstNet) // in OpenCV: NCHW->NCHW data_layouts[name] = DATA_LAYOUT_NHWC; } + else if (permData[0] == 0 && permData[1] == 3 && permData[2] == 2 && permData[3] == 1) + { + // in TensorFlow: NHWC->NCWH + // in OpenCV: NCHW->NCWH + int permData[] = {0, 1, 3, 2}; + layerParams.set("order", DictValue::arrayInt(permData, perm.total())); + data_layouts[name] = DATA_LAYOUT_NCHW; // we keep track NCHW because channels position only matters + type = "Permute"; + } else CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed."); } @@ -1265,7 +1275,7 @@ void TFImporter::populateNet(Net dstNet) else CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed."); } - int id = dstNet.addLayer(name, "Identity", layerParams); + int id = dstNet.addLayer(name, type, layerParams); layer_id[name] = id; connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); } diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 6738f1b91063..1f95dcfa2a8b 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -956,11 +956,25 @@ TEST_P(Test_TensorFlow_layers, resize_bilinear) runTensorFlowNet("resize_bilinear_factor"); } -TEST_P(Test_TensorFlow_layers, tf2_keras) +TEST_P(Test_TensorFlow_layers, tf2_dense) { runTensorFlowNet("tf2_dense"); } +TEST_P(Test_TensorFlow_layers, tf2_prelu) +{ + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH); + runTensorFlowNet("tf2_prelu"); +} + +TEST_P(Test_TensorFlow_layers, tf2_permute_nhwc_ncwh) +{ + runTensorFlowNet("tf2_permute_nhwc_ncwh"); +} + TEST_P(Test_TensorFlow_layers, squeeze) { #if defined(INF_ENGINE_RELEASE)