Skip to content

Commit 42369d1

Browse files
committed
Fix parallel processing implementation to avoid race conditions
1 parent 8419c53 commit 42369d1

File tree

1 file changed

+67
-21
lines changed

1 file changed

+67
-21
lines changed

pydatastructs/graphs/algorithms.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
"""
55
from collections import deque
66
from concurrent.futures import ThreadPoolExecutor
7+
from multiprocessing import Manager
8+
import threading
79
from pydatastructs.utils.misc_util import (
810
_comp, raise_if_backend_is_not_python, Backend, AdjacencyListGraphNode)
911
from pydatastructs.miscellaneous_data_structures import (
@@ -1407,7 +1409,7 @@ def maximum_matching(graph: Graph, algorithm: str, **kwargs) -> set:
14071409
>>> graph.add_edge('v_2', 'v_3')
14081410
>>> graph.add_edge('v_4', 'v_1')
14091411
>>> maximum_matching(graph, 'hopcroft_karp', make_undirected=True)
1410-
>>> {('v_3', 'v_2'), ('v_1', 'v_4')}
1412+
>>> {('v_1', 'v_4'), ('v_3', 'v_2')}
14111413
14121414
References
14131415
==========
@@ -1431,6 +1433,7 @@ def maximum_matching(graph: Graph, algorithm: str, **kwargs) -> set:
14311433
return getattr(algorithms, func)(graph)
14321434

14331435
def _maximum_matching_hopcroft_karp_parallel(graph: Graph, num_threads: int) -> set:
1436+
14341437
U = set()
14351438
V = set()
14361439
bipartiteness, coloring = bipartite_coloring(graph)
@@ -1444,20 +1447,22 @@ def _maximum_matching_hopcroft_karp_parallel(graph: Graph, num_threads: int) ->
14441447
else:
14451448
V.add(node)
14461449

1447-
1448-
pair_U = {u: None for u in U}
1449-
pair_V = {v: None for v in V}
1450-
dist = {}
1450+
manager = Manager()
1451+
pair_U = manager.dict({u: None for u in U})
1452+
pair_V = manager.dict({v: None for v in V})
1453+
lock = threading.RLock()
14511454

14521455
def bfs():
14531456
queue = Queue()
1457+
dist = {}
14541458
for u in U:
14551459
if pair_U[u] is None:
14561460
dist[u] = 0
14571461
queue.append(u)
14581462
else:
14591463
dist[u] = float('inf')
14601464
dist[None] = float('inf')
1465+
14611466
while queue:
14621467
u = queue.popleft()
14631468
if dist[u] < dist[None]:
@@ -1470,36 +1475,77 @@ def bfs():
14701475
elif dist.get(alt, float('inf')) == float('inf'):
14711476
dist[alt] = dist[u] + 1
14721477
queue.append(alt)
1473-
return dist.get(None, float('inf')) != float('inf')
14741478

1475-
def dfs(u):
1479+
return dist, dist.get(None, float('inf')) != float('inf')
1480+
1481+
def dfs_worker(u, dist, local_pair_U, local_pair_V, thread_results):
1482+
if dfs(u, dist, local_pair_U, local_pair_V) and u in local_pair_U and local_pair_U[u] is not None:
1483+
thread_results.append((u, local_pair_U[u]))
1484+
return True
1485+
return False
1486+
1487+
def dfs(u, dist, local_pair_U, local_pair_V):
14761488
if u is None:
14771489
return True
1490+
14781491
for v in graph.neighbors(u):
1479-
if v.name in pair_V:
1480-
alt = pair_V[v.name]
1492+
if v.name in local_pair_V:
1493+
alt = local_pair_V[v.name]
14811494
if alt is None:
1482-
pair_V[v.name] = u
1483-
pair_U[u] = v.name
1495+
local_pair_V[v.name] = u
1496+
local_pair_U[u] = v.name
14841497
return True
14851498
elif dist.get(alt, float('inf')) == dist.get(u, float('inf')) + 1:
1486-
if dfs(alt):
1487-
pair_V[v.name] = u
1488-
pair_U[u] = v.name
1499+
if dfs(alt, dist, local_pair_U, local_pair_V):
1500+
local_pair_V[v.name] = u
1501+
local_pair_U[u] = v.name
14891502
return True
1503+
14901504
dist[u] = float('inf')
14911505
return False
14921506

14931507
matching = set()
14941508

1495-
while bfs():
1496-
unmatched_nodes = [u for u in U if pair_U[u] is None]
1509+
while True:
1510+
dist, has_path = bfs()
1511+
if not has_path:
1512+
break
14971513

1498-
with ThreadPoolExecutor(max_workers=num_threads) as Executor:
1499-
results = Executor.map(dfs, unmatched_nodes)
1514+
unmatched = [u for u in U if pair_U[u] is None]
1515+
if not unmatched:
1516+
break
1517+
1518+
batch_size = max(1, len(unmatched) // num_threads)
1519+
batches = [unmatched[i:i+batch_size] for i in range(0, len(unmatched), batch_size)]
1520+
1521+
for batch in batches:
1522+
all_results = []
1523+
1524+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
1525+
futures = []
1526+
for u in batch:
1527+
local_pair_U = dict(pair_U)
1528+
local_pair_V = dict(pair_V)
1529+
thread_results = []
15001530

1501-
for u, success in zip(unmatched_nodes, results):
1502-
if success and pair_U[u] is not None:
1531+
futures.append(executor.submit(
1532+
dfs_worker, u, dist.copy(), local_pair_U, local_pair_V, thread_results
1533+
))
1534+
1535+
for future in futures:
1536+
future.result()
1537+
1538+
with lock:
1539+
for u in batch:
1540+
if pair_U[u] is None:
1541+
result = dfs(u, dist.copy(), pair_U, pair_V)
1542+
if result and pair_U[u] is not None:
1543+
matching.add((u, pair_U[u]))
1544+
1545+
with lock:
1546+
matching = set()
1547+
for u in U:
1548+
if pair_U[u] is not None:
15031549
matching.add((u, pair_U[u]))
15041550

15051551
return matching
@@ -1548,7 +1594,7 @@ def maximum_matching_parallel(graph: Graph, algorithm: str, num_threads: int, **
15481594
>>> graph.add_bidirectional_edge('v_2', 'v_3')
15491595
>>> graph.add_bidirectional_edge('v_4', 'v_1')
15501596
>>> maximum_matching_parallel(graph, 'hopcroft_karp', 1, make_undirected=True)
1551-
>>> {('v_3', 'v_2'), ('v_1', 'v_4')}
1597+
>>> {('v_1', 'v_4'), ('v_3', 'v_2')}
15521598
15531599
References
15541600
==========

0 commit comments

Comments
 (0)