Skip to content

Commit

Permalink
disconnect on reconnect() if connected
Browse files Browse the repository at this point in the history
  • Loading branch information
vladak committed Feb 9, 2025
1 parent 1778c7c commit dc36134
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 2 deletions.
11 changes: 9 additions & 2 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,11 +939,18 @@ def reconnect(self, resub_topics: bool = True) -> int:
"""

self.logger.debug("Attempting to reconnect with MQTT broker")
subscribed_topics = []
if self.is_connected():
# disconnect() will reset subscribed topics so stash them now.
if resub_topics:
subscribed_topics = self._subscribed_topics.copy()
self.disconnect()

ret = self.connect()
self.logger.debug("Reconnected with broker")
if resub_topics:

if resub_topics and subscribed_topics:
self.logger.debug("Attempting to resubscribe to previously subscribed topics.")
subscribed_topics = self._subscribed_topics.copy()
self._subscribed_topics = []
while subscribed_topics:
feed = subscribed_topics.pop()
Expand Down
205 changes: 205 additions & 0 deletions tests/test_reconnect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# SPDX-FileCopyrightText: 2025 Vladimír Kotal
#
# SPDX-License-Identifier: Unlicense

"""reconnect tests"""

import logging
import ssl
import sys

import pytest
from mocket import Mocket

import adafruit_minimqtt.adafruit_minimqtt as MQTT

if not sys.implementation.name == "circuitpython":
from typing import Optional

from circuitpython_typing.socket import (
SocketType,
SSLContextType,
)


class FakeConnectionManager:
"""
Fake ConnectionManager class
"""

def __init__(self, socket):
self._socket = socket

def get_socket( # noqa: PLR0913, Too many arguments
self,
host: str,
port: int,
proto: str,
session_id: Optional[str] = None,
*,
timeout: float = 1.0,
is_ssl: bool = False,
ssl_context: Optional[SSLContextType] = None,
) -> SocketType:
"""
Return the specified socket.
"""
return self._socket

def close_socket(self, socket) -> None:
pass


def handle_subscribe(client, user_data, topic, qos):
"""
Record topics into user data.
"""
assert topic
assert user_data["topics"] is not None
assert qos == 0

user_data["topics"].append(topic)


def handle_disconnect(client, user_data, zero):
"""
Record disconnect.
"""

user_data["disconnect"] = True


# The MQTT packet contents below were captured using Mosquitto client+server.
testdata = [
(
[],
bytearray(
[
0x20, # CONNACK
0x02,
0x00,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x01,
0x00,
0x20, # CONNACK
0x02,
0x00,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x02,
0x00,
]
),
),
(
[("foo/bar", 0)],
bytearray(
[
0x20, # CONNACK
0x02,
0x00,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x01,
0x00,
0x20, # CONNACK
0x02,
0x00,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x02,
0x00,
]
),
),
(
[("foo/bar", 0), ("bah", 0)],
bytearray(
[
0x20, # CONNACK
0x02,
0x00,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x01,
0x00,
0x00,
0x20, # CONNACK
0x02,
0x00,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x02,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x03,
0x00,
]
),
),
]


@pytest.mark.parametrize(
"topics,to_send",
testdata,
ids=[
"no_topic",
"single_topic",
"multi_topic",
],
)
def test_reconnect(topics, to_send) -> None:
"""
Test reconnect() handling, mainly that it performs disconnect on already connected socket.
Nothing will travel over the wire, it is all fake.
"""
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

host = "localhost"
port = 1883

user_data = {"topics": [], "disconnect": False}
mqtt_client = MQTT.MQTT(
broker=host,
port=port,
ssl_context=ssl.create_default_context(),
connect_retries=1,
user_data=user_data,
)

mocket = Mocket(to_send)
mqtt_client._connection_manager = FakeConnectionManager(mocket)
mqtt_client.connect()

mqtt_client.logger = logger

if topics:
logger.info(f"subscribing to {topics}")
mqtt_client.subscribe(topics)

logger.info("reconnecting")
mqtt_client.on_subscribe = handle_subscribe
mqtt_client.on_disconnect = handle_disconnect
mqtt_client.reconnect()

assert user_data.get("disconnect") == True
assert set(user_data.get("topics")) == set([t[0] for t in topics])

0 comments on commit dc36134

Please sign in to comment.