diff --git a/src/AsyncMqttClient.cpp b/src/AsyncMqttClient.cpp index 5e80256..5840190 100644 --- a/src/AsyncMqttClient.cpp +++ b/src/AsyncMqttClient.cpp @@ -29,8 +29,14 @@ AsyncMqttClient::AsyncMqttClient() , _willQos(0) , _willRetain(false) #if ASYNC_TCP_SSL_ENABLED +#ifdef ESP8266 , _secureServerFingerprints() #endif +#endif +, _wsfilter(nullptr) +, _wsEnabled(false) +, _wsUri("/") +, _wsHost(nullptr) , _onConnectUserCallbacks() , _onDisconnectUserCallbacks() , _onSubscribeUserCallbacks() @@ -130,6 +136,7 @@ AsyncMqttClient& AsyncMqttClient::setSecure(bool secure) { return *this; } +#ifdef ESP8266 AsyncMqttClient& AsyncMqttClient::addServerFingerprint(const uint8_t* fingerprint) { std::array newFingerprint; memcpy(newFingerprint.data(), fingerprint, SHA1_SIZE); @@ -137,6 +144,32 @@ AsyncMqttClient& AsyncMqttClient::addServerFingerprint(const uint8_t* fingerprin return *this; } #endif +#ifdef ESP32 +AsyncMqttClient& AsyncMqttClient::setRootCa(const char* rootca, const size_t len) { + _client.setRootCa(rootca, len); + return *this; +} +AsyncMqttClient& AsyncMqttClient::setClientCert(const char* cli_cert, const size_t cli_cert_len, const char* cli_key, const size_t cli_key_len) { + _client.setClientCert(cli_cert, cli_cert_len); + _client.setClientKey(cli_key, cli_key_len); + return *this; +} +AsyncMqttClient& AsyncMqttClient::setPsk(const char* psk_ident, const char* psk) { + _client.setPsk(psk_ident, psk); + return *this; +} +#endif +#endif + +AsyncMqttClient& AsyncMqttClient::setWsEnabled(bool wsenabled) { + _wsEnabled = wsenabled; + return *this; +} + +AsyncMqttClient& AsyncMqttClient::setWsUri(const char * uri) { + _wsUri = uri; + return *this; +} AsyncMqttClient& AsyncMqttClient::onConnect(AsyncMqttClientInternals::OnConnectUserCallback callback) { _onConnectUserCallbacks.push_back(callback); @@ -187,10 +220,11 @@ void AsyncMqttClient::_clear() { void AsyncMqttClient::_onConnect() { log_i("TCP conn, MQTT CONNECT"); #if ASYNC_TCP_SSL_ENABLED +#ifdef ESP8266 if (_secure && _secureServerFingerprints.size() > 0) { + bool sslFoundFingerprint = false; SSL* clientSsl = _client.getSSL(); - bool sslFoundFingerprint = false; for (std::array fingerprint : _secureServerFingerprints) { if (ssl_match_fingerprint(clientSsl, fingerprint.data()) == SSL_OK) { sslFoundFingerprint = true; @@ -205,6 +239,29 @@ void AsyncMqttClient::_onConnect() { } } #endif +#endif + + if (_wsEnabled) { + static const char * ws_mqtt_protos[1] = { "mqtt" }; + const char * tmpHostPtr; + + if (_wsHost != NULL) delete[] _wsHost; + _wsHost = NULL; + if (_useIp) { + auto sIP = _ip.toString(); + _wsHost = new char[sIP.length() + 1]; + strcpy(_wsHost, sIP.c_str()); + tmpHostPtr = _wsHost; + } else { + tmpHostPtr = _host; + } + + _wsfilter = new AsyncMqttClientInternals::WebsocketFilter( + tmpHostPtr, + _wsUri, + 1, ws_mqtt_protos); + } + AsyncMqttClientInternals::OutPacket* msg = new AsyncMqttClientInternals::ConnectOutPacket(_cleanSession, _username, @@ -227,6 +284,13 @@ void AsyncMqttClient::_onDisconnect() { _clear(); for (auto callback : _onDisconnectUserCallbacks) callback(_disconnectReason); + + if (_wsfilter != NULL) { + delete _wsfilter; + _wsfilter = NULL; + } + if (_wsHost != NULL) delete[] _wsHost; + _wsHost = NULL; } /* @@ -247,6 +311,69 @@ void AsyncMqttClient::_onAck(size_t len) { void AsyncMqttClient::_onData(char* data, size_t len) { log_i("data rcv (%u)", len); + + if (_wsfilter == NULL) { + _onMQTTData(data, len); + return; + } + + // Space occupied by data already fetched into websocket will be reused to + // hold MQTT payload data. + uint8_t * mqttdata = (uint8_t *)data; + uint8_t * outbufptr = mqttdata; + size_t outputavail = 0; + size_t mqttavail = 0; + do { + size_t proc_len = _wsfilter->addDataFromStream(len, (uint8_t *)data); + log_d("pushed %u of %u bytes into websocket filter", proc_len, len); + data += proc_len; + len -= proc_len; + outputavail += proc_len; + + if (proc_len <= 0) { + log_w("websocket filter rx full?"); + } else { + while (_wsfilter->isPacketDataAvailable()) { + size_t mqttbytes = 0; + uint64_t pktoffset = 0; + bool pktisbinary = false; + bool pktlastfetch = false; + _wsfilter->fetchPacketData(outputavail, outbufptr, mqttbytes, + pktoffset, pktisbinary, pktlastfetch); + log_d("fetched %u/%u mqtt bytes pktoff=%u binary=%u last=%u", mqttbytes, proc_len, + (uint32_t)pktoffset, pktisbinary ? 1 : 0, pktlastfetch ? 1 : 0); + mqttavail += mqttbytes; + outbufptr += mqttbytes; + outputavail -= mqttbytes; + } + } + + // Check for any handshake or protocol errors + if (_wsfilter->getError() != AsyncMqttClientInternals::NO_ERROR) { + switch (_wsfilter->getError()) { + case AsyncMqttClientInternals::HANDSHAKE_FAILED: + log_w("Websocket handshake failure"); + break; + case AsyncMqttClientInternals::PROTOCOL_ERROR: + log_w("Websocket protocol violation"); + break; + case AsyncMqttClientInternals::REMOTE_CTRL_CLOSE: + log_w("Remote side closed websocket: code %hu", _wsfilter->getCloseCode()); + if (_wsfilter->getCloseReason() != NULL) { + log_w("Reason text %s", _wsfilter->getCloseReason()); + } + break; + } + disconnect(true); + break; + } else if (len <= 0) { + if (mqttavail > 0) _onMQTTData((char *)mqttdata, mqttavail); + } + } while (len > 0); +} + +void AsyncMqttClient::_onMQTTData(char* data, size_t len) { + log_i("data rcv (%u)", len); size_t currentBytePosition = 0; char currentByte; _lastServerActivity = millis(); @@ -395,23 +522,44 @@ void AsyncMqttClient::_handleQueue() { SEMAPHORE_TAKE(); // On ESP32, onDisconnect is called within the close()-call. So we need to make sure we don't lock bool disconnect = false; + bool wserror = false; + bool canflushqueue = true; + + if (_wsfilter != NULL) { + canflushqueue = _flushWebsocketTX(wserror); + if (!canflushqueue) { + if (wserror) disconnect = true; + } + } - while (_head && _client.space() > 10) { // safe but arbitrary value, send at least 10 bytes + while (canflushqueue && _head && _client.space() > 10) { // safe but arbitrary value, send at least 10 bytes // 1. try to send if (_head->size() > _sent) { - // On SSL the TCP library returns the total amount of bytes, not just the unencrypted payload length. - // So we calculate the amount to be written ourselves. - size_t willSend = std::min(_head->size() - _sent, _client.space()); - size_t realSent = _client.add(reinterpret_cast(_head->data(_sent)), willSend, ASYNC_WRITE_FLAG_COPY); // flag is set by LWIP anyway, added for clarity - _sent += willSend; - (void)realSent; - _client.send(); - _lastClientActivity = millis(); - #if ASYNC_TCP_SSL_ENABLED - log_i("snd #%u: (tls: %u) %u/%u", _head->packetType(), realSent, _sent, _head->size()); - #else - log_i("snd #%u: %u/%u", _head->packetType(), _sent, _head->size()); - #endif + if (_wsfilter != NULL) { + if (_sent == 0 && !_wsfilter->isOpenPacket()) _wsfilter->startPacket(true); + size_t realSent = _wsfilter->addPacketData(_head->size() - _sent, _head->data(_sent), true); + canflushqueue = _flushWebsocketTX(wserror); + if (!canflushqueue) { + // At this point, any failure to flush is a websocket error + if (wserror) disconnect = true; + } + _sent += realSent; + _lastClientActivity = millis(); + } else { + // On SSL the TCP library returns the total amount of bytes, not just the unencrypted payload length. + // So we calculate the amount to be written ourselves. + size_t willSend = std::min(_head->size() - _sent, _client.space()); + size_t realSent = _client.add(reinterpret_cast(_head->data(_sent)), willSend, ASYNC_WRITE_FLAG_COPY); // flag is set by LWIP anyway, added for clarity + _sent += willSend; + (void)realSent; + _client.send(); + _lastClientActivity = millis(); + #if ASYNC_TCP_SSL_ENABLED + log_i("snd #%u: (tls: %u) %u/%u", _head->packetType(), realSent, _sent, _head->size()); + #else + log_i("snd #%u: %u/%u", _head->packetType(), _sent, _head->size()); + #endif + } if (_head->packetType() == AsyncMqttClientInternals::PacketType.DISCONNECT) { disconnect = true; } @@ -434,11 +582,58 @@ void AsyncMqttClient::_handleQueue() { SEMAPHORE_GIVE(); if (disconnect) { - log_i("snd DISCONN, disconnecting"); + if (wserror) + log_i("websocket error, disconnecting"); + else + log_i("snd DISCONN, disconnecting"); _client.close(); } } +bool AsyncMqttClient::_flushWebsocketTX(bool & disconnect) +{ + // Flush as much of the wsfilter TX buffer as possible, required for handshake + while (_client.space() > 10 && _wsfilter->pendingStreamOutputLen() > 0) { + log_d("client.space = %u pending ws output = %u", _client.space(), _wsfilter->pendingStreamOutputLen()); + uint8_t * wsbuffer = NULL; + size_t wsbytes = 0; + _wsfilter->fetchDataPtrForStream(wsbuffer, wsbytes); + wsbytes = std::min(wsbytes, _client.space()); + size_t wsbytes_sent = _client.add((const char *)wsbuffer, wsbytes, ASYNC_WRITE_FLAG_COPY); + if (wsbytes_sent > 0) { + _client.send(); + _wsfilter->discardFetchedData(wsbytes_sent); + } else { + break; + } + } + + switch (_wsfilter->getState()) { + case AsyncMqttClientInternals::HANDSHAKE_TX: + case AsyncMqttClientInternals::HANDSHAKE_RX: + // Still negociating websocket handshake + log_v("still negotiating handshake"); + return false; + case AsyncMqttClientInternals::HANDSHAKE_ERR: + log_w("Websocket handshake failure"); + disconnect = true; + return false; + case AsyncMqttClientInternals::WEBSOCKET_ERROR: + log_w("Websocket protocol violation"); + disconnect = true; + return false; + case AsyncMqttClientInternals::WEBSOCKET_CLOSED: + log_w("Remote side closed websocket: code %hu", _wsfilter->getCloseCode()); + if (_wsfilter->getCloseReason() != NULL) { + log_w("Reason text %s", _wsfilter->getCloseReason()); + } + disconnect = true; + return false; + } + + return true; +} + void AsyncMqttClient::_clearQueue(bool keepSessionData) { SEMAPHORE_TAKE(); AsyncMqttClientInternals::OutPacket* packet = _head; diff --git a/src/AsyncMqttClient.hpp b/src/AsyncMqttClient.hpp index 1e81103..35b68be 100644 --- a/src/AsyncMqttClient.hpp +++ b/src/AsyncMqttClient.hpp @@ -19,7 +19,9 @@ #endif #if ASYNC_TCP_SSL_ENABLED +#ifdef ESP8266 #include +#endif #define SHA1_SIZE 20 #endif @@ -50,6 +52,8 @@ #include "AsyncMqttClient/Packets/Out/Unsubscribe.hpp" #include "AsyncMqttClient/Packets/Out/Publish.hpp" +#include "AsyncMqttClient/WebsocketFilter.hpp" + class AsyncMqttClient { public: AsyncMqttClient(); @@ -65,8 +69,17 @@ class AsyncMqttClient { AsyncMqttClient& setServer(const char* host, uint16_t port); #if ASYNC_TCP_SSL_ENABLED AsyncMqttClient& setSecure(bool secure); +#ifdef ESP8266 AsyncMqttClient& addServerFingerprint(const uint8_t* fingerprint); #endif +#ifdef ESP32 + AsyncMqttClient& setRootCa(const char* rootca, const size_t len); + AsyncMqttClient& setClientCert(const char* cli_cert, const size_t cli_cert_len, const char* cli_key, const size_t cli_key_len); + AsyncMqttClient& setPsk(const char* psk_ident, const char* psk); +#endif +#endif + AsyncMqttClient& setWsEnabled(bool wsenabled); + AsyncMqttClient& setWsUri(const char * uri); AsyncMqttClient& onConnect(AsyncMqttClientInternals::OnConnectUserCallback callback); AsyncMqttClient& onDisconnect(AsyncMqttClientInternals::OnDisconnectUserCallback callback); @@ -121,9 +134,16 @@ class AsyncMqttClient { bool _willRetain; #if ASYNC_TCP_SSL_ENABLED +#ifdef ESP8266 std::vector> _secureServerFingerprints; +#endif #endif + AsyncMqttClientInternals::WebsocketFilter * _wsfilter; + bool _wsEnabled; + const char* _wsUri; + char* _wsHost; + std::vector _onConnectUserCallbacks; std::vector _onDisconnectUserCallbacks; std::vector _onSubscribeUserCallbacks; @@ -156,12 +176,15 @@ class AsyncMqttClient { void _onData(char* data, size_t len); void _onPoll(); + void _onMQTTData(char* data, size_t len); + // QUEUE void _insert(AsyncMqttClientInternals::OutPacket* packet); // for PUBREL void _addFront(AsyncMqttClientInternals::OutPacket* packet); // for CONNECT void _addBack(AsyncMqttClientInternals::OutPacket* packet); // all the rest void _handleQueue(); void _clearQueue(bool keepSessionData); + bool _flushWebsocketTX(bool & disconnect); // MQTT void _onPingResp(); diff --git a/src/AsyncMqttClient/WebsocketFilter.cpp b/src/AsyncMqttClient/WebsocketFilter.cpp new file mode 100644 index 0000000..4d70398 --- /dev/null +++ b/src/AsyncMqttClient/WebsocketFilter.cpp @@ -0,0 +1,723 @@ +#include +#include +#include + +//#include +#include "esp_idf_version.h" +#include "esp32-hal-log.h" +#include "mbedtls/base64.h" +#include "mbedtls/sha1.h" + +#include "WebsocketFilter.hpp" + +using AsyncMqttClientInternals::WebsocketFilter; + +#define MIN_BUFSIZ 256 + +#define WS_OP_DATA_CONT 0 +#define WS_OP_DATA_TEXT 1 +#define WS_OP_DATA_BINARY 2 +#define WS_OP_CTRL_START 8 +#define WS_OP_CTRL_CLOSE 8 +#define WS_OP_CTRL_PING 9 +#define WS_OP_CTRL_PONG 10 + +WebsocketFilter::WebsocketFilter(const char * hostname, const char * wsurl, + uint32_t n_protos, const char * protos[], + size_t rxbufsiz, size_t txbufsiz) +{ + _state = HANDSHAKE_TX; // Starting handshake process + _err = NO_ERROR; // No error (yet) + + if (rxbufsiz < MIN_BUFSIZ) rxbufsiz = MIN_BUFSIZ; + if (txbufsiz < MIN_BUFSIZ) txbufsiz = MIN_BUFSIZ; + _txbuf = new uint8_t[txbufsiz]; _txbufsiz = txbufsiz; _txused = 0; + _rxbuf = new uint8_t[rxbufsiz]; _rxbufsiz = rxbufsiz; _rxused = 0; + + // Initialize header response check variables + _num_respHdrs = 0; + _hs_upgrade_websocket = false; + _hs_connection_upgrade = false; + _hs_sec_websocket_accept = false; + + _rx_inheader = false; + _rx_textframe = false; + _rx_binframe = false; + _rx_lastframe = true; + _rx_packetoffset = 0; + _rx_maskidx = 0; + _rx_close_code = 0; + _rx_close_reason = NULL; + _pendingPong = false; + _pongData = NULL; + _pongDataLen = 0; + + _tx_opcode = WS_OP_DATA_CONT; + _tx_pktopen = false; + + // Calculate a 16-byte random value for Sec-WebSocket-Key header + uint32_t randomkey[4]; size_t dummy_olen; + for (auto i = 0; i < 4; i++) randomkey[i] = (uint32_t)random(); + memset(_base64_key, 0, sizeof(_base64_key)); + mbedtls_base64_encode( + (unsigned char *)_base64_key, sizeof(_base64_key), + &dummy_olen, + (unsigned char *)randomkey, 16); + + // Start at beginning of first handshake string + _sent_handshake[0] = 0; + _sent_handshake[1] = 0; + + // Allocate list of handshake strings to send + _num_handshakeStrings = 8; + if (n_protos > 0 && protos != NULL) _num_handshakeStrings += 1 + 2 * n_protos; + _handshakeStrings = new const char *[_num_handshakeStrings]; + + _handshakeStrings[0] = "GET "; + _handshakeStrings[1] = wsurl; + _handshakeStrings[2] = " HTTP/1.1\r\nHost: "; + _handshakeStrings[3] = hostname; + _handshakeStrings[4] = "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: "; + _handshakeStrings[5] = _base64_key; + _handshakeStrings[6] = "\r\nSec-WebSocket-Version: 13\r\n"; + auto idx = 7; + if (n_protos > 0 && protos != NULL) { + _handshakeStrings[idx] = "Sec-WebSocket-Protocol: "; + idx++; + for (auto i = 0; i < n_protos; i++) { + _handshakeStrings[idx] = protos[i]; + idx++; + _handshakeStrings[idx] = (i < n_protos - 1) ? ", " : "\r\n"; + idx++; + } + } + _handshakeStrings[idx] = "\r\n"; + + _advanceHandshakeTX(); +} + +WebsocketFilter::~WebsocketFilter() +{ + if (_handshakeStrings != NULL) { + delete[] _handshakeStrings; + _handshakeStrings = NULL; + _num_handshakeStrings = 0; + } + + if (_rx_close_reason != NULL) delete[] _rx_close_reason; + _rx_close_reason = NULL; + + if (_pongData != NULL) delete[] _pongData; + _pongData = NULL; + + delete[] _rxbuf; + delete[] _txbuf; +} + +void WebsocketFilter::_advanceHandshakeTX(void) +{ + if (_handshakeStrings == NULL) return; + if (_state != HANDSHAKE_TX) return; + +#define HS_STRIDX (_sent_handshake[0]) +#define HS_STROFF (_sent_handshake[1]) +#define HS_STR (_handshakeStrings[_sent_handshake[0]]) + + // Copy as many of the remaining strings of handshake into output buffer as + // possible, until running out of space or end of strings. + while (_txused < _txbufsiz && HS_STRIDX < _num_handshakeStrings) { + auto len = strlen(HS_STR); + auto len_copy = len - HS_STROFF; + auto txavail = _txbufsiz - _txused; + if (len_copy > txavail) len_copy = txavail; + memcpy(_txbuf + _txused, HS_STR + HS_STROFF, len_copy); + _txused += len_copy; + HS_STROFF += len_copy; + + // Go to next string of handshake + if (HS_STROFF >= len) { + HS_STRIDX++; + HS_STROFF = 0; + } + } + + if (HS_STRIDX >= _num_handshakeStrings) { + // End of handshake queued, switching to handshake response + _state = HANDSHAKE_RX; + delete[] _handshakeStrings; + _handshakeStrings = NULL; + _num_handshakeStrings = 0; + } +} + +void WebsocketFilter::_discardTxData(size_t n_bytes) +{ + if (n_bytes <= 0) return; + if (n_bytes < _txused) { + memmove(_txbuf, _txbuf + n_bytes, _txused - n_bytes); + _txused -= n_bytes; + } else { + _txused = 0; + } +} + +void WebsocketFilter::_discardRxData(size_t n_bytes) +{ + if (n_bytes <= 0) return; + if (n_bytes < _rxused) { + memmove(_rxbuf, _rxbuf + n_bytes, _rxused - n_bytes); + _rxused -= n_bytes; + } else { + _rxused = 0; + } +} + +void WebsocketFilter::fetchDataPtrForStream(uint8_t * & buffer, size_t & n_bytes) +{ + buffer = _txbuf; + n_bytes = _txused; +} + +void WebsocketFilter::discardFetchedData(size_t n_bytes) +{ + if (n_bytes > _txused) n_bytes = _txused; + if (n_bytes <= 0) return; + + _discardTxData(n_bytes); + + if (_state == HANDSHAKE_TX) { + // Some space was freed from buffer, advance handshake + _advanceHandshakeTX(); + } + + if (_state == WEBSOCKET_OPEN && _pendingPong) { + // Check if pending pong can be enqueued in freed space + if (2 + _pongDataLen <= _txbufsiz - _txused) { + _enqueueOutgoingFrame(WS_OP_CTRL_PONG, _pongDataLen, + (_pongDataLen > 0) ? _pongData : NULL, false, true); + if (_pongData != NULL) delete[] _pongData; + _pongData = NULL; + _pongDataLen = 0; + _pendingPong = false; + } + } +} + +bool WebsocketFilter::fetchDataForStream(size_t max_size, uint8_t * buffer, size_t & n_bytes) +{ + if (max_size <= 0 || buffer == NULL) return false; + + n_bytes = _txused; + if (n_bytes > max_size) n_bytes = max_size; + if (n_bytes <= 0) return true; + + memcpy(buffer, _txbuf, n_bytes); + discardFetchedData(n_bytes); + + return true; +} + +size_t WebsocketFilter::addDataFromStream(size_t size, const uint8_t * buffer) +{ + if (size > _rxbufsiz - _rxused) size = _rxbufsiz - _rxused; + if (size > 0) { + memcpy(_rxbuf + _rxused, buffer, size); + _rxused += size; + + if (_state == HANDSHAKE_RX) { + // New data is available, advance handshake response check + _runHandshakeResponseCheck(); + } + + if (_state == WEBSOCKET_OPEN && _rxused > 0) { + // Received frame header or data, analyse and remove frame headers + _runFrameDataParse(); + } + } + + return size; +} + +void WebsocketFilter::_runHandshakeResponseCheck(void) +{ + // Each iteration requires at least one complete header to be present at + // the beginning of the rx buffer. + while (_state == HANDSHAKE_RX && _rxused > 0) { + log_v("_num_respHdrs = %u _rxused = %u", _num_respHdrs, _rxused); + // A null character is not allowed anywhere in the response. Finding one + // means the handshake failed (maybe stream is not plaintext HTTP?) + if (NULL != memchr(_rxbuf, '\0', _rxused)) { + log_w("NULL found in handshake response (non-HTTP stream?)"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + + // A complete header must terminate in \r\n. No solitary \r is allowed. + uint8_t * cr = (uint8_t *)memchr(_rxbuf, '\r', _rxused); + if (cr == NULL) { + // \r was not found. Maybe we need to wait for more data, unless + // buffer is full, which means the "header" is overlong, or (again) + // not an plaintext HTTP response. + if (_rxused >= _rxbufsiz) { + log_w("overlong header (1) in handshake response (non-HTTP stream?)"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + } + break; + } + if ((cr - _rxbuf) == _rxused - 1) { + // \r found at the very edge of the response. Maybe we need to wait + // for more data, unless buffer is full, which means the "header" + // is overlong + if (_rxused >= _rxbufsiz) { + log_w("overlong header (2) in handshake response (non-HTTP stream?)"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + } + break; + } + + // Now it is safe to check whether next character is \n + if (*(cr + 1) != '\n') { + log_w("solitary CR in handshake response (non-HTTP stream?)"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + + char * hdr = (char *)_rxbuf; + size_t hdrlen = (cr - _rxbuf) + 2; // Length of header including \r\n + + *cr = '\0'; // Overwrite CR with null to treat header as C-string + log_v("Examining hdr: [%s]", hdr); + + if (_num_respHdrs == 0) { + // First header MUST be the response: + // HTTP/1.1 101 Switching Protocols + if (strstr(hdr, "HTTP/1.1 101 ") != hdr) { + log_w("invalid HTTP response: %s", hdr); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + } else { + if (strlen(hdr) > 0) { + // Rest of headers must be of the form + // KEY: VALUE + char * k = hdr; + char * v = strstr(hdr, ": "); + if (v == NULL) { + log_w("invalid HTTP header: %s", hdr); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + *v = '\0'; v += 2; + while (*v == ' ') v++; + + // So... what header do we have here...? + if (0 == strcasecmp(k, "Upgrade")) { + // This header MUST exist and MUST have the value: websocket + if (0 != strcasecmp(v, "websocket")) { + // Invalid Upgrade header + log_w("invalid Upgrade header: %s", v); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + _hs_upgrade_websocket = true; + } else if (0 == strcasecmp(k, "Connection")) { + // This header MUST exist and MUST have the value: Upgrade + if (0 != strcasecmp(v, "Upgrade")) { + // Invalid Upgrade header + log_w("invalid Connection header: %s", v); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + _hs_connection_upgrade = true; + } else if (0 == strcasecmp(k, "Sec-WebSocket-Accept")) { + // This header MUST exist, and its value MUST match the + // resulting base64 string for the SHA1 header. + char wskey[61]; + strcpy(wskey, _base64_key); + strcat(wskey, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + + unsigned char sha1_output[20]; +#if ESP_IDF_VERSION_MAJOR < 5 + mbedtls_sha1_ret((unsigned char *)wskey, 60, sha1_output); +#else + mbedtls_sha1((unsigned char *)wskey, 60, sha1_output); +#endif + + char expected_base64[29]; size_t dummy_olen; + memset(expected_base64, 0, sizeof(expected_base64)); + mbedtls_base64_encode( + (unsigned char *)expected_base64, sizeof(expected_base64), + &dummy_olen, + (unsigned char *)sha1_output, sizeof(sha1_output)); + + if (0 != strcmp(v, expected_base64)) { + // Invalid Sec-WebSocket-Accept + log_w("invalid Sec-WebSocket-Accept: expected %s found %s", expected_base64, v); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + break; + } + _hs_sec_websocket_accept = true; + } else if (0 == strcasecmp(k, "Sec-WebSocket-Protocol")) { + // Ignore protocol for now... + log_d("(ignored protocol) %s: %s", k, v); + } else { + log_v("(ignored header) %s: %s", k, v); + } + } else { + // Reached end of response. All checks must have passed + if (!_hs_upgrade_websocket) { + log_w("missing header Upgrade: websocket"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + } else if (!_hs_connection_upgrade) { + log_w("missing header Connection: Upgrade"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + } else if (!_hs_sec_websocket_accept) { + log_w("missing header Sec-WebSocket-Accept"); + _state = HANDSHAKE_ERR; + _err = HANDSHAKE_FAILED; + } else { + // Success! Packet exchange can start now. + _state = WEBSOCKET_OPEN; + _rx_inheader = true; + } + } + } + + // Header processed, may be discarded now + _num_respHdrs++; + _discardRxData(hdrlen); + } +} + +void WebsocketFilter::_runFrameDataParse() +{ + while (_state == WEBSOCKET_OPEN && _rxused > 0) { + + if (_rx_inheader) { + // Need to accumulate one complete header to evaluate further + if (_rxused < 2) return; + + // Check if there is enough data for header + size_t hdrlength = 2; + uint8_t mask = _rxbuf[1]; + if ((mask & 0x7f) == 0x7f) { + hdrlength = 10; + } else if ((mask & 0x7f) == 0x7e) { + hdrlength = 4; + } + if (_rxused < hdrlength + ((mask & 0x80) ? 4 : 0)) return; + + _rx_opcode = _rxbuf[0]; + _rx_framelen = 0; + if ((mask & 0x7f) == 0x7f) { + // Extract 64-bit frame length, network byte order + for (auto i = 0; i < 8; i++) _rx_framelen = (_rx_framelen << 8) | _rxbuf[i+2]; + } else if ((mask & 0x7f) == 0x7e) { + // Extract 16-bit frame length, network byte order + for (auto i = 0; i < 2; i++) _rx_framelen = (_rx_framelen << 8) | _rxbuf[i+2]; + } else { + // 7-bit frame length + _rx_framelen = mask & 0x7f; + } + log_v("hdrlength=%u frame length=%u", hdrlength, (uint32_t)_rx_framelen); + + // Extract last-frame flag, clear reserved flags (ignored) + bool lastframe = ((_rx_opcode & 0x80) != 0); + _rx_opcode &= 0x0F; + + // Use identity mask unless flag set by server (NON-COMPLIANT) + memset(_rx_maskdata, 0, 4); + if (mask & 0x80) { + log_w("mask data not compliant for client recv!"); + memcpy(_rx_maskdata, _rxbuf + hdrlength, 4); + } + _rx_maskidx = 0; + + _discardRxData(hdrlength + ((mask & 0x80) ? 4 : 0)); + _rx_inheader = false; + + switch (_rx_opcode) { + case WS_OP_DATA_TEXT: // Start of text frame + log_v("start of text frame"); + _rx_binframe = false; + _rx_textframe = true; + if (!_rx_lastframe) { + log_e("expecting continuation frame"); + _state = WEBSOCKET_ERROR; + _err = PROTOCOL_ERROR; + } + _rx_lastframe = lastframe; + _rx_packetoffset = 0; + break; + case WS_OP_DATA_BINARY: // Start of binary frame + log_v("start of binary frame"); + _rx_binframe = true; + _rx_textframe = false; + if (!_rx_lastframe) { + log_e("expecting continuation frame"); + _state = WEBSOCKET_ERROR; + _err = PROTOCOL_ERROR; + } + _rx_lastframe = lastframe; + _rx_packetoffset = 0; + break; + case WS_OP_DATA_CONT: // Continuation frame + log_v("continuation frame"); + _rx_lastframe = lastframe; + if (!_rx_textframe && !_rx_binframe) { + log_e("continuation frame without start frame!"); + _state = WEBSOCKET_ERROR; + _err = PROTOCOL_ERROR; + } + break; + case WS_OP_CTRL_CLOSE: + case WS_OP_CTRL_PING: + case WS_OP_CTRL_PONG: + log_v("control frame 0x%02x", _rx_opcode); + if (!lastframe) { + log_e("fragmented control frame 0x%02x", _rx_opcode); + _state = WEBSOCKET_ERROR; + _err = PROTOCOL_ERROR; + } + if (_rx_framelen > 125) { + log_e("overlong control frame 0x%02x", _rx_opcode); + _state = WEBSOCKET_ERROR; + _err = PROTOCOL_ERROR; + } + break; + default: + log_e("unimplemented opcode 0x%02x", _rx_opcode); + _state = WEBSOCKET_ERROR; + _err = PROTOCOL_ERROR; + break; + } + } + + if (!_rx_inheader && _state == WEBSOCKET_OPEN) { + // Some data from payload, picked only if control code. According to + // RFC 6455, any payload for a control code must be no greater than 125 + // bytes. + if (_rx_opcode >= WS_OP_CTRL_START && _rxused >= _rx_framelen) { + switch (_rx_opcode) { + case WS_OP_CTRL_CLOSE: + if (_rx_framelen >= 2) { + // 2 byte code, rest is text explanation + _rx_close_code = (((uint16_t)_rxbuf[0]) << 8) | _rxbuf[1]; + if (_rx_framelen > 2) { + _rx_close_reason = new char[_rx_framelen - 2 + 1]; + memcpy(_rx_close_reason, _rxbuf + 2, _rx_framelen - 2); + _rx_close_reason[_rx_framelen - 2] = '\0'; + } + } + _state = WEBSOCKET_CLOSED; + _err = REMOTE_CTRL_CLOSE; + break; + case WS_OP_CTRL_PING: + // Check if PONG packet can be immediately enqueued + if (2 + _rx_framelen <= _txbufsiz - _txused) { + // Enqueue PONG packet immediately + _enqueueOutgoingFrame(WS_OP_CTRL_PONG, _rx_framelen, + (_rx_framelen > 0) ? _rxbuf : NULL, false, true); + } else { + // Copy PING payload to be enqueued later + _pendingPong = true; + if (_pongData != NULL) delete[] _pongData; + _pongData = NULL; + _pongDataLen = _rx_framelen; + if (_pongDataLen > 0) { + _pongData = new uint8_t[_pongDataLen]; + memcpy(_pongData, _rxbuf, _pongDataLen); + } + } + break; + case WS_OP_CTRL_PONG: + // Ignored + break; + } + + if (_rx_framelen > 0) _discardRxData(_rx_framelen); + _rx_framelen = 0; + _rx_inheader = true; // Process next RX frame header + } + } + + if (!_rx_inheader) break; + } +} + +bool WebsocketFilter::isPacketDataAvailable(void) +{ + if (_state != WEBSOCKET_OPEN) return false; + if (_err != NO_ERROR) return false; + return !_rx_inheader && (_rxused > 0); +} + +void WebsocketFilter::fetchPacketData(size_t max_size, uint8_t * buffer, + size_t & n_bytes, uint64_t & packet_offset, bool & packet_binary, + bool & last_fetch) +{ + n_bytes = 0; + packet_offset = _rx_packetoffset; + packet_binary = true; + last_fetch = false; + + if (max_size <= 0 || buffer == NULL) return; + if (_rx_textframe) packet_binary = false; + if (_rx_binframe) packet_binary = true; + + while (!last_fetch && !_rx_inheader && _state == WEBSOCKET_OPEN && max_size > 0 && _rxused > 0) { + size_t copylen = max_size; + if (copylen > _rx_framelen) copylen = _rx_framelen; + if (copylen > _rxused) copylen = _rxused; + + for (auto i = 0; i < copylen; i++) { + buffer[i] = _rxbuf[i] ^ _rx_maskdata[_rx_maskidx]; + _rx_maskidx = (_rx_maskidx + 1) & 3; + } + buffer += copylen; + max_size -= copylen; + n_bytes += copylen; + _rx_packetoffset += copylen; + _rx_framelen -= copylen; + + _discardRxData(copylen); + + if (_rx_framelen <= 0) { + // Copied all data from current frame. Any following data MUST be + // the header of the next frame. + _rx_inheader = true; + if (_rx_lastframe) { + // Last frame of packet + last_fetch = true; + _rx_textframe = false; + _rx_binframe = false; + } else { + // Not the last frame, might be more data to copy + } + + _runFrameDataParse(); + } + } +} + +bool WebsocketFilter::_enqueueOutgoingFrame(uint8_t opcode, uint32_t len, + const uint8_t * payload, bool masked, bool lastframe) +{ + if (len > 0 && payload == NULL) return false; + + // Calculate required length for full frame + uint32_t headerlen = 2; + if (len > 65535) + headerlen += 8; + else if (len > 125) + headerlen += 2; + if (masked) headerlen += 4; + if (headerlen + len > _txbufsiz - _txused) return false; + + uint8_t * hdr = _txbuf + _txused; + uint8_t * p = hdr + headerlen; + + hdr[0] = opcode; if (lastframe) hdr[0] |= 0x80; + if (len > 65535) { + hdr[1] = 127; + hdr[2] = hdr[3] = hdr[4] = hdr[5] = 0; + hdr[6] = (len >> 24) & 0xFF; + hdr[7] = (len >> 16) & 0xFF; + hdr[8] = (len >> 8) & 0xFF; + hdr[9] = len & 0xFF; + } else if (len > 125) { + hdr[1] = 126; + hdr[2] = (len >> 8) & 0xFF; + hdr[3] = len & 0xFF; + } else { + hdr[1] = len; + } + if (masked) { + hdr[1] |= 0x80; + + uint8_t * tx_maskdata = p - 4; + *((uint32_t *)tx_maskdata) = random(); + log_v("mask = 0x%08x", *((uint32_t *)tx_maskdata)); + + for (auto i = 0; i < len; i++) p[i] = payload[i] ^ tx_maskdata[i & 3]; + } else { + if (len > 0) memcpy(p, payload, len); + } + + _txused += headerlen + len; + + return true; +} + +bool WebsocketFilter::startPacket(bool packet_binary) +{ + if (_state != WEBSOCKET_OPEN) return false; + if (_tx_pktopen) return false; + uint8_t n_opcode = packet_binary ? WS_OP_DATA_BINARY : WS_OP_DATA_TEXT; + if (_tx_opcode != WS_OP_DATA_CONT && _tx_opcode != n_opcode) return false; + _tx_opcode = n_opcode; + _tx_pktopen = true; + return true; +} + +size_t WebsocketFilter::addPacketData(size_t size, const uint8_t * buffer, bool endOfPacket) +{ + if (_state != WEBSOCKET_OPEN) { + log_w("invalid websocket state %u", _state); + return 0; + } + if (size == 0 || buffer == NULL) { + log_e("invalid size or ptr"); + return 0; + } + if (!_tx_pktopen) { + log_w("attempt to add data to not-opened packet!"); + return 0; + } + + uint8_t fixedhdr = 2 + 4; // two-byte minimum plus mask data + size_t txavail = _txbufsiz - _txused; + + if (txavail <= fixedhdr) return 0; // Not enough space for even minimum header + + size_t copylen = txavail - fixedhdr; + if (copylen > size) copylen = size; + + if (copylen > 65535U) { + // 8-byte size field required, check if it fits + if (fixedhdr + 8 + copylen > txavail) { + copylen = txavail - fixedhdr - 8; + } + } + if (copylen > 125 && copylen <= 65535U) { + // 2-byte size field required, check if it fits + if (fixedhdr + 2 + copylen > txavail) { + copylen = txavail - fixedhdr - 2; + } + } + + // Cannot signal end-of-packet unless indicated by parameter AND all + // supplied data actually fits in available space + if (endOfPacket && copylen < size) endOfPacket = false; + + if (!_enqueueOutgoingFrame(_tx_opcode, copylen, buffer, true, endOfPacket)) { + log_e("BUG: incorrect calculation of required space: size=%u copylen=%u txavail=%u", + size, copylen, txavail); + return 0; + } + _tx_opcode = WS_OP_DATA_CONT; + if (endOfPacket) _tx_pktopen = false; + + return copylen; +} \ No newline at end of file diff --git a/src/AsyncMqttClient/WebsocketFilter.hpp b/src/AsyncMqttClient/WebsocketFilter.hpp new file mode 100644 index 0000000..8a7d8bf --- /dev/null +++ b/src/AsyncMqttClient/WebsocketFilter.hpp @@ -0,0 +1,156 @@ +#pragma once + +#include + +namespace AsyncMqttClientInternals { + +typedef enum { + NO_ERROR = 0, + HANDSHAKE_FAILED, // Websocket handshake failed + REMOTE_CTRL_CLOSE, // Remote side sent a close packet + PROTOCOL_ERROR // Invalid condition encountered in received frame +} WebsocketError; + +typedef enum { + HANDSHAKE_TX, // Client is sending HTTP handshake + HANDSHAKE_RX, // Client is receiving HTTP response to handshake + HANDSHAKE_ERR, // Handshake failed at some point + + WEBSOCKET_OPEN, // Handshake succeeded, websocket is sending/receiving packets + WEBSOCKET_CLOSED, // Received close packet from remote side + WEBSOCKET_ERROR // Websocket protocol error +} WebsocketState; + +class WebsocketFilter +{ +private: + WebsocketState _state; // Current state of websocket + WebsocketError _err; // Current error status + + uint8_t * _txbuf; size_t _txbufsiz; size_t _txused; + uint8_t * _rxbuf; size_t _rxbufsiz; size_t _rxused; + + // List of HTTP strings for handshake + const char ** _handshakeStrings; + size_t _num_handshakeStrings; + size_t _sent_handshake[2]; // index,offset of sent handshake + + // base64 representation of 16-bit random value (24 bytes) plus null + // for Sec-WebSocket-Key header. + char _base64_key[25]; + + // Variables for response checks + unsigned int _num_respHdrs; + bool _hs_upgrade_websocket; + bool _hs_connection_upgrade; + bool _hs_sec_websocket_accept; + + // RX: are we expecting a frame header? + bool _rx_inheader; + bool _rx_textframe; + bool _rx_binframe; + bool _rx_lastframe; + uint8_t _rx_opcode; + uint64_t _rx_framelen; // Length of payload portion of frame, without headers/mask + uint64_t _rx_packetoffset; + uint8_t _rx_maskdata[4]; + uint8_t _rx_maskidx; + + // Close reason handling + uint16_t _rx_close_code; + char * _rx_close_reason; + + // Pending PONG handling + bool _pendingPong; + uint8_t * _pongData; + uint8_t _pongDataLen; + + uint8_t _tx_opcode; + bool _tx_pktopen; + + void _discardTxData(size_t); + void _discardRxData(size_t); + void _advanceHandshakeTX(void); + void _runHandshakeResponseCheck(void); + + void _runFrameDataParse(void); + bool _enqueueOutgoingFrame(uint8_t opcode, uint32_t len, const uint8_t * payload, + bool masked, bool lastframe); +public: + WebsocketFilter(const char * hostname, const char * wsurl = "/", + uint32_t n_protos = 0, const char * protos[] = NULL, + size_t rxbufsiz = 256, size_t txbufsiz = 256); + ~WebsocketFilter(); + + WebsocketState getState(void) { return _state; } + WebsocketError getError(void) { return _err; } + + uint16_t getCloseCode(void) { return _rx_close_code; } + const char * getCloseReason(void) { return _rx_close_reason; } + + // Push raw bytes received from a byte stream into websocket RX parser. + // Returns number of bytes actually consumed from buffer, which might be + // zero if the internal buffer is full. + size_t addDataFromStream(size_t size, const uint8_t * buffer); + + // Query number of pending bytes to be sent to stream + size_t pendingStreamOutputLen(void) { return _txused; } + + // Fetch raw bytes to be sent to byte stream from websocket TX. This + // places the number of bytes actually placed into the buffer into n_bytes, + // up to the limit of max_size. After fetching the outgoing data, this data + // is discarded and the space reused for further processing. Returns true + // for success (including n_bytes <- 0 for empty buffer), or false for + // failure (usually a pending error). + bool fetchDataForStream(size_t max_size, uint8_t * buffer, size_t & n_bytes); + + // Fetch a pointer to the internal buffer of raw bytes queued for output. + // Unlike fetchDataForStream(), this method will skip the copy to the + // externally-supplied buffer. However, the app must now remember to call + // discardFetchedData() with the total amount of bytes that were successfully + // written to the output stream. After calling this function, the buffer + // is guaranteed to be valid and unchanged up to n_bytes, until the call to + // discardFetchedData(). + void fetchDataPtrForStream(uint8_t * & buffer, size_t & n_bytes); + + // Discard data that was written to output stream. Companion to fetchDataPtrForStream(). + void discardFetchedData(size_t n_bytes); + + + // Check if some INCOMING frame data is available to be fetched + bool isPacketDataAvailable(void); + + // Fetch some of the current available packet data: + // - max_bytes, buffer: (INPUT) size and location of output buffer + // - n_bytes: number of bytes actually fetched, up to max_bytes + // - packet_offset: offset from start of packet of fetched data + // - packet_binary: true if packet is binary, false if text + // - last_fetch: true if fetch served last of packet data + // + // This method will merge together multiple frames belonging to the same + // fragmented packet, if applicable. + // After calling this function, n_bytes of packet data are discarded. + void fetchPacketData(size_t max_size, uint8_t * buffer, size_t & n_bytes, + uint64_t & packet_offset, bool & packet_binary, bool & last_fetch); + + + + // Start an outgoing packet, with type of either text or binary + bool startPacket(bool packet_binary); + + bool isOpenPacket(void) { return _tx_pktopen; } + + // Add data to the outgoing packet started by startPacket(). This returns + // the number of bytes actually consumed from the buffer and made into a + // websocket frame. The app SHOULD call fetchDataForStream() at least once + // after calling this function, in order to fetch resulting packet data and + // free space for further processing. The endOfPacket flag MUST be set + // if the buffer contains the last byte of data to be sent to the stream. + // NOTE: if the endOfPacket flag is set, BUT the method returns less than + // the specified buffer size, the app MUST fetch some data to be sent to + // the stream using fetchDataForStream() in order to free some space, then + // retry with the remaining unsent buffer, and endOfPacket again set. + size_t addPacketData(size_t size, const uint8_t * buffer, bool endOfPacket); +}; + +} // namespace AsyncMqttClientInternals