Skip to content

Commit

Permalink
Posixify link reciever
Browse files Browse the repository at this point in the history
  • Loading branch information
madhurajayaraman committed Oct 8, 2024
1 parent c1d6aae commit db2a599
Showing 1 changed file with 97 additions and 72 deletions.
169 changes: 97 additions & 72 deletions starboard/shared/starboard/link_receiver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "starboard/shared/starboard/link_receiver.h"

#include <netinet/in.h>
#include <sys/socket.h>
#include <atomic>
#include <memory>
#include <string>
Expand All @@ -35,81 +37,105 @@ namespace shared {
namespace starboard {

namespace {

#if defined(SOMAXCONN)
const int kMaxConn = SOMAXCONN;
#else
const int kMaxConn = 128;
#endif

// Creates a socket that is appropriate for binding and listening, but is not
// bound and hasn't started listening yet.
std::unique_ptr<Socket> CreateServerSocket(SbSocketAddressType address_type) {
std::unique_ptr<Socket> socket(new Socket(address_type));
if (!socket->IsValid()) {
int CreateServerSocket(SbSocketAddressType address_type) {
int socket_fd = -1;
switch (address_type) {
case kSbSocketAddressTypeIpv4:
socket_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
break;
case kSbSocketAddressTypeIpv6:
socket_fd = socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP);
break;
default:
break;
}
if (socket_fd < 0) {
SB_LOG(ERROR) << __FUNCTION__ << ": "
<< "SbSocketCreate failed";
return std::unique_ptr<Socket>();
<< "Socket create failed";
return -1;
}

if (!socket->SetReuseAddress(true)) {
int on = 1;
if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) != 0) {
SB_LOG(ERROR) << __FUNCTION__ << ": "
<< "SbSocketSetReuseAddress failed";
return std::unique_ptr<Socket>();
return -1;
}

return socket;
return socket_fd;
}

// Creates a server socket that is bound to the loopback interface.
std::unique_ptr<Socket> CreateLocallyBoundSocket(
SbSocketAddressType address_type,
int port) {
std::unique_ptr<Socket> socket = CreateServerSocket(address_type);
if (!socket) {
return std::unique_ptr<Socket>();
int CreateLocallyBoundSocket(SbSocketAddressType address_type, int port) {
int socket = CreateServerSocket(address_type);
if (socket < 0) {
return -1;
}

SbSocketAddress address = {};
bool success = GetLocalhostAddress(address_type, port, &address);
if (!success) {
socklen_t socklen;
struct sockaddr_in addr_in = {0};
int local_add_result =
getsockname(socket, reinterpret_cast<sockaddr*>(&addr_in), &socklen);

if (local_add_result < 0) {
SB_LOG(ERROR) << "GetLocalhostAddress failed";
return std::unique_ptr<Socket>();
return -1;
}
SbSocketError result = socket->Bind(&address);
if (result != kSbSocketOk) {

int bind_result =
bind(socket, reinterpret_cast<sockaddr*>(&addr_in), sizeof(sockaddr));

if (bind_result != 0) {
SB_LOG(ERROR) << __FUNCTION__ << ": "
<< "SbSocketBind to " << port << " failed: " << result;
return std::unique_ptr<Socket>();
<< "SbSocketBind to " << port << " failed: " << bind_result;
return -1;
}

return socket;
}

// Creates a server socket that is bound and listening to the loopback interface
// on the given port.
std::unique_ptr<Socket> CreateListeningSocket(SbSocketAddressType address_type,
int port) {
std::unique_ptr<Socket> socket = CreateLocallyBoundSocket(address_type, port);
if (!socket) {
return std::unique_ptr<Socket>();
int CreateListeningSocket(SbSocketAddressType address_type, int port) {
int socket = CreateLocallyBoundSocket(address_type, port);
if (socket < 0) {
return -1;
}

SbSocketError result = socket->Listen();
if (result != kSbSocketOk) {
int listen_result = listen(socket, kMaxConn);
if (listen_result != 0) {
SB_LOG(ERROR) << __FUNCTION__ << ": "
<< "SbSocketListen failed: " << result;
return std::unique_ptr<Socket>();
<< "SbSocketListen failed: " << listen_result;
return -1;
}

return socket;
}

// Gets the port socket is bound to.
bool GetBoundPort(Socket* socket, int* out_port) {
bool GetBoundPort(int socket, int* out_port) {
SB_DCHECK(out_port);
SB_DCHECK(socket);
SB_DCHECK(socket >= 0);

SbSocketAddress socket_address = {0};
bool result = socket->GetLocalAddress(&socket_address);
if (!result) {
socklen_t socklen;
struct sockaddr_in socket_address = {0};
int local_address = getsockname(
socket, reinterpret_cast<sockaddr*>(&socket_address), &socklen);

if (local_address != 0) {
return false;
}

*out_port = socket_address.port;
*out_port = socket_address.sin_port;
return true;
}

Expand Down Expand Up @@ -157,8 +183,7 @@ class LinkReceiver::Impl {
private:
// Encapsulates connection state.
struct Connection {
explicit Connection(std::unique_ptr<Socket> socket)
: socket(std::move(socket)) {}
explicit Connection(int socket_fd) {}
~Connection() {}
void FlushLink(Application* application) {
if (!data.empty()) {
Expand All @@ -167,7 +192,7 @@ class LinkReceiver::Impl {
}
}

std::unique_ptr<Socket> socket;
int socket_fd;
std::string data;
};

Expand All @@ -177,17 +202,17 @@ class LinkReceiver::Impl {

// Adds |socket| to the SbSocketWaiter to wait until ready for accepting a new
// connection.
bool AddForAccept(Socket* socket);
bool AddForAccept(int socket_fd);

// Adds the |connection| to the SbSocketWaiter to wait until ready to read
// more data.
bool AddForRead(Connection* connection);
bool AddForRead(Connection* socket);

// Called when the listening socket has a connection available to accept.
void OnAcceptReady();

// Called when the waiter reports that a socket has more data to read.
void OnReadReady(SbSocket sb_socket);
void OnReadReady(int socket_fd);

// Called when the waiter reports that a connection has more data to read.
void OnReadReady(Connection* connection);
Expand All @@ -197,11 +222,11 @@ class LinkReceiver::Impl {

// SbSocketWaiter entry points.
static void HandleAccept(SbSocketWaiter waiter,
SbSocket socket,
int socket_fd,
void* context,
int ready_interests);
static void HandleRead(SbSocketWaiter waiter,
SbSocket socket,
int socket_fd,
void* context,
int ready_interests);

Expand Down Expand Up @@ -233,10 +258,10 @@ class LinkReceiver::Impl {
Semaphore destroy_waiter_;

// The server socket listening for new connections.
std::unique_ptr<Socket> listen_socket_;
int listen_socket_;

// A map of raw SbSockets to Connection objects.
std::unordered_map<SbSocket, Connection*> connections_;
std::unordered_map<int, Connection*> connections_;
};

LinkReceiver::Impl::Impl(Application* application, int port)
Expand Down Expand Up @@ -270,11 +295,11 @@ void LinkReceiver::Impl::Run() {

listen_socket_ =
CreateListeningSocket(kSbSocketAddressTypeIpv4, specified_port_);
if (!listen_socket_ || !listen_socket_->IsValid()) {
if (listen_socket_ < 0) {
listen_socket_ =
CreateListeningSocket(kSbSocketAddressTypeIpv6, specified_port_);
}
if (!listen_socket_ || !listen_socket_->IsValid()) {
if (listen_socket_ < 0) {
SB_LOG(WARNING) << "Unable to start LinkReceiver on port "
<< specified_port_ << ".";
SbSocketWaiterDestroy(waiter_);
Expand All @@ -284,7 +309,7 @@ void LinkReceiver::Impl::Run() {
}

actual_port_ = 0;
bool result = GetBoundPort(listen_socket_.get(), &actual_port_);
bool result = GetBoundPort(listen_socket_, &actual_port_);
if (!result) {
SB_LOG(WARNING) << "Unable to get LinkReceiver bound port.";
SbSocketWaiterDestroy(waiter_);
Expand All @@ -297,7 +322,7 @@ void LinkReceiver::Impl::Run() {
snprintf(port_string, SB_ARRAY_SIZE(port_string), "%d", actual_port_);
CreateTemporaryFile("link_receiver_port", port_string, strlen(port_string));

if (!AddForAccept(listen_socket_.get())) {
if (!AddForAccept(listen_socket_)) {
quit_.store(true);
}

Expand All @@ -307,33 +332,33 @@ void LinkReceiver::Impl::Run() {
}

for (auto& entry : connections_) {
SbSocketWaiterRemove(waiter_, entry.first);
SbPosixSocketWaiterRemove(waiter_, entry.first);
delete entry.second;
}
connections_.clear();

SbSocketWaiterRemove(waiter_, listen_socket_->socket());
SbPosixSocketWaiterRemove(waiter_, listen_socket_);

// Block until destroying thread will no longer reference waiter.
destroy_waiter_.Take();
SbSocketWaiterDestroy(waiter_);
}

bool LinkReceiver::Impl::AddForAccept(Socket* socket) {
if (!SbSocketWaiterAdd(waiter_, socket->socket(), this,
&LinkReceiver::Impl::HandleAccept,
kSbSocketWaiterInterestRead, true)) {
bool LinkReceiver::Impl::AddForAccept(int socket_fd) {
if (!SbPosixSocketWaiterAdd(waiter_, socket_fd, this,
&LinkReceiver::Impl::HandleAccept,
kSbSocketWaiterInterestRead, true)) {
SB_LOG(ERROR) << __FUNCTION__ << ": "
<< "SbSocketWaiterAdd failed.";
<< "SbPosixSocketWaiterAdd failed.";
return false;
}
return true;
}

bool LinkReceiver::Impl::AddForRead(Connection* connection) {
if (!SbSocketWaiterAdd(waiter_, connection->socket->socket(), this,
&LinkReceiver::Impl::HandleRead,
kSbSocketWaiterInterestRead, false)) {
if (!SbPosixSocketWaiterAdd(waiter_, connection->socket_fd, this,
&LinkReceiver::Impl::HandleRead,
kSbSocketWaiterInterestRead, false)) {
SB_LOG(ERROR) << __FUNCTION__ << ": "
<< "SbSocketWaiterAdd failed.";
return false;
Expand All @@ -342,25 +367,25 @@ bool LinkReceiver::Impl::AddForRead(Connection* connection) {
}

void LinkReceiver::Impl::OnAcceptReady() {
std::unique_ptr<Socket> accepted_socket =
std::unique_ptr<Socket>(listen_socket_->Accept());
int accepted_socket = accept(listen_socket_, NULL, NULL);
SB_DCHECK(accepted_socket);
Connection* connection = new Connection(std::move(accepted_socket));
connections_.emplace(connection->socket->socket(), connection);
Connection* connection = new Connection(accepted_socket);
connections_.emplace(connection->socket_fd, connection);
AddForRead(connection);
}

void LinkReceiver::Impl::OnReadReady(SbSocket sb_socket) {
auto iter = connections_.find(sb_socket);
void LinkReceiver::Impl::OnReadReady(int socket_fd) {
auto iter = connections_.find(socket_fd);
SB_DCHECK(iter != connections_.end());
OnReadReady(iter->second);
}

void LinkReceiver::Impl::OnReadReady(Connection* connection) {
auto socket = connection->socket.get();
int socket = connection->socket_fd;

char data[64] = {0};
int read = socket->ReceiveFrom(data, SB_ARRAY_SIZE_INT(data), NULL);
ssize_t bytes_read = recv(socket, data, SB_ARRAY_SIZE_INT(data), NULL);
int read = static_cast<int>(bytes_read);
int last_null = 0;
for (int position = 0; position < read; ++position) {
if (data[position] == '\0' || data[position] == '\n' ||
Expand All @@ -383,7 +408,7 @@ void LinkReceiver::Impl::OnReadReady(Connection* connection) {
if (read == 0) {
// Terminate connection.
connection->FlushLink(application_);
connections_.erase(socket->socket());
connections_.erase(socket);
delete connection;
return;
}
Expand All @@ -401,7 +426,7 @@ void* LinkReceiver::Impl::RunThread(void* context) {

// static
void LinkReceiver::Impl::HandleAccept(SbSocketWaiter waiter,
SbSocket socket,
int socket,
void* context,
int ready_interests) {
SB_DCHECK(context);
Expand All @@ -410,7 +435,7 @@ void LinkReceiver::Impl::HandleAccept(SbSocketWaiter waiter,

// static
void LinkReceiver::Impl::HandleRead(SbSocketWaiter waiter,
SbSocket socket,
int socket,
void* context,
int ready_interests) {
SB_DCHECK(context);
Expand Down

0 comments on commit db2a599

Please sign in to comment.