diff --git a/aws_sqs_consumer/consumer.py b/aws_sqs_consumer/consumer.py index 28f1244..6e22079 100644 --- a/aws_sqs_consumer/consumer.py +++ b/aws_sqs_consumer/consumer.py @@ -6,6 +6,9 @@ import boto3 import time import traceback +import threading +import atexit +from uuid import uuid4 from typing import List from .error import SQSException @@ -27,7 +30,10 @@ def __init__( batch_size=1, wait_time_seconds=1, visibility_timeout_seconds=None, - polling_wait_time_ms=0 + polling_wait_time_ms=0, + daemon: bool = True, + thread_name: str = "consumer", + threaded: bool = True ): self.queue_url = queue_url self.attribute_names = attribute_names @@ -41,6 +47,10 @@ def __init__( self.wait_time_seconds = wait_time_seconds self.visibility_timeout_seconds = visibility_timeout_seconds self.polling_wait_time_ms = polling_wait_time_ms + self.daemon = daemon + self.thread_name_prefix = "aws_sqs_thread" + thread_name + self._sqs_thread = None + self.threaded = threaded if region: self._sqs_client = sqs_client or boto3.client( "sqs", region_name=region) @@ -52,6 +62,7 @@ def __init__( raise Exception("Please specify the region parameter or set \ AWS_DEFAULT_REGION env variable.") self._running = False + atexit.register(self.stop) def handle_message(self, message: Message): """ @@ -107,7 +118,6 @@ def start(self): """ Start the consumer. """ - # TODO: Figure out threading/daemon self._running = True while self._running: response = self._sqs_client.receive_message( @@ -131,8 +141,24 @@ def stop(self): """ Stop the consumer. """ - # TODO: There's no way to invoke this other than a separate thread. self._running = False + if not self.daemon: + self._sqs_thread.join() + + def start_consumer(self): + """ + Starts the process of receiving sqs messages either in main + thread (if threaded=False) or separate thread (if threaded=True) + depending on threaded. + """ + if not self.threaded: + self.start() + else: + thread_name = self.thread_name_prefix + str(uuid4()) + self._sqs_thread = threading.Thread(target=self.start, + name=thread_name, + daemon=self.daemon) + self._sqs_thread.start() def _process_message(self, message: Message): try: diff --git a/pyproject.toml b/pyproject.toml index e73098c..81a2624 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aws_sqs_consumer" -version = "0.0.15" +version = "0.0.16" description = "AWS SQS Consumer" authors = ["Hexmos Technology "] license = "MIT"