diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index db88217..4533bb3 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -17,7 +17,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] runs-on: ubuntu-latest steps: diff --git a/README.md b/README.md index 559bfa5..9a24223 100644 --- a/README.md +++ b/README.md @@ -113,31 +113,32 @@ Then you can run the unit tests with: python3 -m unittest ``` -`CICD` is for the GitHub Actions which run unit tests and integration tests. +The tests themselves are defined in `lambda_multiprocessing/test_main.py`. + +`CICD` is for the GitHub Actions which run the unit tests and integration tests. You probably don't need to touch those. + ## Design When you `__enter__` the pool, it creates several `Child`s. These contain the actual child `Process`es, -plus a duplex pipe to send tasks to the child and get results back. +plus a duplex pipe to send `Task`s to the child and get `Response`s back. The child process just waits for payloads to appear in the pipe. It grabs the function and arguments from it, does the work, catches any exception, then sends the exception or result back through the pipe. -Note that the function that the client gives to this library might return an Exception for some reason, -so we return either `[result, None]` or `[None, Exception]`, to differentiate. +Note that the arguments and return functions to this function could be anything. +(It's even possible that the function _returns_ an Exception, instead of throwing one.) -To close everything up when we're done, the parent sends a payload with a different structure (`payload[-1] == True`) -and then the child will gracefully exit. +To close everything up when we're done, the parent sends a different subclass of `Request`, which is `QuitSignal`. Upon receiving this, the child will gracefully exit. We keep a counter of how many tasks we've given to the child, minus how many results we've got back. When assigning work, we give it to a child chosen randomly from the set of children whose counter value is smallest. (i.e. shortest backlog) When passing the question and answer to the child and back, we pass around a UUID. -This is because the client may submit two tasks with apply_async, then request the result for the second one, -before the first. +This is because the client may submit two tasks with apply_async, then request the result for the second one, before the first. We can't assume that the next result coming back from the child is the one we want, since each child can have a backlog of a few tasks. @@ -147,3 +148,24 @@ and passing pipes through pipes is unusually slow on low-memory Lambda functions Note that `multiprocessing.Queue` doesn't work in Lambda functions. So we can't use that to distribute work amongst the children. + +### Deadlock + +Note that we must be careful about a particular deadlocking scenario, +described in [this issue](https://github.com/mdavis-xyz/lambda_multiprocessing/issues/17#issuecomment-2468560515) + +Writes to `Pipe`s are usually non-blocking. However if you're writing something large +(>90kB, IIRC) the Pipe's buffer will fill up. The writer will block, +waiting for the reader at the other end to start reading. + +The situation which previously occured was: + +* parent sends a task to the child +* child reads the task from the pipe, and starts working on it +* parent immediately sends the next task, which blocks because the object is larger than the buffer +* child tries sending the result from the first task, which blocks because the result is larger than the buffer + +In this situation both processes are mid-way through writing something, and won't continue until the other starts reading from the other pipe. +Solutions like having the parent/child check if there's data to receive before sending data won't work because those two steps are not atomic. (I did not want to introduce locks, for performance reasons.) + +The solution is that the child will read all available `Task` payloads sent from the parent, into a local buffer, without commencing work on them. Then once it receives a `RunBatchSignal`, it stops reading anything else from the parent, and starts work on the tasks in the buffer. By batching tasks in this way, we can prevent the deadlock, and also ensure that the `*async` functions are non-blocking for large payloads. diff --git a/lambda_multiprocessing/main.py b/lambda_multiprocessing/main.py index a88bf6e..cac74a5 100644 --- a/lambda_multiprocessing/main.py +++ b/lambda_multiprocessing/main.py @@ -1,10 +1,53 @@ from multiprocessing import TimeoutError, Process, Pipe from multiprocessing.connection import Connection -from typing import Any, Iterable, List, Dict, Tuple, Union +from typing import Any, Iterable, List, Dict, Tuple, Union, Optional, Callable from uuid import uuid4, UUID import random import os from time import time +from select import select +from itertools import repeat + +# This is what we send down the pipe from parent to child +class Request: + pass + +class Task(Request): + func: Callable + args: List + kwds: Dict + id: UUID + + def __init__(self, func, args=(), kwds={}, id=None): + self.func = func + self.args = args + self.kwds = kwds or {} + self.id = id or uuid4() + +class QuitSignal(Request): + pass + +class RunBatchSignal(Request): + pass + +# this is what we send back through the pipe from child to parent +class Response: + id: UUID + pass + +class SuccessResponse(Response): + result: Any + def __init__(self, id: UUID, result: Any): + self.result = result + self.id = id + +# processing the task raised an exception +# we save the exception in this object +class FailResponse(Response): + exception: Exception + def __init__(self, id: UUID, exception: Exception): + self.id = id + self.exception = exception class Child: proc: Process @@ -17,12 +60,12 @@ class Child: # does not include the termination command from parent to child queue_sz: int = 0 - # parent_conn.send() to give stuff to the child + # parent_conn.send(Request) to give stuff to the child # parent_conn.recv() to get results back from child parent_conn: Connection child_conn: Connection - result_cache: Dict[UUID, Tuple[Any, Exception]] = {} + response_cache: Dict[UUID, Tuple[Any, Exception]] = {} _closed: bool = False @@ -47,47 +90,81 @@ def __init__(self, main_proc=False): # {id: (None, err)} if func raised exception err # [None, True] -> exit gracefully (write nothing to the pipe) def spin(self) -> None: + quit_signal = False while True: - (job, quit_signal) = self.child_conn.recv() + req_buf = [] + # first read in tasks until we get a pause/quit signal + while True: + request = self.child_conn.recv() + if isinstance(request, Task): + req_buf.append(request) + elif isinstance(request, QuitSignal): + # don't quit yet. Finish what's in the buffer. + quit_signal |= True + break # stop reading new stuff from the pipe + elif isinstance(request, RunBatchSignal): + # stop reading from Pipe, process what's in the request buffer + break + result_buf = [] + for req in req_buf: + assert isinstance(req, Task) + # process the result + result = self._do_work(req) + result_buf.append(result) + + # send all the results + for result in result_buf: + self.child_conn.send(result) + if quit_signal: break - else: - (id, func, args, kwds) = job - result = self._do_work(id, func, args, kwds) - self.child_conn.send(result) self.child_conn.close() - def _do_work(self, id, func, args, kwds) -> Union[Tuple[Any, None], Tuple[None, Exception]]: + # applies the function, catching any exception if it occurs + def _do_work(self, task: Task) -> Response: try: - ret = {id: (func(*args, **kwds), None)} + result = task.func(*task.args, **task.kwds) except Exception as e: # how to handle KeyboardInterrupt? - ret = {id: (None, e)} - assert isinstance(list(ret.keys())[0], UUID) - return ret - - def submit(self, func, args=(), kwds=None) -> 'AsyncResult': + resp = FailResponse(id=task.id, exception=e) + else: + resp = SuccessResponse(id=task.id, result=result) + return resp + + # this sends a task to the child + # if as_batch=False, the child will start work on it immediately + # If as_batch=True, the child will load this into it's local buffer + # but won't start processing until we send a RunBatchSignal with self.run_batch() + def submit(self, func, args=(), kwds=None, as_batch=False) -> 'AsyncResult': if self._closed: raise ValueError("Cannot submit tasks after closure") - if kwds is None: - kwds = {} - id = uuid4() - self.parent_conn.send([(id, func, args, kwds), None]) + request = Task(func=func, args=args, kwds=kwds) + + self.parent_conn.send(request) if self.main_proc: self.child_conn.recv() - ret = self._do_work(id, func, args, kwds) + ret = self._do_work(request) self.child_conn.send(ret) + elif not as_batch: + # treat this as a batch of 1 + self.run_batch() self.queue_sz += 1 - return AsyncResult(id=id, child=self) + return AsyncResult(id=request.id, child=self) + + # non-blocking + # Tells the child to start commencing work on all tasks sent up until now + def run_batch(self): + self.parent_conn.send(RunBatchSignal()) # grab all results in the pipe from child to parent - # save them to self.result_cache - def flush(self): + # save them to self.response_cache + def flush_results(self): # watch out, when the other end is closed, a termination byte appears, so .poll() returns True while (not self.parent_conn.closed) and (self.queue_sz > 0) and self.parent_conn.poll(0): result = self.parent_conn.recv() - assert isinstance(list(result.keys())[0], UUID) - self.result_cache.update(result) + id = result.id + assert id not in self.response_cache + self.response_cache[id] = result self.queue_sz -= 1 # prevent new tasks from being submitted @@ -97,10 +174,10 @@ def close(self): if not self._closed: if not self.main_proc: # send quit signal to child - self.parent_conn.send([None, True]) + self.parent_conn.send(QuitSignal()) else: # no child process to close - self.flush() + self.flush_results() self.child_conn.close() # keep track of closure, @@ -122,7 +199,7 @@ def join(self): finally: self.proc.close() - self.flush() + self.flush_results() self.parent_conn.close() @@ -155,13 +232,13 @@ def __init__(self, id: UUID, child: Child): assert isinstance(id, UUID) self.id = id self.child = child - self.result: Union[Tuple[Any, None], Tuple[None, Exception]] = None + self.response: Result = None - # assume the result is in the self.child.result_cache - # move it into self.result + # assume the result is in the self.child.response_cache + # move it into self.response def _load(self): - self.result = self.child.result_cache[self.id] - del self.child.result_cache[self.id] # prevent memory leak + self.response = self.child.response_cache[self.id] + del self.child.response_cache[self.id] # prevent memory leak # Return the result when it arrives. # If timeout is not None and the result does not arrive within timeout seconds @@ -169,15 +246,15 @@ def _load(self): # If the remote call raised an exception then that exception will be reraised by get(). # .get() must remember the result # and return it again multiple times - # delete it from the Child.result_cache to avoid memory leak + # delete it from the Child.response_cache to avoid memory leak def get(self, timeout=None): - if self.result is not None: - (response, ex) = self.result - if ex: - raise ex - else: - return response - elif self.id in self.child.result_cache: + if self.response is not None: + if isinstance(self.response, SuccessResponse): + return self.response.result + elif isinstance(self.response, FailResponse): + assert isinstance(self.response.exception, Exception) + raise self.response.exception + elif self.id in self.child.response_cache: self._load() return self.get(0) else: @@ -189,12 +266,12 @@ def get(self, timeout=None): # Wait until the result is available or until timeout seconds pass. def wait(self, timeout=None): - start_t = time() - if self.result is None: - self.child.flush() + if self.response is None: + start_t = time() + self.child.flush_results() # the result we want might not be the next result # it might be the 2nd or 3rd next - while (self.id not in self.child.result_cache) and \ + while (self.id not in self.child.response_cache) and \ ((timeout is None) or (time() - timeout < start_t)): if timeout is None: self.child.parent_conn.poll() @@ -203,23 +280,86 @@ def wait(self, timeout=None): remaining = timeout - elapsed_so_far self.child.parent_conn.poll(remaining) if self.child.parent_conn.poll(0): - self.child.flush() + self.child.flush_results() # Return whether the call has completed. def ready(self): - self.child.flush() - return self.result or (self.id in self.child.result_cache) + self.child.flush_results() + return self.response or (self.id in self.child.response_cache) # Return whether the call completed without raising an exception. # Will raise ValueError if the result is not ready. def successful(self): - if self.result is None: + if self.response is None: if not self.ready(): raise ValueError("Result is not ready") else: self._load() - return self.result[1] is None + return isinstance(self.response, SuccessResponse) + +# map_async and starmap_async return a single AsyncResult +# which is a list of actual results +# This class aggregates many AsyncResult into one +class AsyncResultList(AsyncResult): + def __init__(self, child_results: List[AsyncResult]): + self.child_results = child_results + self.result: List[Union[Tuple[Any, None], Tuple[None, Exception]]] = None + + # assume the result is in the self.child.response_cache + # move it into self.result + def _load(self): + for c in self.child_results: + c._load() + + # Return the result when it arrives. + # If timeout is not None and the result does not arrive within timeout seconds + # then multiprocessing.TimeoutError is raised. + # If the remote call raised an exception then that exception will be reraised by get(). + # .get() must remember the result + # and return it again multiple times + # delete it from the Child.response_cache to avoid memory leak + def get(self, timeout=None): + self.wait(timeout) + assert self.ready() + + results = [] + for (i, c) in enumerate(self.child_results): + try: + result = c.get(0) + except Exception: + print(f"Exception raised for {i}th task out of {len(self.child_results)}") + for c2 in self.child_results[i+1:]: + c2.child.flush_results() + raise + + results.append(result) + return results + + # Wait until the result is available or until timeout seconds pass. + def wait(self, timeout=None): + + if timeout: + end_t = time() + timeout + else: + end_t = None + for c in self.child_results: + # Consider cumulative timeout + if timeout is not None: + timeout_remaining = end_t - time() + else: + timeout_remaining = None + + c.wait(timeout_remaining) + + # Return whether the call has completed. + def ready(self): + return all(c.ready() for c in self.child_results) + + # Return whether the call completed without raising an exception. + # Will raise ValueError if the result is not ready. + def successful(self): + return all(c.successful() for c in self.child_results) class Pool: def __init__(self, processes=None, initializer=None, initargs=None, maxtasksperchild=None, context=None): @@ -293,36 +433,106 @@ def apply_async(self, func, args=(), kwds=None, callback=None, error_callback=No if error_callback: raise NotImplementedError("error_callback not implemented") - if self._closed: - raise ValueError("Pool already closed") - if kwds is None: - kwds = {} - + results = self._apply_batch_async(func, [args], [kwds]) + assert len(results) == 1 + return results[0] - # choose the first idle process if there is one - # if not, choose the process with the shortest queue - for c in self.children: - c.flush() - min_q_sz = min(c.queue_sz for c in self.children) - c = random.choice([c for c in self.children if c.queue_sz <= min_q_sz]) - return c.submit(func, args, kwds) - def map_async(self, func, iterable, chunksize=None, callback=None, error_callback=None) -> List[AsyncResult]: + def map_async(self, func, iterable, chunksize=None, callback=None, error_callback=None) -> AsyncResult: return self.starmap_async(func, zip(iterable), chunksize, callback, error_callback) def map(self, func, iterable, chunksize=None, callback=None, error_callback=None) -> List: return self.starmap(func, zip(iterable), chunksize, callback, error_callback) - def starmap_async(self, func, iterable: Iterable[Iterable], chunksize=None, callback=None, error_callback=None) -> List[AsyncResult]: + def starmap_async(self, func, iterable: Iterable[Iterable], chunksize=None, callback=None, error_callback=None) -> AsyncResult: if chunksize: raise NotImplementedError("Haven't implemented chunksizes. Infinite chunksize only.") if callback or error_callback: raise NotImplementedError("Haven't implemented callbacks") - return [self.apply_async(func, args) for args in iterable] + + results = self._apply_batch_async(func, args_iterable=iterable, kwds_iterable=repeat({})) + result = AsyncResultList(child_results=results) + return result + + # like starmap, but has argument for keyword args + # so apply_async can call this + # (apply_async supports kwargs but starmap_async does not) + def _apply_batch_async(self, func, args_iterable: Iterable[Iterable], kwds_iterable: Iterable[Dict]) -> List[AsyncResult]: + if self._closed: + raise ValueError("Pool already closed") + + for c in self.children: + c.flush_results() + + results = [] + children_called = set() + for (args, kwds) in zip(args_iterable, kwds_iterable): + child = self._choose_child() # already flushed results + children_called.add(child) + result = child.submit(func, args, kwds, as_batch=True) + results.append(result) + + for child in children_called: + child.run_batch() + + return results + + + # return the child with the shortest queue + # if a tie, choose randomly + # You should call c.flush_results() first before calling this + def _choose_child(self) -> Child: + min_q_sz = min(c.queue_sz for c in self.children) + return random.choice([c for c in self.children if c.queue_sz <= min_q_sz]) + def starmap(self, func, iterable: Iterable[Iterable], chunksize=None, callback=None, error_callback=None) -> List[Any]: - results = self.starmap_async(func, iterable, chunksize, callback, error_callback) - return [r.get() for r in results] + if chunksize: + raise NotImplementedError("chunksize not implemented") + if callback: + raise NotImplementedError("callback not implemented") + if error_callback: + raise NotImplementedError("error_callback not implemented") + + idle_children = set(self.children) + ids = [] + pending_results: Dict[UUID, AsyncResult] = {} + + for args in iterable: + if not idle_children: + # wait for a child to become idle + # by waiting for any of the pipes from children to become readable + ready, _, _ = select([c.parent_conn for c in self.children], [], []) + + # at least one child is idle. + # check all children, read their last result from the pipe + # then issue the new task + for child in self.children: + if child.parent_conn in ready: + assert child.parent_conn.poll() + child.flush_results() + idle_children.add(child) + + + child = idle_children.pop() + async_result = child.submit(func, args) + pending_results[async_result.id] = async_result + ids.append(async_result.id) + + + if len(idle_children) < len(self.children): + # if at least one child is still working + # wait with select + ready, _, _ = select([c.parent_conn for c in self.children if c not in idle_children], [], []) + + # get all the results + # re-arranging the order + results = [] + for (i, id) in enumerate(ids): + result = pending_results[id].get() + results.append(result) + + return results def imap(self, func, iterable, chunksize=None): raise NotImplementedError("Only normal apply, map, starmap and their async equivilents have been implemented") diff --git a/lambda_multiprocessing/test_main.py b/lambda_multiprocessing/test_main.py index 9a522c7..50df0b9 100644 --- a/lambda_multiprocessing/test_main.py +++ b/lambda_multiprocessing/test_main.py @@ -1,18 +1,33 @@ import unittest -import multiprocessing +import multiprocessing, multiprocessing.pool from lambda_multiprocessing import Pool, TimeoutError, AsyncResult from time import time, sleep -from typing import Tuple +from typing import Tuple, Optional from pathlib import Path import os +import sys + import boto3 from moto import mock_aws +from lambda_multiprocessing.timeout import TimeoutManager, TestTimeoutException + +if sys.version_info < (3, 9): + # functools.cache was added in 3.9 + # define an empty decorator that doesn't do anything + # (our usage of the cache isn't essential) + def cache(func): + return func +else: + # Import the cache function from functools for Python 3.9 and above + from functools import cache # add an overhead for duration when asserting the duration of child processes # if other processes are hogging CPU, make this bigger delta = 0.1 +SEC_PER_MIN = 60 + # some simple functions to run inside the child process def square(x): return x*x @@ -29,6 +44,19 @@ def divide(a, b): def return_args_kwargs(*args, **kwargs): return {'args': args, 'kwargs': kwargs} +def return_with_sleep(x, delay=0.3): + sleep(delay) + return x + +def _raise(ex: Optional[Exception]): + if ex: + raise ex + +class ExceptionA(Exception): + pass +class ExceptionB(Exception): + pass + class TestStdLib(unittest.TestCase): @unittest.skip('Need to set up to remove /dev/shm') def test_standard_library(self): @@ -38,13 +66,25 @@ def test_standard_library(self): # add assertDuration class TestCase(unittest.TestCase): + + max_timeout = SEC_PER_MIN*2 + timeout_mgr = TimeoutManager(seconds=max_timeout) + + def setUp(self): + self.timeout_mgr.start() + + def tearDown(self): + self.timeout_mgr.stop() + # use like # with self.assertDuration(1, 2): # something # to assert that something takes between 1 to 2 seconds to run + # If the task takes forever, this assertion will not be raised + # For a potential eternal task, use timeout.TimeoutManager def assertDuration(self, min_t=None, max_t=None): class AssertDuration: - def __init__(self, test): + def __init__(self, test: unittest.TestCase): self.test = test def __enter__(self): self.start_t = time() @@ -57,10 +97,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.test.assertGreaterEqual(duration, min_t, f"Took less than {min_t}s to run") if max_t is not None: self.test.assertLessEqual(duration, max_t, f"Took more than {max_t}s to run") + return False return AssertDuration(self) + class TestAssertDuration(TestCase): def test(self): @@ -78,22 +120,23 @@ def test(self): sleep(t) class TestApply(TestCase): + pool_generator = Pool def test_simple(self): args = range(5) - with Pool() as p: + with self.pool_generator() as p: actual = [p.apply(square, (x,)) for x in args] with multiprocessing.Pool() as p: expected = [p.apply(square, (x,)) for x in args] self.assertEqual(actual, expected) def test_two_args(self): - with Pool() as p: + with self.pool_generator() as p: actual = p.apply(sum_two, (3, 4)) self.assertEqual(actual, sum_two(3, 4)) def test_kwargs(self): - with Pool() as p: + with self.pool_generator() as p: actual = p.apply(return_args_kwargs, (1, 2), {'x': 'X', 'y': 'Y'}) expected = { 'args': (1,2), @@ -104,10 +147,16 @@ def test_kwargs(self): } self.assertEqual(actual, expected) +# rerun all the tests with stdlib +# to confirm we have the same behavior +class TestApplyStdLib(TestApply): + pool_generator = multiprocessing.Pool + class TestApplyAsync(TestCase): + pool_generator = Pool def test_result(self): - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(square, (2,)) result = r.get(2) self.assertEqual(result, square(2)) @@ -115,7 +164,7 @@ def test_result(self): def test_twice(self): args = range(5) # now try twice with the same pool - with Pool() as p: + with self.pool_generator() as p: ret = [p.apply_async(square, (x,)) for x in args] ret.extend(p.apply_async(square, (x,)) for x in args) @@ -124,7 +173,7 @@ def test_twice(self): def test_time(self): # test the timing to confirm it's in parallel n = 4 - with Pool(n) as p: + with self.pool_generator(n) as p: with self.assertDuration(min_t=n-1, max_t=(n-1)+delta): with self.assertDuration(max_t=1): ret = [p.apply_async(sleep, (r,)) for r in range(n)] @@ -135,14 +184,14 @@ def test_unclean_exit(self): # .get after __exit__, but process finishes before __exit__ with self.assertRaises(multiprocessing.TimeoutError): t = 1 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(square, (t,)) sleep(1) r.get() with self.assertRaises(multiprocessing.TimeoutError): t = 1 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(square, (t,)) sleep(1) r.get(1) @@ -150,26 +199,26 @@ def test_unclean_exit(self): # __exit__ before process finishes with self.assertRaises(multiprocessing.TimeoutError): t = 1 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(sleep, (t,)) r.get(t+1) # .get() with arg after __exit__ with self.assertRaises(multiprocessing.TimeoutError): t = 1 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(sleep, (t,)) r.get() # .get() without arg after __exit__ def test_get_not_ready_a(self): t = 2 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(sleep, (t,)) with self.assertRaises(multiprocessing.TimeoutError): r.get(t-1) # result not ready get def test_get_not_ready_b(self): t = 2 - with Pool() as p: + with self.pool_generator() as p: # check same exception exists from main r = p.apply_async(sleep, (t,)) with self.assertRaises(TimeoutError): @@ -177,7 +226,7 @@ def test_get_not_ready_b(self): def test_get_not_ready_c(self): t = 2 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(sleep, (t,)) sleep(1) self.assertFalse(r.ready()) @@ -185,7 +234,7 @@ def test_get_not_ready_c(self): self.assertTrue(r.ready()) def test_wait(self): - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(square, (1,)) with self.assertDuration(max_t=delta): r.wait() @@ -209,13 +258,13 @@ def test_wait(self): ret = r.get(0) def test_get_twice(self): - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(square, (2,)) self.assertEqual(r.get(), square(2)) self.assertEqual(r.get(), square(2)) def test_successful(self): - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(square, (1,)) sleep(delta) self.assertTrue(r.successful()) @@ -233,13 +282,13 @@ def test_successful(self): self.assertFalse(r.successful()) def test_two_args(self): - with Pool() as p: + with self.pool_generator() as p: ret = p.apply_async(sum_two, (1, 2)) ret.wait() self.assertEqual(ret.get(), sum_two(1,2)) def test_kwargs(self): - with Pool() as p: + with self.pool_generator() as p: actual = p.apply_async(return_args_kwargs, (1, 2), {'x': 'X', 'y': 'Y'}).get() expected = { 'args': (1,2), @@ -252,14 +301,21 @@ def test_kwargs(self): def test_error_handling(self): with self.assertRaises(AssertionError): - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(fail, (1,)) r.get() +# retrun tests with standard library +# to confirm our behavior matches theirs +class TestApplyAsyncStdLib(TestApplyAsync): + pool_generator = multiprocessing.Pool + class TestMap(TestCase): + pool_generator = Pool + def test_simple(self): args = range(5) - with Pool() as p: + with self.pool_generator() as p: actual = p.map(square, args) self.assertIsInstance(actual, list) expected = [square(x) for x in args] @@ -268,111 +324,136 @@ def test_simple(self): def test_duration(self): n = 2 - with Pool(n) as p: + with self.pool_generator(n) as p: with self.assertDuration(min_t=(n-1)-delta, max_t=(n+1)+delta): p.map(sleep, range(n)) def test_error_handling(self): - with Pool() as p: + with self.pool_generator() as p: with self.assertRaises(AssertionError): p.map(fail, range(2)) @unittest.skip('Need to implement chunking to fix this') def test_long_iter(self): - with Pool() as p: + with self.pool_generator() as p: p.map(square, range(10**3)) def test_without_with(self): - # check that the standard library Pool - # can do .map() without `with` - p = multiprocessing.Pool(3) - ret = p.map(square, [1,2]) - self.assertEqual(ret, [1,2*2]) - p.close() - p.join() - - # now check that this library can do it - p = Pool(3) + p = self.pool_generator(3) ret = p.map(square, [1,2]) self.assertEqual(ret, [1,2*2]) p.close() p.join() - +class TestMapStdLib(TestMap): + pool_generator = multiprocessing.Pool class TestMapAsync(TestCase): - def test_simple(self): - args = range(5) - with Pool() as p: - actual = p.map_async(square, args) - self.assertIsInstance(actual, list) - for x in actual: - self.assertIsInstance(x, AsyncResult) + pool_generator = Pool - results = [a.get() for a in actual] - self.assertEqual(results, [square(e) for e in args]) + def test_simple(self): + num_payloads = 5 + args = range(num_payloads) + for num_procs in [(num_payloads+1), (num_payloads-1)]: + with self.pool_generator(num_procs) as p: + actual = p.map_async(square, args) + self.assertIsInstance(actual, (AsyncResult, multiprocessing.pool.AsyncResult)) + results = actual.get() + self.assertEqual(results, [square(e) for e in args]) def test_duration(self): - n = 2 - with Pool(n) as p: - with self.assertDuration(min_t=(n-1)-delta, max_t=(n+1)+delta): + sleep_duration = 0.5 + n_procs = 2 + num_tasks_per_proc = 3 + expected_wall_time = sleep_duration * num_tasks_per_proc + with self.pool_generator(n_procs) as p: + with self.assertDuration(min_t=expected_wall_time-delta, max_t=expected_wall_time+delta): with self.assertDuration(max_t=delta): - results = p.map_async(sleep, range(n)) - [r.get() for r in results] + results = p.map_async(sleep, [sleep_duration] * (num_tasks_per_proc * n_procs)) + results.get() + def test_error_handling(self): - with Pool() as p: + with self.pool_generator() as p: r = p.map_async(fail, range(2)) with self.assertRaises(AssertionError): - [x.get(1) for x in r] + r.get() + + def test_multi_error(self): + # standard library can raise either error + + with self.assertRaises((ExceptionA, ExceptionB)): + with self.pool_generator() as p: + r = p.map_async(_raise, (None, ExceptionA("Task 1"), None, ExceptionB("Task 3"))) + r.get() + +class TestMapAsyncStdLib(TestMapAsync): + pool_generator = multiprocessing.Pool class TestStarmap(TestCase): + pool_generator = Pool + def test(self): - with Pool() as p: + with self.pool_generator() as p: actual = p.starmap(sum_two, [(1,2), (3,4)]) expected = [(1+2), (3+4)] self.assertEqual(actual, expected) def test_error_handling(self): - with Pool() as p: + with self.pool_generator() as p: with self.assertRaises(ZeroDivisionError): p.starmap(divide, [(1,2), (3,0)]) - +class TestStarmapStdLib(TestStarmap): + pool_generator = multiprocessing.Pool class TestStarmapAsync(TestCase): + pool_generator = Pool + def test(self): - with Pool() as p: - actual = p.starmap_async(sum_two, [(1,2), (3,4)]) - self.assertIsInstance(actual, list) - actual = [r.get() for r in actual] + with self.pool_generator() as p: + response = p.starmap_async(sum_two, [(1,2), (3,4)]) + self.assertIsInstance(response, (AsyncResult, multiprocessing.pool.AsyncResult)) + actual = response.get() expected = [(1+2), (3+4)] self.assertEqual(actual, expected) def test_error_handling(self): - with Pool() as p: + with self.pool_generator() as p: results = p.starmap_async(divide, [(1,2), (3,0)]) with self.assertRaises(ZeroDivisionError): - [r.get() for r in results] + results.get() -class TestTidyUp(TestCase): + +class TestStarmapAsyncStdLib(TestStarmapAsync): + pool_generator = multiprocessing.Pool + +class TestExit(TestCase): + # only test this with our library, + # not the standard library + # because the standard library has a bug + # https://github.com/python/cpython/issues/79659 + # test that the implicit __exit__ # waits for child process to finish def test_exit(self): t = 1 with Pool() as p: - r = p.apply_async(sleep, (t,)) t1 = time() + r = p.apply_async(sleep, (t,)) t2 = time() self.assertLessEqual(abs((t2-t1)-t), delta) +class TestTidyUp(TestCase): + pool_generator = Pool + # test that .close() stops new submisssions # but does not halt existing num_processes # nor wait for them to finish def test_close(self): t = 1 - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(sleep, (t,)) with self.assertDuration(max_t=delta): p.close() @@ -381,7 +462,7 @@ def test_close(self): pass # makes traceback from __exit__ clearer def test_submit_after_close(self): - with Pool() as p: + with self.pool_generator() as p: p.close() with self.assertRaises(ValueError): p.apply_async(square, (1,)) @@ -390,7 +471,7 @@ def test_submit_after_close(self): # wait for child process to finish def test_terminate(self): with self.assertDuration(max_t=delta): - with Pool() as p: + with self.pool_generator() as p: r = p.apply_async(sleep, (1,)) t1 = time() p.terminate() @@ -398,11 +479,102 @@ def test_terminate(self): self.assertLessEqual(t2-t1, delta) def test_submit_after_terminate(self): - with Pool() as p: + with self.pool_generator() as p: p.terminate() with self.assertRaises(ValueError): p.apply_async(square, (1,)) +class TestTidyUpStdLib(TestTidyUp): + pool_generator = multiprocessing.Pool + +class TestDeadlock(TestCase): + + # test this issue: + # https://github.com/mdavis-xyz/lambda_multiprocessing/issues/17 + def test_map_deadlock(self): + + child_sleep = 0.01 + num_payloads = 6 + + # use standard library to start with + # and to measure a 'normal' duration + # (the time spend passing data between processes is longer than the sleep + # inside the worker) + expected_duration = child_sleep * num_payloads + data = [self.generate_big_data() for _ in range(num_payloads)] + start_t = time() + + with multiprocessing.Pool(processes=1) as p: + p.map(return_with_sleep, data) + + end_t = time() + stdlib_duration = end_t - start_t + + # this timeout manager doesn't work + # need to run the parent inside another process/thread? + # now our one + data = [self.generate_big_data() for _ in range(num_payloads)] + with Pool(processes=1) as p: + with TimeoutManager(stdlib_duration*2, "This Library's map deadlocked"): + try: + p.map(return_with_sleep, data) + except TestTimeoutException: + p.terminate() + raise + + def test_map_async_deadlock(self): + + child_sleep = 0.4 + num_payloads = 3 + data = [self.generate_big_data() for _ in range(num_payloads)] + + # use standard library to start with + # and to measure a 'normal' duration + # (the time spend passing data between processes is longer than the sleep + # inside the worker) + expected_duration = child_sleep * num_payloads + start_t = time() + with multiprocessing.Pool(processes=1) as p: + results = p.map_async(return_with_sleep, data) + results.get() + end_t = time() + stdlib_duration = end_t - start_t + + # now our one + with Pool(processes=1) as p: + with TimeoutManager(stdlib_duration*2, "This Library's map_async deadlocked"): + try: + results = p.map_async(return_with_sleep, data) + results.get() + except TestTimeoutException: + p.terminate() + raise + + # test that map_async returns immediately + # even when there are multiple tasks per child + # with payloads bigger than the buffer + def test_nonblocking(self): + sleep_duration = 2 + n_procs = 2 + num_tasks_per_proc = 3 + expected_wall_time = sleep_duration * num_tasks_per_proc + + args = [(self.generate_big_data(), sleep_duration)] * (num_tasks_per_proc * n_procs) + with Pool(n_procs) as p: + try: + with TimeoutManager(sleep_duration * num_tasks_per_proc * 2, "This Library's map_async deadlocked"): + with self.assertDuration(max_t=sleep_duration*0.5): + results = p.map_async(return_with_sleep, args) + results.get(expected_wall_time*1.5) + except Exception: + p.terminate() + raise + + @classmethod + @cache + def generate_big_data(cls, sz=2**24) -> bytes: + return 'x' * sz + # must be a global method to be pickleable def upload(args: Tuple[str, str, bytes]): @@ -424,11 +596,12 @@ def test_moto(self): bucket_name = 'mybucket' key = 'my-file' data = b"123" - client = boto3.client('s3') + region = os.getenv("AWS_DEFAULT_REGION", "ap-southeast-2") + client = boto3.client('s3', region_name=region) client.create_bucket( Bucket=bucket_name, CreateBucketConfiguration={ - 'LocationConstraint': 'ap-southeast-2' + 'LocationConstraint': region }, ) # upload in a different thread @@ -450,7 +623,7 @@ def test_moto(self): self.assertEqual(ret, data) class TestSlow(TestCase): - #@unittest.skip('Very slow') + @unittest.skip('Very slow') def test_memory_leak(self): for i in range(10**2): with Pool() as p: diff --git a/lambda_multiprocessing/timeout.py b/lambda_multiprocessing/timeout.py new file mode 100644 index 0000000..4b95b96 --- /dev/null +++ b/lambda_multiprocessing/timeout.py @@ -0,0 +1,43 @@ +# Timeout context manager for unit testing + +import signal +from math import ceil + +class TestTimeoutException(Exception): + """Exception raised when a test takes too long""" + pass + +class TimeoutManager: + # if a float is passed as seconds, it will be rounded up + def __init__(self, seconds: int, description = "Test timed out"): + self.seconds = ceil(seconds) + self.description = description + self.old_handler = None + + def __enter__(self): + self.start() + return self + + def start(self): + assert self.old_handler is None, "Alarm already started" + self.old_handler = signal.signal(signal.SIGALRM, self.timeout_handler) + signal.alarm(self.seconds) + + def timeout_handler(self, signum, frame): + raise TestTimeoutException(self.description) + + def stop(self): + if self.old_handler is not None: + # Disable the alarm + signal.alarm(0) + # Restore old signal handler + signal.signal(signal.SIGALRM, self.old_handler) + self.old_handler = None + + def __exit__(self, exc_type, exc_value, traceback): + self.stop() + if exc_type is TestTimeoutException: + return False # Let the TestTimeoutException exception propagate + + # Propagate any other exceptions, or continue if no exception + return False