diff --git a/src/compiler/precompute_prune.cc b/src/compiler/precompute_prune.cc index 850a961fc..c3656b640 100644 --- a/src/compiler/precompute_prune.cc +++ b/src/compiler/precompute_prune.cc @@ -27,6 +27,25 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) { // number of edges that are not variable int non_var_edge = 0; + auto replace_pruned_entry = [&] (const NodeEntry& e) { + if (!entry_var.count(e)) { + if (!e.node->is_variable()) { + ++non_var_edge; + } + nnvm::NodePtr var = nnvm::Node::Create(); + var->attrs.name = e.node->attrs.name; + if (e.node->num_outputs() != 1) { + var->attrs.name += "_output" + std::to_string(e.index); + } + entry_var.emplace(e, var); + CHECK(!unique_name.count(var->attrs.name)); + unique_name.insert(var->attrs.name); + return nnvm::NodeEntry{var, 0, 0}; + } else { + return nnvm::NodeEntry{entry_var.at(e), 0, 0}; + } + }; + DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) { bool can_be_pruned = true; if (n->is_variable()) { @@ -47,20 +66,7 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) { // scan again to find edge nodes, skip variables for (auto& e : n->inputs) { if (pruned.count(e.node.get())) { - if (!entry_var.count(e)) { - if (!e.node->is_variable()) { - ++non_var_edge; - } - nnvm::NodePtr var = nnvm::Node::Create(); - var->attrs.name = e.node->attrs.name; - if (e.node->num_outputs() != 1) { - var->attrs.name += "_output" + std::to_string(e.index); - } - entry_var.emplace(e, var); - CHECK(!unique_name.count(var->attrs.name)); - unique_name.insert(var->attrs.name); - } - e = nnvm::NodeEntry{entry_var.at(e), 0, 0}; + e = replace_pruned_entry(e); } } } @@ -71,6 +77,12 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) { return src; } + for (auto& e : src.outputs) { + if (pruned.count(e.node.get())) { + e = replace_pruned_entry(e); + } + } + nnvm::Graph pre_graph; pre_graph.outputs.reserve(entry_var.size()); std::vector output_names;