Skip to content

Commit

Permalink
refactor socket for array
Browse files Browse the repository at this point in the history
  • Loading branch information
mpoegel committed Sep 1, 2024
1 parent fa5444f commit abc578d
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 80 deletions.
13 changes: 7 additions & 6 deletions basis/fwoop_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ Array::~Array() { delete[] d_data; }
Array *Array::operator=(const Array &rhs)
{
delete[] d_data;
d_data = new uint8_t[d_actualSize];
d_data = new uint8_t[rhs.d_actualSize];
memcpy(d_data, rhs.d_data, rhs.d_size);
d_size = rhs.d_size;
return this;
}

Expand Down Expand Up @@ -54,14 +55,14 @@ void Array::enlarge(uint32_t newSize)
// new size is smaller
}

void Array::append(const char *str, uint32_t len)
void Array::append(const uint8_t *buf, uint32_t bufLen)
{
uint32_t remaining = d_actualSize - d_size;
if (len > remaining) {
enlarge(d_size + len);
if (bufLen > remaining) {
enlarge(d_size + bufLen);
}
memcpy(d_data + d_size, str, len);
d_size += len;
memcpy(d_data + d_size, buf, bufLen);
d_size += bufLen;
}

void Array::append(uint8_t d) { d_data[d_size++] = d; }
Expand Down
12 changes: 11 additions & 1 deletion basis/fwoop_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ class Array {
uint8_t &operator[](uint32_t i);
const uint8_t &operator[](uint32_t i) const;
uint8_t *operator*();
const uint8_t *operator*() const;

void extend(const Array &arr);
void shrink(uint32_t newSize);
uint32_t size() const;
uint32_t maxSize() const;
void enlarge(uint32_t newSize);
void append(const std::string &str);
void append(const char *str, uint32_t len);
void append(const uint8_t *buf, uint32_t bufLen);
void append(uint8_t d);
Array subArray(uint32_t start, uint32_t end) const;
std::string toString() const;
Expand All @@ -40,10 +43,17 @@ class Array {
inline uint8_t &Array::operator[](uint32_t i) { return d_data[i]; }
inline const uint8_t &Array::operator[](uint32_t i) const { return d_data[i]; }
inline uint8_t *Array::operator*() { return d_data; }
inline const uint8_t *Array::operator*() const { return d_data; }
inline uint32_t Array::size() const { return d_size; }
inline uint32_t Array::maxSize() const { return d_actualSize; }
inline void Array::shrink(uint32_t newSize) { d_size = std::min(d_size, newSize); }
inline void Array::append(const char *str, uint32_t len) { append((uint8_t *)str, len); }
inline void Array::append(const std::string &str) { append(str.data(), str.length()); }
inline std::string Array::toString() const { return std::string((char *)d_data, d_size); }
inline void Array::clear() { memset(d_data, 0, d_actualSize); }
inline void Array::clear()
{
memset(d_data, 0, d_actualSize);
d_size = 0;
}

} // namespace fwoop
16 changes: 7 additions & 9 deletions basis/fwoop_socketio.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "fwoop_array.h"
#include <fwoop_socketio.h>
#include <system_error>

Expand All @@ -9,11 +10,8 @@ Socket::~Socket() {}

Socket::Socket(const Socket &rhs) : d_fd(rhs.d_fd) {}

std::error_code Socket::read(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead)
std::error_code Socket::read(Array &arr)
{
memset(buffer, 0, bufferSize);
bytesRead = 0;

struct pollfd pfd[1];
pfd[0].fd = d_fd;
pfd[0].events = POLLIN;
Expand All @@ -28,26 +26,26 @@ std::error_code Socket::read(uint8_t *buffer, uint32_t bufferSize, uint32_t &byt
}

if (pfd[0].revents & POLLIN) {
rc = ::read(d_fd, buffer, bufferSize);
rc = ::read(d_fd, *arr, arr.size());
if (0 == rc) {
// peer closed the connection
return std::error_code(errno, std::system_category());
} else if (rc < 0) {
// read error
return std::error_code(errno, std::system_category());
} else {
bytesRead = rc;
arr.shrink(rc);
}
}
return std::error_code();
}

std::error_code Socket::write(const uint8_t *out, uint32_t outLen, uint32_t &bytesWritten)
std::error_code Socket::write(const Array &arr, uint32_t &bytesWritten)
{
int rc = 0;
bytesWritten = 0;
while (bytesWritten < outLen) {
rc = ::write(d_fd, out + bytesWritten, outLen - bytesWritten);
while (bytesWritten < arr.size()) {
rc = ::write(d_fd, *arr + bytesWritten, arr.size() - bytesWritten);
if (rc < 0) {
// write failed
return std::error_code(errno, std::system_category());
Expand Down
9 changes: 5 additions & 4 deletions basis/fwoop_socketio.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <fwoop_array.h>
#include <fwoop_log.h>

#include <cstdint>
Expand All @@ -15,7 +16,7 @@ namespace fwoop {
class Reader {
public:
~Reader() {}
virtual std::error_code read(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead) = 0;
virtual std::error_code read(Array &arr) = 0;
virtual void close() = 0;
};

Expand All @@ -24,7 +25,7 @@ typedef std::shared_ptr<Reader> ReaderPtr_t;
class Writer {
public:
~Writer() {}
virtual std::error_code write(const uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesWritten) = 0;
virtual std::error_code write(const Array &arr, uint32_t &bytesWritten) = 0;
virtual void close() = 0;
};

Expand All @@ -47,8 +48,8 @@ class Socket : public SocketBase {
Socket(const Socket &rhs);
Socket operator=(const Socket &rhs);

std::error_code read(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead) override;
std::error_code write(const uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesWritten) override;
std::error_code read(Array &arr) override;
std::error_code write(const Array &arr, uint32_t &bytesWritten) override;
void close() override;
};

Expand Down
65 changes: 29 additions & 36 deletions crypto/fwoop_securesocket.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "fwoop_array.h"
#include <arpa/inet.h>
#include <botan/auto_rng.h>
#include <botan/tls_client.h>
Expand Down Expand Up @@ -39,31 +40,32 @@ SecureSocket::~SecureSocket() {}

std::error_code SecureSocket::handshake()
{
uint8_t buf[16384];
Array arr(16384);
uint32_t bytesRead = 0;
std::error_code ec;
while (!d_client->is_active() && !ec) {
ec = read(buf, sizeof(buf), bytesRead);
// TODO capture lost data
ec = read(arr);
}
return ec;
}

std::error_code SecureSocket::read(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead)
std::error_code SecureSocket::read(Array &arr)
{
bytesRead = 0;
uint32_t bytesRead = 0;

// return early if there's already pending data
d_callbacks->readWaiting(buffer, bufferSize, bytesRead);
if (bytesRead > 0) {
d_callbacks->readWaiting(arr);
if (arr.size() > 0) {
return std::error_code();
}

size_t numMoreBytes = bufferSize;
uint32_t numMoreBytes = arr.maxSize();
struct pollfd pfd[1];
pfd[0].fd = d_fd;
pfd[0].events = POLLIN;

while (numMoreBytes > 0 && bytesRead < bufferSize) {
while (numMoreBytes > 0) {
int rc = poll(pfd, 1, 1000);
if (rc == 0) {
// poll timed out
Expand All @@ -74,38 +76,32 @@ std::error_code SecureSocket::read(uint8_t *buffer, uint32_t bufferSize, uint32_
}

if (pfd[0].revents & POLLIN) {
rc = ::read(d_fd, buffer, numMoreBytes);
rc = ::read(d_fd, *arr, numMoreBytes);
if (0 == rc) {
// peer closed the connection
return std::error_code(errno, std::system_category());
} else if (rc < 0) {
// read error
return std::error_code(errno, std::system_category());
} else {
bytesRead = rc;
}
}

Log::Debug("tls read ", bytesRead, " bytes");
bytesRead += rc;
Log::Debug("tls read ", rc, " bytes");
// decrypt the incoming data
numMoreBytes = d_client->received_data(buffer, bytesRead);
numMoreBytes = d_client->received_data(*arr, rc);
Log::Debug("need ", numMoreBytes, " more bytes to complete TLS record");
}

if (bytesRead == bufferSize) {
Log::Error("read buffer full!");
// TODO return error
numMoreBytes = std::min(arr.maxSize(), numMoreBytes);
}

// reset the buffer
memset(buffer, 0, bufferSize);
bytesRead = 0;
arr.clear();
// copy the decrypted data into the buffer
d_callbacks->readWaiting(buffer, bufferSize, bytesRead);
d_callbacks->readWaiting(arr);
return std::error_code();
}

std::error_code SecureSocket::write(const uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesWritten)
std::error_code SecureSocket::write(const Array &arr, uint32_t &bytesWritten)
{
if (d_client->is_closed_for_writing()) {
Log::Error("cannot write at this time");
Expand All @@ -114,8 +110,8 @@ std::error_code SecureSocket::write(const uint8_t *buffer, uint32_t bufferSize,
if (!d_client->is_active()) {
handshake();
}
d_client->send(buffer, bufferSize);
bytesWritten = bufferSize;
d_client->send(*arr, arr.size());
bytesWritten = arr.size();
return std::error_code();
}

Expand Down Expand Up @@ -181,18 +177,17 @@ SocketBasePtr_t SecureSocketFactory::connect()
return std::make_shared<SecureSocket>(fd, config);
}

SecureCallbacks::SecureCallbacks(int fd) : d_fd(fd), d_peer_closed(false), d_readWaiting(0)
{
memset(d_readBuffer, 0, sizeof(d_readBuffer));
}
SecureCallbacks::SecureCallbacks(int fd) : d_fd(fd), d_peer_closed(false), d_readBuffer(16384), d_readWaiting(0) {}

void SecureCallbacks::readWaiting(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead)
void SecureCallbacks::readWaiting(Array &arr)
{
bytesRead = 0;
if (d_readWaiting > 0) {
bytesRead = std::min(bufferSize, d_readWaiting);
memcpy(buffer, d_readBuffer, bytesRead);
d_readWaiting -= bytesRead;
Log::Debug("waiting ", d_readWaiting);
// TODO remove unnecessary allocations
arr = d_readBuffer;
d_readWaiting = 0;
d_readBuffer.clear();
d_readBuffer.shrink(0);
}
}

Expand All @@ -218,9 +213,7 @@ void SecureCallbacks::tls_record_received(uint64_t seqNum, std::span<const uint8
{
// TODO we could recv more data than the read buffer size?
Log::Info("got app data [seqNum ", seqNum, "]: ", data.size(), " bytes");
uint32_t n =
sizeof(d_readBuffer) > d_readWaiting + data.size() ? data.size() : sizeof(d_readBuffer) - d_readWaiting;
memcpy(d_readBuffer + d_readWaiting, data.data(), n);
d_readBuffer.append(data.data(), data.size_bytes());
d_readWaiting += data.size();
}

Expand Down
9 changes: 5 additions & 4 deletions crypto/fwoop_securesocket.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <fwoop_array.h>
#include <fwoop_socketio.h>
#include <fwoop_tlscredentials.h>

Expand Down Expand Up @@ -35,8 +36,8 @@ class SecureSocket : public SocketBase {
std::error_code handshake();

// from SocketBase
std::error_code read(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead) override;
std::error_code write(const uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesWritten) override;
std::error_code read(Array &arr) override;
std::error_code write(const Array &arr, uint32_t &bytesWritten) override;
void close() override;
};

Expand Down Expand Up @@ -71,14 +72,14 @@ class SecureCallbacks : public Botan::TLS::Callbacks {
private:
int d_fd;
bool d_peer_closed;
uint8_t d_readBuffer[16384];
Array d_readBuffer;
uint32_t d_readWaiting;

public:
SecureCallbacks(int fd);
~SecureCallbacks() {}

void readWaiting(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead);
void readWaiting(Array &arr);

// from Botan::TLS::Callbacks
void tls_emit_data(std::span<const uint8_t> data) override;
Expand Down
17 changes: 10 additions & 7 deletions examples/tlsclient.m.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "fwoop_array.h"
#include <cstdint>
#include <cstring>
#include <iostream>
Expand All @@ -21,28 +22,30 @@ int main(int argc, char *argv[])
::sleep(1);
if (sock) {
fwoop::Log::Info("connected");
uint8_t buf[16384];
uint32_t bytesRead = 0;
uint8_t msg[] = "GET /get HTTP/1.1\r\n"
"Host: httpbin.org\r\n"
"User-Agent: fwoop/1\r\n"
"Accept: */*\r\n\r\n";
uint32_t bytesWritten = 0;
auto ec = sock->write(msg, sizeof(msg) - 1, bytesWritten);
fwoop::Array arr(sizeof(msg));
arr.append(msg, sizeof(msg) - 1);
auto ec = sock->write(arr, bytesWritten);
if (ec) {
fwoop::Log::Error("secure write failed: ", ec.message());
return 1;
}
const uint32_t bufSize = 16384;
fwoop::Array buf(bufSize);
do {
bytesRead = 0;
memset(buf, 0, sizeof(buf));
ec = sock->read(buf, sizeof(buf), bytesRead);
buf.clear();
ec = sock->read(buf);
if (ec) {
fwoop::Log::Error("secure read failed: ", ec.message());
return 1;
}
fwoop::Log::Info("readBytes=", bytesRead, " DATA: ", std::string(buf, buf + bytesRead));
} while (!ec && bytesRead > 0);
fwoop::Log::Info("readBytes=", bytesRead, " DATA: ", buf.toString());
} while (!ec && buf.size() > 0);
}

fwoop::Log::Info("done");
Expand Down
Loading

0 comments on commit abc578d

Please sign in to comment.