forked from mfms-ncsu/Matroid-Parity
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase_graph.py
executable file
·108 lines (83 loc) · 3.88 KB
/
base_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import networkx as nx
import union_find as uf
ELEMENT_ID_KEY = 'ELEMENT_ID_KEY'
PAIR_ID_KEY = 'PAIR_ID_KEY'
BaseEdge = tuple[int, int, dict[str, int]]
class BaseGraph:
"""
This class is used as an interface with `networkx`,
in the perspective of providing a library-less graph implementation.
You should not use self.nx_instance if possible.
"""
def __init__(self):
self.nx_instance: nx.MultiGraph = None
self.elements: dict[int, BaseEdge] = {}
self.max_element_id = 0
def init_instance(self, nx_instance=None):
if not nx_instance:
self.nx_instance = nx.MultiGraph()
else:
self.nx_instance = nx_instance
def add_edge(self, u: int, v: int, edge_id: int):
self.nx_instance.add_edge(u, v, ELEMENT_ID_KEY=edge_id, PAIR_ID_KEY=(edge_id)//2)
self.elements[edge_id] = (u, v, {ELEMENT_ID_KEY: edge_id, PAIR_ID_KEY: (edge_id)//2})
self.max_element_id = max(edge_id, self.max_element_id)
def edges(self) -> list[BaseEdge]:
return list(self.nx_instance.edges(data=True))
def nodes(self) -> list[int]:
return list(self.nx_instance.nodes())
def adjacent_edges(self, edge: BaseEdge) -> list[BaseEdge]:
return list(self.nx_instance.edges([edge[0], edge[1]], data=True))
def get_spanning_forest(self, matching: list[int] = None) -> None | tuple[list[BaseEdge], list[BaseEdge], dict[int, tuple[int, int]]]:
"""
Get a spanning maximum forest of the graph, used for completing a matching into a basis.
Returns the forest, the elements not in the forest as well as a rooted representation
"""
uf_set = list(range(max(list(self.nx_instance.nodes()))+1))
if not matching:
matching = []
forest = []
non_forest = []
# Establish the base forest, with no singletons
for edge in self.edges():
if not edge[2][ELEMENT_ID_KEY] in matching:
non_forest.append(edge)
else:
if uf.uf_find(uf_set, edge[0]) == uf.uf_find(uf_set, edge[1]):
return None # matching has cycle
uf.uf_union(uf_set, edge[0], edge[1])
forest.append(edge)
next_singleton_id = self.max_element_id + 1
# Add the singletons necessary to have a spanning forest
for edge in self.edges():
if uf.uf_find(uf_set, edge[0]) == uf.uf_find(uf_set, edge[1]):
continue
uf.uf_union(uf_set, edge[0], edge[1])
singleton = (edge[0], edge[1], {ELEMENT_ID_KEY: next_singleton_id, PAIR_ID_KEY: None})
next_singleton_id = next_singleton_id + 1
forest.append(singleton)
# Compute a rooted forest representation by greedily doing Depth-First explorations
remaining = self.nodes()
edges_remaining = forest[:]
parent = {}
stack = []
while len(remaining) > 0: # pick a root
root = remaining.pop()
parent[root] = (root, None)
stack.append(root)
while len(stack) > 0: # do a DFS from that root
current = stack.pop()
edges_to_remove = []
for e in edges_remaining:
if not (e[0] == current or e[1] == current):
continue
endpoint = e[0] if e[1] == current else e[1]
if endpoint in parent.keys():
continue
parent[endpoint] = (current, e[2][ELEMENT_ID_KEY])
stack.append(endpoint)
remaining.remove(endpoint)
edges_to_remove.append(e)
for e in edges_to_remove:
edges_remaining.remove(e)
return forest, non_forest, parent