Skip to content

Commit

Permalink
Implement condensation tentatively
Browse files Browse the repository at this point in the history
  • Loading branch information
kazuki0824 committed Dec 22, 2024
1 parent 25077aa commit 5045623
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 4 deletions.
1 change: 1 addition & 0 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ from .rustworkx import number_connected_components as number_connected_component
from .rustworkx import number_weakly_connected_components as number_weakly_connected_components
from .rustworkx import node_connected_component as node_connected_component
from .rustworkx import strongly_connected_components as strongly_connected_components
from .rustworkx import condensation as condensation
from .rustworkx import weakly_connected_components as weakly_connected_components
from .rustworkx import digraph_adjacency_matrix as digraph_adjacency_matrix
from .rustworkx import graph_adjacency_matrix as graph_adjacency_matrix
Expand Down
1 change: 1 addition & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def number_connected_components(graph: PyGraph, /) -> int: ...
def number_weakly_connected_components(graph: PyDiGraph, /) -> bool: ...
def node_connected_component(graph: PyGraph, node: int, /) -> set[int]: ...
def strongly_connected_components(graph: PyDiGraph, /) -> list[list[int]]: ...
def condensation(graph: PyDiGraph, /, sccs=None) -> PyDiGraph: ...
def weakly_connected_components(graph: PyDiGraph, /) -> list[set[int]]: ...
def digraph_adjacency_matrix(
graph: PyDiGraph[_S, _T],
Expand Down
81 changes: 77 additions & 4 deletions src/connectivity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ use super::{
};

use hashbrown::{HashMap, HashSet};
use petgraph::algo;
use petgraph::algo::condensation;
use petgraph::graph::DiGraph;
use petgraph::graph::{DiGraph, IndexType};
use petgraph::stable_graph::NodeIndex;
use petgraph::unionfind::UnionFind;
use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeCount, NodeIndexable, Visitable};
use petgraph::{algo, Graph};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyDict;
Expand All @@ -35,6 +34,7 @@ use rayon::prelude::*;

use ndarray::prelude::*;
use numpy::IntoPyArray;
use petgraph::prelude::StableGraph;

use crate::iterators::{
AllPairsMultiplePathMapping, BiconnectedComponents, Chains, EdgeList, NodeIndices,
Expand Down Expand Up @@ -114,6 +114,79 @@ pub fn strongly_connected_components(graph: &digraph::PyDiGraph) -> Vec<Vec<usiz
.collect()
}

fn condensation_inner<N, E, Ty, Ix>(
py: &Python,
g: Graph<N, E, Ty, Ix>,
make_acyclic: bool,
sccs: Option<Vec<Vec<usize>>>,
) -> StableGraph<PyObject, PyObject, Ty, Ix>
where
Ty: EdgeType,
Ix: IndexType,
N: ToPyObject,
E: ToPyObject,
{
// Don't use into_iter to avoid extra allocations
let sccs = if let Some(sccs) = sccs {
sccs.iter()
.map(|row| row.iter().map(|x| NodeIndex::new(*x)).collect())
.collect()
} else {
algo::kosaraju_scc(&g)
};

let mut condensed: StableGraph<Vec<N>, E, Ty, Ix> =
StableGraph::with_capacity(sccs.len(), g.edge_count());

// Build a map from old indices to new ones.
let mut node_map = vec![NodeIndex::end(); g.node_count()];
for comp in sccs {
let new_nix = condensed.add_node(Vec::new());
for nix in comp {
node_map[nix.index()] = new_nix;
}
}

// Consume nodes and edges of the old graph and insert them into the new one.
let (nodes, edges) = g.into_nodes_edges();
for (nix, node) in nodes.into_iter().enumerate() {
condensed[node_map[nix]].push(node.weight);
}
for edge in edges {
let source = node_map[edge.source().index()];
let target = node_map[edge.target().index()];
if make_acyclic {
if source != target {
condensed.update_edge(source, target, edge.weight);
}
} else {
condensed.add_edge(source, target, edge.weight);
}
}
condensed.map(|_, w| w.to_object(*py), |_, w| w.to_object(*py))
}

