Skip to content

Commit

Permalink
♻️ Small code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ImNotAVirus committed Feb 17, 2022
1 parent caa6a62 commit 8503381
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 57 deletions.
10 changes: 5 additions & 5 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from binaryninja.enums import MessageBoxButtonSet, MessageBoxIcon, MessageBoxButtonResult

from .bnhelpers import BNHelpers
from .delphi_analyser import ClassFinder, DelphiVMT
from .delphi_analyser import DelphiAnalyzer, DelphiVMT


class AnalyzeDelphiVmtsTask(BackgroundTaskThread):
Expand All @@ -17,8 +17,8 @@ def __init__(self, bv: BinaryView, tag_type: Tag, delphi_version: int):
def run(self):
self._bv.begin_undo_actions()

finder = ClassFinder(self._bv, self._delphi_version)
finder.update_analysis_and_wait(self.analyze_callback)
analyzer = DelphiAnalyzer(self._bv, self._delphi_version)
analyzer.update_analysis_and_wait(self.analyze_callback)

self._bv.commit_undo_actions()
self._bv.update_analysis()
Expand All @@ -41,8 +41,8 @@ def analyze_callback(self, delphi_vmt: DelphiVMT):
if function.name.startswith('sub_') or function.name == 'vmt' + delphi_vmt.class_name:
self._bv.remove_user_function(function)

# TODO: Clean that later (define a property for VMT tables)
for table_addr in delphi_vmt._get_vmt_tables_addr():
# Same here
for table_addr in delphi_vmt.table_list.keys():
for function in self._bv.get_functions_containing(table_addr):
if function.name.startswith('sub_'):
self._bv.remove_user_function(function)
Expand Down
90 changes: 51 additions & 39 deletions delphi_analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,25 @@ def __init__(self, bv: BinaryView, delphi_version: int, address: int):
self._class_name = ''
self._instance_size = 0
self._parent_vmt = 0
self._table_list: Mapping[int, str] = {}
self._virtual_methods: Mapping[int, str] = {}

if not self._check_self_ptr():
return

if not self._parse_name():
if not self._resolve_name():
return

if not self._parse_instance_size():
if not self._resolve_instance_size():
return

if not self._parse_parent_vmt():
if not self._resolve_parent_vmt():
return

if not self._parse_virtual_methods():
if not self._resolve_table_list():
return

if not self._resolve_virtual_methods():
return

self._is_valid = True
Expand All @@ -57,7 +61,9 @@ def __repr__(self):
def __str__(self):
if not self._is_valid:
return f'<InvalidVmt address=0x{self._vmt_address:08X}>'
return f'<{self._class_name} start=0x{self.start:08X} instance_size=0x{self._instance_size:X}>'

return (f'<{self._class_name} start=0x{self.start:08X} '
f'instance_size=0x{self._instance_size:X}>')


## Properties
Expand All @@ -82,6 +88,10 @@ def instance_size(self) -> int:
def parent_vmt(self) -> int:
return self._parent_vmt

@property
def table_list(self) -> Mapping[int, str]:
return self._table_list

@property
def virtual_methods(self) -> Mapping[int, str]:
return self._virtual_methods
Expand All @@ -107,23 +117,23 @@ def br_offset(self) -> int:
## Public API

def seek_to_code(self, address: int) -> bool:
if not self._isValidCodeAdr(address):
if not self._is_valid_code_addr(address):
return False

self._br.seek(address)
return True


def seek_to_code_offset(self, offset: int) -> bool:
if not self._isValidCodeAdr(self._code_section.start + offset):
if not self._is_valid_code_addr(self._code_section.start + offset):
return False

self._br.seek(self._code_section.start + offset)
return True


def seek_to_vmt_offset(self, offset: int) -> bool:
if not self._isValidCodeAdr(self._vmt_address + offset):
if not self._is_valid_code_addr(self._vmt_address + offset):
return False

self._br.seek(self._vmt_address + offset)
Expand All @@ -148,7 +158,7 @@ def _check_self_ptr(self) -> bool:
return self_ptr == self._vmt_address


