From 98549b0b1d25f50c43a3bdca5ea078634b4f628c Mon Sep 17 00:00:00 2001
From: Gabriel de Quadros Ligneul <gabriel.ligneul@olxbr.com>
Date: Fri, 13 Mar 2020 12:43:12 -0300
Subject: [PATCH] Add on_connection_fail callback

---
 barterdude/__init__.py                        |  4 +++-
 barterdude/hooks/__init__.py                  |  7 +++++++
 barterdude/hooks/healthcheck.py               | 16 ++++++++++++++-
 barterdude/hooks/logging.py                   |  8 ++++++++
 .../hooks/metrics/prometheus/__init__.py      |  5 +++++
 .../hooks/metrics/prometheus/definitions.py   | 20 +++++++++++++++++++
 barterdude/monitor.py                         |  7 +++++++
 requirements/requirements_base.txt            |  2 +-
 tests/test__init__.py                         |  7 +++++--
 tests/test_hooks/test__init__.py              |  2 ++
 tests/test_hooks/test_healthcheck.py          |  8 ++++++++
 tests/test_hooks/test_logging.py              | 18 +++++++++++++++++
 .../test_prometheus/test__init__.py           | 10 ++++++++++
 tests/test_monitor.py                         | 11 ++++++++++
 tests_integration/helpers.py                  |  3 +++
 .../test_rabbitmq_integration.py              | 18 +++++++++++++++++
 16 files changed, 141 insertions(+), 5 deletions(-)

diff --git a/barterdude/__init__.py b/barterdude/__init__.py
index 094728c..48a91e4 100644
--- a/barterdude/__init__.py
+++ b/barterdude/__init__.py
@@ -57,7 +57,9 @@ async def process_message(message: RabbitMQMessage):
                 type=RouteTypes.AMQP_RABBITMQ,
                 options={
                     Options.BULK_SIZE: coroutines,
-                    Options.BULK_FLUSH_INTERVAL: bulk_flush_interval
+                    Options.BULK_FLUSH_INTERVAL: bulk_flush_interval,
+                    Options.CONNECTION_FAIL_CALLBACK:
+                        monitor.dispatch_on_connection_fail,
                 }
             )
             async def wrapper(messages: RabbitMQMessage):
diff --git a/barterdude/hooks/__init__.py b/barterdude/hooks/__init__.py
index 229023c..ec3c4a6 100644
--- a/barterdude/hooks/__init__.py
+++ b/barterdude/hooks/__init__.py
@@ -17,6 +17,10 @@ async def on_fail(self, message: RabbitMQMessage, error: Exception):
     async def before_consume(self, message: RabbitMQMessage):
         '''Called before consuming the message'''
 
+    @abstractmethod
+    async def on_connection_fail(self, error: Exception, retries: int):
+        '''Called when the consumer fails to connect to the broker'''
+
 
 class HttpHook(BaseHook):
     def __init__(self, barterdude: BarterDude, path: str):
@@ -37,3 +41,6 @@ async def on_fail(self, message: RabbitMQMessage, error: Exception):
 
     async def before_consume(self, message: RabbitMQMessage):
         raise NotImplementedError
+
+    async def on_connection_fail(self, error: Exception, retries: int):
+        raise NotImplementedError
diff --git a/barterdude/hooks/healthcheck.py b/barterdude/hooks/healthcheck.py
index f9bffeb..172cee2 100644
--- a/barterdude/hooks/healthcheck.py
+++ b/barterdude/hooks/healthcheck.py
@@ -27,13 +27,16 @@ def __init__(
         barterdude: BarterDude,
         path: str = "/healthcheck",
         success_rate: float = 0.95,
-        health_window: float = 60.0  # seconds
+        health_window: float = 60.0,  # seconds
+        max_connection_fails: int = 3
     ):
         self.__success_rate = success_rate
         self.__health_window = health_window
         self.__success = deque()
         self.__fail = deque()
         self.__force_fail = False
+        self.__connection_fails = 0
+        self.__max_connection_fails = max_connection_fails
         super(Healthcheck, self).__init__(barterdude, path)
 
     def force_fail(self):
