Skip to content

Commit

Permalink
Allow an Ansi build of StreamServer.
Browse files Browse the repository at this point in the history
This still uses a Unicode build of SSLServer so to permit that SSLServer no longer uses MFC - it uses ATL instead, which leads to a few minor changes as method names or parameters are a little different. A fw Windows calls are now explicitly "W" (wide) versions.

There's now a "Debug Ansi" build which builds SimpleClient, SteamClient and StreamServer using multibyte (ansi) and SSLClient and SSLServer using Unicode.
  • Loading branch information
david-maw committed Jan 27, 2020
1 parent 6beb14d commit 34553ba
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 45 deletions.
12 changes: 6 additions & 6 deletions SSLServer/Include/Listener.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class ISocketStream;
class CListener
{
public:
enum ErrorType {
enum class ErrorType {
NoError,
UnknownError,
SocketInuse,
Expand All @@ -20,19 +20,19 @@ class CListener
HANDLE m_hSocketEvents[FD_SETSIZE]{};
int m_iNumListenSockets{ 0 };
CCriticalSection m_WorkerCountLock;
CWinThread * m_ListenerThread{ nullptr };
static UINT __cdecl Worker(LPVOID);
static UINT __cdecl ListenerWorker(LPVOID);
uintptr_t m_ListenerThread { 0 };
static void __cdecl Worker(LPVOID);
static void __cdecl ListenerWorker(LPVOID);
void Listen();
std::function<void(ISocketStream * StreamSock)> m_actualwork;
public:
static void LogWarning(const WCHAR* const);
static void LogWarning(const CHAR* const);
int m_WorkerCount{ 0 };
CEvent m_StopEvent{ FALSE, TRUE };
CEvent m_StopEvent{ TRUE, FALSE };
// Initialize the listener
ErrorType Initialize(int TCPSocket);
std::function<SECURITY_STATUS(PCCERT_CONTEXT & pCertContext, LPCTSTR pszSubjectName)> SelectServerCert;
std::function<SECURITY_STATUS(PCCERT_CONTEXT & pCertContext, LPCWSTR pszSubjectName)> SelectServerCert;
std::function<bool(PCCERT_CONTEXT pCertContext, const bool trusted)> ClientCertAcceptable;
void EndListening();
void BeginListening(std::function<void(ISocketStream * StreamSock)> actualwork);
Expand Down
46 changes: 22 additions & 24 deletions SSLServer/Listener.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,22 @@ CListener::~CListener()

// This is the individual worker process, all it does is start, change its name to something useful,
// then call the Lambda function passed in via the BeginListening method
UINT __cdecl CListener::Worker(void * v)
void __cdecl CListener::Worker(LPVOID v)
{
std::unique_ptr<CSSLServer> SSLServer(reinterpret_cast<CSSLServer*>(v));
SetThreadName("Connection Worker");
// Invoke the caller provided function defining the work to do, passing an interface which
// allows the user code to send and receive messages and so on.
(SSLServer->GetListener()->m_actualwork)(SSLServer->GetSocketStream());
return 0;
}

// Worker process for connection listening
UINT __cdecl CListener::ListenerWorker(LPVOID v)
void __cdecl CListener::ListenerWorker(LPVOID v)
{
auto * Listener = static_cast<CListener*>(v); // See _beginthread call for parameter definition

SetThreadName("Listener");
Listener->Listen();
return 0;
}

// Initialize the listener, set up the socket to listen on, or return an error
Expand All @@ -56,7 +54,7 @@ CListener::ErrorType CListener::Initialize(int TCPSocket)

WSADATA wsadata;
if (WSAStartup(MAKEWORD(1, 1), &wsadata))
return UnknownError;
return CListener::ErrorType::UnknownError;

// Get list of addresses to listen on
ADDRINFOT Hints, *AddrInfo, *AI;
Expand All @@ -69,7 +67,7 @@ CListener::ErrorType CListener::Initialize(int TCPSocket)
WCHAR MsgText[100];
StringCchPrintf(MsgText, _countof(MsgText), L"getaddressinfo error: %i", GetLastError());
LogWarning(MsgText);
return UnknownError;
return CListener::ErrorType::UnknownError;
}

// Create one or more passive sockets to listen on
Expand All @@ -92,14 +90,14 @@ CListener::ErrorType CListener::Initialize(int TCPSocket)
nullptr); // no name

if (!(m_hSocketEvents[i]))
return UnknownError;
return CListener::ErrorType::UnknownError;

// StringCchPrintf(MsgText, _countof(MsgText), L"::OnInit Created m_hSocketEvents[%d], handle=%d"), i, m_hSocketEvents[i];
// LogWarning(MsgText);

m_iListenSockets[i] = WSASocket(AI->ai_family, SOCK_STREAM, 0, nullptr, 0, WSA_FLAG_OVERLAPPED);
if (m_iListenSockets[i] == INVALID_SOCKET)
return SocketUnusable;
return CListener::ErrorType::SocketUnusable;

// StringCchPrintf(MsgText, _countof(MsgText), L"::OnInit binding m_iListenSockets[%d] to sa_family=%u sa_data=%s len=%d"), i, AI->ai_addr->sa_family, AI->ai_addr->sa_data, AI->ai_addrlen;
// LogWarning(MsgText);
Expand All @@ -108,15 +106,15 @@ CListener::ErrorType CListener::Initialize(int TCPSocket)
if (rc)
{
if (WSAGetLastError() == WSAEADDRINUSE)
return SocketInuse;
return CListener::ErrorType::SocketInuse;
else
return SocketUnusable;
return CListener::ErrorType::SocketUnusable;
}

if (listen(m_iListenSockets[i], 10))
return SocketUnusable;
return CListener::ErrorType::SocketUnusable;
if (WSAEventSelect(m_iListenSockets[i], m_hSocketEvents[i], FD_ACCEPT))
return SocketUnusable;
return CListener::ErrorType::SocketUnusable;
i++;
}