def _parse_name(self) -> bool:
def _resolve_name(self) -> bool:
class_name_addr = self._get_class_name_addr()

if class_name_addr is None:
Expand All @@ -172,23 +182,23 @@ def _parse_name(self) -> bool:
return True


def _parse_instance_size(self) -> bool:
def _resolve_instance_size(self) -> bool:
if not self.seek_to_vmt_offset(self._vmt_offsets.cVmtInstanceSize):
return False

self._instance_size = self._br.read32()
return True


def _parse_parent_vmt(self) -> bool:
def _resolve_parent_vmt(self) -> bool:
if not self.seek_to_vmt_offset(self._vmt_offsets.cVmtParent):
return False

self._parent_vmt = self._br.read32()
return True


def _parse_virtual_methods(self) -> bool:
def _resolve_virtual_methods(self) -> bool:
class_name_addr = self._get_class_name_addr()

if class_name_addr is None:
Expand All @@ -197,7 +207,7 @@ def _parse_virtual_methods(self) -> bool:
address_size = self._bv.address_size
offsets = self.vmt_offsets.__dict__.items()
offset_map = {y:x for x, y in offsets}
tables_addr = self._get_vmt_tables_addr()
tables_addr = self._table_list.keys()

if not self.seek_to_vmt_offset(self._vmt_offsets.cVmtParent + address_size):
return False
Expand All @@ -208,7 +218,7 @@ def _parse_virtual_methods(self) -> bool:
if field_value == 0:
continue

