From 8503381d340b302caccea2e96e6bb53e0d5f8a20 Mon Sep 17 00:00:00 2001 From: ImNotAVirus <17680522+ImNotAVirus@users.noreply.github.com> Date: Fri, 2 Oct 2020 23:12:46 +0200 Subject: [PATCH] :recycle: Small code refactoring --- __init__.py | 10 ++-- delphi_analyser.py | 90 +++++++++++++++++++--------------- examples/create_vmt_structs.py | 8 +-- examples/list_vmts.py | 8 +-- examples/vmt_visualizer.py | 10 ++-- 5 files changed, 69 insertions(+), 57 deletions(-) diff --git a/__init__.py b/__init__.py index 6fd5675..5fda4fa 100644 --- a/__init__.py +++ b/__init__.py @@ -3,7 +3,7 @@ from binaryninja.enums import MessageBoxButtonSet, MessageBoxIcon, MessageBoxButtonResult from .bnhelpers import BNHelpers -from .delphi_analyser import ClassFinder, DelphiVMT +from .delphi_analyser import DelphiAnalyzer, DelphiVMT class AnalyzeDelphiVmtsTask(BackgroundTaskThread): @@ -17,8 +17,8 @@ def __init__(self, bv: BinaryView, tag_type: Tag, delphi_version: int): def run(self): self._bv.begin_undo_actions() - finder = ClassFinder(self._bv, self._delphi_version) - finder.update_analysis_and_wait(self.analyze_callback) + analyzer = DelphiAnalyzer(self._bv, self._delphi_version) + analyzer.update_analysis_and_wait(self.analyze_callback) self._bv.commit_undo_actions() self._bv.update_analysis() @@ -41,8 +41,8 @@ def analyze_callback(self, delphi_vmt: DelphiVMT): if function.name.startswith('sub_') or function.name == 'vmt' + delphi_vmt.class_name: self._bv.remove_user_function(function) - # TODO: Clean that later (define a property for VMT tables) - for table_addr in delphi_vmt._get_vmt_tables_addr(): + # Same here + for table_addr in delphi_vmt.table_list.keys(): for function in self._bv.get_functions_containing(table_addr): if function.name.startswith('sub_'): self._bv.remove_user_function(function) diff --git a/delphi_analyser.py b/delphi_analyser.py index b35af8e..e687600 100644 --- a/delphi_analyser.py +++ b/delphi_analyser.py @@ -30,21 +30,25 @@ def __init__(self, bv: BinaryView, delphi_version: int, address: int): self._class_name = '' self._instance_size = 0 self._parent_vmt = 0 + self._table_list: Mapping[int, str] = {} self._virtual_methods: Mapping[int, str] = {} if not self._check_self_ptr(): return - if not self._parse_name(): + if not self._resolve_name(): return - if not self._parse_instance_size(): + if not self._resolve_instance_size(): return - if not self._parse_parent_vmt(): + if not self._resolve_parent_vmt(): return - if not self._parse_virtual_methods(): + if not self._resolve_table_list(): + return + + if not self._resolve_virtual_methods(): return self._is_valid = True @@ -57,7 +61,9 @@ def __repr__(self): def __str__(self): if not self._is_valid: return f'' - return f'<{self._class_name} start=0x{self.start:08X} instance_size=0x{self._instance_size:X}>' + + return (f'<{self._class_name} start=0x{self.start:08X} ' + f'instance_size=0x{self._instance_size:X}>') ## Properties @@ -82,6 +88,10 @@ def instance_size(self) -> int: def parent_vmt(self) -> int: return self._parent_vmt + @property + def table_list(self) -> Mapping[int, str]: + return self._table_list + @property def virtual_methods(self) -> Mapping[int, str]: return self._virtual_methods @@ -107,7 +117,7 @@ def br_offset(self) -> int: ## Public API def seek_to_code(self, address: int) -> bool: - if not self._isValidCodeAdr(address): + if not self._is_valid_code_addr(address): return False self._br.seek(address) @@ -115,7 +125,7 @@ def seek_to_code(self, address: int) -> bool: def seek_to_code_offset(self, offset: int) -> bool: - if not self._isValidCodeAdr(self._code_section.start + offset): + if not self._is_valid_code_addr(self._code_section.start + offset): return False self._br.seek(self._code_section.start + offset) @@ -123,7 +133,7 @@ def seek_to_code_offset(self, offset: int) -> bool: def seek_to_vmt_offset(self, offset: int) -> bool: - if not self._isValidCodeAdr(self._vmt_address + offset): + if not self._is_valid_code_addr(self._vmt_address + offset): return False self._br.seek(self._vmt_address + offset) @@ -148,7 +158,7 @@ def _check_self_ptr(self) -> bool: return self_ptr == self._vmt_address - def _parse_name(self) -> bool: + def _resolve_name(self) -> bool: class_name_addr = self._get_class_name_addr() if class_name_addr is None: @@ -172,7 +182,7 @@ def _parse_name(self) -> bool: return True - def _parse_instance_size(self) -> bool: + def _resolve_instance_size(self) -> bool: if not self.seek_to_vmt_offset(self._vmt_offsets.cVmtInstanceSize): return False @@ -180,7 +190,7 @@ def _parse_instance_size(self) -> bool: return True - def _parse_parent_vmt(self) -> bool: + def _resolve_parent_vmt(self) -> bool: if not self.seek_to_vmt_offset(self._vmt_offsets.cVmtParent): return False @@ -188,7 +198,7 @@ def _parse_parent_vmt(self) -> bool: return True - def _parse_virtual_methods(self) -> bool: + def _resolve_virtual_methods(self) -> bool: class_name_addr = self._get_class_name_addr() if class_name_addr is None: @@ -197,7 +207,7 @@ def _parse_virtual_methods(self) -> bool: address_size = self._bv.address_size offsets = self.vmt_offsets.__dict__.items() offset_map = {y:x for x, y in offsets} - tables_addr = self._get_vmt_tables_addr() + tables_addr = self._table_list.keys() if not self.seek_to_vmt_offset(self._vmt_offsets.cVmtParent + address_size): return False @@ -208,7 +218,7 @@ def _parse_virtual_methods(self) -> bool: if field_value == 0: continue - if not self._isValidCodeAdr(field_value): + if not self._is_valid_code_addr(field_value): prev_offset = self._br.offset - address_size raise RuntimeError(f'Invalid code address deteted at 0x{prev_offset:08X} ' '({self.class_name})\n If you think it\'s a bug, please open an issue on ' @@ -227,7 +237,31 @@ def _parse_virtual_methods(self) -> bool: return True - def _isValidCodeAdr(self, addy: int, allow_null=False) -> bool: + def _resolve_table_list(self) -> bool: + if not self.seek_to_vmt_offset(self.vmt_offsets.cVmtIntfTable): + return False + + offsets = self._vmt_offsets.__dict__.items() + offset_map = {y:x[4:] for x, y in offsets} + + stop_at = self._vmt_address + self._vmt_offsets.cVmtClassName + + while self._br.offset != stop_at: + prev_br_offset = self._br.offset + address = self._br.read32() + + if address < 1: + continue + + if not self._is_valid_code_addr(address): + raise RuntimeError('Invalid table address detected') + + self._table_list[address] = offset_map[prev_br_offset - self._vmt_address] + + return True + + + def _is_valid_code_addr(self, addy: int, allow_null=False) -> bool: if addy == 0: return allow_null return addy >= self._code_section.start and addy < self._code_section.end @@ -239,34 +273,13 @@ def _get_class_name_addr(self) -> Union[None, int]: class_name_addr = self._br.read32() - if not self._isValidCodeAdr(class_name_addr): + if not self._is_valid_code_addr(class_name_addr): return None return class_name_addr - def _get_vmt_tables_addr(self) -> Union[None, List[int]]: - if not self.seek_to_vmt_offset(self.vmt_offsets.cVmtIntfTable): - return - - result = [] - stop_at = self._vmt_address + self.vmt_offsets.cVmtClassName - - while self._br.offset != stop_at: - address = self._br.read32() - - if address < 1: - continue - - if not self._isValidCodeAdr(address): - raise RuntimeError('Invalid table address detected') - - result.append(address) - - return result - - -class ClassFinder(object): +class DelphiAnalyzer(object): ''' TODO: Doc ''' @@ -286,7 +299,6 @@ def __init__(self, bv: BinaryView, delphi_version: int): def delphi_version(self) -> int: return self._delphi_version - @property def vmt_list(self) -> List[DelphiVMT]: return self._vmt_list diff --git a/examples/create_vmt_structs.py b/examples/create_vmt_structs.py index cd760fd..b46022b 100644 --- a/examples/create_vmt_structs.py +++ b/examples/create_vmt_structs.py @@ -5,7 +5,7 @@ import sys from os import path -# from delphi_ninja.delphi_analyser import ClassFinder, DelphiVMT +# from delphi_ninja.delphi_analyser import DelphiAnalyzer, DelphiVMT # from delphi_ninja.bnlogger import BNLogger # from delphi_ninja.bnhelpers import BNHelpers @@ -14,7 +14,7 @@ module_parent = path.dirname(module_dir) sys.path.insert(0, module_parent) delphi_ninja = importlib.import_module(module_name) -ClassFinder = delphi_ninja.delphi_analyser.ClassFinder +DelphiAnalyzer = delphi_ninja.delphi_analyser.DelphiAnalyzer DelphiVMT = delphi_ninja.delphi_analyser.DelphiVMT BNLogger = delphi_ninja.bnlogger.BNLogger BNHelpers = delphi_ninja.bnhelpers.BNHelpers @@ -44,8 +44,8 @@ def main(target: str, delphi_version: int): BNLogger.log('-----------------------------') BNLogger.log('Searching for VMT...') - finder = ClassFinder(bv, delphi_version) - finder.update_analysis_and_wait(lambda vmt: analyze_callback(vmt, bv)) + analyzer = DelphiAnalyzer(bv, delphi_version) + analyzer.update_analysis_and_wait(lambda vmt: analyze_callback(vmt, bv)) bv.update_analysis_and_wait() diff --git a/examples/list_vmts.py b/examples/list_vmts.py index f55a3e2..24d7cb1 100644 --- a/examples/list_vmts.py +++ b/examples/list_vmts.py @@ -6,7 +6,7 @@ from os import path # from delphi_ninja.bnlogger import BNLogger -# from delphi_ninja.delphi_analyser import ClassFinder +# from delphi_ninja.delphi_analyser import DelphiAnalyzer module_dir = path.dirname(path.dirname(path.abspath(__file__))) module_name = path.basename(module_dir) @@ -14,7 +14,7 @@ sys.path.insert(0, module_parent) delphi_ninja = importlib.import_module(module_name) BNLogger = delphi_ninja.bnlogger.BNLogger -ClassFinder = delphi_ninja.delphi_analyser.ClassFinder +DelphiAnalyzer = delphi_ninja.delphi_analyser.DelphiAnalyzer def main(target: str, delphi_version: int): @@ -36,8 +36,8 @@ def main(target: str, delphi_version: int): BNLogger.log('-----------------------------') BNLogger.log('Searching for VMT...') - finder = ClassFinder(bv, delphi_version) - finder.update_analysis_and_wait(lambda vmt: BNLogger.log(vmt)) + analyzer = DelphiAnalyzer(bv, delphi_version) + analyzer.update_analysis_and_wait(lambda vmt: BNLogger.log(vmt)) if __name__ == '__main__': diff --git a/examples/vmt_visualizer.py b/examples/vmt_visualizer.py index 27d055e..9a63770 100644 --- a/examples/vmt_visualizer.py +++ b/examples/vmt_visualizer.py @@ -8,7 +8,7 @@ from os import path # from delphi_ninja.bnlogger import BNLogger -# from delphi_ninja.delphi_analyser import ClassFinder, DelphiVMT +# from delphi_ninja.delphi_analyser import DelphiAnalyzer, DelphiVMT module_dir = path.dirname(path.dirname(path.abspath(__file__))) module_name = path.basename(module_dir) @@ -17,7 +17,7 @@ delphi_ninja = importlib.import_module(module_name) BNLogger = delphi_ninja.bnlogger.BNLogger -ClassFinder = delphi_ninja.delphi_analyser.ClassFinder +DelphiAnalyzer = delphi_ninja.delphi_analyser.DelphiAnalyzer DelphiVMT = delphi_ninja.delphi_analyser.DelphiVMT @@ -58,11 +58,11 @@ def main(target: str, delphi_version: int): BNLogger.log('File loaded') BNLogger.log('Searching for VMT...') - finder = ClassFinder(bv, delphi_version) - finder.update_analysis_and_wait() + analyzer = DelphiAnalyzer(bv, delphi_version) + analyzer.update_analysis_and_wait() BNLogger.log('Creating Graph...') - vmt_map = {vmt.start:vmt for vmt in finder.vmt_list} + vmt_map = {vmt.start:vmt for vmt in analyzer.vmt_list} create_graph(vmt_map)