Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client connect timeout support #89

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions arduino/libraries/WiFi/src/WiFiClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

#include "WiFiClient.h"

extern "C" {
#include "esp_log.h"
}

WiFiClient::WiFiClient() :
WiFiClient(-1)
Expand Down Expand Up @@ -64,15 +67,59 @@ int WiFiClient::connect(IPAddress ip, uint16_t port)
addr.sin_addr.s_addr = (uint32_t)ip;
addr.sin_port = htons(port);

if (_connTimeout == 0) {
if (lwip_connect_r(_socket, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
lwip_close_r(_socket);
_socket = -1;
return 0;
}
}

int nonBlocking = 1;
lwip_ioctl_r(_socket, FIONBIO, &nonBlocking);

if (_connTimeout > 0) {
int res = lwip_connect_r(_socket, (struct sockaddr*)&addr, sizeof(addr));
if (res < 0 && errno != EINPROGRESS) {
ESP_LOGW("WiFiClient", "connect on socket %d, errno: %d, \"%s\"", _socket, errno, strerror(errno));
lwip_close_r(_socket);
_socket = -1;
return 0;
}

struct timeval tv;
tv.tv_sec = _connTimeout / 1000;
tv.tv_usec = (_connTimeout % 1000) * 1000;

fd_set fdset;
FD_ZERO(&fdset);
FD_SET(_socket, &fdset);

res = select(_socket + 1, nullptr, &fdset, nullptr, &tv);
if (res < 0) {
ESP_LOGW("WiFiClient", "select on socket %d, errno: %d, \"%s\"", _socket, errno, strerror(errno));
lwip_close_r(_socket);
return 0;
}
if (res == 0) {
ESP_LOGW("WiFiClient", "select returned due to timeout %d ms for socket %d", _connTimeout, _socket);
lwip_close_r(_socket);
return 0;
}
int sockerr;
socklen_t len = (socklen_t) sizeof(int);
res = lwip_getsockopt(_socket, SOL_SOCKET, SO_ERROR, &sockerr, &len);
if (res < 0) {
ESP_LOGW("WiFiClient", "getsockopt on socket %d, errno: %d, \"%s\"", _socket, errno, strerror(errno));
lwip_close_r(_socket);
return 0;
}
if (sockerr != 0) {
ESP_LOGW("WiFiClient", "socket error on socket %d, errno: %d, \"%s\"", _socket, sockerr, strerror(sockerr));
lwip_close_r(_socket);
return 0;
}
}
return 1;
}

