diff --git a/src/checkpoint/dispatch/dispatch.h b/src/checkpoint/dispatch/dispatch.h index 8470a0e1..91dba3d6 100644 --- a/src/checkpoint/dispatch/dispatch.h +++ b/src/checkpoint/dispatch/dispatch.h @@ -188,50 +188,6 @@ struct Standard { static SerialByteType* allocate(); }; -/** - * \struct Prefixed - * - * \brief Traversal for polymorphic types prefixed with the vrt::TypeIdx - */ -struct Prefixed { - /** - * \brief Traverse a \c target of type \c T recursively with a general \c - * TraverserT that gets applied to each element. - * Allows to traverse only part of the data. - * - * \param[in,out] target the target to traverse - * \param[in] len the len of the target. If > 1, \c target is an array - * \param[in] check_type the flag to control type validation - * \param[in] check_mem the flag to control memory validation - * \param[in] args the args to pass to the traverser for construction - * - * \return the traverser after traversal is complete - */ - template - static TraverserT traverse(T& target, SerialSizeType len, bool check_type, bool check_mem, Args&&... args); - - /** - * \brief Unpack \c T from packed byte-buffer \c mem - * - * \param[in] mem bytes holding a serialized \c T - * \param[in] check_type the flag to control type validation - * \param[in] check_mem the flag to control memory validation - * \param[in] args arguments to the unpacker's constructor - * - * \return a pointer to an unpacked \c T - */ - template - static T* unpack(T* mem, bool check_type, bool check_mem, Args&&... args); - - /** - * \brief Check if prefix is valid - * - * \param[in] prefix the prefix to be validated - */ - template - static void validatePrefix(vrt::TypeIdx prefix); -}; - template buffer::ImplReturnType packBuffer( T& target, SerialSizeType size, BufferObtainFnType fn diff --git a/src/checkpoint/dispatch/dispatch.impl.h b/src/checkpoint/dispatch/dispatch.impl.h index 35a1db3d..efa3d659 100644 --- a/src/checkpoint/dispatch/dispatch.impl.h +++ b/src/checkpoint/dispatch/dispatch.impl.h @@ -71,7 +71,7 @@ struct serialization_error : public std::runtime_error { }; template -TraverserT& withTypeIdx(TraverserT& t, bool check_type = true) { +TraverserT& withTypeIdx(TraverserT& t) { using CleanT = typename CleanType::CleanT; using DispatchType = typename TraverserT::template DispatcherType; @@ -87,9 +87,9 @@ TraverserT& withTypeIdx(TraverserT& t, bool check_type = true) { auto val = cleanType(&serTypeIdx); ap(t, val, len); - if (check_type && - (typeregistry::validateIndex(serTypeIdx) == false || - thisTypeIdx != serTypeIdx) + if ( + typeregistry::validateIndex(serTypeIdx) == false || + thisTypeIdx != serTypeIdx ) { auto const err = std::string("Unpacking wrong type, got=") + typeregistry::getTypeNameForIdx(thisTypeIdx) + @@ -105,7 +105,7 @@ TraverserT& withTypeIdx(TraverserT& t, bool check_type = true) { } template -TraverserT& withMemUsed(TraverserT& t, SerialSizeType len, bool check_mem = true) { +TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) { using DispatchType = typename TraverserT::template DispatcherType; SerializerDispatch ap; @@ -120,7 +120,7 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len, bool check_mem = true auto val = cleanType(&serMemUsed); ap(t, val, memUsedLen); - if (check_mem && memUsed != serMemUsed) { + if (t.shouldValidateMemory() && memUsed != serMemUsed) { using CleanT = typename CleanType::CleanT; std::string msg = "For type '" + typeregistry::getTypeName() + "' serialization used " + std::to_string(serMemUsed) + @@ -133,37 +133,6 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len, bool check_mem = true return t; } -template -TraverserT Prefixed::traverse(T& target, SerialSizeType len, bool check_type, bool check_mem, Args&&... args) { - using CleanT = typename CleanType::CleanT; - using DispatchType = - typename TraverserT::template DispatcherType; - - TraverserT t(std::forward(args)...); - - withTypeIdx(t, check_type); - - auto val = cleanType(&target); - SerializerDispatch ap; - - #if defined(SERIALIZATION_ERROR_CHECKING) - try { - ap(t, val, len); - } catch (serialization_error const& err) { - auto const depth = err.depth_ + 1; - auto const what = std::string(err.what()) + "\n#" + std::to_string(depth) + - " " + typeregistry::getTypeName(); - throw serialization_error(what, depth); - } - #else - ap(t, val, len); - #endif - - withMemUsed(t, 1, check_mem); - - return t; -} - template TraverserT& Traverse::with(T& target, TraverserT& t, SerialSizeType len) { using CleanT = typename CleanType::CleanT; @@ -252,12 +221,6 @@ T* Standard::unpack(T* t_buf, Args&&... args) { return t_buf; } -template -T* Prefixed::unpack(T* t_buf, bool check_type, bool check_mem, Args&&... args) { - Prefixed::traverse(*t_buf, 1, check_type, check_mem, std::forward(args)...); - return t_buf; -} - template T* Standard::construct(SerialByteType* mem) { return Traverse::reconstruct(mem); @@ -313,9 +276,7 @@ serializeType(T& target, BufferObtainFnType fn) { } template -typename std::enable_if< - !vrt::VirtualSerializeTraits::has_virtual_serialize, - T*>::type +typename std::enable_if::has_virtual_serialize, T*>::type deserializeType(SerialByteType* data, SerialByteType* allocBuf) { auto mem = allocBuf ? allocBuf : Standard::allocate(); auto t_buf = std::unique_ptr(Standard::construct(mem)); @@ -351,14 +312,17 @@ typename std::enable_if< vrt::VirtualSerializeTraits::has_virtual_serialize, buffer::ImplReturnType>::type serializeType(T& target, BufferObtainFnType fn) { + using BaseType = vrt::checkpoint_base_type_t; + using PrefixedType = PrefixedType; + auto prefixed = PrefixedType(&target); - auto len = Standard::size(prefixed); + auto len = Standard::size(prefixed); debug_checkpoint("serializeType: len=%ld\n", len); - return packBuffer(prefixed, len, fn); + return packBuffer(prefixed, len, fn); } template -void Prefixed::validatePrefix(vrt::TypeIdx prefix) { +void validatePrefix(vrt::TypeIdx prefix) { if (!vrt::objregistry::isValidIdx(prefix)) { std::string const err = std::string("Unpacking invalid prefix type (") + std::to_string(prefix) + std::string(") from object registry for type=") + @@ -368,28 +332,28 @@ void Prefixed::validatePrefix(vrt::TypeIdx prefix) { } template -typename std::enable_if< - vrt::VirtualSerializeTraits::has_virtual_serialize, - T*>::type +typename std::enable_if::has_virtual_serialize, T*>::type deserializeType(SerialByteType* data, SerialByteType* allocBuf) { using BaseType = vrt::checkpoint_base_type_t; + using PrefixedType = PrefixedType; - auto prefix_mem = Standard::allocate(); - auto prefix_buf = std::unique_ptr(Standard::construct(prefix_mem)); - // Unpack TypeIdx, ignore checks for type and memory used - unpacking will only use a part of the data - vrt::TypeIdx* prefix = - Prefixed::unpack>(prefix_buf.get(), false, false, data); - prefix_buf.release(); + auto prefix_mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(0); + auto prefix_buf = vrt::objregistry::constructConcreteType(0, prefix_mem); + auto prefix_struct = PrefixedType(prefix_buf); + // Disable memory check during first unpacking. + // Unpacking BaseType will always result in memory amount missmatch between serialization/deserialization + auto* prefix = + Standard::unpack>(&prefix_struct, data, false); + delete prefix_buf; - Prefixed::validatePrefix(*prefix); + validatePrefix(prefix->prefix_); // allocate memory based on the readed TypeIdx - auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(*prefix); - auto t_buf = vrt::objregistry::constructConcreteType(*prefix, mem); + auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType(prefix->prefix_); + auto t_buf = vrt::objregistry::constructConcreteType(prefix->prefix_, mem); auto prefixed = PrefixedType(t_buf); - // Unpack PrefixedType, ignore checks for unpacked type and execute checks for memory used auto* traverser = - Prefixed::unpack>(&prefixed, false, true, data); + Standard::unpack>(&prefixed, data); return static_cast(traverser->target_); } diff --git a/src/checkpoint/serializers/base_serializer.h b/src/checkpoint/serializers/base_serializer.h index e64251cc..5bcd5904 100644 --- a/src/checkpoint/serializers/base_serializer.h +++ b/src/checkpoint/serializers/base_serializer.h @@ -203,6 +203,13 @@ struct BaseSerializer { */ void setVirtualDisabled(bool val) { virtual_disabled_ = val; } + /** + * \brief Check if used memory should be validated + * + * \return whether memory should be validated + */ + bool shouldValidateMemory() const { return true; } + protected: ModeType cur_mode_ = ModeType::Invalid; /**< The current mode */ bool virtual_disabled_ = false; /**< Virtual serialization disabled */ diff --git a/src/checkpoint/serializers/unpacker.h b/src/checkpoint/serializers/unpacker.h index 85e9a850..4c31f82b 100644 --- a/src/checkpoint/serializers/unpacker.h +++ b/src/checkpoint/serializers/unpacker.h @@ -54,7 +54,7 @@ template struct UnpackerBuffer : MemorySerializer { using BufferPtrType = std::unique_ptr; - explicit UnpackerBuffer(SerialByteType* buf); + explicit UnpackerBuffer(SerialByteType* buf, bool validate = true); template explicit UnpackerBuffer(Args&&... args); @@ -62,11 +62,15 @@ struct UnpackerBuffer : MemorySerializer { void contiguousBytes(void* ptr, SerialSizeType size, SerialSizeType num_elms); SerialSizeType usedBufferSize() const; + bool shouldValidateMemory() const; + private: // Size of the actually used memory (for error checking) SerialSizeType usedSize_ = 0; BufferPtrType buffer_ = nullptr; + + bool validate_memory_ = true; }; using Unpacker = UnpackerBuffer; diff --git a/src/checkpoint/serializers/unpacker.impl.h b/src/checkpoint/serializers/unpacker.impl.h index ad2429b2..73bab01e 100644 --- a/src/checkpoint/serializers/unpacker.impl.h +++ b/src/checkpoint/serializers/unpacker.impl.h @@ -54,9 +54,10 @@ namespace checkpoint { template -UnpackerBuffer::UnpackerBuffer(SerialByteType* buf) +UnpackerBuffer::UnpackerBuffer(SerialByteType* buf, bool validate) : MemorySerializer(ModeType::Unpacking), - buffer_(std::make_unique(buf, 0)) + buffer_(std::make_unique(buf, 0)), + validate_memory_(validate) { MemorySerializer::initializeBuffer(buffer_->getBuffer()); @@ -103,6 +104,11 @@ SerialSizeType UnpackerBuffer::usedBufferSize() const { return usedSize_; } +template +bool UnpackerBuffer::shouldValidateMemory() const { + return validate_memory_; +} + } /* end namespace checkpoint */ #endif /*INCLUDED_SRC_CHECKPOINT_SERIALIZERS_UNPACKER_IMPL_H*/