@@ -48,12 +51,23 @@ async def on_success(self, message: RabbitMQMessage):
     async def on_fail(self, message: RabbitMQMessage, error: Exception):
         self.__fail.append(time())
 
+    async def on_connection_fail(self, error: Exception, retries: int):
+        self.__connection_fails = retries
+
     async def __call__(self, req: web.Request):
         if self.__force_fail:
             return _response(500, {
                 "message": "Healthcheck fail called manually"
             })
 
+        if self.__connection_fails >= self.__max_connection_fails:
+            return _response(500, {
+                "message": (
+                    "Reached max connection fails "
+                    f"({self.__max_connection_fails})"
+                )
+            })
+
         old_timestamp = time() - self.__health_window
         success = _remove_old(self.__success, old_timestamp)
         fail = _remove_old(self.__fail, old_timestamp)
diff --git a/barterdude/hooks/logging.py b/barterdude/hooks/logging.py
index 542ee92..cf18408 100644
--- a/barterdude/hooks/logging.py
+++ b/barterdude/hooks/logging.py
@@ -30,3 +30,11 @@ async def on_fail(self, message: RabbitMQMessage, error: Exception):
             "exception": repr(error),
             "traceback": format_tb(error.__traceback__),
         })
+
+    async def on_connection_fail(self, error: Exception, retries: int):
+        logger.error({
+            "message": "Failed to connect to the broker",
+            "retries": retries,
+            "exception": repr(error),
+            "traceback": format_tb(error.__traceback__),
+        })
diff --git a/barterdude/hooks/metrics/prometheus/__init__.py b/barterdude/hooks/metrics/prometheus/__init__.py
index 378b81b..571611e 100644
--- a/barterdude/hooks/metrics/prometheus/__init__.py
+++ b/barterdude/hooks/metrics/prometheus/__init__.py
@@ -71,6 +71,11 @@ async def on_success(self, message: RabbitMQMessage):
     async def on_fail(self, message: RabbitMQMessage, error: Exception):
         await self._on_complete(message, self.__definitions.FAIL, error)
 
