Skip to content

Commit

Permalink
Add packet compression
Browse files Browse the repository at this point in the history
  • Loading branch information
nekiro committed Dec 9, 2024
1 parent 7573e41 commit 259ee4f
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 22 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ endif ()

# Find packages.
find_package(OpenSSL 3.0.0 REQUIRED COMPONENTS Crypto)
find_package(ZLIB REQUIRED)

find_package(fmt 8.1.1 CONFIG)
if (NOT fmt_FOUND)
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ target_link_libraries(tfslib PRIVATE
fmt::fmt
OpenSSL::Crypto
pugixml::pugixml
ZLIB::ZLIB
${CMAKE_THREAD_LIBS_INIT}
${LUA_LIBRARIES}
${MYSQL_CLIENT_LIBS}
Expand Down
8 changes: 6 additions & 2 deletions src/outputmessage.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ class OutputMessage : public NetworkMessage

void writeMessageLength() { add_header(info.length); }

void addCryptoHeader(checksumMode_t mode, uint32_t& sequence)
void addCryptoHeader(checksumMode_t mode)
{
if (mode == CHECKSUM_ADLER) {
add_header(adlerChecksum(&buffer[outputBufferStart], info.length));
} else if (mode == CHECKSUM_SEQUENCE) {
add_header(sequence++);
add_header(getSequenceId());
}

writeMessageLength();
Expand All @@ -48,6 +48,9 @@ class OutputMessage : public NetworkMessage
info.position += msgLen;
}

void setSequenceId(uint32_t sequence) { sequenceId = sequence; }
uint32_t getSequenceId() const { return sequenceId; }

private:
template <typename T>
void add_header(T add)
Expand All @@ -60,6 +63,7 @@ class OutputMessage : public NetworkMessage
}

MsgSize_t outputBufferStart = INITIAL_BUFFER_POSITION;
uint32_t sequenceId;
};

namespace tfs::net {
Expand Down
39 changes: 38 additions & 1 deletion src/protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,20 @@ bool XTEA_decrypt(NetworkMessage& msg, const xtea::round_keys& key)
void Protocol::onSendMessage(const OutputMessage_ptr& msg)
{
if (!rawMessages) {
if (encryptionEnabled && checksumMode == CHECKSUM_SEQUENCE) {
uint32_t compressionChecksum = 0;
if (msg->getLength() >= 128 && deflateMessage(*msg)) {
compressionChecksum = 0x80000000;
}

msg->setSequenceId(compressionChecksum | getNextSequenceId());
}

msg->writeMessageLength();

if (encryptionEnabled) {
XTEA_encrypt(*msg, key);
msg->addCryptoHeader(checksumMode, sequenceNumber);
msg->addCryptoHeader(checksumMode);
}
}
}
Expand Down Expand Up @@ -86,6 +95,34 @@ bool Protocol::RSA_decrypt(NetworkMessage& msg)
return msg.getByte() == 0;
}

bool Protocol::deflateMessage(OutputMessage& msg)
{
static thread_local std::vector<uint8_t> buffer(NETWORKMESSAGE_MAXSIZE);
zstream.next_in = msg.getOutputBuffer();
zstream.avail_in = msg.getLength();
zstream.next_out = buffer.data();
zstream.avail_out = buffer.size();

const auto result = deflate(&zstream, Z_FINISH);
if (result != Z_OK && result != Z_STREAM_END) {
std::cout << "Error while deflating packet data error: " << (zstream.msg ? zstream.msg : "unknown")
<< std::endl;
return false;
}

const auto size = zstream.total_out;
if (size <= 0) {
std::cout << "Deflated packet data had invalid size: " << size
<< " error: " << (zstream.msg ? zstream.msg : "unknown") << std::endl;
return false;
}

msg.reset();
msg.addBytes(reinterpret_cast<const char*>(buffer.data()), size);

return true;
}

Connection::Address Protocol::getIP() const
{
if (auto connection = getConnection()) {
Expand Down
23 changes: 22 additions & 1 deletion src/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@
#include "connection.h"
#include "xtea.h"

#include <zlib.h>

class Protocol : public std::enable_shared_from_this<Protocol>
{
public:
explicit Protocol(Connection_ptr connection) : connection(connection) {}
explicit Protocol(Connection_ptr connection) : connection(connection)
{
if (deflateInit2(&zstream, 6, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY) != Z_OK) {
std::cout << "ZLIB initialization error: " << (zstream.msg ? zstream.msg : "unknown") << std::endl;
}
}
virtual ~Protocol() = default;

// non-copyable
Expand Down Expand Up @@ -42,6 +49,16 @@ class Protocol : public std::enable_shared_from_this<Protocol>
}
}

uint32_t getNextSequenceId()
{
const auto sequence = ++sequenceNumber;
if (sequenceNumber >= std::numeric_limits<int32_t>::max()) {
sequenceNumber = 0;
}

return sequence;
}

protected:
static constexpr size_t RSA_BUFFER_LENGTH = 128;

Expand All @@ -57,6 +74,8 @@ class Protocol : public std::enable_shared_from_this<Protocol>

static bool RSA_decrypt(NetworkMessage& msg);

bool deflateMessage(OutputMessage& msg);

void setRawMessages(bool value) { rawMessages = value; }

virtual void release() {}
Expand All @@ -72,6 +91,8 @@ class Protocol : public std::enable_shared_from_this<Protocol>
bool encryptionEnabled = false;
checksumMode_t checksumMode = CHECKSUM_ADLER;
bool rawMessages = false;

z_stream zstream{};
};

#endif // FS_PROTOCOL_H
25 changes: 7 additions & 18 deletions vcpkg.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,27 @@
},
"libmariadb",
"openssl",
"pugixml"
"pugixml",
"zlib"
],
"features": {
"http": {
"description": "Enable HTTP support",
"dependencies": [
"boost-beast",
"boost-json"
]
"dependencies": ["boost-beast", "boost-json"]
},
"lua": {
"description": "Use Lua instead of LuaJIT",
"dependencies": [
"lua"
]
"dependencies": ["lua"]
},
"luajit": {
"description": "Use LuaJIT instead of Lua",
"dependencies": [
"luajit"
]
"dependencies": ["luajit"]
},
"unit-tests": {
"description": "Build unit tests",
"dependencies": [
"boost-test"
]
"dependencies": ["boost-test"]
}
},
"default-features": [
"lua",
"http"
],
"default-features": ["lua", "http"],
"builtin-baseline": "215a2535590f1f63788ac9bd2ed58ad15e6afdff"
}

0 comments on commit 259ee4f

Please sign in to comment.