Skip to content

Commit

Permalink
Use UniquePtr in VmtCopy class
Browse files Browse the repository at this point in the history
  • Loading branch information
danielkrupinski committed Aug 1, 2023
1 parent d89e33b commit e0f1b44
Showing 1 changed file with 6 additions and 19 deletions.
25 changes: 6 additions & 19 deletions Source/Vmt/VmtCopy.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,23 @@
#include <cstdint>
#include <cstddef>

#include <MemoryAllocation/MemoryAllocator.h>
#include <MemoryAllocation/UniquePtr.h>
#include <Platform/TypeInfoPrecedingVmt.h>
#include "VmtLength.h"

class VmtCopy {
public:
VmtCopy(std::uintptr_t* vmt, VmtLength length) noexcept
: originalVmt{ vmt }
, length{ static_cast<std::size_t>(length) }
, replacementVmtWithTypeInfo{ allocateReplacementVmtWithTypeInfo() }
, replacementVmtWithTypeInfo{ mem::makeUniqueForOverwrite<std::uintptr_t[]>(static_cast<std::size_t>(length) + platform::lengthOfTypeInfoPrecedingVmt) }
{
copyOriginalVmt();
}

[[nodiscard]] std::uintptr_t* getReplacementVmt() const noexcept
{
if (replacementVmtWithTypeInfo) [[likely]]
return replacementVmtWithTypeInfo + platform::lengthOfTypeInfoPrecedingVmt;
return replacementVmtWithTypeInfo.get() + platform::lengthOfTypeInfoPrecedingVmt;
return nullptr;
}

Expand All @@ -30,30 +29,18 @@ class VmtCopy {
return originalVmt;
}

~VmtCopy() noexcept
{
if (replacementVmtWithTypeInfo)
MemoryAllocator::deallocate(reinterpret_cast<std::byte*>(replacementVmtWithTypeInfo), MemoryAllocator::memoryFor<std::uintptr_t[]>(lengthWithTypeInfo()));
}

private:
[[nodiscard]] std::uintptr_t* allocateReplacementVmtWithTypeInfo() const noexcept
{
return new (MemoryAllocator::allocate(MemoryAllocator::memoryFor<std::uintptr_t[]>(lengthWithTypeInfo()))) std::uintptr_t[lengthWithTypeInfo()];
}

void copyOriginalVmt() const noexcept
{
if (replacementVmtWithTypeInfo) [[likely]]
std::copy_n(originalVmt - platform::lengthOfTypeInfoPrecedingVmt, lengthWithTypeInfo(), replacementVmtWithTypeInfo);
std::copy_n(originalVmt - platform::lengthOfTypeInfoPrecedingVmt, lengthWithTypeInfo(), replacementVmtWithTypeInfo.get());
}

[[nodiscard]] std::size_t lengthWithTypeInfo() const noexcept
{
return length + platform::lengthOfTypeInfoPrecedingVmt;
return replacementVmtWithTypeInfo.get_deleter().getNumberOfElements();
}

std::uintptr_t* originalVmt;
std::size_t length;
std::uintptr_t* replacementVmtWithTypeInfo;
UniquePtr<std::uintptr_t[]> replacementVmtWithTypeInfo;
};

0 comments on commit e0f1b44

Please sign in to comment.