diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 93926b4..00bace2 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -47,10 +47,13 @@ # MQTT Commands MQTT_PINGREQ = b"\xc0\0" MQTT_PINGRESP = const(0xD0) +MQTT_PUBLISH = const(0x30) MQTT_SUB = b"\x82" MQTT_UNSUB = b"\xA2" MQTT_DISCONNECT = b"\xe0\0" +MQTT_PKT_TYPE_MASK = const(0xF0) + # Variable CONNECT header [MQTT 3.1.2] MQTT_HDR_CONNECT = bytearray(b"\x04MQTT\x04\x02\0\0") @@ -210,7 +213,6 @@ def __init__( # LWT self._lw_topic = None self._lw_qos = 0 - self._lw_topic = None self._lw_msg = None self._lw_retain = False @@ -628,7 +630,7 @@ def publish(self, topic, msg, retain=False, qos=0): ), "Quality of Service Level 2 is unsupported by this library." # fixed header. [3.3.1.2], [3.3.1.3] - pub_hdr_fixed = bytearray([0x30 | retain | qos << 1]) + pub_hdr_fixed = bytearray([MQTT_PUBLISH | retain | qos << 1]) # variable header = 2-byte Topic length (big endian) pub_hdr_var = bytearray(struct.pack(">H", len(topic.encode("utf-8")))) @@ -877,7 +879,9 @@ def loop(self, timeout=0): def _wait_for_msg(self, timeout=0.1): # pylint: disable = too-many-return-statements - """Reads and processes network events.""" + """Reads and processes network events. + Return the packet type or None if there is nothing to be received. + """ # CPython socket module contains a timeout attribute if hasattr(self._socket_pool, "timeout"): try: @@ -898,7 +902,7 @@ def _wait_for_msg(self, timeout=0.1): if res in [None, b"", b"\x00"]: # If we get here, it means that there is nothing to be received return None - if res[0] == MQTT_PINGRESP: + if res[0] & MQTT_PKT_TYPE_MASK == MQTT_PINGRESP: if self.logger is not None: self.logger.debug("Got PINGRESP") sz = self._sock_exact_recv(1)[0] @@ -907,12 +911,21 @@ def _wait_for_msg(self, timeout=0.1): "Unexpected PINGRESP returned from broker: {}.".format(sz) ) return MQTT_PINGRESP - if res[0] & 0xF0 != 0x30: + + if res[0] & MQTT_PKT_TYPE_MASK != MQTT_PUBLISH: return res[0] + + # Handle only the PUBLISH packet type from now on. sz = self._recv_len() # topic length MSB & LSB topic_len = self._sock_exact_recv(2) topic_len = (topic_len[0] << 8) | topic_len[1] + + if topic_len > sz - 2: + raise MMQTTException( + f"Topic length {topic_len} in PUBLISH packet exceeds remaining length {sz} - 2" + ) + topic = self._sock_exact_recv(topic_len) topic = str(topic, "utf-8") sz -= topic_len + 2 @@ -921,12 +934,13 @@ def _wait_for_msg(self, timeout=0.1): pid = self._sock_exact_recv(2) pid = pid[0] << 0x08 | pid[1] sz -= 0x02 + # read message contents raw_msg = self._sock_exact_recv(sz) msg = raw_msg if self._use_binary_mode else str(raw_msg, "utf-8") if self.logger is not None: self.logger.debug( - "Receiving SUBSCRIBE \nTopic: %s\nMsg: %s\n", topic, raw_msg + "Receiving PUBLISH \nTopic: %s\nMsg: %s\n", topic, raw_msg ) self._handle_on_message(self, topic, msg) if res[0] & 0x06 == 0x02: @@ -935,6 +949,7 @@ def _wait_for_msg(self, timeout=0.1): self._sock.send(pkt) elif res[0] & 6 == 4: assert 0 + return res[0] def _recv_len(self):