diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index e74d391..27413f6 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -353,7 +353,6 @@ def _get_connect_socket(self, host: str, port: int, *, timeout: int = 1): connect_host = host sock.settimeout(timeout) - last_exception = None try: sock.connect((connect_host, port)) except MemoryError as exc: @@ -363,10 +362,9 @@ def _get_connect_socket(self, host: str, port: int, *, timeout: int = 1): raise TemporaryError from exc except OSError as exc: sock.close() - last_exception = exc - - if last_exception: - raise last_exception + self.logger.warning(f"Failed to connect: {exc}") + # Do not consider this for back-off. + raise TemporaryError from exc self._backwards_compatible_sock = not hasattr(sock, "recv_into") return sock @@ -543,10 +541,6 @@ def connect( except TemporaryError as e: self.logger.warning(f"temporary error when connecting: {e}") backoff = False - except OSError as e: - last_exception = e - self.logger.info(f"failed to connect: {e}") - backoff = True except MMQTTException as e: last_exception = e self.logger.info(f"MMQT error: {e}") diff --git a/tests/test_port_ssl.py b/tests/test_port_ssl.py index 8474b56..6263dbf 100644 --- a/tests/test_port_ssl.py +++ b/tests/test_port_ssl.py @@ -20,7 +20,7 @@ class PortSslSetup(TestCase): def test_default_port(self) -> None: """verify default port value and that TLS is not used""" host = "127.0.0.1" - port = 1883 + expected_port = 1883 with patch.object(socket.socket, "connect") as connect_mock: ssl_context = ssl.create_default_context() @@ -31,14 +31,15 @@ def test_default_port(self) -> None: connect_retries=1, ) + connect_mock.side_effect = OSError ssl_mock = Mock() ssl_context.wrap_socket = ssl_mock with self.assertRaises(MQTT.MMQTTException): - expected_port = port mqtt_client.connect() ssl_mock.assert_not_called() + connect_mock.assert_called() # Assuming the repeated calls will have the same arguments. connect_mock.assert_has_calls([call((host, expected_port))])