Skip to content

Commit

Permalink
[p] Cover azul.queues with mypy (#6821)
Browse files Browse the repository at this point in the history
  • Loading branch information
hannes-ucsc committed Feb 3, 2025
1 parent 5f09a71 commit cc0ef6f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
3 changes: 2 additions & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ modules =
azul.lambdas,
azul.modules,
azul.oauth2,
azul.objects
azul.objects,
azul.queues
packages =
azul.openapi

Expand Down
48 changes: 25 additions & 23 deletions src/azul/queues.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import builtins
from collections import (
deque,
)
Expand All @@ -23,14 +24,15 @@
)
import os
import time
from typing import (
Any,
)

import more_itertools
from more_itertools import (
one,
)
from mypy_boto3_sqs.service_resource import (
Message,
Queue,
)

from azul import (
cached_property,
Expand All @@ -51,8 +53,6 @@

log = logging.getLogger(__name__)

Queue = Any # place-holder for boto3's SQS queue resource


class Queues:

Expand Down Expand Up @@ -92,8 +92,8 @@ def _dump(self, queue, path):
log.info(f'Finished writing {path!r}')
self._cleanup_messages(queue, messages)

def _get_messages(self, queue):
messages = []
def _get_messages(self, queue: Queue) -> builtins.list[Message]:
messages: list[Message] = []
while True:
message_batch = queue.receive_messages(AttributeNames=['All'],
MaxNumberOfMessages=10,
Expand Down Expand Up @@ -202,14 +202,14 @@ def _get_queue_lengths(self,
dictionary mapping each queue's name to the number of messages
in that queue.
"""
attributes = [
'ApproximateNumberOfMessages' + suffix
for suffix in ('', 'NotVisible', 'Delayed')
]
total, lengths = 0, {}
for queue_name, queue in queues.items():
queue.reload()
message_counts = [int(queue.attributes[attribute]) for attribute in attributes]
message_counts = [
int(queue.attributes['ApproximateNumberOfMessages']),
int(queue.attributes['ApproximateNumberOfMessagesNotVisible']),
int(queue.attributes['ApproximateNumberOfMessagesDelayed']),
]
length = sum(message_counts)
log.debug('Queue %s has %i message(s) (%i available, %i in flight and %i delayed).',
queue_name, length, *message_counts)
Expand All @@ -230,11 +230,12 @@ def wait_to_stabilize(self) -> int:
timeout = max(config.contribution_lambda_timeout(retry=True),
config.aggregation_lambda_timeout(retry=True))
queues = self.get_queues(config.work_queue_names)
total_lengths = deque(maxlen=ceil(timeout / sleep_time))
maxlen = ceil(timeout / sleep_time)
total_lengths: deque[int] = deque(maxlen=maxlen)
# Two minutes to safely accommodate SQS eventual consistency window of
# one minute. For more info, read WARNING section on
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sqs.html#SQS.Client.get_queue_attributes
assert total_lengths.maxlen * sleep_time >= 2 * 60
assert maxlen * sleep_time >= 2 * 60

while True:
# Determine queue lengths
Expand All @@ -246,7 +247,7 @@ def wait_to_stabilize(self) -> int:
list(reversed(total_lengths)))

min_num_zeros = 60 // sleep_time
assert min_num_zeros <= total_lengths.maxlen, min_num_zeros
assert min_num_zeros <= maxlen, min_num_zeros
num_total_lengths = len(total_lengths)
if num_total_lengths >= min_num_zeros:
if not any(islice(reversed(total_lengths), min_num_zeros)):
Expand Down Expand Up @@ -345,10 +346,12 @@ def _wait_for_queue_idle(self, queue: Queue):
queue.reload()

def _wait_for_queue_empty(self, queue: Queue):
# Gotta have some fun some of the time
attribute_names = tuple(map('ApproximateNumberOfMessages'.__add__, ('', 'Delayed', 'NotVisible')))
while True:
num_messages = sum(map(int, map(queue.attributes.get, attribute_names)))
num_messages = (
int(queue.attributes['ApproximateNumberOfMessages']) +
int(queue.attributes['ApproximateNumberOfMessagesDelayed']) +
int(queue.attributes['ApproximateNumberOfMessages'])
)
if num_messages == 0:
break
log.info('Queue %r still has %i messages', queue.url, num_messages)
Expand All @@ -359,16 +362,15 @@ def _manage_sqs_push(self, function_name, queue, enable: bool):
lambda_ = aws.lambda_
response = lambda_.list_event_source_mappings(FunctionName=function_name,
EventSourceArn=queue.attributes['QueueArn'])
mapping = one(response['EventSourceMappings'])
mapping_uuid = one(response['EventSourceMappings'])['UUID']

def update_():
log.info('%s push from %r to lambda function %r',
'Enabling' if enable else 'Disabling', queue.url, function_name)
lambda_.update_event_source_mapping(UUID=mapping['UUID'],
Enabled=enable)
lambda_.update_event_source_mapping(UUID=mapping_uuid, Enabled=enable)

state = one(response['EventSourceMappings'])['State']
while True:
state = mapping['State']
log.info('Push from %r to lambda function %r is in state %r.',
queue.url, function_name, state)
if state in ('Disabling', 'Enabling', 'Updating'):
Expand All @@ -386,7 +388,7 @@ def update_():
else:
raise NotImplementedError(state)
time.sleep(3)
mapping = lambda_.get_event_source_mapping(UUID=mapping['UUID'])
state = lambda_.get_event_source_mapping(UUID=mapping_uuid)['State']

def functions_by_queue(self) -> Mapping[str, str]:
"""
Expand Down

0 comments on commit cc0ef6f

Please sign in to comment.