From abc578d6e9226bf01aee2317ab7b4293488e62ae Mon Sep 17 00:00:00 2001 From: Matt Poegel Date: Sun, 1 Sep 2024 02:07:53 +0000 Subject: [PATCH] refactor socket for array --- basis/fwoop_array.cpp | 13 ++++--- basis/fwoop_array.h | 12 +++++- basis/fwoop_socketio.cpp | 16 ++++---- basis/fwoop_socketio.h | 9 +++-- crypto/fwoop_securesocket.cpp | 65 ++++++++++++++------------------ crypto/fwoop_securesocket.h | 9 +++-- examples/tlsclient.m.cpp | 17 +++++---- http/fwoop_httpconnhandler.cpp | 11 +++--- http/fwoop_httpconnhandler.g.cpp | 11 +++--- http/fwoop_httpserverevent.cpp | 10 ++++- 10 files changed, 93 insertions(+), 80 deletions(-) diff --git a/basis/fwoop_array.cpp b/basis/fwoop_array.cpp index 1e6f96f..2cdd5a3 100644 --- a/basis/fwoop_array.cpp +++ b/basis/fwoop_array.cpp @@ -23,8 +23,9 @@ Array::~Array() { delete[] d_data; } Array *Array::operator=(const Array &rhs) { delete[] d_data; - d_data = new uint8_t[d_actualSize]; + d_data = new uint8_t[rhs.d_actualSize]; memcpy(d_data, rhs.d_data, rhs.d_size); + d_size = rhs.d_size; return this; } @@ -54,14 +55,14 @@ void Array::enlarge(uint32_t newSize) // new size is smaller } -void Array::append(const char *str, uint32_t len) +void Array::append(const uint8_t *buf, uint32_t bufLen) { uint32_t remaining = d_actualSize - d_size; - if (len > remaining) { - enlarge(d_size + len); + if (bufLen > remaining) { + enlarge(d_size + bufLen); } - memcpy(d_data + d_size, str, len); - d_size += len; + memcpy(d_data + d_size, buf, bufLen); + d_size += bufLen; } void Array::append(uint8_t d) { d_data[d_size++] = d; } diff --git a/basis/fwoop_array.h b/basis/fwoop_array.h index c39c20e..32b0f03 100644 --- a/basis/fwoop_array.h +++ b/basis/fwoop_array.h @@ -23,13 +23,16 @@ class Array { uint8_t &operator[](uint32_t i); const uint8_t &operator[](uint32_t i) const; uint8_t *operator*(); + const uint8_t *operator*() const; void extend(const Array &arr); void shrink(uint32_t newSize); uint32_t size() const; + uint32_t maxSize() const; void enlarge(uint32_t newSize); void append(const std::string &str); void append(const char *str, uint32_t len); + void append(const uint8_t *buf, uint32_t bufLen); void append(uint8_t d); Array subArray(uint32_t start, uint32_t end) const; std::string toString() const; @@ -40,10 +43,17 @@ class Array { inline uint8_t &Array::operator[](uint32_t i) { return d_data[i]; } inline const uint8_t &Array::operator[](uint32_t i) const { return d_data[i]; } inline uint8_t *Array::operator*() { return d_data; } +inline const uint8_t *Array::operator*() const { return d_data; } inline uint32_t Array::size() const { return d_size; } +inline uint32_t Array::maxSize() const { return d_actualSize; } inline void Array::shrink(uint32_t newSize) { d_size = std::min(d_size, newSize); } +inline void Array::append(const char *str, uint32_t len) { append((uint8_t *)str, len); } inline void Array::append(const std::string &str) { append(str.data(), str.length()); } inline std::string Array::toString() const { return std::string((char *)d_data, d_size); } -inline void Array::clear() { memset(d_data, 0, d_actualSize); } +inline void Array::clear() +{ + memset(d_data, 0, d_actualSize); + d_size = 0; +} } // namespace fwoop diff --git a/basis/fwoop_socketio.cpp b/basis/fwoop_socketio.cpp index 7deeceb..fc3bddd 100644 --- a/basis/fwoop_socketio.cpp +++ b/basis/fwoop_socketio.cpp @@ -1,3 +1,4 @@ +#include "fwoop_array.h" #include #include @@ -9,11 +10,8 @@ Socket::~Socket() {} Socket::Socket(const Socket &rhs) : d_fd(rhs.d_fd) {} -std::error_code Socket::read(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead) +std::error_code Socket::read(Array &arr) { - memset(buffer, 0, bufferSize); - bytesRead = 0; - struct pollfd pfd[1]; pfd[0].fd = d_fd; pfd[0].events = POLLIN; @@ -28,7 +26,7 @@ std::error_code Socket::read(uint8_t *buffer, uint32_t bufferSize, uint32_t &byt } if (pfd[0].revents & POLLIN) { - rc = ::read(d_fd, buffer, bufferSize); + rc = ::read(d_fd, *arr, arr.size()); if (0 == rc) { // peer closed the connection return std::error_code(errno, std::system_category()); @@ -36,18 +34,18 @@ std::error_code Socket::read(uint8_t *buffer, uint32_t bufferSize, uint32_t &byt // read error return std::error_code(errno, std::system_category()); } else { - bytesRead = rc; + arr.shrink(rc); } } return std::error_code(); } -std::error_code Socket::write(const uint8_t *out, uint32_t outLen, uint32_t &bytesWritten) +std::error_code Socket::write(const Array &arr, uint32_t &bytesWritten) { int rc = 0; bytesWritten = 0; - while (bytesWritten < outLen) { - rc = ::write(d_fd, out + bytesWritten, outLen - bytesWritten); + while (bytesWritten < arr.size()) { + rc = ::write(d_fd, *arr + bytesWritten, arr.size() - bytesWritten); if (rc < 0) { // write failed return std::error_code(errno, std::system_category()); diff --git a/basis/fwoop_socketio.h b/basis/fwoop_socketio.h index 80bc672..bab7da7 100644 --- a/basis/fwoop_socketio.h +++ b/basis/fwoop_socketio.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -15,7 +16,7 @@ namespace fwoop { class Reader { public: ~Reader() {} - virtual std::error_code read(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead) = 0; + virtual std::error_code read(Array &arr) = 0; virtual void close() = 0; }; @@ -24,7 +25,7 @@ typedef std::shared_ptr ReaderPtr_t; class Writer { public: ~Writer() {} - virtual std::error_code write(const uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesWritten) = 0; + virtual std::error_code write(const Array &arr, uint32_t &bytesWritten) = 0; virtual void close() = 0; }; @@ -47,8 +48,8 @@ class Socket : public SocketBase { Socket(const Socket &rhs); Socket operator=(const Socket &rhs); - std::error_code read(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead) override; - std::error_code write(const uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesWritten) override; + std::error_code read(Array &arr) override; + std::error_code write(const Array &arr, uint32_t &bytesWritten) override; void close() override; }; diff --git a/crypto/fwoop_securesocket.cpp b/crypto/fwoop_securesocket.cpp index 0668efc..27ddef6 100644 --- a/crypto/fwoop_securesocket.cpp +++ b/crypto/fwoop_securesocket.cpp @@ -1,3 +1,4 @@ +#include "fwoop_array.h" #include #include #include @@ -39,31 +40,32 @@ SecureSocket::~SecureSocket() {} std::error_code SecureSocket::handshake() { - uint8_t buf[16384]; + Array arr(16384); uint32_t bytesRead = 0; std::error_code ec; while (!d_client->is_active() && !ec) { - ec = read(buf, sizeof(buf), bytesRead); + // TODO capture lost data + ec = read(arr); } return ec; } -std::error_code SecureSocket::read(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead) +std::error_code SecureSocket::read(Array &arr) { - bytesRead = 0; + uint32_t bytesRead = 0; // return early if there's already pending data - d_callbacks->readWaiting(buffer, bufferSize, bytesRead); - if (bytesRead > 0) { + d_callbacks->readWaiting(arr); + if (arr.size() > 0) { return std::error_code(); } - size_t numMoreBytes = bufferSize; + uint32_t numMoreBytes = arr.maxSize(); struct pollfd pfd[1]; pfd[0].fd = d_fd; pfd[0].events = POLLIN; - while (numMoreBytes > 0 && bytesRead < bufferSize) { + while (numMoreBytes > 0) { int rc = poll(pfd, 1, 1000); if (rc == 0) { // poll timed out @@ -74,38 +76,32 @@ std::error_code SecureSocket::read(uint8_t *buffer, uint32_t bufferSize, uint32_ } if (pfd[0].revents & POLLIN) { - rc = ::read(d_fd, buffer, numMoreBytes); + rc = ::read(d_fd, *arr, numMoreBytes); if (0 == rc) { // peer closed the connection return std::error_code(errno, std::system_category()); } else if (rc < 0) { // read error return std::error_code(errno, std::system_category()); - } else { - bytesRead = rc; } } - Log::Debug("tls read ", bytesRead, " bytes"); + bytesRead += rc; + Log::Debug("tls read ", rc, " bytes"); // decrypt the incoming data - numMoreBytes = d_client->received_data(buffer, bytesRead); + numMoreBytes = d_client->received_data(*arr, rc); Log::Debug("need ", numMoreBytes, " more bytes to complete TLS record"); - } - - if (bytesRead == bufferSize) { - Log::Error("read buffer full!"); - // TODO return error + numMoreBytes = std::min(arr.maxSize(), numMoreBytes); } // reset the buffer - memset(buffer, 0, bufferSize); - bytesRead = 0; + arr.clear(); // copy the decrypted data into the buffer - d_callbacks->readWaiting(buffer, bufferSize, bytesRead); + d_callbacks->readWaiting(arr); return std::error_code(); } -std::error_code SecureSocket::write(const uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesWritten) +std::error_code SecureSocket::write(const Array &arr, uint32_t &bytesWritten) { if (d_client->is_closed_for_writing()) { Log::Error("cannot write at this time"); @@ -114,8 +110,8 @@ std::error_code SecureSocket::write(const uint8_t *buffer, uint32_t bufferSize, if (!d_client->is_active()) { handshake(); } - d_client->send(buffer, bufferSize); - bytesWritten = bufferSize; + d_client->send(*arr, arr.size()); + bytesWritten = arr.size(); return std::error_code(); } @@ -181,18 +177,17 @@ SocketBasePtr_t SecureSocketFactory::connect() return std::make_shared(fd, config); } -SecureCallbacks::SecureCallbacks(int fd) : d_fd(fd), d_peer_closed(false), d_readWaiting(0) -{ - memset(d_readBuffer, 0, sizeof(d_readBuffer)); -} +SecureCallbacks::SecureCallbacks(int fd) : d_fd(fd), d_peer_closed(false), d_readBuffer(16384), d_readWaiting(0) {} -void SecureCallbacks::readWaiting(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead) +void SecureCallbacks::readWaiting(Array &arr) { - bytesRead = 0; if (d_readWaiting > 0) { - bytesRead = std::min(bufferSize, d_readWaiting); - memcpy(buffer, d_readBuffer, bytesRead); - d_readWaiting -= bytesRead; + Log::Debug("waiting ", d_readWaiting); + // TODO remove unnecessary allocations + arr = d_readBuffer; + d_readWaiting = 0; + d_readBuffer.clear(); + d_readBuffer.shrink(0); } } @@ -218,9 +213,7 @@ void SecureCallbacks::tls_record_received(uint64_t seqNum, std::span d_readWaiting + data.size() ? data.size() : sizeof(d_readBuffer) - d_readWaiting; - memcpy(d_readBuffer + d_readWaiting, data.data(), n); + d_readBuffer.append(data.data(), data.size_bytes()); d_readWaiting += data.size(); } diff --git a/crypto/fwoop_securesocket.h b/crypto/fwoop_securesocket.h index 9be2d4d..84ecbe3 100644 --- a/crypto/fwoop_securesocket.h +++ b/crypto/fwoop_securesocket.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -35,8 +36,8 @@ class SecureSocket : public SocketBase { std::error_code handshake(); // from SocketBase - std::error_code read(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead) override; - std::error_code write(const uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesWritten) override; + std::error_code read(Array &arr) override; + std::error_code write(const Array &arr, uint32_t &bytesWritten) override; void close() override; }; @@ -71,14 +72,14 @@ class SecureCallbacks : public Botan::TLS::Callbacks { private: int d_fd; bool d_peer_closed; - uint8_t d_readBuffer[16384]; + Array d_readBuffer; uint32_t d_readWaiting; public: SecureCallbacks(int fd); ~SecureCallbacks() {} - void readWaiting(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead); + void readWaiting(Array &arr); // from Botan::TLS::Callbacks void tls_emit_data(std::span data) override; diff --git a/examples/tlsclient.m.cpp b/examples/tlsclient.m.cpp index 776b8b2..cd1e846 100644 --- a/examples/tlsclient.m.cpp +++ b/examples/tlsclient.m.cpp @@ -1,3 +1,4 @@ +#include "fwoop_array.h" #include #include #include @@ -21,28 +22,30 @@ int main(int argc, char *argv[]) ::sleep(1); if (sock) { fwoop::Log::Info("connected"); - uint8_t buf[16384]; uint32_t bytesRead = 0; uint8_t msg[] = "GET /get HTTP/1.1\r\n" "Host: httpbin.org\r\n" "User-Agent: fwoop/1\r\n" "Accept: */*\r\n\r\n"; uint32_t bytesWritten = 0; - auto ec = sock->write(msg, sizeof(msg) - 1, bytesWritten); + fwoop::Array arr(sizeof(msg)); + arr.append(msg, sizeof(msg) - 1); + auto ec = sock->write(arr, bytesWritten); if (ec) { fwoop::Log::Error("secure write failed: ", ec.message()); return 1; } + const uint32_t bufSize = 16384; + fwoop::Array buf(bufSize); do { - bytesRead = 0; - memset(buf, 0, sizeof(buf)); - ec = sock->read(buf, sizeof(buf), bytesRead); + buf.clear(); + ec = sock->read(buf); if (ec) { fwoop::Log::Error("secure read failed: ", ec.message()); return 1; } - fwoop::Log::Info("readBytes=", bytesRead, " DATA: ", std::string(buf, buf + bytesRead)); - } while (!ec && bytesRead > 0); + fwoop::Log::Info("readBytes=", bytesRead, " DATA: ", buf.toString()); + } while (!ec && buf.size() > 0); } fwoop::Log::Info("done"); diff --git a/http/fwoop_httpconnhandler.cpp b/http/fwoop_httpconnhandler.cpp index 0c68e51..5a40133 100644 --- a/http/fwoop_httpconnhandler.cpp +++ b/http/fwoop_httpconnhandler.cpp @@ -27,15 +27,14 @@ void HttpConnHandler::operator()() Log::Debug("received http/1.1 connection"); constexpr unsigned int bufferSize = 2048; - uint8_t buffer[bufferSize]; - unsigned int bytesRead; - std::error_code ec = d_reader->read(buffer, bufferSize, bytesRead); + Array arr(bufferSize); + std::error_code ec = d_reader->read(arr); if (ec) { Log::Error("socket read failed", ec); } unsigned int bytesParsed = 0; - std::shared_ptr request = HttpRequest::parse(buffer, bytesRead, bytesParsed); + std::shared_ptr request = HttpRequest::parse(*arr, arr.size(), bytesParsed); if (!request) { Log::Error("did not receive full http request"); } @@ -48,7 +47,9 @@ void HttpConnHandler::operator()() uint32_t length; uint32_t bytesWritten; uint8_t *encResp = response.encode(length); - ec = d_writer->write(encResp, length, bytesWritten); + Array resp(length); + resp.append(encResp, length); + ec = d_writer->write(resp, bytesWritten); delete[] encResp; if (ec) { Log::Error("socket write failed, ec=", ec); diff --git a/http/fwoop_httpconnhandler.g.cpp b/http/fwoop_httpconnhandler.g.cpp index 2f29887..88d3443 100644 --- a/http/fwoop_httpconnhandler.g.cpp +++ b/http/fwoop_httpconnhandler.g.cpp @@ -1,3 +1,4 @@ +#include "fwoop_array.h" #include "fwoop_socketio.h" #include "gmock/gmock.h" #include @@ -12,10 +13,9 @@ class MockHttpReader : public fwoop::Reader { public: ~MockHttpReader() {} std::string response; - std::error_code read(uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesRead) + std::error_code read(fwoop::Array &arr) { - ::memcpy(buffer, response.data(), response.length()); - bytesRead = response.length(); + arr.append(response); return std::error_code(); } MOCK_METHOD(void, close, ()); @@ -24,7 +24,7 @@ class MockHttpReader : public fwoop::Reader { class MockHttpWriter : public fwoop::Writer { public: ~MockHttpWriter() {} - MOCK_METHOD(std::error_code, write, (const uint8_t *buffer, uint32_t bufferSize, uint32_t &bytesWritten)); + MOCK_METHOD(std::error_code, write, (const fwoop::Array &arr, uint32_t &bytesWritten)); MOCK_METHOD(void, close, ()); }; @@ -46,8 +46,7 @@ TEST(HttpConnHandler, Handle) EXPECT_CALL(*reader, close()).WillOnce(::testing::Return()); EXPECT_CALL(*writer, close()).WillOnce(::testing::Return()); - EXPECT_CALL(*writer, write(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(std::error_code())); + EXPECT_CALL(*writer, write(::testing::_, ::testing::_)).WillOnce(::testing::Return(std::error_code())); EXPECT_CALL(mockCallback, onRequest(::testing::_, ::testing::_)).WillOnce(::testing::Return()); EXPECT_CALL(mockCallback, afterResponse(::testing::_, ::testing::_)).WillOnce(::testing::Return()); diff --git a/http/fwoop_httpserverevent.cpp b/http/fwoop_httpserverevent.cpp index a97a8ad..8ec0044 100644 --- a/http/fwoop_httpserverevent.cpp +++ b/http/fwoop_httpserverevent.cpp @@ -1,3 +1,4 @@ +#include "fwoop_array.h" #include #include #include @@ -16,7 +17,9 @@ HttpServerEvent::~HttpServerEvent() uint32_t len; uint32_t bytesWritten; uint8_t *out = finalResponse.encode(len); - auto ec = d_writer->write(out, len, bytesWritten); + Array arr(len); + arr.append(out, len); + auto ec = d_writer->write(arr, bytesWritten); delete[] out; if (ec) { Log::Warn("failed to write final response"); @@ -47,7 +50,10 @@ bool HttpServerEvent::pushEvent(const std::string &event, const std::string &dat out[offset++] = '\n'; uint32_t bytesWritten; - auto ec = d_writer->write(out, offset, bytesWritten); + // TODO refactor for Array + Array arr(offset); + arr.append(out, offset); + auto ec = d_writer->write(arr, bytesWritten); delete[] out; return ec.value() == 0; }