diff --git a/arroyo/processing/strategies/run_task_with_multiprocessing.py b/arroyo/processing/strategies/run_task_with_multiprocessing.py index 1fc93c9e..1137d34b 100644 --- a/arroyo/processing/strategies/run_task_with_multiprocessing.py +++ b/arroyo/processing/strategies/run_task_with_multiprocessing.py @@ -279,6 +279,44 @@ def parallel_run_task_worker_apply( return ParallelRunTaskResult(next_index_to_process, valid_messages_transformed) +class MultiprocessingPool: + """ + Multiprocessing pool for the RunTaskWithMultiprocessing strategy. + It can be re-used each time the strategy is created on assignments. + + :param num_processes: The number of processes to spawn. + + :param initializer: A function to run at the beginning of each subprocess. + + Subprocesses are spawned without any of the state of the parent + process, they are entirely new Python interpreters. You might want to + re-initialize your Django application here. + + """ + + def __init__( + self, + num_processes: int, + initializer: Optional[Callable[[], None]] = None, + ) -> None: + self.__pool = Pool( + num_processes, + initializer=partial(parallel_worker_initializer, initializer), + context=multiprocessing.get_context("spawn"), + ) + self.__num_processes = num_processes + + @property + def num_processes(self) -> int: + return self.__num_processes + + def apply_async(self, *args: Any, **kwargs: Any) -> Any: + return self.__pool.apply_async(*args, **kwargs) + + def terminate(self) -> None: + self.__pool.terminate() + + class RunTaskWithMultiprocessing( ProcessingStrategy[Union[FilteredPayload, TStrategyPayload]], Generic[TStrategyPayload, TResult], @@ -291,10 +329,13 @@ class RunTaskWithMultiprocessing( :param function: The function to use for transforming. :param next_step: The processing strategy to forward transformed messages to. - :param num_processes: The number of processes to spawn. :param max_batch_size: Wait at most for this many messages before "closing" a batch. :param max_batch_time: Wait at most for this many seconds before closing a batch. + :param pool: The multiprocessing pool to use for parallel processing. The same pool + instance can be re-used each time ``RunTaskWithMultiprocessing`` is created on + rebalance. + :param input_block_size: For each subprocess, a shared memory buffer of ``input_block_size`` is allocated. This value should be at least `message_size * max_batch_size` large, where `message_size` is the expected @@ -325,11 +366,6 @@ class RunTaskWithMultiprocessing( :param max_output_block_size: Same as `max_input_block_size` but for output blocks. - :param initializer: A function to run at the beginning of each subprocess. - - Subprocesses are spawned without any of the state of the parent - process, they are entirely new Python interpreters. You might want to - re-initialize your Django application here. Number of processes ~~~~~~~~~~~~~~~~~~~ @@ -424,14 +460,13 @@ def __init__( self, function: Callable[[Message[TStrategyPayload]], TResult], next_step: ProcessingStrategy[Union[FilteredPayload, TResult]], - num_processes: int, max_batch_size: int, max_batch_time: float, + pool: MultiprocessingPool, input_block_size: Optional[int] = None, output_block_size: Optional[int] = None, max_input_block_size: Optional[int] = None, max_output_block_size: Optional[int] = None, - initializer: Optional[Callable[[], None]] = None, ) -> None: self.__transform_function = function self.__next_step = next_step @@ -446,11 +481,8 @@ def __init__( self.__shared_memory_manager = SharedMemoryManager() self.__shared_memory_manager.start() - self.__pool = Pool( - num_processes, - initializer=partial(parallel_worker_initializer, initializer), - context=multiprocessing.get_context("spawn"), - ) + self.__pool = pool + num_processes = self.__pool.num_processes self.__input_blocks = [ self.__shared_memory_manager.SharedMemory( @@ -799,7 +831,7 @@ def join(self, timeout: Optional[float] = None) -> None: raise logger.debug("Waiting for %s...", self.__pool) - self.__pool.terminate() + # self.__pool.terminate() self.__shared_memory_manager.shutdown() diff --git a/tests/processing/strategies/test_all.py b/tests/processing/strategies/test_all.py index 328deb0b..b06c89a4 100644 --- a/tests/processing/strategies/test_all.py +++ b/tests/processing/strategies/test_all.py @@ -26,6 +26,7 @@ from arroyo.processing.strategies.run_task import RunTask from arroyo.processing.strategies.run_task_in_threads import RunTaskInThreads from arroyo.processing.strategies.run_task_with_multiprocessing import ( + MultiprocessingPool, RunTaskWithMultiprocessing, ) from arroyo.types import ( @@ -72,9 +73,9 @@ def run_task_with_multiprocessing_factory( return RunTaskWithMultiprocessing( partial(run_task_function, raises_invalid_message), next_step=next_step, - num_processes=4, max_batch_size=10, max_batch_time=60, + pool=MultiprocessingPool(num_processes=4), input_block_size=16384, output_block_size=16384, ) diff --git a/tests/processing/strategies/test_run_task_with_multiprocessing.py b/tests/processing/strategies/test_run_task_with_multiprocessing.py index 1aa8bd4f..9ef26b7b 100644 --- a/tests/processing/strategies/test_run_task_with_multiprocessing.py +++ b/tests/processing/strategies/test_run_task_with_multiprocessing.py @@ -12,6 +12,7 @@ from arroyo.processing.strategies import MessageRejected from arroyo.processing.strategies.run_task_with_multiprocessing import ( MessageBatch, + MultiprocessingPool, RunTaskWithMultiprocessing, ValueTooLarge, parallel_run_task_worker_apply, @@ -200,9 +201,9 @@ def test_parallel_transform_step() -> None: transform_step = RunTaskWithMultiprocessing( transform_payload_expand, next_step, - num_processes=worker_processes, max_batch_size=5, max_batch_time=60, + pool=MultiprocessingPool(worker_processes), input_block_size=16384, output_block_size=16384, ) @@ -290,9 +291,9 @@ def test_parallel_run_task_terminate_workers() -> None: transform_step = RunTaskWithMultiprocessing( transform_payload_expand, # doesn't matter next_step, - num_processes=worker_processes, max_batch_size=5, max_batch_time=60, + pool=MultiprocessingPool(worker_processes), input_block_size=4096, output_block_size=4096, ) @@ -335,9 +336,9 @@ def test_message_rejected_multiple() -> None: strategy = RunTaskWithMultiprocessing( count_calls, next_step, - num_processes=1, max_batch_size=1, max_batch_time=60, + pool=MultiprocessingPool(num_processes=1), input_block_size=4096, output_block_size=4096, ) @@ -461,9 +462,9 @@ def test_regression_join_timeout_one_message() -> None: strategy = RunTaskWithMultiprocessing( run_sleep, next_step, - num_processes=1, max_batch_size=1, max_batch_time=60, + pool=MultiprocessingPool(num_processes=1), input_block_size=4096, output_block_size=4096, ) @@ -490,8 +491,8 @@ def test_regression_join_timeout_many_messages() -> None: strategy = RunTaskWithMultiprocessing( run_sleep, next_step, - num_processes=1, max_batch_size=1, + pool=MultiprocessingPool(num_processes=1), max_batch_time=60, input_block_size=4096, output_block_size=4096, @@ -525,9 +526,9 @@ def test_input_block_resizing_max_size() -> None: strategy = RunTaskWithMultiprocessing( run_multiply_times_two, next_step, - num_processes=2, max_batch_size=NUM_MESSAGES, max_batch_time=60, + pool=MultiprocessingPool(num_processes=2), input_block_size=None, output_block_size=INPUT_SIZE // 2, max_input_block_size=16000, @@ -552,12 +553,13 @@ def test_input_block_resizing_without_limits() -> None: NUM_MESSAGES = INPUT_SIZE // MSG_SIZE next_step = Mock() + pool = MultiprocessingPool(num_processes=2) strategy = RunTaskWithMultiprocessing( run_multiply_times_two, next_step, - num_processes=2, max_batch_size=NUM_MESSAGES, max_batch_time=60, + pool=pool, input_block_size=None, output_block_size=INPUT_SIZE // 2, ) @@ -588,9 +590,9 @@ def test_output_block_resizing_max_size() -> None: strategy = RunTaskWithMultiprocessing( run_multiply_times_two, next_step, - num_processes=2, max_batch_size=NUM_MESSAGES, max_batch_time=60, + pool=MultiprocessingPool(num_processes=2), input_block_size=INPUT_SIZE, output_block_size=None, max_output_block_size=16000, @@ -619,9 +621,9 @@ def test_output_block_resizing_without_limits() -> None: strategy = RunTaskWithMultiprocessing( run_multiply_times_two, next_step, - num_processes=2, max_batch_size=NUM_MESSAGES, max_batch_time=60, + pool=MultiprocessingPool(num_processes=2), input_block_size=INPUT_SIZE, output_block_size=None, ) @@ -665,9 +667,9 @@ def test_multiprocessing_with_invalid_message() -> None: strategy = RunTaskWithMultiprocessing( message_processor_raising_invalid_message, next_step, - num_processes=2, max_batch_size=1, max_batch_time=60, + pool=MultiprocessingPool(num_processes=2), ) strategy.submit( @@ -690,9 +692,9 @@ def test_reraise_invalid_message() -> None: strategy = RunTaskWithMultiprocessing( run_multiply_times_two, next_step, - num_processes=2, max_batch_size=1, max_batch_time=60, + pool=MultiprocessingPool(num_processes=2), ) strategy.submit(Message(Value(KafkaPayload(None, b"x" * 10, []), {}, now))) @@ -703,3 +705,47 @@ def test_reraise_invalid_message() -> None: next_step.poll.reset_mock(side_effect=True) strategy.close() strategy.join() + + +def slow_func(message: Message[int]) -> int: + time.sleep(0.2) + return message.payload + + +def test_reuse_pool() -> None: + # To be reused in strategy_one and strategy_two + pool = MultiprocessingPool(num_processes=2) + next_step = Mock() + + strategy_one = RunTaskWithMultiprocessing( + slow_func, + next_step, + max_batch_size=2, + max_batch_time=5, + pool=pool, + ) + + strategy_one.submit(Message(Value(10, committable={}))) + + strategy_one.close() + + # Join with timeout=0.0 to ensure there will be unprocessed pending messages + # in the first batch + strategy_one.join(0.0) + + strategy_two = RunTaskWithMultiprocessing( + slow_func, + next_step, + max_batch_size=2, + max_batch_time=5, + pool=pool, + ) + + strategy_two.submit(Message(Value(10, committable={}))) + + strategy_two.close() + + # Join with no timeout so the pending task will complete and message gets submitted + strategy_two.join() + + assert next_step.submit.call_count == 1