-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmst.py
134 lines (101 loc) · 2.85 KB
/
mst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from heapq import heappush, heappop
from sys import maxsize
"""Disjoint set union for kruskal algorithm"""
class Dsu:
def __init__(self, V):
self.parent = list(range(0, V + 1))
def find(self, u):
if u == self.parent[u]:
return u
else:
self.parent[u] = self.find(self.parent[u])
return self.parent[u]
"""judge whether two nodes belong to the same set"""
def is_same(self, u, v):
u = self.find(u)
v = self.find(v)
return u == v
"""join two set"""
def join(self, u, v):
u = self.find(u)
v = self.find(v)
self.parent[u] = v
def transformEdge(edges):
new_edges = []
for edge in edges:
new_edge = {
'first_node': edge[0],
'second_node': edge[1],
'weight': edge[2]
}
new_edges.append(new_edge)
return new_edges
def kruskal_mst(V, edges):
forest = Dsu(V)
mst = []
cost = 0
edges = sorted(edges, key=lambda edge: edge['weight'])
for edge in edges:
l = edge['first_node']
r = edge['second_node']
if not forest.is_same(l, r):
forest.join(l, r)
mst.append(edge)
cost += edge['weight']
return mst, cost
class AdjNode:
def __init__(self, node, weight):
self.node = node
self.weight = weight
def __lt__(self, other):
return self.weight < other.weight
def prim_mst(edges, V):
parent = [-1] * (V + 1)
in_mst = set()
Q = []
graph = [[] for _ in range(V + 1)]
min_dist = [maxsize] * (V + 1)
root, mst_sum = 1, 0
for edge in edges:
graph[edge['first_node']].append(AdjNode(edge['second_node'], edge['weight']))
graph[edge['second_node']].append(AdjNode(edge['first_node'], edge['weight']))
min_dist[root] = 0 # 起始节点的距离设为 0
heappush(Q, AdjNode(root, 0))
while Q:
min_node = heappop(Q)
node = min_node.node
weight = min_node.weight
if node in in_mst:
continue
in_mst.add(node) # 使用集合更高效
mst_sum += weight
for adj_node in graph[node]:
adj = adj_node.node
w = adj_node.weight
# 更新最小距离和父节点
if adj not in in_mst and w < min_dist[adj]:
min_dist[adj] = w
parent[adj] = node
heappush(Q, AdjNode(adj, w))
print(mst_sum)
if __name__ == "__main__" :
V = 7
data = [
[1, 2, 1],
[1, 3, 1],
[1, 5, 2],
[2, 6, 1],
[2, 4, 2],
[2, 3, 2],
[3, 4, 1],
[4, 5, 1],
[5, 6, 2],
[5, 7, 1],
[6, 7, 1]
]
e = transformEdge(data)
path, c = kruskal_mst(V, e)
for p in path:
print(p)
print(c)
prim_mst(e, V)