Skip to content

Commit

Permalink
Merge pull request opencv#16983 from dkurt:dnn_tf_prelu
Browse files Browse the repository at this point in the history
  • Loading branch information
alalek committed Apr 28, 2020
2 parents dc1b1f2 + 25ec4ce commit 5da4bb7
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
60 changes: 60 additions & 0 deletions modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,26 @@ class FlattenShapeSubgraph : public Subgraph
}
};

class FlattenProdSubgraph : public Subgraph
{
public:
FlattenProdSubgraph()
{
int input = addNodeToMatch("");
int shape = addNodeToMatch("Shape", input);
int stack = addNodeToMatch("Const");
int stack_1 = addNodeToMatch("Const");
int stack_2 = addNodeToMatch("Const");
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
int prod = addNodeToMatch("Prod", strided_slice, addNodeToMatch("Const"));
int shape_pack = addNodeToMatch("Const");
int pack = addNodeToMatch("Pack", shape_pack, prod);
addNodeToMatch("Reshape", input, pack);

setFusedNode("Flatten", input);
}
};

// K.layers.Softmax
class SoftMaxKerasSubgraph : public Subgraph
{
Expand Down Expand Up @@ -629,6 +649,36 @@ class KerasMVNSubgraph : public TFSubgraph
}
};

class PReLUSubgraph : public TFSubgraph
{
public:
PReLUSubgraph(bool negativeScales_) : negativeScales(negativeScales_)
{
int input = addNodeToMatch("");
int scales = addNodeToMatch("Const");
int neg = addNodeToMatch("Neg", input);
int relu_neg = addNodeToMatch("Relu", neg);
int finalScales = negativeScales ? addNodeToMatch("Neg", scales) : scales;
int mul = addNodeToMatch("Mul", finalScales, relu_neg);
int relu_pos = addNodeToMatch("Relu", input);
addNodeToMatch("Add", relu_pos, mul);
setFusedNode("PReLU", input, scales);
}

virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
{
if (!negativeScales)
{
Mat scales = getTensorContent(inputNodes[1]->attr().at("value").tensor(), /*copy*/false);
scales *= -1;
}
}

private:
bool negativeScales;
};

void simplifySubgraphs(tensorflow::GraphDef& net)
{
std::vector<Ptr<Subgraph> > subgraphs;
Expand All @@ -649,6 +699,16 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph()));
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new KerasMVNSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(true)));
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(false)));
subgraphs.push_back(Ptr<Subgraph>(new FlattenProdSubgraph()));

for (int i = 0; i < net.node_size(); ++i)
{
tensorflow::NodeDef* layer = net.mutable_node(i);
if (layer->op() == "AddV2")
layer->set_op("Add");
}

simplifySubgraphs(Ptr<ImportGraphWrapper>(new TFGraphWrapper(net)), subgraphs);
}
Expand Down
12 changes: 11 additions & 1 deletion modules/dnn/src/tensorflow/tf_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,7 @@ void TFImporter::populateNet(Net dstNet)
// Only NHWC <-> NCHW permutations are allowed. OpenCV is always
// keep NCHW layout this way.
int inpLayout = getDataLayout(layer.input(0), data_layouts);
std::string type = "Identity";
if (inpLayout == DATA_LAYOUT_NHWC)
{
if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2)
Expand All @@ -1245,6 +1246,15 @@ void TFImporter::populateNet(Net dstNet)
// in OpenCV: NCHW->NCHW
data_layouts[name] = DATA_LAYOUT_NHWC;
}
else if (permData[0] == 0 && permData[1] == 3 && permData[2] == 2 && permData[3] == 1)
{
// in TensorFlow: NHWC->NCWH
// in OpenCV: NCHW->NCWH
int permData[] = {0, 1, 3, 2};
layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
data_layouts[name] = DATA_LAYOUT_NCHW; // we keep track NCHW because channels position only matters
type = "Permute";
}
else
CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
}
Expand All @@ -1265,7 +1275,7 @@ void TFImporter::populateNet(Net dstNet)
else
CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
}
int id = dstNet.addLayer(name, "Identity", layerParams);
int id = dstNet.addLayer(name, type, layerParams);
layer_id[name] = id;
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
}
Expand Down
16 changes: 15 additions & 1 deletion modules/dnn/test/test_tf_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -956,11 +956,25 @@ TEST_P(Test_TensorFlow_layers, resize_bilinear)
runTensorFlowNet("resize_bilinear_factor");
}

TEST_P(Test_TensorFlow_layers, tf2_keras)
TEST_P(Test_TensorFlow_layers, tf2_dense)
{
runTensorFlowNet("tf2_dense");
}

TEST_P(Test_TensorFlow_layers, tf2_prelu)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
runTensorFlowNet("tf2_prelu");
}

TEST_P(Test_TensorFlow_layers, tf2_permute_nhwc_ncwh)
{
runTensorFlowNet("tf2_permute_nhwc_ncwh");
}

TEST_P(Test_TensorFlow_layers, squeeze)
{
#if defined(INF_ENGINE_RELEASE)
Expand Down

0 comments on commit 5da4bb7

Please sign in to comment.