diff --git a/deeprank2/query.py b/deeprank2/query.py index 9b3d54879..2365e3d5f 100644 --- a/deeprank2/query.py +++ b/deeprank2/query.py @@ -80,7 +80,7 @@ def _set_graph_targets(self, graph: Graph): for target_name, target_data in self.targets.items(): graph.targets[target_name] = target_data - def _load_structure(self, pssm_required: bool) -> PDBStructure: + def _load_structure(self) -> PDBStructure: """Build PDBStructure objects from pdb and pssm data.""" pdb = pdb2sql.pdb2sql(self.pdb_path) try: @@ -88,7 +88,7 @@ def _load_structure(self, pssm_required: bool) -> PDBStructure: finally: pdb._close() # pylint: disable=protected-access # read the pssm - if pssm_required: + if self._pssm_required: self._load_pssm_data(structure) return structure @@ -171,7 +171,32 @@ def model_id(self, value: str): def __repr__(self) -> str: return f"{type(self)}({self.get_query_id()})" - def build(self, feature_modules: list[ModuleType]) -> Graph: + def build( + self, + feature_modules: list[str], + ) -> Graph: + """Builds the graph from the .PDB structure. + + Args: + feature_modules (list[str]): the feature modules used to build the graph. + These must be filenames existing inside `deeprank2.features` subpackage. + + Returns: + :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. + """ + + try: + feature_modules = [importlib.import_module('deeprank2.features.' + name) for name in feature_modules] + except TypeError: + if isinstance(feature_modules, ModuleType): + feature_modules = [feature_modules] + elif isinstance(feature_modules, str): + feature_modules = [importlib.import_module('deeprank2.features.' + feature_modules)] + self._pssm_required = conservation in feature_modules + + return self._child_build(feature_modules) + + def _child_build(self, feature_modules: list[ModuleType]): raise NotImplementedError("Must be defined in child classes.") def get_query_id(self) -> str: raise NotImplementedError("Must be defined in child classes.") @@ -212,30 +237,25 @@ def get_query_id(self) -> str: + f"{self.wildtype_amino_acid.name}->{self.variant_amino_acid.name}:{self.model_id}" ) - def build( + def _child_build( self, - feature_modules: list[ModuleType] | ModuleType, + feature_modules: list[ModuleType], ) -> Graph: - #TODO: check how much of this is common with PPI and move it to parent class """Builds the graph from the .PDB structure. Args: - feature_modules (list[ModuleType]): Each must implement the :py:func:`add_features` function. + feature_modules (list[str]): the feature modules used to build the graph. + These must be filenames existing inside `deeprank2.features` subpackage. Returns: :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. """ # load .PDB structure - if isinstance(feature_modules, list): - pssm_required = conservation in feature_modules - else: - pssm_required = conservation == feature_modules - feature_modules = [feature_modules] - structure: PDBStructure = self._load_structure(pssm_required) + structure: PDBStructure = self._load_structure() # find the variant residue and its surroundings - variant_residue = None + variant_residue: Residue = None for residue in structure.get_chain(self.variant_chain_id).residues: residue: Residue if ( @@ -272,7 +292,7 @@ def build( raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.") graph.center = get_residue_center(variant_residue) - # add data to the graph + # add target and feature data to the graph self._set_graph_targets(graph) for feature_module in feature_modules: feature_module.add_features(self.pdb_path, graph, variant) @@ -305,9 +325,9 @@ def get_query_id(self) -> str: + f"{self.chain_ids[0]}-{self.chain_ids[1]}:{self.model_id}" ) - def build( + def _child_build( self, - feature_modules: list[ModuleType] | ModuleType, + feature_modules: list[ModuleType], ) -> Graph: #TODO: check how much of this is common with SRV and move it to parent class """Builds the graph from the .PDB structure. @@ -319,7 +339,12 @@ def build( :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. """ - contact_atoms = get_contact_atoms(self.pdb_path, self.chain_ids, self.distance_cutoff) + # find the contact atoms + contact_atoms = get_contact_atoms( + self.pdb_path, + self.chain_ids, + self.distance_cutoff + ) if len(contact_atoms) == 0: raise ValueError("no contact atoms found") @@ -341,9 +366,7 @@ def build( #TODO: unify with the way pssms are read for srv queries structure = contact_atoms[0].residue.chain.model - if not isinstance(feature_modules, list): - feature_modules = [feature_modules] - if conservation in feature_modules: + if self._pssm_required: self._load_pssm_data(structure) # add the features @@ -426,8 +449,9 @@ def _process_one_query(self, query: DeepRankQuery): # TODO: Maybe make exception catching optional, because I think it would be good to raise the error by default. output_path = f"{self._prefix}-{os.getpid()}.hdf5" #TODO: move the line below into generic build method so we can pass a list of strings here. - feature_modules = [importlib.import_module('deeprank2.features.' + name) for name in self._feature_modules] - graph = query.build(feature_modules) + # feature_modules = [importlib.import_module('deeprank2.features.' + name) for name in self._feature_modules] + # graph = query.build(feature_modules) + graph = query.build(self._feature_modules) graph.write_to_hdf5(output_path) if self._grid_settings is not None and self._grid_map_method is not None: