From 6a95cd0663e5d19b0e5e70a38f913e39a69992d4 Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Mon, 14 Nov 2022 08:05:55 +0000 Subject: [PATCH] [Runtime] Implement general multi-stream partition policy. We convert graph to minimum equivalent graph, then convert it to bipartite graph. Next we find the bipartite graph maximum matching. At last we split the graph according the maximum matching into subgraphs. --- tensorflow/core/BUILD | 1 + tensorflow/core/graph/stream_subgraph.cc | 284 ++++++++++++++++-- tensorflow/core/graph/stream_subgraph.h | 6 + tensorflow/core/graph/stream_subgraph_test.cc | 54 ++++ 4 files changed, 319 insertions(+), 26 deletions(-) create mode 100644 tensorflow/core/graph/stream_subgraph_test.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 95616ebaf29..45af055ff8f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -4443,6 +4443,7 @@ tf_cc_tests( "graph/optimizer_cse_test.cc", "graph/optimizer_fusion_engine_test.cc", "graph/star_server_graph_partition_test.cc", + "graph/stream_subgraph_test.cc", "graph/subgraph_test.cc", "graph/tensor_id_test.cc", "graph/validate_test.cc", diff --git a/tensorflow/core/graph/stream_subgraph.cc b/tensorflow/core/graph/stream_subgraph.cc index 56fe090470c..07559e2132e 100644 --- a/tensorflow/core/graph/stream_subgraph.cc +++ b/tensorflow/core/graph/stream_subgraph.cc @@ -22,6 +22,8 @@ limitations under the License. namespace tensorflow { namespace stream_subgraph { +using DAG = std::vector>; +using Bigraph = std::vector>; namespace { @@ -43,52 +45,212 @@ std::string GetDeviceNamePrefix(const std::string& device_name) { return device_name_prefix; } -} // namesapce +DAG GraphToDAG(const Graph* g) { + DAG dag; + dag.resize(g->num_node_ids()); + for (auto node : g->nodes()) { + for (auto edge : node->out_edges()) { + int dst_id = edge->dst()->id(); + dag[node->id()].push_back(dst_id); + } + } -void MarkStreamSubGraph(Graph* g, const MultiStreamOptions& opt) { - int num_streams = opt.multi_stream_num(); - MultiStreamPartitionPolicy policy = opt.partition_policy(); + return dag; +} - if (policy == MultiStreamPartitionPolicy::EMBEDDING_GRAPH_PARTITION) { - MarkEmbeddingGraph(g, num_streams); +void DFS(int curr, const DAG& graph, + std::vector& visited) { + visited[curr] = true; + const std::vector& adjacent_nodes = graph[curr]; + for (auto n : adjacent_nodes) { + if (!visited[n]) { + DFS(n, graph, visited); + } } } -void MarkEmbeddingGraph(Graph* g, int num_streams) { - bool train_graph = false; +// TODO: Optimize the algorithm +std::vector> GetReachableNodes(const DAG& dag) { + std::vector> reachable_nodes; + int num_nodes = dag.size(); + for (int i = 0; i < num_nodes; i++) { + std::vector reachable(num_nodes, false); + DFS(i, dag, reachable); + reachable[i] = false; + reachable_nodes.push_back(std::move(reachable)); + } + + return reachable_nodes; +} + +// Get minimum equivalent graph +DAG GetMEG(const DAG& dag) { + const auto& reachable_nodes = GetReachableNodes(dag); + int num_nodes = dag.size(); + DAG meg = dag; + for (int i = 0; i < num_nodes; i++) { + auto& meg_child_nodes = meg[i]; + auto& child_nodes = dag[i]; + for (auto child : child_nodes) { + if (std::find(meg_child_nodes.begin(), + meg_child_nodes.end(), child) == + meg_child_nodes.end()) { + continue; + } + for (auto another : child_nodes) { + if (reachable_nodes[child][another]) { + auto it = std::find(meg_child_nodes.begin(), + meg_child_nodes.end(), another); + if (it != meg_child_nodes.end()) { + meg_child_nodes.erase(it); + } + } + } + } + } + + return meg; +} + +Bigraph MEGToBigraph(const DAG& meg) { + Bigraph bigraph; + int num_nodes = meg.size(); + for (int i = 0; i < num_nodes; i++) { + std::vector adjacency(num_nodes, false); + for (auto child : meg[i]) { + adjacency[child] = true; + } + bigraph.push_back(std::move(adjacency)); + } + + return bigraph; +} + +Bigraph DAGToBigraph(const DAG& dag) { + Bigraph bigraph; + int num_nodes = dag.size(); + for (int i = 0; i < num_nodes; i++) { + std::vector reachable(num_nodes, false); + DFS(i, dag, reachable); + reachable[i] = false; + bigraph.push_back(std::move(reachable)); + } + + return bigraph; +} + +DAG BuildStreamDAG( + const DAG& dag, + const std::vector>& stream_chains) { + const auto& reachable_nodes = GetReachableNodes(dag); + DAG stream_dag; + for (int i = 0; i < stream_chains.size(); i++) { + std::vector ensuing_streams; + auto chain_end = stream_chains[i][1]; + for (int j = 0; j < stream_chains.size(); j++) { + auto chain_begin = stream_chains[j][0]; + if (reachable_nodes[chain_end][chain_begin]) { + ensuing_streams.push_back(j); + } + } + stream_dag.push_back(ensuing_streams); + } + + return stream_dag; +} + +bool FindMatching(int start, const Bigraph& graph, + std::vector& visited, + std::vector& match_status) { + int num = graph[0].size(); + for (int i = 0; i < num; i++) { + if (graph[start][i] && !visited[i]) { + visited[i] = true; + int curr_match = match_status[i]; + if (match_status[i] == -1 || + FindMatching(curr_match, graph, visited, match_status)) { + match_status[i] = start; + return true; + } + } + } + + return false; +} + +std::vector MaximumMatching(const Bigraph& graph) { + int num = graph[0].size(); + std::vector match_result(num, -1); + int num_bigraph = graph.size(); + for (int i = 0; i < num_bigraph; i++) { + std::vector visited(num, false); + FindMatching(i, graph, visited, match_result); + } + + return match_result; +} + +std::tuple, std::vector>, int> +GetMapping(const std::vector& matching) { + int num_nodes = matching.size(); + std::vector> chains; + for(int i = 0; i < num_nodes; i++) { + auto it = std::find(matching.begin(), matching.end(), i); + if (it == matching.end()) { + chains.push_back({i, i}); + } + } + + int group_num = 0; + std::vector mapping(num_nodes, -1); + for (auto& chain : chains) { + int group_id = group_num++; + int curr = chain[1]; + while (true) { + mapping[curr] = group_id; + if (matching[curr] == -1) { + chain[0] = curr; + break; + } else { + curr = matching[curr]; + } + } + } + + return std::make_tuple(mapping, chains, group_num); +} + +} // namesapce + +void MarkStreamSubGraph(Graph* g, const MultiStreamOptions& opt) { // trained graph if (!g->IsTrainingGraph()) { return; } - //for (Node* n : g->nodes()) { // if (n->type_string() == "IsVariableInitialized" && // n->name() != "global_step/IsVariableInitialized") { - // return; + // return; // } //} + int num_streams = opt.multi_stream_num(); + MultiStreamPartitionPolicy policy = opt.partition_policy(); + if (policy == MultiStreamPartitionPolicy::EMBEDDING_GRAPH_PARTITION) { + MarkEmbeddingGraph(g, num_streams); + } else if (policy == MultiStreamPartitionPolicy::FULL_GRAPH_PARTITION) { + MarkFullGraph(g, num_streams); + } else { + // Unrecognized policy + return; + } + std::unordered_map name_to_node; - // User marked subgraph for (Node* n : g->nodes()) { name_to_node[n->name()] = n; - - if (n->assigned_device_name().find("device:GPU:") == std::string::npos || - n->def().attr().find("_stream_id") == n->def().attr().end()) { - continue; - } - - int stream_id = n->def().attr().at("_stream_id").i(); - std::string required_device = - GetDeviceNamePrefix(n->assigned_device_name()) + - std::to_string(stream_id); - if (n->assigned_device_name() != required_device) { - n->set_assigned_device_name(required_device); - } } - // Colocate nodes std::unordered_map> node_colocate_childs; std::unordered_set colocate_nodes; @@ -130,7 +292,6 @@ void MarkEmbeddingGraph(Graph* g, int num_streams) { continue; } - //std::vector edges_to_delete; std::vector in_edges(n->in_edges().begin(), n->in_edges().end()); for (const Edge* e : in_edges) { @@ -161,5 +322,76 @@ void MarkEmbeddingGraph(Graph* g, int num_streams) { } } +// Return stream id vector which indexed by node id +std::vector GenerateNodeStreamId(const Graph* graph) { + // Assign stream id nodes. + const auto& dag = GraphToDAG(graph); + const auto& meg = GetMEG(dag); + const auto& bigraph = MEGToBigraph(meg); + const auto& matching = MaximumMatching(bigraph); + const auto& result = GetMapping(matching); + std::vector node_to_chain = std::get<0>(result); + + // Rematching stream, some streams can have the same id. + const auto& stream_chains = std::get<1>(result); + const auto& stream_dag = BuildStreamDAG(meg, stream_chains); + const auto& stream_bigraph = DAGToBigraph(stream_dag); + const auto& rematching = MaximumMatching(stream_bigraph); + const auto& remapping = GetMapping(rematching); + std::vector chain_to_stream = std::get<0>(remapping); + + std::vector stream_ids(node_to_chain.size(), -1); + for (int node_id = 0; node_id < node_to_chain.size(); ++node_id) { + stream_ids[node_id] = chain_to_stream[node_to_chain[node_id]]; + } + + return stream_ids; +} + +void MarkFullGraph(Graph* g, int num_streams) { + std::vector node_stream_ids = GenerateNodeStreamId(g); + + std::unordered_map name_to_node; + for (Node* n : g->nodes()) { + name_to_node[n->name()] = n; + + if (n->assigned_device_name().find("device:GPU:") == + std::string::npos) { + continue; + } + + int stream_id = node_stream_ids[n->id()] % num_streams; + n->AddAttr("_stream_id", stream_id); + + std::string required_device = + GetDeviceNamePrefix(n->assigned_device_name()) + + std::to_string(stream_id); + if (n->assigned_device_name() != required_device) { + n->set_assigned_device_name(required_device); + } + } +} + +void MarkEmbeddingGraph(Graph* g, int num_streams) { + std::unordered_map name_to_node; + // User marked subgraph + for (Node* n : g->nodes()) { + name_to_node[n->name()] = n; + + if (n->assigned_device_name().find("device:GPU:") == std::string::npos || + n->def().attr().find("_stream_id") == n->def().attr().end()) { + continue; + } + + int stream_id = n->def().attr().at("_stream_id").i(); + std::string required_device = + GetDeviceNamePrefix(n->assigned_device_name()) + + std::to_string(stream_id); + if (n->assigned_device_name() != required_device) { + n->set_assigned_device_name(required_device); + } + } +} + } // namespace stream_subgraph } // namespace tensorflow diff --git a/tensorflow/core/graph/stream_subgraph.h b/tensorflow/core/graph/stream_subgraph.h index 28a3415084b..ac568a121a8 100644 --- a/tensorflow/core/graph/stream_subgraph.h +++ b/tensorflow/core/graph/stream_subgraph.h @@ -34,6 +34,12 @@ void MarkStreamSubGraph(Graph* g, const MultiStreamOptions& opt); // Assign embedding graphs stream. void MarkEmbeddingGraph(Graph* g, int num_streams); +// Auto split full graph to subgraphs, +// and assign stream to each subgraph. +void MarkFullGraph(Graph* g, int num_streams); +// Return stream id vector which indexed by node id +std::vector GenerateNodeStreamId(const Graph* graph); + } // namespace stream_subgraph } // namespace tensorflow diff --git a/tensorflow/core/graph/stream_subgraph_test.cc b/tensorflow/core/graph/stream_subgraph_test.cc new file mode 100644 index 00000000000..8159fb1e02b --- /dev/null +++ b/tensorflow/core/graph/stream_subgraph_test.cc @@ -0,0 +1,54 @@ +#include "tensorflow/core/graph/stream_subgraph.h" + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(GenerateNodeStreamId, TestGraph) { + Graph graph(OpRegistry::Global()); + std::vector nodes; + nodes.push_back(graph.source_node()); + nodes.push_back(graph.sink_node()); + for (auto edge : graph.source_node()->out_edges()) { + graph.RemoveEdge(edge); + } + for (int i = 0; i < 5; ++i) { + Node* node; + TF_CHECK_OK(NodeBuilder(strings::StrCat("v", i+1), "NoOp").Finalize(&graph, &node)); + nodes.push_back(node); + } + + graph.AddEdge(nodes[0], 0, nodes[1], 0); + graph.AddEdge(nodes[0], 1, nodes[2], 0); + graph.AddEdge(nodes[0], 2, nodes[3], 0); + graph.AddEdge(nodes[0], 3, nodes[5], 0); + + graph.AddEdge(nodes[1], 0, nodes[4], 0); + + graph.AddEdge(nodes[2], 0, nodes[4], 1); + graph.AddEdge(nodes[2], 1, nodes[6], 0); + graph.AddEdge(nodes[2], 2, nodes[5], 1); + + graph.AddEdge(nodes[3], 0, nodes[5], 2); + + graph.AddEdge(nodes[4], 0, nodes[6], 1); + + auto mapping = stream_subgraph::GenerateNodeStreamId(&graph); + + EXPECT_EQ(mapping.size(), 7); + EXPECT_EQ(mapping[0], mapping[1]); + EXPECT_EQ(mapping[1], mapping[4]); + EXPECT_EQ(mapping[4], mapping[6]); + EXPECT_EQ(mapping[2], mapping[5]); + for (int i = 0; i < mapping.size(); ++i) { + VLOG(2) << i+1 << ": " << mapping[i]; + } +} + +} // namespace +} // namespace tensorflow