Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
lynnagara committed Dec 4, 2023
1 parent 7a2a871 commit e73be70
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 26 deletions.
60 changes: 46 additions & 14 deletions arroyo/processing/strategies/run_task_with_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,44 @@ def parallel_run_task_worker_apply(
return ParallelRunTaskResult(next_index_to_process, valid_messages_transformed)


class MultiprocessingPool:
"""
Multiprocessing pool for the RunTaskWithMultiprocessing strategy.
It can be re-used each time the strategy is created on assignments.
:param num_processes: The number of processes to spawn.
:param initializer: A function to run at the beginning of each subprocess.
Subprocesses are spawned without any of the state of the parent
process, they are entirely new Python interpreters. You might want to
re-initialize your Django application here.
"""

def __init__(
self,
num_processes: int,
initializer: Optional[Callable[[], None]] = None,
) -> None:
self.__pool = Pool(
num_processes,
initializer=partial(parallel_worker_initializer, initializer),
context=multiprocessing.get_context("spawn"),
)
self.__num_processes = num_processes

@property
def num_processes(self) -> int:
return self.__num_processes

def apply_async(self, *args: Any, **kwargs: Any) -> Any:
return self.__pool.apply_async(*args, **kwargs)

def terminate(self) -> None:
self.__pool.terminate()


class RunTaskWithMultiprocessing(
ProcessingStrategy[Union[FilteredPayload, TStrategyPayload]],
Generic[TStrategyPayload, TResult],
Expand All @@ -291,10 +329,13 @@ class RunTaskWithMultiprocessing(
:param function: The function to use for transforming.
:param next_step: The processing strategy to forward transformed messages to.
:param num_processes: The number of processes to spawn.
:param max_batch_size: Wait at most for this many messages before "closing" a batch.
:param max_batch_time: Wait at most for this many seconds before closing a batch.
:param pool: The multiprocessing pool to use for parallel processing. The same pool
instance can be re-used each time ``RunTaskWithMultiprocessing`` is created on
rebalance.
:param input_block_size: For each subprocess, a shared memory buffer of
``input_block_size`` is allocated. This value should be at least
`message_size * max_batch_size` large, where `message_size` is the expected
Expand Down Expand Up @@ -325,11 +366,6 @@ class RunTaskWithMultiprocessing(
:param max_output_block_size: Same as `max_input_block_size` but for output
blocks.
:param initializer: A function to run at the beginning of each subprocess.
Subprocesses are spawned without any of the state of the parent
process, they are entirely new Python interpreters. You might want to
re-initialize your Django application here.
Number of processes
~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -424,14 +460,13 @@ def __init__(
self,
function: Callable[[Message[TStrategyPayload]], TResult],
next_step: ProcessingStrategy[Union[FilteredPayload, TResult]],
num_processes: int,
max_batch_size: int,
max_batch_time: float,
pool: MultiprocessingPool,
input_block_size: Optional[int] = None,
output_block_size: Optional[int] = None,
max_input_block_size: Optional[int] = None,
max_output_block_size: Optional[int] = None,
initializer: Optional[Callable[[], None]] = None,
) -> None:
self.__transform_function = function
self.__next_step = next_step
Expand All @@ -446,11 +481,8 @@ def __init__(
self.__shared_memory_manager = SharedMemoryManager()
self.__shared_memory_manager.start()

self.__pool = Pool(
num_processes,
initializer=partial(parallel_worker_initializer, initializer),
context=multiprocessing.get_context("spawn"),
)
self.__pool = pool
num_processes = self.__pool.num_processes

self.__input_blocks = [
self.__shared_memory_manager.SharedMemory(
Expand Down Expand Up @@ -799,7 +831,7 @@ def join(self, timeout: Optional[float] = None) -> None:
raise

logger.debug("Waiting for %s...", self.__pool)
self.__pool.terminate()
# self.__pool.terminate()

self.__shared_memory_manager.shutdown()

Expand Down
3 changes: 2 additions & 1 deletion tests/processing/strategies/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from arroyo.processing.strategies.run_task import RunTask
from arroyo.processing.strategies.run_task_in_threads import RunTaskInThreads
from arroyo.processing.strategies.run_task_with_multiprocessing import (
MultiprocessingPool,
RunTaskWithMultiprocessing,
)
from arroyo.types import (
Expand Down Expand Up @@ -72,9 +73,9 @@ def run_task_with_multiprocessing_factory(
return RunTaskWithMultiprocessing(
partial(run_task_function, raises_invalid_message),
next_step=next_step,
num_processes=4,
max_batch_size=10,
max_batch_time=60,
pool=MultiprocessingPool(num_processes=4),
input_block_size=16384,
output_block_size=16384,
)
Expand Down
68 changes: 57 additions & 11 deletions tests/processing/strategies/test_run_task_with_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from arroyo.processing.strategies import MessageRejected
from arroyo.processing.strategies.run_task_with_multiprocessing import (
MessageBatch,
MultiprocessingPool,
RunTaskWithMultiprocessing,
ValueTooLarge,
parallel_run_task_worker_apply,
Expand Down Expand Up @@ -200,9 +201,9 @@ def test_parallel_transform_step() -> None:
transform_step = RunTaskWithMultiprocessing(
transform_payload_expand,
next_step,
num_processes=worker_processes,
max_batch_size=5,
max_batch_time=60,
pool=MultiprocessingPool(worker_processes),
input_block_size=16384,
output_block_size=16384,
)
Expand Down Expand Up @@ -290,9 +291,9 @@ def test_parallel_run_task_terminate_workers() -> None:
transform_step = RunTaskWithMultiprocessing(
transform_payload_expand, # doesn't matter
next_step,
num_processes=worker_processes,
max_batch_size=5,
max_batch_time=60,
pool=MultiprocessingPool(worker_processes),
input_block_size=4096,
output_block_size=4096,
)
Expand Down Expand Up @@ -335,9 +336,9 @@ def test_message_rejected_multiple() -> None:
strategy = RunTaskWithMultiprocessing(
count_calls,
next_step,
num_processes=1,
max_batch_size=1,
max_batch_time=60,
pool=MultiprocessingPool(num_processes=1),
input_block_size=4096,
output_block_size=4096,
)
Expand Down Expand Up @@ -461,9 +462,9 @@ def test_regression_join_timeout_one_message() -> None:
strategy = RunTaskWithMultiprocessing(
run_sleep,
next_step,
num_processes=1,
max_batch_size=1,
max_batch_time=60,
pool=MultiprocessingPool(num_processes=1),
input_block_size=4096,
output_block_size=4096,
)
Expand All @@ -490,8 +491,8 @@ def test_regression_join_timeout_many_messages() -> None:
strategy = RunTaskWithMultiprocessing(
run_sleep,
next_step,
num_processes=1,
max_batch_size=1,
pool=MultiprocessingPool(num_processes=1),
max_batch_time=60,
input_block_size=4096,
output_block_size=4096,
Expand Down Expand Up @@ -525,9 +526,9 @@ def test_input_block_resizing_max_size() -> None:
strategy = RunTaskWithMultiprocessing(
run_multiply_times_two,
next_step,
num_processes=2,
max_batch_size=NUM_MESSAGES,
max_batch_time=60,
pool=MultiprocessingPool(num_processes=2),
input_block_size=None,
output_block_size=INPUT_SIZE // 2,
max_input_block_size=16000,
Expand All @@ -552,12 +553,13 @@ def test_input_block_resizing_without_limits() -> None:
NUM_MESSAGES = INPUT_SIZE // MSG_SIZE
next_step = Mock()

pool = MultiprocessingPool(num_processes=2)
strategy = RunTaskWithMultiprocessing(
run_multiply_times_two,
next_step,
num_processes=2,
max_batch_size=NUM_MESSAGES,
max_batch_time=60,
pool=pool,
input_block_size=None,
output_block_size=INPUT_SIZE // 2,
)
Expand Down Expand Up @@ -588,9 +590,9 @@ def test_output_block_resizing_max_size() -> None:
strategy = RunTaskWithMultiprocessing(
run_multiply_times_two,
next_step,
num_processes=2,
max_batch_size=NUM_MESSAGES,
max_batch_time=60,
pool=MultiprocessingPool(num_processes=2),
input_block_size=INPUT_SIZE,
output_block_size=None,
max_output_block_size=16000,
Expand Down Expand Up @@ -619,9 +621,9 @@ def test_output_block_resizing_without_limits() -> None:
strategy = RunTaskWithMultiprocessing(
run_multiply_times_two,
next_step,
num_processes=2,
max_batch_size=NUM_MESSAGES,
max_batch_time=60,
pool=MultiprocessingPool(num_processes=2),
input_block_size=INPUT_SIZE,
output_block_size=None,
)
Expand Down Expand Up @@ -665,9 +667,9 @@ def test_multiprocessing_with_invalid_message() -> None:
strategy = RunTaskWithMultiprocessing(
message_processor_raising_invalid_message,
next_step,
num_processes=2,
max_batch_size=1,
max_batch_time=60,
pool=MultiprocessingPool(num_processes=2),
)

strategy.submit(
Expand All @@ -690,9 +692,9 @@ def test_reraise_invalid_message() -> None:
strategy = RunTaskWithMultiprocessing(
run_multiply_times_two,
next_step,
num_processes=2,
max_batch_size=1,
max_batch_time=60,
pool=MultiprocessingPool(num_processes=2),
)

strategy.submit(Message(Value(KafkaPayload(None, b"x" * 10, []), {}, now)))
Expand All @@ -703,3 +705,47 @@ def test_reraise_invalid_message() -> None:
next_step.poll.reset_mock(side_effect=True)
strategy.close()
strategy.join()


def slow_func(message: Message[int]) -> int:
time.sleep(0.2)
return message.payload


def test_reuse_pool() -> None:
# To be reused in strategy_one and strategy_two
pool = MultiprocessingPool(num_processes=2)
next_step = Mock()

strategy_one = RunTaskWithMultiprocessing(
slow_func,
next_step,
max_batch_size=2,
max_batch_time=5,
pool=pool,
)

strategy_one.submit(Message(Value(10, committable={})))

strategy_one.close()

# Join with timeout=0.0 to ensure there will be unprocessed pending messages
# in the first batch
strategy_one.join(0.0)

strategy_two = RunTaskWithMultiprocessing(
slow_func,
next_step,
max_batch_size=2,
max_batch_time=5,
pool=pool,
)

strategy_two.submit(Message(Value(10, committable={})))

strategy_two.close()

# Join with no timeout so the pending task will complete and message gets submitted
strategy_two.join()

assert next_step.submit.call_count == 1

0 comments on commit e73be70

Please sign in to comment.