Skip to content

Commit

Permalink
Add chunking of binaries when writing with msgpack (#2068)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Sep 14, 2023
1 parent fbd12bd commit f50ba41
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 10 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ jobs:
mkdir build
cd build
CXX=/opt/rocm/llvm/bin/clang++ CC=/opt/rocm/llvm/bin/clang cmake \
-DMIGRAPHX_DISABLE_LARGE_BUFFER_TESTS=On \
-DBUILD_DEV=On \
-DCMAKE_CXX_COMPILER_LAUNCHER=/usr/local/bin/ccache \
-DCMAKE_C_COMPILER_LAUNCHER=/usr/local/bin/ccache \
Expand Down Expand Up @@ -365,6 +366,7 @@ jobs:
rbuild build -d cget -s gh -T check \
-DCMAKE_BUILD_TYPE=${{matrix.configuration}} \
-DMIGRAPHX_ENABLE_PYTHON=${{matrix.configuration == 'release' && 'On' || 'Off'}} \
-DMIGRAPHX_DISABLE_LARGE_BUFFER_TESTS=On \
-DBUILD_DEV=On \
-DCMAKE_CXX_FLAGS_DEBUG="-g1 -Os -fdebug-prefix-map=$PWD=. -fdebug-types-section -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined" \
-DCMAKE_CXX_FLAGS_CODECOV="-g1 -Og -fdebug-prefix-map=$PWD=. -fdebug-types-section -fprofile-arcs -ftest-coverage -fno-omit-frame-pointer" \
Expand Down Expand Up @@ -481,6 +483,7 @@ jobs:
rbuild build -d cget -s gh -T check \
-DCMAKE_BUILD_TYPE=${{matrix.configuration}} \
-DMIGRAPHX_ENABLE_PYTHON=${{matrix.configuration == 'release' && 'On' || 'Off'}} \
-DMIGRAPHX_DISABLE_LARGE_BUFFER_TESTS=On \
-DBUILD_DEV=On \
-DCMAKE_CXX_FLAGS_DEBUG="-g1 -Os -fdebug-prefix-map=$PWD=. -fdebug-types-section -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined" \
-DCMAKE_CXX_FLAGS_CODECOV="-g1 -Og -fdebug-prefix-map=$PWD=. -fdebug-types-section -fprofile-arcs -ftest-coverage -fno-omit-frame-pointer" \
Expand Down
64 changes: 56 additions & 8 deletions src/msgpack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,33 @@
#include <migraphx/serialize.hpp>
#include <msgpack.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

// Leave an extra byte for error checking
constexpr std::size_t msgpack_size_limit = std::numeric_limits<uint32_t>::max() - 1;

template <class Range>
std::size_t msgpack_chunk_size(const Range& r)
{
return 1 + (r.size() - 1) / msgpack_size_limit;
}

template <class Iterator, class F>
void msgpack_chunk_for_each(Iterator start, Iterator last, F f)
{
while(std::distance(start, last) > msgpack_size_limit)
{
auto next = std::next(start, msgpack_size_limit);
f(start, next);
start = next;
}
f(start, last);
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

namespace msgpack {
MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{
Expand Down Expand Up @@ -63,16 +90,31 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
break;
}
case msgpack::type::BIN: {
// For backwards compatibility
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break;
}
case msgpack::type::ARRAY: {
migraphx::value r = migraphx::value::array{};
std::for_each(
o.via.array.ptr,
o.via.array.ptr + o.via.array.size,
[&](const msgpack::object& so) { r.push_back(so.as<migraphx::value>()); });
v = r;
if(o.via.array.size != 0 and o.via.array.ptr->type == msgpack::type::BIN)
{
auto bin = migraphx::value::binary{};
std::for_each(
o.via.array.ptr,
o.via.array.ptr + o.via.array.size,
[&](const msgpack::object& so) {
bin.insert(bin.end(), so.via.bin.ptr, so.via.bin.ptr + so.via.bin.size);
});
v = bin;
}
else
{
migraphx::value r = migraphx::value::array{};
std::for_each(
o.via.array.ptr,
o.via.array.ptr + o.via.array.size,
[&](const msgpack::object& so) { r.push_back(so.as<migraphx::value>()); });
v = r;
}
break;
}
case msgpack::type::MAP: {
Expand Down Expand Up @@ -102,8 +144,12 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{
const auto* data = reinterpret_cast<const char*>(x.data());
auto size = x.size();
o.pack_bin(size);
o.pack_bin_body(data, size);
o.pack_array(migraphx::msgpack_chunk_size(x));
migraphx::msgpack_chunk_for_each(
data, data + size, [&](const char* start, const char* last) {
o.pack_bin(last - start);
o.pack_bin_body(start, last - start);
});
return o;
}
};
Expand All @@ -129,6 +175,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
o.pack_array(0);
return;
}
if(v.size() > migraphx::msgpack_size_limit)
MIGRAPHX_THROW("Size is too large for msgpack");
if(not v.front().get_key().empty())
{
o.pack_map(v.size());
Expand Down
2 changes: 1 addition & 1 deletion src/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ std::string get_migraphx_version()
program file version is for the data structure or format of the MXR file. Version should be bumped
if any changes occur to the format of the MXR file.
*/
const int program_file_version = 6;
const int program_file_version = 7;

value program::to_value() const
{
Expand Down
5 changes: 5 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ set(CTEST_PARALLEL_LEVEL ${N} CACHE STRING "CTest parallel level")
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -j ${CTEST_PARALLEL_LEVEL} -C ${CMAKE_CFG_INTDIR} --timeout 5000)
add_custom_target(tests)

set(MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS Off CACHE BOOL "")
if(MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS)
add_compile_definitions(MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS)
endif()

find_program(MIGRAPHX_GDB gdb)

if(MIGRAPHX_GDB)
Expand Down
73 changes: 72 additions & 1 deletion test/msgpack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,37 @@
#include <migraphx/value.hpp>
#include <msgpack.hpp>
#include <map>
#include <numeric>
#include "test.hpp"

template <class T, MIGRAPHX_REQUIRES(not std::is_base_of<std::vector<std::uint8_t>, T>{})>
void write_msgpack(std::ostream& os, const T& src)
{
msgpack::pack(os, src);
}
void write_msgpack(std::ostream& os, const std::vector<std::uint8_t>& src)
{
const auto limit = std::numeric_limits<uint32_t>::max() - 1;
std::vector<std::vector<std::uint8_t>> chunks;
if(src.size() > limit)
{
// Only test two chunks
assert(std::distance(src.begin() + limit, src.end()) < limit);
chunks.emplace_back(src.begin(), src.begin() + limit);
chunks.emplace_back(src.begin() + limit, src.end());
}
else
{
chunks = {src};
}
write_msgpack(os, chunks);
}

template <class T>
std::vector<char> msgpack_buffer(const T& src)
{
std::stringstream buffer;
msgpack::pack(buffer, src);
write_msgpack(buffer, src);
buffer.seekg(0);
std::string str = buffer.str();
return std::vector<char>(str.data(), str.data() + str.size()); // NOLINT
Expand Down Expand Up @@ -147,4 +171,51 @@ TEST_CASE(test_msgpack_array_class)
EXPECT(migraphx::from_msgpack(buffer) == v);
}

TEST_CASE(test_msgpack_binary)
{
migraphx::value::binary bin{64};
std::iota(bin.begin(), bin.end(), 1);
auto buffer = migraphx::to_msgpack(bin);
EXPECT(buffer == msgpack_buffer(bin));
EXPECT(migraphx::from_msgpack(buffer) == bin);
}

#ifndef MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS
TEST_CASE(test_msgpack_large_binary1)
{
const std::size_t n = 4LL * 1024 * 1024 * 1024 + 2;
const char fill_value = 2;
migraphx::value v;
{
std::vector<char> buffer;
{
migraphx::value::binary bin{n};
std::fill(bin.begin(), bin.begin() + n / 2, fill_value);
std::fill(bin.begin() + n / 2, bin.end(), fill_value + 1);
buffer = migraphx::to_msgpack(std::move(bin));
}
v = migraphx::from_msgpack(buffer);
}
EXPECT(v.is_binary());
EXPECT(v.get_binary().size() == n);
EXPECT(std::all_of(v.get_binary().begin(), v.get_binary().begin() + n / 2, [](auto c) {
return c == fill_value;
}));
EXPECT(std::all_of(v.get_binary().begin() + n / 2, v.get_binary().end(), [](auto c) {
return c == fill_value + 1;
}));
}

TEST_CASE(test_msgpack_binary2)
{
const std::size_t n = 4LL * 1024 * 1024 * 1024 + 2;
migraphx::value::binary bin{n};
std::size_t i = 0;
std::generate(bin.begin(), bin.end(), [&] {
i++;
return i % 256;
});
EXPECT(migraphx::to_msgpack(bin) == msgpack_buffer(bin));
}
#endif
int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit f50ba41

Please sign in to comment.