From dcb54daca88b380751822ee4269c33f7d8c69260 Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Mon, 24 Feb 2025 16:42:43 -0800 Subject: [PATCH] [cleanup] Depend directly on graph.h, not the deprecated ebert_graph.h PiperOrigin-RevId: 730643865 --- xls/passes/BUILD | 4 +-- xls/passes/dataflow_graph_analysis.cc | 45 +++++++++++++------------- xls/passes/dataflow_graph_analysis.h | 46 ++++++++++++--------------- 3 files changed, 44 insertions(+), 51 deletions(-) diff --git a/xls/passes/BUILD b/xls/passes/BUILD index b2b5490607..09b8605e24 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -1081,8 +1081,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_ortools//ortools/graph:ebert_graph", - "@com_google_ortools//ortools/graph:max_flow", + "@com_google_ortools//ortools/graph", + "@com_google_ortools//ortools/graph:generic_max_flow", ], ) diff --git a/xls/passes/dataflow_graph_analysis.cc b/xls/passes/dataflow_graph_analysis.cc index fb0de54820..e990f2af5f 100644 --- a/xls/passes/dataflow_graph_analysis.cc +++ b/xls/passes/dataflow_graph_analysis.cc @@ -37,13 +37,14 @@ #include "xls/ir/ternary.h" #include "xls/ir/topo_sort.h" #include "xls/passes/query_engine.h" -#include "ortools/graph/ebert_graph.h" -#include "ortools/graph/max_flow.h" +#include "ortools/graph/generic_max_flow.h" namespace xls { namespace { +using ::operations_research::GenericMaxFlow; + // Returns whether the given node *originates* data that is potentially unknown. bool IsDataOriginating(Node* node) { return node->OpIn({Op::kReceive, Op::kRegisterRead, Op::kParam, @@ -57,9 +58,7 @@ DataflowGraphAnalysis::DataflowGraphAnalysis(FunctionBase* f, : nodes_(TopoSort(f)) { CHECK_LT( nodes_.size(), - static_cast( - (std::numeric_limits::max() >> 1) - - 1)); + static_cast((std::numeric_limits::max() >> 1) - 1)); node_to_index_.reserve(nodes_.size()); for (size_t i = 0; i < nodes_.size(); ++i) { node_to_index_[nodes_[i]] = i; @@ -67,15 +66,15 @@ DataflowGraphAnalysis::DataflowGraphAnalysis(FunctionBase* f, graph_ = std::make_unique(); graph_->AddNode(kSourceIndex); - absl::flat_hash_map arc_capacities; - absl::flat_hash_map internal_arcs; - absl::flat_hash_map source_arcs; - absl::flat_hash_map sink_arcs; + absl::flat_hash_map arc_capacities; + absl::flat_hash_map internal_arcs; + absl::flat_hash_map source_arcs; + absl::flat_hash_map sink_arcs; for (size_t i = 0; i < nodes_.size(); ++i) { Node* node = nodes_[i]; - operations_research::NodeIndex v_in = InIndex(i); - operations_research::NodeIndex v_out = OutIndex(i); + NodeIndex v_in = InIndex(i); + NodeIndex v_out = OutIndex(i); graph_->AddNode(v_in); graph_->AddNode(v_out); @@ -110,9 +109,9 @@ DataflowGraphAnalysis::DataflowGraphAnalysis(FunctionBase* f, } } - std::vector arc_permutation; + std::vector arc_permutation; graph_->Build(&arc_permutation); - auto permuted = [&](operations_research::ArcIndex arc) { + auto permuted = [&](ArcIndex arc) { return arc < arc_permutation.size() ? arc_permutation[arc] : arc; }; for (const auto& [arc, capacity] : arc_capacities) { @@ -127,8 +126,8 @@ DataflowGraphAnalysis::DataflowGraphAnalysis(FunctionBase* f, for (const auto& [node, arc] : sink_arcs) { sink_arcs_[node] = permuted(arc); } - max_flow_ = std::make_unique>( - graph_.get(), kSourceIndex, kSinkIndex); + max_flow_ = std::make_unique>(graph_.get(), + kSourceIndex, kSinkIndex); for (auto& [arc, capacity] : arc_capacities_) { max_flow_->SetArcCapacity(arc, capacity); } @@ -148,15 +147,15 @@ absl::Status DataflowGraphAnalysis::SolveFor(Node* node) { return absl::OkStatus(); } switch (max_flow_->status()) { - case operations_research::GenericMaxFlow::NOT_SOLVED: + case GenericMaxFlow::NOT_SOLVED: return absl::InternalError("Max flow solver failed to solve"); - case operations_research::GenericMaxFlow::OPTIMAL: + case GenericMaxFlow::OPTIMAL: return absl::InternalError("Max flow solver reported an unknown failure"); - case operations_research::GenericMaxFlow::INT_OVERFLOW: + case GenericMaxFlow::INT_OVERFLOW: return absl::InternalError("Possible overflow in max flow solver"); - case operations_research::GenericMaxFlow::BAD_INPUT: + case GenericMaxFlow::BAD_INPUT: return absl::InternalError("Bad input to max flow solver"); - case operations_research::GenericMaxFlow::BAD_RESULT: + case GenericMaxFlow::BAD_RESULT: return absl::InternalError("Bad result from max flow solver"); } return absl::InternalError( @@ -179,11 +178,11 @@ absl::StatusOr> DataflowGraphAnalysis::GetMinCutFor( return std::vector({}); } - std::vector min_cut_indices; + std::vector min_cut_indices; max_flow_->GetSourceSideMinCut(&min_cut_indices); absl::flat_hash_set min_cut_nodes; - for (operations_research::NodeIndex index : min_cut_indices) { + for (NodeIndex index : min_cut_indices) { if (index == kSourceIndex) { for (const auto& [source, arc_index] : source_arcs_) { if (max_flow_->Flow(arc_index) > 0) { @@ -205,7 +204,7 @@ absl::StatusOr> DataflowGraphAnalysis::GetMinCutFor( // is not in the min cut. if (absl::c_any_of( max_flow_->graph()->OutgoingArcs(OutIndex(min_cut_node)), - [&](operations_research::ArcIndex out_arc) { + [&](ArcIndex out_arc) { Node* target = nodes_[TopoIndex(max_flow_->graph()->Head(out_arc))]; return !min_cut_nodes.contains(target) && diff --git a/xls/passes/dataflow_graph_analysis.h b/xls/passes/dataflow_graph_analysis.h index ec0d15203c..1a3d2291f9 100644 --- a/xls/passes/dataflow_graph_analysis.h +++ b/xls/passes/dataflow_graph_analysis.h @@ -26,8 +26,8 @@ #include "absl/status/statusor.h" #include "xls/ir/node.h" #include "xls/passes/query_engine.h" -#include "ortools/graph/ebert_graph.h" -#include "ortools/graph/max_flow.h" +#include "ortools/graph/graph.h" +#include "ortools/graph/generic_max_flow.h" namespace xls { @@ -49,46 +49,40 @@ class DataflowGraphAnalysis { absl::StatusOr GetUnknownBitsFor(Node* node); private: + using ArcIndex = int32_t; + using NodeIndex = int32_t; + Node* current_sink_ = nullptr; absl::Status SolveFor(Node* node); - static constexpr operations_research::NodeIndex kSourceIndex = 0; - static constexpr operations_research::NodeIndex kSinkIndex = 1; + static constexpr NodeIndex kSourceIndex = 0; + static constexpr NodeIndex kSinkIndex = 1; - static operations_research::NodeIndex InIndex(size_t topo_index) { - return static_cast(2 * topo_index + 2); - } - operations_research::NodeIndex InIndex(Node* node) { - return InIndex(node_to_index_[node]); + static NodeIndex InIndex(size_t topo_index) { + return static_cast(2 * topo_index + 2); } + NodeIndex InIndex(Node* node) { return InIndex(node_to_index_[node]); } - bool IsOutIndex(operations_research::NodeIndex index) { - return index > 1 && (index % 2) == 1; - } - static operations_research::NodeIndex OutIndex(size_t topo_index) { - return static_cast(2 * topo_index + 3); - } - operations_research::NodeIndex OutIndex(Node* node) { - return OutIndex(node_to_index_[node]); + bool IsOutIndex(NodeIndex index) { return index > 1 && (index % 2) == 1; } + static NodeIndex OutIndex(size_t topo_index) { + return static_cast(2 * topo_index + 3); } + NodeIndex OutIndex(Node* node) { return OutIndex(node_to_index_[node]); } - static size_t TopoIndex(operations_research::NodeIndex index) { - return (index - 2) >> 1; - } + static size_t TopoIndex(NodeIndex index) { return (index - 2) >> 1; } size_t TopoIndex(Node* node) { return node_to_index_[node]; } - using Graph = ::util::ReverseArcStaticGraph; + using Graph = ::util::ReverseArcStaticGraph; const std::vector nodes_; absl::flat_hash_map node_to_index_; std::unique_ptr graph_; - absl::flat_hash_map arc_capacities_; + absl::flat_hash_map arc_capacities_; - absl::flat_hash_map internal_arcs_; - absl::flat_hash_map source_arcs_; - absl::flat_hash_map sink_arcs_; + absl::flat_hash_map internal_arcs_; + absl::flat_hash_map source_arcs_; + absl::flat_hash_map sink_arcs_; std::unique_ptr> max_flow_;