if not self._isValidCodeAdr(field_value):
if not self._is_valid_code_addr(field_value):
prev_offset = self._br.offset - address_size
raise RuntimeError(f'Invalid code address deteted at 0x{prev_offset:08X} '
'({self.class_name})\n If you think it\'s a bug, please open an issue on '
Expand All @@ -227,7 +237,31 @@ def _parse_virtual_methods(self) -> bool:
return True


def _isValidCodeAdr(self, addy: int, allow_null=False) -> bool:
def _resolve_table_list(self) -> bool:
if not self.seek_to_vmt_offset(self.vmt_offsets.cVmtIntfTable):
return False

offsets = self._vmt_offsets.__dict__.items()
offset_map = {y:x[4:] for x, y in offsets}

stop_at = self._vmt_address + self._vmt_offsets.cVmtClassName

while self._br.offset != stop_at:
prev_br_offset = self._br.offset
address = self._br.read32()

if address < 1:
continue

if not self._is_valid_code_addr(address):
raise RuntimeError('Invalid table address detected')

self._table_list[address] = offset_map[prev_br_offset - self._vmt_address]

return True


def _is_valid_code_addr(self, addy: int, allow_null=False) -> bool:
if addy == 0:
return allow_null
return addy >= self._code_section.start and addy < self._code_section.end
Expand All @@ -239,34 +273,13 @@ def _get_class_name_addr(self) -> Union[None, int]:

class_name_addr = self._br.read32()

if not self._isValidCodeAdr(class_name_addr):
if not self._is_valid_code_addr(class_name_addr):
return None

return class_name_addr


def _get_vmt_tables_addr(self) -> Union[None, List[int]]:
if not self.seek_to_vmt_offset(self.vmt_offsets.cVmtIntfTable):
return

result = []
stop_at = self._vmt_address + self.vmt_offsets.cVmtClassName

while self._br.offset != stop_at:
address = self._br.read32()

if address < 1:
continue

if not self._isValidCodeAdr(address):
raise RuntimeError('Invalid table address detected')

result.append(address)

return result


class ClassFinder(object):
class DelphiAnalyzer(object):
'''
TODO: Doc
'''
Expand All @@ -286,7 +299,6 @@ def __init__(self, bv: BinaryView, delphi_version: int):
def delphi_version(self) -> int:
return self._delphi_version


@property
def vmt_list(self) -> List[DelphiVMT]:
return self._vmt_list
Expand Down
8 changes: 4 additions & 4 deletions examples/create_vmt_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
from os import path

# from delphi_ninja.delphi_analyser import ClassFinder, DelphiVMT
# from delphi_ninja.delphi_analyser import DelphiAnalyzer, DelphiVMT
# from delphi_ninja.bnlogger import BNLogger
# from delphi_ninja.bnhelpers import BNHelpers

Expand All @@ -14,7 +14,7 @@
module_parent = path.dirname(module_dir)
sys.path.insert(0, module_parent)
delphi_ninja = importlib.import_module(module_name)
ClassFinder = delphi_ninja.delphi_analyser.ClassFinder
DelphiAnalyzer = delphi_ninja.delphi_analyser.DelphiAnalyzer
DelphiVMT = delphi_ninja.delphi_analyser.DelphiVMT
BNLogger = delphi_ninja.bnlogger.BNLogger
BNHelpers = delphi_ninja.bnhelpers.BNHelpers
Expand Down Expand Up @@ -44,8 +44,8 @@ def main(target: str, delphi_version: int):
BNLogger.log('-----------------------------')
BNLogger.log('Searching for VMT...')

finder = ClassFinder(bv, delphi_version)
finder.update_analysis_and_wait(lambda vmt: analyze_callback(vmt, bv))
analyzer = DelphiAnalyzer(bv, delphi_version)
analyzer.update_analysis_and_wait(lambda vmt: analyze_callback(vmt, bv))

bv.update_analysis_and_wait()

Expand Down
8 changes: 4 additions & 4 deletions examples/list_vmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from os import path

# from delphi_ninja.bnlogger import BNLogger
# from delphi_ninja.delphi_analyser import ClassFinder
# from delphi_ninja.delphi_analyser import DelphiAnalyzer

module_dir = path.dirname(path.dirname(path.abspath(__file__)))
module_name = path.basename(module_dir)
module_parent = path.dirname(module_dir)
sys.path.insert(0, module_parent)
delphi_ninja = importlib.import_module(module_name)
BNLogger = delphi_ninja.bnlogger.BNLogger
ClassFinder = delphi_ninja.delphi_analyser.ClassFinder
DelphiAnalyzer = delphi_ninja.delphi_analyser.DelphiAnalyzer


def main(target: str, delphi_version: int):
Expand All @@ -36,8 +36,8 @@ def main(target: str, delphi_version: int):
BNLogger.log('-----------------------------')
BNLogger.log('Searching for VMT...')

finder = ClassFinder(bv, delphi_version)
finder.update_analysis_and_wait(lambda vmt: BNLogger.log(vmt))
analyzer = DelphiAnalyzer(bv, delphi_version)
analyzer.update_analysis_and_wait(lambda vmt: BNLogger.log(vmt))


if __name__ == '__main__':
Expand Down
10 changes: 5 additions & 5 deletions examples/vmt_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from os import path

# from delphi_ninja.bnlogger import BNLogger
# from delphi_ninja.delphi_analyser import ClassFinder, DelphiVMT
# from delphi_ninja.delphi_analyser import DelphiAnalyzer, DelphiVMT

module_dir = path.dirname(path.dirname(path.abspath(__file__)))
module_name = path.basename(module_dir)
Expand All @@ -17,7 +17,7 @@
delphi_ninja = importlib.import_module(module_name)

BNLogger = delphi_ninja.bnlogger.BNLogger
ClassFinder = delphi_ninja.delphi_analyser.ClassFinder
DelphiAnalyzer = delphi_ninja.delphi_analyser.DelphiAnalyzer
DelphiVMT = delphi_ninja.delphi_analyser.DelphiVMT


Expand Down Expand Up @@ -58,11 +58,11 @@ def main(target: str, delphi_version: int):
BNLogger.log('File loaded')
BNLogger.log('Searching for VMT...')

finder = ClassFinder(bv, delphi_version)
finder.update_analysis_and_wait()
analyzer = DelphiAnalyzer(bv, delphi_version)
analyzer.update_analysis_and_wait()

BNLogger.log('Creating Graph...')
vmt_map = {vmt.start:vmt for vmt in finder.vmt_list}
vmt_map = {vmt.start:vmt for vmt in analyzer.vmt_list}
create_graph(vmt_map)


Expand Down

0 comments on commit 8503381

Please sign in to comment.