diff --git a/faust/transport/conductor.py b/faust/transport/conductor.py index cda7858f1..21826465e 100644 --- a/faust/transport/conductor.py +++ b/faust/transport/conductor.py @@ -214,6 +214,12 @@ class Conductor(ConductorT, Service): #: to call here. _tp_to_callback: MutableMapping[TP, ConsumerCallback] + #: Lock used to synchronize access to _tp_to_callback. + #: Resubscriptions and updates to the indices may modify the mapping, and + #: while that is happening, the mapping should not be accessed by message + #: handlers. + _tp_to_callback_lock: asyncio.Lock + #: Whenever a change is made, i.e. a Topic is added/removed, we notify #: the background task responsible for resubscribing. _subscription_changed: Optional[asyncio.Event] @@ -235,6 +241,7 @@ def __init__(self, app: AppT, **kwargs: Any) -> None: self._topic_name_index = defaultdict(set) self._tp_index = defaultdict(set) self._tp_to_callback = {} + self._tp_to_callback_lock = asyncio.Lock() self._acking_topics = set() self._subscription_changed = None self._subscription_done = None @@ -266,12 +273,18 @@ def _compile_message_handler(self) -> ConsumerCallback: async def on_message(message: Message) -> None: tp = TP(topic=message.topic, partition=0) - return await get_callback_for_tp(tp)(message) + async with self._tp_to_callback_lock: + callback = get_callback_for_tp(tp) + + return await callback(message) else: async def on_message(message: Message) -> None: - return await get_callback_for_tp(message.tp)(message) + async with self._tp_to_callback_lock: + callback = get_callback_for_tp(message.tp) + + return await callback(message) return on_message @@ -309,11 +322,14 @@ async def _subscriber(self) -> None: # pragma: no cover # further subscription requests will happen during the same # rebalance. await self.sleep(self._resubscribe_sleep_lock_seconds) + + # Clear the event before updating indices. This way, new events + # that get triggered during the update will be handled the next + # time around. + ev.clear() subscribed_topics = await self._update_indices() await self.app.consumer.subscribe(subscribed_topics) - # clear the subscription_changed flag, so we can wait on it again. - ev.clear() # wake-up anything waiting for the subscription to be done. notify(self._subscription_done) @@ -328,15 +344,23 @@ async def maybe_wait_for_subscriptions(self) -> None: await self._subscription_done async def _update_indices(self) -> Iterable[str]: - self._topic_name_index.clear() - self._tp_to_callback.clear() - for channel in self._topics: - if channel.internal: - await channel.maybe_declare() - for topic in channel.topics: - if channel.acks: - self._acking_topics.add(topic) - self._topic_name_index[topic].add(channel) + async with self._tp_to_callback_lock: + self._topic_name_index.clear() + self._tp_to_callback.clear() + + # Make a (shallow) copy of the topics, so new additions to the set + # won't poison the iterator. Additions can come in while this + # function yields during an await. + topics = list(self._topics) + for channel in topics: + if channel.internal: + await channel.maybe_declare() + for topic in channel.topics: + if channel.acks: + self._acking_topics.add(topic) + self._topic_name_index[topic].add(channel) + + self._update_callback_map() return self._topic_name_index @@ -418,6 +442,7 @@ def _topic_contain_unsubscribed_topics(self, topic: TopicT) -> bool: def discard(self, topic: Any) -> None: """Unregister topic from conductor.""" self._topics.discard(topic) + self._flag_changes() def _flag_changes(self) -> None: if self._subscription_changed is not None: