Skip to content

Commit

Permalink
iocp_asio 实现 WSASendTo WSARecvFrom CancelIo
Browse files Browse the repository at this point in the history
  • Loading branch information
microcai committed Dec 24, 2024
1 parent b9bd961 commit 725e042
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 87 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ of what wepoll does.
| PostQueuedCompletionStatus |||
| WSASend |||
| WSARecv |||
| WSASendTo || 实现中 |
| WSARecvFrom || 实现中 |
| WSASendTo || |
| WSARecvFrom || |
| AcceptEx |||
| WSAConnectEx |||
| DisconnectEx |||
| CreateFileA |||
| CreateFileW |||
| ReadFile || API 可用,为同步模拟 |
| WriteFile || API 可用,为同步模拟 |
| ReadFile || API 可用,为同步模拟 |
| WriteFile || API 可用,为同步模拟 |
| CloseHandle |||
| CancelIo || 不支持 |
| CancelIo || |
| CancelIoEx |||


Expand Down
37 changes: 33 additions & 4 deletions iocp_asio/src/internal_iocp_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ struct SOCKET_emu_class final : public base_handle
udp_sock& udp_socket() { return std::get<udp_sock>(sock_); };
acceptor& accept_socket() { return std::get<acceptor>(sock_); };

auto cancel_all()
{
return std::visit([](auto& s)
{
if constexpr (!std::is_same_v<decltype(s), normal_file&>)
{
asio::error_code ignore_ec;
s.cancel(ignore_ec);
return true;
}
return false;
}, sock_);
}

template<typename Handler>
void async_accept(SOCKET_emu_class* into, Handler&& handler)
{
Expand All @@ -210,8 +224,8 @@ struct SOCKET_emu_class final : public base_handle
accept_sock.async_accept(into->tcp_socket(), std::forward<Handler>(handler));
}

template<typename Handler>
void async_connect(asio::ip::address addr, asio::ip::port_type port, Handler&& handler)
template<typename Protocal, typename Handler>
void async_connect(asio::ip::basic_endpoint<Protocal> endpoint, Handler&& handler)
{
assert(_iocp);
if (std::holds_alternative<normal_file>(sock_))
Expand All @@ -232,11 +246,11 @@ struct SOCKET_emu_class final : public base_handle

if (type == SOCK_STREAM)
{
tcp_socket().async_connect(asio::ip::tcp::endpoint{addr, port}, std::forward<Handler>(handler));
tcp_socket().async_connect(asio::ip::tcp::endpoint{endpoint.address(), endpoint.port()}, std::forward<Handler>(handler));
}
else if (type == SOCK_DGRAM)
{
udp_socket().async_connect(asio::ip::udp::endpoint{addr, port}, std::forward<Handler>(handler));
udp_socket().async_connect(asio::ip::udp::endpoint{endpoint.address(), endpoint.port()}, std::forward<Handler>(handler));
}
}

Expand Down Expand Up @@ -275,6 +289,21 @@ struct SOCKET_emu_class final : public base_handle
}
}

template<typename Buffer, typename Protocol, typename Handler>
void async_receive_from(Buffer&& buf, asio::ip::basic_endpoint<Protocol>& endpoint, Handler&& handler)
{
assert(type == SOCK_DGRAM);

if (type == SOCK_DGRAM)
{
udp_socket().async_receive_from(buf, endpoint, std::forward<Handler>(handler));
}
else
{
handler(asio::error::make_error_code(asio::error::operation_not_supported), 0);
}
}

};

}
151 changes: 73 additions & 78 deletions iocp_asio/src/iocp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,11 @@ IOCP_DECL BOOL WINAPI CancelIo(_In_ HANDLE hFile)

iocp_handle_emu_class* iocp = s->_iocp;

return FALSE;
return s->cancel_all();
}

