Skip to content

Add encrypted TLS connection support using certificates or PSK #266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
227 changes: 211 additions & 16 deletions src/AsyncMqttClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,14 @@ AsyncMqttClient::AsyncMqttClient()
, _willQos(0)
, _willRetain(false)
#if ASYNC_TCP_SSL_ENABLED
#ifdef ESP8266
, _secureServerFingerprints()
#endif
#endif
, _wsfilter(nullptr)
, _wsEnabled(false)
, _wsUri("/")
, _wsHost(nullptr)
, _onConnectUserCallbacks()
, _onDisconnectUserCallbacks()
, _onSubscribeUserCallbacks()
Expand Down Expand Up @@ -130,13 +136,40 @@ AsyncMqttClient& AsyncMqttClient::setSecure(bool secure) {
return *this;
}

#ifdef ESP8266
AsyncMqttClient& AsyncMqttClient::addServerFingerprint(const uint8_t* fingerprint) {
std::array<uint8_t, SHA1_SIZE> newFingerprint;
memcpy(newFingerprint.data(), fingerprint, SHA1_SIZE);
_secureServerFingerprints.push_back(newFingerprint);
return *this;
}
#endif
#ifdef ESP32
AsyncMqttClient& AsyncMqttClient::setRootCa(const char* rootca, const size_t len) {
_client.setRootCa(rootca, len);
return *this;
}
AsyncMqttClient& AsyncMqttClient::setClientCert(const char* cli_cert, const size_t cli_cert_len, const char* cli_key, const size_t cli_key_len) {
_client.setClientCert(cli_cert, cli_cert_len);
_client.setClientKey(cli_key, cli_key_len);
return *this;
}
AsyncMqttClient& AsyncMqttClient::setPsk(const char* psk_ident, const char* psk) {
_client.setPsk(psk_ident, psk);
return *this;
}
#endif
#endif

AsyncMqttClient& AsyncMqttClient::setWsEnabled(bool wsenabled) {
_wsEnabled = wsenabled;
return *this;
}

AsyncMqttClient& AsyncMqttClient::setWsUri(const char * uri) {
_wsUri = uri;
return *this;
}

