From 4962d6b8dbcc6ce3e62d81a4e6c0f0425e7baab7 Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Tue, 23 Jul 2024 15:10:39 +0200 Subject: [PATCH] #268: Add support for deserialization when called from the base type --- src/checkpoint/dispatch/dispatch.h | 50 +++++++++++- src/checkpoint/dispatch/dispatch.impl.h | 103 ++++++++++++++++++++++-- 2 files changed, 143 insertions(+), 10 deletions(-) diff --git a/src/checkpoint/dispatch/dispatch.h b/src/checkpoint/dispatch/dispatch.h index e5c6da1b..46e0a2a3 100644 --- a/src/checkpoint/dispatch/dispatch.h +++ b/src/checkpoint/dispatch/dispatch.h @@ -188,6 +188,42 @@ struct Standard { static SerialByteType* allocate(); }; +/** + * \struct Prefixed + * + * \brief Traversal for polymorphic types prefixed with the vrt::TypeIdx + */ +struct Prefixed { + /** + * \brief Traverse a \c target of type \c T recursively with a general \c + * TraverserT that gets applied to each element. + * Allows to traverse only part of the data. + * + * \param[in,out] target the target to traverse + * \param[in] len the len of the target. If > 1, \c target is an array + * \param[in] check_type the flag to control type validation + * \param[in] check_mem the flag to control memory validation + * \param[in] args the args to pass to the traverser for construction + * + * \return the traverser after traversal is complete + */ + template + static TraverserT traverse(T& target, SerialSizeType len, bool check_type, bool check_mem, Args&&... args); + + /** + * \brief Unpack \c T from packed byte-buffer \c mem + * + * \param[in] t_buf bytes holding a serialized \c T + * \param[in] check_type the flag to control type validation + * \param[in] check_mem the flag to control memory validation + * \param[in] args arguments to the unpacker's constructor + * + * \return a pointer to an unpacked \c T + */ + template + static T* unpack(T* t_buf, bool check_type = true, bool check_mem = true, Args&&... args); +}; + template buffer::ImplReturnType packBuffer( T& target, SerialSizeType size, BufferObtainFnType fn @@ -197,10 +233,20 @@ template inline void serializeArray(Serializer& s, T* array, SerialSizeType const len); template -buffer::ImplReturnType serializeType(T& target, BufferObtainFnType fn = nullptr); +typename std::enable_if::value && vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type +serializeType(T& target, BufferObtainFnType fn = nullptr); + +template +typename std::enable_if::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type +serializeType(T& target, BufferObtainFnType fn = nullptr); + +template +typename std::enable_if::value && vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type +deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr); template -T* deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr); +typename std::enable_if::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type +deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr); template void deserializeType(InPlaceTag, SerialByteType* data, T* t); diff --git a/src/checkpoint/dispatch/dispatch.impl.h b/src/checkpoint/dispatch/dispatch.impl.h index 44aac6a2..d88e3b00 100644 --- a/src/checkpoint/dispatch/dispatch.impl.h +++ b/src/checkpoint/dispatch/dispatch.impl.h @@ -71,7 +71,7 @@ struct serialization_error : public std::runtime_error { }; template -TraverserT& withTypeIdx(TraverserT& t) { +TraverserT& withTypeIdx(TraverserT& t, bool check_type = true) { using CleanT = typename CleanType::CleanT; using DispatchType = typename TraverserT::template DispatcherType; @@ -87,9 +87,9 @@ TraverserT& withTypeIdx(TraverserT& t) { auto val = cleanType(&serTypeIdx); ap(t, val, len); - if ( - typeregistry::validateIndex(serTypeIdx) == false || - thisTypeIdx != serTypeIdx + if (check_type && + (typeregistry::validateIndex(serTypeIdx) == false || + thisTypeIdx != serTypeIdx) ) { auto const err = std::string("Unpacking wrong type, got=") + typeregistry::getTypeNameForIdx(thisTypeIdx) + @@ -105,7 +105,7 @@ TraverserT& withTypeIdx(TraverserT& t) { } template -TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) { +TraverserT& withMemUsed(TraverserT& t, SerialSizeType len, bool check_mem = true) { using DispatchType = typename TraverserT::template DispatcherType; SerializerDispatch ap; @@ -120,7 +120,7 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) { auto val = cleanType(&serMemUsed); ap(t, val, memUsedLen); - if (memUsed != serMemUsed) { + if (check_mem && memUsed != serMemUsed) { using CleanT = typename CleanType::CleanT; std::string msg = "For type '" + typeregistry::getTypeName() + "' serialization used " + std::to_string(serMemUsed) + @@ -133,6 +133,37 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) { return t; } +template +TraverserT Prefixed::traverse(T& target, SerialSizeType len, bool check_type, bool check_mem, Args&&... args) { + using CleanT = typename CleanType::CleanT; + using DispatchType = + typename TraverserT::template DispatcherType; + + TraverserT t(std::forward(args)...); + + withTypeIdx(t, check_type); + + auto val = cleanType(&target); + SerializerDispatch ap; + + #if defined(SERIALIZATION_ERROR_CHECKING) + try { + ap(t, val, len); + } catch (serialization_error const& err) { + auto const depth = err.depth_ + 1; + auto const what = std::string(err.what()) + "\n#" + std::to_string(depth) + + " " + typeregistry::getTypeName(); + throw serialization_error(what, depth); + } + #else + ap(t, val, len); + #endif + + withMemUsed(t, 1, check_mem); + + return t; +} + template TraverserT& Traverse::with(T& target, TraverserT& t, SerialSizeType len) { using CleanT = typename CleanType::CleanT; @@ -221,6 +252,12 @@ T* Standard::unpack(T* t_buf, Args&&... args) { return t_buf; } +template +T* Prefixed::unpack(T* t_buf, bool check_type, bool check_mem, Args&&... args) { + Prefixed::traverse(*t_buf, 1, check_type, check_mem, std::forward(args)...); + return t_buf; +} + template T* Standard::construct(SerialByteType* mem) { return Traverse::reconstruct(mem); @@ -266,14 +303,17 @@ packBuffer(T& target, SerialSizeType size, BufferObtainFnType fn) { } template -buffer::ImplReturnType serializeType(T& target, BufferObtainFnType fn) { +typename std::enable_if::value || !vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type +serializeType(T& target, BufferObtainFnType fn) { auto len = Standard::size(target); debug_checkpoint("serializeType: len=%ld\n", len); return packBuffer(target, len, fn); } template -T* deserializeType(SerialByteType* data, SerialByteType* allocBuf) { +typename std::enable_if::value + || !vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type +deserializeType(SerialByteType* data, SerialByteType* allocBuf) { auto mem = allocBuf ? allocBuf : Standard::allocate(); auto t_buf = std::unique_ptr(Standard::construct(mem)); T* traverser = @@ -282,11 +322,58 @@ T* deserializeType(SerialByteType* data, SerialByteType* allocBuf) { return traverser; } +// TODO: this also needs to be updated template void deserializeType(InPlaceTag, SerialByteType* data, T* t) { Standard::unpack>(t, data); } +template +struct PrefixedType { + explicit PrefixedType(T* target) : target_(target) { + prefix_ = target->_checkpointDynamicTypeIndex(); + } + + vrt::TypeIdx prefix_; + T* target_; + + template + void serialize(SerializerT& s) { + s | prefix_; + s | *target_; + } +}; + +template +typename std::enable_if::value && vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type +serializeType(T& target, BufferObtainFnType fn) { + auto prefixed = PrefixedType(&target); + auto len = Standard::size, Sizer>(prefixed); + debug_checkpoint("serializeType: len=%ld\n", len); + return packBuffer>(prefixed, len, fn); +} + +template +typename std::enable_if::value + && vrt::VirtualSerializeTraits::has_virtual_serialize, T*>::type +deserializeType(SerialByteType* data, SerialByteType* allocBuf) { + using BaseType = vrt::checkpoint_base_type_t; + + auto prefix_mem = Standard::allocate(); + auto prefix_buf = std::unique_ptr(Standard::construct(prefix_mem)); + vrt::TypeIdx* prefix = + Prefixed::unpack>(prefix_buf.get(), false, false, data); + prefix_buf.release(); + + auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(*prefix); + auto t_buf = vrt::objregistry::constructConcreteType(*prefix, mem); + auto prefixed = PrefixedType(t_buf); + + auto* traverser = + Prefixed::unpack>(&prefixed, false, true, data); + return static_cast(traverser->target_); +} + }} /* end namespace checkpoint::dispatch */ #endif /*INCLUDED_SRC_CHECKPOINT_DISPATCH_DISPATCH_IMPL_H*/