+    async def on_connection_fail(self, error: Exception, retries: int):
+        self.metrics[self.__definitions.CONNECTION_FAIL].labels(
+            **self.__labels
+        ).inc()
+
     async def __call__(self, req: web.Request):
         return web.Response(
             content_type=CONTENT_TYPE_LATEST.split(";")[0],
diff --git a/barterdude/hooks/metrics/prometheus/definitions.py b/barterdude/hooks/metrics/prometheus/definitions.py
index 3390ec5..034ae01 100644
--- a/barterdude/hooks/metrics/prometheus/definitions.py
+++ b/barterdude/hooks/metrics/prometheus/definitions.py
@@ -11,12 +11,14 @@
 class Definitions:
 
     MESSAGE_UNITS = "messages"
+    ERROR_UNITS = "errors"
     TIME_UNITS = "seconds"
     NAMESPACE = "barterdude"
     BEFORE_CONSUME = "before_consume"
     SUCCESS = "success"
     FAIL = "fail"
     TIME_MEASURE = "time_measure"
+    CONNECTION_FAIL = "connection_fail"
 
     def __init__(
         self,
@@ -51,6 +53,11 @@ def save_metrics(self):
             namespace=self.NAMESPACE,
             unit=self.TIME_UNITS,
         )
+        self._prepare_on_connection_fail(
+            self.CONNECTION_FAIL,
+            namespace=self.NAMESPACE,
+            unit=self.ERROR_UNITS,
+        )
 
     def _prepare_before_consume(
             self, name: str, namespace: str = "", unit: str = ""):
@@ -90,3 +97,16 @@ def _prepare_time_measure(
                 unit=unit,
                 registry=self.__registry,
             )
+
+    def _prepare_on_connection_fail(
+            self, state: str, namespace: str, unit: str):
+
+        self.__metrics[state] = Counter(
+                name=f"connection_fail",
+                documentation=("Number of times barterdude failed "
+                               "to connect to the AMQP broker"),
+                labelnames=self.__labelkeys,
+                namespace=namespace,
+                unit=unit,
+                registry=self.__registry
+            )
diff --git a/barterdude/monitor.py b/barterdude/monitor.py
index d417c09..8c05575 100644
--- a/barterdude/monitor.py
+++ b/barterdude/monitor.py
@@ -43,3 +43,10 @@ async def dispatch_on_success(self, message: RabbitMQMessage):
     async def dispatch_on_fail(self, message: RabbitMQMessage,
                                error: Exception):
         await gather(*self._prepare_callbacks("on_fail", message, error))
+
+    async def dispatch_on_connection_fail(
+            self, error: Exception, retries: int
+    ):
+        await gather(*self._prepare_callbacks(
+            "on_connection_fail", error, retries
+        ))
diff --git a/requirements/requirements_base.txt b/requirements/requirements_base.txt
index f73a30e..92bad54 100644
--- a/requirements/requirements_base.txt
+++ b/requirements/requirements_base.txt
@@ -1,3 +1,3 @@
-async-worker==0.11.3
+async-worker==0.12.1
 aioamqp==0.14.0
 python-json-logger==0.1.11
diff --git a/tests/test__init__.py b/tests/test__init__.py
index d3eb52b..0be66db 100644
--- a/tests/test__init__.py
+++ b/tests/test__init__.py
@@ -36,15 +36,18 @@ def test_should_create_connection(self):
         self.App.assert_called_once_with(connections=[self.connection])
 
     def test_should_call_route_when_created(self):
+        monitor = Mock()
         self.barterdude.consume_amqp(
-            ["queue"]
+            ["queue"], monitor=monitor
         )(CoroutineMock())
         self.app.route.assert_called_once_with(
             ["queue"],
             type=RouteTypes.AMQP_RABBITMQ,
             options={
                 Options.BULK_SIZE: 10,
-                Options.BULK_FLUSH_INTERVAL: 60
+                Options.BULK_FLUSH_INTERVAL: 60,
+                Options.CONNECTION_FAIL_CALLBACK:
+                    monitor.dispatch_on_connection_fail,
             }
         )
 
diff --git a/tests/test_hooks/test__init__.py b/tests/test_hooks/test__init__.py
index 105bc3d..0245010 100644
--- a/tests/test_hooks/test__init__.py
+++ b/tests/test_hooks/test__init__.py
@@ -31,3 +31,5 @@ async def test_should_fail_when_calling_unimplemented_methods(self):
             await hook.on_success(None)
         with self.assertRaises(NotImplementedError):
             await hook.on_fail(None, None)
+        with self.assertRaises(NotImplementedError):
+            await hook.on_connection_fail(None, None)
diff --git a/tests/test_hooks/test_healthcheck.py b/tests/test_hooks/test_healthcheck.py
index 39ef9be..b8114df 100644
--- a/tests/test_hooks/test_healthcheck.py
+++ b/tests/test_hooks/test_healthcheck.py
@@ -99,3 +99,11 @@ async def test_should_erase_old_messages(self):
                 '{"message": "Success rate: 0.125 (expected: 0.9)", '
                 '"fail": 7, "success": 1, "status": "fail"}'
             )
+
+    async def test_should_fail_healthcheck_when_fail_to_connect(self):
+        await self.healthcheck.on_connection_fail(None, 3)
+        response = await self.healthcheck(Mock())
+        self.assertEqual(
+            response.body._value.decode('utf-8'),
+            '{"message": "Reached max connection fails (3)", "status": "fail"}'
+        )
diff --git a/tests/test_hooks/test_logging.py b/tests/test_hooks/test_logging.py
index 6576e91..c4efc01 100644
--- a/tests/test_hooks/test_logging.py
+++ b/tests/test_hooks/test_logging.py
@@ -48,3 +48,21 @@ async def test_should_log_on_fail(self, format_tb, repr, dumps, logger):
             "exception": repr.return_value,
             "traceback": format_tb.return_value,
         })