Expand All @@ -125,32 +123,32 @@ CListener::ErrorType CListener::Initialize(int TCPSocket)
// StringCchPrintf(MsgText, _countof(MsgText), L"::OnInit no errors, m_iNumListenSockets = %d"), m_iNumListenSockets;
// LogWarning(MsgText);

return NoError;
return CListener::ErrorType::NoError;
}

// Start listening for connections, if a timeout is specified keep listening until then
void CListener::BeginListening(std::function<void(ISocketStream * StreamSock)> actualwork)
{
m_actualwork = actualwork;
m_ListenerThread = AfxBeginThread(ListenerWorker, this);
m_ListenerThread = _beginthread(ListenerWorker, 0, this);
}

void CListener::IncrementWorkerCount(int i)
{
m_WorkerCountLock.Lock();
m_WorkerCountLock.Enter();
m_WorkerCount += i;
m_WorkerCountLock.Unlock();
m_WorkerCountLock.Leave();
}

// Stop listening, tells the listener thread it can stop, then waits for it to terminate
void CListener::EndListening()
{
m_StopEvent.SetEvent();
m_StopEvent.Set();
if (m_ListenerThread)
{
WaitForSingleObject(m_ListenerThread->m_hThread, INFINITE); // Will auto delete
WaitForSingleObject((HANDLE)m_ListenerThread, INFINITE); // Will auto delete
}
m_ListenerThread = nullptr;
m_ListenerThread = 0;
}

// Log a warning
Expand Down Expand Up @@ -220,22 +218,22 @@ void CListener::Listen()