Expand Down
3 changes: 3 additions & 0 deletions arduino/libraries/WiFi/src/WiFiClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class WiFiClient : public Client {
virtual /*IPAddress*/uint32_t remoteIP();
virtual uint16_t remotePort();

void setConnectionTimeout(uint16_t timeout) {_connTimeout = timeout;}

// using Print::write;

protected:
Expand All @@ -59,6 +61,7 @@ class WiFiClient : public Client {

private:
int _socket;
uint16_t _connTimeout = 0;
};

#endif // WIFICLIENT_H
85 changes: 84 additions & 1 deletion arduino/libraries/WiFi/src/WiFiSSLClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

#include "WiFiSSLClient.h"

extern "C" {
#include "esp_log.h"
}

class __Guard {
public:
__Guard(SemaphoreHandle_t handle) {
Expand Down Expand Up @@ -50,6 +54,8 @@ WiFiSSLClient::WiFiSSLClient() :
_mbedMutex = xSemaphoreCreateRecursiveMutex();
}

static int net_connect( mbedtls_net_context *ctx, const char *host, const char *port, int proto, uint16_t timeout);

int WiFiSSLClient::connect(const char* host, uint16_t port, bool sni)
{
synchronized {
Expand Down Expand Up @@ -113,7 +119,8 @@ int WiFiSSLClient::connect(const char* host, uint16_t port, bool sni)
char portStr[6];
itoa(port, portStr, 10);

if (mbedtls_net_connect(&_netContext, host, portStr, MBEDTLS_NET_PROTO_TCP) != 0) {
if (_connTimeout ? net_connect(&_netContext, host, portStr, MBEDTLS_NET_PROTO_TCP, _connTimeout)
: mbedtls_net_connect(&_netContext, host, portStr, MBEDTLS_NET_PROTO_TCP)) {
stop();
return 0;
}
Expand Down Expand Up @@ -293,3 +300,79 @@ uint16_t WiFiSSLClient::remotePort()

return ntohs(((struct sockaddr_in *)&addr)->sin_port);
}


/*
* based on mbedtls_net_connect, but with timeout support
*/
int net_connect(mbedtls_net_context *ctx, const char *host, const char *port, int proto, uint16_t timeout) {
int ret;
struct addrinfo hints, *addr_list, *cur;

/* Do name resolution with both IPv6 and IPv4 */
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = proto == MBEDTLS_NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM;
hints.ai_protocol =
proto == MBEDTLS_NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP;

if ( getaddrinfo( host, port, &hints, &addr_list ) != 0) {
return ( MBEDTLS_ERR_NET_UNKNOWN_HOST);
}

/* Try the sockaddrs until a connection succeeds */
ret = MBEDTLS_ERR_NET_UNKNOWN_HOST;
for (cur = addr_list; cur != NULL; cur = cur->ai_next) {
int fd = socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol);

if (fd < 0) {
ret = MBEDTLS_ERR_NET_SOCKET_FAILED;
continue;
}

mbedtls_net_context tmpCtx;
tmpCtx.fd = fd;
mbedtls_net_set_nonblock(&tmpCtx);

int res = connect(fd, cur->ai_addr, cur->ai_addrlen);
if (res < 0 && errno != EINPROGRESS) {
ESP_LOGW("WiFiSSLClient", "connect on fd %d, errno: %d, \"%s\"", fd, errno, strerror(errno));
} else {
struct timeval tv;
tv.tv_sec = timeout / 1000;
tv.tv_usec = (timeout % 1000) * 1000;

fd_set fdset;
FD_ZERO(&fdset);
FD_SET(fd, &fdset);

res = select(fd + 1, nullptr, &fdset, nullptr, &tv);
if (res < 0) {
ESP_LOGW("WiFiSSLClient", "select on fd %d, errno: %d, \"%s\"", fd, errno, strerror(errno));
} else if (res == 0) {
ESP_LOGW("WiFiSSLClient", "select returned due to timeout %d ms for fd %d", timeout, fd);
} else {
int sockerr;
socklen_t len = (socklen_t) sizeof(int);
res = getsockopt(fd, SOL_SOCKET, SO_ERROR, &sockerr, &len);
if (res < 0) {
ESP_LOGW("WiFiSSLClient", "getsockopt on fd %d, errno: %d, \"%s\"", fd, errno, strerror(errno));
} else if (sockerr != 0) {
ESP_LOGW("WiFiSSLClient", "socket error on fd %d, errno: %d, \"%s\"", fd, sockerr, strerror(sockerr));
} else {
ctx->fd = fd; // connected!
ret = 0;
mbedtls_net_set_block(ctx); // back to blocking for SSL handshake
break;
}
}
}
close(fd);
ret = MBEDTLS_ERR_NET_CONNECT_FAILED;
}

freeaddrinfo(addr_list);

return (ret);
}

3 changes: 3 additions & 0 deletions arduino/libraries/WiFi/src/WiFiSSLClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class WiFiSSLClient /*: public Client*/ {
virtual /*IPAddress*/uint32_t remoteIP();
virtual uint16_t remotePort();

void setConnectionTimeout(uint16_t timeout) {_connTimeout = timeout;}

private:
int connect(const char* host, uint16_t port, bool sni);

Expand All @@ -69,6 +71,7 @@ class WiFiSSLClient /*: public Client*/ {
mbedtls_x509_crt _caCrt;
bool _connected;
int _peek;
uint16_t _connTimeout = 0;

SemaphoreHandle_t _mbedMutex;
};
Expand Down
10 changes: 10 additions & 0 deletions main/CommandHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ int startClientTcp(const uint8_t command[], uint8_t response[])
uint16_t port;
uint8_t socket;
uint8_t type;
uint16_t timeout = 0;

memset(host, 0x00, sizeof(host));

Expand All @@ -611,11 +612,16 @@ int startClientTcp(const uint8_t command[], uint8_t response[])
port = ntohs(port);
socket = command[13 + command[3]];
type = command[15 + command[3]];
if (command[2] == 6) { // optional sixth parameter
timeout = (uint16_t) command[17 + command[3]] << 8 | command[18 + command[3]];
}
}

if (type == 0x00) {
int result;

tcpClients[socket].setConnectionTimeout(timeout);

if (host[0] != '\0') {
result = tcpClients[socket].connect(host, port);
} else {
Expand Down Expand Up @@ -660,6 +666,8 @@ int startClientTcp(const uint8_t command[], uint8_t response[])
} else if (type == 0x02) {
int result;

tlsClients[socket].setConnectionTimeout(timeout);

if (host[0] != '\0') {
result = tlsClients[socket].connect(host, port);
} else {
Expand All @@ -684,6 +692,8 @@ int startClientTcp(const uint8_t command[], uint8_t response[])

configureECCx08();

static_cast<WiFiClient*>(bearsslClient.getClient())->setConnectionTimeout(timeout);

if (host[0] != '\0') {
result = bearsslClient.connect(host, port);
} else {
Expand Down