From ad5473110fd4de80e4606bb661cde94742534146 Mon Sep 17 00:00:00 2001 From: Matthew Davis Date: Tue, 12 Nov 2024 13:09:45 +0100 Subject: [PATCH 01/12] add unit test for deadlock issue --- lambda_multiprocessing/test_main.py | 85 +++++++++++++++++++++++++++++ lambda_multiprocessing/timeout.py | 35 ++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 lambda_multiprocessing/timeout.py diff --git a/lambda_multiprocessing/test_main.py b/lambda_multiprocessing/test_main.py index 9a522c7..440f7fe 100644 --- a/lambda_multiprocessing/test_main.py +++ b/lambda_multiprocessing/test_main.py @@ -8,6 +8,7 @@ import boto3 from moto import mock_aws +from timeout import TimeoutManager, TestTimeoutException # add an overhead for duration when asserting the duration of child processes # if other processes are hogging CPU, make this bigger @@ -42,6 +43,8 @@ class TestCase(unittest.TestCase): # with self.assertDuration(1, 2): # something # to assert that something takes between 1 to 2 seconds to run + # If the task takes forever, this assertion will not be raised + # For a potential eternal task, use timeout.TimeoutManager def assertDuration(self, min_t=None, max_t=None): class AssertDuration: def __init__(self, test): @@ -403,6 +406,79 @@ def test_submit_after_terminate(self): with self.assertRaises(ValueError): p.apply_async(square, (1,)) +class TestDeadlock(TestCase): + # test this issue: + # https://github.com/mdavis-xyz/lambda_multiprocessing/issues/17 + def test_map_deadlock(self): + + child_sleep = 0.01 + num_payloads = 6 + + # use standard library to start with + # and to measure a 'normal' duration + # (the time spend passing data between processes is longer than the sleep + # inside the worker) + expected_duration = child_sleep * num_payloads + data = [self.generate_big_data() for _ in range(num_payloads)] + start_t = time() + + with multiprocessing.Pool(processes=1) as p: + p.map(self.work, data) + + end_t = time() + stdlib_duration = end_t - start_t + print(f"{stdlib_duration=}") + + # this timeout manager doesn't work + # need to run the parent inside another process/thread? + # now our one + print("Running test which might deadlock") + data = [self.generate_big_data() for _ in range(num_payloads)] + with Pool(processes=1) as p: + with TimeoutManager(stdlib_duration*2, "This Library's map deadlocked"): + try: + p.map(self.work, data) + except TestTimeoutException: + p.terminate() + raise + + def test_map_async_deadlock(self): + + child_sleep = 0.4 + num_payloads = 3 + data = [self.generate_big_data() for _ in range(num_payloads)] + + # use standard library to start with + # and to measure a 'normal' duration + # (the time spend passing data between processes is longer than the sleep + # inside the worker) + expected_duration = child_sleep * num_payloads + start_t = time() + with multiprocessing.Pool(processes=1) as p: + results = p.map_async(self.work, data) + [r.get() for r in results] + end_t = time() + stdlib_duration = end_t - start_t + + # now our one + with Pool(processes=1) as p: + with TimeoutManager(stdlib_duration*2, "This Library's map_async deadlocked"): + try: + results = p.map_async(self.work, data) + [r.get() for r in results] + except TestTimeoutException: + p.terminate() + raise + + + @classmethod + def generate_big_data(cls, sz=2**26) -> bytes: + return 'x' * sz + + @classmethod + def work(cls, x, delay=0.3): + sleep(delay) + return x # must be a global method to be pickleable def upload(args: Tuple[str, str, bytes]): @@ -457,5 +533,14 @@ def test_memory_leak(self): for j in range(10**2): p.map(square, range(10**3)) +class Timeout: + def __init__(self, seconds, message='Test Timed Out'): + self.seconds = seconds + self.message = message + + def __enter__(self): + + return self + if __name__ == '__main__': unittest.main() diff --git a/lambda_multiprocessing/timeout.py b/lambda_multiprocessing/timeout.py new file mode 100644 index 0000000..3c58fb6 --- /dev/null +++ b/lambda_multiprocessing/timeout.py @@ -0,0 +1,35 @@ +# Timeout context manager for unit testing + +import signal +from math import ceil + +class TestTimeoutException(Exception): + """Exception raised when a test takes too long""" + pass + +class TimeoutManager: + # if a float is passed as seconds, it will be rounded up + def __init__(self, seconds: int, description = "Test timed out"): + self.seconds = ceil(seconds) + self.description = description + self.old_handler = None + + def __enter__(self): + self.old_handler = signal.signal(signal.SIGALRM, self.timeout_handler) + signal.alarm(self.seconds) + return self + + def timeout_handler(self, signum, frame): + raise TestTimeoutException(self.description) + + def __exit__(self, exc_type, exc_value, traceback): + # Disable the alarm + signal.alarm(0) + # Restore old signal handler + signal.signal(signal.SIGALRM, self.old_handler) + + if exc_type is TestTimeoutException: + return False # Let the TestTimeoutException exception propagate + + # Propagate any other exceptions, or continue if no exception + return False From 8a7460ae82877f4448428337479e7cc9b36d4cc6 Mon Sep 17 00:00:00 2001 From: Matthew Davis Date: Tue, 12 Nov 2024 13:48:16 +0100 Subject: [PATCH 02/12] Update tests for map_async to match stdlib behavior --- README.md | 5 +- lambda_multiprocessing/test_main.py | 175 +++++++++++++++------------- 2 files changed, 100 insertions(+), 80 deletions(-) diff --git a/README.md b/README.md index 559bfa5..a1c0d01 100644 --- a/README.md +++ b/README.md @@ -113,9 +113,12 @@ Then you can run the unit tests with: python3 -m unittest ``` -`CICD` is for the GitHub Actions which run unit tests and integration tests. +The tests themselves are defined in `lambda_multiprocessing/test_main.py`. + +`CICD` is for the GitHub Actions which run the unit tests and integration tests. You probably don't need to touch those. + ## Design When you `__enter__` the pool, it creates several `Child`s. diff --git a/lambda_multiprocessing/test_main.py b/lambda_multiprocessing/test_main.py index 440f7fe..ad05ee3 100644 --- a/lambda_multiprocessing/test_main.py +++ b/lambda_multiprocessing/test_main.py @@ -30,6 +30,10 @@ def divide(a, b): def return_args_kwargs(*args, **kwargs): return {'args': args, 'kwargs': kwargs} +def return_with_sleep(x, delay=0.3): + sleep(delay) + return x + class TestStdLib(unittest.TestCase): @unittest.skip('Need to set up to remove /dev/shm') def test_standard_library(self): @@ -81,22 +85,23 @@ def test(self): sleep(t) class TestApply(TestCase): + pool_generator = Pool def test_simple(self): args = range(5) - with Pool() as p: + with self.pool_generator() as p: actual = [p.apply(square, (x,)) for x in args] with multiprocessing.Pool() as p: expected = [p.apply(square, (x,)) for x in args] self.assertEqual(actual, expected) def test_two_args(self): - with Pool() as p: + with self.pool_generator() as p: actual = p.apply(sum_two, (3, 4)) self.assertEqual(actual, sum_two(3, 4)) def test_kwargs(self): - with Pool() as p: + with self.pool_generator() as p: actual = p.apply(return_args_kwargs, (1, 2), {'x': 'X', 'y': 'Y'}) expected = { 'args': (1,2), @@ -107,10 +112,16 @@ def test_kwargs(self): } self.assertEqual(actual, expected) +# rerun all the tests with stdlib +# to confirm we have the same behavior +class TestApplyStdLib(TestApply): + pool_generator = multiprocessing.Pool + class TestApplyAsync(TestCase): + pool_generator = Pool def test_result(self): - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(square, (2,)) result = r.get(2) self.assertEqual(result, square(2)) @@ -118,7 +129,7 @@ def test_result(self): def test_twice(self): args = range(5) # now try twice with the same pool - with Pool() as p: + with self.pool_generator() as p: ret = [p.apply_async(square, (x,)) for x in args] ret.extend(p.apply_async(square, (x,)) for x in args) @@ -127,7 +138,7 @@ def test_twice(self): def test_time(self): # test the timing to confirm it's in parallel n = 4 - with Pool(n) as p: + with self.pool_generator(n) as p: with self.assertDuration(min_t=n-1, max_t=(n-1)+delta): with self.assertDuration(max_t=1): ret = [p.apply_async(sleep, (r,)) for r in range(n)] @@ -138,14 +149,14 @@ def test_unclean_exit(self): # .get after __exit__, but process finishes before __exit__ with self.assertRaises(multiprocessing.TimeoutError): t = 1 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(square, (t,)) sleep(1) r.get() with self.assertRaises(multiprocessing.TimeoutError): t = 1 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(square, (t,)) sleep(1) r.get(1) @@ -153,26 +164,26 @@ def test_unclean_exit(self): # __exit__ before process finishes with self.assertRaises(multiprocessing.TimeoutError): t = 1 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(sleep, (t,)) r.get(t+1) # .get() with arg after __exit__ with self.assertRaises(multiprocessing.TimeoutError): t = 1 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(sleep, (t,)) r.get() # .get() without arg after __exit__ def test_get_not_ready_a(self): t = 2 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(sleep, (t,)) with self.assertRaises(multiprocessing.TimeoutError): r.get(t-1) # result not ready get def test_get_not_ready_b(self): t = 2 - with Pool() as p: + with self.pool_generator() as p: # check same exception exists from main r = p.apply_async(sleep, (t,)) with self.assertRaises(TimeoutError): @@ -180,7 +191,7 @@ def test_get_not_ready_b(self): def test_get_not_ready_c(self): t = 2 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(sleep, (t,)) sleep(1) self.assertFalse(r.ready()) @@ -188,7 +199,7 @@ def test_get_not_ready_c(self): self.assertTrue(r.ready()) def test_wait(self): - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(square, (1,)) with self.assertDuration(max_t=delta): r.wait() @@ -212,13 +223,13 @@ def test_wait(self): ret = r.get(0) def test_get_twice(self): - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(square, (2,)) self.assertEqual(r.get(), square(2)) self.assertEqual(r.get(), square(2)) def test_successful(self): - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(square, (1,)) sleep(delta) self.assertTrue(r.successful()) @@ -236,13 +247,13 @@ def test_successful(self): self.assertFalse(r.successful()) def test_two_args(self): - with Pool() as p: + with self.pool_generator() as p: ret = p.apply_async(sum_two, (1, 2)) ret.wait() self.assertEqual(ret.get(), sum_two(1,2)) def test_kwargs(self): - with Pool() as p: + with self.pool_generator() as p: actual = p.apply_async(return_args_kwargs, (1, 2), {'x': 'X', 'y': 'Y'}).get() expected = { 'args': (1,2), @@ -255,14 +266,21 @@ def test_kwargs(self): def test_error_handling(self): with self.assertRaises(AssertionError): - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(fail, (1,)) r.get() +# retrun tests with standard library +# to confirm our behavior matches theirs +class TestApplyAsyncStdLib(TestApplyAsync): + pool_generator = multiprocessing.Pool + class TestMap(TestCase): + pool_generator = Pool + def test_simple(self): args = range(5) - with Pool() as p: + with self.pool_generator() as p: actual = p.map(square, args) self.assertIsInstance(actual, list) expected = [square(x) for x in args] @@ -271,111 +289,121 @@ def test_simple(self): def test_duration(self): n = 2 - with Pool(n) as p: + with self.pool_generator(n) as p: with self.assertDuration(min_t=(n-1)-delta, max_t=(n+1)+delta): p.map(sleep, range(n)) def test_error_handling(self): - with Pool() as p: + with self.pool_generator() as p: with self.assertRaises(AssertionError): p.map(fail, range(2)) @unittest.skip('Need to implement chunking to fix this') def test_long_iter(self): - with Pool() as p: + with self.pool_generator() as p: p.map(square, range(10**3)) def test_without_with(self): - # check that the standard library Pool - # can do .map() without `with` - p = multiprocessing.Pool(3) + p = self.pool_generator(3) ret = p.map(square, [1,2]) self.assertEqual(ret, [1,2*2]) p.close() p.join() - # now check that this library can do it - p = Pool(3) - ret = p.map(square, [1,2]) - self.assertEqual(ret, [1,2*2]) - p.close() - p.join() - - +class TestMapStdLib(TestMap): + pool_generator = multiprocessing.Pool class TestMapAsync(TestCase): + pool_generator = Pool + def test_simple(self): args = range(5) - with Pool() as p: + with self.pool_generator() as p: actual = p.map_async(square, args) - self.assertIsInstance(actual, list) - for x in actual: - self.assertIsInstance(x, AsyncResult) - - results = [a.get() for a in actual] + self.assertIsInstance(actual, (AsyncResult, multiprocessing.pool.AsyncResult)) + results = actual.get() self.assertEqual(results, [square(e) for e in args]) def test_duration(self): n = 2 - with Pool(n) as p: + with self.pool_generator(n) as p: with self.assertDuration(min_t=(n-1)-delta, max_t=(n+1)+delta): with self.assertDuration(max_t=delta): results = p.map_async(sleep, range(n)) - [r.get() for r in results] + results.get() def test_error_handling(self): - with Pool() as p: + with self.pool_generator() as p: r = p.map_async(fail, range(2)) with self.assertRaises(AssertionError): - [x.get(1) for x in r] + r.get() + +class TestMapAsyncStdLib(TestMapAsync): + pool_generator = multiprocessing.Pool class TestStarmap(TestCase): + pool_generator = Pool + def test(self): - with Pool() as p: + with self.pool_generator() as p: actual = p.starmap(sum_two, [(1,2), (3,4)]) expected = [(1+2), (3+4)] self.assertEqual(actual, expected) def test_error_handling(self): - with Pool() as p: + with self.pool_generator() as p: with self.assertRaises(ZeroDivisionError): p.starmap(divide, [(1,2), (3,0)]) - +class TestStarmapStdLib(TestStarmap): + pool_generator = multiprocessing.Pool class TestStarmapAsync(TestCase): + pool_generator = Pool + def test(self): - with Pool() as p: - actual = p.starmap_async(sum_two, [(1,2), (3,4)]) - self.assertIsInstance(actual, list) - actual = [r.get() for r in actual] + with self.pool_generator() as p: + response = p.starmap_async(sum_two, [(1,2), (3,4)]) + self.assertIsInstance(response, (AsyncResult, multiprocessing.pool.AsyncResult)) + actual = response.get() expected = [(1+2), (3+4)] self.assertEqual(actual, expected) def test_error_handling(self): - with Pool() as p: + with self.pool_generator() as p: results = p.starmap_async(divide, [(1,2), (3,0)]) with self.assertRaises(ZeroDivisionError): - [r.get() for r in results] + results.get() +class TestStarmapAsyncStdLib(TestStarmapAsync): + pool_generator = multiprocessing.Pool + +class TestExit(TestCase): + # only test this with our library, + # not the standard library + # because the standard library has a bug + # https://github.com/python/cpython/issues/79659 -class TestTidyUp(TestCase): # test that the implicit __exit__ # waits for child process to finish def test_exit(self): t = 1 with Pool() as p: - r = p.apply_async(sleep, (t,)) t1 = time() + r = p.apply_async(sleep, (t,)) t2 = time() + breakpoint() self.assertLessEqual(abs((t2-t1)-t), delta) +class TestTidyUp(TestCase): + pool_generator = Pool + # test that .close() stops new submisssions # but does not halt existing num_processes # nor wait for them to finish def test_close(self): t = 1 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(sleep, (t,)) with self.assertDuration(max_t=delta): p.close() @@ -384,7 +412,7 @@ def test_close(self): pass # makes traceback from __exit__ clearer def test_submit_after_close(self): - with Pool() as p: + with self.pool_generator() as p: p.close() with self.assertRaises(ValueError): p.apply_async(square, (1,)) @@ -393,7 +421,7 @@ def test_submit_after_close(self): # wait for child process to finish def test_terminate(self): with self.assertDuration(max_t=delta): - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(sleep, (1,)) t1 = time() p.terminate() @@ -401,12 +429,16 @@ def test_terminate(self): self.assertLessEqual(t2-t1, delta) def test_submit_after_terminate(self): - with Pool() as p: + with self.pool_generator() as p: p.terminate() with self.assertRaises(ValueError): p.apply_async(square, (1,)) +class TestTidyUpStdLib(TestTidyUp): + pool_generator = multiprocessing.Pool + class TestDeadlock(TestCase): + # test this issue: # https://github.com/mdavis-xyz/lambda_multiprocessing/issues/17 def test_map_deadlock(self): @@ -423,21 +455,19 @@ def test_map_deadlock(self): start_t = time() with multiprocessing.Pool(processes=1) as p: - p.map(self.work, data) + p.map(return_with_sleep, data) end_t = time() stdlib_duration = end_t - start_t - print(f"{stdlib_duration=}") # this timeout manager doesn't work # need to run the parent inside another process/thread? # now our one - print("Running test which might deadlock") data = [self.generate_big_data() for _ in range(num_payloads)] with Pool(processes=1) as p: with TimeoutManager(stdlib_duration*2, "This Library's map deadlocked"): try: - p.map(self.work, data) + p.map(return_with_sleep, data) except TestTimeoutException: p.terminate() raise @@ -455,7 +485,7 @@ def test_map_async_deadlock(self): expected_duration = child_sleep * num_payloads start_t = time() with multiprocessing.Pool(processes=1) as p: - results = p.map_async(self.work, data) + results = p.map_async(return_with_sleep, data) [r.get() for r in results] end_t = time() stdlib_duration = end_t - start_t @@ -464,7 +494,7 @@ def test_map_async_deadlock(self): with Pool(processes=1) as p: with TimeoutManager(stdlib_duration*2, "This Library's map_async deadlocked"): try: - results = p.map_async(self.work, data) + results = p.map_async(return_with_sleep, data) [r.get() for r in results] except TestTimeoutException: p.terminate() @@ -475,10 +505,6 @@ def test_map_async_deadlock(self): def generate_big_data(cls, sz=2**26) -> bytes: return 'x' * sz - @classmethod - def work(cls, x, delay=0.3): - sleep(delay) - return x # must be a global method to be pickleable def upload(args: Tuple[str, str, bytes]): @@ -526,21 +552,12 @@ def test_moto(self): self.assertEqual(ret, data) class TestSlow(TestCase): - #@unittest.skip('Very slow') + @unittest.skip('Very slow') def test_memory_leak(self): for i in range(10**2): with Pool() as p: for j in range(10**2): p.map(square, range(10**3)) -class Timeout: - def __init__(self, seconds, message='Test Timed Out'): - self.seconds = seconds - self.message = message - - def __enter__(self): - - return self - if __name__ == '__main__': unittest.main() From 6cc9f4f519ee05e885190e0c4811ebe33c0e8b32 Mon Sep 17 00:00:00 2001 From: Matthew Davis Date: Tue, 12 Nov 2024 14:01:38 +0100 Subject: [PATCH 03/12] fix deadlock issue in map with scheduler --- lambda_multiprocessing/main.py | 50 ++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/lambda_multiprocessing/main.py b/lambda_multiprocessing/main.py index a88bf6e..92ccbd1 100644 --- a/lambda_multiprocessing/main.py +++ b/lambda_multiprocessing/main.py @@ -5,6 +5,8 @@ import random import os from time import time +from select import select +import signal class Child: proc: Process @@ -321,8 +323,52 @@ def starmap_async(self, func, iterable: Iterable[Iterable], chunksize=None, call return [self.apply_async(func, args) for args in iterable] def starmap(self, func, iterable: Iterable[Iterable], chunksize=None, callback=None, error_callback=None) -> List[Any]: - results = self.starmap_async(func, iterable, chunksize, callback, error_callback) - return [r.get() for r in results] + if chunksize: + raise NotImplementedError("chunksize not implemented") + if callback: + raise NotImplementedError("callback not implemented") + if error_callback: + raise NotImplementedError("error_callback not implemented") + + idle_children = set(self.children) + ids = [] + pending_results: Dict[UUID, AsyncResult] = {} + + for args in iterable: + if not idle_children: + # wait for a child to become idle + # by waiting for any of the pipes from children to become readable + ready, _, _ = select([c.parent_conn for c in self.children], [], []) + + # at least one child is idle. + # check all children, read their last result from the pipe + # then issue the new task + for child in self.children: + if child.parent_conn in ready: + assert child.parent_conn.poll() + child.flush() + idle_children.add(child) + + + child = idle_children.pop() + async_result = child.submit(func, args) + pending_results[async_result.id] = async_result + ids.append(async_result.id) + + + if len(idle_children) < len(self.children): + # if at least one child is still working + # wait with select + ready, _, _ = select([c.parent_conn for c in self.children if c not in idle_children], [], []) + + # get all the results + # re-arranging the order + results = [] + for (i, id) in enumerate(ids): + result = pending_results[id].get() + results.append(result) + + return results def imap(self, func, iterable, chunksize=None): raise NotImplementedError("Only normal apply, map, starmap and their async equivilents have been implemented") From b6894075bc91ed3bce30ec3d1f73fe2b54d1b67a Mon Sep 17 00:00:00 2001 From: Matthew Davis Date: Tue, 12 Nov 2024 14:32:40 +0100 Subject: [PATCH 04/12] Rewrite map_async to return single result --- lambda_multiprocessing/main.py | 85 +++++++++++++++++++++++++++-- lambda_multiprocessing/test_main.py | 31 ++++++++--- 2 files changed, 105 insertions(+), 11 deletions(-) diff --git a/lambda_multiprocessing/main.py b/lambda_multiprocessing/main.py index 92ccbd1..57aef81 100644 --- a/lambda_multiprocessing/main.py +++ b/lambda_multiprocessing/main.py @@ -1,6 +1,6 @@ from multiprocessing import TimeoutError, Process, Pipe from multiprocessing.connection import Connection -from typing import Any, Iterable, List, Dict, Tuple, Union +from typing import Any, Iterable, List, Dict, Tuple, Union, Optional from uuid import uuid4, UUID import random import os @@ -223,6 +223,78 @@ def successful(self): return self.result[1] is None +# map_async and starmap_async return a single AsyncResult +# which is a list of actual results +# This class aggregates many AsyncResult into one +class AsyncResultList(AsyncResult): + def __init__(self, child_results: List[AsyncResult]): + self.child_results = child_results + self.result: List[Union[Tuple[Any, None], Tuple[None, Exception]]] = None + + # assume the result is in the self.child.result_cache + # move it into self.result + def _load(self): + for c in self.child_results: + c._load() + + # Return the result when it arrives. + # If timeout is not None and the result does not arrive within timeout seconds + # then multiprocessing.TimeoutError is raised. + # If the remote call raised an exception then that exception will be reraised by get(). + # .get() must remember the result + # and return it again multiple times + # delete it from the Child.result_cache to avoid memory leak + def get(self, timeout=None): + if timeout: + end_t = time() + timeout + else: + end_t = None + + results = [] + for (i, c) in enumerate(self.child_results): + # Consider cumulative timeout + if timeout is not None: + timeout_remaining = end_t - time() + else: + timeout_remaining = None + + try: + result = c.get(timeout=timeout_remaining) + except Exception: + print(f"Exception raised for {i}th task out of {len(self.child_results)}") + # terminate remaining children + for c2 in self.child_results[i+1:]: + c2.child.terminate() + raise + + results.append(result) + return results + + # Wait until the result is available or until timeout seconds pass. + def wait(self, timeout=None): + + if timeout: + end_t = time() + timeout + else: + end_t = None + for c in self.child_results: + # Consider cumulative timeout + if timeout is not None: + timeout_remaining = end_t - time() + else: + timeout_remaining = None + + c.wait(timeout_remaining) + + # Return whether the call has completed. + def ready(self): + return all(c.ready() for c in self.child_results) + + # Return whether the call completed without raising an exception. + # Will raise ValueError if the result is not ready. + def successful(self): + return all(c.successful() for c in self.child_results) + class Pool: def __init__(self, processes=None, initializer=None, initargs=None, maxtasksperchild=None, context=None): if processes is None: @@ -309,18 +381,23 @@ def apply_async(self, func, args=(), kwds=None, callback=None, error_callback=No c = random.choice([c for c in self.children if c.queue_sz <= min_q_sz]) return c.submit(func, args, kwds) - def map_async(self, func, iterable, chunksize=None, callback=None, error_callback=None) -> List[AsyncResult]: + def map_async(self, func, iterable, chunksize=None, callback=None, error_callback=None) -> AsyncResult: return self.starmap_async(func, zip(iterable), chunksize, callback, error_callback) def map(self, func, iterable, chunksize=None, callback=None, error_callback=None) -> List: return self.starmap(func, zip(iterable), chunksize, callback, error_callback) - def starmap_async(self, func, iterable: Iterable[Iterable], chunksize=None, callback=None, error_callback=None) -> List[AsyncResult]: + def starmap_async(self, func, iterable: Iterable[Iterable], chunksize=None, callback=None, error_callback=None) -> AsyncResult: if chunksize: raise NotImplementedError("Haven't implemented chunksizes. Infinite chunksize only.") if callback or error_callback: raise NotImplementedError("Haven't implemented callbacks") - return [self.apply_async(func, args) for args in iterable] + results = [self.apply_async(func, args) for args in iterable] + + # aggregate into one result + result = AsyncResultList(child_results=results) + + return result def starmap(self, func, iterable: Iterable[Iterable], chunksize=None, callback=None, error_callback=None) -> List[Any]: if chunksize: diff --git a/lambda_multiprocessing/test_main.py b/lambda_multiprocessing/test_main.py index ad05ee3..7b8e8c1 100644 --- a/lambda_multiprocessing/test_main.py +++ b/lambda_multiprocessing/test_main.py @@ -1,8 +1,8 @@ import unittest -import multiprocessing +import multiprocessing, multiprocessing.pool from lambda_multiprocessing import Pool, TimeoutError, AsyncResult from time import time, sleep -from typing import Tuple +from typing import Tuple, Optional from pathlib import Path import os @@ -34,6 +34,15 @@ def return_with_sleep(x, delay=0.3): sleep(delay) return x +def _raise(ex: Optional[Exception]): + if ex: + raise ex + +class ExceptionA(Exception): + pass +class ExceptionB(Exception): + pass + class TestStdLib(unittest.TestCase): @unittest.skip('Need to set up to remove /dev/shm') def test_standard_library(self): @@ -338,6 +347,14 @@ def test_error_handling(self): with self.assertRaises(AssertionError): r.get() + def test_multi_error(self): + # standard library can raise either error + + with self.assertRaises((ExceptionA, ExceptionB)): + with self.pool_generator() as p: + r = p.map_async(_raise, (None, ExceptionA("Task 1"), None, ExceptionB("Task 3"))) + r.get() + class TestMapAsyncStdLib(TestMapAsync): pool_generator = multiprocessing.Pool @@ -392,7 +409,6 @@ def test_exit(self): t1 = time() r = p.apply_async(sleep, (t,)) t2 = time() - breakpoint() self.assertLessEqual(abs((t2-t1)-t), delta) class TestTidyUp(TestCase): @@ -486,7 +502,7 @@ def test_map_async_deadlock(self): start_t = time() with multiprocessing.Pool(processes=1) as p: results = p.map_async(return_with_sleep, data) - [r.get() for r in results] + results.get() end_t = time() stdlib_duration = end_t - start_t @@ -495,7 +511,7 @@ def test_map_async_deadlock(self): with TimeoutManager(stdlib_duration*2, "This Library's map_async deadlocked"): try: results = p.map_async(return_with_sleep, data) - [r.get() for r in results] + results.get() except TestTimeoutException: p.terminate() raise @@ -526,11 +542,12 @@ def test_moto(self): bucket_name = 'mybucket' key = 'my-file' data = b"123" - client = boto3.client('s3') + region = os.getenv("AWS_DEFAULT_REGION", "ap-southeast-2") + client = boto3.client('s3', region_name=region) client.create_bucket( Bucket=bucket_name, CreateBucketConfiguration={ - 'LocationConstraint': 'ap-southeast-2' + 'LocationConstraint': region }, ) # upload in a different thread From cabcc8ba1308719696070ea1254b05330b3cd7d7 Mon Sep 17 00:00:00 2001 From: Matthew Davis Date: Tue, 12 Nov 2024 15:03:42 +0100 Subject: [PATCH 05/12] add more accurate test for map_async total time --- lambda_multiprocessing/test_main.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lambda_multiprocessing/test_main.py b/lambda_multiprocessing/test_main.py index 7b8e8c1..e91a00b 100644 --- a/lambda_multiprocessing/test_main.py +++ b/lambda_multiprocessing/test_main.py @@ -334,11 +334,14 @@ def test_simple(self): self.assertEqual(results, [square(e) for e in args]) def test_duration(self): - n = 2 - with self.pool_generator(n) as p: - with self.assertDuration(min_t=(n-1)-delta, max_t=(n+1)+delta): + sleep_duration = 0.5 + n_procs = 2 + num_tasks_per_proc = 2 + expected_wall_time = sleep_duration * num_tasks_per_proc + with self.pool_generator(n_procs) as p: + with self.assertDuration(min_t=expected_wall_time-delta, max_t=expected_wall_time+delta): with self.assertDuration(max_t=delta): - results = p.map_async(sleep, range(n)) + results = p.map_async(sleep, [1] * (num_tasks_per_proc * n_procs)) results.get() def test_error_handling(self): From 3ed1ea854e446190292b80fcbfe8b01f7a72a44a Mon Sep 17 00:00:00 2001 From: Matthew Davis Date: Tue, 12 Nov 2024 15:09:46 +0100 Subject: [PATCH 06/12] add deadlock test to check map_async returns immediately --- lambda_multiprocessing/test_main.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/lambda_multiprocessing/test_main.py b/lambda_multiprocessing/test_main.py index e91a00b..cad6192 100644 --- a/lambda_multiprocessing/test_main.py +++ b/lambda_multiprocessing/test_main.py @@ -336,13 +336,14 @@ def test_simple(self): def test_duration(self): sleep_duration = 0.5 n_procs = 2 - num_tasks_per_proc = 2 + num_tasks_per_proc = 3 expected_wall_time = sleep_duration * num_tasks_per_proc with self.pool_generator(n_procs) as p: with self.assertDuration(min_t=expected_wall_time-delta, max_t=expected_wall_time+delta): with self.assertDuration(max_t=delta): - results = p.map_async(sleep, [1] * (num_tasks_per_proc * n_procs)) + results = p.map_async(sleep, [sleep_duration] * (num_tasks_per_proc * n_procs)) results.get() + def test_error_handling(self): with self.pool_generator() as p: @@ -395,6 +396,8 @@ def test_error_handling(self): with self.assertRaises(ZeroDivisionError): results.get() + + class TestStarmapAsyncStdLib(TestStarmapAsync): pool_generator = multiprocessing.Pool @@ -519,6 +522,25 @@ def test_map_async_deadlock(self): p.terminate() raise + # test that map_async returns immediately + # even when there are multiple tasks per child + # with payloads bigger than the buffer + def test_nonblocking(self): + sleep_duration = 0.5 + n_procs = 2 + num_tasks_per_proc = 3 + expected_wall_time = sleep_duration * num_tasks_per_proc + # use a big payload, to ensure the buffer fills up from the first arg + with Pool(n_procs) as p: + try: + with TimeoutManager(sleep_duration * num_tasks_per_proc * 2, "This Library's map_async deadlocked"): + with self.assertDuration(max_t=delta): + results = p.map_async(return_with_sleep, [(self.generate_big_data(), sleep_duration)] * (num_tasks_per_proc * n_procs)) + + results.get() + except TestTimeoutException: + p.terminate() + raise @classmethod def generate_big_data(cls, sz=2**26) -> bytes: From df87f461ca5e47b617714439b4961cd069919c45 Mon Sep 17 00:00:00 2001 From: Matthew Davis Date: Wed, 13 Nov 2024 17:49:19 +0100 Subject: [PATCH 07/12] Batch tasks for async maps --- lambda_multiprocessing/main.py | 251 +++++++++++++++++++--------- lambda_multiprocessing/test_main.py | 46 +++-- lambda_multiprocessing/timeout.py | 20 ++- 3 files changed, 215 insertions(+), 102 deletions(-) diff --git a/lambda_multiprocessing/main.py b/lambda_multiprocessing/main.py index 57aef81..63a3d3f 100644 --- a/lambda_multiprocessing/main.py +++ b/lambda_multiprocessing/main.py @@ -1,12 +1,53 @@ from multiprocessing import TimeoutError, Process, Pipe from multiprocessing.connection import Connection -from typing import Any, Iterable, List, Dict, Tuple, Union, Optional +from typing import Any, Iterable, List, Dict, Tuple, Union, Optional, Callable from uuid import uuid4, UUID import random import os from time import time from select import select -import signal +from itertools import repeat + +# This is what we send down the pipe from parent to child +class Request: + pass + +class Task[Request]: + func: Callable + args: List + kwds: Dict + id: UUID + + def __init__(self, func, args=(), kwds={}, id=None): + self.func = func + self.args = args + self.kwds = kwds or {} + self.id = id or uuid4() + +class QuitSignal[Request]: + pass + +class RunBatchSignal[Request]: + pass + +# this is what we send back through the pipe from child to parent +class Response: + id: UUID + pass + +class SuccessResponse: + result: Any + def __init__(self, id: UUID, result: Any): + self.result = result + self.id = id + +# processing the task raised an exception +# we save the exception in this object +class FailResponse: + exception: Exception + def __init__(self, id: UUID, exception: Exception): + self.id = id + self.exception = exception class Child: proc: Process @@ -19,12 +60,12 @@ class Child: # does not include the termination command from parent to child queue_sz: int = 0 - # parent_conn.send() to give stuff to the child + # parent_conn.send(Request) to give stuff to the child # parent_conn.recv() to get results back from child parent_conn: Connection child_conn: Connection - result_cache: Dict[UUID, Tuple[Any, Exception]] = {} + response_cache: Dict[UUID, Tuple[Any, Exception]] = {} _closed: bool = False @@ -49,47 +90,81 @@ def __init__(self, main_proc=False): # {id: (None, err)} if func raised exception err # [None, True] -> exit gracefully (write nothing to the pipe) def spin(self) -> None: + quit_signal = False while True: - (job, quit_signal) = self.child_conn.recv() + req_buf = [] + # first read in tasks until we get a pause/quit signal + while True: + request = self.child_conn.recv() + if isinstance(request, Task): + req_buf.append(request) + elif isinstance(request, QuitSignal): + # don't quit yet. Finish what's in the buffer. + quit_signal |= True + break # stop reading new stuff from the pipe + elif isinstance(request, RunBatchSignal): + # stop reading from Pipe, process what's in the request buffer + break + result_buf = [] + for req in req_buf: + assert isinstance(req, Task) + # process the result + result = self._do_work(req) + result_buf.append(result) + + # send all the results + for result in result_buf: + self.child_conn.send(result) + if quit_signal: break - else: - (id, func, args, kwds) = job - result = self._do_work(id, func, args, kwds) - self.child_conn.send(result) self.child_conn.close() - def _do_work(self, id, func, args, kwds) -> Union[Tuple[Any, None], Tuple[None, Exception]]: + # applies the function, catching any exception if it occurs + def _do_work(self, task: Task) -> Response: try: - ret = {id: (func(*args, **kwds), None)} + result = task.func(*task.args, **task.kwds) except Exception as e: # how to handle KeyboardInterrupt? - ret = {id: (None, e)} - assert isinstance(list(ret.keys())[0], UUID) - return ret - - def submit(self, func, args=(), kwds=None) -> 'AsyncResult': + resp = FailResponse(id=task.id, exception=e) + else: + resp = SuccessResponse(id=task.id, result=result) + return resp + + # this sends a task to the child + # if as_batch=False, the child will start work on it immediately + # If as_batch=True, the child will load this into it's local buffer + # but won't start processing until we send a RunBatchSignal with self.run_batch() + def submit(self, func, args=(), kwds=None, as_batch=False) -> 'AsyncResult': if self._closed: raise ValueError("Cannot submit tasks after closure") - if kwds is None: - kwds = {} - id = uuid4() - self.parent_conn.send([(id, func, args, kwds), None]) + request = Task(func=func, args=args, kwds=kwds) + + self.parent_conn.send(request) if self.main_proc: self.child_conn.recv() - ret = self._do_work(id, func, args, kwds) + ret = self._do_work(request) self.child_conn.send(ret) + elif not as_batch: + # treat this as a batch of 1 + self.run_batch() self.queue_sz += 1 - return AsyncResult(id=id, child=self) + return AsyncResult(id=request.id, child=self) + + # non-blocking + # Tells the child to start commencing work on all tasks sent up until now + def run_batch(self): + self.parent_conn.send(RunBatchSignal()) # grab all results in the pipe from child to parent - # save them to self.result_cache - def flush(self): + # save them to self.response_cache + def flush_results(self): # watch out, when the other end is closed, a termination byte appears, so .poll() returns True while (not self.parent_conn.closed) and (self.queue_sz > 0) and self.parent_conn.poll(0): result = self.parent_conn.recv() - assert isinstance(list(result.keys())[0], UUID) - self.result_cache.update(result) + id = result.id + assert id not in self.response_cache + self.response_cache[id] = result self.queue_sz -= 1 # prevent new tasks from being submitted @@ -99,10 +174,10 @@ def close(self): if not self._closed: if not self.main_proc: # send quit signal to child - self.parent_conn.send([None, True]) + self.parent_conn.send(QuitSignal()) else: # no child process to close - self.flush() + self.flush_results() self.child_conn.close() # keep track of closure, @@ -124,7 +199,7 @@ def join(self): finally: self.proc.close() - self.flush() + self.flush_results() self.parent_conn.close() @@ -157,13 +232,13 @@ def __init__(self, id: UUID, child: Child): assert isinstance(id, UUID) self.id = id self.child = child - self.result: Union[Tuple[Any, None], Tuple[None, Exception]] = None + self.response: Result = None - # assume the result is in the self.child.result_cache - # move it into self.result + # assume the result is in the self.child.response_cache + # move it into self.response def _load(self): - self.result = self.child.result_cache[self.id] - del self.child.result_cache[self.id] # prevent memory leak + self.response = self.child.response_cache[self.id] + del self.child.response_cache[self.id] # prevent memory leak # Return the result when it arrives. # If timeout is not None and the result does not arrive within timeout seconds @@ -171,15 +246,15 @@ def _load(self): # If the remote call raised an exception then that exception will be reraised by get(). # .get() must remember the result # and return it again multiple times - # delete it from the Child.result_cache to avoid memory leak + # delete it from the Child.response_cache to avoid memory leak def get(self, timeout=None): - if self.result is not None: - (response, ex) = self.result - if ex: - raise ex - else: - return response - elif self.id in self.child.result_cache: + if self.response is not None: + if isinstance(self.response, SuccessResponse): + return self.response.result + elif isinstance(self.response, FailResponse): + assert isinstance(self.response.exception, Exception) + raise self.response.exception + elif self.id in self.child.response_cache: self._load() return self.get(0) else: @@ -191,12 +266,12 @@ def get(self, timeout=None): # Wait until the result is available or until timeout seconds pass. def wait(self, timeout=None): - start_t = time() - if self.result is None: - self.child.flush() + if self.response is None: + start_t = time() + self.child.flush_results() # the result we want might not be the next result # it might be the 2nd or 3rd next - while (self.id not in self.child.result_cache) and \ + while (self.id not in self.child.response_cache) and \ ((timeout is None) or (time() - timeout < start_t)): if timeout is None: self.child.parent_conn.poll() @@ -205,23 +280,23 @@ def wait(self, timeout=None): remaining = timeout - elapsed_so_far self.child.parent_conn.poll(remaining) if self.child.parent_conn.poll(0): - self.child.flush() + self.child.flush_results() # Return whether the call has completed. def ready(self): - self.child.flush() - return self.result or (self.id in self.child.result_cache) + self.child.flush_results() + return self.response or (self.id in self.child.response_cache) # Return whether the call completed without raising an exception. # Will raise ValueError if the result is not ready. def successful(self): - if self.result is None: + if self.response is None: if not self.ready(): raise ValueError("Result is not ready") else: self._load() - return self.result[1] is None + return isinstance(self.response, SuccessResponse) # map_async and starmap_async return a single AsyncResult # which is a list of actual results @@ -231,7 +306,7 @@ def __init__(self, child_results: List[AsyncResult]): self.child_results = child_results self.result: List[Union[Tuple[Any, None], Tuple[None, Exception]]] = None - # assume the result is in the self.child.result_cache + # assume the result is in the self.child.response_cache # move it into self.result def _load(self): for c in self.child_results: @@ -243,28 +318,19 @@ def _load(self): # If the remote call raised an exception then that exception will be reraised by get(). # .get() must remember the result # and return it again multiple times - # delete it from the Child.result_cache to avoid memory leak + # delete it from the Child.response_cache to avoid memory leak def get(self, timeout=None): - if timeout: - end_t = time() + timeout - else: - end_t = None + self.wait(timeout) + assert self.ready() results = [] for (i, c) in enumerate(self.child_results): - # Consider cumulative timeout - if timeout is not None: - timeout_remaining = end_t - time() - else: - timeout_remaining = None - try: - result = c.get(timeout=timeout_remaining) + result = c.get(0) except Exception: print(f"Exception raised for {i}th task out of {len(self.child_results)}") - # terminate remaining children for c2 in self.child_results[i+1:]: - c2.child.terminate() + c2.child.flush_results() raise results.append(result) @@ -367,19 +433,10 @@ def apply_async(self, func, args=(), kwds=None, callback=None, error_callback=No if error_callback: raise NotImplementedError("error_callback not implemented") - if self._closed: - raise ValueError("Pool already closed") - if kwds is None: - kwds = {} - + results = self._apply_batch_async(func, [args], [kwds]) + assert len(results) == 1 + return results[0] - # choose the first idle process if there is one - # if not, choose the process with the shortest queue - for c in self.children: - c.flush() - min_q_sz = min(c.queue_sz for c in self.children) - c = random.choice([c for c in self.children if c.queue_sz <= min_q_sz]) - return c.submit(func, args, kwds) def map_async(self, func, iterable, chunksize=None, callback=None, error_callback=None) -> AsyncResult: return self.starmap_async(func, zip(iterable), chunksize, callback, error_callback) @@ -392,13 +449,43 @@ def starmap_async(self, func, iterable: Iterable[Iterable], chunksize=None, call raise NotImplementedError("Haven't implemented chunksizes. Infinite chunksize only.") if callback or error_callback: raise NotImplementedError("Haven't implemented callbacks") - results = [self.apply_async(func, args) for args in iterable] - # aggregate into one result + results = self._apply_batch_async(func, args_iterable=iterable, kwds_iterable=repeat({})) result = AsyncResultList(child_results=results) - return result + # like starmap, but has argument for keyword args + # so apply_async can call this + # (apply_async supports kwargs but starmap_async does not) + def _apply_batch_async(self, func, args_iterable: Iterable[Iterable], kwds_iterable: Iterable[Dict]) -> List[AsyncResult]: + if self._closed: + raise ValueError("Pool already closed") + + for c in self.children: + c.flush_results() + + results = [] + children_called = set() + for (args, kwds) in zip(args_iterable, kwds_iterable): + child = self._choose_child() # already flushed results + children_called.add(child) + result = child.submit(func, args, kwds, as_batch=True) + results.append(result) + + for child in children_called: + child.run_batch() + + return results + + + # return the child with the shortest queue + # if a tie, choose randomly + # You should call c.flush_results() first before calling this + def _choose_child(self) -> Child: + min_q_sz = min(c.queue_sz for c in self.children) + return random.choice([c for c in self.children if c.queue_sz <= min_q_sz]) + + def starmap(self, func, iterable: Iterable[Iterable], chunksize=None, callback=None, error_callback=None) -> List[Any]: if chunksize: raise NotImplementedError("chunksize not implemented") @@ -423,7 +510,7 @@ def starmap(self, func, iterable: Iterable[Iterable], chunksize=None, callback=N for child in self.children: if child.parent_conn in ready: assert child.parent_conn.poll() - child.flush() + child.flush_results() idle_children.add(child) diff --git a/lambda_multiprocessing/test_main.py b/lambda_multiprocessing/test_main.py index cad6192..2ec9316 100644 --- a/lambda_multiprocessing/test_main.py +++ b/lambda_multiprocessing/test_main.py @@ -5,6 +5,7 @@ from typing import Tuple, Optional from pathlib import Path import os +from functools import cache import boto3 from moto import mock_aws @@ -14,6 +15,8 @@ # if other processes are hogging CPU, make this bigger delta = 0.1 +SEC_PER_MIN = 60 + # some simple functions to run inside the child process def square(x): return x*x @@ -52,6 +55,16 @@ def test_standard_library(self): # add assertDuration class TestCase(unittest.TestCase): + + max_timeout = SEC_PER_MIN*2 + timeout_mgr = TimeoutManager(seconds=max_timeout) + + def setUp(self): + self.timeout_mgr.start() + + def tearDown(self): + self.timeout_mgr.stop() + # use like # with self.assertDuration(1, 2): # something @@ -73,10 +86,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.test.assertGreaterEqual(duration, min_t, f"Took less than {min_t}s to run") if max_t is not None: self.test.assertLessEqual(duration, max_t, f"Took more than {max_t}s to run") + return False return AssertDuration(self) + class TestAssertDuration(TestCase): def test(self): @@ -326,12 +341,14 @@ class TestMapAsync(TestCase): pool_generator = Pool def test_simple(self): - args = range(5) - with self.pool_generator() as p: - actual = p.map_async(square, args) - self.assertIsInstance(actual, (AsyncResult, multiprocessing.pool.AsyncResult)) - results = actual.get() - self.assertEqual(results, [square(e) for e in args]) + num_payloads = 5 + args = range(num_payloads) + for num_procs in [(num_payloads+1), (num_payloads-1)]: + with self.pool_generator(num_procs) as p: + actual = p.map_async(square, args) + self.assertIsInstance(actual, (AsyncResult, multiprocessing.pool.AsyncResult)) + results = actual.get() + self.assertEqual(results, [square(e) for e in args]) def test_duration(self): sleep_duration = 0.5 @@ -526,24 +543,25 @@ def test_map_async_deadlock(self): # even when there are multiple tasks per child # with payloads bigger than the buffer def test_nonblocking(self): - sleep_duration = 0.5 + sleep_duration = 2 n_procs = 2 num_tasks_per_proc = 3 expected_wall_time = sleep_duration * num_tasks_per_proc - # use a big payload, to ensure the buffer fills up from the first arg + + args = [(self.generate_big_data(), sleep_duration)] * (num_tasks_per_proc * n_procs) with Pool(n_procs) as p: try: with TimeoutManager(sleep_duration * num_tasks_per_proc * 2, "This Library's map_async deadlocked"): - with self.assertDuration(max_t=delta): - results = p.map_async(return_with_sleep, [(self.generate_big_data(), sleep_duration)] * (num_tasks_per_proc * n_procs)) - - results.get() - except TestTimeoutException: + with self.assertDuration(max_t=sleep_duration*0.5): + results = p.map_async(return_with_sleep, args) + results.get(expected_wall_time*1.5) + except Exception: p.terminate() raise @classmethod - def generate_big_data(cls, sz=2**26) -> bytes: + @cache + def generate_big_data(cls, sz=2**24) -> bytes: return 'x' * sz diff --git a/lambda_multiprocessing/timeout.py b/lambda_multiprocessing/timeout.py index 3c58fb6..4b95b96 100644 --- a/lambda_multiprocessing/timeout.py +++ b/lambda_multiprocessing/timeout.py @@ -15,19 +15,27 @@ def __init__(self, seconds: int, description = "Test timed out"): self.old_handler = None def __enter__(self): + self.start() + return self + + def start(self): + assert self.old_handler is None, "Alarm already started" self.old_handler = signal.signal(signal.SIGALRM, self.timeout_handler) signal.alarm(self.seconds) - return self def timeout_handler(self, signum, frame): raise TestTimeoutException(self.description) - def __exit__(self, exc_type, exc_value, traceback): - # Disable the alarm - signal.alarm(0) - # Restore old signal handler - signal.signal(signal.SIGALRM, self.old_handler) + def stop(self): + if self.old_handler is not None: + # Disable the alarm + signal.alarm(0) + # Restore old signal handler + signal.signal(signal.SIGALRM, self.old_handler) + self.old_handler = None + def __exit__(self, exc_type, exc_value, traceback): + self.stop() if exc_type is TestTimeoutException: return False # Let the TestTimeoutException exception propagate From 4b91430f4efd241c4c87de28a5fd6d1fe00b7f05 Mon Sep 17 00:00:00 2001 From: Matthew Davis Date: Wed, 13 Nov 2024 17:57:00 +0100 Subject: [PATCH 08/12] explain deadlock solution --- README.md | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index a1c0d01..9a24223 100644 --- a/README.md +++ b/README.md @@ -123,24 +123,22 @@ You probably don't need to touch those. When you `__enter__` the pool, it creates several `Child`s. These contain the actual child `Process`es, -plus a duplex pipe to send tasks to the child and get results back. +plus a duplex pipe to send `Task`s to the child and get `Response`s back. The child process just waits for payloads to appear in the pipe. It grabs the function and arguments from it, does the work, catches any exception, then sends the exception or result back through the pipe. -Note that the function that the client gives to this library might return an Exception for some reason, -so we return either `[result, None]` or `[None, Exception]`, to differentiate. +Note that the arguments and return functions to this function could be anything. +(It's even possible that the function _returns_ an Exception, instead of throwing one.) -To close everything up when we're done, the parent sends a payload with a different structure (`payload[-1] == True`) -and then the child will gracefully exit. +To close everything up when we're done, the parent sends a different subclass of `Request`, which is `QuitSignal`. Upon receiving this, the child will gracefully exit. We keep a counter of how many tasks we've given to the child, minus how many results we've got back. When assigning work, we give it to a child chosen randomly from the set of children whose counter value is smallest. (i.e. shortest backlog) When passing the question and answer to the child and back, we pass around a UUID. -This is because the client may submit two tasks with apply_async, then request the result for the second one, -before the first. +This is because the client may submit two tasks with apply_async, then request the result for the second one, before the first. We can't assume that the next result coming back from the child is the one we want, since each child can have a backlog of a few tasks. @@ -150,3 +148,24 @@ and passing pipes through pipes is unusually slow on low-memory Lambda functions Note that `multiprocessing.Queue` doesn't work in Lambda functions. So we can't use that to distribute work amongst the children. + +### Deadlock + +Note that we must be careful about a particular deadlocking scenario, +described in [this issue](https://github.com/mdavis-xyz/lambda_multiprocessing/issues/17#issuecomment-2468560515) + +Writes to `Pipe`s are usually non-blocking. However if you're writing something large +(>90kB, IIRC) the Pipe's buffer will fill up. The writer will block, +waiting for the reader at the other end to start reading. + +The situation which previously occured was: + +* parent sends a task to the child +* child reads the task from the pipe, and starts working on it +* parent immediately sends the next task, which blocks because the object is larger than the buffer +* child tries sending the result from the first task, which blocks because the result is larger than the buffer + +In this situation both processes are mid-way through writing something, and won't continue until the other starts reading from the other pipe. +Solutions like having the parent/child check if there's data to receive before sending data won't work because those two steps are not atomic. (I did not want to introduce locks, for performance reasons.) + +The solution is that the child will read all available `Task` payloads sent from the parent, into a local buffer, without commencing work on them. Then once it receives a `RunBatchSignal`, it stops reading anything else from the parent, and starts work on the tasks in the buffer. By batching tasks in this way, we can prevent the deadlock, and also ensure that the `*async` functions are non-blocking for large payloads. From 39bf307af4e6d9650360178c72c8a1271e7b8fe6 Mon Sep 17 00:00:00 2001 From: Matthew Davis Date: Wed, 13 Nov 2024 17:59:25 +0100 Subject: [PATCH 09/12] fix bug with class subclass syntax --- lambda_multiprocessing/main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lambda_multiprocessing/main.py b/lambda_multiprocessing/main.py index 63a3d3f..cac74a5 100644 --- a/lambda_multiprocessing/main.py +++ b/lambda_multiprocessing/main.py @@ -12,7 +12,7 @@ class Request: pass -class Task[Request]: +class Task(Request): func: Callable args: List kwds: Dict @@ -24,10 +24,10 @@ def __init__(self, func, args=(), kwds={}, id=None): self.kwds = kwds or {} self.id = id or uuid4() -class QuitSignal[Request]: +class QuitSignal(Request): pass -class RunBatchSignal[Request]: +class RunBatchSignal(Request): pass # this is what we send back through the pipe from child to parent @@ -35,7 +35,7 @@ class Response: id: UUID pass -class SuccessResponse: +class SuccessResponse(Response): result: Any def __init__(self, id: UUID, result: Any): self.result = result @@ -43,7 +43,7 @@ def __init__(self, id: UUID, result: Any): # processing the task raised an exception # we save the exception in this object -class FailResponse: +class FailResponse(Response): exception: Exception def __init__(self, id: UUID, exception: Exception): self.id = id From 4c3cad5a07b199e5735330c4ac2b7c08a0c45fb5 Mon Sep 17 00:00:00 2001 From: Matthew Davis Date: Wed, 13 Nov 2024 17:59:33 +0100 Subject: [PATCH 10/12] add python 3.13 to unit test CICD --- .github/workflows/unit.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index db88217..4533bb3 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -17,7 +17,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] runs-on: ubuntu-latest steps: From 1a3f429d7649eaf1fa3d2a1347a8bfe4227c7aa4 Mon Sep 17 00:00:00 2001 From: Matthew Davis Date: Wed, 13 Nov 2024 18:07:32 +0100 Subject: [PATCH 11/12] fix imports so unit tests work from root and library dir --- lambda_multiprocessing/test_main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lambda_multiprocessing/test_main.py b/lambda_multiprocessing/test_main.py index 2ec9316..37f159f 100644 --- a/lambda_multiprocessing/test_main.py +++ b/lambda_multiprocessing/test_main.py @@ -9,7 +9,7 @@ import boto3 from moto import mock_aws -from timeout import TimeoutManager, TestTimeoutException +from lambda_multiprocessing.timeout import TimeoutManager, TestTimeoutException # add an overhead for duration when asserting the duration of child processes # if other processes are hogging CPU, make this bigger @@ -39,7 +39,7 @@ def return_with_sleep(x, delay=0.3): def _raise(ex: Optional[Exception]): if ex: - raise ex + raise exfrom .timeout class ExceptionA(Exception): pass @@ -73,7 +73,7 @@ def tearDown(self): # For a potential eternal task, use timeout.TimeoutManager def assertDuration(self, min_t=None, max_t=None): class AssertDuration: - def __init__(self, test): + def __init__(self, test: unittest.TestCase): self.test = test def __enter__(self): self.start_t = time() From 68988771369443dad1ae37691d68275bbca589df Mon Sep 17 00:00:00 2001 From: Matthew Davis Date: Wed, 13 Nov 2024 18:31:33 +0100 Subject: [PATCH 12/12] make test script compatible with python 3.8 --- lambda_multiprocessing/test_main.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/lambda_multiprocessing/test_main.py b/lambda_multiprocessing/test_main.py index 37f159f..50df0b9 100644 --- a/lambda_multiprocessing/test_main.py +++ b/lambda_multiprocessing/test_main.py @@ -5,12 +5,23 @@ from typing import Tuple, Optional from pathlib import Path import os -from functools import cache +import sys + import boto3 from moto import mock_aws from lambda_multiprocessing.timeout import TimeoutManager, TestTimeoutException +if sys.version_info < (3, 9): + # functools.cache was added in 3.9 + # define an empty decorator that doesn't do anything + # (our usage of the cache isn't essential) + def cache(func): + return func +else: + # Import the cache function from functools for Python 3.9 and above + from functools import cache + # add an overhead for duration when asserting the duration of child processes # if other processes are hogging CPU, make this bigger delta = 0.1 @@ -39,7 +50,7 @@ def return_with_sleep(x, delay=0.3): def _raise(ex: Optional[Exception]): if ex: - raise exfrom .timeout + raise ex class ExceptionA(Exception): pass