From d7adaea1b5663f0eaaeff648b6aca1a91bc3d4cb Mon Sep 17 00:00:00 2001
From: Lyn Nagara <lyn.nagara@gmail.com>
Date: Thu, 30 Nov 2023 09:52:06 -0800
Subject: [PATCH] fix(dlq): Actually respect the DLQ limits (#309)

While working on porting this code to the Rust consumer, I noticed that the DLQ limits are not working in Python as expected. This is an attempt to fix that. Specifically, limits are being ignored. Even if one is defined, the Arroyo processor behaves as if there was no limit set and simply forwards all invalid messages to the DLQ.
---
 arroyo/dlq.py                  | 91 +++++++++++++++++++++++++++++-----
 arroyo/processing/processor.py |  2 +
 tests/test_dlq.py              | 25 ++++++++--
 3 files changed, 102 insertions(+), 16 deletions(-)

diff --git a/arroyo/dlq.py b/arroyo/dlq.py
index 62817392..6b5f0bf0 100644
--- a/arroyo/dlq.py
+++ b/arroyo/dlq.py
@@ -93,17 +93,56 @@ class DlqLimitState:
 
     def __init__(
         self,
-        limit: DlqLimit,
-        valid_messages: Optional[Mapping[Partition, int]] = None,
-        invalid_messages: Optional[Mapping[Partition, int]] = None,
-        invalid_consecutive_messages: Optional[Mapping[Partition, int]] = None,
+        limit: Optional[DlqLimit],
+        valid_messages: Optional[MutableMapping[Partition, int]] = None,
+        invalid_messages: Optional[MutableMapping[Partition, int]] = None,
+        invalid_consecutive_messages: Optional[MutableMapping[Partition, int]] = None,
+        last_invalid_offsets: Optional[MutableMapping[Partition, int]] = None,
     ) -> None:
         self.__limit = limit
         self.__valid_messages = valid_messages or {}
         self.__invalid_messages = invalid_messages or {}
         self.__invalid_consecutive_messages = invalid_consecutive_messages or {}
+        # Keep track of the last offset for the partition
+        self.__last_invalid_offsets: MutableMapping[Partition, int] = (
+            last_invalid_offsets or {}
+        )
+
+    def update_invalid_value(self, value: BrokerValue[TStrategyPayload]) -> None:
+        """
+        This method should be called (prior to should_accept) with each invalid value
+        to update the count of valid and invalid messages
+        """
+        if self.__limit is None:
+            return
+
+        partition = value.partition
+
+        last_invalid_offset = self.__last_invalid_offsets.get(partition)
+
+        if last_invalid_offset is not None:
+            if last_invalid_offset >= value.offset:
+                logger.error("Invalid message raised out of order")
+            elif last_invalid_offset == value.offset - 1:
+                self.__invalid_consecutive_messages[partition] = (
+                    self.__invalid_consecutive_messages.get(partition, 0) + 1
+                )
+            else:
+                valid_count = value.offset - last_invalid_offset + 1
+                self.__valid_messages[partition] = (
+                    self.__valid_messages.get(partition, 0) + valid_count
+                )
+                self.__invalid_consecutive_messages[value.partition] = 1
+
+            self.__invalid_messages[partition] = (
+                self.__invalid_messages.get(partition, 0) + 1
+            )
+            self.__last_invalid_offsets[partition] = value.offset
 
     def should_accept(self, value: BrokerValue[TStrategyPayload]) -> bool:
+        if self.__limit is None:
+            return True
+
         if self.__limit.max_invalid_ratio is not None:
             invalid = self.__invalid_messages.get(value.partition, 0)
             valid = self.__valid_messages.get(value.partition, 0)
@@ -143,7 +182,9 @@ def produce(
 
     @classmethod
     @abstractmethod
-    def build_initial_state(cls, limit: DlqLimit) -> DlqLimitState:
+    def build_initial_state(
+        cls, limit: Optional[DlqLimit], assignment: Mapping[Partition, int]
+    ) -> DlqLimitState:
         """
         Called on consumer start to build the current DLQ state
         """
@@ -164,7 +205,9 @@ def produce(
         return future
 
     @classmethod
-    def build_initial_state(cls, limit: DlqLimit) -> DlqLimitState:
+    def build_initial_state(
+        cls, limit: Optional[DlqLimit], assignment: Mapping[Partition, int]
+    ) -> DlqLimitState:
         return DlqLimitState(limit)
 
 
@@ -194,9 +237,15 @@ def produce(
         return self.__producer.produce(self.__topic, value.payload)
 
     @classmethod
-    def build_initial_state(cls, limit: DlqLimit) -> DlqLimitState:
-        # TODO: Build the current state by reading the DLQ topic in Kafka
-        return DlqLimitState(limit)
+    def build_initial_state(
+        cls, limit: Optional[DlqLimit], assignment: Mapping[Partition, int]
+    ) -> DlqLimitState:
+        # XXX: We assume the last offsets were invalid when starting the consumer
+        last_invalid = {
+            partition: offset - 1 for partition, offset in assignment.items()
+        }
+
+        return DlqLimitState(limit, last_invalid_offsets=last_invalid)
 
 
 @dataclass(frozen=True)
@@ -274,9 +323,13 @@ class DlqPolicyWrapper(Generic[TStrategyPayload]):
     Wraps the DLQ policy and manages the buffer of messages that are pending commit.
     """
 
-    def __init__(self, policy: DlqPolicy[TStrategyPayload]) -> None:
+    def __init__(
+        self,
+        policy: DlqPolicy[TStrategyPayload],
+    ) -> None:
         self.MAX_PENDING_FUTURES = 1000  # This is a per partition max
         self.__dlq_policy = policy
+
         self.__futures: MutableMapping[
             Partition,
             Deque[
@@ -286,6 +339,15 @@ def __init__(self, policy: DlqPolicy[TStrategyPayload]) -> None:
                 ]
             ],
         ] = defaultdict(deque)
+        self.reset_offsets({})
+
+    def reset_offsets(self, assignment: Mapping[Partition, int]) -> None:
+        """
+        Called on consumer assignment
+        """
+        self.__dlq_limit_state = self.__dlq_policy.producer.build_initial_state(
+            self.__dlq_policy.limit, assignment
+        )
 
     def produce(self, message: BrokerValue[TStrategyPayload]) -> None:
         """
@@ -303,15 +365,18 @@ def produce(self, message: BrokerValue[TStrategyPayload]) -> None:
                 values[0][1].result()
                 values.popleft()
 
-        future = self.__dlq_policy.producer.produce(message)
-        self.__futures[message.partition].append((message, future))
+        self.__dlq_limit_state.update_invalid_value(message)
+        should_accept = self.__dlq_limit_state.should_accept(message)
+        if should_accept:
+            future = self.__dlq_policy.producer.produce(message)
+            self.__futures[message.partition].append((message, future))
 
     def flush(self, committable: Mapping[Partition, int]) -> None:
         """
         Blocks until all messages up to the committable have been produced so
         they are safe to commit.
         """
-        for (partition, offset) in committable.items():
+        for partition, offset in committable.items():
             while len(self.__futures[partition]) > 0:
                 values = self.__futures[partition]
                 msg, future = values[0]
diff --git a/arroyo/processing/processor.py b/arroyo/processing/processor.py
index 660e8761..09356451 100644
--- a/arroyo/processing/processor.py
+++ b/arroyo/processing/processor.py
@@ -216,6 +216,8 @@ def _create_strategy(partitions: Mapping[Partition, int]) -> None:
         def on_partitions_assigned(partitions: Mapping[Partition, int]) -> None:
             logger.info("New partitions assigned: %r", partitions)
             self.__buffered_messages.reset()
+            if self.__dlq_policy:
+                self.__dlq_policy.reset_offsets(partitions)
             if partitions:
                 if self.__processing_strategy is not None:
                     logger.exception(
diff --git a/tests/test_dlq.py b/tests/test_dlq.py
index 1b024521..d8adfca0 100644
--- a/tests/test_dlq.py
+++ b/tests/test_dlq.py
@@ -9,6 +9,7 @@
 from arroyo.dlq import (
     BufferedMessages,
     DlqLimit,
+    DlqLimitState,
     DlqPolicy,
     DlqPolicyWrapper,
     InvalidMessage,
@@ -102,12 +103,12 @@ def test_dlq_policy_wrapper() -> None:
     dlq_policy = DlqPolicy(
         KafkaDlqProducer(broker.get_producer(), dlq_topic), DlqLimit(), None
     )
+    partition = Partition(topic, 0)
     wrapper = DlqPolicyWrapper(dlq_policy)
+    wrapper.reset_offsets({partition: 0})
     wrapper.MAX_PENDING_FUTURES = 1
     for i in range(10):
-        message = BrokerValue(
-            KafkaPayload(None, b"", []), Partition(topic, 0), i, datetime.now()
-        )
+        message = BrokerValue(KafkaPayload(None, b"", []), partition, i, datetime.now())
         wrapper.produce(message)
     wrapper.flush({partition: 11})
 
@@ -117,3 +118,21 @@ def test_invalid_message_pickleable() -> None:
     pickled_exc = pickle.dumps(exc)
     unpickled_exc = pickle.loads(pickled_exc)
     assert exc == unpickled_exc
+
+
+def test_dlq_limit_state() -> None:
+    starting_offset = 2
+    partition = Partition(Topic("test_topic"), 0)
+    last_invalid_offset = {partition: starting_offset}
+    limit = DlqLimit(None, 5)
+    state = DlqLimitState(limit, last_invalid_offsets=last_invalid_offset)
+
+    # 1 valid message followed by 4 invalid
+    for i in range(4, 9):
+        value = BrokerValue(i, partition, i, datetime.now())
+        state.update_invalid_value(value)
+        assert state.should_accept(value)
+
+    # Next message should not be accepted
+    state.update_invalid_value(BrokerValue(9, partition, 9, datetime.now()))
+    assert state.should_accept(value) == False