+
+    @patch("barterdude.hooks.logging.logger")
+    @patch("barterdude.hooks.logging.repr")
+    @patch("barterdude.hooks.logging.format_tb")
+    async def test_should_log_on_connection_fail(
+        self, format_tb, repr, logger
+    ):
+        retries = Mock()
+        exception = Exception()
+        await self.logging.on_connection_fail(exception, retries)
+        repr.assert_called_once_with(exception)
+        format_tb.assert_called_once_with(exception.__traceback__)
+        logger.error.assert_called_once_with({
+            "message": "Failed to connect to the broker",
+            "retries": retries,
+            "exception": repr.return_value,
+            "traceback": format_tb.return_value,
+        })
diff --git a/tests/test_hooks/test_metrics/test_prometheus/test__init__.py b/tests/test_hooks/test_metrics/test_prometheus/test__init__.py
index 0f3b710..0e85e80 100644
--- a/tests/test_hooks/test_metrics/test_prometheus/test__init__.py
+++ b/tests/test_hooks/test_metrics/test_prometheus/test__init__.py
@@ -139,6 +139,16 @@ async def test_should_remove_message_from_timer_on_complete(
             {}
         )
 
+    @patch("barterdude.hooks.metrics.prometheus.Prometheus.metrics")
+    async def test_should_increment_on_connection_fail_counter(self, metrics):
+        counter = Mock()
+        labels = Mock(labels=Mock(return_value=counter))
+        metrics.__getitem__ = Mock(return_value=labels)
+        await self.prometheus.on_connection_fail(Mock(), Mock())
+        self.assertEqual(labels.labels.call_count, 1)
+        labels.labels.assert_called_with(test='my_test')
+        counter.inc.assert_called_once()
+
     def test_should_call_counter(self):
         self.assertTrue(
             isinstance(self.prometheus.metrics.counter(
diff --git a/tests/test_monitor.py b/tests/test_monitor.py
index c042576..f952f05 100644
--- a/tests/test_monitor.py
+++ b/tests/test_monitor.py
@@ -37,6 +37,17 @@ async def test_should_call_hooks_on_fail(self):
         self.hook1.on_fail.assert_called_with({}, exception)
         self.hook2.on_fail.assert_called_with({}, exception)
 
+    async def test_should_call_hooks_on_connection_fail(self):
+        exception = Mock()
+        retries = Mock()
+        self.hook1.on_connection_fail = CoroutineMock()
+        self.hook2.on_connection_fail = CoroutineMock()
+        await self.monitor.dispatch_on_connection_fail(exception, retries)
+        self.hook1.on_connection_fail.assert_called_once()
+        self.hook2.on_connection_fail.assert_called_once()
+        self.hook1.on_connection_fail.assert_called_with(exception, retries)
+        self.hook2.on_connection_fail.assert_called_with(exception, retries)
+
     @patch("barterdude.monitor.logger")
     @patch("barterdude.monitor.repr")
     @patch("barterdude.monitor.format_tb")
diff --git a/tests_integration/helpers.py b/tests_integration/helpers.py
index 2dfd8c5..e2a378c 100644
--- a/tests_integration/helpers.py
+++ b/tests_integration/helpers.py
@@ -11,3 +11,6 @@ async def on_fail(self, message: RabbitMQMessage, error: Exception):
 
     async def before_consume(self, message: RabbitMQMessage):
         raise NotImplementedError
+
+    async def on_connection_fail(self, error: Exception, retries: int):
+        raise NotImplementedError
diff --git a/tests_integration/test_rabbitmq_integration.py b/tests_integration/test_rabbitmq_integration.py
index c445e0e..c9afae1 100644
--- a/tests_integration/test_rabbitmq_integration.py
+++ b/tests_integration/test_rabbitmq_integration.py
@@ -340,3 +340,21 @@ async def handler(message):
         self.assertIn("'delivery_tag': 1", cm.output[1])
         self.assertIn(f"'exception': \"{error_str}\"", cm.output[1])
         self.assertIn("'traceback': [", cm.output[1])
+
+    async def test_fails_to_connect_to_rabbitmq(self):
+        monitor = Monitor(Logging())
+
+        self.app = BarterDude(hostname="invalid_host")
+
+        @self.app.consume_amqp([self.input_queue], monitor)
+        async def handler(message):
+            pass
+
+        await self.app.startup()
+        with self.assertLogs("barterdude") as cm:
+            await asyncio.sleep(2)
+
+        self.assertIn(
+            "{'message': 'Failed to connect to the broker', 'retries': 1,",
+            cm.output[0]
+        )