diff --git a/arroyo/processing/strategies/filter.py b/arroyo/processing/strategies/filter.py index bf5be9f8..3f857f9f 100644 --- a/arroyo/processing/strategies/filter.py +++ b/arroyo/processing/strategies/filter.py @@ -3,7 +3,7 @@ from typing import Callable, MutableMapping, Optional, Union, cast from arroyo.commit import CommitPolicy, CommitPolicyState -from arroyo.processing.strategies.abstract import ProcessingStrategy +from arroyo.processing.strategies.abstract import MessageRejected, ProcessingStrategy from arroyo.types import ( FILTERED_PAYLOAD, FilteredPayload, @@ -73,7 +73,10 @@ def submit( ) -> None: assert not self.__closed + policy = self.__commit_policy_state now = time.time() + if policy is not None and policy.should_commit(now, self.__uncommitted_offsets): + self.__flush_uncommitted_offsets(now, can_backpressure=True) if not isinstance(message.payload, FilteredPayload) and self.__test_function( cast(Message[TStrategyPayload], message) @@ -86,19 +89,21 @@ def submit( if self.__commit_policy_state is not None: self.__uncommitted_offsets.update(message.committable) - policy = self.__commit_policy_state - - if policy is not None and policy.should_commit(now, message.committable): - self.__flush_uncommitted_offsets(now) - - def __flush_uncommitted_offsets(self, now: float) -> None: + def __flush_uncommitted_offsets(self, now: float, can_backpressure: bool) -> None: if not self.__uncommitted_offsets: return new_message: Message[Union[FilteredPayload, TStrategyPayload]] = Message( Value(FILTERED_PAYLOAD, self.__uncommitted_offsets) ) - self.__next_step.submit(new_message) + try: + self.__next_step.submit(new_message) + except MessageRejected: + if can_backpressure: + raise + # We have little to gain from reattempting the submission. + # Filtering is not supposed to be that expensive. + return if self.__commit_policy_state is not None: self.__commit_policy_state.did_commit(now, self.__uncommitted_offsets) @@ -115,6 +120,8 @@ def terminate(self) -> None: self.__next_step.terminate() def join(self, timeout: Optional[float] = None) -> None: - self.__flush_uncommitted_offsets(time.time()) + # We cannot let MessageRejected propagate here. join() is not supposed + # to raise this exception at all. + self.__flush_uncommitted_offsets(time.time(), can_backpressure=False) self.__next_step.close() self.__next_step.join(timeout=timeout) diff --git a/arroyo/types.py b/arroyo/types.py index 2f1c4dc6..4472fc56 100644 --- a/arroyo/types.py +++ b/arroyo/types.py @@ -35,6 +35,9 @@ class FilteredPayload: def __eq__(self, other: Any) -> bool: return isinstance(other, FilteredPayload) + def __repr__(self) -> str: + return "" + FILTERED_PAYLOAD = FilteredPayload() @@ -64,7 +67,7 @@ def __repr__(self) -> str: # ``__slots__`` for performance reasons. The class variable names # would conflict with the instance slot names, causing an error. - if type(self.payload) in (float, int): + if type(self.payload) in (float, int, bool, FilteredPayload): # For the case where value is a float or int, the repr is small and # therefore safe. This is very useful in tests. # diff --git a/tests/processing/strategies/test_filter.py b/tests/processing/strategies/test_filter.py index 3fe21e3f..f3268f47 100644 --- a/tests/processing/strategies/test_filter.py +++ b/tests/processing/strategies/test_filter.py @@ -1,8 +1,10 @@ from datetime import datetime -from typing import Union from unittest.mock import Mock, call +import pytest + from arroyo.commit import CommitPolicy +from arroyo.processing.strategies.abstract import MessageRejected from arroyo.processing.strategies.filter import FilterStep from arroyo.types import ( FILTERED_PAYLOAD, @@ -78,19 +80,22 @@ def test_function(message: Message[bool]) -> bool: # partitions, and is flushing them all out since this is the third message # and according to our commit policy we are supposed to commit at this # point, roughly. - expected_filter_message: Message[Union[FilteredPayload, bool]] = Message( - Value( - FILTERED_PAYLOAD, - {Partition(topic, 1): 1, Partition(topic, 0): 2}, + assert next_step.submit.mock_calls == [ + call( + Message( + Value( + FILTERED_PAYLOAD, {Partition(topic, 1): 1, Partition(topic, 0): 1} + ) + ) ) - ) - assert next_step.submit.mock_calls == [call(expected_filter_message)] + ] next_step.submit.reset_mock() - # Since all offsets have been recently flushed, join()/shutdown should not - # send an additional filter message filter_step.join() - assert next_step.submit.call_count == 0 + assert next_step.submit.mock_calls == [ + call(Message(Value(FILTERED_PAYLOAD, {Partition(topic, 0): 2}))) + ] + next_step.submit.reset_mock() fail_message = Message(Value(False, {Partition(topic, 0): 3}, now)) filter_step.submit(fail_message) @@ -160,3 +165,88 @@ def test_function(message: Message[bool]) -> bool: call(Message(Value(True, {Partition(topic, 1): 3}, now))), call(Message(Value(True, {Partition(topic, 1): 5}, now))), ] + + +def test_backpressure_in_join() -> None: + topic = Topic("topic") + next_step = Mock() + next_step.submit.side_effect = [None] * 6 + [MessageRejected] # type: ignore + + now = datetime.now() + + def test_function(message: Message[bool]) -> bool: + return message.payload + + filter_step = FilterStep( + test_function, next_step, commit_policy=CommitPolicy(None, 3) + ) + + filter_step.submit(Message(Value(True, {Partition(topic, 1): 1}, now))) + filter_step.submit(Message(Value(False, {Partition(topic, 1): 2}, now))) + filter_step.submit(Message(Value(True, {Partition(topic, 1): 3}, now))) + filter_step.submit(Message(Value(False, {Partition(topic, 1): 4}, now))) + filter_step.submit(Message(Value(True, {Partition(topic, 1): 5}, now))) + filter_step.submit(Message(Value(False, {Partition(topic, 1): 6}, now))) + + filter_step.join() + + assert next_step.submit.mock_calls == [ + call(Message(Value(True, {Partition(topic, 1): 1}, now))), + call(Message(Value(True, {Partition(topic, 1): 3}, now))), + call(Message(Value(FILTERED_PAYLOAD, {Partition(topic, 1): 4}))), + call(Message(Value(True, {Partition(topic, 1): 5}, now))), + call(Message(Value(FILTERED_PAYLOAD, {Partition(topic, 1): 6}))), + ] + + +def test_backpressure_in_submit() -> None: + """ + Assert that MessageRejected is propagated for the right messages, and + handled correctly in join() (i.e. suppressed) + """ + topic = Topic("topic") + next_step = Mock() + next_step.submit.side_effect = [ + MessageRejected, + None, + MessageRejected, + MessageRejected, + None, + ] + + now = datetime.now() + + def test_function(message: Message[bool]) -> bool: + return message.payload + + filter_step = FilterStep( + test_function, next_step, commit_policy=CommitPolicy(None, 3) + ) + + with pytest.raises(MessageRejected): + filter_step.submit(Message(Value(True, {Partition(topic, 1): 1}, now))) + + filter_step.submit(Message(Value(True, {Partition(topic, 1): 1}, now))) + + filter_step.submit(Message(Value(False, {Partition(topic, 1): 2}, now))) + + assert next_step.submit.mock_calls == [ + call(Message(Value(True, {Partition(topic, 1): 1}, now))), + call(Message(Value(True, {Partition(topic, 1): 1}, now))), + ] + + next_step.submit.mock_calls.clear() + + filter_step.join() + + assert next_step.submit.mock_calls == [ + call(Message(Value(FILTERED_PAYLOAD, {Partition(topic, 1): 2}))), + ] + + next_step.submit.mock_calls.clear() + + filter_step.join() + + assert next_step.submit.mock_calls == [ + call(Message(Value(FILTERED_PAYLOAD, {Partition(topic, 1): 2}))), + ]