From f114872edf5bfcb4a7238b538dc3c6966a82482c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Villac=C3=ADs=20Lasso?= Date: Wed, 13 Oct 2021 17:20:23 -0500 Subject: [PATCH 1/9] Wrap ESP8266-specific SSL code This is required for successful compilation in ESP32. This effectively disables fingerprint functionality on ESP32. Hopefully this will be restored in a future commit. --- src/AsyncMqttClient.cpp | 4 +++- src/AsyncMqttClient.hpp | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/AsyncMqttClient.cpp b/src/AsyncMqttClient.cpp index b4375cf..4362c5b 100644 --- a/src/AsyncMqttClient.cpp +++ b/src/AsyncMqttClient.cpp @@ -188,15 +188,17 @@ void AsyncMqttClient::_onConnect() { log_i("TCP conn, MQTT CONNECT"); #if ASYNC_TCP_SSL_ENABLED if (_secure && _secureServerFingerprints.size() > 0) { + bool sslFoundFingerprint = false; +#ifdef ESP8266 SSL* clientSsl = _client.getSSL(); - bool sslFoundFingerprint = false; for (std::array fingerprint : _secureServerFingerprints) { if (ssl_match_fingerprint(clientSsl, fingerprint.data()) == SSL_OK) { sslFoundFingerprint = true; break; } } +#endif if (!sslFoundFingerprint) { _disconnectReason = AsyncMqttClientDisconnectReason::TLS_BAD_FINGERPRINT; diff --git a/src/AsyncMqttClient.hpp b/src/AsyncMqttClient.hpp index 1e81103..b826a7b 100644 --- a/src/AsyncMqttClient.hpp +++ b/src/AsyncMqttClient.hpp @@ -19,7 +19,9 @@ #endif #if ASYNC_TCP_SSL_ENABLED +#ifdef ESP8266 #include +#endif #define SHA1_SIZE 20 #endif From 1ffa2498268c6ec219c4533fc15913626f100ae2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Villac=C3=ADs=20Lasso?= Date: Wed, 13 Oct 2021 17:56:47 -0500 Subject: [PATCH 2/9] Further hide ESP8266-specific API on ESP32 No point exposing a non-working API on ESP32. --- src/AsyncMqttClient.cpp | 8 ++++++-- src/AsyncMqttClient.hpp | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/AsyncMqttClient.cpp b/src/AsyncMqttClient.cpp index 4362c5b..d72e9c1 100644 --- a/src/AsyncMqttClient.cpp +++ b/src/AsyncMqttClient.cpp @@ -29,8 +29,10 @@ AsyncMqttClient::AsyncMqttClient() , _willQos(0) , _willRetain(false) #if ASYNC_TCP_SSL_ENABLED +#ifdef ESP8266 , _secureServerFingerprints() #endif +#endif , _onConnectUserCallbacks() , _onDisconnectUserCallbacks() , _onSubscribeUserCallbacks() @@ -130,6 +132,7 @@ AsyncMqttClient& AsyncMqttClient::setSecure(bool secure) { return *this; } +#ifdef ESP8266 AsyncMqttClient& AsyncMqttClient::addServerFingerprint(const uint8_t* fingerprint) { std::array newFingerprint; memcpy(newFingerprint.data(), fingerprint, SHA1_SIZE); @@ -137,6 +140,7 @@ AsyncMqttClient& AsyncMqttClient::addServerFingerprint(const uint8_t* fingerprin return *this; } #endif +#endif AsyncMqttClient& AsyncMqttClient::onConnect(AsyncMqttClientInternals::OnConnectUserCallback callback) { _onConnectUserCallbacks.push_back(callback); @@ -187,9 +191,9 @@ void AsyncMqttClient::_clear() { void AsyncMqttClient::_onConnect() { log_i("TCP conn, MQTT CONNECT"); #if ASYNC_TCP_SSL_ENABLED +#ifdef ESP8266 if (_secure && _secureServerFingerprints.size() > 0) { bool sslFoundFingerprint = false; -#ifdef ESP8266 SSL* clientSsl = _client.getSSL(); for (std::array fingerprint : _secureServerFingerprints) { @@ -198,7 +202,6 @@ void AsyncMqttClient::_onConnect() { break; } } -#endif if (!sslFoundFingerprint) { _disconnectReason = AsyncMqttClientDisconnectReason::TLS_BAD_FINGERPRINT; @@ -206,6 +209,7 @@ void AsyncMqttClient::_onConnect() { return; } } +#endif #endif AsyncMqttClientInternals::OutPacket* msg = new AsyncMqttClientInternals::ConnectOutPacket(_cleanSession, diff --git a/src/AsyncMqttClient.hpp b/src/AsyncMqttClient.hpp index b826a7b..a544e67 100644 --- a/src/AsyncMqttClient.hpp +++ b/src/AsyncMqttClient.hpp @@ -67,7 +67,9 @@ class AsyncMqttClient { AsyncMqttClient& setServer(const char* host, uint16_t port); #if ASYNC_TCP_SSL_ENABLED AsyncMqttClient& setSecure(bool secure); +#ifdef ESP8266 AsyncMqttClient& addServerFingerprint(const uint8_t* fingerprint); +#endif #endif AsyncMqttClient& onConnect(AsyncMqttClientInternals::OnConnectUserCallback callback); @@ -123,7 +125,9 @@ class AsyncMqttClient { bool _willRetain; #if ASYNC_TCP_SSL_ENABLED +#ifdef ESP8266 std::vector> _secureServerFingerprints; +#endif #endif std::vector _onConnectUserCallbacks; From 8708f2a6d2f1c6cb465bdcf49bdbdaa51803fa65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Villac=C3=ADs=20Lasso?= Date: Wed, 13 Oct 2021 18:11:17 -0500 Subject: [PATCH 3/9] Add support for server authentication using root CA This support simply forwards the setRootCa from AsyncTCP. Compatible with most AsyncTCP forks as well as AsyncTCPSock. --- src/AsyncMqttClient.cpp | 6 ++++++ src/AsyncMqttClient.hpp | 3 +++ 2 files changed, 9 insertions(+) diff --git a/src/AsyncMqttClient.cpp b/src/AsyncMqttClient.cpp index d72e9c1..f184a2c 100644 --- a/src/AsyncMqttClient.cpp +++ b/src/AsyncMqttClient.cpp @@ -140,6 +140,12 @@ AsyncMqttClient& AsyncMqttClient::addServerFingerprint(const uint8_t* fingerprin return *this; } #endif +#ifdef ESP32 +AsyncMqttClient& AsyncMqttClient::setRootCa(const char* rootca, const size_t len) { + _client.setRootCa(rootca, len); + return *this; +} +#endif #endif AsyncMqttClient& AsyncMqttClient::onConnect(AsyncMqttClientInternals::OnConnectUserCallback callback) { diff --git a/src/AsyncMqttClient.hpp b/src/AsyncMqttClient.hpp index a544e67..1b294de 100644 --- a/src/AsyncMqttClient.hpp +++ b/src/AsyncMqttClient.hpp @@ -70,6 +70,9 @@ class AsyncMqttClient { #ifdef ESP8266 AsyncMqttClient& addServerFingerprint(const uint8_t* fingerprint); #endif +#ifdef ESP32 + AsyncMqttClient& setRootCa(const char* rootca, const size_t len); +#endif #endif AsyncMqttClient& onConnect(AsyncMqttClientInternals::OnConnectUserCallback callback); From a61b3afe41a23ef90f7c92bd10c54fe1f9929ca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Villac=C3=ADs=20Lasso?= Date: Wed, 13 Oct 2021 18:19:37 -0500 Subject: [PATCH 4/9] Add support for client authentication using client certificate/key As with the previous commit, this simply forwards the existing support in AsyncTCP/AsyncTCPSock. Expected to be required to attempt connecting to Google Cloud or Amazon AWS servers. --- src/AsyncMqttClient.cpp | 5 +++++ src/AsyncMqttClient.hpp | 1 + 2 files changed, 6 insertions(+) diff --git a/src/AsyncMqttClient.cpp b/src/AsyncMqttClient.cpp index f184a2c..c583567 100644 --- a/src/AsyncMqttClient.cpp +++ b/src/AsyncMqttClient.cpp @@ -145,6 +145,11 @@ AsyncMqttClient& AsyncMqttClient::setRootCa(const char* rootca, const size_t len _client.setRootCa(rootca, len); return *this; } +AsyncMqttClient& AsyncMqttClient::setClientCert(const char* cli_cert, const size_t cli_cert_len, const char* cli_key, const size_t cli_key_len) { + _client.setClientCert(cli_cert, cli_cert_len); + _client.setClientKey(cli_key, cli_key_len); + return *this; +} #endif #endif diff --git a/src/AsyncMqttClient.hpp b/src/AsyncMqttClient.hpp index 1b294de..a39cdd5 100644 --- a/src/AsyncMqttClient.hpp +++ b/src/AsyncMqttClient.hpp @@ -72,6 +72,7 @@ class AsyncMqttClient { #endif #ifdef ESP32 AsyncMqttClient& setRootCa(const char* rootca, const size_t len); + AsyncMqttClient& setClientCert(const char* cli_cert, const size_t cli_cert_len, const char* cli_key, const size_t cli_key_len); #endif #endif From d1f8ecbabdbb06380032669ba6839082ac5de260 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Villac=C3=ADs=20Lasso?= Date: Wed, 13 Oct 2021 18:24:08 -0500 Subject: [PATCH 5/9] Add support for server authentication using public shared key --- src/AsyncMqttClient.cpp | 5 +++++ src/AsyncMqttClient.hpp | 1 + 2 files changed, 6 insertions(+) diff --git a/src/AsyncMqttClient.cpp b/src/AsyncMqttClient.cpp index c583567..5a5457a 100644 --- a/src/AsyncMqttClient.cpp +++ b/src/AsyncMqttClient.cpp @@ -150,6 +150,11 @@ AsyncMqttClient& AsyncMqttClient::setClientCert(const char* cli_cert, const size _client.setClientKey(cli_key, cli_key_len); return *this; } +AsyncMqttClient& AsyncMqttClient::setPsk(const char* psk_ident, const char* psk) +{ + _client.setPsk(psk_ident, psk); + return *this; +} #endif #endif diff --git a/src/AsyncMqttClient.hpp b/src/AsyncMqttClient.hpp index a39cdd5..6cefb52 100644 --- a/src/AsyncMqttClient.hpp +++ b/src/AsyncMqttClient.hpp @@ -73,6 +73,7 @@ class AsyncMqttClient { #ifdef ESP32 AsyncMqttClient& setRootCa(const char* rootca, const size_t len); AsyncMqttClient& setClientCert(const char* cli_cert, const size_t cli_cert_len, const char* cli_key, const size_t cli_key_len); + AsyncMqttClient& setPsk(const char* psk_ident, const char* psk); #endif #endif From 9f49fdb76fd011e59587807d9853ec002ab25ebe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Villac=C3=ADs=20Lasso?= Date: Sat, 23 Oct 2021 06:48:10 -0500 Subject: [PATCH 6/9] cpplint fix 1 --- src/AsyncMqttClient.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/AsyncMqttClient.cpp b/src/AsyncMqttClient.cpp index 5a5457a..c742558 100644 --- a/src/AsyncMqttClient.cpp +++ b/src/AsyncMqttClient.cpp @@ -150,8 +150,7 @@ AsyncMqttClient& AsyncMqttClient::setClientCert(const char* cli_cert, const size _client.setClientKey(cli_key, cli_key_len); return *this; } -AsyncMqttClient& AsyncMqttClient::setPsk(const char* psk_ident, const char* psk) -{ +AsyncMqttClient& AsyncMqttClient::setPsk(const char* psk_ident, const char* psk) { _client.setPsk(psk_ident, psk); return *this; } From 759b1787c65e5caf3b412c7e98e417188b5da3cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Villac=C3=ADs=20Lasso?= Date: Tue, 26 Apr 2022 18:29:40 -0500 Subject: [PATCH 7/9] Initial version of websocket client filter for MQTT protocol This is a generic websocket filter class that has no knowledge of where the stream data is coming from or how to transmit into a stream. Actual I/O is delegated to the caller. The constructor receives optional buffer sizes, which sets the maximum frame length that can be generated by the implementation. However, there is no maximum frame length limit - the class allows streaming of the RX frame data in chunks. Websocket protocol negotiation is supported but chosen protocol is not (yet) enforced or reported. --- src/AsyncMqttClient/WebsocketFilter.cpp | 718 ++++++++++++++++++++++++ src/AsyncMqttClient/WebsocketFilter.hpp | 156 +++++ 2 files changed, 874 insertions(+) create mode 100644 src/AsyncMqttClient/WebsocketFilter.cpp create mode 100644 src/AsyncMqttClient/WebsocketFilter.hpp diff --git a/src/AsyncMqttClient/WebsocketFilter.cpp b/src/AsyncMqttClient/WebsocketFilter.cpp new file mode 100644 index 0000000..348fe47 --- /dev/null +++ b/src/AsyncMqttClient/WebsocketFilter.cpp @@ -0,0 +1,718 @@ +#include +#include +#include + +//#include +#include "esp32-hal-log.h" +#include "mbedtls/base64.h" +#include "mbedtls/sha1.h" + +#include "WebsocketFilter.hpp" + +using AsyncMqttClientInternals::WebsocketFilter; + +#define MIN_BUFSIZ 256 + +#define WS_OP_DATA_CONT 0 +#define WS_OP_DATA_TEXT 1 +#define WS_OP_DATA_BINARY 2 +#define WS_OP_CTRL_START 8 +#define WS_OP_CTRL_CLOSE 8 +#define WS_OP_CTRL_PING 9 +#define WS_OP_CTRL_PONG 10 + +WebsocketFilter::WebsocketFilter(const char * hostname, const char * wsurl, + uint32_t n_protos, const char * protos[], + size_t rxbufsiz, size_t txbufsiz) +{ + _state = HANDSHAKE_TX; // Starting handshake process + _err = NO_ERROR; // No error (yet) + + if (rxbufsiz < MIN_BUFSIZ) rxbufsiz = MIN_BUFSIZ; + if (txbufsiz < MIN_BUFSIZ) txbufsiz = MIN_BUFSIZ; + _txbuf = new uint8_t[txbufsiz]; _txbufsiz = txbufsiz; _txused = 0; + _rxbuf = new uint8_t[rxbufsiz]; _rxbufsiz = rxbufsiz; _rxused = 0; + + // Initialize header response check variables + _num_respHdrs = 0; + _hs_upgrade_websocket = false; + _hs_connection_upgrade = false; + _hs_sec_websocket_accept = false; + + _rx_inheader = false; + _rx_textframe = false; + _rx_binframe = false; + _rx_lastframe = true; + _rx_packetoffset = 0; + _rx_maskidx = 0; + _rx_close_code = 0; + _rx_close_reason = NULL; + _pendingPong = false; + _pongData = NULL; + _pongDataLen = 0; + + _tx_opcode = WS_OP_DATA_CONT; + _tx_pktopen = false; + + // Calculate a 16-byte random value for Sec-WebSocket-Key header + uint32_t randomkey[4]; size_t dummy_olen; + for (auto i = 0; i < 4; i++) randomkey[i] = (uint32_t)random(); + memset(_base64_key, 0, sizeof(_base64_key)); + mbedtls_base64_encode( + (unsigned char *)_base64_key, sizeof(_base64_key), + &dummy_olen, + (unsigned char *)randomkey, 16); + + // Start at beginning of first handshake string + _sent_handshake[0] = 0; + _sent_handshake[1] = 0; + + // Allocate list of handshake strings to send + _num_handshakeStrings = 8; + if (n_protos > 0 && protos != NULL) _num_handshakeStrings += 1 + 2 * n_protos; + _handshakeStrings = new const char *[_num_handshakeStrings]; + + _handshakeStrings[0] = "GET "; + _handshakeStrings[1] = wsurl; + _handshakeStrings[2] = " HTTP/1.1\r\nHost: "; + _handshakeStrings[3] = hostname; + _handshakeStrings[4] = "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: "; + _handshakeStrings[5] = _base64_key; + _handshakeStrings[6] = "\r\nSec-WebSocket-Version: 13\r\n"; + auto idx = 7; + if (n_protos > 0 && protos != NULL) { + _handshakeStrings[idx] = "Sec-WebSocket-Protocol: "; + idx++; + for (auto i = 0; i < n_protos; i++) { + _handshakeStrings[idx] = protos[i]; + idx++; + _handshakeStrings[idx] = (i < n_protos - 1) ? ", " : "\r\n"; + idx++; + } + } + _handshakeStrings[idx] = "\r\n"; + + _advanceHandshakeTX(); +} + +WebsocketFilter::~WebsocketFilter() +{ + if (_handshakeStrings != NULL) { + delete[] _handshakeStrings; + _handshakeStrings = NULL; + _num_handshakeStrings = 0; + } + + if (_rx_close_reason != NULL) delete[] _rx_close_reason; + _rx_close_reason = NULL; + + if (_pongData != NULL) delete[] _pongData; + _pongData = NULL; + + delete[] _rxbuf; + delete[] _txbuf; +} + +void WebsocketFilter::_advanceHandshakeTX(void) +{ + if (_handshakeStrings == NULL) return; + if (_state != HANDSHAKE_TX) return; + +#define HS_STRIDX (_sent_handshake[0]) +#define HS_STROFF (_sent_handshake[1]) +#define HS_STR (_handshakeStrings[_sent_handshake[0]]) + + // Copy as many of the remaining strings of handshake into output buffer as + // possible, until running out of space or end of strings. + while (_txused < _txbufsiz && HS_STRIDX < _num_handshakeStrings) { + auto len = strlen(HS_STR); + auto len_copy = len - HS_STROFF; + auto txavail = _txbufsiz - _txused; + if (len_copy > txavail) len_copy = txavail; + memcpy(_txbuf + _txused, HS_STR + HS_STROFF, len_copy); + _txused += len_copy; + HS_STROFF += len_copy; + + // Go to next string of handshake + if (HS_STROFF >= len) { + HS_STRIDX++; + HS_STROFF = 0; + } + } + + if (HS_STRIDX >= _num_handshakeStrings) { + // End of handshake queued, switching to handshake response + _state = HANDSHAKE_RX; + delete[] _handshakeStrings; + _handshakeStrings = NULL; + _num_handshakeStrings = 0; + } +} + +void WebsocketFilter::_discardTxData(size_t n_bytes) +{ + if (n_bytes <= 0) return; + if (n_bytes < _txused) { + memmove(_txbuf, _txbuf + n_bytes, _txused - n_bytes); + _txused -= n_bytes; + } else { + _txused = 0; + } +} + +void WebsocketFilter::_discardRxData(size_t n_bytes) +{ + if (n_bytes <= 0) return; + if (n_bytes < _rxused) { + memmove(_rxbuf, _rxbuf + n_bytes, _rxused - n_bytes); + _rxused -= n_bytes; + } else { + _rxused = 0; + } +} + +void WebsocketFilter::fetchDataPtrForStream(uint8_t * & buffer, size_t & n_bytes) +{ + buffer = _txbuf; + n_bytes = _txused; +} + +void WebsocketFilter::discardFetchedData(size_t n_bytes) +{ + if (n_bytes > _txused) n_bytes = _txused; + if (n_bytes <= 0) return; + + _discardTxData(n_bytes); + + if (_state == HANDSHAKE_TX) { + // Some space was freed from buffer, advance handshake + _advanceHandshakeTX(); + } + + if (_state == WEBSOCKET_OPEN && _pendingPong) { + // Check if pending pong can be enqueued in freed space + if (2 + _pongDataLen <= _txbufsiz - _txused) { + _enqueueOutgoingFrame(WS_OP_CTRL_PONG, _pongDataLen, + (_pongDataLen > 0) ? _pongData : NULL, false, true); + if (_pongData != NULL) delete[] _pongData; + _pongData = NULL; + _pongDataLen = 0; + _pendingPong = false; + } + } +} + +bool WebsocketFilter::fetchDataForStream(size_t max_size, uint8_t * buffer, size_t & n_bytes) +{ + if (max_size <= 0 || buffer == NULL) return false; + + n_bytes = _txused; + if (n_bytes > max_size) n_bytes = max_size; + if (n_bytes <= 0) return true; + + memcpy(buffer, _txbuf, n_bytes); + discardFetchedData(n_bytes); + + return true; +} + +size_t WebsocketFilter::addDataFromStream(size_t size, const uint8_t * buffer) +{ + if (size > _rxbufsiz - _rxused) size = _rxbufsiz - _rxused; + if (size > 0) { + memcpy(_rxbuf + _rxused, buffer, size); + _rxused += size; + + if (_state == HANDSHAKE_RX) { + // New data is available, advance handshake response check + _runHandshakeResponseCheck(); + } + + if (_state == WEBSOCKET_OPEN && _rxused > 0) { + // Received frame header or data, analyse and remove frame headers + _runFrameDataParse(); + } + } + + return size; +} + +void WebsocketFilter::_runHandshakeResponseCheck(void) +{ + // Each iteration requires at least one complete header to be present at + // the beginning of the rx buffer. + while (_state == HANDSHAKE_RX && _rxused > 0) { + log_v("_num_respHdrs = %u _rxused = %u", _num_respHdrs, _rxused); + // A null character is not allowed anywhere in the response. Finding one + // means the handshake failed (maybe stream is not plaintext HTTP?) + if (NULL != memchr(_rxbuf, '\0', _rxused)) { + log_w("NULL found in handshake response (non-HTTP stream?)"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + + // A complete header must terminate in \r\n. No solitary \r is allowed. + uint8_t * cr = (uint8_t *)memchr(_rxbuf, '\r', _rxused); + if (cr == NULL) { + // \r was not found. Maybe we need to wait for more data, unless + // buffer is full, which means the "header" is overlong, or (again) + // not an plaintext HTTP response. + if (_rxused >= _rxbufsiz) { + log_w("overlong header (1) in handshake response (non-HTTP stream?)"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + } + break; + } + if ((cr - _rxbuf) == _rxused - 1) { + // \r found at the very edge of the response. Maybe we need to wait + // for more data, unless buffer is full, which means the "header" + // is overlong + if (_rxused >= _rxbufsiz) { + log_w("overlong header (2) in handshake response (non-HTTP stream?)"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + } + break; + } + + // Now it is safe to check whether next character is \n + if (*(cr + 1) != '\n') { + log_w("solitary CR in handshake response (non-HTTP stream?)"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + + char * hdr = (char *)_rxbuf; + size_t hdrlen = (cr - _rxbuf) + 2; // Length of header including \r\n + + *cr = '\0'; // Overwrite CR with null to treat header as C-string + log_v("Examining hdr: [%s]", hdr); + + if (_num_respHdrs == 0) { + // First header MUST be the response: + // HTTP/1.1 101 Switching Protocols + if (strstr(hdr, "HTTP/1.1 101 ") != hdr) { + log_w("invalid HTTP response: %s", hdr); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + } else { + if (strlen(hdr) > 0) { + // Rest of headers must be of the form + // KEY: VALUE + char * k = hdr; + char * v = strstr(hdr, ": "); + if (v == NULL) { + log_w("invalid HTTP header: %s", hdr); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + *v = '\0'; v += 2; + while (*v == ' ') v++; + + // So... what header do we have here...? + if (0 == strcasecmp(k, "Upgrade")) { + // This header MUST exist and MUST have the value: websocket + if (0 != strcasecmp(v, "websocket")) { + // Invalid Upgrade header + log_w("invalid Upgrade header: %s", v); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + _hs_upgrade_websocket = true; + } else if (0 == strcasecmp(k, "Connection")) { + // This header MUST exist and MUST have the value: Upgrade + if (0 != strcasecmp(v, "Upgrade")) { + // Invalid Upgrade header + log_w("invalid Connection header: %s", v); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + _hs_connection_upgrade = true; + } else if (0 == strcasecmp(k, "Sec-WebSocket-Accept")) { + // This header MUST exist, and its value MUST match the + // resulting base64 string for the SHA1 header. + char wskey[61]; + strcpy(wskey, _base64_key); + strcat(wskey, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + + unsigned char sha1_output[20]; + mbedtls_sha1_ret((unsigned char *)wskey, 60, sha1_output); + + char expected_base64[29]; size_t dummy_olen; + memset(expected_base64, 0, sizeof(expected_base64)); + mbedtls_base64_encode( + (unsigned char *)expected_base64, sizeof(expected_base64), + &dummy_olen, + (unsigned char *)sha1_output, sizeof(sha1_output)); + + if (0 != strcmp(v, expected_base64)) { + // Invalid Sec-WebSocket-Accept + log_w("invalid Sec-WebSocket-Accept: expected %s found %s", expected_base64, v); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + _hs_sec_websocket_accept = true; + } else if (0 == strcasecmp(k, "Sec-WebSocket-Protocol")) { + // Ignore protocol for now... + log_d("(ignored protocol) %s: %s", k, v); + } else { + log_v("(ignored header) %s: %s", k, v); + } + } else { + // Reached end of response. All checks must have passed + if (!_hs_upgrade_websocket) { + log_w("missing header Upgrade: websocket"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + } else if (!_hs_connection_upgrade) { + log_w("missing header Connection: Upgrade"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + } else if (!_hs_sec_websocket_accept) { + log_w("missing header Sec-WebSocket-Accept"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + } else { + // Success! Packet exchange can start now. + _state = WEBSOCKET_OPEN; + _rx_inheader = true; + } + } + } + + // Header processed, may be discarded now + _num_respHdrs++; + _discardRxData(hdrlen); + } +} + +void WebsocketFilter::_runFrameDataParse() +{ + while (_state == WEBSOCKET_OPEN && _rxused > 0) { + + if (_rx_inheader) { + // Need to accumulate one complete header to evaluate further + if (_rxused < 2) return; + + // Check if there is enough data for header + size_t hdrlength = 2; + uint8_t mask = _rxbuf[1]; + if ((mask & 0x7f) == 0x7f) { + hdrlength = 10; + } else if ((mask & 0x7f) == 0x7e) { + hdrlength = 4; + } + if (_rxused < hdrlength + ((mask & 0x80) ? 4 : 0)) return; + + _rx_opcode = _rxbuf[0]; + _rx_framelen = 0; + if ((mask & 0x7f) == 0x7f) { + // Extract 64-bit frame length, network byte order + for (auto i = 0; i < 8; i++) _rx_framelen = (_rx_framelen << 8) | _rxbuf[i+2]; + } else if ((mask & 0x7f) == 0x7e) { + // Extract 16-bit frame length, network byte order + for (auto i = 0; i < 2; i++) _rx_framelen = (_rx_framelen << 8) | _rxbuf[i+2]; + } else { + // 7-bit frame length + _rx_framelen = mask & 0x7f; + } + log_v("hdrlength=%u frame length=%u", hdrlength, (uint32_t)_rx_framelen); + + // Extract last-frame flag, clear reserved flags (ignored) + bool lastframe = ((_rx_opcode & 0x80) != 0); + _rx_opcode &= 0x0F; + + // Use identity mask unless flag set by server (NON-COMPLIANT) + memset(_rx_maskdata, 0, 4); + if (mask & 0x80) { + log_w("mask data not compliant for client recv!"); + memcpy(_rx_maskdata, _rxbuf + hdrlength, 4); + } + _rx_maskidx = 0; + + _discardRxData(hdrlength + ((mask & 0x80) ? 4 : 0)); + _rx_inheader = false; + + switch (_rx_opcode) { + case WS_OP_DATA_TEXT: // Start of text frame + log_v("start of text frame"); + _rx_binframe = false; + _rx_textframe = true; + if (!_rx_lastframe) { + log_e("expecting continuation frame"); + _state = WEBSOCKET_ERROR; + _err = PROTOCOL_ERROR; + } + _rx_lastframe = lastframe; + _rx_packetoffset = 0; + break; + case WS_OP_DATA_BINARY: // Start of binary frame + log_v("start of binary frame"); + _rx_binframe = true; + _rx_textframe = false; + if (!_rx_lastframe) { + log_e("expecting continuation frame"); + _state = WEBSOCKET_ERROR; + _err = PROTOCOL_ERROR; + } + _rx_lastframe = lastframe; + _rx_packetoffset = 0; + break; + case WS_OP_DATA_CONT: // Continuation frame + log_v("continuation frame"); + _rx_lastframe = lastframe; + if (!_rx_textframe && !_rx_binframe) { + log_e("continuation frame without start frame!"); + _state = WEBSOCKET_ERROR; + _err = PROTOCOL_ERROR; + } + break; + case WS_OP_CTRL_CLOSE: + case WS_OP_CTRL_PING: + case WS_OP_CTRL_PONG: + log_v("control frame 0x%02x", _rx_opcode); + if (!lastframe) { + log_e("fragmented control frame 0x%02x", _rx_opcode); + _state = WEBSOCKET_ERROR; + _err = PROTOCOL_ERROR; + } + if (_rx_framelen > 125) { + log_e("overlong control frame 0x%02x", _rx_opcode); + _state = WEBSOCKET_ERROR; + _err = PROTOCOL_ERROR; + } + break; + default: + log_e("unimplemented opcode 0x%02x", _rx_opcode); + _state = WEBSOCKET_ERROR; + _err = PROTOCOL_ERROR; + break; + } + } + + if (!_rx_inheader && _state == WEBSOCKET_OPEN) { + // Some data from payload, picked only if control code. According to + // RFC 6455, any payload for a control code must be no greater than 125 + // bytes. + if (_rx_opcode >= WS_OP_CTRL_START && _rxused >= _rx_framelen) { + switch (_rx_opcode) { + case WS_OP_CTRL_CLOSE: + if (_rx_framelen >= 2) { + // 2 byte code, rest is text explanation + _rx_close_code = (((uint16_t)_rxbuf[0]) << 8) | _rxbuf[1]; + if (_rx_framelen > 2) { + _rx_close_reason = new char[_rx_framelen - 2 + 1]; + memcpy(_rx_close_reason, _rxbuf + 2, _rx_framelen - 2); + _rx_close_reason[_rx_framelen - 2] = '\0'; + } + } + _state = WEBSOCKET_CLOSED; + _err = REMOTE_CTRL_CLOSE; + break; + case WS_OP_CTRL_PING: + // Check if PONG packet can be immediately enqueued + if (2 + _rx_framelen <= _txbufsiz - _txused) { + // Enqueue PONG packet immediately + _enqueueOutgoingFrame(WS_OP_CTRL_PONG, _rx_framelen, + (_rx_framelen > 0) ? _rxbuf : NULL, false, true); + } else { + // Copy PING payload to be enqueued later + _pendingPong = true; + if (_pongData != NULL) delete[] _pongData; + _pongData = NULL; + _pongDataLen = _rx_framelen; + if (_pongDataLen > 0) { + _pongData = new uint8_t[_pongDataLen]; + memcpy(_pongData, _rxbuf, _pongDataLen); + } + } + break; + case WS_OP_CTRL_PONG: + // Ignored + break; + } + + if (_rx_framelen > 0) _discardRxData(_rx_framelen); + _rx_framelen = 0; + _rx_inheader = true; // Process next RX frame header + } + } + + if (!_rx_inheader) break; + } +} + +bool WebsocketFilter::isPacketDataAvailable(void) +{ + if (_state != WEBSOCKET_OPEN) return false; + if (_err != NO_ERROR) return false; + return !_rx_inheader && (_rxused > 0); +} + +void WebsocketFilter::fetchPacketData(size_t max_size, uint8_t * buffer, + size_t & n_bytes, uint64_t & packet_offset, bool & packet_binary, + bool & last_fetch) +{ + n_bytes = 0; + packet_offset = _rx_packetoffset; + packet_binary = true; + last_fetch = false; + + if (max_size <= 0 || buffer == NULL) return; + if (_rx_textframe) packet_binary = false; + if (_rx_binframe) packet_binary = true; + + while (!last_fetch && !_rx_inheader && _state == WEBSOCKET_OPEN && max_size > 0 && _rxused > 0) { + size_t copylen = max_size; + if (copylen > _rx_framelen) copylen = _rx_framelen; + if (copylen > _rxused) copylen = _rxused; + + for (auto i = 0; i < copylen; i++) { + buffer[i] = _rxbuf[i] ^ _rx_maskdata[_rx_maskidx]; + _rx_maskidx = (_rx_maskidx + 1) & 3; + } + buffer += copylen; + max_size -= copylen; + n_bytes += copylen; + _rx_packetoffset += copylen; + _rx_framelen -= copylen; + + _discardRxData(copylen); + + if (_rx_framelen <= 0) { + // Copied all data from current frame. Any following data MUST be + // the header of the next frame. + _rx_inheader = true; + if (_rx_lastframe) { + // Last frame of packet + last_fetch = true; + _rx_textframe = false; + _rx_binframe = false; + } else { + // Not the last frame, might be more data to copy + } + + _runFrameDataParse(); + } + } +} + +bool WebsocketFilter::_enqueueOutgoingFrame(uint8_t opcode, uint32_t len, + const uint8_t * payload, bool masked, bool lastframe) +{ + if (len > 0 && payload == NULL) return false; + + // Calculate required length for full frame + uint32_t headerlen = 2; + if (len > 65535) + headerlen += 8; + else if (len > 125) + headerlen += 2; + if (masked) headerlen += 4; + if (headerlen + len > _txbufsiz - _txused) return false; + + uint8_t * hdr = _txbuf + _txused; + uint8_t * p = hdr + headerlen; + + hdr[0] = opcode; if (lastframe) hdr[0] |= 0x80; + if (len > 65535) { + hdr[1] = 127; + hdr[2] = hdr[3] = hdr[4] = hdr[5] = 0; + hdr[6] = (len >> 24) & 0xFF; + hdr[7] = (len >> 16) & 0xFF; + hdr[8] = (len >> 8) & 0xFF; + hdr[9] = len & 0xFF; + } else if (len > 125) { + hdr[1] = 126; + hdr[2] = (len >> 8) & 0xFF; + hdr[3] = len & 0xFF; + } else { + hdr[1] = len; + } + if (masked) { + hdr[1] |= 0x80; + + uint8_t * tx_maskdata = p - 4; + *((uint32_t *)tx_maskdata) = random(); + log_v("mask = 0x%08x", *((uint32_t *)tx_maskdata)); + + for (auto i = 0; i < len; i++) p[i] = payload[i] ^ tx_maskdata[i & 3]; + } else { + if (len > 0) memcpy(p, payload, len); + } + + _txused += headerlen + len; + + return true; +} + +bool WebsocketFilter::startPacket(bool packet_binary) +{ + if (_state != WEBSOCKET_OPEN) return false; + if (_tx_pktopen) return false; + uint8_t n_opcode = packet_binary ? WS_OP_DATA_BINARY : WS_OP_DATA_TEXT; + if (_tx_opcode != WS_OP_DATA_CONT && _tx_opcode != n_opcode) return false; + _tx_opcode = n_opcode; + _tx_pktopen = true; + return true; +} + +size_t WebsocketFilter::addPacketData(size_t size, const uint8_t * buffer, bool endOfPacket) +{ + if (_state != WEBSOCKET_OPEN) { + log_w("invalid websocket state %u", _state); + return 0; + } + if (size == 0 || buffer == NULL) { + log_e("invalid size or ptr"); + return 0; + } + if (!_tx_pktopen) { + log_w("attempt to add data to not-opened packet!"); + return 0; + } + + uint8_t fixedhdr = 2 + 4; // two-byte minimum plus mask data + size_t txavail = _txbufsiz - _txused; + + if (txavail <= fixedhdr) return 0; // Not enough space for even minimum header + + size_t copylen = txavail - fixedhdr; + if (copylen > size) copylen = size; + + if (copylen > 65535U) { + // 8-byte size field required, check if it fits + if (fixedhdr + 8 + copylen > txavail) { + copylen = txavail - fixedhdr - 8; + } + } + if (copylen > 125 && copylen <= 65535U) { + // 2-byte size field required, check if it fits + if (fixedhdr + 2 + copylen > txavail) { + copylen = txavail - fixedhdr - 2; + } + } + + // Cannot signal end-of-packet unless indicated by parameter AND all + // supplied data actually fits in available space + if (endOfPacket && copylen < size) endOfPacket = false; + + if (!_enqueueOutgoingFrame(_tx_opcode, copylen, buffer, true, endOfPacket)) { + log_e("BUG: incorrect calculation of required space: size=%u copylen=%u txavail=%u", + size, copylen, txavail); + return 0; + } + _tx_opcode = WS_OP_DATA_CONT; + if (endOfPacket) _tx_pktopen = false; + + return copylen; +} \ No newline at end of file diff --git a/src/AsyncMqttClient/WebsocketFilter.hpp b/src/AsyncMqttClient/WebsocketFilter.hpp new file mode 100644 index 0000000..8a7d8bf --- /dev/null +++ b/src/AsyncMqttClient/WebsocketFilter.hpp @@ -0,0 +1,156 @@ +#pragma once + +#include + +namespace AsyncMqttClientInternals { + +typedef enum { + NO_ERROR = 0, + HANDSHAKE_FAILED, // Websocket handshake failed + REMOTE_CTRL_CLOSE, // Remote side sent a close packet + PROTOCOL_ERROR // Invalid condition encountered in received frame +} WebsocketError; + +typedef enum { + HANDSHAKE_TX, // Client is sending HTTP handshake + HANDSHAKE_RX, // Client is receiving HTTP response to handshake + HANDSHAKE_ERR, // Handshake failed at some point + + WEBSOCKET_OPEN, // Handshake succeeded, websocket is sending/receiving packets + WEBSOCKET_CLOSED, // Received close packet from remote side + WEBSOCKET_ERROR // Websocket protocol error +} WebsocketState; + +class WebsocketFilter +{ +private: + WebsocketState _state; // Current state of websocket + WebsocketError _err; // Current error status + + uint8_t * _txbuf; size_t _txbufsiz; size_t _txused; + uint8_t * _rxbuf; size_t _rxbufsiz; size_t _rxused; + + // List of HTTP strings for handshake + const char ** _handshakeStrings; + size_t _num_handshakeStrings; + size_t _sent_handshake[2]; // index,offset of sent handshake + + // base64 representation of 16-bit random value (24 bytes) plus null + // for Sec-WebSocket-Key header. + char _base64_key[25]; + + // Variables for response checks + unsigned int _num_respHdrs; + bool _hs_upgrade_websocket; + bool _hs_connection_upgrade; + bool _hs_sec_websocket_accept; + + // RX: are we expecting a frame header? + bool _rx_inheader; + bool _rx_textframe; + bool _rx_binframe; + bool _rx_lastframe; + uint8_t _rx_opcode; + uint64_t _rx_framelen; // Length of payload portion of frame, without headers/mask + uint64_t _rx_packetoffset; + uint8_t _rx_maskdata[4]; + uint8_t _rx_maskidx; + + // Close reason handling + uint16_t _rx_close_code; + char * _rx_close_reason; + + // Pending PONG handling + bool _pendingPong; + uint8_t * _pongData; + uint8_t _pongDataLen; + + uint8_t _tx_opcode; + bool _tx_pktopen; + + void _discardTxData(size_t); + void _discardRxData(size_t); + void _advanceHandshakeTX(void); + void _runHandshakeResponseCheck(void); + + void _runFrameDataParse(void); + bool _enqueueOutgoingFrame(uint8_t opcode, uint32_t len, const uint8_t * payload, + bool masked, bool lastframe); +public: + WebsocketFilter(const char * hostname, const char * wsurl = "/", + uint32_t n_protos = 0, const char * protos[] = NULL, + size_t rxbufsiz = 256, size_t txbufsiz = 256); + ~WebsocketFilter(); + + WebsocketState getState(void) { return _state; } + WebsocketError getError(void) { return _err; } + + uint16_t getCloseCode(void) { return _rx_close_code; } + const char * getCloseReason(void) { return _rx_close_reason; } + + // Push raw bytes received from a byte stream into websocket RX parser. + // Returns number of bytes actually consumed from buffer, which might be + // zero if the internal buffer is full. + size_t addDataFromStream(size_t size, const uint8_t * buffer); + + // Query number of pending bytes to be sent to stream + size_t pendingStreamOutputLen(void) { return _txused; } + + // Fetch raw bytes to be sent to byte stream from websocket TX. This + // places the number of bytes actually placed into the buffer into n_bytes, + // up to the limit of max_size. After fetching the outgoing data, this data + // is discarded and the space reused for further processing. Returns true + // for success (including n_bytes <- 0 for empty buffer), or false for + // failure (usually a pending error). + bool fetchDataForStream(size_t max_size, uint8_t * buffer, size_t & n_bytes); + + // Fetch a pointer to the internal buffer of raw bytes queued for output. + // Unlike fetchDataForStream(), this method will skip the copy to the + // externally-supplied buffer. However, the app must now remember to call + // discardFetchedData() with the total amount of bytes that were successfully + // written to the output stream. After calling this function, the buffer + // is guaranteed to be valid and unchanged up to n_bytes, until the call to + // discardFetchedData(). + void fetchDataPtrForStream(uint8_t * & buffer, size_t & n_bytes); + + // Discard data that was written to output stream. Companion to fetchDataPtrForStream(). + void discardFetchedData(size_t n_bytes); + + + // Check if some INCOMING frame data is available to be fetched + bool isPacketDataAvailable(void); + + // Fetch some of the current available packet data: + // - max_bytes, buffer: (INPUT) size and location of output buffer + // - n_bytes: number of bytes actually fetched, up to max_bytes + // - packet_offset: offset from start of packet of fetched data + // - packet_binary: true if packet is binary, false if text + // - last_fetch: true if fetch served last of packet data + // + // This method will merge together multiple frames belonging to the same + // fragmented packet, if applicable. + // After calling this function, n_bytes of packet data are discarded. + void fetchPacketData(size_t max_size, uint8_t * buffer, size_t & n_bytes, + uint64_t & packet_offset, bool & packet_binary, bool & last_fetch); + + + + // Start an outgoing packet, with type of either text or binary + bool startPacket(bool packet_binary); + + bool isOpenPacket(void) { return _tx_pktopen; } + + // Add data to the outgoing packet started by startPacket(). This returns + // the number of bytes actually consumed from the buffer and made into a + // websocket frame. The app SHOULD call fetchDataForStream() at least once + // after calling this function, in order to fetch resulting packet data and + // free space for further processing. The endOfPacket flag MUST be set + // if the buffer contains the last byte of data to be sent to the stream. + // NOTE: if the endOfPacket flag is set, BUT the method returns less than + // the specified buffer size, the app MUST fetch some data to be sent to + // the stream using fetchDataForStream() in order to free some space, then + // retry with the remaining unsent buffer, and endOfPacket again set. + size_t addPacketData(size_t size, const uint8_t * buffer, bool endOfPacket); +}; + +} // namespace AsyncMqttClientInternals From 56f95e96f7e45f0bd57e8df4c33f1b58323cc174 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Villac=C3=ADs=20Lasso?= Date: Wed, 27 Apr 2022 12:52:12 -0500 Subject: [PATCH 8/9] Wire up AsyncMqttClient to use websocket protocol Publicly exposed methods: - setWsEnabled(bool): toggle use of websockets before connect() - setWsUri(const char*): set URI endpoint where websocket is exposed (default "/") Tested with Eclipse Mosquitto served behind Apache 2.4 as websocket proxy. --- src/AsyncMqttClient.cpp | 207 ++++++++++++++++++++++++++++++++++++---- src/AsyncMqttClient.hpp | 12 +++ 2 files changed, 203 insertions(+), 16 deletions(-) diff --git a/src/AsyncMqttClient.cpp b/src/AsyncMqttClient.cpp index b4375cf..7f9cd87 100644 --- a/src/AsyncMqttClient.cpp +++ b/src/AsyncMqttClient.cpp @@ -31,6 +31,10 @@ AsyncMqttClient::AsyncMqttClient() #if ASYNC_TCP_SSL_ENABLED , _secureServerFingerprints() #endif +, _wsfilter(nullptr) +, _wsEnabled(false) +, _wsUri("/") +, _wsHost(nullptr) , _onConnectUserCallbacks() , _onDisconnectUserCallbacks() , _onSubscribeUserCallbacks() @@ -138,6 +142,16 @@ AsyncMqttClient& AsyncMqttClient::addServerFingerprint(const uint8_t* fingerprin } #endif +AsyncMqttClient& AsyncMqttClient::setWsEnabled(bool wsenabled) { + _wsEnabled = wsenabled; + return *this; +} + +AsyncMqttClient& AsyncMqttClient::setWsUri(const char * uri) { + _wsUri = uri; + return *this; +} + AsyncMqttClient& AsyncMqttClient::onConnect(AsyncMqttClientInternals::OnConnectUserCallback callback) { _onConnectUserCallbacks.push_back(callback); return *this; @@ -205,6 +219,28 @@ void AsyncMqttClient::_onConnect() { } } #endif + + if (_wsEnabled) { + static const char * ws_mqtt_protos[1] = { "mqtt" }; + const char * tmpHostPtr; + + if (_wsHost != NULL) delete[] _wsHost; + _wsHost = NULL; + if (_useIp) { + auto sIP = _ip.toString(); + _wsHost = new char[sIP.length() + 1]; + strcpy(_wsHost, sIP.c_str()); + tmpHostPtr = _wsHost; + } else { + tmpHostPtr = _host; + } + + _wsfilter = new AsyncMqttClientInternals::WebsocketFilter( + tmpHostPtr, + _wsUri, + 1, ws_mqtt_protos); + } + AsyncMqttClientInternals::OutPacket* msg = new AsyncMqttClientInternals::ConnectOutPacket(_cleanSession, _username, @@ -227,6 +263,13 @@ void AsyncMqttClient::_onDisconnect() { _clear(); for (auto callback : _onDisconnectUserCallbacks) callback(_disconnectReason); + + if (_wsfilter != NULL) { + delete _wsfilter; + _wsfilter = NULL; + } + if (_wsHost != NULL) delete[] _wsHost; + _wsHost = NULL; } /* @@ -247,6 +290,69 @@ void AsyncMqttClient::_onAck(size_t len) { void AsyncMqttClient::_onData(char* data, size_t len) { log_i("data rcv (%u)", len); + + if (_wsfilter == NULL) { + _onMQTTData(data, len); + return; + } + + // Space occupied by data already fetched into websocket will be reused to + // hold MQTT payload data. + uint8_t * mqttdata = (uint8_t *)data; + uint8_t * outbufptr = mqttdata; + size_t outputavail = 0; + size_t mqttavail = 0; + do { + size_t proc_len = _wsfilter->addDataFromStream(len, (uint8_t *)data); + log_d("pushed %u of %u bytes into websocket filter", proc_len, len); + data += proc_len; + len -= proc_len; + outputavail += proc_len; + + if (proc_len <= 0) { + log_w("websocket filter rx full?"); + } else { + while (_wsfilter->isPacketDataAvailable()) { + size_t mqttbytes = 0; + uint64_t pktoffset = 0; + bool pktisbinary = false; + bool pktlastfetch = false; + _wsfilter->fetchPacketData(outputavail, outbufptr, mqttbytes, + pktoffset, pktisbinary, pktlastfetch); + log_d("fetched %u/%u mqtt bytes pktoff=%u binary=%u last=%u", mqttbytes, proc_len, + (uint32_t)pktoffset, pktisbinary ? 1 : 0, pktlastfetch ? 1 : 0); + mqttavail += mqttbytes; + outbufptr += mqttbytes; + outputavail -= mqttbytes; + } + } + + // Check for any handshake or protocol errors + if (_wsfilter->getError() != AsyncMqttClientInternals::NO_ERROR) { + switch (_wsfilter->getError()) { + case AsyncMqttClientInternals::HANDSHAKE_FAILED: + log_w("Websocket handshake failure"); + break; + case AsyncMqttClientInternals::PROTOCOL_ERROR: + log_w("Websocket protocol violation"); + break; + case AsyncMqttClientInternals::REMOTE_CTRL_CLOSE: + log_w("Remote side closed websocket: code %hu", _wsfilter->getCloseCode()); + if (_wsfilter->getCloseReason() != NULL) { + log_w("Reason text %s", _wsfilter->getCloseReason()); + } + break; + } + disconnect(true); + break; + } else if (len <= 0) { + if (mqttavail > 0) _onMQTTData((char *)mqttdata, mqttavail); + } + } while (len > 0); +} + +void AsyncMqttClient::_onMQTTData(char* data, size_t len) { + log_i("data rcv (%u)", len); size_t currentBytePosition = 0; char currentByte; _lastServerActivity = millis(); @@ -395,24 +501,46 @@ void AsyncMqttClient::_handleQueue() { SEMAPHORE_TAKE(); // On ESP32, onDisconnect is called within the close()-call. So we need to make sure we don't lock bool disconnect = false; + bool wserror = false; + bool canflushqueue = true; - while (_head && _client.space() > 10) { // safe but arbitrary value, send at least 10 bytes + if (_wsfilter != NULL) { + canflushqueue = _flushWebsocketTX(wserror); + if (!canflushqueue) { + if (wserror) disconnect = true; + } + } + + while (canflushqueue && _head && _client.space() > 10) { // safe but arbitrary value, send at least 10 bytes // 1. try to send if (_head->size() > _sent) { - // On SSL the TCP library returns the total amount of bytes, not just the unencrypted payload length. - // So we calculate the amount to be written ourselves. - size_t willSend = std::min(_head->size() - _sent, _client.space()); - size_t realSent = _client.add(reinterpret_cast(_head->data(_sent)), willSend, ASYNC_WRITE_FLAG_COPY); // flag is set by LWIP anyway, added for clarity - _sent += willSend; - (void)realSent; - _client.send(); - _lastClientActivity = millis(); - _lastPingRequestTime = 0; - #if ASYNC_TCP_SSL_ENABLED - log_i("snd #%u: (tls: %u) %u/%u", _head->packetType(), realSent, _sent, _head->size()); - #else - log_i("snd #%u: %u/%u", _head->packetType(), _sent, _head->size()); - #endif + if (_wsfilter != NULL) { + if (_sent == 0 && !_wsfilter->isOpenPacket()) _wsfilter->startPacket(true); + size_t realSent = _wsfilter->addPacketData(_head->size() - _sent, _head->data(_sent), true); + canflushqueue = _flushWebsocketTX(wserror); + if (!canflushqueue) { + // At this point, any failure to flush is a websocket error + if (wserror) disconnect = true; + } + _sent += realSent; + _lastClientActivity = millis(); + _lastPingRequestTime = 0; + } else { + // On SSL the TCP library returns the total amount of bytes, not just the unencrypted payload length. + // So we calculate the amount to be written ourselves. + size_t willSend = std::min(_head->size() - _sent, _client.space()); + size_t realSent = _client.add(reinterpret_cast(_head->data(_sent)), willSend, ASYNC_WRITE_FLAG_COPY); // flag is set by LWIP anyway, added for clarity + _sent += willSend; + (void)realSent; + _client.send(); + _lastClientActivity = millis(); + _lastPingRequestTime = 0; + #if ASYNC_TCP_SSL_ENABLED + log_i("snd #%u: (tls: %u) %u/%u", _head->packetType(), realSent, _sent, _head->size()); + #else + log_i("snd #%u: %u/%u", _head->packetType(), _sent, _head->size()); + #endif + } if (_head->packetType() == AsyncMqttClientInternals::PacketType.DISCONNECT) { disconnect = true; } @@ -435,11 +563,58 @@ void AsyncMqttClient::_handleQueue() { SEMAPHORE_GIVE(); if (disconnect) { - log_i("snd DISCONN, disconnecting"); + if (wserror) + log_i("websocket error, disconnecting"); + else + log_i("snd DISCONN, disconnecting"); _client.close(); } } +bool AsyncMqttClient::_flushWebsocketTX(bool & disconnect) +{ + // Flush as much of the wsfilter TX buffer as possible, required for handshake + while (_client.space() > 10 && _wsfilter->pendingStreamOutputLen() > 0) { + log_d("client.space = %u pending ws output = %u", _client.space(), _wsfilter->pendingStreamOutputLen()); + uint8_t * wsbuffer = NULL; + size_t wsbytes = 0; + _wsfilter->fetchDataPtrForStream(wsbuffer, wsbytes); + wsbytes = std::min(wsbytes, _client.space()); + size_t wsbytes_sent = _client.add((const char *)wsbuffer, wsbytes, ASYNC_WRITE_FLAG_COPY); + if (wsbytes_sent > 0) { + _client.send(); + _wsfilter->discardFetchedData(wsbytes_sent); + } else { + break; + } + } + + switch (_wsfilter->getState()) { + case AsyncMqttClientInternals::HANDSHAKE_TX: + case AsyncMqttClientInternals::HANDSHAKE_RX: + // Still negociating websocket handshake + log_v("still negotiating handshake"); + return false; + case AsyncMqttClientInternals::HANDSHAKE_ERR: + log_w("Websocket handshake failure"); + disconnect = true; + return false; + case AsyncMqttClientInternals::WEBSOCKET_ERROR: + log_w("Websocket protocol violation"); + disconnect = true; + return false; + case AsyncMqttClientInternals::WEBSOCKET_CLOSED: + log_w("Remote side closed websocket: code %hu", _wsfilter->getCloseCode()); + if (_wsfilter->getCloseReason() != NULL) { + log_w("Reason text %s", _wsfilter->getCloseReason()); + } + disconnect = true; + return false; + } + + return true; +} + void AsyncMqttClient::_clearQueue(bool keepSessionData) { SEMAPHORE_TAKE(); AsyncMqttClientInternals::OutPacket* packet = _head; diff --git a/src/AsyncMqttClient.hpp b/src/AsyncMqttClient.hpp index 1e81103..4a2a8f4 100644 --- a/src/AsyncMqttClient.hpp +++ b/src/AsyncMqttClient.hpp @@ -50,6 +50,8 @@ #include "AsyncMqttClient/Packets/Out/Unsubscribe.hpp" #include "AsyncMqttClient/Packets/Out/Publish.hpp" +#include "AsyncMqttClient/WebsocketFilter.hpp" + class AsyncMqttClient { public: AsyncMqttClient(); @@ -67,6 +69,8 @@ class AsyncMqttClient { AsyncMqttClient& setSecure(bool secure); AsyncMqttClient& addServerFingerprint(const uint8_t* fingerprint); #endif + AsyncMqttClient& setWsEnabled(bool wsenabled); + AsyncMqttClient& setWsUri(const char * uri); AsyncMqttClient& onConnect(AsyncMqttClientInternals::OnConnectUserCallback callback); AsyncMqttClient& onDisconnect(AsyncMqttClientInternals::OnDisconnectUserCallback callback); @@ -124,6 +128,11 @@ class AsyncMqttClient { std::vector> _secureServerFingerprints; #endif + AsyncMqttClientInternals::WebsocketFilter * _wsfilter; + bool _wsEnabled; + const char* _wsUri; + char* _wsHost; + std::vector _onConnectUserCallbacks; std::vector _onDisconnectUserCallbacks; std::vector _onSubscribeUserCallbacks; @@ -156,12 +165,15 @@ class AsyncMqttClient { void _onData(char* data, size_t len); void _onPoll(); + void _onMQTTData(char* data, size_t len); + // QUEUE void _insert(AsyncMqttClientInternals::OutPacket* packet); // for PUBREL void _addFront(AsyncMqttClientInternals::OutPacket* packet); // for CONNECT void _addBack(AsyncMqttClientInternals::OutPacket* packet); // all the rest void _handleQueue(); void _clearQueue(bool keepSessionData); + bool _flushWebsocketTX(bool & disconnect); // MQTT void _onPingResp(); From 1fe6a44df01ebf65472544ed7e2450c942e40f90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Villac=C3=ADs=20Lasso?= Date: Thu, 1 May 2025 13:19:43 -0500 Subject: [PATCH 9/9] Conditionally use updated name for mbedtls function Required for Arduino-ESP32 3.2.0 compatibility. --- src/AsyncMqttClient/WebsocketFilter.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/AsyncMqttClient/WebsocketFilter.cpp b/src/AsyncMqttClient/WebsocketFilter.cpp index 348fe47..4d70398 100644 --- a/src/AsyncMqttClient/WebsocketFilter.cpp +++ b/src/AsyncMqttClient/WebsocketFilter.cpp @@ -3,6 +3,7 @@ #include //#include +#include "esp_idf_version.h" #include "esp32-hal-log.h" #include "mbedtls/base64.h" #include "mbedtls/sha1.h" @@ -344,7 +345,11 @@ void WebsocketFilter::_runHandshakeResponseCheck(void) strcat(wskey, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); unsigned char sha1_output[20]; +#if ESP_IDF_VERSION_MAJOR < 5 mbedtls_sha1_ret((unsigned char *)wskey, 60, sha1_output); +#else + mbedtls_sha1((unsigned char *)wskey, 60, sha1_output); +#endif char expected_base64[29]; size_t dummy_olen; memset(expected_base64, 0, sizeof(expected_base64));