Skip to content

Commit

Permalink
[cleanup] Depend directly on graph.h, not the deprecated ebert_graph.h
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 730643865
  • Loading branch information
ericastor authored and copybara-github committed Feb 25, 2025
1 parent f5bd14e commit dcb54da
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 51 deletions.
4 changes: 2 additions & 2 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1081,8 +1081,8 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_ortools//ortools/graph:ebert_graph",
"@com_google_ortools//ortools/graph:max_flow",
"@com_google_ortools//ortools/graph",
"@com_google_ortools//ortools/graph:generic_max_flow",
],
)

Expand Down
45 changes: 22 additions & 23 deletions xls/passes/dataflow_graph_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@
#include "xls/ir/ternary.h"
#include "xls/ir/topo_sort.h"
#include "xls/passes/query_engine.h"
#include "ortools/graph/ebert_graph.h"
#include "ortools/graph/max_flow.h"
#include "ortools/graph/generic_max_flow.h"

namespace xls {

namespace {

using ::operations_research::GenericMaxFlow;

// Returns whether the given node *originates* data that is potentially unknown.
bool IsDataOriginating(Node* node) {
return node->OpIn({Op::kReceive, Op::kRegisterRead, Op::kParam,
Expand All @@ -57,25 +58,23 @@ DataflowGraphAnalysis::DataflowGraphAnalysis(FunctionBase* f,
: nodes_(TopoSort(f)) {
CHECK_LT(
nodes_.size(),
static_cast<size_t>(
(std::numeric_limits<operations_research::NodeIndex>::max() >> 1) -
1));
static_cast<size_t>((std::numeric_limits<NodeIndex>::max() >> 1) - 1));
node_to_index_.reserve(nodes_.size());
for (size_t i = 0; i < nodes_.size(); ++i) {
node_to_index_[nodes_[i]] = i;
}

graph_ = std::make_unique<Graph>();
graph_->AddNode(kSourceIndex);
absl::flat_hash_map<operations_research::ArcIndex, int64_t> arc_capacities;
absl::flat_hash_map<Node*, operations_research::ArcIndex> internal_arcs;
absl::flat_hash_map<Node*, operations_research::ArcIndex> source_arcs;
absl::flat_hash_map<Node*, operations_research::ArcIndex> sink_arcs;
absl::flat_hash_map<ArcIndex, int64_t> arc_capacities;
absl::flat_hash_map<Node*, ArcIndex> internal_arcs;
absl::flat_hash_map<Node*, ArcIndex> source_arcs;
absl::flat_hash_map<Node*, ArcIndex> sink_arcs;
for (size_t i = 0; i < nodes_.size(); ++i) {
Node* node = nodes_[i];

operations_research::NodeIndex v_in = InIndex(i);
operations_research::NodeIndex v_out = OutIndex(i);
NodeIndex v_in = InIndex(i);
NodeIndex v_out = OutIndex(i);
graph_->AddNode(v_in);
graph_->AddNode(v_out);

Expand Down Expand Up @@ -110,9 +109,9 @@ DataflowGraphAnalysis::DataflowGraphAnalysis(FunctionBase* f,
}
}

std::vector<operations_research::ArcIndex> arc_permutation;
std::vector<ArcIndex> arc_permutation;
graph_->Build(&arc_permutation);
auto permuted = [&](operations_research::ArcIndex arc) {
auto permuted = [&](ArcIndex arc) {
return arc < arc_permutation.size() ? arc_permutation[arc] : arc;
};
for (const auto& [arc, capacity] : arc_capacities) {
Expand All @@ -127,8 +126,8 @@ DataflowGraphAnalysis::DataflowGraphAnalysis(FunctionBase* f,
for (const auto& [node, arc] : sink_arcs) {
sink_arcs_[node] = permuted(arc);
}
max_flow_ = std::make_unique<operations_research::GenericMaxFlow<Graph>>(
graph_.get(), kSourceIndex, kSinkIndex);
max_flow_ = std::make_unique<GenericMaxFlow<Graph>>(graph_.get(),
kSourceIndex, kSinkIndex);
for (auto& [arc, capacity] : arc_capacities_) {
max_flow_->SetArcCapacity(arc, capacity);
}
Expand All @@ -148,15 +147,15 @@ absl::Status DataflowGraphAnalysis::SolveFor(Node* node) {
return absl::OkStatus();
}
switch (max_flow_->status()) {
case operations_research::GenericMaxFlow<Graph>::NOT_SOLVED:
case GenericMaxFlow<Graph>::NOT_SOLVED:
return absl::InternalError("Max flow solver failed to solve");
case operations_research::GenericMaxFlow<Graph>::OPTIMAL:
case GenericMaxFlow<Graph>::OPTIMAL:
return absl::InternalError("Max flow solver reported an unknown failure");
case operations_research::GenericMaxFlow<Graph>::INT_OVERFLOW:
case GenericMaxFlow<Graph>::INT_OVERFLOW:
return absl::InternalError("Possible overflow in max flow solver");
case operations_research::GenericMaxFlow<Graph>::BAD_INPUT:
case GenericMaxFlow<Graph>::BAD_INPUT:
return absl::InternalError("Bad input to max flow solver");
case operations_research::GenericMaxFlow<Graph>::BAD_RESULT:
case GenericMaxFlow<Graph>::BAD_RESULT:
return absl::InternalError("Bad result from max flow solver");
}
return absl::InternalError(
Expand All @@ -179,11 +178,11 @@ absl::StatusOr<std::vector<Node*>> DataflowGraphAnalysis::GetMinCutFor(
return std::vector<Node*>({});
}

std::vector<operations_research::NodeIndex> min_cut_indices;
std::vector<NodeIndex> min_cut_indices;
max_flow_->GetSourceSideMinCut(&min_cut_indices);

absl::flat_hash_set<Node*> min_cut_nodes;
for (operations_research::NodeIndex index : min_cut_indices) {
for (NodeIndex index : min_cut_indices) {
if (index == kSourceIndex) {
for (const auto& [source, arc_index] : source_arcs_) {
if (max_flow_->Flow(arc_index) > 0) {
Expand All @@ -205,7 +204,7 @@ absl::StatusOr<std::vector<Node*>> DataflowGraphAnalysis::GetMinCutFor(
// is not in the min cut.
if (absl::c_any_of(
max_flow_->graph()->OutgoingArcs(OutIndex(min_cut_node)),
[&](operations_research::ArcIndex out_arc) {
[&](ArcIndex out_arc) {
Node* target =
nodes_[TopoIndex(max_flow_->graph()->Head(out_arc))];
return !min_cut_nodes.contains(target) &&
Expand Down
46 changes: 20 additions & 26 deletions xls/passes/dataflow_graph_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
#include "absl/status/statusor.h"
#include "xls/ir/node.h"
#include "xls/passes/query_engine.h"
#include "ortools/graph/ebert_graph.h"
#include "ortools/graph/max_flow.h"
#include "ortools/graph/graph.h"
#include "ortools/graph/generic_max_flow.h"

namespace xls {

Expand All @@ -49,46 +49,40 @@ class DataflowGraphAnalysis {
absl::StatusOr<int64_t> GetUnknownBitsFor(Node* node);

private:
using ArcIndex = int32_t;
using NodeIndex = int32_t;

Node* current_sink_ = nullptr;
absl::Status SolveFor(Node* node);

static constexpr operations_research::NodeIndex kSourceIndex = 0;
static constexpr operations_research::NodeIndex kSinkIndex = 1;
static constexpr NodeIndex kSourceIndex = 0;
static constexpr NodeIndex kSinkIndex = 1;

static operations_research::NodeIndex InIndex(size_t topo_index) {
return static_cast<operations_research::NodeIndex>(2 * topo_index + 2);
}
operations_research::NodeIndex InIndex(Node* node) {
return InIndex(node_to_index_[node]);
static NodeIndex InIndex(size_t topo_index) {
return static_cast<NodeIndex>(2 * topo_index + 2);
}
NodeIndex InIndex(Node* node) { return InIndex(node_to_index_[node]); }

bool IsOutIndex(operations_research::NodeIndex index) {
return index > 1 && (index % 2) == 1;
}
static operations_research::NodeIndex OutIndex(size_t topo_index) {
return static_cast<operations_research::NodeIndex>(2 * topo_index + 3);
}
operations_research::NodeIndex OutIndex(Node* node) {
return OutIndex(node_to_index_[node]);
bool IsOutIndex(NodeIndex index) { return index > 1 && (index % 2) == 1; }
static NodeIndex OutIndex(size_t topo_index) {
return static_cast<NodeIndex>(2 * topo_index + 3);
}
NodeIndex OutIndex(Node* node) { return OutIndex(node_to_index_[node]); }

static size_t TopoIndex(operations_research::NodeIndex index) {
return (index - 2) >> 1;
}
static size_t TopoIndex(NodeIndex index) { return (index - 2) >> 1; }
size_t TopoIndex(Node* node) { return node_to_index_[node]; }

using Graph = ::util::ReverseArcStaticGraph<operations_research::NodeIndex,
operations_research::ArcIndex>;
using Graph = ::util::ReverseArcStaticGraph<NodeIndex, ArcIndex>;

const std::vector<Node*> nodes_;
absl::flat_hash_map<Node*, size_t> node_to_index_;

std::unique_ptr<Graph> graph_;
absl::flat_hash_map<operations_research::ArcIndex, int64_t> arc_capacities_;
absl::flat_hash_map<ArcIndex, int64_t> arc_capacities_;

absl::flat_hash_map<Node*, operations_research::ArcIndex> internal_arcs_;
absl::flat_hash_map<Node*, operations_research::ArcIndex> source_arcs_;
absl::flat_hash_map<Node*, operations_research::ArcIndex> sink_arcs_;
absl::flat_hash_map<Node*, ArcIndex> internal_arcs_;
absl::flat_hash_map<Node*, ArcIndex> source_arcs_;
absl::flat_hash_map<Node*, ArcIndex> sink_arcs_;

std::unique_ptr<operations_research::GenericMaxFlow<Graph>> max_flow_;

Expand Down

0 comments on commit dcb54da

Please sign in to comment.