Skip to content

Commit

Permalink
#268: Add support for deserialization when called from the base type
Browse files Browse the repository at this point in the history
  • Loading branch information
thearusable committed Sep 24, 2024
1 parent 0fb2e1b commit 4962d6b
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 10 deletions.
50 changes: 48 additions & 2 deletions src/checkpoint/dispatch/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,42 @@ struct Standard {
static SerialByteType* allocate();
};

/**
* \struct Prefixed
*
* \brief Traversal for polymorphic types prefixed with the vrt::TypeIdx
*/
struct Prefixed {
/**
* \brief Traverse a \c target of type \c T recursively with a general \c
* TraverserT that gets applied to each element.
* Allows to traverse only part of the data.
*
* \param[in,out] target the target to traverse
* \param[in] len the len of the target. If > 1, \c target is an array
* \param[in] check_type the flag to control type validation
* \param[in] check_mem the flag to control memory validation
* \param[in] args the args to pass to the traverser for construction
*
* \return the traverser after traversal is complete
*/
template <typename T, typename TraverserT, typename... Args>
static TraverserT traverse(T& target, SerialSizeType len, bool check_type, bool check_mem, Args&&... args);

/**
* \brief Unpack \c T from packed byte-buffer \c mem
*
* \param[in] t_buf bytes holding a serialized \c T
* \param[in] check_type the flag to control type validation
* \param[in] check_mem the flag to control memory validation
* \param[in] args arguments to the unpacker's constructor
*
* \return a pointer to an unpacked \c T
*/
template <typename T, typename UnpackerT, typename... Args>
static T* unpack(T* t_buf, bool check_type = true, bool check_mem = true, Args&&... args);
};

template <typename T>
buffer::ImplReturnType packBuffer(
T& target, SerialSizeType size, BufferObtainFnType fn
Expand All @@ -197,10 +233,20 @@ template <typename Serializer, typename T>
inline void serializeArray(Serializer& s, T* array, SerialSizeType const len);

template <typename T>
buffer::ImplReturnType serializeType(T& target, BufferObtainFnType fn = nullptr);
typename std::enable_if<std::is_class<T>::value && vrt::VirtualSerializeTraits<T>::has_virtual_serialize, buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn = nullptr);

template <typename T>
typename std::enable_if<!std::is_class<T>::value || !vrt::VirtualSerializeTraits<T>::has_virtual_serialize, buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn = nullptr);

template <typename T>
typename std::enable_if<std::is_class<T>::value && vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr);

template <typename T>
T* deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr);
typename std::enable_if<!std::is_class<T>::value || !vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr);

template <typename T>
void deserializeType(InPlaceTag, SerialByteType* data, T* t);
Expand Down
103 changes: 95 additions & 8 deletions src/checkpoint/dispatch/dispatch.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct serialization_error : public std::runtime_error {
};