AsyncMqttClient& AsyncMqttClient::onConnect(AsyncMqttClientInternals::OnConnectUserCallback callback) {
_onConnectUserCallbacks.push_back(callback);
Expand Down Expand Up @@ -187,10 +220,11 @@ void AsyncMqttClient::_clear() {
void AsyncMqttClient::_onConnect() {
log_i("TCP conn, MQTT CONNECT");
#if ASYNC_TCP_SSL_ENABLED
#ifdef ESP8266
if (_secure && _secureServerFingerprints.size() > 0) {
bool sslFoundFingerprint = false;
SSL* clientSsl = _client.getSSL();

bool sslFoundFingerprint = false;
for (std::array<uint8_t, SHA1_SIZE> fingerprint : _secureServerFingerprints) {
if (ssl_match_fingerprint(clientSsl, fingerprint.data()) == SSL_OK) {
sslFoundFingerprint = true;
Expand All @@ -205,6 +239,29 @@ void AsyncMqttClient::_onConnect() {
}
}
#endif
#endif

if (_wsEnabled) {
static const char * ws_mqtt_protos[1] = { "mqtt" };
const char * tmpHostPtr;

if (_wsHost != NULL) delete[] _wsHost;
_wsHost = NULL;
if (_useIp) {
auto sIP = _ip.toString();
_wsHost = new char[sIP.length() + 1];
strcpy(_wsHost, sIP.c_str());
tmpHostPtr = _wsHost;
} else {
tmpHostPtr = _host;
}

_wsfilter = new AsyncMqttClientInternals::WebsocketFilter(
tmpHostPtr,
_wsUri,
1, ws_mqtt_protos);
}

AsyncMqttClientInternals::OutPacket* msg =
new AsyncMqttClientInternals::ConnectOutPacket(_cleanSession,
_username,
Expand All @@ -227,6 +284,13 @@ void AsyncMqttClient::_onDisconnect() {
_clear();

for (auto callback : _onDisconnectUserCallbacks) callback(_disconnectReason);

if (_wsfilter != NULL) {
delete _wsfilter;
_wsfilter = NULL;
}
if (_wsHost != NULL) delete[] _wsHost;
_wsHost = NULL;
}

/*
Expand All @@ -247,6 +311,69 @@ void AsyncMqttClient::_onAck(size_t len) {

void AsyncMqttClient::_onData(char* data, size_t len) {
log_i("data rcv (%u)", len);

if (_wsfilter == NULL) {
_onMQTTData(data, len);
return;
}

// Space occupied by data already fetched into websocket will be reused to
// hold MQTT payload data.
uint8_t * mqttdata = (uint8_t *)data;
uint8_t * outbufptr = mqttdata;
size_t outputavail = 0;
size_t mqttavail = 0;
do {
size_t proc_len = _wsfilter->addDataFromStream(len, (uint8_t *)data);
log_d("pushed %u of %u bytes into websocket filter", proc_len, len);
data += proc_len;
len -= proc_len;
outputavail += proc_len;

if (proc_len <= 0) {
log_w("websocket filter rx full?");
} else {
while (_wsfilter->isPacketDataAvailable()) {
size_t mqttbytes = 0;
uint64_t pktoffset = 0;
bool pktisbinary = false;
bool pktlastfetch = false;
_wsfilter->fetchPacketData(outputavail, outbufptr, mqttbytes,
pktoffset, pktisbinary, pktlastfetch);
log_d("fetched %u/%u mqtt bytes pktoff=%u binary=%u last=%u", mqttbytes, proc_len,
(uint32_t)pktoffset, pktisbinary ? 1 : 0, pktlastfetch ? 1 : 0);
mqttavail += mqttbytes;
outbufptr += mqttbytes;
outputavail -= mqttbytes;
}
}

// Check for any handshake or protocol errors
if (_wsfilter->getError() != AsyncMqttClientInternals::NO_ERROR) {
switch (_wsfilter->getError()) {
case AsyncMqttClientInternals::HANDSHAKE_FAILED:
log_w("Websocket handshake failure");
break;
case AsyncMqttClientInternals::PROTOCOL_ERROR:
log_w("Websocket protocol violation");
break;
case AsyncMqttClientInternals::REMOTE_CTRL_CLOSE:
log_w("Remote side closed websocket: code %hu", _wsfilter->getCloseCode());
if (_wsfilter->getCloseReason() != NULL) {
log_w("Reason text %s", _wsfilter->getCloseReason());
}
break;
}
disconnect(true);
break;
} else if (len <= 0) {
if (mqttavail > 0) _onMQTTData((char *)mqttdata, mqttavail);
}
} while (len > 0);
}

void AsyncMqttClient::_onMQTTData(char* data, size_t len) {
log_i("data rcv (%u)", len);
size_t currentBytePosition = 0;
char currentByte;
_lastServerActivity = millis();
Expand Down Expand Up @@ -395,23 +522,44 @@ void AsyncMqttClient::_handleQueue() {
SEMAPHORE_TAKE();
// On ESP32, onDisconnect is called within the close()-call. So we need to make sure we don't lock
bool disconnect = false;
bool wserror = false;
bool canflushqueue = true;

if (_wsfilter != NULL) {
canflushqueue = _flushWebsocketTX(wserror);
if (!canflushqueue) {
if (wserror) disconnect = true;
}
}

while (_head && _client.space() > 10) { // safe but arbitrary value, send at least 10 bytes
while (canflushqueue && _head && _client.space() > 10) { // safe but arbitrary value, send at least 10 bytes
// 1. try to send
if (_head->size() > _sent) {
// On SSL the TCP library returns the total amount of bytes, not just the unencrypted payload length.
// So we calculate the amount to be written ourselves.
size_t willSend = std::min(_head->size() - _sent, _client.space());
size_t realSent = _client.add(reinterpret_cast<const char*>(_head->data(_sent)), willSend, ASYNC_WRITE_FLAG_COPY); // flag is set by LWIP anyway, added for clarity
_sent += willSend;
(void)realSent;
_client.send();
_lastClientActivity = millis();
#if ASYNC_TCP_SSL_ENABLED
log_i("snd #%u: (tls: %u) %u/%u", _head->packetType(), realSent, _sent, _head->size());
#else
log_i("snd #%u: %u/%u", _head->packetType(), _sent, _head->size());
#endif
if (_wsfilter != NULL) {
if (_sent == 0 && !_wsfilter->isOpenPacket()) _wsfilter->startPacket(true);
size_t realSent = _wsfilter->addPacketData(_head->size() - _sent, _head->data(_sent), true);
canflushqueue = _flushWebsocketTX(wserror);
if (!canflushqueue) {
// At this point, any failure to flush is a websocket error
if (wserror) disconnect = true;
}
_sent += realSent;
_lastClientActivity = millis();
} else {
// On SSL the TCP library returns the total amount of bytes, not just the unencrypted payload length.
// So we calculate the amount to be written ourselves.
size_t willSend = std::min(_head->size() - _sent, _client.space());
size_t realSent = _client.add(reinterpret_cast<const char*>(_head->data(_sent)), willSend, ASYNC_WRITE_FLAG_COPY); // flag is set by LWIP anyway, added for clarity
_sent += willSend;
(void)realSent;
_client.send();
_lastClientActivity = millis();
#if ASYNC_TCP_SSL_ENABLED
log_i("snd #%u: (tls: %u) %u/%u", _head->packetType(), realSent, _sent, _head->size());
#else
log_i("snd #%u: %u/%u", _head->packetType(), _sent, _head->size());
#endif
}
if (_head->packetType() == AsyncMqttClientInternals::PacketType.DISCONNECT) {
disconnect = true;
}
Expand All @@ -434,11 +582,58 @@ void AsyncMqttClient::_handleQueue() {

SEMAPHORE_GIVE();
if (disconnect) {
log_i("snd DISCONN, disconnecting");
if (wserror)
log_i("websocket error, disconnecting");
else
log_i("snd DISCONN, disconnecting");
_client.close();
}
}

bool AsyncMqttClient::_flushWebsocketTX(bool & disconnect)
{
// Flush as much of the wsfilter TX buffer as possible, required for handshake
while (_client.space() > 10 && _wsfilter->pendingStreamOutputLen() > 0) {
log_d("client.space = %u pending ws output = %u", _client.space(), _wsfilter->pendingStreamOutputLen());
uint8_t * wsbuffer = NULL;
size_t wsbytes = 0;
_wsfilter->fetchDataPtrForStream(wsbuffer, wsbytes);
wsbytes = std::min(wsbytes, _client.space());
size_t wsbytes_sent = _client.add((const char *)wsbuffer, wsbytes, ASYNC_WRITE_FLAG_COPY);
if (wsbytes_sent > 0) {
_client.send();
_wsfilter->discardFetchedData(wsbytes_sent);
} else {
break;
}
}

switch (_wsfilter->getState()) {
case AsyncMqttClientInternals::HANDSHAKE_TX:
case AsyncMqttClientInternals::HANDSHAKE_RX:
// Still negociating websocket handshake
log_v("still negotiating handshake");
return false;
case AsyncMqttClientInternals::HANDSHAKE_ERR:
log_w("Websocket handshake failure");
disconnect = true;
return false;
case AsyncMqttClientInternals::WEBSOCKET_ERROR:
log_w("Websocket protocol violation");
disconnect = true;
return false;
case AsyncMqttClientInternals::WEBSOCKET_CLOSED:
log_w("Remote side closed websocket: code %hu", _wsfilter->getCloseCode());
if (_wsfilter->getCloseReason() != NULL) {
log_w("Reason text %s", _wsfilter->getCloseReason());
}
disconnect = true;
return false;
}

return true;
}

void AsyncMqttClient::_clearQueue(bool keepSessionData) {
SEMAPHORE_TAKE();
AsyncMqttClientInternals::OutPacket* packet = _head;
Expand Down
23 changes: 23 additions & 0 deletions src/AsyncMqttClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
#endif

#if ASYNC_TCP_SSL_ENABLED
#ifdef ESP8266
#include <tcp_axtls.h>
#endif
#define SHA1_SIZE 20
#endif

Expand Down Expand Up @@ -50,6 +52,8 @@
#include "AsyncMqttClient/Packets/Out/Unsubscribe.hpp"
#include "AsyncMqttClient/Packets/Out/Publish.hpp"

#include "AsyncMqttClient/WebsocketFilter.hpp"

class AsyncMqttClient {
public:
AsyncMqttClient();
Expand All @@ -65,8 +69,17 @@ class AsyncMqttClient {
AsyncMqttClient& setServer(const char* host, uint16_t port);
#if ASYNC_TCP_SSL_ENABLED
AsyncMqttClient& setSecure(bool secure);
#ifdef ESP8266
AsyncMqttClient& addServerFingerprint(const uint8_t* fingerprint);
#endif
#ifdef ESP32
AsyncMqttClient& setRootCa(const char* rootca, const size_t len);
AsyncMqttClient& setClientCert(const char* cli_cert, const size_t cli_cert_len, const char* cli_key, const size_t cli_key_len);
AsyncMqttClient& setPsk(const char* psk_ident, const char* psk);
#endif
#endif
AsyncMqttClient& setWsEnabled(bool wsenabled);
AsyncMqttClient& setWsUri(const char * uri);

AsyncMqttClient& onConnect(AsyncMqttClientInternals::OnConnectUserCallback callback);
AsyncMqttClient& onDisconnect(AsyncMqttClientInternals::OnDisconnectUserCallback callback);
Expand Down Expand Up @@ -121,9 +134,16 @@ class AsyncMqttClient {
bool _willRetain;

#if ASYNC_TCP_SSL_ENABLED
#ifdef ESP8266
std::vector<std::array<uint8_t, SHA1_SIZE>> _secureServerFingerprints;
#endif
#endif

AsyncMqttClientInternals::WebsocketFilter * _wsfilter;
bool _wsEnabled;
const char* _wsUri;
char* _wsHost;

std::vector<AsyncMqttClientInternals::OnConnectUserCallback> _onConnectUserCallbacks;
std::vector<AsyncMqttClientInternals::OnDisconnectUserCallback> _onDisconnectUserCallbacks;
std::vector<AsyncMqttClientInternals::OnSubscribeUserCallback> _onSubscribeUserCallbacks;
Expand Down Expand Up @@ -156,12 +176,15 @@ class AsyncMqttClient {
void _onData(char* data, size_t len);
void _onPoll();

void _onMQTTData(char* data, size_t len);

// QUEUE
void _insert(AsyncMqttClientInternals::OutPacket* packet); // for PUBREL
void _addFront(AsyncMqttClientInternals::OutPacket* packet); // for CONNECT
void _addBack(AsyncMqttClientInternals::OutPacket* packet); // all the rest
void _handleQueue();
void _clearQueue(bool keepSessionData);
bool _flushWebsocketTX(bool & disconnect);

// MQTT
void _onPingResp();
Expand Down
Loading
Loading