diff --git a/src/luascript.cpp b/src/luascript.cpp index e7a5e88ef1..590936c318 100644 --- a/src/luascript.cpp +++ b/src/luascript.cpp @@ -2527,6 +2527,7 @@ void LuaScriptInterface::registerFunctions() registerMethod(L, "NetworkMessage", "addDouble", LuaScriptInterface::luaNetworkMessageAddDouble); registerMethod(L, "NetworkMessage", "addItem", LuaScriptInterface::luaNetworkMessageAddItem); registerMethod(L, "NetworkMessage", "addItemId", LuaScriptInterface::luaNetworkMessageAddItemId); + registerMethod(L, "NetworkMessage", "compose", LuaScriptInterface::luaNetworkMessageCompose); registerMethod(L, "NetworkMessage", "reset", LuaScriptInterface::luaNetworkMessageReset); registerMethod(L, "NetworkMessage", "seek", LuaScriptInterface::luaNetworkMessageSeek); @@ -6293,6 +6294,26 @@ int LuaScriptInterface::luaNetworkMessageAddItemId(lua_State* L) return 1; } +int LuaScriptInterface::luaNetworkMessageCompose(lua_State* L) +{ + // networkMessage:compose(otherMsg) + NetworkMessage* message = tfs::lua::getUserdata(L, 1); + if (!message) { + lua_pushnil(L); + return 1; + } + + NetworkMessage* msgToAdd = tfs::lua::getUserdata(L, 2); + if (!msgToAdd) { + lua_pushnil(L); + return 1; + } + + message->addNetworkMessage(*msgToAdd); + tfs::lua::pushBoolean(L, true); + return 1; +} + int LuaScriptInterface::luaNetworkMessageReset(lua_State* L) { // networkMessage:reset() diff --git a/src/luascript.h b/src/luascript.h index fa067cc988..e931f90d51 100644 --- a/src/luascript.h +++ b/src/luascript.h @@ -395,6 +395,7 @@ class LuaScriptInterface static int luaNetworkMessageAddDouble(lua_State* L); static int luaNetworkMessageAddItem(lua_State* L); static int luaNetworkMessageAddItemId(lua_State* L); + static int luaNetworkMessageCompose(lua_State* L); static int luaNetworkMessageReset(lua_State* L); static int luaNetworkMessageSeek(lua_State* L); diff --git a/src/networkmessage.cpp b/src/networkmessage.cpp index 319e1dd394..42b47504ba 100644 --- a/src/networkmessage.cpp +++ b/src/networkmessage.cpp @@ -10,6 +10,8 @@ #include +static constexpr size_t MAX_BODY_SIZE = 8192; + std::string NetworkMessage::getString(uint16_t stringLen /* = 0*/) { if (stringLen == 0) { @@ -42,7 +44,7 @@ void NetworkMessage::addString(std::string_view value) std::string latin1Str = boost::locale::conv::from_utf(value.data(), value.data() + value.size(), "ISO-8859-1", boost::locale::conv::skip); size_t stringLen = latin1Str.size(); - if (!canAdd(stringLen + 2) || stringLen > 8192) { + if (!canAdd(stringLen + 2) || stringLen > MAX_BODY_SIZE) { return; } @@ -61,7 +63,18 @@ void NetworkMessage::addDouble(double value, uint8_t precision /* = 2*/) void NetworkMessage::addBytes(const char* bytes, size_t size) { - if (!canAdd(size) || size > 8192) { + if (!canAdd(size) || size > MAX_BODY_SIZE) { + return; + } + + std::memcpy(buffer.data() + info.position, bytes, size); + info.position += size; + info.length += size; +} + +void NetworkMessage::addBytes(const uint8_t* bytes, size_t size) +{ + if (!canAdd(size) || size > MAX_BODY_SIZE) { return; } @@ -191,3 +204,12 @@ void NetworkMessage::addItem(const Item* item) } void NetworkMessage::addItemId(uint16_t itemId) { add(Item::items[itemId].clientId); } + +void NetworkMessage::addNetworkMessage(const NetworkMessage& networkMsg) +{ + if (!canAdd(networkMsg.getLength())) { + return; + } + + addBytes(networkMsg.getBuffer() + INITIAL_BUFFER_POSITION, networkMsg.getLength()); +} diff --git a/src/networkmessage.h b/src/networkmessage.h index 84898e19c7..9e1c4ec6bd 100644 --- a/src/networkmessage.h +++ b/src/networkmessage.h @@ -53,7 +53,14 @@ class NetworkMessage return buffer[info.position++]; } - uint8_t getPreviousByte() { return buffer[--info.position]; } + // Returns first element of body + uint8_t getPreviousByte() + { + if (info.position == INITIAL_BUFFER_POSITION) { + return buffer[INITIAL_BUFFER_POSITION]; + } + return buffer[--info.position]; + } template std::enable_if_t, T> get() noexcept @@ -100,6 +107,7 @@ class NetworkMessage } void addBytes(const char* bytes, size_t size); + void addBytes(const uint8_t* bytes, size_t size); void addPaddingBytes(size_t n); void addString(std::string_view value); @@ -111,6 +119,7 @@ class NetworkMessage void addItem(uint16_t id, uint8_t count); void addItem(const Item* item); void addItemId(uint16_t itemId); + void addNetworkMessage(const NetworkMessage& networkMsg); MsgSize_t getLength() const { return info.length; } diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 2e6d1b1a6d..c73ed23512 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -5,6 +5,7 @@ set(tests_SRC ${CMAKE_CURRENT_LIST_DIR}/test_rsa.cpp ${CMAKE_CURRENT_LIST_DIR}/test_sha1.cpp ${CMAKE_CURRENT_LIST_DIR}/test_xtea.cpp + ${CMAKE_CURRENT_LIST_DIR}/test_networkmessage.cpp ) foreach(test_src ${tests_SRC}) diff --git a/src/tests/test_networkmessage.cpp b/src/tests/test_networkmessage.cpp new file mode 100644 index 0000000000..ae155bfeeb --- /dev/null +++ b/src/tests/test_networkmessage.cpp @@ -0,0 +1,121 @@ +#define BOOST_TEST_MODULE networkmessage + +#include "../otpch.h" + +#include "../networkmessage.h" +#include "../position.h" + +#include + +BOOST_AUTO_TEST_CASE(test_networkmessage_reset) +{ + uint8_t expected = 12; + uint8_t actual = 0; + NetworkMessage msg{}; + msg.addByte(expected); + msg.setBufferPosition(0); + actual = msg.getByte(); + BOOST_TEST(msg.getLength() == 1); + BOOST_TEST(msg.getBufferPosition() == 1 + NetworkMessage::INITIAL_BUFFER_POSITION); + msg.reset(); + BOOST_TEST(msg.getLength() == 0); + BOOST_TEST(actual == expected); +} + +BOOST_AUTO_TEST_CASE(test_networkmessage_getPreviousByte) +{ + NetworkMessage msg{}; + msg.addByte(11); + msg.addByte(22); + msg.addByte(33); + msg.addByte(44); + BOOST_TEST(msg.getLength() == 4); + BOOST_TEST(msg.getPreviousByte() == 44); + BOOST_TEST(msg.getPreviousByte() == 33); + BOOST_TEST(msg.getPreviousByte() == 22); + BOOST_TEST(msg.getPreviousByte() == 11); + // overflow case + BOOST_TEST(msg.getPreviousByte() == 11); + BOOST_TEST(msg.getBufferPosition() == NetworkMessage::INITIAL_BUFFER_POSITION); +} + +BOOST_AUTO_TEST_CASE(test_networkmessage_add_get_template) +{ + NetworkMessage msg{}; + msg.add(std::numeric_limits::max()); + msg.add(std::numeric_limits::max()); + msg.add(std::numeric_limits::max()); + msg.setBufferPosition(0); + BOOST_TEST(msg.getLength() == 14); + BOOST_TEST(msg.get() == std::numeric_limits::max()); + BOOST_TEST(msg.get() == std::numeric_limits::max()); + BOOST_TEST(msg.get() == std::numeric_limits::max()); +} + +BOOST_AUTO_TEST_CASE(test_networkmessage_string) +{ + NetworkMessage msg{}; + msg.addString("test"); + msg.addString("Msg"); + msg.setBufferPosition(0); + BOOST_TEST(msg.getLength() == 11); + BOOST_TEST(msg.getString() == "test"); + BOOST_TEST(msg.getString() == "Msg"); +} + +BOOST_AUTO_TEST_CASE(test_networkmessage_position) +{ + Position position1(11, 22, 3); + Position position2(111, 222, 4); + NetworkMessage msg{}; + msg.addPosition(position1); + msg.addPosition(position2); + msg.setBufferPosition(0); + BOOST_TEST(msg.getLength() == 10); + BOOST_TEST(msg.getPosition() == position1); + BOOST_TEST(msg.getPosition() == position2); +} + +BOOST_AUTO_TEST_CASE(test_networkmessage_skip) +{ + NetworkMessage msg{}; + msg.addByte(1); + msg.addByte(2); + msg.addByte(5); + msg.setBufferPosition(0); + msg.skipBytes(2); + BOOST_TEST(msg.getLength() == 3); + BOOST_TEST(msg.getByte() == 5); +} + +BOOST_AUTO_TEST_CASE(test_networkmessage_byte) +{ + NetworkMessage msg{}; + msg.addByte(12); + msg.addByte(34); + msg.setBufferPosition(0); + BOOST_TEST(msg.getLength() == 2); + BOOST_TEST(msg.getByte() == 12); + BOOST_TEST(msg.getByte() == 34); +} + +BOOST_AUTO_TEST_CASE(test_networkmessage_bytes) +{ + uint8_t expected1[] = {1, 2, 3}; + NetworkMessage msg1{}; + msg1.addBytes(expected1, 3); + msg1.setBufferPosition(0); + BOOST_TEST(msg1.getLength() == 3); + BOOST_TEST(msg1.getByte() == expected1[0]); + BOOST_TEST(msg1.getByte() == expected1[1]); + BOOST_TEST(msg1.getByte() == expected1[2]); + + const char* expected2 = {"abc"}; + NetworkMessage msg2{}; + msg2.addBytes(expected2, 3); + msg2.setBufferPosition(0); + BOOST_TEST(msg2.getLength() == 3); + BOOST_TEST(msg2.getByte() == expected2[0]); + BOOST_TEST(msg2.getByte() == expected2[1]); + BOOST_TEST(msg2.getByte() == expected2[2]); +}