Skip to content

Commit

Permalink
整理代码
Browse files Browse the repository at this point in the history
  • Loading branch information
actboy168 committed Nov 26, 2024
1 parent ba11929 commit a62798d
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 143 deletions.
146 changes: 3 additions & 143 deletions bee/net/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@
// clang-format off
# include <winsock2.h>
// clang-format on
# include <bee/nonstd/charconv.h>
# include <bee/utility/dynarray.h>
# include <bee/net/uds_win.h>
# include <bee/win/wtf8.h>
# include <mstcpip.h>
# include <mswsock.h>

# include <array>
# include <limits>
#else
# include <fcntl.h>
# include <netinet/in.h>
Expand All @@ -27,10 +23,6 @@
#include <bee/net/socket.h>
#include <bee/nonstd/unreachable.h>

#if defined(__MINGW32__)
# define WSA_FLAG_NO_HANDLE_INHERIT 0x80
#endif

namespace bee::net::socket {
static bool net_success(int x) noexcept {
return x == 0;
Expand All @@ -39,140 +31,8 @@ namespace bee::net::socket {
#if defined(_WIN32)
static_assert(sizeof(SOCKET) == sizeof(fd_t));

namespace fileutil {
static FILE* open(zstring_view filename, const wchar_t* mode) noexcept {
# if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable : 4996)
# endif
return _wfopen(wtf8::u2w(filename).c_str(), mode);
# if defined(_MSC_VER)
# pragma warning(pop)
# endif
}
static size_t size(FILE* f) noexcept {
_fseeki64(f, 0, SEEK_END);
long long size = _ftelli64(f);
_fseeki64(f, 0, SEEK_SET);
return (size_t)size;
}
static size_t read(FILE* f, void* buf, size_t sz) noexcept {
return fread(buf, sizeof(char), sz, f);
}
static size_t write(FILE* f, const void* buf, size_t sz) noexcept {
return fwrite(buf, sizeof(char), sz, f);
}
static void close(FILE* f) noexcept {
fclose(f);
}
}

static std::string file_read(zstring_view filename) noexcept {
FILE* f = fileutil::open(filename, L"rb");
if (!f) {
return std::string();
}
std::string result(fileutil::size(f), '\0');
fileutil::read(f, result.data(), result.size());
fileutil::close(f);
return result;
}

static bool file_write(zstring_view filename, const std::string& value) noexcept {
FILE* f = fileutil::open(filename, L"wb");
if (!f) {
return false;
}
fileutil::write(f, value.data(), value.size());
fileutil::close(f);
return true;
}

static bool read_tcp_port(const endpoint& ep, uint16_t& tcpport) noexcept {
auto [type, path] = ep.get_unix();
if (type != un_format::pathname) {
return false;
}
auto unixpath = file_read(path);
if (unixpath.empty()) {
return false;
}
if (auto [p, ec] = std::from_chars(unixpath.data(), unixpath.data() + unixpath.size(), tcpport); ec != std::errc()) {
return false;
}
if (tcpport <= 0 || tcpport > (std::numeric_limits<uint16_t>::max)()) {
return false;
}
return true;
}

static bool write_tcp_port(zstring_view path, fd_t s) noexcept {
endpoint ep;
if (socket::getsockname(s, ep)) {
auto tcpport = ep.get_port();
std::array<char, 10> portstr;
if (auto [p, ec] = std::to_chars(portstr.data(), portstr.data() + portstr.size() - 1, tcpport); ec != std::errc()) {
return false;
} else {
p[0] = '\0';
}
return file_write(path, portstr.data());
}
return false;
}

static status u_connect(fd_t s, const endpoint& ep) noexcept {
uint16_t tcpport = 0;
if (!read_tcp_port(ep, tcpport)) {
::WSASetLastError(WSAECONNREFUSED);
return status::failed;
}
return socket::connect(s, endpoint::from_localhost(tcpport));
}

static bool u_bind(fd_t s, const endpoint& ep) {
const bool ok = socket::bind(s, endpoint::from_localhost(0));
if (!ok) {
return ok;
}
auto [type, path] = ep.get_unix();
if (type != un_format::pathname) {
::WSASetLastError(WSAENETDOWN);
return false;
}
if (!write_tcp_port(path, s)) {
::WSASetLastError(WSAENETDOWN);
return false;
}
return true;
}

static WSAPROTOCOL_INFOW UnixProtocol;
static bool supportUnixDomainSocket_() noexcept {
static GUID AF_UNIX_PROVIDER_ID = { 0xA00943D9, 0x9C2E, 0x4633, { 0x9B, 0x59, 0x00, 0x57, 0xA3, 0x16, 0x09, 0x94 } };
DWORD len = 0;
::WSAEnumProtocolsW(0, NULL, &len);
dynarray<std::byte> buf(len);
LPWSAPROTOCOL_INFOW protocols = (LPWSAPROTOCOL_INFOW)buf.data();
const int n = ::WSAEnumProtocolsW(0, protocols, &len);
if (n == SOCKET_ERROR) {
return false;
}
for (int i = 0; i < n; ++i) {
if (protocols[i].iAddressFamily == AF_UNIX && IsEqualGUID(protocols[i].ProviderId, AF_UNIX_PROVIDER_ID)) {
const fd_t fd = ::WSASocketW(PF_UNIX, SOCK_STREAM, 0, &protocols[i], 0, WSA_FLAG_NO_HANDLE_INHERIT);
if (fd == retired_fd) {
return false;
}
::closesocket(fd);
UnixProtocol = protocols[i];
return true;
}
}
return false;
}
static bool supportUnixDomainSocket() noexcept {
static bool support = supportUnixDomainSocket_();
static bool support = u_support();
return support;
}
#endif
Expand Down Expand Up @@ -287,7 +147,7 @@ namespace bee::net::socket {

static fd_t createSocket(int af, int type, int protocol, fd_flags fd_flags) noexcept {
#if defined(_WIN32)
const fd_t fd = ::WSASocketW(af, type, protocol, af == PF_UNIX ? &UnixProtocol : NULL, 0, WSA_FLAG_NO_HANDLE_INHERIT);
const fd_t fd = u_createSocket(af, type, protocol, fd_flags);
if (fd == retired_fd) {
return retired_fd;
}
Expand Down
159 changes: 159 additions & 0 deletions bee/net/uds_win.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// clang-format off
#include <winsock2.h>
// clang-format on
#include <bee/net/endpoint.h>
#include <bee/net/socket.h>
#include <bee/net/uds_win.h>
#include <bee/nonstd/charconv.h>
#include <bee/nonstd/unreachable.h>
#include <bee/utility/dynarray.h>
#include <bee/win/wtf8.h>
#include <mstcpip.h>
#include <mswsock.h>

#include <array>
#include <limits>

#if defined(__MINGW32__)
# define WSA_FLAG_NO_HANDLE_INHERIT 0x80
#endif

namespace bee::net::socket {

namespace fileutil {
static FILE* open(zstring_view filename, const wchar_t* mode) noexcept {
#if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable : 4996)
#endif
return _wfopen(wtf8::u2w(filename).c_str(), mode);
#if defined(_MSC_VER)
# pragma warning(pop)
#endif
}
static size_t size(FILE* f) noexcept {
_fseeki64(f, 0, SEEK_END);
long long size = _ftelli64(f);
_fseeki64(f, 0, SEEK_SET);
return (size_t)size;
}
static size_t read(FILE* f, void* buf, size_t sz) noexcept {
return fread(buf, sizeof(char), sz, f);
}
static size_t write(FILE* f, const void* buf, size_t sz) noexcept {
return fwrite(buf, sizeof(char), sz, f);
}
static void close(FILE* f) noexcept {
fclose(f);
}
}

static std::string file_read(zstring_view filename) noexcept {
FILE* f = fileutil::open(filename, L"rb");
if (!f) {
return std::string();
}
std::string result(fileutil::size(f), '\0');
fileutil::read(f, result.data(), result.size());
fileutil::close(f);
return result;
}

static bool file_write(zstring_view filename, const std::string& value) noexcept {
FILE* f = fileutil::open(filename, L"wb");
if (!f) {
return false;
}
fileutil::write(f, value.data(), value.size());
fileutil::close(f);
return true;
}

static bool read_tcp_port(const endpoint& ep, uint16_t& tcpport) noexcept {
auto [type, path] = ep.get_unix();
if (type != un_format::pathname) {
return false;
}
auto unixpath = file_read(path);
if (unixpath.empty()) {
return false;
}
if (auto [p, ec] = std::from_chars(unixpath.data(), unixpath.data() + unixpath.size(), tcpport); ec != std::errc()) {
return false;
}
if (tcpport <= 0 || tcpport > (std::numeric_limits<uint16_t>::max)()) {
return false;
}
return true;
}

static bool write_tcp_port(zstring_view path, fd_t s) noexcept {
endpoint ep;
if (socket::getsockname(s, ep)) {
auto tcpport = ep.get_port();
std::array<char, 10> portstr;
if (auto [p, ec] = std::to_chars(portstr.data(), portstr.data() + portstr.size() - 1, tcpport); ec != std::errc()) {
return false;
} else {
p[0] = '\0';
}
return file_write(path, portstr.data());
}
return false;
}

status u_connect(fd_t s, const endpoint& ep) noexcept {
uint16_t tcpport = 0;
if (!read_tcp_port(ep, tcpport)) {
::WSASetLastError(WSAECONNREFUSED);
return status::failed;
}
return socket::connect(s, endpoint::from_localhost(tcpport));
}

bool u_bind(fd_t s, const endpoint& ep) {
const bool ok = socket::bind(s, endpoint::from_localhost(0));
if (!ok) {
return ok;
}
auto [type, path] = ep.get_unix();
if (type != un_format::pathname) {
::WSASetLastError(WSAENETDOWN);
return false;
}
if (!write_tcp_port(path, s)) {
::WSASetLastError(WSAENETDOWN);
return false;
}
return true;
}

static WSAPROTOCOL_INFOW UnixProtocol;
bool u_support() noexcept {
static GUID AF_UNIX_PROVIDER_ID = { 0xA00943D9, 0x9C2E, 0x4633, { 0x9B, 0x59, 0x00, 0x57, 0xA3, 0x16, 0x09, 0x94 } };
DWORD len = 0;
::WSAEnumProtocolsW(0, NULL, &len);
dynarray<std::byte> buf(len);
LPWSAPROTOCOL_INFOW protocols = (LPWSAPROTOCOL_INFOW)buf.data();
const int n = ::WSAEnumProtocolsW(0, protocols, &len);
if (n == SOCKET_ERROR) {
return false;
}
for (int i = 0; i < n; ++i) {
if (protocols[i].iAddressFamily == AF_UNIX && IsEqualGUID(protocols[i].ProviderId, AF_UNIX_PROVIDER_ID)) {
const fd_t fd = ::WSASocketW(PF_UNIX, SOCK_STREAM, 0, &protocols[i], 0, WSA_FLAG_NO_HANDLE_INHERIT);
if (fd == retired_fd) {
return false;
}
::closesocket(fd);
UnixProtocol = protocols[i];
return true;
}
}
return false;
}

fd_t u_createSocket(int af, int type, int protocol, fd_flags fd_flags) noexcept {
return ::WSASocketW(af, type, protocol, af == PF_UNIX ? &UnixProtocol : NULL, 0, WSA_FLAG_NO_HANDLE_INHERIT);
}
}
11 changes: 11 additions & 0 deletions bee/net/uds_win.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include <bee/net/fd.h>
#include <bee/net/socket.h>

namespace bee::net::socket {
status u_connect(fd_t s, const endpoint& ep) noexcept;
bool u_bind(fd_t s, const endpoint& ep);
bool u_support() noexcept;
fd_t u_createSocket(int af, int type, int protocol, fd_flags fd_flags) noexcept;
}

0 comments on commit a62798d

Please sign in to comment.