template <typename T, typename TraverserT>
TraverserT& withTypeIdx(TraverserT& t) {
TraverserT& withTypeIdx(TraverserT& t, bool check_type = true) {
using CleanT = typename CleanType<typeregistry::DecodedIndex>::CleanT;
using DispatchType =
typename TraverserT::template DispatcherType<TraverserT, CleanT>;
Expand All @@ -87,9 +87,9 @@ TraverserT& withTypeIdx(TraverserT& t) {
auto val = cleanType(&serTypeIdx);
ap(t, val, len);

if (
typeregistry::validateIndex(serTypeIdx) == false ||
thisTypeIdx != serTypeIdx
if (check_type &&
(typeregistry::validateIndex(serTypeIdx) == false ||
thisTypeIdx != serTypeIdx)
) {
auto const err = std::string("Unpacking wrong type, got=") +
typeregistry::getTypeNameForIdx(thisTypeIdx) +
Expand All @@ -105,7 +105,7 @@ TraverserT& withTypeIdx(TraverserT& t) {
}

template <typename T, typename TraverserT>
TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) {
TraverserT& withMemUsed(TraverserT& t, SerialSizeType len, bool check_mem = true) {
using DispatchType =
typename TraverserT::template DispatcherType<TraverserT, SerialSizeType>;
SerializerDispatch<TraverserT, SerialSizeType, DispatchType> ap;
Expand All @@ -120,7 +120,7 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) {
auto val = cleanType(&serMemUsed);
ap(t, val, memUsedLen);

if (memUsed != serMemUsed) {
if (check_mem && memUsed != serMemUsed) {
using CleanT = typename CleanType<T>::CleanT;
std::string msg = "For type '" + typeregistry::getTypeName<CleanT>() +
"' serialization used " + std::to_string(serMemUsed) +
Expand All @@ -133,6 +133,37 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) {
return t;
}

template <typename T, typename TraverserT, typename... Args>
TraverserT Prefixed::traverse(T& target, SerialSizeType len, bool check_type, bool check_mem, Args&&... args) {
using CleanT = typename CleanType<T>::CleanT;
using DispatchType =
typename TraverserT::template DispatcherType<TraverserT, CleanT>;

TraverserT t(std::forward<Args>(args)...);

withTypeIdx<CleanT>(t, check_type);

auto val = cleanType(&target);
SerializerDispatch<TraverserT, CleanT, DispatchType> ap;

#if defined(SERIALIZATION_ERROR_CHECKING)
try {
ap(t, val, len);
} catch (serialization_error const& err) {
auto const depth = err.depth_ + 1;
auto const what = std::string(err.what()) + "\n#" + std::to_string(depth) +
" " + typeregistry::getTypeName<T>();
throw serialization_error(what, depth);
}
#else
ap(t, val, len);
#endif

withMemUsed<CleanT>(t, 1, check_mem);

return t;
}

template <typename T, typename TraverserT>
TraverserT& Traverse::with(T& target, TraverserT& t, SerialSizeType len) {
using CleanT = typename CleanType<T>::CleanT;
Expand Down Expand Up @@ -221,6 +252,12 @@ T* Standard::unpack(T* t_buf, Args&&... args) {
return t_buf;
}

template <typename T, typename UnpackerT, typename... Args>
T* Prefixed::unpack(T* t_buf, bool check_type, bool check_mem, Args&&... args) {
Prefixed::traverse<T, UnpackerT>(*t_buf, 1, check_type, check_mem, std::forward<Args>(args)...);
return t_buf;
}

template <typename T>
T* Standard::construct(SerialByteType* mem) {
return Traverse::reconstruct<T>(mem);
Expand Down Expand Up @@ -266,14 +303,17 @@ packBuffer(T& target, SerialSizeType size, BufferObtainFnType fn) {
}

template <typename T>
buffer::ImplReturnType serializeType(T& target, BufferObtainFnType fn) {
typename std::enable_if<!std::is_class<T>::value || !vrt::VirtualSerializeTraits<T>::has_virtual_serialize, buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn) {
auto len = Standard::size<T, Sizer>(target);
debug_checkpoint("serializeType: len=%ld\n", len);
return packBuffer<T>(target, len, fn);
}

template <typename T>
T* deserializeType(SerialByteType* data, SerialByteType* allocBuf) {
typename std::enable_if<!std::is_class<T>::value
|| !vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf) {
auto mem = allocBuf ? allocBuf : Standard::allocate<T>();
auto t_buf = std::unique_ptr<T>(Standard::construct<T>(mem));
T* traverser =
Expand All @@ -282,11 +322,58 @@ T* deserializeType(SerialByteType* data, SerialByteType* allocBuf) {
return traverser;
}

// TODO: this also needs to be updated
template <typename T>
void deserializeType(InPlaceTag, SerialByteType* data, T* t) {
Standard::unpack<T, UnpackerBuffer<buffer::UserBuffer>>(t, data);
}

template <typename T>
struct PrefixedType {
explicit PrefixedType(T* target) : target_(target) {
prefix_ = target->_checkpointDynamicTypeIndex();
}

vrt::TypeIdx prefix_;
T* target_;

template <typename SerializerT>
void serialize(SerializerT& s) {
s | prefix_;
s | *target_;
}
};

template <typename T>
typename std::enable_if<std::is_class<T>::value && vrt::VirtualSerializeTraits<T>::has_virtual_serialize, buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn) {
auto prefixed = PrefixedType(&target);
auto len = Standard::size<PrefixedType<T>, Sizer>(prefixed);
debug_checkpoint("serializeType: len=%ld\n", len);
return packBuffer<PrefixedType<T>>(prefixed, len, fn);
}

template <typename T>
typename std::enable_if<std::is_class<T>::value
&& vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf) {
using BaseType = vrt::checkpoint_base_type_t<T>;

auto prefix_mem = Standard::allocate<vrt::TypeIdx>();
auto prefix_buf = std::unique_ptr<vrt::TypeIdx>(Standard::construct<vrt::TypeIdx>(prefix_mem));
vrt::TypeIdx* prefix =
Prefixed::unpack<vrt::TypeIdx, UnpackerBuffer<buffer::UserBuffer>>(prefix_buf.get(), false, false, data);
prefix_buf.release();

auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType<BaseType>(*prefix);
auto t_buf = vrt::objregistry::constructConcreteType<BaseType>(*prefix, mem);
auto prefixed = PrefixedType(t_buf);

auto* traverser =
Prefixed::unpack<decltype(prefixed), UnpackerBuffer<buffer::UserBuffer>>(&prefixed, false, true, data);
return static_cast<T*>(traverser->target_);
}

}} /* end namespace checkpoint::dispatch */

#endif /*INCLUDED_SRC_CHECKPOINT_DISPATCH_DISPATCH_IMPL_H*/

0 comments on commit 4962d6b

Please sign in to comment.