diff --git a/.gitignore b/.gitignore index df22783..b06fc71 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ cmake_install.cmake vcpkg_installed/ CMakeFiles/ +.vscode/launch.json + *.exe *.dll diff --git a/.gitmodules b/.gitmodules index 6489456..52e4ba3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "binaryninja-api"] path = binaryninja-api - url = https://github.com/Vector35/binaryninja-api + url = https://github.com/Vector35/binaryninja-api \ No newline at end of file diff --git a/README.md b/README.md index f4113c1..bf1ab75 100644 --- a/README.md +++ b/README.md @@ -7,91 +7,85 @@ Parses and symbolizes MSVC RTTI information in [Binary Ninja]. Arguably the most import function of symbolizing RTTI information is the virtual function tables. The listing below is the symbolized view of `simple.cpp` (found in test\bins). ```c -void* data_140010320 = ParentA_objLocator // ParentA -struct ParentA_vfTable = +void* data_140010320 = ParentA::`RTTI Complete Object Locator +struct ParentA::VTable ParentA::`vftable = { void* (* const vFunc_0)(void* arg1, int32_t arg2) = ParentB::vFunc_0 - int64_t (* const vFunc_1)() = ParentA::vFunc_1 - int64_t (* const vFunc_2)() = ParentA::vFunc_2 + int64_t (* const vFunc_1)() __pure = ParentA::vFunc_1 + int64_t (* const vFunc_2)() __pure = ParentA::vFunc_2 } -void* data_140010340 = ParentB_objLocator // ParentB -struct ParentB_vfTable = +void* data_140010340 = ParentB::`RTTI Complete Object Locator +struct ParentB::VTable ParentB::`vftable = { void* (* const vFunc_0)(void* arg1, int32_t arg2) = ParentB::vFunc_0 - int64_t (* const vFunc_1)() = ParentB::vFunc_1 + int64_t (* const vFunc_1)() __pure = ParentB::vFunc_1 } -void* data_140010358 = SomeClass_objLocator // SomeClass : ParentA, ParentB -struct SomeClass_vfTable = +void* data_140010358 = SomeClass::`RTTI Complete Object Locator +struct SomeClass::VTable SomeClass::`vftable = { void* (* const vFunc_0)(void* arg1, int32_t arg2) = SomeClass::vFunc_0 - int64_t (* const vFunc_1)() = ParentA::vFunc_1 - int64_t (* const vFunc_2)() = ParentA::vFunc_2 - int64_t (* const vFunc_3)() = SomeClass::vFunc_3 + int64_t (* const vFunc_1)() __pure = ParentA::vFunc_1 + int64_t (* const vFunc_2)() __pure = ParentA::vFunc_2 + int64_t (* const vFunc_3)() __pure = SomeClass::vFunc_3 } -void* data_140010380 = SomeClass_objLocator // __offset(8) SomeClass : ParentA, ParentB -struct SomeClass_vfTable = +void* data_140010380 = SomeClass::`RTTI Complete Object Locator{for `ParentB} +struct ParentB::VTable SomeClass::`vftable{for `ParentB} = { - int64_t (* const vFunc_0)(int64_t arg1, int32_t arg2) = SomeClass::vFunc_0 - int64_t (* const vFunc_1)() = ParentB::vFunc_1 -} -void* data_140010398 = type_info_objLocator // type_info -struct type_info_vfTable = -{ - void*** (* const vFunc_0)(void*** arg1, char arg2) = type_info::vFunc_0 -} -void* data_1400103a8 = std::exception_objLocator // std::exception -struct std::exception_vfTable = -{ - void*** (* const vFunc_0)(void*** arg1, char arg2) = std::exception::vFunc_0 - int64_t (* const vFunc_1)() = std::exception::vFunc_1 + void* (* const vFunc_0)(void* arg1, int32_t arg2) = SomeClass::vFunc_0 + int64_t (* const vFunc_1)() __pure = ParentB::vFunc_1 } ``` ## Example Constructor Listing -Based off the information collected from the RTTI scan, we can deduce constructors and create types and symbolize their structures. Using the [type inheritence](https://binary.ninja/2023/05/03/3.4-finally-freed.html#inherited-types) in [Binary Ninja] we can make these types easily composable. The listing below shows the fully symbolized constructor function for `SomeClass` in `simple.cpp` (found in test\bins), as well as the accompanying auto created type. -```c -struct __base(ParentA, 0) __base(ParentB, 0) __data_var_refs SomeClass +Based off the information collected from the RTTI scan, we can deduce constructors and create types and symbolize their structures. Using the [type inheritence](https://binary.ninja/2023/05/03/3.4-finally-freed.html#inherited-types) in [Binary Ninja] we can make these types easily composable. The listing below shows the fully symbolized constructor function for `Bird` in `overrides.cpp` (found in test\bins), as well as the accompanying auto created type. + +```cpp +class __base(Animal, 0) __base(Flying, 0) Bird { - struct SomeClass_vfTable* vft_SomeClass; - struct __ptr_offset(0x8) SomeClass_vfTable* vft_ptr_offset(0x8) SomeClass; + struct `Bird::VTable`* vtable; + char const* field_8; + struct `Flying::VTable`* vtable_Flying; + int32_t field_18; + __padding char _1C[4]; + int32_t field_20; }; -struct SomeClass* SomeClass::SomeClass(struct SomeClass* this) +class Bird* Bird::Bird(class Bird* this, int32_t arg2) { - arg_8 = this; - ParentA::ParentA(arg_8); - ParentB::ParentB(&arg_8->vft_ptr_offset(0x8) SomeClass); - arg_8->vft_SomeClass = &SomeClass_vfTable; - arg_8->vft_ptr_offset(0x8) SomeClass = &__ptr_offset(0x8) SomeClass_vfTable; - return arg_8; + Animal::Animal(this); + Flying::Flying(&this->vtable_Flying); + this->vtable = &Bird::`vftable'; + this->vtable_Flying = &Bird::`vftable'{for `Flying}; + this->field_8 = "A bird"; + this->field_18 = 0x58; + this->field_20 = arg2; + return this; } ``` ## Example Virtual Function Listing -Using the newly created constructor object type in [Example Constructor Listing](#example-constructor-listing) we can apply it to all virtual functions as the first parameter. The listing below shows a fully symbolized virtual function for `SomeClass` in `simple.cpp` (found in test\bins). +Using the newly created constructor object type in [Example Constructor Listing](#example-constructor-listing) we can apply it to all virtual functions as the first parameter. The listing below shows a fully symbolized virtual function for `Bird` in `overrides.cpp` (found in test\bins). + ```c -struct SomeClass* SomeClass::vFunc_0(struct SomeClass* this, int32_t arg2) +uint64_t Bird::vFunc_0(class Bird* this) { - sub_140001140(this); - if ((arg2 & 1) != 0) + int32_t var_18 = 0; + uint64_t field_20; + while (true) { - j_sub_140005f3c(this); + field_20 = ((uint64_t)this->field_20); + if (var_18 >= field_20) + { + break; + } + fputs("Tweet!"); + var_18 = (var_18 + 1); } - return this; + return field_20; } -``` - -## TODO -- ~~Identify virtual functions~~ and integrate with component view. -- Provide a UI to view associated classes. -- Graphviz support. -- Automatic scan on binary open. -- Provide CI for releasing new versions automatically and on all platforms. -- Provide better logging. -- Fixup cross references between defined symbols and their RVA pointers. -- Provide statistics on discovered functions and other useful information after completion. +``` [Binary Ninja]: https://binary.ninja diff --git a/binaryninja-api b/binaryninja-api index 7598688..84c7821 160000 --- a/binaryninja-api +++ b/binaryninja-api @@ -1 +1 @@ -Subproject commit 7598688466960427890036590239565364310171 +Subproject commit 84c782175a002a0a62ba1b00234de6b89380c9fc diff --git a/include/base_class_array.h b/include/base_class_array.h index ebb09e8..1a1a102 100644 --- a/include/base_class_array.h +++ b/include/base_class_array.h @@ -17,6 +17,8 @@ class BaseClassArray BaseClassArray(BinaryView* view, uint64_t address, int32_t length); std::vector GetBaseClassDescriptors(); + BaseClassDescriptor GetRootClassDescriptor(); Ref GetType(); - Ref CreateSymbol(std::string name, std::string rawName); + Ref CreateSymbol(); + std::string GetSymbolName(); }; \ No newline at end of file diff --git a/include/base_class_descriptor.h b/include/base_class_descriptor.h index 4e7cb42..b714ec2 100644 --- a/include/base_class_descriptor.h +++ b/include/base_class_descriptor.h @@ -22,9 +22,10 @@ class BaseClassDescriptor int32_t m_where_pdispValue; int32_t m_where_vdispValue; uint32_t m_attributesValue; - int32_t m_pClassHeirarchyDescriptorValue; + int32_t m_pClassHierarchyDescriptorValue; BaseClassDescriptor(BinaryView* view, uint64_t address); TypeDescriptor GetTypeDescriptor(); - Ref CreateSymbol(std::string name); + Ref CreateSymbol(); + std::string GetSymbolName(); }; \ No newline at end of file diff --git a/include/class_heirarchy_descriptor.h b/include/class_hierarchy_descriptor.h similarity index 58% rename from include/class_heirarchy_descriptor.h rename to include/class_hierarchy_descriptor.h index 7ea2aa9..dfa661f 100644 --- a/include/class_heirarchy_descriptor.h +++ b/include/class_hierarchy_descriptor.h @@ -6,9 +6,9 @@ using namespace BinaryNinja; -Ref GetClassHeirarchyDescriptorType(); +Ref GetClassHierarchyDescriptorType(); -class ClassHeirarchyDescriptor +class ClassHierarchyDescriptor { private: Ref m_view; @@ -22,7 +22,9 @@ class ClassHeirarchyDescriptor uint32_t m_numBaseClassesValue; int32_t m_pBaseClassArrayValue; - ClassHeirarchyDescriptor(BinaryView* view, uint64_t address); + ClassHierarchyDescriptor(BinaryView* view, uint64_t address); BaseClassArray GetBaseClassArray(); - Ref CreateSymbol(std::string name, std::string rawName); + BaseClassDescriptor GetRootBaseClassDescriptor(); + Ref CreateSymbol(); + std::string GetSymbolName(); }; \ No newline at end of file diff --git a/include/constructor.h b/include/constructor.h index a2e5f8c..3525f32 100644 --- a/include/constructor.h +++ b/include/constructor.h @@ -12,14 +12,14 @@ class Constructor private: Ref m_view; Ref m_func; - std::string GetRawName(); public: Constructor(BinaryView* view, Function* func); bool IsValid(); std::string GetName(); - std::vector GetInnerConstructors(); std::optional GetRootVirtualFunctionTable(); + size_t AddTag(); Ref CreateObjectType(); Ref CreateSymbol(); + std::string GetSymbolName(); }; \ No newline at end of file diff --git a/include/object_locator.h b/include/object_locator.h index 1c11688..8fd32ff 100644 --- a/include/object_locator.h +++ b/include/object_locator.h @@ -3,15 +3,13 @@ #include #include "type_descriptor.h" -#include "class_heirarchy_descriptor.h" +#include "class_hierarchy_descriptor.h" #include "virtual_function_table.h" using namespace BinaryNinja; constexpr auto COL_SIG_REV1 = 1; -Ref GetCompleteObjectLocatorType(BinaryView* view); - class CompleteObjectLocator { private: @@ -25,15 +23,19 @@ class CompleteObjectLocator uint32_t m_offsetValue; uint32_t m_cdOffsetValue; int32_t m_pTypeDescriptorValue; - int32_t m_pClassHeirarchyDescriptorValue; + int32_t m_pClassHierarchyDescriptorValue; int32_t m_pSelfValue; CompleteObjectLocator(BinaryView* view, uint64_t address); - std::string GetUniqueName(); TypeDescriptor GetTypeDescriptor(); - ClassHeirarchyDescriptor GetClassHeirarchyDescriptor(); + ClassHierarchyDescriptor GetClassHierarchyDescriptor(); std::optional GetVirtualFunctionTable(); bool IsValid(); bool IsSubObject(); - Ref CreateSymbol(std::string name, std::string rawName); + std::optional GetSubObjectTypeDescriptor(); + Ref GetType(); + std::string GetAssociatedClassName(); + Ref CreateSymbol(); + std::string GetSymbolName(); + std::string GetClassName(); }; \ No newline at end of file diff --git a/include/type_descriptor.h b/include/type_descriptor.h index d2a67db..1d9f7ad 100644 --- a/include/type_descriptor.h +++ b/include/type_descriptor.h @@ -20,5 +20,6 @@ class TypeDescriptor TypeDescriptor(BinaryView* view, uint64_t address); std::string GetDemangledName(); Ref GetType(); - Ref CreateSymbol(std::string name, std::string rawName); + Ref CreateSymbol(); + std::string GetSymbolName(); }; \ No newline at end of file diff --git a/include/utils.h b/include/utils.h index fcee0b1..0b3936e 100644 --- a/include/utils.h +++ b/include/utils.h @@ -8,4 +8,11 @@ using namespace BinaryNinja; uint64_t ReadIntWithSize(BinaryReader* reader, size_t size); std::string IntToHex(uint64_t val); Ref GetConstructorTagType(BinaryView* view); -Ref GetVirtualFunctionTableTagType(BinaryView* view); \ No newline at end of file +Ref GetVirtualFunctionTableTagType(BinaryView* view); +Ref GetVirtualFunctionTagType(BinaryView* view); +Ref GetCOLocatorTagType(BinaryView* view); +Ref GetPointerTypeChildStructure(Ref ptrType); +std::optional GetSSAVariableUnscopedDefinition(Ref mlil, SSAVariable var); +std::optional WalkToSSAVariableOffset(Ref mlil, SSAVariable var); +std::vector GetSSAVariablesForVariable(Ref func, const Variable& var); +uint64_t ResolveRelPointer(BinaryView* view, uint64_t ptrVal); \ No newline at end of file diff --git a/include/virtual_function.h b/include/virtual_function.h index 6258af6..983006c 100644 --- a/include/virtual_function.h +++ b/include/virtual_function.h @@ -18,5 +18,5 @@ class VirtualFunction // NOTE: If you create for example, a vfunc that is able to be deduped, two vtables will point to that function, we // don't want to mislead by renaming the function to the last of those vtables associated class. bool IsUnique(); - Ref CreateSymbol(std::string name, std::string rawName); + Ref CreateSymbol(std::string name); }; \ No newline at end of file diff --git a/include/virtual_function_table.h b/include/virtual_function_table.h index ffbd87d..2b12dba 100644 --- a/include/virtual_function_table.h +++ b/include/virtual_function_table.h @@ -19,6 +19,8 @@ class VirtualFunctionTable VirtualFunctionTable(BinaryView* view, uint64_t address); std::vector GetVirtualFunctions(); CompleteObjectLocator GetCOLocator(); - Ref GetType(std::string name, std::string idName); - Ref CreateSymbol(std::string name, std::string rawName); + Ref GetType(); + Ref CreateSymbol(); + std::string GetSymbolName(); + std::string GetTypeName(); }; \ No newline at end of file diff --git a/llvm-demangle/CMakeLists.txt b/llvm-demangle/CMakeLists.txt index 2fe3eed..5a2556d 100644 --- a/llvm-demangle/CMakeLists.txt +++ b/llvm-demangle/CMakeLists.txt @@ -10,4 +10,6 @@ set_target_properties(llvm-demangle PROPERTIES POSITION_INDEPENDENT_CODE ON CXX_STANDARD 17) +set_property(TARGET ${PROJECT_NAME} PROPERTY POSITION_INDEPENDENT_CODE ON) + target_include_directories(llvm-demangle PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") diff --git a/src/base_class_array.cpp b/src/base_class_array.cpp index 658c62a..d6b10a8 100644 --- a/src/base_class_array.cpp +++ b/src/base_class_array.cpp @@ -1,6 +1,7 @@ #include -#include "class_heirarchy_descriptor.h" +#include "class_hierarchy_descriptor.h" +#include "utils.h" using namespace BinaryNinja; @@ -19,12 +20,18 @@ std::vector BaseClassArray::GetBaseClassDescriptors() for (size_t i = 0; i < m_length; i++) { - baseClassDescriptors.emplace_back(BaseClassDescriptor(m_view, m_view->GetStart() + reader.Read32())); + baseClassDescriptors.emplace_back( + BaseClassDescriptor(m_view, ResolveRelPointer(m_view, (uint64_t)reader.Read32()))); } return baseClassDescriptors; } +BaseClassDescriptor BaseClassArray::GetRootClassDescriptor() +{ + return GetBaseClassDescriptors().front(); +} + Ref BaseClassArray::GetType() { StructureBuilder baseClassArrayBuilder; @@ -34,10 +41,16 @@ Ref BaseClassArray::GetType() return TypeBuilder::StructureType(&baseClassArrayBuilder).Finalize(); } -Ref BaseClassArray::CreateSymbol(std::string name, std::string rawName) +Ref BaseClassArray::CreateSymbol() { - Ref baseClassArraySym = new Symbol {DataSymbol, name, name, rawName, m_address}; + Ref baseClassArraySym = new Symbol {DataSymbol, GetSymbolName(), m_address}; m_view->DefineUserSymbol(baseClassArraySym); - m_view->DefineDataVariable(m_address, GetType()); + m_view->DefineUserDataVariable(m_address, GetType()); return baseClassArraySym; +} + +// Example: Animal::`RTTI Base Class Array' +std::string BaseClassArray::GetSymbolName() +{ + return GetRootClassDescriptor().GetTypeDescriptor().GetDemangledName() + "::`RTTI Base Class Array'"; } \ No newline at end of file diff --git a/src/base_class_descriptor.cpp b/src/base_class_descriptor.cpp index f4c68dd..6872521 100644 --- a/src/base_class_descriptor.cpp +++ b/src/base_class_descriptor.cpp @@ -1,9 +1,30 @@ #include #include "base_class_descriptor.h" +#include "utils.h" using namespace BinaryNinja; +Ref GetPMDType(BinaryView* view) +{ + Ref typeCache = view->GetTypeById("msvc_PMD"); + + if (typeCache == nullptr) + { + Ref intType = Type::IntegerType(4, true); + + StructureBuilder pmdBuilder; + pmdBuilder.AddMember(intType, "mdisp"); + pmdBuilder.AddMember(intType, "pdisp"); + pmdBuilder.AddMember(intType, "vdisp"); + + view->DefineType("msvc_PMD", QualifiedName("_PMD"), TypeBuilder::StructureType(&pmdBuilder).Finalize()); + typeCache = view->GetTypeById("msvc_PMD"); + } + + return typeCache; +} + Ref GetBaseClassDescriptorType(BinaryView* view) { Ref typeCache = view->GetTypeById("msvc_RTTIBaseClassDescriptor"); @@ -16,15 +37,12 @@ Ref GetBaseClassDescriptorType(BinaryView* view) StructureBuilder baseClassDescriptorBuilder; baseClassDescriptorBuilder.AddMember(intType, "pTypeDescriptor"); baseClassDescriptorBuilder.AddMember(uintType, "numContainedBases"); - baseClassDescriptorBuilder.AddMember(intType, "where_mdisp"); - baseClassDescriptorBuilder.AddMember(intType, "where_pdisp"); - baseClassDescriptorBuilder.AddMember(intType, "where_vdisp"); + baseClassDescriptorBuilder.AddMember(GetPMDType(view), "where"); baseClassDescriptorBuilder.AddMember(uintType, "attributes"); - baseClassDescriptorBuilder.AddMember(intType, "pClassHeirarchyDescriptor"); + baseClassDescriptorBuilder.AddMember(intType, "pClassDescriptor"); view->DefineType("msvc_RTTIBaseClassDescriptor", QualifiedName("_RTTIBaseClassDescriptor"), TypeBuilder::StructureType(&baseClassDescriptorBuilder).Finalize()); - typeCache = view->GetTypeById("msvc_RTTIBaseClassDescriptor"); } @@ -44,19 +62,26 @@ BaseClassDescriptor::BaseClassDescriptor(BinaryView* view, uint64_t address) m_where_pdispValue = (int32_t)reader.Read32(); m_where_vdispValue = (int32_t)reader.Read32(); m_attributesValue = reader.Read32(); - m_pClassHeirarchyDescriptorValue = (int32_t)reader.Read32(); + m_pClassHierarchyDescriptorValue = (int32_t)reader.Read32(); } TypeDescriptor BaseClassDescriptor::GetTypeDescriptor() { - // NOTE: No signature value attached to `BaseClassDescriptor` so must be relative? - return TypeDescriptor(m_view, m_view->GetStart() + m_pTypeDescriptorValue); + return TypeDescriptor(m_view, ResolveRelPointer(m_view, m_pTypeDescriptorValue)); } -Ref BaseClassDescriptor::CreateSymbol(std::string name) +Ref BaseClassDescriptor::CreateSymbol() { - Ref baseClassDescriptorSym = new Symbol {DataSymbol, name, m_address}; + Ref baseClassDescriptorSym = new Symbol {DataSymbol, GetSymbolName(), m_address}; m_view->DefineUserSymbol(baseClassDescriptorSym); - m_view->DefineDataVariable(m_address, GetBaseClassDescriptorType(m_view)); + m_view->DefineUserDataVariable(m_address, GetBaseClassDescriptorType(m_view)); return baseClassDescriptorSym; +} + +// Example: Animal::`RTTI Base Class Descriptor at (0,-1,0,64)' +std::string BaseClassDescriptor::GetSymbolName() +{ + return GetTypeDescriptor().GetDemangledName() + "::`RTTI Base Class Descriptor at (" + + std::to_string(m_where_mdispValue) + "," + std::to_string(m_where_pdispValue) + "," + + std::to_string(m_where_vdispValue) + "," + std::to_string(m_attributesValue) + ")'"; } \ No newline at end of file diff --git a/src/class_heirarchy_descriptor.cpp b/src/class_heirarchy_descriptor.cpp deleted file mode 100644 index 6e07498..0000000 --- a/src/class_heirarchy_descriptor.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include - -#include "class_heirarchy_descriptor.h" -#include "utils.h" - -using namespace BinaryNinja; - -Ref GetClassHeirarchyDescriptorType(BinaryView* view) -{ - Ref typeCache = view->GetTypeById("msvc_RTTIClassHeirarchyDescriptor"); - - if (typeCache == nullptr) - { - Ref uintType = Type::IntegerType(4, false); - Ref intType = Type::IntegerType(4, true); - - StructureBuilder classHeirarchyDescriptorBuilder; - classHeirarchyDescriptorBuilder.AddMember(uintType, "signature"); - classHeirarchyDescriptorBuilder.AddMember(uintType, "attributes"); - classHeirarchyDescriptorBuilder.AddMember(uintType, "numBaseClasses"); - classHeirarchyDescriptorBuilder.AddMember(intType, "pBaseClassArray"); - - view->DefineType("msvc_RTTIClassHeirarchyDescriptor", QualifiedName("_RTTIClassHeirarchyDescriptor"), - TypeBuilder::StructureType(&classHeirarchyDescriptorBuilder).Finalize()); - - typeCache = view->GetTypeById("msvc_RTTIClassHeirarchyDescriptor"); - } - - return typeCache; -} - -ClassHeirarchyDescriptor::ClassHeirarchyDescriptor(BinaryView* view, uint64_t address) -{ - BinaryReader reader = BinaryReader(view); - reader.Seek(address); - - m_view = view; - m_address = address; - m_signatureValue = reader.Read32(); - m_attributesValue = reader.Read32(); - m_numBaseClassesValue = reader.Read32(); - m_pBaseClassArrayValue = (int32_t)reader.Read32(); -} - -BaseClassArray ClassHeirarchyDescriptor::GetBaseClassArray() -{ - return BaseClassArray(m_view, m_view->GetStart() + m_pBaseClassArrayValue, m_numBaseClassesValue); -} - -Ref ClassHeirarchyDescriptor::CreateSymbol(std::string name, std::string rawName) -{ - Ref classDescSym = new Symbol {DataSymbol, name, name, rawName, m_address}; - m_view->DefineUserSymbol(classDescSym); - m_view->DefineDataVariable(m_address, GetClassHeirarchyDescriptorType(m_view)); - return classDescSym; -} \ No newline at end of file diff --git a/src/class_hierarchy_descriptor.cpp b/src/class_hierarchy_descriptor.cpp new file mode 100644 index 0000000..10d8f2d --- /dev/null +++ b/src/class_hierarchy_descriptor.cpp @@ -0,0 +1,63 @@ +#include + +#include "class_hierarchy_descriptor.h" +#include "utils.h" + +using namespace BinaryNinja; + +Ref GetClassHierarchyDescriptorType(BinaryView* view) +{ + Ref typeCache = view->GetTypeById("msvc_RTTIClassHierarchyDescriptor"); + + if (typeCache == nullptr) + { + Ref uintType = Type::IntegerType(4, false); + Ref intType = Type::IntegerType(4, true); + + StructureBuilder classHierarchyDescriptorBuilder; + classHierarchyDescriptorBuilder.AddMember(uintType, "signature"); + classHierarchyDescriptorBuilder.AddMember(uintType, "attributes"); + classHierarchyDescriptorBuilder.AddMember(uintType, "numBaseClasses"); + classHierarchyDescriptorBuilder.AddMember(intType, "pBaseClassArray"); + + view->DefineType("msvc_RTTIClassHierarchyDescriptor", QualifiedName("_RTTIClassHierarchyDescriptor"), + TypeBuilder::StructureType(&classHierarchyDescriptorBuilder).Finalize()); + + typeCache = view->GetTypeById("msvc_RTTIClassHierarchyDescriptor"); + } + + return typeCache; +} + +ClassHierarchyDescriptor::ClassHierarchyDescriptor(BinaryView* view, uint64_t address) +{ + BinaryReader reader = BinaryReader(view); + reader.Seek(address); + + m_view = view; + m_address = address; + m_signatureValue = reader.Read32(); + m_attributesValue = reader.Read32(); + m_numBaseClassesValue = reader.Read32(); + m_pBaseClassArrayValue = (int32_t)reader.Read32(); +} + +BaseClassArray ClassHierarchyDescriptor::GetBaseClassArray() +{ + return BaseClassArray(m_view, ResolveRelPointer(m_view, m_pBaseClassArrayValue), m_numBaseClassesValue); +} + +Ref ClassHierarchyDescriptor::CreateSymbol() +{ + Ref classDescSym = new Symbol {DataSymbol, GetSymbolName(), m_address}; + m_view->DefineUserSymbol(classDescSym); + m_view->DefineUserDataVariable(m_address, GetClassHierarchyDescriptorType(m_view)); + return classDescSym; +} + +// Example: Animal::`RTTI Class Hierarchy Descriptor' +std::string ClassHierarchyDescriptor::GetSymbolName() +{ + return GetBaseClassArray().GetRootClassDescriptor().GetTypeDescriptor().GetDemangledName() + + "::`RTTI Class Hierarchy Descriptor'"; +} \ No newline at end of file diff --git a/src/constructor.cpp b/src/constructor.cpp index b0c086d..0963ffe 100644 --- a/src/constructor.cpp +++ b/src/constructor.cpp @@ -2,6 +2,7 @@ #include #include "constructor.h" +#include "utils.h" using namespace BinaryNinja; @@ -13,241 +14,169 @@ Constructor::Constructor(BinaryView* view, Function* func) bool Constructor::IsValid() { - if (GetRootVirtualFunctionTable().has_value()) - { - // Check to make sure we have the first param as a _vfTable** type... - auto paramVars = m_func->GetParameterVariables(); - if (!paramVars->empty()) - { - auto registeredName = m_func->GetVariableType(paramVars->front())->GetRegisteredName(); - if (registeredName == nullptr - && registeredName->GetName().GetString().find("_vfTable") != std::string::npos) - return true; - } - } - - return false; + // TODO: Check to make sure we have the first param as a _vfTable** type... + // TODO: x86 getting constant pointer recognition failures, see: binaryninja-api issue 4399. + return GetRootVirtualFunctionTable().has_value(); } std::string Constructor::GetName() { - return GetRootVirtualFunctionTable().value().GetCOLocator().GetTypeDescriptor().GetDemangledName(); + return GetRootVirtualFunctionTable()->GetCOLocator().GetTypeDescriptor().GetDemangledName(); } -std::string Constructor::GetRawName() -{ - return GetRootVirtualFunctionTable().value().GetCOLocator().GetUniqueName(); -} - -std::vector Constructor::GetInnerConstructors() +std::optional Constructor::GetRootVirtualFunctionTable() { - std::vector innerConstructors = {}; + Ref vftTagType = GetVirtualFunctionTableTagType(m_view); + Ref mlil = m_func->GetMediumLevelIL()->GetSSAForm(); - for (auto callSite : m_func->GetCallSites()) + for (auto& block : mlil->GetBasicBlocks()) { - for (auto calleeAddr : m_view->GetCallees(callSite)) + for (size_t instIdx = block->GetStart(); instIdx < block->GetEnd(); instIdx++) { - auto calleeFuncs = m_view->GetAnalysisFunctionsForAddress(calleeAddr); - if (calleeFuncs.empty()) - continue; + MediumLevelILInstruction inst = (*mlil)[instIdx]; - for (auto calleeFunc : calleeFuncs) + if (inst.operation == MLIL_STORE_SSA) { - // TODO: Get rid of second constructor creation - if (Constructor(m_view, calleeFunc).IsValid()) + auto destExpr = inst.GetDestExpr(); + auto sourceExpr = inst.GetSourceExpr(); + DataVariable sourceDataVar = {}; + + if (m_view->GetDataVariableAtAddress(sourceExpr.GetValue().value, sourceDataVar)) { - innerConstructors.emplace_back(Constructor(m_view, calleeFunc)); + for (auto& tag : m_view->GetDataTags(sourceDataVar.address)) + { + if (tag->GetType() != vftTagType) + continue; + + switch (destExpr.operation) + { + case MLIL_VAR_SSA: + return VirtualFunctionTable(m_view, sourceDataVar.address); + } + } } } } } - return innerConstructors; + return std::nullopt; } -std::optional Constructor::GetRootVirtualFunctionTable() +size_t Constructor::AddTag() { - Ref mlil = m_func->GetMediumLevelIL(); - if (mlil == nullptr) - return std::nullopt; - Ref mlilssa = mlil->GetSSAForm(); - if (mlilssa == nullptr) - return std::nullopt; - std::optional rootVftAddr = std::nullopt; - - // TODO: Shouldn't we make sure that the vfTable is assigned to the return value? - for (auto& block : mlilssa->GetBasicBlocks()) + auto tagType = GetConstructorTagType(m_view); + auto constructorName = GetName(); + m_func->CreateUserFunctionTag(tagType, constructorName, true); + + size_t tagCount = 0; + for (auto funcTags : m_view->GetAllTagReferencesOfType(tagType)) { - for (size_t instIdx = block->GetStart(); instIdx < block->GetEnd(); instIdx++) + if (funcTags.tag->GetData() == constructorName) { - MediumLevelILInstruction inst = (*mlilssa)[instIdx]; - - inst.VisitExprs([&](const MediumLevelILInstruction& expr) { - switch (expr.operation) - { - case MLIL_STORE_SSA: - auto destExpr = expr.GetDestExpr(); - auto sourceExpr = expr.GetSourceExpr(); - DataVariable sourceDataVar = {}; - - if (m_view->GetDataVariableAtAddress(sourceExpr.GetValue().value, sourceDataVar)) - { - auto registeredName = sourceDataVar.type->GetRegisteredName(); - if (registeredName == nullptr) - return true; - // TODO: Check here to make sure we are setting constructorObjVar for a var that comes - // from the first param. - if (registeredName->GetName().GetString().find("_vfTable") != std::string::npos) - { - switch (destExpr.operation) - { - case MLIL_VAR_SSA: - rootVftAddr = sourceDataVar.address; - // Stop visiting. - return false; - } - } - } - } - return true; - }); - - // TODO: This is just awful. - if (rootVftAddr.has_value()) - break; + tagCount++; } - - // TODO: This is just awful. - if (rootVftAddr.has_value()) - break; - } - - if (rootVftAddr.has_value()) - { - return VirtualFunctionTable(m_view, rootVftAddr.value()); - } - else - { - return std::nullopt; } + return tagCount; } -// TODO: We need to check for heap allocated constructors, we should first identify all constructors with a -// reference to a vfTable then once we do that a second pass should check all callers of those constructor functions and -// see if they call `operator new` with the returned pointer being used as arg1 of those constructors. - -// TODO: OR we could just get the return var and create fields for every access... - -/* -void* rax = operator new(0x18) -struct SomeClass_vfTable** var_20 -if (rax == 0) - var_20 = nullptr -else - __builtin_memset(s: rax, c: 0, n: 0x18) - var_20 = std::_Future_error_category::_Future_error_category(rax) - // array accesses after this are fields we need to populate! -*/ - Ref Constructor::CreateObjectType() { - std::string idName = GetRawName(); - Ref typeCache = m_view->GetTypeById("msvc_" + idName); + Ref vftTagType = GetVirtualFunctionTableTagType(m_view); + QualifiedName typeName = QualifiedName(GetName()); + Ref typeCache = Type::NamedType(m_view, typeName); + + // TODO: What if one constructor has fields defined that another one doesnt, does that even happen? TLDR: We need a + // way to add fields to already existing constructor. - if (typeCache == nullptr) + // TODO: We need to get fields, we should assume the ownership of fields by using the bases offset to put all fields + // that are below the next vtable to the last assigned vtable + + if (m_view->GetTypeByName(typeName) == nullptr) { StructureBuilder objBuilder = {}; - objBuilder.SetPropagateDataVariableReferences(true); - // Collect all object fields. + objBuilder.SetStructureType(ClassStructureType); Ref mlil = m_func->GetMediumLevelIL()->GetSSAForm(); + + // Collect all object vtables and set their appropriate types. for (auto& block : mlil->GetBasicBlocks()) { for (size_t instIdx = block->GetStart(); instIdx < block->GetEnd(); instIdx++) { MediumLevelILInstruction inst = (*mlil)[instIdx]; - std::optional constructorObjVar = {}; - inst.VisitExprs([&](const MediumLevelILInstruction& expr) { - switch (expr.operation) + if (inst.operation == MLIL_STORE_SSA) + { + auto destExpr = inst.GetDestExpr(); + auto sourceExpr = inst.GetSourceExpr(); + DataVariable sourceDataVar = {}; + if (m_view->GetDataVariableAtAddress(sourceExpr.GetValue().value, sourceDataVar)) { - case MLIL_STORE_SSA: - auto destExpr = expr.GetDestExpr(); - int offset = 0; - SSAVariable assignedVar = {}; - switch (destExpr.operation) - { - case MLIL_VAR_SSA: - // NOTE: Constructor obj must give a 0 offset assignment in its constructor func. - assignedVar = destExpr.GetRawOperandAsSSAVariable(0); - break; - case MLIL_ADD: - auto leftExpr = destExpr.GetLeftExpr(); - auto rightExpr = destExpr.GetRightExpr(); - assignedVar = leftExpr.GetRawOperandAsSSAVariable(0); - offset = rightExpr.GetValue().value; - break; - }; - - auto sourceExpr = expr.GetSourceExpr(); - DataVariable sourceDataVar = {}; - - if (m_view->GetDataVariableAtAddress(sourceExpr.GetValue().value, sourceDataVar)) + for (auto& tag : m_view->GetDataTags(sourceDataVar.address)) { - auto registeredName = sourceDataVar.type->GetRegisteredName(); - if (registeredName == nullptr) - return true; - if (registeredName->GetName().GetString().find("_vfTable") != std::string::npos) + if (tag->GetType() != vftTagType) + continue; + + int64_t offset = 0; + switch (destExpr.operation) { - // TODO: Check here to make sure we are setting constructorObjVar for a var that comes - // from the first param. - switch (destExpr.operation) - { - case MLIL_VAR_SSA: - constructorObjVar = destExpr.GetRawOperandAsSSAVariable(0); - break; - case MLIL_ADD: - auto leftExpr = destExpr.GetLeftExpr(); - constructorObjVar = leftExpr.GetRawOperandAsSSAVariable(0); - break; - } - - VirtualFunctionTable currentVft = VirtualFunctionTable(m_view, sourceDataVar.address); - CompleteObjectLocator currentVftCOLocator = currentVft.GetCOLocator(); - std::string currentVftName = currentVftCOLocator.GetUniqueName(); - if (currentVftCOLocator.IsSubObject()) - currentVftName.erase(0, 2); + case MLIL_ADD: + auto leftExpr = destExpr.GetLeftExpr(); + auto rightExpr = destExpr.GetRightExpr(); + offset = rightExpr.GetValue().value; + break; + } + + VirtualFunctionTable currentVft = VirtualFunctionTable(m_view, sourceDataVar.address); + CompleteObjectLocator currentVftCOLocator = currentVft.GetCOLocator(); + if (offset != 0) + { objBuilder.AddMemberAtOffset( - Type::PointerType(m_view->GetAddressSize(), - m_view->GetTypeByRef(sourceDataVar.type->GetRegisteredName())), - "vft_" + currentVftName, offset); + Type::PointerType(m_view->GetAddressSize(), sourceDataVar.type), + "vtable_" + currentVftCOLocator.GetAssociatedClassName(), offset); } - else if (constructorObjVar.has_value() && assignedVar == constructorObjVar.value()) + else { - // TODO: Check to see if destExpr uses the `constructorObjVar` (TODO: Dont we do that?) objBuilder.AddMemberAtOffset( - sourceDataVar.type, "field_" + std::to_string(offset), offset); + Type::PointerType(m_view->GetAddressSize(), sourceDataVar.type), "vtable", offset); } } } - return true; - }); + } } } std::vector innerStructures = {}; - for (auto& innerConstructor : GetInnerConstructors()) + for (auto callSite : m_func->GetCallSites()) { - // TODO: What happens if a field exists in an inherited class and the root class? - LogDebug("inner -> %s", innerConstructor.GetName().c_str()); - innerStructures.emplace_back(BaseStructure(innerConstructor.CreateObjectType(), 0)); + for (auto calleeAddr : m_view->GetCallees(callSite)) + { + auto calleeFuncs = m_view->GetAnalysisFunctionsForAddress(calleeAddr); + if (calleeFuncs.empty()) + continue; + + auto innerConstructor = Constructor(m_view, calleeFuncs.front()); + if (!innerConstructor.IsValid()) + continue; + + std::vector varRefs = + m_func->GetMediumLevelILVariableReferencesFrom(m_view->GetDefaultArchitecture(), callSite.addr); + if (varRefs.empty()) + continue; + + // TODO: What happens if a field exists in an inherited class and the root class? (ANSWER: It is from + // the inherited class!) + auto innerTy = innerConstructor.CreateObjectType(); + + innerStructures.emplace_back(BaseStructure(innerTy->GetNamedTypeReference(), + innerConstructor.GetRootVirtualFunctionTable()->GetCOLocator().m_offsetValue, innerTy->GetWidth())); + } } objBuilder.SetBaseStructures(innerStructures); - m_view->DefineType( - "msvc_" + idName, QualifiedName(GetName()), TypeBuilder::StructureType(&objBuilder).Finalize()); + m_view->DefineUserType(typeName, TypeBuilder::StructureType(&objBuilder).Finalize()); - typeCache = m_view->GetTypeById("msvc_" + idName); + typeCache = Type::NamedType(m_view, typeName); } return typeCache; @@ -255,11 +184,19 @@ Ref Constructor::CreateObjectType() Ref Constructor::CreateSymbol() { - std::string symName = GetName(); - std::string symRawName = GetRawName(); - std::string funcName = symName + "::" + symName; - Ref newFuncSym = - new Symbol {FunctionSymbol, funcName, funcName, symRawName + "::" + symRawName, m_func->GetStart()}; + Ref newFuncSym = new Symbol {FunctionSymbol, GetSymbolName(), m_func->GetStart()}; m_view->DefineUserSymbol(newFuncSym); return newFuncSym; +} + +std::string Constructor::GetSymbolName() +{ + std::string symName = GetName(); + // If this constructor is not the first of its class, add a counter to the end of it. + size_t tagCount = AddTag(); + if (tagCount > 1) + { + return symName + "::" + symName + "_" + std::to_string(tagCount); + } + return symName + "::" + symName; } \ No newline at end of file diff --git a/src/object_locator.cpp b/src/object_locator.cpp index 2ab5c2b..6e2daca 100644 --- a/src/object_locator.cpp +++ b/src/object_locator.cpp @@ -5,33 +5,6 @@ using namespace BinaryNinja; -Ref GetCompleteObjectLocatorType(BinaryView* view) -{ - Ref typeCache = view->GetTypeById("msvc_RTTICompleteObjectLocator"); - - if (typeCache == nullptr) - { - Ref uintType = Type::IntegerType(4, false); - Ref intType = Type::IntegerType(4, true); - - StructureBuilder completeObjectLocatorBuilder; - // TODO: make signature an enum with COL_SIG_REV0 & COL_SIG_REV1? - completeObjectLocatorBuilder.AddMember(uintType, "signature"); - completeObjectLocatorBuilder.AddMember(uintType, "offset"); - completeObjectLocatorBuilder.AddMember(uintType, "cdOffset"); - completeObjectLocatorBuilder.AddMember(intType, "pTypeDescriptor"); - completeObjectLocatorBuilder.AddMember(intType, "pClassHeirarchyDescriptor"); - completeObjectLocatorBuilder.AddMember(intType, "pSelf"); - - view->DefineType("msvc_RTTICompleteObjectLocator", QualifiedName("_RTTICompleteObjectLocator"), - TypeBuilder::StructureType(&completeObjectLocatorBuilder).Finalize()); - - typeCache = view->GetTypeById("msvc_RTTICompleteObjectLocator"); - } - - return typeCache; -} - CompleteObjectLocator::CompleteObjectLocator(BinaryView* view, uint64_t address) { BinaryReader reader = BinaryReader(view); @@ -43,20 +16,15 @@ CompleteObjectLocator::CompleteObjectLocator(BinaryView* view, uint64_t address) m_offsetValue = reader.Read32(); m_cdOffsetValue = reader.Read32(); m_pTypeDescriptorValue = (int32_t)reader.Read32(); - m_pClassHeirarchyDescriptorValue = (int32_t)reader.Read32(); - m_pSelfValue = (int32_t)reader.Read32(); -} - -std::string CompleteObjectLocator::GetUniqueName() -{ - std::string uniqueName = GetTypeDescriptor().GetDemangledName(); - - if (m_offsetValue != 0) + m_pClassHierarchyDescriptorValue = (int32_t)reader.Read32(); + if (m_signatureValue == COL_SIG_REV1) { - uniqueName = "__ptr_offset(0x" + IntToHex(m_offsetValue) + ") " + uniqueName; + m_pSelfValue = (int32_t)reader.Read32(); + } + else + { + m_pSelfValue = 0; } - - return uniqueName; } TypeDescriptor CompleteObjectLocator::GetTypeDescriptor() @@ -66,11 +34,11 @@ TypeDescriptor CompleteObjectLocator::GetTypeDescriptor() return TypeDescriptor(m_view, m_pTypeDescriptorValue); } -ClassHeirarchyDescriptor CompleteObjectLocator::GetClassHeirarchyDescriptor() +ClassHierarchyDescriptor CompleteObjectLocator::GetClassHierarchyDescriptor() { if (m_signatureValue == COL_SIG_REV1) - return ClassHeirarchyDescriptor(m_view, m_view->GetStart() + m_pClassHeirarchyDescriptorValue); - return ClassHeirarchyDescriptor(m_view, m_pClassHeirarchyDescriptorValue); + return ClassHierarchyDescriptor(m_view, m_view->GetStart() + m_pClassHierarchyDescriptorValue); + return ClassHierarchyDescriptor(m_view, m_pClassHierarchyDescriptorValue); } std::optional CompleteObjectLocator::GetVirtualFunctionTable() @@ -92,11 +60,14 @@ bool CompleteObjectLocator::IsValid() if (m_signatureValue == COL_SIG_REV1) { + if (m_pSelfValue != m_address - startAddr) + return false; + // Relative addrs if (m_pTypeDescriptorValue + startAddr > endAddr) return false; - if (m_pClassHeirarchyDescriptorValue + startAddr > endAddr) + if (m_pClassHierarchyDescriptorValue + startAddr > endAddr) return false; } else @@ -105,26 +76,98 @@ bool CompleteObjectLocator::IsValid() if (m_pTypeDescriptorValue < startAddr || m_pTypeDescriptorValue > endAddr) return false; - if (m_pClassHeirarchyDescriptorValue < startAddr || m_pClassHeirarchyDescriptorValue > endAddr) + if (m_pClassHierarchyDescriptorValue < startAddr || m_pClassHierarchyDescriptorValue > endAddr) return false; } - if (m_pSelfValue != m_address - startAddr) - return false; - return true; } -// NOTE: If COLocator is a sub object then we need to retrieve the bool CompleteObjectLocator::IsSubObject() { return m_offsetValue > 0; } -Ref CompleteObjectLocator::CreateSymbol(std::string name, std::string rawName) +// TODO: This fails sometimes, figure out what causes this. +std::optional CompleteObjectLocator::GetSubObjectTypeDescriptor() +{ + if (!IsSubObject()) + return std::nullopt; + + for (auto baseClassDescs : GetClassHierarchyDescriptor().GetBaseClassArray().GetBaseClassDescriptors()) + { + if (m_offsetValue == baseClassDescs.m_where_mdispValue) + { + return baseClassDescs.GetTypeDescriptor(); + } + } + + return std::nullopt; +} + +Ref CompleteObjectLocator::GetType() +{ + Ref typeCache = m_view->GetTypeById("msvc_RTTICompleteObjectLocator" + m_signatureValue); + + if (typeCache == nullptr) + { + Ref uintType = Type::IntegerType(4, false); + Ref intType = Type::IntegerType(4, true); + + StructureBuilder completeObjectLocatorBuilder; + // TODO: make signature an enum with COL_SIG_REV0 & COL_SIG_REV1? + completeObjectLocatorBuilder.AddMember(uintType, "signature"); + completeObjectLocatorBuilder.AddMember(uintType, "offset"); + completeObjectLocatorBuilder.AddMember(uintType, "cdOffset"); + completeObjectLocatorBuilder.AddMember(intType, "pTypeDescriptor"); + completeObjectLocatorBuilder.AddMember(intType, "pClassHierarchyDescriptor"); + + if (m_signatureValue == COL_SIG_REV1) + { + completeObjectLocatorBuilder.AddMember(intType, "pSelf"); + } + + m_view->DefineType("msvc_RTTICompleteObjectLocator" + m_signatureValue, + QualifiedName("_RTTICompleteObjectLocator"), + TypeBuilder::StructureType(&completeObjectLocatorBuilder).Finalize()); + + typeCache = m_view->GetTypeById("msvc_RTTICompleteObjectLocator" + m_signatureValue); + } + + return typeCache; +} + +std::string CompleteObjectLocator::GetAssociatedClassName() { - Ref COLocSym = new Symbol {DataSymbol, name, name, rawName, m_address}; + if (IsSubObject()) + { + if (auto subObjectTypeDesc = GetSubObjectTypeDescriptor()) + { + return subObjectTypeDesc->GetDemangledName(); + } + } + return GetTypeDescriptor().GetDemangledName(); +} + +Ref CompleteObjectLocator::CreateSymbol() +{ + Ref COLocSym = new Symbol {DataSymbol, GetSymbolName(), m_address}; m_view->DefineUserSymbol(COLocSym); - m_view->DefineDataVariable(m_address, GetCompleteObjectLocatorType(m_view)); + m_view->DefineUserDataVariable(m_address, GetType()); return COLocSym; } + +std::string CompleteObjectLocator::GetSymbolName() +{ + std::string symName = GetTypeDescriptor().GetDemangledName() + "::`RTTI Complete Object Locator'"; + if (IsSubObject()) + { + symName = symName + "{for `" + GetSubObjectTypeDescriptor()->GetDemangledName() + "'}"; + } + return symName; +} + +std::string CompleteObjectLocator::GetClassName() +{ + return GetTypeDescriptor().GetDemangledName(); +} \ No newline at end of file diff --git a/src/plugin.cpp b/src/plugin.cpp index c17426a..f2c52ad 100644 --- a/src/plugin.cpp +++ b/src/plugin.cpp @@ -13,154 +13,105 @@ const size_t COLocatorSize = 24; void CreateConstructorsAtFunction(BinaryView* view, Function* func) { Constructor constructor = Constructor(view, func); - if (!constructor.IsValid()) + // Skip invalid & already tagged constructors. + if (!constructor.IsValid() || !func->GetFunctionTagsOfType(GetConstructorTagType(view)).empty()) return; - LogDebug("Attempting to create constructor -> %s", constructor.GetName().c_str()); - - // Apply this to the return type. + // TODO: Apply this to the return type. Ref objType = constructor.CreateObjectType(); - // TODO: This will immediately go and unmerge in the function body sometimes! - auto paramVars = func->GetParameterVariables(); - func->CreateUserVariable(paramVars->front(), Type::PointerType(view->GetAddressSize(), objType), "this"); - // TODO: Update all root vfuncs. - for (auto vFuncs : constructor.GetRootVirtualFunctionTable().value().GetVirtualFunctions()) + // TODO: Doing any changes to the func here do not get applied... + + auto newVFuncType = [](BinaryView* bv, Ref funcType, Ref thisType) { + auto newFuncType = TypeBuilder(funcType); + auto adjustedParams = newFuncType.GetParameters(); + if (adjustedParams.empty()) + adjustedParams.push_back({}); + adjustedParams.at(0) = FunctionParameter("this", Type::PointerType(bv->GetAddressSize(), thisType)); + newFuncType.SetParameters(adjustedParams); + return newFuncType.Finalize(); + }; + + func->SetUserType(newVFuncType(view, func->GetType(), objType)); + for (auto vFunc : constructor.GetRootVirtualFunctionTable()->GetVirtualFunctions()) { - auto paramVars = vFuncs.m_func->GetParameterVariables(); - // TODO: This does not add a param like it should... - if (paramVars->empty()) - paramVars->push_back({}); - vFuncs.m_func->CreateUserVariable( - paramVars->front(), Type::PointerType(view->GetAddressSize(), objType), "this"); + vFunc.m_func->SetUserType(newVFuncType(view, vFunc.m_func->GetType(), objType)); } // Apply to function name. constructor.CreateSymbol(); - - // Tag function as a constructor. - func->CreateUserFunctionTag(GetConstructorTagType(view), constructor.GetName()); } void CreateSymbolsFromCOLocatorAddress(BinaryView* view, uint64_t address) { - CompleteObjectLocator objLocator = CompleteObjectLocator(view, address); - if (!objLocator.IsValid()) - return; - - std::string shortName = objLocator.GetUniqueName(); - std::string rawName = shortName; - - ClassHeirarchyDescriptor classDesc = objLocator.GetClassHeirarchyDescriptor(); - TypeDescriptor typeDesc = objLocator.GetTypeDescriptor(); - - std::optional vfTableOpt = objLocator.GetVirtualFunctionTable(); - if (!vfTableOpt.has_value()) + CompleteObjectLocator coLocator = CompleteObjectLocator(view, address); + if (!coLocator.IsValid()) { - LogWarn("Failed to get VFT for %s", shortName.c_str()); + LogError("Invalid Colocator! %x", coLocator.m_address); return; } - VirtualFunctionTable vfTable = vfTableOpt.value(); + ClassHierarchyDescriptor classDesc = coLocator.GetClassHierarchyDescriptor(); + TypeDescriptor typeDesc = coLocator.GetTypeDescriptor(); BaseClassArray baseClassArray = classDesc.GetBaseClassArray(); - std::vector baseClassDescriptors = baseClassArray.GetBaseClassDescriptors(); + VirtualFunctionTable vfTable = coLocator.GetVirtualFunctionTable().value(); - if (objLocator.m_offsetValue != 0) + for (auto&& baseClassDesc : baseClassArray.GetBaseClassDescriptors()) { - rawName = "__offset(" + std::to_string(objLocator.m_offsetValue) + ") " + rawName; + baseClassDesc.CreateSymbol(); } - // TODO: Cleanup this! - if (!baseClassDescriptors.empty()) - { - std::string inheritenceName = " : "; - bool first = true; - for (auto&& baseClassDesc : baseClassDescriptors) - { - std::string demangledBaseClassDescName = baseClassDesc.GetTypeDescriptor().GetDemangledName(); - baseClassDesc.CreateSymbol(demangledBaseClassDescName + "_baseClassDesc"); - if (demangledBaseClassDescName != shortName) - { - if (first) - { - inheritenceName.append(demangledBaseClassDescName); - first = false; - } - else - { - inheritenceName.append(", " + demangledBaseClassDescName); - } - } - } - - if (first == false) - { - rawName.append(inheritenceName); - } - } - - LogDebug("Creating symbols for %s...", shortName.c_str()); - - - // TODO: If option is enabled, create a new structure for this class and define the vtable structures and - // everything. (so vfuncs are resolved...) - size_t vFuncIdx = 0; + auto vftTagType = GetVirtualFunctionTagType(view); for (auto&& vFunc : vfTable.GetVirtualFunctions()) { // TODO: Check to see if function already changed by user, if not, don't modify? - // rename them, demangledName::funcName - if (vFunc.IsUnique()) + // Must be owned by the class, no inheritence, OR must be unique to the vtable. + if (coLocator.GetClassHierarchyDescriptor().m_numBaseClassesValue <= 1 || vFunc.IsUnique()) { - vFunc.m_func->SetComment("Unique to " + shortName); - vFunc.CreateSymbol( - shortName + "::vFunc_" + std::to_string(vFuncIdx), rawName + "::vFunc_" + std::to_string(vFuncIdx)); + // Remove "Unresolved ownership" tag. + vFunc.m_func->RemoveUserFunctionTagsOfType(vftTagType); + vFunc.m_func->CreateUserFunctionTag(vftTagType, "Resolved to " + coLocator.GetClassName(), true); + vFunc.CreateSymbol(coLocator.GetClassName() + "::vFunc_" + std::to_string(vFuncIdx)); } - else + else if (vFunc.m_func->GetUserFunctionTagsOfType(vftTagType).empty()) { - // Must be owned by the class, no inheritence. - if (classDesc.m_numBaseClassesValue <= 1) - { - vFunc.CreateSymbol( - shortName + "::vFunc_" + std::to_string(vFuncIdx), rawName + "::vFunc_" + std::to_string(vFuncIdx)); - } + vFunc.m_func->CreateUserFunctionTag(vftTagType, "Unresolved ownership", true); } vFuncIdx++; } - // Set comment showing raw name. - size_t addrSize = view->GetAddressSize(); - std::vector objLocatorRefs = view->GetDataReferences(objLocator.m_address); - if (!objLocatorRefs.empty()) - view->SetCommentForAddress(objLocatorRefs.front(), rawName); - - objLocator.CreateSymbol(shortName + "_objLocator", rawName + "_objLocator"); - vfTable.CreateSymbol(shortName + "_vfTable", rawName + "_vfTable"); - typeDesc.CreateSymbol(shortName + "_typeDesc", rawName + "_typeDesc"); - classDesc.CreateSymbol(shortName + "_classDesc", rawName + "_classDesc"); - baseClassArray.CreateSymbol(shortName + "_classArray", rawName + "_classArray"); + coLocator.CreateSymbol(); + vfTable.CreateSymbol(); + typeDesc.CreateSymbol(); + classDesc.CreateSymbol(); + baseClassArray.CreateSymbol(); + // Add tag to objLocator... + view->CreateUserDataTag(coLocator.m_address, GetCOLocatorTagType(view), coLocator.GetClassName()); // Add tag to vfTable... - view->CreateUserDataTag(vfTable.m_address, GetVirtualFunctionTableTagType(view), shortName); + view->CreateUserDataTag(vfTable.m_address, GetVirtualFunctionTableTagType(view), vfTable.GetSymbolName()); } void ScanRTTIView(BinaryView* view) { uint64_t bvStartAddr = view->GetStart(); - - auto undo = view->BeginUndoActions(); + std::string undoId = view->BeginUndoActions(); view->BeginBulkModifySymbols(); + BinaryReader optReader = BinaryReader(view); + // Scan data sections for colocators. for (Ref segment : view->GetSegments()) { if (segment->GetFlags() & (SegmentReadable | SegmentContainsData | SegmentDenyExecute | SegmentDenyWrite)) { - BinaryReader optReader = BinaryReader(view); LogDebug("Attempting to find CompleteObjectLocators in segment %x", segment->GetStart()); + // TODO: Check to see if they are always aligned, if so currAddr += addrSize for (uint64_t currAddr = segment->GetStart(); currAddr < segment->GetEnd() - COLocatorSize; currAddr++) { optReader.Seek(currAddr); - if (optReader.Read32() == 1) + uint32_t sigVal = optReader.Read32(); + if (sigVal == 1) { optReader.SeekRelative(16); if (optReader.Read32() == currAddr - bvStartAddr) @@ -168,34 +119,141 @@ void ScanRTTIView(BinaryView* view) CreateSymbolsFromCOLocatorAddress(view, currAddr); } } + else if (sigVal == 0) + { + // Check ?AV + optReader.SeekRelative(8); + uint64_t typeDescNameAddr = optReader.Read32() + 8; + if (typeDescNameAddr > view->GetStart() && typeDescNameAddr < view->GetEnd()) + { + // Make sure we do not read across segment boundary. + auto typeDescSegment = view->GetSegmentAt(typeDescNameAddr); + if (typeDescSegment != nullptr && typeDescSegment->GetEnd() - typeDescNameAddr > 4) + { + optReader.Seek(typeDescNameAddr); + if (optReader.ReadString(4) == ".?AV") + { + CreateSymbolsFromCOLocatorAddress(view, currAddr); + } + } + } + } } } } view->EndBulkModifySymbols(); - view->CommitUndoActions(undo); + view->CommitUndoActions(undoId); + view->Reanalyze(); } void ScanConstructorView(BinaryView* view) { - auto undo = view->BeginUndoActions(); + std::string undoId = view->BeginUndoActions(); view->BeginBulkModifySymbols(); - std::vector doneFuncs = {}; + std::vector funcRefs = {}; for (auto vtableTag : view->GetAllTagReferencesOfType(GetVirtualFunctionTableTagType(view))) { for (auto codeRef : view->GetCodeReferences(vtableTag.addr)) { uint64_t funcStart = codeRef.func->GetStart(); - if (std::find(doneFuncs.begin(), doneFuncs.end(), funcStart) != doneFuncs.end()) + if (std::find(funcRefs.begin(), funcRefs.end(), funcStart) != funcRefs.end()) continue; - doneFuncs.push_back(funcStart); + funcRefs.push_back(funcStart); CreateConstructorsAtFunction(view, codeRef.func); } } view->EndBulkModifySymbols(); - view->CommitUndoActions(undo); + view->CommitUndoActions(undoId); + view->Reanalyze(); +} + +void ScanClassFieldsView(BinaryView* view) +{ + std::string undoId = view->BeginUndoActions(); + view->BeginBulkModifySymbols(); + + auto applyFieldAccessesToNamedType = [](BinaryView* bv, Ref targetType) { + auto typeName = targetType->GetStructureName(); + bool newMemberAdded = false; + auto appliedStruct = bv->CreateStructureFromOffsetAccess(typeName, &newMemberAdded); + if (newMemberAdded) + { + bv->DefineUserType(typeName, targetType->WithReplacedStructure(targetType->GetStructure(), appliedStruct)); + } + }; + + for (auto coLocatorTag : view->GetAllTagReferencesOfType(GetCOLocatorTagType(view))) + { + auto coLocator = CompleteObjectLocator(view, coLocatorTag.addr); + + for (auto baseClassDesc : coLocator.GetClassHierarchyDescriptor().GetBaseClassArray().GetBaseClassDescriptors()) + { + auto baseClassType = + view->GetTypeByName(QualifiedName(baseClassDesc.GetTypeDescriptor().GetDemangledName())); + if (baseClassType != nullptr) + { + applyFieldAccessesToNamedType(view, baseClassType); + } + } + + auto classType = view->GetTypeByName(QualifiedName(coLocator.GetClassName())); + if (classType != nullptr) + { + applyFieldAccessesToNamedType(view, classType); + } + } + + view->EndBulkModifySymbols(); + view->CommitUndoActions(undoId); + view->Reanalyze(); +} + +void GenerateConstructorGraphViz(BinaryView* view) +{ + std::stringstream out; + out << "digraph Constructors {node [shape=record];\n"; + + auto tagType = GetConstructorTagType(view); + for (auto constructorTag : view->GetAllTagReferencesOfType(tagType)) + { + auto classNamedType = GetPointerTypeChildStructure( + constructorTag.func->GetVariableType(constructorTag.func->GetParameterVariables()->front())); + if (classNamedType == nullptr) + { + LogWarn("class with data (%s) has invalid this param", constructorTag.tag->GetData().c_str()); + continue; + } + + auto className = classNamedType->GetTypeName().GetString(); + out << '"' << className << '"' << " [label=\"{" << className; + + auto classType = view->GetTypeById(classNamedType->GetNamedTypeReference()->GetTypeId()); + if (!classType->IsStructure()) + { + LogWarn("class %s has invalid this param, not a structure...", className.c_str()); + continue; + } + + auto classTypeStruct = view->GetTypeById(classNamedType->GetNamedTypeReference()->GetTypeId())->GetStructure(); + for (auto classMember : classTypeStruct->GetMembersIncludingInherited(view)) + { + // TODO: Handle inherited by adding an arrow to the real struct i guess? (use ports?) + out << "|{0x" << IntToHex(classMember.member.offset) << "|" << classMember.member.name << "}"; + } + + out << "}\"];\n"; + + for (auto baseStruct : classTypeStruct->GetBaseStructures()) + { + out << '"' << className << "\"->\"" << baseStruct.type->GetName().GetString() << "\";\n"; + } + } + + out << "}"; + view->ShowPlainTextReport("MSVC Constructor GraphViz DOT", out.str()); } extern "C" @@ -204,17 +262,11 @@ extern "C" BINARYNINJAPLUGIN bool CorePluginInit() { - // Ref settings = Settings::Instance(); - // settings->RegisterGroup("msvc", "MSVC"); - // settings->RegisterSetting("msvc.autosearch", R"~({ - // "title" : "Automatically Scan RTTI", - // "type" : "boolean", - // "default" : true, - // "description" : "Automatically search and symbolize RTTI information" - // })~"); - - PluginCommand::Register("Find MSVC RTTI", "Scans for all RTTI in view.", &ScanRTTIView); - PluginCommand::Register("Find MSVC Constructors", "Scans for all constructors in view.", &ScanConstructorView); + PluginCommand::Register("MSVC\\Find RTTI", "Scans for all RTTI in view.", &ScanRTTIView); + PluginCommand::Register("MSVC\\Find Constructors", "Scans for all constructors in view.", &ScanConstructorView); + PluginCommand::Register("MSVC\\Find Class Fields", "Scans for all class fields in view.", &ScanClassFieldsView); + PluginCommand::Register("MSVC\\Generate Constructors Graphviz", + "Makes a graph from all the available MSVC constructors.", &GenerateConstructorGraphViz); return true; } diff --git a/src/type_descriptor.cpp b/src/type_descriptor.cpp index d1ec89b..bfb4edf 100644 --- a/src/type_descriptor.cpp +++ b/src/type_descriptor.cpp @@ -21,8 +21,10 @@ TypeDescriptor::TypeDescriptor(BinaryView* view, uint64_t address) std::string TypeDescriptor::GetDemangledName() { - std::string demangledNameValue = - std::string(llvm::microsoftDemangle(m_nameValue.c_str(), nullptr, nullptr, nullptr, nullptr)); + char* msDemangle = llvm::microsoftDemangle(m_nameValue.c_str(), nullptr, nullptr, nullptr, nullptr); + if (msDemangle == nullptr) + return m_nameValue; // TODO: Not good. + std::string demangledNameValue = std::string(msDemangle); size_t beginFind = demangledNameValue.find_first_of(" "); if (beginFind != std::string::npos) @@ -51,10 +53,16 @@ Ref TypeDescriptor::GetType() return TypeBuilder::StructureType(&typeDescriptorBuilder).Finalize(); } -Ref TypeDescriptor::CreateSymbol(std::string name, std::string rawName) +Ref TypeDescriptor::CreateSymbol() { - Ref typeDescSym = new Symbol {DataSymbol, name, name, rawName, m_address}; + Ref typeDescSym = new Symbol {DataSymbol, GetSymbolName(), m_address}; m_view->DefineUserSymbol(typeDescSym); - m_view->DefineDataVariable(m_address, GetType()); + m_view->DefineUserDataVariable(m_address, GetType()); return typeDescSym; +} + +// Example: class Animal `RTTI Type Descriptor' +std::string TypeDescriptor::GetSymbolName() +{ + return "class " + GetDemangledName() + " `RTTI Type Descriptor'"; } \ No newline at end of file diff --git a/src/utils.cpp b/src/utils.cpp index ffe5688..6f518ea 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -1,4 +1,5 @@ #include +#include using namespace BinaryNinja; @@ -47,4 +48,64 @@ Ref GetVirtualFunctionTableTagType(BinaryView* view) view->AddTagType(tagType); } return tagType; -} \ No newline at end of file +} + +Ref GetVirtualFunctionTagType(BinaryView* view) +{ + Ref tagType = view->GetTagType("MSVC Virtual Function"); + if (tagType == nullptr) + { + tagType = new TagType(view, "MSVC Virtual Function", "☝️"); + view->AddTagType(tagType); + } + return tagType; +} + +Ref GetCOLocatorTagType(BinaryView* view) +{ + Ref tagType = view->GetTagType("MSVC Complete Object Locator"); + if (tagType == nullptr) + { + tagType = new TagType(view, "MSVC Complete Object Locator", "✨"); + view->AddTagType(tagType); + } + return tagType; +} + +Ref GetPointerTypeChildStructure(Ref ptrType) +{ + if (!ptrType->IsPointer()) + return 0; + Ref childType = ptrType->GetChildType().GetValue(); + while (childType->IsPointer()) + { + childType = childType->GetChildType().GetValue(); + } + return childType; +} + +uint64_t ResolveRelPointer(BinaryView* view, uint64_t ptrVal) +{ + switch (view->GetAddressSize()) + { + case 8: + return view->GetStart() + ptrVal; + case 4: + return ptrVal; + default: + // TODO: Handle this correctly. + return 0; + } +} + +// void GetAllMembersForStructure(BinaryView* view, Ref func, StructureBuilder structBuilder, Variable var) +// { +// auto mlil = func->GetMediumLevelIL(); +// for (auto varRef : func->GetMediumLevelILVariableReferences(var)) +// { +// auto inst = mlil[varRef.exprId]; +// } + +// // TODO: Scan the function for accesses to the var, make sure this works correctly with inherited classes, i.e. make +// // sure they dont already exist. +// } \ No newline at end of file diff --git a/src/virtual_function.cpp b/src/virtual_function.cpp index 15c4880..aca112c 100644 --- a/src/virtual_function.cpp +++ b/src/virtual_function.cpp @@ -16,9 +16,13 @@ bool VirtualFunction::IsUnique() return m_view->GetDataReferences(m_func->GetStart()).size() == 1; } -Ref VirtualFunction::CreateSymbol(std::string name, std::string rawName) +// TODO: IsThunk +// TODO: IsConstructor? +// TODO: IsDestructor? + +Ref VirtualFunction::CreateSymbol(std::string name) { - Ref newFuncSym = new Symbol {FunctionSymbol, name, name, rawName, m_func->GetStart()}; + Ref newFuncSym = new Symbol {FunctionSymbol, name, m_func->GetStart()}; m_view->DefineUserSymbol(newFuncSym); return newFuncSym; } \ No newline at end of file diff --git a/src/virtual_function_table.cpp b/src/virtual_function_table.cpp index ccc3257..0cda896 100644 --- a/src/virtual_function_table.cpp +++ b/src/virtual_function_table.cpp @@ -36,12 +36,7 @@ std::vector VirtualFunctionTable::GetVirtualFunctions() { LogInfo("Discovered function from vtable reference -> %x", vFuncAddr); m_view->CreateUserFunction(m_view->GetDefaultPlatform(), vFuncAddr); - funcs = m_view->GetAnalysisFunctionsForAddress(vFuncAddr); - if (funcs.empty()) - { - LogWarn("vFunc does not point to function -> %x", vFuncAddr); - break; - } + funcs.emplace_back(m_view->GetAnalysisFunctionsForAddress(vFuncAddr).front()); } else { @@ -50,10 +45,12 @@ std::vector VirtualFunctionTable::GetVirtualFunctions() } } - vFuncs.emplace_back(VirtualFunction(m_view, vFuncAddr, funcs.front())); + for (auto func : funcs) + { + vFuncs.emplace_back(VirtualFunction(m_view, m_address, func)); + } } - return vFuncs; } @@ -63,34 +60,55 @@ CompleteObjectLocator VirtualFunctionTable::GetCOLocator() return CompleteObjectLocator(m_view, dataRefs.front()); } -Ref VirtualFunctionTable::GetType(std::string name, std::string idName) +Ref VirtualFunctionTable::GetType() { - Ref typeCache = m_view->GetTypeById("msvc_" + idName); + QualifiedName typeName = QualifiedName(GetTypeName()); + Ref typeCache = Type::NamedType(m_view, typeName); - if (typeCache == nullptr) + if (m_view->GetTypeByName(typeName) == nullptr) { size_t addrSize = m_view->GetAddressSize(); - StructureBuilder vftBuilder; + StructureBuilder vftBuilder = {}; + vftBuilder.SetPropagateDataVariableReferences(true); size_t vFuncIdx = 0; for (auto&& vFunc : GetVirtualFunctions()) { + // TODO: This needs to be fixed, must update vfunc type to this ptr to our structure. vftBuilder.AddMember( Type::PointerType(addrSize, vFunc.m_func->GetType(), true), "vFunc_" + std::to_string(vFuncIdx)); vFuncIdx++; } - m_view->DefineType("msvc_" + idName, QualifiedName(name), TypeBuilder::StructureType(&vftBuilder).Finalize()); + m_view->DefineUserType(typeName, TypeBuilder::StructureType(&vftBuilder).Finalize()); - typeCache = m_view->GetTypeById("msvc_" + idName); + typeCache = Type::NamedType(m_view, typeName); } return typeCache; } -Ref VirtualFunctionTable::CreateSymbol(std::string name, std::string rawName) +Ref VirtualFunctionTable::CreateSymbol() { - Ref newFuncSym = new Symbol {DataSymbol, name, name, rawName, m_address}; + Ref newFuncSym = new Symbol {DataSymbol, GetSymbolName(), m_address}; m_view->DefineUserSymbol(newFuncSym); - m_view->DefineDataVariable(m_address, GetType(name, rawName)); + m_view->DefineUserDataVariable(m_address, GetType()); return newFuncSym; +} + +// Example: Animal::`vftable' +// If subobject this will return: Bird::`vftable'{for `Flying'} +std::string VirtualFunctionTable::GetSymbolName() +{ + auto coLocator = GetCOLocator(); + std::string className = coLocator.GetClassName(); + if (coLocator.IsSubObject()) + return className + "::`vftable'" + "{for `" + coLocator.GetAssociatedClassName() + "'}"; + return className + "::`vftable'"; +} + +// Example: Animal::VTable +// If subobject this will return the type name of the subobject type. +std::string VirtualFunctionTable::GetTypeName() +{ + return GetCOLocator().GetAssociatedClassName() + "::VTable"; } \ No newline at end of file diff --git a/test/bins/overrides.cpp b/test/bins/overrides.cpp new file mode 100644 index 0000000..9528628 --- /dev/null +++ b/test/bins/overrides.cpp @@ -0,0 +1,89 @@ +#include + +class Animal +{ +public: + const char* name; + + virtual void make_sound() = 0; + + virtual void approach() { make_sound(); } + + void greet() + { + printf("You:\n"); + printf("Hello %s!\n", name); + printf("%s:\n", name); + make_sound(); + } +}; + +class Flying +{ +public: + int max_airspeed; + + virtual void fly() { puts("Up, up, and away!"); } +}; + +class Dog : public Animal +{ +public: + int bark_count; + + virtual void make_sound() override + { + puts("Woof!"); + bark_count++; + } +}; + +class Cat : public Animal +{ +public: + int nap_count; + + virtual void make_sound() override { puts("Meow!"); } + + virtual void approach() override + { + if (nap_count) + puts("Zzzz..."); + else + make_sound(); + } + + virtual void nap() { nap_count++; } +}; + +class Lion : public Cat +{ +public: + virtual void make_sound() override { puts("Roar!"); } +}; + +class Bird : public Animal, public Flying +{ +public: + int song_length; + + Bird(int song_length) + { + this->name = "A bird"; + this->max_airspeed = 88; + this->song_length = song_length; + } + + virtual void make_sound() override + { + for (int i = 0; i < song_length; i++) + puts("Tweet!"); + } + + virtual void approach() override { fly(); } +}; + +int main() +{ + Bird birdObj = Bird(5); +} \ No newline at end of file