Skip to content

Commit

Permalink
#268: Add prefix validation to protect against requests with invalid …
Browse files Browse the repository at this point in the history
…types
  • Loading branch information
thearusable committed Sep 24, 2024
1 parent 62517d2 commit dfaf97f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 9 deletions.
16 changes: 12 additions & 4 deletions src/checkpoint/dispatch/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ struct Prefixed {
*/
template <typename T, typename UnpackerT, typename... Args>
static T* unpack(T* mem, bool check_type, bool check_mem, Args&&... args);

/**
* \brief Check if prefix is valid
*
* \param[in] prefix the prefix to be validated
*/
template <typename T>
static void validatePrefix(vrt::TypeIdx prefix);
};

template <typename T>
Expand All @@ -233,19 +241,19 @@ template <typename Serializer, typename T>
inline void serializeArray(Serializer& s, T* array, SerialSizeType const len);

template <typename T>
typename std::enable_if<std::is_class<T>::value && vrt::VirtualSerializeTraits<T>::has_virtual_serialize, buffer::ImplReturnType>::type
typename std::enable_if<vrt::VirtualSerializeTraits<T>::has_virtual_serialize, buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn = nullptr);

template <typename T>
typename std::enable_if<!std::is_class<T>::value || !vrt::VirtualSerializeTraits<T>::has_virtual_serialize, buffer::ImplReturnType>::type
typename std::enable_if<!vrt::VirtualSerializeTraits<T>::has_virtual_serialize, buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn = nullptr);

template <typename T>
typename std::enable_if<std::is_class<T>::value && vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
typename std::enable_if<vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr);

template <typename T>
typename std::enable_if<!std::is_class<T>::value || !vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
typename std::enable_if<!vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr);

template <typename T>
Expand Down
20 changes: 16 additions & 4 deletions src/checkpoint/dispatch/dispatch.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ packBuffer(T& target, SerialSizeType size, BufferObtainFnType fn) {

template <typename T>
typename std::enable_if<
!std::is_class<T>::value || !vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
!vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn) {
auto len = Standard::size<T, Sizer>(target);
Expand All @@ -314,7 +314,7 @@ serializeType(T& target, BufferObtainFnType fn) {

template <typename T>
typename std::enable_if<
!std::is_class<T>::value || !vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
!vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf) {
auto mem = allocBuf ? allocBuf : Standard::allocate<T>();
Expand Down Expand Up @@ -348,7 +348,7 @@ struct PrefixedType {

template <typename T>
typename std::enable_if<
std::is_class<T>::value && vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn) {
auto prefixed = PrefixedType(&target);
Expand All @@ -357,9 +357,19 @@ serializeType(T& target, BufferObtainFnType fn) {
return packBuffer<decltype(prefixed)>(prefixed, len, fn);
}

template <typename T>
void Prefixed::validatePrefix(vrt::TypeIdx prefix) {
if (!vrt::objregistry::isValidIdx<T>(prefix)) {
std::string const err = std::string("Unpacking invalid prefix type (") +
std::to_string(prefix) + std::string(") from object registry for type=") +
std::string(typeregistry::getTypeName<T>());
throw serialization_error(err);
}
}

template <typename T>
typename std::enable_if<
std::is_class<T>::value && vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf) {
using BaseType = vrt::checkpoint_base_type_t<T>;
Expand All @@ -371,6 +381,8 @@ deserializeType(SerialByteType* data, SerialByteType* allocBuf) {
Prefixed::unpack<vrt::TypeIdx, UnpackerBuffer<buffer::UserBuffer>>(prefix_buf.get(), false, false, data);
prefix_buf.release();

Prefixed::validatePrefix<BaseType>(*prefix);

// allocate memory based on the readed TypeIdx
auto mem = allocBuf ? allocBuf : vrt::objregistry::allocateConcreteType<BaseType>(*prefix);
auto t_buf = vrt::objregistry::constructConcreteType<BaseType>(*prefix, mem);
Expand Down
5 changes: 5 additions & 0 deletions src/checkpoint/dispatch/vrt/object_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ inline auto getObjIdx(TypeIdx han) {
return getRegistry<T>().at(han).idx_;
}

template <typename T>
inline auto isValidIdx(TypeIdx han) {
return getRegistry<T>().size() > static_cast<std::size_t>(han);
}

template <typename T>
inline auto getSizeConcreteType(TypeIdx han) {
return getRegistry<T>().at(han).size_;
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_polymorphic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ void testPolymorphicTypes(int val) {
EXPECT_EQ(val, out->getVal());
}

TEST_F(TestPolymorphic, test_polumorphic_type) {
TEST_F(TestPolymorphic, test_polymorphic_type) {
testPolymorphicTypes<Derived2, Derived2>(5);
testPolymorphicTypes<Derived1, Derived2>(50);
testPolymorphicTypes<Base, Derived2>(500);
Expand Down

0 comments on commit dfaf97f

Please sign in to comment.