Skip to content

Commit

Permalink
#268: Remove validate_memory_ member from UnpackerBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
thearusable committed Sep 30, 2024
1 parent 148b93c commit 03df7fa
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 43 deletions.
46 changes: 25 additions & 21 deletions src/checkpoint/dispatch/dispatch.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ TraverserT& withMemUsed(TraverserT& t, SerialSizeType len) {
auto val = cleanType(&serMemUsed);
ap(t, val, memUsedLen);

if (t.shouldValidateMemory() && memUsed != serMemUsed) {
if (memUsed != serMemUsed) {
using CleanT = typename CleanType<T>::CleanT;
std::string msg = "For type '" + typeregistry::getTypeName<CleanT>() +
"' serialization used " + std::to_string(serMemUsed) +
Expand Down Expand Up @@ -291,46 +291,50 @@ void deserializeType(InPlaceTag, SerialByteType* data, T* t) {
Standard::unpack<T, UnpackerBuffer<buffer::UserBuffer>>(t, data);
}

template <typename T>
void validatePrefix(vrt::TypeIdx prefix) {
if (!vrt::objregistry::isValidIdx<T>(prefix)) {
std::string const err = std::string("Unpacking invalid prefix type (") +
std::to_string(prefix) + std::string(") from object registry for type=") +
std::string(typeregistry::getTypeName<T>());
throw serialization_error(err);
}
}

template <typename T>
struct PrefixedType {
using BaseType = vrt::checkpoint_base_type_t<T>;

explicit PrefixedType(T* target) : target_(target) {
// Create PrefixedType for serialization purposes
explicit PrefixedType(BaseType* target) : target_(target) {
prefix_ = target->_checkpointDynamicTypeIndex();
}

explicit PrefixedType(SerialByteType* allocBuf)
: unpack_buf_(allocBuf) {
}

vrt::TypeIdx prefix_ = 0;
T* target_ = nullptr;
SerialByteType* unpack_buf_ = nullptr;
// Create PrefixedType for deserialization purposes
explicit PrefixedType(SerialByteType* allocBuf) : unpack_buf_(allocBuf) { }

template <typename SerializerT>
void serialize(SerializerT& s) {
s | prefix_;

// Determine the correct type and allocate memory
if (s.isUnpacking()) {
validatePrefix<BaseType>(prefix_);
validatePrefix(prefix_);

auto mem = unpack_buf_ ? unpack_buf_ : vrt::objregistry::allocateConcreteType<BaseType>(prefix_);
target_ = vrt::objregistry::constructConcreteType<BaseType>(prefix_, mem);
}

s | *target_;
}

BaseType* getTarget() const {
return target_;
}

private:
void validatePrefix(vrt::TypeIdx prefix) {
if (!vrt::objregistry::isValidIdx<BaseType>(prefix)) {
std::string const err = std::string("Unpacking invalid prefix type (") +
std::to_string(prefix) + std::string(") from object registry for type=") +
std::string(typeregistry::getTypeName<BaseType>());
throw serialization_error(err);
}
}

vrt::TypeIdx prefix_ = 0;
BaseType* target_ = nullptr;
SerialByteType* unpack_buf_ = nullptr;
};

template <typename T>
Expand All @@ -355,7 +359,7 @@ deserializeType(SerialByteType* data, SerialByteType* allocBuf) {

auto prefixed = PrefixedType(allocBuf);
auto* traverser = Standard::unpack<PrefixedType, UnpackerBuffer<buffer::UserBuffer>>(&prefixed, data);
return static_cast<T*>(traverser->target_);
return static_cast<T*>(traverser->getTarget());
}

}} /* end namespace checkpoint::dispatch */
Expand Down
7 changes: 0 additions & 7 deletions src/checkpoint/serializers/base_serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,6 @@ struct BaseSerializer {
*/
void setVirtualDisabled(bool val) { virtual_disabled_ = val; }

/**
* \brief Check if used memory should be validated
*
* \return whether memory should be validated
*/
bool shouldValidateMemory() const { return true; }

protected:
ModeType cur_mode_ = ModeType::Invalid; /**< The current mode */
bool virtual_disabled_ = false; /**< Virtual serialization disabled */
Expand Down
6 changes: 1 addition & 5 deletions src/checkpoint/serializers/unpacker.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,19 @@ template <typename BufferT>
struct UnpackerBuffer : MemorySerializer {
using BufferPtrType = std::unique_ptr<BufferT>;

explicit UnpackerBuffer(SerialByteType* buf, bool validate = true);
explicit UnpackerBuffer(SerialByteType* buf);

template <typename... Args>
explicit UnpackerBuffer(Args&&... args);

void contiguousBytes(void* ptr, SerialSizeType size, SerialSizeType num_elms);
SerialSizeType usedBufferSize() const;

bool shouldValidateMemory() const;

private:
// Size of the actually used memory (for error checking)
SerialSizeType usedSize_ = 0;

BufferPtrType buffer_ = nullptr;

bool validate_memory_ = true;
};

using Unpacker = UnpackerBuffer<buffer::UserBuffer>;
Expand Down
13 changes: 3 additions & 10 deletions src/checkpoint/serializers/unpacker.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@
namespace checkpoint {

template <typename BufferT>
UnpackerBuffer<BufferT>::UnpackerBuffer(SerialByteType* buf, bool validate)
UnpackerBuffer<BufferT>::UnpackerBuffer(SerialByteType* buf)
: MemorySerializer(ModeType::Unpacking),
buffer_(std::make_unique<BufferT>(buf, 0)),
validate_memory_(validate)
buffer_(std::make_unique<BufferT>(buf, 0))
{
MemorySerializer::initializeBuffer(buffer_->getBuffer());

Expand All @@ -72,8 +71,7 @@ template <typename BufferT>
template <typename... Args>
UnpackerBuffer<BufferT>::UnpackerBuffer(Args&&... args)
: MemorySerializer(ModeType::Unpacking),
buffer_(std::make_unique<BufferT>(std::forward<Args>(args)...)),
validate_memory_(true)
buffer_(std::make_unique<BufferT>(std::forward<Args>(args)...))
{
MemorySerializer::initializeBuffer(buffer_->getBuffer());

Expand Down Expand Up @@ -105,11 +103,6 @@ SerialSizeType UnpackerBuffer<BufferT>::usedBufferSize() const {
return usedSize_;
}

template <typename BufferT>
bool UnpackerBuffer<BufferT>::shouldValidateMemory() const {
return validate_memory_;
}

} /* end namespace checkpoint */

#endif /*INCLUDED_SRC_CHECKPOINT_SERIALIZERS_UNPACKER_IMPL_H*/

0 comments on commit 03df7fa

Please sign in to comment.