diff --git a/eule/core.py b/eule/core.py index dfa1724..cb76f9c 100644 --- a/eule/core.py +++ b/eule/core.py @@ -1,6 +1,7 @@ """Main module.""" from typing import List, Dict, Union +from multiprocessing import Pool from copy import deepcopy from warnings import warn from reprlib import repr @@ -106,8 +107,66 @@ def euler_generator( sets_[set_key] = [] set_keys = cleared_set_keys(sets_) + +def euler_generator_worker(args): + sets, set_keys, set_key = args + results = [] + this_set = sets[set_key] + + other_keys = [key for key in set_keys if key != set_key] + if not this_set or not other_keys: + return results + + # Complementary sets + csets = {key: sets[key] for key in other_keys} + + for euler_tuple, celements in euler_generator(csets): + comb_elems = difference(celements, this_set) + + if comb_elems: + sorted_comb_key = ordered_tuplify(euler_tuple) + results.append((sorted_comb_key, comb_elems)) + + # Update sets + for euler_set_key in sorted_comb_key: + sets[euler_set_key] = difference(sets[euler_set_key], comb_elems) + + comb_elems = intersection(celements, this_set) + if comb_elems: + comb_key = update_ordered_tuple(euler_tuple, set_key) + results.append((comb_key, comb_elems)) + + for euler_set_key in comb_key: + sets[euler_set_key] = difference(sets[euler_set_key], comb_elems) + sets[set_key] = difference(sets[set_key], comb_elems) + + if sets[set_key]: + results.append(((set_key,), sets[set_key])) + sets[set_key] = [] + + return results + +def euler_generator_parallel(sets: SetsType): + sets_ = deepcopy(sets) + sets_ = validate_euler_generator_input(sets_) + set_keys = cleared_set_keys(sets_) + + if len(set_keys) == 1: + comb_key = set_keys[0] + comb_elements = list(sets_.values())[0] + yield ((comb_key,), comb_elements) + return + + with Pool() as pool: + results = pool.map(euler_generator_worker, [(sets_, set_keys, key) for key in set_keys]) + for result in results: + for res in result: + yield res + +def euler_parallel(sets: SetsType): + return dict(euler_generator_parallel(sets)) def euler( sets: SetsType diff --git a/tests/test_core.py b/tests/test_core.py index 2564823..40c1edd 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -6,7 +6,7 @@ from eule.operations import intersection from eule.core import euler_generator, euler, \ - euler_keys, euler_boundaries, Euler + euler_keys, euler_boundaries, Euler, euler_parallel from eule.utils import sequence_to_set from .fixtures import \ @@ -104,6 +104,7 @@ def test_euler(test_sets, euler_sets): } assert euler(setified_test_sets) == setified_euler_sets + assert euler_parallel(setified_test_sets) == setified_euler_sets def test_euler_keys(sets, euler_sets_keys):