auto SSLServer = CSSLServer::Create(iReadSocket, this);
if (SSLServer && SSLServer->IsConnected)
AfxBeginThread(Worker, SSLServer);
_beginthread(Worker, 0, SSLServer);
else
delete SSLServer;
iReadSocket = INVALID_SOCKET;
}
// Either we're done, or there has been a problem, wait for all the worker threads to terminate
Sleep(500);
m_WorkerCountLock.Lock();
m_WorkerCountLock.Enter();
while (m_WorkerCount)
{
m_WorkerCountLock.Unlock();
m_WorkerCountLock.Leave();
Sleep(1000);
DebugMsg("Waiting for all workers to terminate: worker thread count = %i", m_WorkerCount);
m_WorkerCountLock.Lock();
m_WorkerCountLock.Enter();
};
m_WorkerCountLock.Unlock();
m_WorkerCountLock.Leave();
if ((iReadSocket != NULL) && (iReadSocket != INVALID_SOCKET))
closesocket(iReadSocket);
DebugMsg("End Listen method");
Expand Down
8 changes: 4 additions & 4 deletions SSLServer/SSLServer.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,30 @@
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v142</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
<UseOfMfc>Dynamic</UseOfMfc>
<UseOfMfc>false</UseOfMfc>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>StaticLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v142</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
<UseOfMfc>Dynamic</UseOfMfc>
<UseOfMfc>false</UseOfMfc>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>StaticLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v142</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
<UseOfMfc>Dynamic</UseOfMfc>
<UseOfMfc>false</UseOfMfc>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>StaticLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v142</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
<UseOfMfc>Dynamic</UseOfMfc>
<UseOfMfc>false</UseOfMfc>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
Expand Down
7 changes: 4 additions & 3 deletions SSLServer/framework.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ const bool debug = false;
#define VC_EXTRALEAN // Exclude rarely-used stuff from Windows headers
#endif

#include <afx.h>
#include <afxwin.h> // MFC core and standard components
#include <afxmt.h>
#include <atlsync.h>
#include <atltime.h>
#include <Windows.h>

#include <WS2tcpip.h>
#pragma comment(lib, "Ws2_32.lib")
#define SECURITY_WIN32
#include <security.h>
#include <strsafe.h>
Expand Down
8 changes: 4 additions & 4 deletions Samples/StreamServer/StreamServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using namespace std;
// This method is called when the first client tries to connect in order to allow a certificate to be selected to send to the client
// It has to wait for the client connect request because the client tells the server what identity it expects it to present
// This is called SNI (Server Name Indication) and it is a relatively new (it began to become available about 2005) SSL/TLS feature
SECURITY_STATUS SelectServerCert(PCCERT_CONTEXT & pCertContext, LPCTSTR pszSubjectName)
SECURITY_STATUS SelectServerCert(PCCERT_CONTEXT & pCertContext, LPCWSTR pszSubjectName)
{
SECURITY_STATUS status = SEC_E_INVALID_HANDLE;

Expand Down Expand Up @@ -48,11 +48,11 @@ bool ClientCertAcceptable(PCCERT_CONTEXT pCertContext, const bool trusted)
// This function simply runs arbitrary code and returns process information to the caller, it's just a handy utility function
bool RunApp(std::wstring app, PROCESS_INFORMATION& pi)
{ // Not strictly needed but it makes testing easier
STARTUPINFO si = {};
STARTUPINFOW si = {};
si.cb = sizeof si;
ZeroMemory(&pi, sizeof(pi));
#pragma warning(suppress:6335)
if (CreateProcess(nullptr, &app[0], nullptr, FALSE, 0, CREATE_NEW_CONSOLE, nullptr, nullptr, &si, &pi))
if (CreateProcessW(nullptr, &app[0], nullptr, FALSE, 0, CREATE_NEW_CONSOLE, nullptr, nullptr, &si, &pi))
return true;
else
{
Expand All @@ -65,7 +65,7 @@ void RunClient(std::wstring toHost = L"", PROCESS_INFORMATION * ppi = nullptr)
{
cout << "Initiating a client instance for testing.\n" << endl;
WCHAR acPathName[MAX_PATH + 1];
GetModuleFileName(nullptr, acPathName, _countof(acPathName));
GetModuleFileNameW(nullptr, acPathName, _countof(acPathName));
std::wstring appName(acPathName);
const auto len = appName.find_last_of(L'\\');
appName = appName.substr(0, len + 1) + L"StreamClient.exe " + toHost;
Expand Down
Loading

0 comments on commit 34553ba

Please sign in to comment.