diff --git a/samples/tcp/CMakeLists.txt b/samples/tcp/CMakeLists.txt index f62c6531..5d77c95e 100644 --- a/samples/tcp/CMakeLists.txt +++ b/samples/tcp/CMakeLists.txt @@ -1,6 +1,7 @@ MakeApp(TcpClient) MakeApp(TcpServer) MakeApp(TcpRepeater) +MakeApp(NonBlockingExperiment) MakeSample(Sample01_server_client Tcp) add_dependencies(TcpSample01_server_client TcpClient TcpServer) diff --git a/samples/tcp/NonBlockingExperiment.cpp b/samples/tcp/NonBlockingExperiment.cpp new file mode 100644 index 00000000..11734ad5 --- /dev/null +++ b/samples/tcp/NonBlockingExperiment.cpp @@ -0,0 +1,91 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +using namespace std; + +using namespace MinimalSocket; + +const std::uint16_t PORT = 44553; +const string MESSAGE = "Hello from the sender"; + +struct Logger { + Logger(const string &name) : name_{name} {} + + template void log(const Args &...args) { + std::stringstream buff; + (pack(buff, args), ...); + std::scoped_lock lock{getMtx()}; + std::cout << '|' << name_ << "|: " << buff.str() << std::endl; + } + +private: + template static void pack(std::stringstream &recipient, T val) { + recipient << val; + } + + static std::mutex &getMtx() { + static std::mutex res = std::mutex{}; + return res; + } + + string name_; +}; + +void server_loop(std::atomic_bool &done) { + Logger logger{"Server"}; + + std::optional connection; + { + tcp::TcpServer server{PORT, MinimalSocket::AddressFamily::IP_V4}; + if (!server.open()) + throw std::runtime_error{"unable to open the server"}; + done = true; + logger.log("listening"); + connection.emplace(server.acceptNewClient()); + } + logger.log("connected"); + while (true) { + logger.log("sending"); + connection->send(MESSAGE); + logger.log("sleeping ... "); + std::this_thread::sleep_for(std::chrono::milliseconds{1500}); + } +} + +void client_loop() { + Logger logger{"Client"}; + + tcp::TcpClient connection{Address{PORT}}; + + logger.log("connecting"); + connection.open(); + logger.log("connected"); + + while (true) { + auto res = connection.receive(MESSAGE.size()); + if (res.empty()) + logger.log("nothing"); + else + logger.log("received `", res, '`'); + std::this_thread::sleep_for(std::chrono::milliseconds{100}); + } +} + +int main() { + std::atomic_bool done = false; + std::thread server([&done]() { server_loop(done); }); + while (!done.load()) { + } + + client_loop(); + + server.join(); + + return EXIT_SUCCESS; +} diff --git a/src/header/MinimalSocket/Error.h b/src/header/MinimalSocket/Error.h index 0a568db4..88456062 100644 --- a/src/header/MinimalSocket/Error.h +++ b/src/header/MinimalSocket/Error.h @@ -56,6 +56,6 @@ class SocketError : public ErrorCodeHolder, public Error { class TimeOutError : public Error { public: - TimeOutError() : Error("Timeout"){}; + TimeOutError() : Error("Timeout reached"){}; }; } // namespace MinimalSocket diff --git a/src/header/MinimalSocket/core/Receiver.h b/src/header/MinimalSocket/core/Receiver.h index b6fffef9..2a0d1c9e 100644 --- a/src/header/MinimalSocket/core/Receiver.h +++ b/src/header/MinimalSocket/core/Receiver.h @@ -13,19 +13,11 @@ #include namespace MinimalSocket { -class ReceiverBase : public virtual Socket { +class ReceiverWithTimeout : public virtual Socket { protected: - template - void lazyUpdateAndUseTimeout(const Timeout &to, Pred what) { - std::scoped_lock lock{receive_mtx}; - updateTimeout_(to); - what(receive_timeout); - } - -private: void updateTimeout_(const Timeout &timeout); - std::mutex receive_mtx; +private: Timeout receive_timeout = NULL_TIMEOUT; }; @@ -36,7 +28,7 @@ class ReceiverBase : public virtual Socket { * receive, they will be satisfited one at a time, as an internal mutex must be * locked before starting to receive. */ -class Receiver : public ReceiverBase { +class ReceiverBlocking : public ReceiverWithTimeout { public: /** * @param message the buffer that will store the received bytes. @@ -63,6 +55,33 @@ class Receiver : public ReceiverBase { */ std::string receive(std::size_t expected_max_bytes, const Timeout &timeout = NULL_TIMEOUT); + +private: + std::mutex recv_mtx; +}; + +class ReceiverNonBlocking : public virtual Socket { +public: + std::size_t receive(BufferView message); + + std::string receive(std::size_t expected_max_bytes); + +private: + std::mutex recv_mtx; +}; + +template class Receiver {}; +template <> class Receiver : public ReceiverBlocking {}; +template <> class Receiver : public ReceiverNonBlocking {}; + +struct ReceiveResult { + Address sender; + std::size_t received_bytes; +}; + +struct ReceiveStringResult { + Address sender; + std::string received_message; }; /** @@ -72,12 +91,8 @@ class Receiver : public ReceiverBase { * receive, they will be satisfited one at a time, as an internal mutex must be * locked before starting to receive. */ -class ReceiverUnkownSender : public ReceiverBase { +class ReceiverUnkownSenderBlocking : public ReceiverWithTimeout { public: - struct ReceiveResult { - Address sender; - std::size_t received_bytes; - }; /** * @param message the buffer that will store the received bytes. * @param timeout the timeout to consider. A NULL_TIMEOUT means actually to @@ -90,10 +105,6 @@ class ReceiverUnkownSender : public ReceiverBase { std::optional receive(BufferView message, const Timeout &timeout = NULL_TIMEOUT); - struct ReceiveStringResult { - Address sender; - std::string received_message; - }; /** * @brief Similar to ReceiverUnkownSender::receive(Buffer &, const Timeout &), * but internally building the recipient buffer which is converted into a @@ -110,5 +121,24 @@ class ReceiverUnkownSender : public ReceiverBase { std::optional receive(std::size_t expected_max_bytes, const Timeout &timeout = NULL_TIMEOUT); + +private: + std::mutex recv_mtx; +}; + +class ReceiverUnkownSenderNonBlocking : public virtual Socket { +public: + std::optional receive(BufferView message); + + std::optional receive(std::size_t expected_max_bytes); + +private: + std::mutex recv_mtx; }; + +template class ReceiverUnkownSender {}; +template <> +class ReceiverUnkownSender : public ReceiverUnkownSenderBlocking {}; +template <> +class ReceiverUnkownSender : public ReceiverUnkownSenderNonBlocking {}; } // namespace MinimalSocket diff --git a/src/header/MinimalSocket/core/SocketContext.h b/src/header/MinimalSocket/core/SocketContext.h index d6614f35..03ed9c3f 100644 --- a/src/header/MinimalSocket/core/SocketContext.h +++ b/src/header/MinimalSocket/core/SocketContext.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include @@ -98,4 +99,22 @@ class RemoteAddressFamilyAware { private: std::atomic remote_address_family; }; + +class BlockingMode : virtual public Socket { +public: + BlockingMode &operator=(const BlockingMode &o) { + mode = o.mode; + return *this; + } + + bool isBlocking() const { return mode; } + +protected: + BlockingMode(bool mode) : mode{mode} {} + + void setUp(); + +private: + bool mode; +}; } // namespace MinimalSocket diff --git a/src/header/MinimalSocket/tcp/TcpClient.h b/src/header/MinimalSocket/tcp/TcpClient.h index 4e35d618..39e8116a 100644 --- a/src/header/MinimalSocket/tcp/TcpClient.h +++ b/src/header/MinimalSocket/tcp/TcpClient.h @@ -13,14 +13,15 @@ #include namespace MinimalSocket::tcp { -class TcpClient : public NonCopiable, - public Openable, - public Sender, - public Receiver, - public RemoteAddressAware { -public: - TcpClient(TcpClient &&o); - TcpClient &operator=(TcpClient &&o); +class TcpClientBase : public NonCopiable, + public Openable, + public BlockingMode, + public Sender, + public RemoteAddressAware { +protected: + TcpClientBase(TcpClientBase &&o); + + void stealBase(TcpClientBase &o); /** * @brief The connection to the server is not asked in this c'tor which @@ -28,15 +29,32 @@ class TcpClient : public NonCopiable, * when calling open(...) * @param server_address the server to reach when opening this socket */ - TcpClient(const Address &server_address); + TcpClientBase(const Address &server_address, bool block_mode); -protected: void open_() override; }; +template +class TcpClient : public TcpClientBase, public Receiver { +public: + TcpClient(const Address &server_address) + : TcpClientBase{server_address, BlockMode} {} + + TcpClient(TcpClient &&o) + : TcpClientBase{std::forward(o)} {} + TcpClient &operator=(TcpClient &&o) { + this->stealBase(o); + return *this; + } +}; + /** * @return a client ready to ask the connection to the same server. * Beware that a closed socket is returned, which can be later opened. */ -TcpClient clone(const TcpClient &o); +template +TcpClient clone(const TcpClient &o) { + return TcpClient{o.getRemoteAddress()}; +} + } // namespace MinimalSocket::tcp diff --git a/src/header/MinimalSocket/tcp/TcpServer.h b/src/header/MinimalSocket/tcp/TcpServer.h index 62d0da1c..3278de54 100644 --- a/src/header/MinimalSocket/tcp/TcpServer.h +++ b/src/header/MinimalSocket/tcp/TcpServer.h @@ -16,33 +16,50 @@ #include namespace MinimalSocket::tcp { -class TcpServer; +class TcpServerBase; +class TcpConnectionNonBlocking; + /** * @brief Handler of an already established connection with a client, on the * server side. * An istance of this object is created calling TcpServer::acceptNewClient(). */ -class TcpConnection : public NonCopiable, - public Sender, - public Receiver, - public RemoteAddressAware { - friend class TcpServer; +class TcpConnectionBlocking : public NonCopiable, + public Sender, + public Receiver, + public RemoteAddressAware { + friend class TcpServerBase; public: - TcpConnection(TcpConnection &&o); - TcpConnection &operator=(TcpConnection &&o); + TcpConnectionBlocking(TcpConnectionBlocking &&o); + TcpConnectionBlocking &operator=(TcpConnectionBlocking &&o); + + TcpConnectionNonBlocking turnToNonBlocking(); private: - TcpConnection(const Address &remote_address); + TcpConnectionBlocking(const Address &remote_address); }; -class TcpServer : public NonCopiable, - public PortToBindAware, - public RemoteAddressFamilyAware, - public Openable { +class TcpConnectionNonBlocking : public NonCopiable, + public Sender, + public Receiver, + public RemoteAddressAware { public: - TcpServer(TcpServer &&o); - TcpServer &operator=(TcpServer &&o); + TcpConnectionNonBlocking(TcpConnectionNonBlocking &&o); + TcpConnectionNonBlocking &operator=(TcpConnectionNonBlocking &&o); + + TcpConnectionNonBlocking(TcpConnectionBlocking &&connection); +}; + +class TcpServerBase : public NonCopiable, + public PortToBindAware, + public RemoteAddressFamilyAware, + public Openable, + public BlockingMode { +protected: + TcpServerBase(TcpServerBase &&o); + + void stealBase(TcpServerBase &o); /** * @brief The port is not reserved in this c'tor which @@ -54,24 +71,10 @@ class TcpServer : public NonCopiable, * @param accepted_client_family family of the client that will ask the * connection to this server */ - TcpServer(Port port_to_bind = ANY_PORT, - AddressFamily accepted_client_family = AddressFamily::IP_V4); - - /** - * @brief Wait till accepting the connection from a new client. This is a - * blocking operation. - */ - TcpConnection acceptNewClient(); // blocking - - /** - * @brief Wait till accepting the connection from a new client. In case such a - * connection is not asked within the specified timeout, a nullopt is - * returned. - * @param timeout the timeout to consider. A NULL_TIMEOUT means actually to - * begin a blocking accept. - */ - std::optional acceptNewClient(const Timeout &timeout); + TcpServerBase(Port port_to_bind, AddressFamily accepted_client_family, + bool block_mode); +public: /** * @param queue_size the backlog size to assume when the server will be * opened, refer also to @@ -84,11 +87,82 @@ class TcpServer : public NonCopiable, protected: void open_() override; + struct AcceptedSocket; + void acceptClient_(AcceptedSocket &recipient); + + static TcpConnectionBlocking makeClient(const AcceptedSocket &acceptedSocket); + private: // maximum number of clients waiting for the connection to be // accepted std::atomic client_queue_size = 50; +}; +class AcceptorBlocking : public TcpServerBase { +public: + /** + * @brief Wait till accepting the connection from a new client. This is a + * blocking operation. + */ + TcpConnectionBlocking acceptNewClient(); + + TcpConnectionNonBlocking acceptNewNonBlockingClient() { + return acceptNewClient().turnToNonBlocking(); + } + +protected: + template + AcceptorBlocking(Args &&...args) + : TcpServerBase{std::forward(args)...} {} + +private: std::mutex accept_mtx; }; + +class AcceptorNonBlocking : public TcpServerBase { +public: + std::optional acceptNewClient(); + + std::optional acceptNewNonBlockingClient() { + auto client = acceptNewClient(); + if (client.has_value()) { + return client->turnToNonBlocking(); + } + return std::nullopt; + } + +protected: + template + AcceptorNonBlocking(Args &&...args) + : TcpServerBase{std::forward(args)...} {} + +private: + std::mutex accept_mtx; +}; + +template class Acceptor {}; +template <> class Acceptor : public AcceptorBlocking { +protected: + template + Acceptor(Args &&...args) : AcceptorBlocking{std::forward(args)...} {} +}; +template <> class Acceptor : public AcceptorNonBlocking { +protected: + template + Acceptor(Args &&...args) : AcceptorNonBlocking{std::forward(args)...} {} +}; + +template class TcpServer : public Acceptor { +public: + TcpServer(Port port_to_bind = 0, + AddressFamily accepted_client_family = AddressFamily::IP_V4) + : Acceptor{port_to_bind, accepted_client_family, BlockMode} {} + + TcpServer(TcpServer &&o) + : Acceptor{std::forward>(o)} {} + TcpServer &operator=(TcpServer &&o) { + this->stealBase(o); + return *this; + } +}; } // namespace MinimalSocket::tcp diff --git a/src/header/MinimalSocket/udp/UdpSocket.h b/src/header/MinimalSocket/udp/UdpSocket.h index 2d3bd86e..0e201c4f 100644 --- a/src/header/MinimalSocket/udp/UdpSocket.h +++ b/src/header/MinimalSocket/udp/UdpSocket.h @@ -19,7 +19,7 @@ namespace MinimalSocket::udp { */ static constexpr std::size_t MAX_UDP_RECV_MESSAGE = 65507; -class UdpConnected; +template class UdpConnected; /** * @brief This kind of udp is agnostic of the remote address (which can also @@ -29,15 +29,16 @@ class UdpConnected; * At the same time, this udp can send messages to any other non connected udp * sockets. */ -class UdpBinded : public NonCopiable, - public SenderTo, - public ReceiverUnkownSender, - public PortToBindAware, - public RemoteAddressFamilyAware, - public Openable { -public: - UdpBinded(UdpBinded &&o); - UdpBinded &operator=(UdpBinded &&o); +class UdpBase : public NonCopiable, + public SenderTo, + public PortToBindAware, + public RemoteAddressFamilyAware, + public Openable, + public BlockingMode { +protected: + UdpBase(UdpBase &&o); + + void stealBase(UdpBase &o); /** * @brief The port is not reserved in this c'tor which @@ -46,9 +47,14 @@ class UdpBinded : public NonCopiable, * @param port_to_bind the port to reserve by this udp * @param accepted_connection_family the kind of udp that can reach this one */ - UdpBinded(Port port_to_bind = ANY_PORT, - AddressFamily accepted_connection_family = AddressFamily::IP_V4); + UdpBase(Port port_to_bind, AddressFamily accepted_connection_family, + bool blockMode); + void open_() override; +}; + +class UdpBlocking : public UdpBase, public ReceiverUnkownSender { +public: /** * @brief Connects the udo socket to the specified remote address. * This leads to transfer the ownership of the underlying socket to the @@ -56,7 +62,7 @@ class UdpBinded : public NonCopiable, * @param remote_address the address to use for connecting the socket * @return a socket connected to the passed remote address */ - UdpConnected connect(const Address &remote_address); + UdpConnected connect(const Address &remote_address); /** * @brief similar to connect(const Address &). Here, the remote address is not @@ -67,7 +73,7 @@ class UdpBinded : public NonCopiable, * @param initial_message the initial message sent from the remote peer to * detect its address. */ - UdpConnected connect(std::string *initial_message = nullptr); + UdpConnected connect(std::string *initial_message = nullptr); /** * @brief similar to connect(std::string *initial_message), but non blocking. @@ -78,11 +84,49 @@ class UdpBinded : public NonCopiable, * @param initial_message the initial message sent from the remote peer to * detect its address. */ - std::optional connect(const Timeout &timeout, - std::string *initial_message = nullptr); + std::optional> + connect(const Timeout &timeout, std::string *initial_message = nullptr); protected: - void open_() override; + template + UdpBlocking(Args &&...args) : UdpBase{std::forward(args)...} {} +}; + +class UdpNonBlocking : public UdpBase, public ReceiverUnkownSender { +public: + UdpConnected connect(const Address &remote_address); + + std::optional> + connect(std::string *initial_message = nullptr); + +protected: + template + UdpNonBlocking(Args &&...args) : UdpBase{std::forward(args)...} {} +}; + +template class Udp_ {}; +template <> class Udp_ : public UdpBlocking { +protected: + template + Udp_(Args &&...args) : UdpBlocking{std::forward(args)...} {} +}; +template <> class Udp_ : public UdpNonBlocking { +protected: + template + Udp_(Args &&...args) : UdpNonBlocking{std::forward(args)...} {} +}; + +template class Udp : public Udp_ { +public: + Udp(Port port_to_bind = ANY_PORT, + AddressFamily accepted_connection_family = AddressFamily::IP_V4) + : Udp_{port_to_bind, accepted_connection_family, BlockMode} {} + + Udp(Udp &&o) : Udp_{std::forward>(o)} {} + Udp &operator=(Udp &&o) { + this->stealBase(o); + return *this; + } }; /** @@ -94,15 +138,16 @@ class UdpBinded : public NonCopiable, * incoming from udp sockets different from the remote address are filtered out. * At the same time, the remote address might also not exists at all. */ -class UdpConnected : public NonCopiable, - public Sender, - public Receiver, - public PortToBindAware, - public RemoteAddressAware, - public Openable { -public: - UdpConnected(UdpConnected &&o); - UdpConnected &operator=(UdpConnected &&o); +class UdpConnectedBase : public NonCopiable, + public Sender, + public PortToBindAware, + public RemoteAddressAware, + public Openable, + public BlockingMode { +protected: + UdpConnectedBase(UdpConnectedBase &&o); + + void stealBase(UdpConnectedBase &o); /** * @brief The connection to the remote address is not done in this c'tor which @@ -111,17 +156,68 @@ class UdpConnected : public NonCopiable, * @param remote_address remote address of the peer * @param port the port to reserve by this udp */ - UdpConnected(const Address &remote_address, Port port = ANY_PORT); + UdpConnectedBase(const Address &remote_address, Port port, bool blockMode); + + void open_() override; +}; + +class UdpConnectedBlocking : public UdpConnectedBase { +public: + /** + * @brief disconnect the underlying socket, generating an unbinded udp that + * reserves the same port reserved by this one. This leaves this onbject + * empty and closed. + */ + Udp disconnect(); + +protected: + template + UdpConnectedBlocking(Args &&...args) + : UdpConnectedBase{std::forward(args)...} {} +}; +class UdpConnectedNonBlocking : public UdpConnectedBase { +public: /** * @brief disconnect the underlying socket, generating an unbinded udp that * reserves the same port reserved by this one. This leaves this onbject * empty and closed. */ - UdpBinded disconnect(); + Udp disconnect(); protected: - void open_() override; + template + UdpConnectedNonBlocking(Args &&...args) + : UdpConnectedBase{std::forward(args)...} {} +}; + +template class UdpConnected_ {}; +template <> class UdpConnected_ : public UdpConnectedBlocking { +protected: + template + UdpConnected_(Args &&...args) + : UdpConnectedBlocking{std::forward(args)...} {} +}; +template <> class UdpConnected_ : public UdpConnectedNonBlocking { +protected: + template + UdpConnected_(Args &&...args) + : UdpConnectedNonBlocking{std::forward(args)...} {} +}; + +template +class UdpConnected : public UdpConnected_, + public Receiver { +public: + UdpConnected(const Address &remote_address, Port port_to_bind = ANY_PORT) + : UdpConnected_{remote_address, port_to_bind, BlockMode} {} + + UdpConnected(UdpConnected &&o) + : UdpConnected_{std::forward>(o)} {} + UdpConnected &operator=(UdpConnected &&o) { + this->stealBase(o); + return *this; + } }; /** @@ -134,17 +230,8 @@ class UdpConnected : public NonCopiable, * @param initial_message the message sent from the remote peer to detect its * address */ -UdpConnected makeUdpConnectedToUnknown(Port port, - AddressFamily accepted_connection_family, - std::string *initial_message = nullptr); - -/** - * @brief non blocking version of makeUdpConnectedToUnknown(const Port &, const - * AddressFamily &, std::string *). In case no remote peer sends at least 1 byte - * within the timeout, a nullopt is returned. - */ -std::optional +UdpConnected makeUdpConnectedToUnknown(Port port, AddressFamily accepted_connection_family, - const Timeout &timeout, std::string *initial_message = nullptr); + } // namespace MinimalSocket::udp diff --git a/src/src/SocketFunctions.cpp b/src/src/SocketFunctions.cpp index f51d6bdd..b65d4c08 100644 --- a/src/src/SocketFunctions.cpp +++ b/src/src/SocketFunctions.cpp @@ -11,6 +11,10 @@ #include "SocketFunctions.h" #include "Utils.h" +#if defined(__unix__) || defined(__APPLE__) +#include +#endif + namespace MinimalSocket { namespace { #ifdef _WIN32 @@ -133,4 +137,22 @@ void connect(SocketID socket_id, const Address &remote_address) { } }); } + +void turnToNonBlocking(SocketID socket_id) { +#ifdef _WIN32 + // https://learn.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-ioctlsocket + u_long iMode = 1; + iResult = ioctlsocket(socket_id, FIONBIO, &iMode); + throw 0; // TODO in the below function throw SocketError + if (iResult != NO_ERROR) { + printf("ioctlsocket failed with error: %ld\n", iResult); + } +#elif defined(__unix__) || defined(__APPLE__) + // https://jameshfisher.com/2017/04/05/set_socket_nonblocking/ + int flags = ::fcntl(socket_id, F_GETFL); + if (::fcntl(socket_id, F_SETFL, flags | O_NONBLOCK) == -1) { + throw Error{"Unable to set up the non blocking mode"}; + } +#endif +} } // namespace MinimalSocket \ No newline at end of file diff --git a/src/src/SocketFunctions.h b/src/src/SocketFunctions.h index eaed8e76..75f509f0 100644 --- a/src/src/SocketFunctions.h +++ b/src/src/SocketFunctions.h @@ -17,4 +17,6 @@ Port bind(SocketID socket_id, AddressFamily family, Port port, void listen(SocketID socket_id, std::size_t backlog_size); void connect(SocketID socket_id, const Address &remote_address); + +void turnToNonBlocking(SocketID socket_id); } // namespace MinimalSocket diff --git a/src/src/SocketHandler.cpp b/src/src/SocketHandler.cpp index 5e75091f..724582c4 100644 --- a/src/src/SocketHandler.cpp +++ b/src/src/SocketHandler.cpp @@ -10,6 +10,10 @@ #include "SocketHandler.h" #include "Utils.h" +#if defined(__unix__) || defined(__APPLE__) +#include +#endif + namespace MinimalSocket { #ifdef _WIN32 WSALazyInitializer::WSALazyInitializer(const WSAVersion &version) @@ -129,4 +133,5 @@ void SocketHandler::reset(SocketType type, AddressFamily family) { throw err; } } + } // namespace MinimalSocket \ No newline at end of file diff --git a/src/src/core/Receiver.cpp b/src/src/core/Receiver.cpp index 6871c2e8..21b05233 100644 --- a/src/src/core/Receiver.cpp +++ b/src/src/core/Receiver.cpp @@ -14,7 +14,7 @@ #endif namespace MinimalSocket { -void ReceiverBase::updateTimeout_(const Timeout &timeout) { +void ReceiverWithTimeout::updateTimeout_(const Timeout &timeout) { if (timeout == receive_timeout) { return; } @@ -66,81 +66,156 @@ void check_received_bytes(int &recvBytes, const Timeout &timeout) { } } // namespace -std::size_t Receiver::receive(BufferView message, const Timeout &timeout) { - std::size_t res = 0; - - lazyUpdateAndUseTimeout( - timeout, [&message, &res, this](const Timeout &timeout) { - clear(message); +std::size_t ReceiverBlocking::receive(BufferView message, + const Timeout &timeout) { + std::scoped_lock lock{recv_mtx}; + updateTimeout_(timeout); + clear(message); + + int recvBytes = ::recv(getHandler().accessId(), message.buffer, + static_cast(message.buffer_size), 0); + check_received_bytes(recvBytes, timeout); + if (recvBytes > message.buffer_size) { + // if here, the message received is probably corrupted + recvBytes = 0; + } + return static_cast(recvBytes); +} - int recvBytes = ::recv(getHandler().accessId(), message.buffer, - static_cast(message.buffer_size), 0); - check_received_bytes(recvBytes, timeout); - if (recvBytes > message.buffer_size) { - // if here, the message received is probably corrupted - recvBytes = 0; - } - res = static_cast(recvBytes); - }); +std::size_t ReceiverNonBlocking::receive(BufferView message) { + std::scoped_lock lock{recv_mtx}; + clear(message); - return res; + int recvBytes = ::recv(getHandler().accessId(), message.buffer, + static_cast(message.buffer_size), 0); + return (recvBytes == -1 || recvBytes > message.buffer_size) + ? 0 + : static_cast(recvBytes); } -std::string Receiver::receive(std::size_t expected_max_bytes, - const Timeout &timeout) { +namespace { +template +std::string receive_into_string(std::size_t expected_max_bytes, Pred pred, + const Args &...args) { std::string buffer; buffer.resize(expected_max_bytes); auto buffer_temp = makeBufferView(buffer); - auto recvBytes = receive(buffer_temp, timeout); + auto recvBytes = pred(buffer_temp, args...); buffer.resize(recvBytes); return buffer; } +} // namespace + +std::string ReceiverBlocking::receive(std::size_t expected_max_bytes, + const Timeout &timeout) { + return receive_into_string( + expected_max_bytes, + [this](BufferView message, const Timeout &timeout) { + return this->receive(message, timeout); + }, + timeout); +} -std::optional -ReceiverUnkownSender::receive(BufferView message, const Timeout &timeout) { - std::optional res; - - lazyUpdateAndUseTimeout( - timeout, [&message, &res, this](const Timeout &timeout) { - clear(message); - - char sender_address[MAX_POSSIBLE_ADDRESS_SIZE]; - SocketAddressLength sender_address_length = MAX_POSSIBLE_ADDRESS_SIZE; - - int recvBytes = - ::recvfrom(getHandler().accessId(), message.buffer, - static_cast(message.buffer_size), 0, - reinterpret_cast(&sender_address[0]), - &sender_address_length); - check_received_bytes(recvBytes, timeout); - if (recvBytes > message.buffer_size) { - // if here, the message received is probably corrupted - return; - } - if (0 == recvBytes) { - // if here, timeout was reached - return; - } - - res = ReceiveResult{ - toAddress(reinterpret_cast(sender_address)), - static_cast(recvBytes)}; - }); +std::string ReceiverNonBlocking::receive(std::size_t expected_max_bytes) { + return receive_into_string(expected_max_bytes, [this](BufferView message) { + return this->receive(message); + }); +} + +std::optional +ReceiverUnkownSenderBlocking::receive(BufferView message, + const Timeout &timeout) { + std::scoped_lock lock{recv_mtx}; + updateTimeout_(timeout); + clear(message); + + std::optional res; + + char sender_address[MAX_POSSIBLE_ADDRESS_SIZE]; + SocketAddressLength sender_address_length = MAX_POSSIBLE_ADDRESS_SIZE; + + int recvBytes = + ::recvfrom(getHandler().accessId(), message.buffer, + static_cast(message.buffer_size), 0, + reinterpret_cast(&sender_address[0]), + &sender_address_length); + check_received_bytes(recvBytes, timeout); + if (recvBytes > message.buffer_size) { + // if here, the message received is probably corrupted + return std::nullopt; + } + if (0 == recvBytes) { + // if here, timeout was reached + return std::nullopt; + } + + res = ReceiveResult{ + toAddress(reinterpret_cast(sender_address)), + static_cast(recvBytes)}; return res; } -std::optional -ReceiverUnkownSender::receive(std::size_t expected_max_bytes, - const Timeout &timeout) { +std::optional +ReceiverUnkownSenderNonBlocking::receive(BufferView message) { + std::scoped_lock lock{recv_mtx}; + clear(message); + + std::optional res; + + char sender_address[MAX_POSSIBLE_ADDRESS_SIZE]; + SocketAddressLength sender_address_length = MAX_POSSIBLE_ADDRESS_SIZE; + + int recvBytes = + ::recvfrom(getHandler().accessId(), message.buffer, + static_cast(message.buffer_size), 0, + reinterpret_cast(&sender_address[0]), + &sender_address_length); + + if (recvBytes == -1 || recvBytes > message.buffer_size) { + return std::nullopt; + } + + res = ReceiveResult{ + toAddress(reinterpret_cast(sender_address)), + static_cast(recvBytes)}; + + return res; +} + +namespace { +template +std::optional +receive_unknown_into_string(std::size_t expected_max_bytes, Pred pred, + const Args &...args) { std::string buffer; buffer.resize(expected_max_bytes); auto buffer_temp = makeBufferView(buffer); - auto result = receive(buffer_temp, timeout); + auto result = pred(buffer_temp, args...); if (!result) { return std::nullopt; } buffer.resize(result->received_bytes); return ReceiveStringResult{std::move(result->sender), std::move(buffer)}; } +} // namespace + +std::optional +ReceiverUnkownSenderBlocking::receive(std::size_t expected_max_bytes, + const Timeout &timeout) { + return receive_unknown_into_string( + expected_max_bytes, + [this](BufferView message, const Timeout &timeout) { + return this->receive(message, timeout); + }, + timeout); +} + +std::optional +ReceiverUnkownSenderNonBlocking::receive(std::size_t expected_max_bytes) { + return receive_unknown_into_string( + expected_max_bytes, + [this](BufferView message) { return this->receive(message); }); +} + } // namespace MinimalSocket diff --git a/src/src/core/SocketContext.cpp b/src/src/core/SocketContext.cpp index 6f016a2a..7d8f51dc 100644 --- a/src/src/core/SocketContext.cpp +++ b/src/src/core/SocketContext.cpp @@ -8,6 +8,9 @@ #include #include +#include "../SocketFunctions.h" +#include "../SocketHandler.h" + namespace MinimalSocket { Address RemoteAddressAware::getRemoteAddress() const { std::scoped_lock lock(remote_address_mtx); @@ -21,4 +24,9 @@ RemoteAddressAware::RemoteAddressAware(const Address &address) } } +void BlockingMode::setUp() { + if (!mode) { + turnToNonBlocking(getHandler().accessId()); + } +} } // namespace MinimalSocket diff --git a/src/src/tcp/TcpClient.cpp b/src/src/tcp/TcpClient.cpp index 740e17df..5fc8efc8 100644 --- a/src/src/tcp/TcpClient.cpp +++ b/src/src/tcp/TcpClient.cpp @@ -12,22 +12,24 @@ #include "../Utils.h" namespace MinimalSocket::tcp { -TcpClient::TcpClient(TcpClient &&o) : RemoteAddressAware(o) { this->steal(o); } -TcpClient &TcpClient::operator=(TcpClient &&o) { +TcpClientBase::TcpClientBase(TcpClientBase &&o) + : BlockingMode{o.isBlocking()}, RemoteAddressAware(o) { + this->steal(o); +} + +void TcpClientBase::stealBase(TcpClientBase &o) { this->steal(o); copy_as(*this, o); - return *this; } -TcpClient::TcpClient(const Address &server_address) - : RemoteAddressAware(server_address) {} +TcpClientBase::TcpClientBase(const Address &server_address, bool block_mode) + : BlockingMode{block_mode}, RemoteAddressAware(server_address) {} -void TcpClient::open_() { +void TcpClientBase::open_() { auto &socket = getHandler(); const auto remote_address = getRemoteAddress(); socket.reset(SocketType::TCP, remote_address.getFamily()); MinimalSocket::connect(socket.accessId(), remote_address); + this->BlockingMode::setUp(); } - -TcpClient clone(const TcpClient &o) { return TcpClient{o.getRemoteAddress()}; } } // namespace MinimalSocket::tcp diff --git a/src/src/tcp/TcpServer.cpp b/src/src/tcp/TcpServer.cpp index cb22217e..5b35acd0 100644 --- a/src/src/tcp/TcpServer.cpp +++ b/src/src/tcp/TcpServer.cpp @@ -13,22 +13,64 @@ #include "../Utils.h" namespace MinimalSocket::tcp { -TcpServer::TcpServer(TcpServer &&o) - : PortToBindAware(o), RemoteAddressFamilyAware(o) { +TcpConnectionBlocking::TcpConnectionBlocking(const Address &remote_address) + : RemoteAddressAware{remote_address} {} + +TcpConnectionBlocking::TcpConnectionBlocking(TcpConnectionBlocking &&o) + : RemoteAddressAware(o) { + this->steal(o); +} + +TcpConnectionBlocking & +TcpConnectionBlocking::operator=(TcpConnectionBlocking &&o) { + this->steal(o); + copy_as(*this, o); + return *this; +} + +TcpConnectionNonBlocking TcpConnectionBlocking::turnToNonBlocking() { + return TcpConnectionNonBlocking{std::move(*this)}; +} + +TcpConnectionNonBlocking::TcpConnectionNonBlocking(TcpConnectionNonBlocking &&o) + : RemoteAddressAware{o} { this->steal(o); } -TcpServer &TcpServer::operator=(TcpServer &&o) { + +TcpConnectionNonBlocking & +TcpConnectionNonBlocking::operator=(TcpConnectionNonBlocking &&o) { + this->steal(o); + copy_as(*this, o); + return *this; +} + +TcpConnectionNonBlocking::TcpConnectionNonBlocking( + TcpConnectionBlocking &&connection) + : RemoteAddressAware{connection} { + this->steal(connection); + turnToNonBlocking(getHandler().accessId()); +} + +TcpServerBase::TcpServerBase(TcpServerBase &&o) + : PortToBindAware(o), + RemoteAddressFamilyAware(o), BlockingMode{o.isBlocking()} { + this->steal(o); +} + +void TcpServerBase::stealBase(TcpServerBase &o) { this->steal(o); copy_as(*this, o); copy_as(*this, o); - return *this; } -TcpServer::TcpServer(Port port_to_bind, AddressFamily accepted_client_family) +TcpServerBase::TcpServerBase(Port port_to_bind, + AddressFamily accepted_client_family, + bool block_mode) : PortToBindAware(port_to_bind), - RemoteAddressFamilyAware(accepted_client_family) {} + RemoteAddressFamilyAware(accepted_client_family), BlockingMode{ + block_mode} {} -void TcpServer::open_() { +void TcpServerBase::open_() { auto &socket = getHandler(); const auto port = getPortToBind(); const auto family = getRemoteAddressFamily(); @@ -36,79 +78,79 @@ void TcpServer::open_() { auto binded_port = MinimalSocket::bind(socket.accessId(), family, port, shallBeFreePort()); setPort(binded_port); + this->BlockingMode::setUp(); MinimalSocket::listen(socket.accessId(), client_queue_size); } -void TcpServer::setClientQueueSize(const std::size_t queue_size) { +void TcpServerBase::setClientQueueSize(const std::size_t queue_size) { if (wasOpened()) { throw Error{"Can't set client queue size of an alrady opened tcp server"}; } client_queue_size = queue_size; } -TcpConnection TcpServer::acceptNewClient() { - auto temp = acceptNewClient(NULL_TIMEOUT); - return std::move(temp.value()); -} +struct TcpServerBase::AcceptedSocket { + SocketID fd = SCK_INVALID_SOCKET; + SocketAddressLength address_length = MAX_POSSIBLE_ADDRESS_SIZE; + char address[MAX_POSSIBLE_ADDRESS_SIZE]; +}; -std::optional -TcpServer::acceptNewClient(const Timeout &timeout) { - std::scoped_lock lock(accept_mtx); - if (!this->wasOpened()) { - throw Error("Tcp server was not opened before starting to accept clients"); - } +void TcpServerBase::acceptClient_(AcceptedSocket &recipient) { + auto &[accepted_client_socket_id, acceptedClientAddress_length, + acceptedClientAddress] = recipient; - char acceptedClientAddress[MAX_POSSIBLE_ADDRESS_SIZE]; - SocketAddressLength acceptedClientAddress_length = MAX_POSSIBLE_ADDRESS_SIZE; - SocketID accepted_client_socket_id = SCK_INVALID_SOCKET; - - auto accept_client = [&]() { - // accept: wait for a client to call connect and hit this server and get a - // pointer to this client. - accepted_client_socket_id = - ::accept(getHandler().accessId(), - reinterpret_cast(&acceptedClientAddress[0]), - &acceptedClientAddress_length); + // accept: wait for a client to call connect and hit this server and get a + // pointer to this client. + accepted_client_socket_id = + ::accept(getHandler().accessId(), + reinterpret_cast(&acceptedClientAddress[0]), + &acceptedClientAddress_length); + if (isBlocking()) { if (accepted_client_socket_id == SCK_INVALID_SOCKET) { auto err = SocketError{"accepting a new client"}; throw err; } - }; + } +} + +TcpConnectionBlocking +TcpServerBase::makeClient(const AcceptedSocket &acceptedSocket) { + auto accepted_client_parsed_address = toAddress( + reinterpret_cast(acceptedSocket.address)); + TcpConnectionBlocking result{accepted_client_parsed_address}; + result.getHandler().reset(acceptedSocket.fd); + return result; +} + +TcpConnectionBlocking AcceptorBlocking::acceptNewClient() { + std::scoped_lock lock(accept_mtx); + if (!this->wasOpened()) { + throw Error("Tcp server was not opened before starting to accept clients"); + } + AcceptedSocket acceptedSocket; try { - if (NULL_TIMEOUT == timeout) { - accept_client(); - } else { - try_within_timeout([&]() { accept_client(); }, - [this]() { this->resetHandler(); }, timeout); - } - } catch (const TimeOutError &) { - TcpServer reopened = TcpServer{getPortToBind(), getRemoteAddressFamily()}; - reopened.open(); - *this = std::move(reopened); - return std::nullopt; + acceptClient_(acceptedSocket); } catch (...) { std::rethrow_exception(std::current_exception()); } - auto accepted_client_parsed_address = - toAddress(reinterpret_cast(acceptedClientAddress)); - std::optional result; - auto &accepted = - result.emplace(TcpConnection{accepted_client_parsed_address}); - accepted.getHandler().reset(accepted_client_socket_id); - return result; + return makeClient(acceptedSocket); } -TcpConnection::TcpConnection(const Address &remote_address) - : RemoteAddressAware(remote_address) {} +std::optional AcceptorNonBlocking::acceptNewClient() { + std::scoped_lock lock(accept_mtx); + if (!this->wasOpened()) { + throw Error("Tcp server was not opened before starting to accept clients"); + } -TcpConnection::TcpConnection(TcpConnection &&o) : RemoteAddressAware(o) { - this->steal(o); -} -TcpConnection &TcpConnection::operator=(TcpConnection &&o) { - copy_as(*this, o); - this->steal(o); - return *this; + AcceptedSocket acceptedSocket; + acceptClient_(acceptedSocket); + + if (acceptedSocket.fd == SCK_INVALID_SOCKET) { + return std::nullopt; + } + return makeClient(acceptedSocket); } + } // namespace MinimalSocket::tcp diff --git a/src/src/udp/UdpSocket.cpp b/src/src/udp/UdpSocket.cpp index 91ddfe78..f49c48d5 100644 --- a/src/src/udp/UdpSocket.cpp +++ b/src/src/udp/UdpSocket.cpp @@ -12,35 +12,39 @@ #include "../Utils.h" namespace MinimalSocket::udp { -UdpBinded::UdpBinded(Port port_to_bind, - AddressFamily accepted_connection_family) +UdpBase::UdpBase(Port port_to_bind, AddressFamily accepted_connection_family, + bool blockMode) : PortToBindAware(port_to_bind), - RemoteAddressFamilyAware(accepted_connection_family) {} + RemoteAddressFamilyAware(accepted_connection_family), BlockingMode{ + blockMode} {} -UdpBinded::UdpBinded(UdpBinded &&o) - : PortToBindAware(o), RemoteAddressFamilyAware(o) { +UdpBase::UdpBase(UdpBase &&o) + : PortToBindAware(o), + RemoteAddressFamilyAware(o), BlockingMode{o.isBlocking()} { this->steal(o); } -UdpBinded &UdpBinded::operator=(UdpBinded &&o) { + +void UdpBase::stealBase(UdpBase &o) { copy_as(*this, o); copy_as(*this, o); + copy_as(*this, o); this->steal(o); - return *this; } -void UdpBinded::open_() { +void UdpBase::open_() { getHandler().reset(SocketType::UDP, getRemoteAddressFamily()); auto binded_port = MinimalSocket::bind(getHandler().accessId(), getRemoteAddressFamily(), getPortToBind(), shallBeFreePort()); setPort(binded_port); + this->BlockingMode::setUp(); } -UdpConnected UdpBinded::connect(const Address &remote_address) { +UdpConnected UdpBlocking::connect(const Address &remote_address) { if (remote_address.getFamily() != getRemoteAddressFamily()) { throw Error{"Passed address has invalid family"}; } - UdpConnected result(remote_address, getPortToBind()); + UdpConnected result(remote_address, getPortToBind()); if (wasOpened()) { MinimalSocket::connect(getHandler().accessId(), remote_address); } @@ -48,13 +52,13 @@ UdpConnected UdpBinded::connect(const Address &remote_address) { return std::move(result); } -UdpConnected UdpBinded::connect(std::string *initial_message) { +UdpConnected UdpBlocking::connect(std::string *initial_message) { auto result = this->connect(NULL_TIMEOUT, initial_message); return std::move(result.value()); } -std::optional UdpBinded::connect(const Timeout &timeout, - std::string *initial_message) { +std::optional> +UdpBlocking::connect(const Timeout &timeout, std::string *initial_message) { auto maybe_received = this->receive(MAX_UDP_RECV_MESSAGE, timeout); if (!maybe_received) { return std::nullopt; @@ -65,54 +69,81 @@ std::optional UdpBinded::connect(const Timeout &timeout, return connect(maybe_received->sender); } -UdpConnected::UdpConnected(const Address &remote_address, Port port) - : PortToBindAware(port), RemoteAddressAware(remote_address) {} +UdpConnected UdpNonBlocking::connect(const Address &remote_address) { + if (remote_address.getFamily() != getRemoteAddressFamily()) { + throw Error{"Passed address has invalid family"}; + } + UdpConnected result(remote_address, getPortToBind()); + if (wasOpened()) { + MinimalSocket::connect(getHandler().accessId(), remote_address); + } + this->transfer(result); + return std::move(result); +} -UdpConnected::UdpConnected(UdpConnected &&o) - : PortToBindAware(o), RemoteAddressAware(o) { +std::optional> +UdpNonBlocking::connect(std::string *initial_message) { + auto maybe_received = this->receive(MAX_UDP_RECV_MESSAGE); + if (!maybe_received) { + return std::nullopt; + } + if (nullptr != initial_message) { + *initial_message = std::move(maybe_received->received_message); + } + return connect(maybe_received->sender); +} + +UdpConnectedBase::UdpConnectedBase(const Address &remote_address, Port port, + bool blockMode) + : PortToBindAware(port), + RemoteAddressAware(remote_address), BlockingMode{blockMode} {} + +UdpConnectedBase::UdpConnectedBase(UdpConnectedBase &&o) + : PortToBindAware(o), RemoteAddressAware(o), BlockingMode{o.isBlocking()} { this->steal(o); } -UdpConnected &UdpConnected::operator=(UdpConnected &&o) { + +void UdpConnectedBase::stealBase(UdpConnectedBase &o) { copy_as(*this, o); copy_as(*this, o); + copy_as(*this, o); this->steal(o); - return *this; } -void UdpConnected::open_() { +void UdpConnectedBase::open_() { const auto &remote_address = getRemoteAddress(); getHandler().reset(SocketType::UDP, remote_address.getFamily()); auto socket_id = getHandler().accessId(); auto binded_port = MinimalSocket::bind(socket_id, remote_address.getFamily(), getPortToBind(), shallBeFreePort()); setPort(binded_port); + this->BlockingMode::setUp(); MinimalSocket::connect(socket_id, remote_address); } -UdpBinded UdpConnected::disconnect() { +Udp UdpConnectedBlocking::disconnect() { resetHandler(); - UdpBinded result(getPortToBind(), getRemoteAddress().getFamily()); + Udp result(getPortToBind(), getRemoteAddress().getFamily()); result.open(); return std::move(result); } -UdpConnected makeUdpConnectedToUnknown(Port port, - AddressFamily accepted_connection_family, - std::string *initial_message) { - auto result = makeUdpConnectedToUnknown(port, accepted_connection_family, - NULL_TIMEOUT, initial_message); - return std::move(result.value()); +Udp UdpConnectedNonBlocking::disconnect() { + resetHandler(); + Udp result(getPortToBind(), getRemoteAddress().getFamily()); + result.open(); + return std::move(result); } -std::optional +UdpConnected makeUdpConnectedToUnknown(Port port, AddressFamily accepted_connection_family, - const Timeout &timeout, std::string *initial_message) { - UdpBinded primal_socket(port, accepted_connection_family); + Udp primal_socket(port, accepted_connection_family); auto success = primal_socket.open(); if (!success) { - return std::nullopt; + throw Error{"Unable to open the primal upd socket"}; } - return primal_socket.connect(timeout, initial_message); + return primal_socket.connect(initial_message); } + } // namespace MinimalSocket::udp diff --git a/tests/ConnectionsUtils.cpp b/tests/ConnectionsUtils.cpp index e6dfdaf5..21ebd868 100644 --- a/tests/ConnectionsUtils.cpp +++ b/tests/ConnectionsUtils.cpp @@ -16,11 +16,11 @@ TcpPeers::TcpPeers(const Port &port, const AddressFamily &family) ParallelSection::biSection( [&](Barrier &br) { // server - tcp::TcpServer server(port, family); + tcp::TcpServer server(port, family); REQUIRE(server.open()); br.arrive_and_wait(); auto accepted = server.acceptNewClient(); - server_side = std::make_unique(std::move(accepted)); + server_side.emplace(std::move(accepted)); }, [&](Barrier &br) { // client @@ -29,10 +29,62 @@ TcpPeers::TcpPeers(const Port &port, const AddressFamily &family) }); } -UdpPeers::UdpPeers(const Port &port_a, const Port &port_b, - const AddressFamily &family) +template <> +UdpPeers>::UdpPeers(const Port &port_a, const Port &port_b, + const AddressFamily &family) : peer_a(port_a, family), peer_b(port_b, family) { REQUIRE(peer_a.open()); REQUIRE(peer_b.open()); } + +template <> +Address +UdpPeers>::extractRemoteAddress(const udp::Udp &subject) { + return Address{subject.getPortToBind(), subject.getRemoteAddressFamily()}; +} + +template <> +UdpPeers>::UdpPeers(const Port &port_a, const Port &port_b, + const AddressFamily &family) + : peer_a(port_a, family), peer_b(port_b, family) { + REQUIRE(peer_a.open()); + REQUIRE(peer_b.open()); +} + +template <> +Address UdpPeers>::extractRemoteAddress( + const udp::Udp &subject) { + return Address{subject.getPortToBind(), subject.getRemoteAddressFamily()}; +} + +template <> +UdpPeers>::UdpPeers(const Port &port_a, + const Port &port_b, + const AddressFamily &family) + : peer_a(Address{port_b, family}, port_a), + peer_b(Address{port_a, family}, port_b) { + REQUIRE(peer_a.open()); + REQUIRE(peer_b.open()); +} +template <> +Address UdpPeers>::extractRemoteAddress( + const udp::UdpConnected &subject) { + return subject.getRemoteAddress(); +} + +template <> +UdpPeers>::UdpPeers(const Port &port_a, + const Port &port_b, + const AddressFamily &family) + : peer_a(Address{port_b, family}, port_a), + peer_b(Address{port_a, family}, port_b) { + REQUIRE(peer_a.open()); + REQUIRE(peer_b.open()); +} +template <> +Address UdpPeers>::extractRemoteAddress( + const udp::UdpConnected &subject) { + return subject.getRemoteAddress(); +} + } // namespace MinimalSocket::test diff --git a/tests/ConnectionsUtils.h b/tests/ConnectionsUtils.h index 4204f3a5..361e0883 100644 --- a/tests/ConnectionsUtils.h +++ b/tests/ConnectionsUtils.h @@ -11,42 +11,45 @@ #include #include +#include +#include + namespace MinimalSocket::test { class TcpPeers { public: TcpPeers(const Port &port, const AddressFamily &family); - tcp::TcpConnection &getServerSide() { return *server_side; } - tcp::TcpClient &getClientSide() { return client_side; } + std::pair *> get() { + return std::make_pair(&server_side.value(), &client_side); + } private: - std::unique_ptr server_side; - tcp::TcpClient client_side; + std::optional server_side; + tcp::TcpClient client_side; }; -class UdpPeers { +template class UdpPeers { public: UdpPeers(const Port &port_a, const Port &port_b, const AddressFamily &family); - udp::UdpBinded &getPeerA() { return peer_a; } - udp::UdpBinded &getPeerB() { return peer_b; } - - Address addressPeerA() const { - return Address{peer_a.getPortToBind(), peer_a.getRemoteAddressFamily()}; - }; - Address addressPeerB() const { - return Address{peer_b.getPortToBind(), peer_b.getRemoteAddressFamily()}; - }; + std::tuple get() { + return std::make_tuple(&peer_a, extractRemoteAddress(peer_a), &peer_b, + extractRemoteAddress(peer_b)); + } private: - udp::UdpBinded peer_a; - udp::UdpBinded peer_b; + static Address extractRemoteAddress(const Udp_ &subject); + + Udp_ peer_a; + Udp_ peer_b; }; -#define UDP_PEERS(PORT_A, PORT_B, FAMILY) \ - UdpPeers peers(PORT_A, PORT_B, FAMILY); \ - auto &requester = peers.getPeerA(); \ - const auto requester_address = peers.addressPeerA(); \ - auto &responder = peers.getPeerB(); \ - const auto responder_address = peers.addressPeerB(); +#define UDP_PEERS(TYPE, FAMILY) \ + UdpPeers peers{PortFactory::get().makePort(), \ + PortFactory::get().makePort(), family}; \ + auto tmp = peers.get(); \ + auto *requester = std::get<0>(tmp); \ + auto &requester_address = std::get<1>(tmp); \ + auto *responder = std::get<2>(tmp); \ + auto &responder_address = std::get<3>(tmp); } // namespace MinimalSocket::test diff --git a/tests/ParallelSection.cpp b/tests/ParallelSection.cpp index 64002b86..c2c07dc3 100644 --- a/tests/ParallelSection.cpp +++ b/tests/ParallelSection.cpp @@ -10,7 +10,7 @@ namespace MinimalSocket::test { namespace { -std::function make_thread(Barrier &br, const Task &task) { +std::function make_task(Barrier &br, const Task &task) { return [&task = task, &br = br]() mutable { br.arrive_and_wait(); task(br); @@ -22,12 +22,12 @@ void ParallelSection::run() { if (tasks.size() < 2) { throw std::runtime_error{"invalid number of tasks for parallel region"}; } - barrier.emplace(tasks.size()); + auto &br = barrier.emplace(tasks.size()); std::vector spinners; - for (auto it = tasks.begin(); it != tasks.end() - 1; ++it) { - spinners.emplace_back(make_thread(barrier.value(), *it)); - } - spinners.emplace_back(make_thread(barrier.value(), tasks.back())); + std::for_each(tasks.begin() + 1, tasks.end(), [&](const Task &t) { + spinners.emplace_back(make_task(br, t)); + }); + make_task(br, tasks.front())(); for (auto &sp : spinners) { sp.join(); } diff --git a/tests/PortFactory.cpp b/tests/PortFactory.cpp index ad8eead1..9c9971ad 100644 --- a/tests/PortFactory.cpp +++ b/tests/PortFactory.cpp @@ -8,18 +8,22 @@ #include "PortFactory.h" namespace MinimalSocket::test { +PortFactory &PortFactory::get() { + static PortFactory res = PortFactory{}; + return res; +} + namespace { static constexpr std::uint16_t INITIAL_PORT = 9999; static constexpr std::uint16_t DELTA_PORT = 10; } // namespace -std::mutex PortFactory::port_mtx = std::mutex{}; -Port PortFactory::port = INITIAL_PORT; - Port PortFactory::makePort() { std::lock_guard lock(port_mtx); auto result = port; port += DELTA_PORT; return result; } + +PortFactory::PortFactory() : port{INITIAL_PORT} {} } // namespace MinimalSocket::test diff --git a/tests/PortFactory.h b/tests/PortFactory.h index 3d50d126..3fdec065 100644 --- a/tests/PortFactory.h +++ b/tests/PortFactory.h @@ -13,10 +13,14 @@ namespace MinimalSocket::test { class PortFactory { public: - static Port makePort(); + static PortFactory &get(); + + Port makePort(); private: - static std::mutex port_mtx; - static Port port; + PortFactory(); + + std::mutex port_mtx; + Port port; }; } // namespace MinimalSocket::test diff --git a/tests/RollingView.cpp b/tests/RollingView.cpp new file mode 100644 index 00000000..6a0900e4 --- /dev/null +++ b/tests/RollingView.cpp @@ -0,0 +1,56 @@ +#include "RollingView.h" + +namespace MinimalSocket::test { +RollingView::RollingView(const std::string &buff) : buffer{buff} {} + +RollingView::RollingView(std::size_t buff_size) { buffer.resize(buff_size); } + +void sliced_send(Sender &subject, const std::string &to_send, + std::size_t delta_send) { + RollingView buff{to_send}; + buff.forEachView(delta_send, [&](const std::string_view &view) { + bool ok = subject.send(BufferViewConst{view.data(), view.size()}); + if (!ok) { + throw std::runtime_error{"wasn't able to send all data"}; + } + return delta_send; + }); +} + +void sliced_send(SenderTo &subject, const std::string &to_send, + const Address &to_send_address, std::size_t delta_send) { + RollingView buff{to_send}; + buff.forEachView(delta_send, [&](const std::string_view &view) { + bool ok = subject.sendTo(BufferViewConst{view.data(), view.size()}, + to_send_address); + if (!ok) { + throw std::runtime_error{"wasn't able to send all data"}; + } + return delta_send; + }); +} + +std::string sliced_receive(Receiver &subject, std::size_t to_receive, + std::size_t delta_receive) { + RollingView buff{to_receive}; + buff.forEachView(delta_receive, [&](const std::string_view &view) { + BufferView buff{const_cast(view.data()), view.size()}; + return subject.receive(buff); + }); + return buff.getBuffer(); +} + +std::string sliced_receive(ReceiverUnkownSender &subject, + std::size_t to_receive, std::size_t delta_receive) { + RollingView buff{to_receive}; + buff.forEachView(delta_receive, [&](const std::string_view &view) { + BufferView buff{const_cast(view.data()), view.size()}; + auto maybe_bytes_received = subject.receive(buff); + if (!maybe_bytes_received.has_value()) { + throw std::runtime_error{"wasn'table tor receive the data"}; + } + return maybe_bytes_received->received_bytes; + }); + return buff.getBuffer(); +} +} // namespace MinimalSocket::test diff --git a/tests/RollingView.h b/tests/RollingView.h new file mode 100644 index 00000000..ae698b64 --- /dev/null +++ b/tests/RollingView.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +#include +#include + +namespace MinimalSocket::test { +class RollingView { +public: + RollingView(const std::string &buff); + RollingView(std::size_t buff_size); + + template + void forEachView(std::size_t offset, Pred pred) const { + std::size_t view_begin = 0; + std::size_t view_end = std::min(offset, buffer.size()); + while (true) { + std::size_t processed = pred( + std::string_view{buffer.data() + view_begin, view_end - view_begin}); + if (processed != offset) { + throw std::runtime_error{"Wrong number of bytes were processed"}; + } + if (view_end == buffer.size()) { + break; + } + view_begin = view_end; + view_end += offset; + view_end = std::min(view_end, buffer.size()); + } + } + + const auto &getBuffer() const { return buffer; } + +private: + std::string buffer; +}; + +void sliced_send(Sender &subject, const std::string &to_send, + std::size_t delta_send); + +void sliced_send(SenderTo &subject, const std::string &to_send, + const Address &to_send_address, std::size_t delta_send); + +std::string sliced_receive(Receiver &subject, std::size_t to_receive, + std::size_t delta_receive); + +std::string sliced_receive(ReceiverUnkownSender &subject, + std::size_t to_receive, std::size_t delta_receive); +} // namespace MinimalSocket::test diff --git a/tests/SlicedOps.cpp b/tests/SlicedOps.cpp deleted file mode 100644 index 2ae73b1a..00000000 --- a/tests/SlicedOps.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include "SlicedOps.h" - -namespace MinimalSocket::test { -MovingPointerBuffer::MovingPointerBuffer(const std::string &buff) - : buffer(buff) { - init(); -} - -MovingPointerBuffer::MovingPointerBuffer(const std::size_t buff_size) { - buffer.resize(buff_size); - init(); -} - -std::size_t MovingPointerBuffer::remainingBytes() const { - return buffer.size() - buffer_cursor; -} - -void MovingPointerBuffer::shift(const std::size_t stride) { - buffer_cursor += stride; - buffer_pointer += stride; -} - -void MovingPointerBuffer::init() { - buffer_cursor = 0; - buffer_pointer = buffer.data(); -} - -void sliced_send(Sender &subject, const std::string &to_send, - const std::size_t delta_send) { - MovingPointerBuffer buffer(to_send); - while (buffer.remainingBytes() != 0) { - std::size_t bytes_to_send = - std::min(delta_send, buffer.remainingBytes()); - subject.send(BufferViewConst{buffer.data(), bytes_to_send}); - buffer.shift(bytes_to_send); - } -} - -void sliced_send(SenderTo &subject, const std::string &to_send, - const Address &to_send_address, const std::size_t delta_send) { - MovingPointerBuffer buffer(to_send); - while (buffer.remainingBytes() != 0) { - std::size_t bytes_to_send = - std::min(delta_send, buffer.remainingBytes()); - subject.sendTo(BufferViewConst{buffer.data(), bytes_to_send}, - to_send_address); - buffer.shift(bytes_to_send); - } -} - -std::string sliced_receive(Receiver &subject, const std::size_t to_receive, - const std::size_t delta_receive) { - MovingPointerBuffer buffer(to_receive); - while (buffer.remainingBytes() != 0) { - std::size_t bytes_to_receive = - std::min(delta_receive, buffer.remainingBytes()); - auto bytes_received = - subject.receive(BufferView{buffer.data(), bytes_to_receive}); - buffer.shift(bytes_received); - } - return buffer.asString(); -} - -std::string sliced_receive(ReceiverUnkownSender &subject, - const std::size_t to_receive, - const std::size_t delta_receive) { - MovingPointerBuffer buffer(to_receive); - while (buffer.remainingBytes() != 0) { - std::size_t bytes_to_receive = - std::min(delta_receive, buffer.remainingBytes()); - auto maybe_bytes_received = - subject.receive(BufferView{buffer.data(), bytes_to_receive}); - if (maybe_bytes_received) { - buffer.shift(maybe_bytes_received->received_bytes); - } - } - return buffer.asString(); -} -} // namespace MinimalSocket::test diff --git a/tests/SlicedOps.h b/tests/SlicedOps.h deleted file mode 100644 index 4b15de0f..00000000 --- a/tests/SlicedOps.h +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include -#include - -namespace MinimalSocket::test { -class MovingPointerBuffer { -public: - MovingPointerBuffer(const std::string &buff); - MovingPointerBuffer(const std::size_t buff_size); - - const std::string &asString() const { return buffer; } - char *data() { return buffer_pointer; } - const char *data() const { return buffer_pointer; } - - std::size_t remainingBytes() const; - void shift(const std::size_t stride); - -private: - void init(); - - std::string buffer; - std::size_t buffer_cursor; - char *buffer_pointer; -}; - -void sliced_send(Sender &subject, const std::string &to_send, - const std::size_t delta_send); - -void sliced_send(SenderTo &subject, const std::string &to_send, - const Address &to_send_address, const std::size_t delta_send); - -std::string sliced_receive(Receiver &subject, const std::size_t to_receive, - const std::size_t delta_receive); - -std::string sliced_receive(ReceiverUnkownSender &subject, - const std::size_t to_receive, - const std::size_t delta_receive); -} // namespace MinimalSocket::test diff --git a/tests/TestRobustness.cpp b/tests/TestRobustness.cpp index ab0b1713..4b26bcf2 100644 --- a/tests/TestRobustness.cpp +++ b/tests/TestRobustness.cpp @@ -32,37 +32,38 @@ template void close(SocketT &subject) { } // namespace TEST_CASE("Thread safe d'tor tcp case", "[robustness]") { - const auto port = PortFactory::makePort(); + const auto port = PortFactory::get().makePort(); const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); SECTION("on connected sockets") { test::TcpPeers peers(port, family); + auto [server_side, client_side] = peers.get(); SECTION("close client while receiving") { ParallelSection::biSection( - [&peers](auto &) { - CHECK(peers.getClientSide().receive(500).empty()); + [&client_side = client_side](auto &) { + CHECK(client_side->receive(500).empty()); }, - [&peers](auto &) { + [&client_side = client_side](auto &) { std::this_thread::sleep_for(std::chrono::milliseconds{50}); - close(peers.getClientSide()); + close(*client_side); }); } SECTION("close server side while receiving") { ParallelSection::biSection( - [&peers](auto &) { - CHECK(peers.getClientSide().receive(500).empty()); + [&server_side = server_side](auto &) { + CHECK(server_side->receive(500).empty()); }, - [&peers](auto &) { + [&server_side = server_side](auto &) { std::this_thread::sleep_for(std::chrono::milliseconds{50}); - close(peers.getServerSide()); + close(*server_side); }); } } SECTION("close while accepting client") { - tcp::TcpServer server(port, family); + tcp::TcpServer server(port, family); REQUIRE(server.open()); ParallelSection::biSection( [&server](auto &) { @@ -76,21 +77,20 @@ TEST_CASE("Thread safe d'tor tcp case", "[robustness]") { } TEST_CASE("Receive from multiple threads tcp case", "[robustness]") { - const auto port = PortFactory::makePort(); + const auto port = PortFactory::get().makePort(); const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); test::TcpPeers peers(port, family); - auto &server_side = peers.getServerSide(); - auto &client_side = peers.getClientSide(); + auto [server_side, client_side] = peers.get(); const std::size_t threads = 3; ParallelSection sections; - sections.add([&](auto &) { - client_side.send(make_repeated_message(MESSAGE, threads)); + sections.add([&client_side = client_side](auto &) { + client_side->send(make_repeated_message(MESSAGE, threads)); }); for (std::size_t t = 0; t < threads; ++t) { - sections.add([&](auto &) { - const auto received_request = server_side.receive(MESSAGE.size()); + sections.add([&server_side = server_side](auto &) { + const auto received_request = server_side->receive(MESSAGE.size()); CHECK(received_request == MESSAGE); }); } @@ -98,21 +98,21 @@ TEST_CASE("Receive from multiple threads tcp case", "[robustness]") { } TEST_CASE("Send from multiple threads tcp case", "[robustness]") { - const auto port = PortFactory::makePort(); + const auto port = PortFactory::get().makePort(); const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); test::TcpPeers peers(port, family); - auto &server_side = peers.getServerSide(); - auto &client_side = peers.getClientSide(); + auto [server_side, client_side] = peers.get(); const std::size_t threads = 3; ParallelSection sections; for (std::size_t t = 0; t < threads; ++t) { - sections.add([&](auto &) { client_side.send(MESSAGE); }); + sections.add( + [&client_side = client_side](auto &) { client_side->send(MESSAGE); }); } - sections.add([&](auto &) { + sections.add([&server_side = server_side](auto &) { for (std::size_t t = 0; t < threads; ++t) { - const auto received_request = server_side.receive(MESSAGE.size()); + const auto received_request = server_side->receive(MESSAGE.size()); CHECK(received_request == MESSAGE); } }); @@ -122,7 +122,7 @@ TEST_CASE("Send from multiple threads tcp case", "[robustness]") { TEST_CASE("Thread safe d'tor udp case", "[robustness]") { const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - udp::UdpBinded connection(PortFactory::makePort()); + udp::Udp connection(PortFactory::get().makePort()); ParallelSection::biSection( [&](auto &) { CHECK_THROWS_AS(connection.receive(500), Error); }, @@ -132,25 +132,23 @@ TEST_CASE("Thread safe d'tor udp case", "[robustness]") { }); } -/* - TEST_CASE("Receive from multiple threads udp case", "[robustness]") { const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - UDP_PEERS(PortFactory::makePort(), PortFactory::makePort(), family) + UDP_PEERS(udp::Udp, family); const std::size_t threads = 3; ParallelSection sections; sections.add([&](Barrier &br) { for (std::size_t t = 0; t < threads; ++t) { - requester.sendTo(MESSAGE, responder_address); + requester->sendTo(MESSAGE, responder_address); } br.arrive_and_wait(); }); for (std::size_t t = 0; t < threads; ++t) { sections.add([&](Barrier &br) { br.arrive_and_wait(); - const auto received_request = responder.receive(MESSAGE.size()); + const auto received_request = responder->receive(MESSAGE.size()); CHECK(received_request); CHECK(received_request->received_message == MESSAGE); }); @@ -161,39 +159,38 @@ TEST_CASE("Receive from multiple threads udp case", "[robustness]") { TEST_CASE("Send from multiple threads udp case", "[robustness]") { const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - UDP_PEERS(PortFactory::makePort(), PortFactory::makePort(), family) + UDP_PEERS(udp::Udp, family); const std::size_t threads = 3; ParallelSection sections; for (std::size_t t = 0; t < threads; ++t) { sections.add([&](Barrier &br) { - requester.sendTo(MESSAGE, responder_address); + requester->sendTo(MESSAGE, responder_address); br.arrive_and_wait(); }); } sections.add([&](Barrier &br) { br.arrive_and_wait(); for (std::size_t t = 0; t < threads; ++t) { - const auto received_request = responder.receive(MESSAGE.size()); + const auto received_request = responder->receive(MESSAGE.size()); CHECK(received_request); CHECK(received_request->received_message == MESSAGE); } }); sections.run(); } -*/ TEST_CASE("Use tcp socket before opening it", "[robustness]") { - const auto port = PortFactory::makePort(); + const auto port = PortFactory::get().makePort(); const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); SECTION("server") { - tcp::TcpServer socket(port, family); + tcp::TcpServer socket(port, family); CHECK_THROWS_AS(socket.acceptNewClient(), Error); } SECTION("client") { - tcp::TcpClient socket(Address{port, family}); + tcp::TcpClient socket(Address{port, family}); CHECK_THROWS_AS(socket.receive(500), SocketError); CHECK_THROWS_AS(socket.send("dummy"), SocketError); } @@ -202,10 +199,10 @@ TEST_CASE("Use tcp socket before opening it", "[robustness]") { TEST_CASE("Use udp socket before opening it", "[robustness]") { const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - udp::UdpBinded socket(PortFactory::makePort(), family); + udp::Udp socket(PortFactory::get().makePort(), family); CHECK_THROWS_AS(socket.receive(500), SocketError); CHECK_THROWS_AS( - socket.sendTo("dummy", Address{PortFactory::makePort(), family}), + socket.sendTo("dummy", Address{PortFactory::get().makePort(), family}), SocketError); CHECK_THROWS_AS(socket.connect(), SocketError); } diff --git a/tests/TestTCP.cpp b/tests/TestTCP.cpp index f09f7c0a..0db64172 100644 --- a/tests/TestTCP.cpp +++ b/tests/TestTCP.cpp @@ -9,7 +9,7 @@ #include "ConnectionsUtils.h" #include "ParallelSection.h" #include "PortFactory.h" -#include "SlicedOps.h" +#include "RollingView.h" using namespace MinimalSocket; using namespace MinimalSocket::tcp; @@ -21,11 +21,11 @@ static const std::string response = "Welcome"; struct SenderReceiver { Sender &sender; - Receiver &receiver; + Receiver &receiver; }; template SenderReceiver makeSenderReceiver(T &subject) { Sender &as_sender = subject; - Receiver &as_receiver = subject; + Receiver &as_receiver = subject; return SenderReceiver{as_sender, as_receiver}; } @@ -52,12 +52,12 @@ void send_response(const SenderReceiver &requester, } // namespace TEST_CASE("Establish tcp connection", "[tcp]") { - const auto port = PortFactory::makePort(); + const auto port = PortFactory::get().makePort(); const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); #if !defined(_WIN32) SECTION("expected failure") { - TcpClient client(Address(port, family)); + TcpClient client(Address(port, family)); CHECK_THROWS_AS(client.open(), Error); CHECK_FALSE(client.wasOpened()); } @@ -65,43 +65,40 @@ TEST_CASE("Establish tcp connection", "[tcp]") { SECTION("expected success") { test::TcpPeers peers(port, family); - auto &server_side = peers.getServerSide(); - auto &client_side = peers.getClientSide(); + auto [server_side, client_side] = peers.get(); - REQUIRE(client_side.wasOpened()); + REQUIRE(client_side->wasOpened()); const std::size_t cycles = 5; - const std::string request = "Hello"; - const std::string response = "Welcome"; SECTION("client send, server respond") { - send_response(makeSenderReceiver(client_side), - makeSenderReceiver(server_side)); + send_response(makeSenderReceiver(*client_side), + makeSenderReceiver(*server_side)); } SECTION("server send, client respond") { - send_response(makeSenderReceiver(server_side), - makeSenderReceiver(client_side)); + send_response(makeSenderReceiver(*server_side), + makeSenderReceiver(*client_side)); } SECTION("receive with timeout") { const auto timeout = Timeout{500}; SECTION("expect fail within timeout") { - auto received_request = server_side.receive(request.size(), timeout); + auto received_request = server_side->receive(request.size(), timeout); CHECK(received_request.empty()); } SECTION("expect success within timeout") { const auto wait = Timeout{250}; ParallelSection::biSection( - [&](auto &) { + [&, client_side = client_side](auto &) { std::this_thread::sleep_for(wait); - client_side.send(request); + client_side->send(request); }, - [&](auto &) { + [&, server_side = server_side](auto &) { auto received_request = - server_side.receive(request.size(), timeout); + server_side->receive(request.size(), timeout); CHECK(received_request == request); }); } @@ -110,17 +107,17 @@ TEST_CASE("Establish tcp connection", "[tcp]") { } TEST_CASE("Establish many tcp connections to same server", "[tcp]") { - const auto port = PortFactory::makePort(); + const auto port = PortFactory::get().makePort(); const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - TcpServer server(port, family); + TcpServer server(port, family); server.open(); const std::size_t clients_numb = 5; SECTION("sequencial connnections") { - std::list accepted_clients; - std::list clients; + std::list accepted_clients; + std::list> clients; ParallelSection::biSection( [&](auto &) { for (std::size_t c = 0; c < clients_numb; ++c) { @@ -143,7 +140,7 @@ TEST_CASE("Establish many tcp connections to same server", "[tcp]") { } }); Task ask_connection = [&](auto &) { - TcpClient client(Address(port, family)); + TcpClient client(Address(port, family)); CHECK(client.open()); }; for (std::size_t c = 0; c < clients_numb; ++c) { @@ -154,15 +151,15 @@ TEST_CASE("Establish many tcp connections to same server", "[tcp]") { } TEST_CASE("Open multiple times tcp clients", "[tcp]") { - const auto port = PortFactory::makePort(); + const auto port = PortFactory::get().makePort(); const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - TcpServer server(port, family); + TcpServer server(port, family); server.open(); std::size_t cycles = 5; - TcpClient client(Address(port, family)); + TcpClient client(Address(port, family)); for (std::size_t c = 0; c < cycles; ++c) { ParallelSection::biSection([&](auto &) { server.acceptNewClient(); }, @@ -175,33 +172,32 @@ TEST_CASE("Open multiple times tcp clients", "[tcp]") { } TEST_CASE("Open tcp client with timeout", "[tcp]") { - const auto port = PortFactory::makePort(); + const auto port = PortFactory::get().makePort(); const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); const auto timeout = Timeout{500}; - TcpClient client(Address(port, family)); + TcpClient client(Address(port, family)); SECTION("expect fail within timeout") { #ifdef _WIN32 CHECK_FALSE(client.open(timeout)); #else - CHECK_THROWS_AS( - client.open(timeout), - Error); // linux throw if no server tcp were previously created, while - // windows seems to does not have this check + // linux throw if no server tcp were previously created, while windows seems + // to does not have this check + CHECK_THROWS_AS(client.open(timeout), Error); #endif CHECK_FALSE(client.wasOpened()); } SECTION("expect success within timeout") { const auto wait = Timeout{250}; - TcpServer server(port, family); + TcpServer server(port, family); REQUIRE(server.open()); ParallelSection::biSection( [&](auto &) { std::this_thread::sleep_for(wait); - TcpConnection conn = server.acceptNewClient(); + auto conn = server.acceptNewClient(); auto received_request = conn.receive(request.size()); CHECK(received_request == request); }, @@ -215,7 +211,7 @@ TEST_CASE("Open tcp client with timeout", "[tcp]") { TEST_CASE("Reserve random port for tcp server", "[tcp]") { const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - TcpServer server(ANY_PORT, family); + TcpServer server(ANY_PORT, family); REQUIRE(server.open()); const auto port = server.getPortToBind(); REQUIRE(port != 0); @@ -229,7 +225,7 @@ TEST_CASE("Reserve random port for tcp server", "[tcp]") { }, [&](Barrier &br) { // client - TcpClient client(Address(port, family)); + TcpClient client(Address(port, family)); br.arrive_and_wait(); REQUIRE(client.open()); REQUIRE(client.wasOpened()); @@ -237,109 +233,97 @@ TEST_CASE("Reserve random port for tcp server", "[tcp]") { }); } -TEST_CASE("Accept client with timeout", "[tcp]") { +TEST_CASE("Send Receive messages split into multiple pieces (tcp)", "[tcp]") { + const auto port = PortFactory::get().makePort(); const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - const auto port = PortFactory::makePort(); - - TcpServer server(port, family); - REQUIRE(server.open()); - const auto server_address = Address(port, family); - - const auto timeout = Timeout{500}; - SECTION("expect fail within timeout") { - // connect first client - TcpClient client_first = TcpClient{server_address}; - std::unique_ptr server_side_first; - ParallelSection::biSection([&](auto &) { CHECK(client_first.open()); }, - [&](auto &) { - auto accepted = server.acceptNewClient(); - server_side_first = - std::make_unique( - std::move(accepted)); - }); + TcpPeers peers(port, family); + auto [server_side, client_side] = peers.get(); - // expect second accept to fail - CHECK_FALSE(server.acceptNewClient(timeout)); - CHECK(server.wasOpened()); + const std::string request = "This is a simulated long message"; - // check first accepted connection is still valid - ParallelSection::biSection( - [&](auto &) { - auto received_request = server_side_first->receive(request.size()); - CHECK(received_request == request); - }, - [&](auto &) { - // client - client_first.send(request); - }); + const std::size_t delta = 4; - // connect second client after accept unsuccess and check they can exchange - // messages + SECTION("split receive") { ParallelSection::biSection( - [&](Barrier &br) { - TcpClient client_second = TcpClient{server_address}; - br.arrive_and_wait(); - CHECK(client_second.open()); - client_second.send(request); - }, - [&](Barrier &br) { - br.arrive_and_wait(); - auto server_side_second = server.acceptNewClient(); - auto received_request = server_side_second.receive(request.size()); + [&, client_side = client_side](auto &) { client_side->send(request); }, + [&, server_side = server_side](auto &) { + auto received_request = + sliced_receive(*server_side, request.size(), 4); CHECK(received_request == request); }); } - SECTION("expect success within timeout") { - const auto wait = Timeout{250}; + SECTION("split send") { ParallelSection::biSection( - [&](Barrier &br) { - TcpClient client = TcpClient{server_address}; + [&, client_side = client_side](Barrier &br) { + sliced_send(*client_side, request, 4); br.arrive_and_wait(); - std::this_thread::sleep_for(wait); - CHECK(client.open()); }, - [&](Barrier &br) { + [&, server_side = server_side](Barrier &br) { br.arrive_and_wait(); - CHECK(server.acceptNewClient(timeout)); + auto received_request = server_side->receive(request.size()); + CHECK(received_request == request); }); } } -#if !defined(__APPLE__) -TEST_CASE("Send Receive messages split into multiple pieces (tcp)", "[tcp]") { - const auto port = PortFactory::makePort(); +TEST_CASE("Establish tcp connection non blocking", "[tcp]") { + const auto port = PortFactory::get().makePort(); const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - TcpPeers peers(port, family); - auto &server_side = peers.getServerSide(); - auto &client_side = peers.getClientSide(); + tcp::TcpServer server{port, family}; + REQUIRE(server.open()); - const std::string request = "This is a simulated long message"; + ParallelSection::biSection( + [&](Barrier &br) { + CHECK_FALSE(server.acceptNewClient().has_value()); + br.arrive_and_wait(); + std::this_thread::sleep_for(std::chrono::milliseconds{500}); + auto accepted = server.acceptNewClient(); + REQUIRE(accepted.has_value()); + auto received_request = accepted->receive(request.size()); + CHECK(received_request == request); + }, + [&](Barrier &br) { + br.arrive_and_wait(); + TcpClient client{port}; + client.open(); + REQUIRE(client.wasOpened()); + client.send(request); + }); +} - const std::size_t delta = 4; +TEST_CASE("Receive non blocking (tcp)", "[tcp]") { + const auto port = PortFactory::get().makePort(); + const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - SECTION("split receive") { - ParallelSection::biSection([&](auto &) { client_side.send(request); }, - [&](auto &) { - auto received_request = sliced_receive( - server_side, request.size(), 4); - CHECK(received_request == request); - }); + std::optional server_side; + tcp::TcpClient client_side{port}; + ParallelSection::biSection( + [&](Barrier &br) { + tcp::TcpServer server{port, family}; + REQUIRE(server.open()); + br.arrive_and_wait(); + auto accepted = server.acceptNewClient(); + server_side.emplace(accepted.turnToNonBlocking()); + }, + [&](Barrier &br) { + br.arrive_and_wait(); + REQUIRE(client_side.open()); + }); + + SECTION("client side non blocking receive") { + CHECK(client_side.receive(request.size()).empty()); + server_side->send(request); + auto received_request = client_side.receive(request.size()); + CHECK(received_request == request); } - SECTION("split send") { - ParallelSection::biSection( - [&](Barrier &br) { - sliced_send(client_side, request, 4); - br.arrive_and_wait(); - }, - [&](Barrier &br) { - br.arrive_and_wait(); - auto received_request = server_side.receive(request.size()); - CHECK(received_request == request); - }); + SECTION("server side non blocking receive") { + CHECK(server_side->receive(request.size()).empty()); + client_side.send(request); + auto received_request = server_side->receive(request.size()); + CHECK(received_request == request); } } -#endif diff --git a/tests/TestUDP.cpp b/tests/TestUDP.cpp index c393ab43..bf010c6b 100644 --- a/tests/TestUDP.cpp +++ b/tests/TestUDP.cpp @@ -6,7 +6,7 @@ #include "ConnectionsUtils.h" #include "ParallelSection.h" #include "PortFactory.h" -#include "SlicedOps.h" +#include "RollingView.h" using namespace MinimalSocket; using namespace MinimalSocket::udp; @@ -20,21 +20,22 @@ bool are_same(const Address &a, const Address &b, const AddressFamily &family) { return (family == AddressFamily::IP_V4) ? (a == b) : (a.getPort() == b.getPort()); } -} // namespace + +}; // namespace TEST_CASE("Exchange messages between UdpBinded and UdpBinded", "[udp]") { const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); const std::size_t cycles = 5; - UDP_PEERS(PortFactory::makePort(), PortFactory::makePort(), family); + UDP_PEERS(udp::Udp, family); ParallelSection::biSection( [&](Barrier &br) { for (std::size_t c = 0; c < cycles; ++c) { - CHECK(requester.sendTo(request, responder_address)); + CHECK(requester->sendTo(request, responder_address)); br.arrive_and_wait(); br.arrive_and_wait(); - auto received_response = requester.receive(response.size()); + auto received_response = requester->receive(response.size()); REQUIRE(received_response); CHECK(received_response->received_message == response); CHECK(are_same(received_response->sender, responder_address, family)); @@ -43,11 +44,11 @@ TEST_CASE("Exchange messages between UdpBinded and UdpBinded", "[udp]") { [&](Barrier &br) { for (std::size_t c = 0; c < cycles; ++c) { br.arrive_and_wait(); - auto received_request = responder.receive(request.size()); + auto received_request = responder->receive(request.size()); REQUIRE(received_request); CHECK(received_request->received_message == request); CHECK(are_same(received_request->sender, requester_address, family)); - responder.sendTo(response, requester_address); + responder->sendTo(response, requester_address); br.arrive_and_wait(); } }); @@ -56,7 +57,7 @@ TEST_CASE("Exchange messages between UdpBinded and UdpBinded", "[udp]") { const auto timeout = Timeout{500}; SECTION("expect fail within timeout") { - auto received_request = responder.receive(request.size(), timeout); + auto received_request = responder->receive(request.size(), timeout); CHECK_FALSE(received_request); } @@ -65,10 +66,10 @@ TEST_CASE("Exchange messages between UdpBinded and UdpBinded", "[udp]") { ParallelSection::biSection( [&](auto &) { std::this_thread::sleep_for(wait); - requester.sendTo(request, responder_address); + requester->sendTo(request, responder_address); }, [&](auto &) { - auto received_request = responder.receive(request.size(), timeout); + auto received_request = responder->receive(request.size(), timeout); REQUIRE(received_request); CHECK(received_request->received_message == request); CHECK( @@ -82,33 +83,24 @@ TEST_CASE("Exchange messages between UdpConnected and UdpConnected", "[udp]") { const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); const std::size_t cycles = 5; - const auto requester_port = PortFactory::makePort(); - const Address requester_address = Address(requester_port, family); - - const auto responder_port = PortFactory::makePort(); - const Address responder_address = Address(responder_port, family); - - UdpConnected requester(responder_address, requester_port); - REQUIRE(requester.open()); - UdpConnected responder(requester_address, responder_port); - REQUIRE(responder.open()); + UDP_PEERS(udp::UdpConnected, family); ParallelSection::biSection( [&](Barrier &br) { for (std::size_t c = 0; c < cycles; ++c) { - CHECK(requester.send(request)); + CHECK(requester->send(request)); br.arrive_and_wait(); br.arrive_and_wait(); - auto received_response = requester.receive(response.size()); + auto received_response = requester->receive(response.size()); CHECK(received_response == response); } }, [&](Barrier &br) { for (std::size_t c = 0; c < cycles; ++c) { br.arrive_and_wait(); - auto received_request = responder.receive(request.size()); + auto received_request = responder->receive(request.size()); CHECK(received_request == request); - responder.send(response); + responder->send(response); br.arrive_and_wait(); } }); @@ -117,7 +109,7 @@ TEST_CASE("Exchange messages between UdpConnected and UdpConnected", "[udp]") { const auto timeout = Timeout{500}; SECTION("expect fail within timeout") { - auto received_request = responder.receive(request.size(), timeout); + auto received_request = responder->receive(request.size(), timeout); CHECK(received_request.empty()); } @@ -126,10 +118,10 @@ TEST_CASE("Exchange messages between UdpConnected and UdpConnected", "[udp]") { ParallelSection::biSection( [&](auto &) { std::this_thread::sleep_for(wait); - requester.send(request); + requester->send(request); }, [&](auto &) { - auto received_request = responder.receive(request.size(), timeout); + auto received_request = responder->receive(request.size(), timeout); CHECK(received_request == request); }); } @@ -141,47 +133,38 @@ TEST_CASE( "[udp]") { const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - const auto requester_port = PortFactory::makePort(); - const Address requester_address = Address(requester_port, family); - - const auto responder_port = PortFactory::makePort(); - const Address responder_address = Address(responder_port, family); - - UdpConnected requester(responder_address, requester_port); - REQUIRE(requester.open()); - UdpConnected responder(requester_address, responder_port); - REQUIRE(responder.open()); + UDP_PEERS(udp::UdpConnected, family); auto exchange_messages_before = GENERATE(true, false); if (exchange_messages_before) { ParallelSection::biSection( [&](Barrier &br) { - CHECK(requester.send(request)); + CHECK(requester->send(request)); br.arrive_and_wait(); br.arrive_and_wait(); - auto received_response = requester.receive(response.size()); + auto received_response = requester->receive(response.size()); CHECK(received_response == response); }, [&](Barrier &br) { br.arrive_and_wait(); - auto received_request = responder.receive(request.size()); + auto received_request = responder->receive(request.size()); CHECK(received_request == request); - responder.send(response); + responder->send(response); br.arrive_and_wait(); }); } - UdpBinded second_requester(PortFactory::makePort(), family); + udp::Udp second_requester(PortFactory::get().makePort(), family); REQUIRE(second_requester.open()); const auto timeout = Timeout{500}; const auto wait = Timeout{250}; ParallelSection::biSection( [&](auto &) { std::this_thread::sleep_for(wait); - second_requester.sendTo(request, Address(responder_port, family)); + second_requester.sendTo(request, responder_address); }, [&](auto &) { - auto received_request = responder.receive(request.size(), timeout); + auto received_request = responder->receive(request.size(), timeout); CHECK(received_request.empty()); }); } @@ -190,38 +173,26 @@ TEST_CASE("Metamorphosis of udp connections", "[udp]") { const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); const std::size_t cycles = 5; - const auto requester_port = PortFactory::makePort(); - const Address requester_address = Address(requester_port, family); - const auto responder_port = PortFactory::makePort(); - const Address responder_address = Address(responder_port, family); - - UdpBinded responder(responder_port, family); - REQUIRE(responder.open()); - - std::unique_ptr requester_only_bind = - std::make_unique(requester_port, family); - REQUIRE(requester_only_bind->open()); + UDP_PEERS(udp::Udp, family); // connect requester to responder auto deduce_sender = GENERATE(true, false); - std::unique_ptr requester_connected; + std::optional> requester_connected; if (deduce_sender) { ParallelSection::biSection( [&](Barrier &br) { - responder.sendTo("1", requester_address); + responder->sendTo("1", requester_address); br.arrive_and_wait(); }, [&](Barrier &br) { br.arrive_and_wait(); - auto socket_connected = requester_only_bind->connect(); + auto socket_connected = requester->connect(); CHECK(are_same(socket_connected.getRemoteAddress(), responder_address, family)); - requester_connected = - std::make_unique(std::move(socket_connected)); + requester_connected.emplace(std::move(socket_connected)); }); } else { - requester_connected = std::make_unique( - requester_only_bind->connect(responder_address)); + requester_connected.emplace(requester->connect(responder_address)); } REQUIRE(requester_connected->wasOpened()); @@ -240,28 +211,26 @@ TEST_CASE("Metamorphosis of udp connections", "[udp]") { [&](Barrier &br) { for (std::size_t c = 0; c < cycles; ++c) { br.arrive_and_wait(); - auto received_request = responder.receive(request.size()); + auto received_request = responder->receive(request.size()); REQUIRE(received_request); CHECK(received_request->received_message == request); - responder.sendTo(response, requester_address); + responder->sendTo(response, requester_address); br.arrive_and_wait(); } }); // try to disconnect requester - requester_only_bind = - std::make_unique(requester_connected->disconnect()); - REQUIRE(requester_only_bind->wasOpened()); + *requester = requester_connected->disconnect(); + REQUIRE(requester->wasOpened()); // try message exchange ParallelSection::biSection( [&](Barrier &br) { for (std::size_t c = 0; c < cycles; ++c) { - CHECK(requester_only_bind->sendTo(request, responder_address)); + CHECK(requester->sendTo(request, responder_address)); br.arrive_and_wait(); br.arrive_and_wait(); - auto received_response = - requester_only_bind->receive(response.size()); + auto received_response = requester->receive(response.size()); REQUIRE(received_response); CHECK(received_response->received_message == response); } @@ -269,10 +238,10 @@ TEST_CASE("Metamorphosis of udp connections", "[udp]") { [&](Barrier &br) { for (std::size_t c = 0; c < cycles; ++c) { br.arrive_and_wait(); - auto received_request = responder.receive(request.size()); + auto received_request = responder->receive(request.size()); REQUIRE(received_request); CHECK(received_request->received_message == request); - responder.sendTo(response, requester_address); + responder->sendTo(response, requester_address); br.arrive_and_wait(); } }); @@ -281,12 +250,12 @@ TEST_CASE("Metamorphosis of udp connections", "[udp]") { TEST_CASE("Open connection with timeout", "[udp]") { const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - UDP_PEERS(PortFactory::makePort(), PortFactory::makePort(), family); + UDP_PEERS(udp::Udp, family); const auto timeout = Timeout{500}; SECTION("expect fail within timeout") { - CHECK_FALSE(requester.connect(timeout)); + CHECK_FALSE(requester->connect(timeout)); } SECTION("expect success within timeout") { @@ -294,10 +263,10 @@ TEST_CASE("Open connection with timeout", "[udp]") { ParallelSection::biSection( [&](auto &) { std::this_thread::sleep_for(wait); - responder.sendTo("1", requester_address); + responder->sendTo("1", requester_address); }, [&](auto &) { - auto connected_result = requester.connect(timeout); + auto connected_result = requester->connect(timeout); REQUIRE(connected_result); CHECK(are_same(connected_result->getRemoteAddress(), responder_address, family)); @@ -309,13 +278,13 @@ TEST_CASE("Reserve random port for udp connection", "[udp]") { const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); auto requester_port = ANY_PORT; - UdpBinded requester(requester_port, family); + udp::Udp requester(requester_port, family); REQUIRE(requester.open()); requester_port = requester.getPortToBind(); const Address requester_address = Address(requester_port, family); - auto responder_port = GENERATE(PortFactory::makePort(), ANY_PORT); - UdpBinded responder(responder_port, family); + auto responder_port = GENERATE(PortFactory::get().makePort(), ANY_PORT); + udp::Udp responder(responder_port, family); REQUIRE(responder.open()); responder_port = responder.getPortToBind(); const Address responder_address = Address(responder_port, family); @@ -342,12 +311,10 @@ TEST_CASE("Reserve random port for udp connection", "[udp]") { } /* - -TEST_CASE("Send Receive messages split into multiple pieces (udp)", - "[udp][!mayfail]") { +TEST_CASE("Send Receive messages split into multiple pieces (udp)", "[udp]") { const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); - UDP_PEERS(PortFactory::makePort(), PortFactory::makePort(), family); + UDP_PEERS(udp::Udp, family); const std::string request = "This is a simulated long message"; @@ -357,13 +324,13 @@ TEST_CASE("Send Receive messages split into multiple pieces (udp)", SECTION("split receive") { ParallelSection::biSection( [&](Barrier &br) { - requester.sendTo(request, responder_address); + requester->sendTo(request, responder_address); br.arrive_and_wait(); }, [&](Barrier &br) { br.arrive_and_wait(); auto received_request = - sliced_receive(responder, request.size(), 4); + sliced_receive(*responder, request.size(), 4); CHECK(received_request == request); }); } @@ -371,12 +338,12 @@ TEST_CASE("Send Receive messages split into multiple pieces (udp)", SECTION("split send") { ParallelSection::biSection( [&](Barrier &br) { - sliced_send(requester, request, responder_address, 4); + sliced_send(*requester, request, responder_address, 4); br.arrive_and_wait(); }, [&](Barrier &br) { br.arrive_and_wait(); - auto received_request = responder.receive(request.size()); + auto received_request = responder->receive(request.size()); CHECK(received_request); CHECK(received_request->received_message == request); }); @@ -384,8 +351,8 @@ TEST_CASE("Send Receive messages split into multiple pieces (udp)", } SECTION("connected") { - auto requester_conn = requester.connect(responder_address); - auto responder_conn = responder.connect(requester_address); + auto requester_conn = requester->connect(responder_address); + auto responder_conn = responder->connect(requester_address); SECTION("split receive") { ParallelSection::biSection( [&](Barrier &br) { @@ -414,5 +381,30 @@ TEST_CASE("Send Receive messages split into multiple pieces (udp)", } } } - */ + +TEST_CASE("Receive from unknown non blocking", "[udp]") { + const auto port = PortFactory::get().makePort(); + const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); + + UDP_PEERS(udp::Udp, family); + + CHECK_FALSE(responder->receive(request.size()).has_value()); + requester->sendTo(request, responder_address); + auto received_request = responder->receive(request.size()); + REQUIRE(received_request.has_value()); + CHECK(received_request->received_message == request); + CHECK(received_request->sender == requester_address); +} + +TEST_CASE("Receive non blocking (udp)", "[udp]") { + const auto port = PortFactory::get().makePort(); + const auto family = GENERATE(AddressFamily::IP_V4, AddressFamily::IP_V6); + + UDP_PEERS(udp::UdpConnected, family); + + CHECK(responder->receive(request.size()).empty()); + requester->send(request); + auto received_request = responder->receive(request.size()); + CHECK(received_request == request); +}