IOCP_DECL BOOL WINAPI CancelIoEx(_In_ HANDLE hFile, _In_opt_ LPOVERLAPPED lpOverlapped)
{
// SOCKET_emu_class* s = dynamic_cast<SOCKET_emu_class*>(hFile);

// iocp_handle_emu_class* iocp = s->_iocp;

// std::scoped_lock<std::mutex> l(iocp->result_mutex);

reinterpret_cast<asio_operation*>(lpOverlapped->Internal)->cancel_signal.emit(asio::cancellation_type::partial);

return TRUE;
Expand Down Expand Up @@ -287,6 +281,31 @@ IOCP_DECL void GetAcceptExSockaddrs(_In_ PVOID lpOutputBuffer, _In_ DWORD dwRece
*RemoteSockaddr = reinterpret_cast<sockaddr*>(reinterpret_cast<char*>(lpOutputBuffer) + local_addr_length + 2);
}

template <typename Protocal>
static asio::ip::basic_endpoint<Protocal> from_sockaddr(const sockaddr* name)
{
asio::ip::address addr;
asio::ip::port_type port;

if ( name->sa_family == AF_INET)
{
sockaddr_in * v4addr = (sockaddr_in*) name;
asio::ip::address_v4::bytes_type native_sin_addr;
memcpy(&native_sin_addr, &(v4addr->sin_addr), 4);
addr = asio::ip::address_v4{native_sin_addr};
port = ntohs(v4addr->sin_port);
}
else if ( name->sa_family == AF_INET6)
{
sockaddr_in6 * v6addr = (sockaddr_in6*) name;
asio::ip::address_v6::bytes_type native_sin_addr;
memcpy(&native_sin_addr, &(v6addr->sin6_addr), 4);
addr = asio::ip::address_v6{native_sin_addr, v6addr->sin6_scope_id};
port = ntohs(v6addr->sin6_port);
}

return asio::ip::basic_endpoint<Protocal>{addr, port};
}

IOCP_DECL BOOL WSAConnectEx(
_In_ SOCKET socket_,
Expand Down Expand Up @@ -328,28 +347,9 @@ IOCP_DECL BOOL WSAConnectEx(
op->dwSendDataLength = dwSendDataLength;
op->sock = s;

asio::ip::address remote_ip;
asio::ip::port_type remote_port;
asio::ip::tcp::endpoint endpoint = from_sockaddr<asio::ip::tcp>(name);


if ( name->sa_family == AF_INET)
{
sockaddr_in * v4addr = (sockaddr_in*) name;
asio::ip::address_v4::bytes_type native_sin_addr;
memcpy(&native_sin_addr, &(v4addr->sin_addr), 4);
remote_ip = asio::ip::address_v4{native_sin_addr};
remote_port = ntohs(v4addr->sin_port);
}
else if ( name->sa_family == AF_INET6)
{
sockaddr_in6 * v6addr = (sockaddr_in6*) name;
asio::ip::address_v6::bytes_type native_sin_addr;
memcpy(&native_sin_addr, &(v6addr->sin6_addr), 4);
remote_ip = asio::ip::address_v6{native_sin_addr, v6addr->sin6_scope_id};
remote_port = ntohs(v6addr->sin6_port);
}

s->async_connect(remote_ip, remote_port, asio::bind_cancellation_slot(op->cancel_signal.slot(), [op, iocp](asio::error_code ec)
s->async_connect(endpoint, asio::bind_cancellation_slot(op->cancel_signal.slot(), [op, iocp](asio::error_code ec)
{
op->last_error = ec.value();

Expand Down Expand Up @@ -478,7 +478,6 @@ IOCP_DECL int WSARecv(_In_ SOCKET socket_, _Inout_ LPWSABUF lpBuffers, _In_ DWOR
return SOCKET_ERROR;
}

#if 0

IOCP_DECL int WSASendTo(
_In_ SOCKET socket_,
Expand All @@ -495,54 +494,57 @@ IOCP_DECL int WSASendTo(

if (s->_iocp == nullptr && lpOverlapped) [[unlikely]]
{
WSASetLastError(WSAEOPNOTSUPP);
WSASetLastError(EOPNOTSUPP);
return SOCKET_ERROR;
}

iocp_handle_emu_class* iocp = s->_iocp;

assert(lpOverlapped);

*lpNumberOfBytesSent = 0;
lpOverlapped->InternalHigh = (ULONG_PTR) __builtin_extract_return_addr (__builtin_return_address (0));

if (lpNumberOfBytesSent) [[likely]]
*lpNumberOfBytesSent = 0;

struct io_uring_write_op : io_uring_operations
struct write_op : asio_operation
{
std::vector<iovec> msg_iov;
msghdr msg = {};
virtual void do_complete(io_uring_cqe* cqe, DWORD* lpNumberOfBytes) override
{
if (cqe->res < 0)
WSASetLastError(-cqe->res);
}
std::vector<asio::const_buffer> buffers;
};

// now, enter IOCP emul logic
io_uring_write_op* op = io_uring_operation_allocator{}.allocate<io_uring_write_op>();
write_op* op = new write_op;
op->lpCompletionRoutine = lpCompletionRoutine;
op->overlapped_ptr = lpOverlapped;
lpOverlapped->Internal = reinterpret_cast<ULONG_PTR>(op);
op->CompletionKey = s->_completion_key;
op->msg_iov.resize(dwBufferCount);
op->msg.msg_iovlen = dwBufferCount;
op->msg.msg_iov = op->msg_iov.data();
op->msg.msg_name = (void*) lpTo;
op->msg.msg_namelen = iTolen;

op->buffers.resize(dwBufferCount);
int64_t total_send_bytes = 0;
for (int i = 0; i < dwBufferCount; i++)
{
op->msg_iov[i].iov_base = lpBuffers[i].buf;
op->msg_iov[i].iov_len = lpBuffers[i].len;
op->buffers[i] = asio::buffer(lpBuffers[i].buf, lpBuffers[i].len);
total_send_bytes += lpBuffers[i].len;
}

iocp->submit_io([&](struct io_uring_sqe* sqe)
asio::ip::udp::endpoint dest = from_sockaddr<asio::ip::udp>(lpTo);


s->async_sendto(op->buffers, dest, asio::bind_cancellation_slot(op->cancel_signal.slot(), [iocp, op](asio::error_code ec, std::size_t bytes_transfered)
{
io_uring_prep_sendmsg_zc(sqe, s->_socket_fd, &op->msg, 0);
io_uring_sqe_set_data(sqe, op);
});
op->last_error = ec.value();
op->NumberOfBytes = bytes_transfered;

std::scoped_lock<std::mutex> l(iocp->result_mutex);
iocp->results_.emplace_back(op);
}));

WSASetLastError(ERROR_IO_PENDING);
return SOCKET_ERROR;
}



IOCP_DECL int WSARecvFrom(
_In_ SOCKET socket_,
_In_ LPWSABUF lpBuffers,
Expand All @@ -567,51 +569,44 @@ IOCP_DECL int WSARecvFrom(
// *lpNumberOfBytesRecvd = 0;
assert(lpOverlapped);

struct io_uring_read_op : io_uring_operations
struct read_op : asio_operation
{
LPINT lpFromlen;
std::vector<iovec> msg_iov;
msghdr msg = {};
virtual void do_complete(io_uring_cqe* cqe, DWORD* lpNumberOfBytes) override
{
if (cqe->res < 0) [[unlikely]]
WSASetLastError(-cqe->res);
else if (cqe->res == 0) [[unlikely]]
WSASetLastError(ERROR_HANDLE_EOF);
else [[likely]]
* lpFromlen = msg.msg_namelen;
}
sockaddr* out_remote_addr;
asio::ip::udp::endpoint remote_addr;
std::vector<asio::mutable_buffer> buffers;
};

// now, enter IOCP emul logic
io_uring_read_op* op = io_uring_operation_allocator{}.allocate<io_uring_read_op>();
read_op* op = new read_op{};

op->out_remote_addr = lpFrom;
op->lpCompletionRoutine = lpCompletionRoutine;
op->overlapped_ptr = lpOverlapped;
lpOverlapped->Internal = reinterpret_cast<ULONG_PTR>(op);
op->CompletionKey = s->_completion_key;
op->msg_iov.resize(dwBufferCount);
op->msg.msg_iovlen = dwBufferCount;
op->msg.msg_iov = op->msg_iov.data();
op->msg.msg_name = lpFrom;
op->msg.msg_namelen = *lpFromlen;
op->lpFromlen = lpFromlen;

op->buffers.resize(dwBufferCount);
int64_t total_send_bytes = 0;
for (int i = 0; i < dwBufferCount; i++)
{
op->msg_iov[i].iov_base = lpBuffers[i].buf;
op->msg_iov[i].iov_len = lpBuffers[i].len;
op->buffers[i] = asio::buffer(lpBuffers[i].buf, lpBuffers[i].len);
total_send_bytes += lpBuffers[i].len;
}

iocp->submit_io([&](struct io_uring_sqe* sqe)
s->async_receive_from(op->buffers, op->remote_addr, asio::bind_cancellation_slot(op->cancel_signal.slot(), [iocp, op](asio::error_code ec, std::size_t bytes_transfered)
{
io_uring_prep_recvmsg(sqe, s->_socket_fd, &(op->msg), 0);
io_uring_sqe_set_data(sqe, op);
});
op->last_error = ec.value();
op->NumberOfBytes = bytes_transfered;

memcpy(op->out_remote_addr, op->remote_addr.data(), op->remote_addr.size());

std::scoped_lock<std::mutex> l(iocp->result_mutex);
iocp->results_.emplace_back(op);
}));

WSASetLastError(ERROR_IO_PENDING);
return SOCKET_ERROR;
}
#endif

IOCP_DECL BOOL DisconnectEx(
_In_ SOCKET hSocket,
Expand Down

0 comments on commit 725e042

Please sign in to comment.