#[pyfunction]
#[pyo3(text_signature = "(graph, /, sccs=None)", signature=(graph, sccs=None))]
pub fn condensation(
py: Python,
graph: &digraph::PyDiGraph,
sccs: Option<Vec<Vec<usize>>>,
) -> digraph::PyDiGraph {
let g = graph.graph.clone();

let condensed = condensation_inner(&py, g.into(), true, sccs);

digraph::PyDiGraph {
graph: condensed,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
multigraph: true,
attrs: py.None(),
}
}

/// Return the first cycle encountered during DFS of a given PyDiGraph,
/// empty list is returned if no cycle is found
///
Expand Down Expand Up @@ -295,7 +368,7 @@ pub fn is_semi_connected(graph: &digraph::PyDiGraph) -> PyResult<bool> {
temp_graph.add_edge(node_map[source.index()], node_map[target.index()], ());
}

let condensed = condensation(temp_graph, true);
let condensed = algo::condensation(temp_graph, true);
let n = condensed.node_count();
let weight_fn =
|_: petgraph::graph::EdgeReference<()>| Ok::<usize, std::convert::Infallible>(1usize);
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(cycle_basis))?;
m.add_wrapped(wrap_pyfunction!(simple_cycles))?;
m.add_wrapped(wrap_pyfunction!(strongly_connected_components))?;
m.add_wrapped(wrap_pyfunction!(condensation))?;
m.add_wrapped(wrap_pyfunction!(digraph_dfs_edges))?;
m.add_wrapped(wrap_pyfunction!(graph_dfs_edges))?;
m.add_wrapped(wrap_pyfunction!(digraph_find_cycle))?;
Expand Down
52 changes: 52 additions & 0 deletions tests/digraph/test_strongly_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,55 @@ def test_number_strongly_connected_big(self):
node = G.add_node(i)
G.add_child(node, str(i), {})
self.assertEqual(len(rustworkx.strongly_connected_components(G)), 200000)


class TestCondensation(unittest.TestCase):
def setUp(self):
# グラフをセットアップ
self.graph = rustworkx.PyDiGraph()
self.node_a = self.graph.add_node("a")
self.node_b = self.graph.add_node("b")
self.node_c = self.graph.add_node("c")
self.node_d = self.graph.add_node("d")
self.node_e = self.graph.add_node("e")
self.node_f = self.graph.add_node("f")
self.node_g = self.graph.add_node("g")
self.node_h = self.graph.add_node("h")

# エッジを追加
self.graph.add_edge(self.node_a, self.node_b, "a->b")
self.graph.add_edge(self.node_b, self.node_c, "b->c")
self.graph.add_edge(self.node_c, self.node_d, "c->d")
self.graph.add_edge(self.node_d, self.node_a, "d->a") # サイクル: a -> b -> c -> d -> a

self.graph.add_edge(self.node_b, self.node_e, "b->e")

self.graph.add_edge(self.node_e, self.node_f, "e->f")
self.graph.add_edge(self.node_f, self.node_g, "f->g")
self.graph.add_edge(self.node_g, self.node_h, "g->h")
self.graph.add_edge(self.node_h, self.node_e, "h->e") # サイクル: e -> f -> g -> h -> e

def test_condensation(self):
# condensation関数を呼び出し
condensed_graph = rustworkx.condensation(self.graph)

# ノード数を確認(2つのサイクルが1つずつのノードに縮約される)
self.assertEqual(
len(condensed_graph.node_indices()), 2
) # [SCC(a, b, c, d), SCC(e, f, g, h)]

# エッジ数を確認
self.assertEqual(
len(condensed_graph.edge_indices()), 1
) # Edge: [SCC(a, b, c, d)] -> [SCC(e, f, g, h)]

# 縮約されたノードの内容を確認
nodes = list(condensed_graph.nodes())
scc1 = nodes[0]
scc2 = nodes[1]
self.assertTrue(set(scc1) == {"a", "b", "c", "d"} or set(scc2) == {"a", "b", "c", "d"})
self.assertTrue(set(scc1) == {"e", "f", "g", "h"} or set(scc2) == {"e", "f", "g", "h"})

# エッジの内容を確認
weight = condensed_graph.edges()[0]
self.assertIn("b->e", weight) # 縮約後のグラフにおいて、正しいエッジが残っていることを確認

0 comments on commit 5045623

Please sign in to comment.