4
4
"""
5
5
from collections import deque
6
6
from concurrent .futures import ThreadPoolExecutor
7
+ from multiprocessing import Manager
8
+ import threading
7
9
from pydatastructs .utils .misc_util import (
8
10
_comp , raise_if_backend_is_not_python , Backend , AdjacencyListGraphNode )
9
11
from pydatastructs .miscellaneous_data_structures import (
@@ -1407,7 +1409,7 @@ def maximum_matching(graph: Graph, algorithm: str, **kwargs) -> set:
1407
1409
>>> graph.add_edge('v_2', 'v_3')
1408
1410
>>> graph.add_edge('v_4', 'v_1')
1409
1411
>>> maximum_matching(graph, 'hopcroft_karp', make_undirected=True)
1410
- >>> {('v_3 ', 'v_2 '), ('v_1 ', 'v_4 ')}
1412
+ >>> {('v_1 ', 'v_4 '), ('v_3 ', 'v_2 ')}
1411
1413
1412
1414
References
1413
1415
==========
@@ -1431,6 +1433,7 @@ def maximum_matching(graph: Graph, algorithm: str, **kwargs) -> set:
1431
1433
return getattr (algorithms , func )(graph )
1432
1434
1433
1435
def _maximum_matching_hopcroft_karp_parallel (graph : Graph , num_threads : int ) -> set :
1436
+
1434
1437
U = set ()
1435
1438
V = set ()
1436
1439
bipartiteness , coloring = bipartite_coloring (graph )
@@ -1444,20 +1447,22 @@ def _maximum_matching_hopcroft_karp_parallel(graph: Graph, num_threads: int) ->
1444
1447
else :
1445
1448
V .add (node )
1446
1449
1447
-
1448
- pair_U = {u : None for u in U }
1449
- pair_V = {v : None for v in V }
1450
- dist = {}
1450
+ manager = Manager ()
1451
+ pair_U = manager . dict ( {u : None for u in U })
1452
+ pair_V = manager . dict ( {v : None for v in V })
1453
+ lock = threading . RLock ()
1451
1454
1452
1455
def bfs ():
1453
1456
queue = Queue ()
1457
+ dist = {}
1454
1458
for u in U :
1455
1459
if pair_U [u ] is None :
1456
1460
dist [u ] = 0
1457
1461
queue .append (u )
1458
1462
else :
1459
1463
dist [u ] = float ('inf' )
1460
1464
dist [None ] = float ('inf' )
1465
+
1461
1466
while queue :
1462
1467
u = queue .popleft ()
1463
1468
if dist [u ] < dist [None ]:
@@ -1470,36 +1475,77 @@ def bfs():
1470
1475
elif dist .get (alt , float ('inf' )) == float ('inf' ):
1471
1476
dist [alt ] = dist [u ] + 1
1472
1477
queue .append (alt )
1473
- return dist .get (None , float ('inf' )) != float ('inf' )
1474
1478
1475
- def dfs (u ):
1479
+ return dist , dist .get (None , float ('inf' )) != float ('inf' )
1480
+
1481
+ def dfs_worker (u , dist , local_pair_U , local_pair_V , thread_results ):
1482
+ if dfs (u , dist , local_pair_U , local_pair_V ) and u in local_pair_U and local_pair_U [u ] is not None :
1483
+ thread_results .append ((u , local_pair_U [u ]))
1484
+ return True
1485
+ return False
1486
+
1487
+ def dfs (u , dist , local_pair_U , local_pair_V ):
1476
1488
if u is None :
1477
1489
return True
1490
+
1478
1491
for v in graph .neighbors (u ):
1479
- if v .name in pair_V :
1480
- alt = pair_V [v .name ]
1492
+ if v .name in local_pair_V :
1493
+ alt = local_pair_V [v .name ]
1481
1494
if alt is None :
1482
- pair_V [v .name ] = u
1483
- pair_U [u ] = v .name
1495
+ local_pair_V [v .name ] = u
1496
+ local_pair_U [u ] = v .name
1484
1497
return True
1485
1498
elif dist .get (alt , float ('inf' )) == dist .get (u , float ('inf' )) + 1 :
1486
- if dfs (alt ):
1487
- pair_V [v .name ] = u
1488
- pair_U [u ] = v .name
1499
+ if dfs (alt , dist , local_pair_U , local_pair_V ):
1500
+ local_pair_V [v .name ] = u
1501
+ local_pair_U [u ] = v .name
1489
1502
return True
1503
+
1490
1504
dist [u ] = float ('inf' )
1491
1505
return False
1492
1506
1493
1507
matching = set ()
1494
1508
1495
- while bfs ():
1496
- unmatched_nodes = [u for u in U if pair_U [u ] is None ]
1509
+ while True :
1510
+ dist , has_path = bfs ()
1511
+ if not has_path :
1512
+ break
1497
1513
1498
- with ThreadPoolExecutor (max_workers = num_threads ) as Executor :
1499
- results = Executor .map (dfs , unmatched_nodes )
1514
+ unmatched = [u for u in U if pair_U [u ] is None ]
1515
+ if not unmatched :
1516
+ break
1517
+
1518
+ batch_size = max (1 , len (unmatched ) // num_threads )
1519
+ batches = [unmatched [i :i + batch_size ] for i in range (0 , len (unmatched ), batch_size )]
1520
+
1521
+ for batch in batches :
1522
+ all_results = []
1523
+
1524
+ with ThreadPoolExecutor (max_workers = num_threads ) as executor :
1525
+ futures = []
1526
+ for u in batch :
1527
+ local_pair_U = dict (pair_U )
1528
+ local_pair_V = dict (pair_V )
1529
+ thread_results = []
1500
1530
1501
- for u , success in zip (unmatched_nodes , results ):
1502
- if success and pair_U [u ] is not None :
1531
+ futures .append (executor .submit (
1532
+ dfs_worker , u , dist .copy (), local_pair_U , local_pair_V , thread_results
1533
+ ))
1534
+
1535
+ for future in futures :
1536
+ future .result ()
1537
+
1538
+ with lock :
1539
+ for u in batch :
1540
+ if pair_U [u ] is None :
1541
+ result = dfs (u , dist .copy (), pair_U , pair_V )
1542
+ if result and pair_U [u ] is not None :
1543
+ matching .add ((u , pair_U [u ]))
1544
+
1545
+ with lock :
1546
+ matching = set ()
1547
+ for u in U :
1548
+ if pair_U [u ] is not None :
1503
1549
matching .add ((u , pair_U [u ]))
1504
1550
1505
1551
return matching
@@ -1548,7 +1594,7 @@ def maximum_matching_parallel(graph: Graph, algorithm: str, num_threads: int, **
1548
1594
>>> graph.add_bidirectional_edge('v_2', 'v_3')
1549
1595
>>> graph.add_bidirectional_edge('v_4', 'v_1')
1550
1596
>>> maximum_matching_parallel(graph, 'hopcroft_karp', 1, make_undirected=True)
1551
- >>> {('v_3 ', 'v_2 '), ('v_1 ', 'v_4 ')}
1597
+ >>> {('v_1 ', 'v_4 '), ('v_3 ', 'v_2 ')}
1552
1598
1553
1599
References
1554
1600
==========
0 commit comments