diff --git a/deeprank2/query.py b/deeprank2/query.py index 2365e3d5f..d0550d278 100644 --- a/deeprank2/query.py +++ b/deeprank2/query.py @@ -62,6 +62,7 @@ class DeepRankQuery: def __post_init__(self): self._model_id = os.path.splitext(os.path.basename(self.pdb_path))[0] + self.variant = None # not used for PPI, overwritten for SRV if self.resolution not in VALID_RESOLUTIONS: raise ValueError(f"Invalid resolution given ({self.resolution}). Must be one of {VALID_RESOLUTIONS}") @@ -173,7 +174,7 @@ def __repr__(self) -> str: def build( self, - feature_modules: list[str], + feature_modules: list[str | ModuleType], ) -> Graph: """Builds the graph from the .PDB structure. @@ -184,19 +185,30 @@ def build( Returns: :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. """ + # TODO: Should this be an internal method, as it is only called upon by QueryCollection and not by user? - try: - feature_modules = [importlib.import_module('deeprank2.features.' + name) for name in feature_modules] - except TypeError: - if isinstance(feature_modules, ModuleType): - feature_modules = [feature_modules] - elif isinstance(feature_modules, str): - feature_modules = [importlib.import_module('deeprank2.features.' + feature_modules)] + # TODO: I would prefer if `_set_feature_modules` would return a list of `ModuleType` instead of `str`, + # but this leads to an exception in the pool function in `QueryCollection.process`, that I don't know how to solve. + # Tests are currently passing `ModuleType`s directly, which is clearer than passing strings + # and is the reason I am allowing both input types. + + feature_modules = [importlib.import_module('deeprank2.features.' + module) + if isinstance(module, str) else module + for module in feature_modules] self._pssm_required = conservation in feature_modules - return self._child_build(feature_modules) + # TODO: my gut feeling is that the building can be unified further, but it is + # not trivial. I will leave it like this for now. + graph = self._build_helper() + + # add target and feature data to the graph + self._set_graph_targets(graph) + for feature_module in feature_modules: + feature_module.add_features(self.pdb_path, graph, self.variant) + + return graph - def _child_build(self, feature_modules: list[ModuleType]): + def _build_helper(self) -> Graph: raise NotImplementedError("Must be defined in child classes.") def get_query_id(self) -> str: raise NotImplementedError("Must be defined in child classes.") @@ -237,27 +249,19 @@ def get_query_id(self) -> str: + f"{self.wildtype_amino_acid.name}->{self.variant_amino_acid.name}:{self.model_id}" ) - def _child_build( - self, - feature_modules: list[ModuleType], - ) -> Graph: - """Builds the graph from the .PDB structure. - - Args: - feature_modules (list[str]): the feature modules used to build the graph. - These must be filenames existing inside `deeprank2.features` subpackage. + def _build_helper(self) -> Graph: + """Helper function to build a graph for SRV queries. Returns: :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. """ # load .PDB structure - structure: PDBStructure = self._load_structure() + structure = self._load_structure() # find the variant residue and its surroundings variant_residue: Residue = None for residue in structure.get_chain(self.variant_chain_id).residues: - residue: Residue if ( residue.number == self.variant_residue_number and residue.insertion_code == self.insertion_code @@ -268,7 +272,7 @@ def _child_build( raise ValueError( f"Residue not found in {self.pdb_path}: {self.variant_chain_id} {self.residue_id}" ) - variant = SingleResidueVariant(variant_residue, self.variant_amino_acid) + self.variant = SingleResidueVariant(variant_residue, self.variant_amino_acid) residues = get_surrounding_residues(structure, variant_residue, self.radius) # build the graph @@ -287,16 +291,10 @@ def _child_build( graph = build_atomic_graph(atoms, self.get_query_id(), self.distance_cutoff) #TODO: check if this works with a set instead of a list - else: raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.") graph.center = get_residue_center(variant_residue) - # add target and feature data to the graph - self._set_graph_targets(graph) - for feature_module in feature_modules: - feature_module.add_features(self.pdb_path, graph, variant) - return graph @@ -325,21 +323,15 @@ def get_query_id(self) -> str: + f"{self.chain_ids[0]}-{self.chain_ids[1]}:{self.model_id}" ) - def _child_build( - self, - feature_modules: list[ModuleType], - ) -> Graph: + def _build_helper(self) -> Graph: #TODO: check how much of this is common with SRV and move it to parent class - """Builds the graph from the .PDB structure. - - Args: - feature_modules (list[ModuleType]): Each must implement the :py:func:`add_features` function. + """Helper function to build a graph for PPI queries. Returns: :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. """ - # find the contact atoms + # find the atoms near the contact interface contact_atoms = get_contact_atoms( self.pdb_path, self.chain_ids, @@ -359,20 +351,11 @@ def _child_build( raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.") graph.center = np.mean([atom.position for atom in contact_atoms], axis=0) - # add data to the graph - self._set_graph_targets(graph) - - # read the pssm - #TODO: unify with the way pssms are read for srv queries structure = contact_atoms[0].residue.chain.model if self._pssm_required: self._load_pssm_data(structure) - # add the features - for feature_module in feature_modules: - feature_module.add_features(self.pdb_path, graph) - graph.center = np.mean([atom.position for atom in contact_atoms], axis=0) return graph @@ -448,9 +431,6 @@ def _process_one_query(self, query: DeepRankQuery): try: # TODO: Maybe make exception catching optional, because I think it would be good to raise the error by default. output_path = f"{self._prefix}-{os.getpid()}.hdf5" - #TODO: move the line below into generic build method so we can pass a list of strings here. - # feature_modules = [importlib.import_module('deeprank2.features.' + name) for name in self._feature_modules] - # graph = query.build(feature_modules) graph = query.build(self._feature_modules) graph.write_to_hdf5(output_path) diff --git a/tests/test_query.py b/tests/test_query.py index 4c643c464..eb44f43af 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -373,15 +373,15 @@ def test_incorrect_pssm_order(): # check that error is thrown for incorrect pssm with pytest.raises(ValueError): - _ = q.build(conservation) + _ = q.build([conservation]) # no error if conservation module is not used - _ = q.build(components) + _ = q.build([components]) # check that error suppression works with pytest.warns(UserWarning): q.suppress_pssm_errors = True - _ = q.build(conservation) + _ = q.build([conservation]) def test_incomplete_pssm(): @@ -396,15 +396,15 @@ def test_incomplete_pssm(): ) with pytest.raises(ValueError): - _ = q.build(conservation) + _ = q.build([conservation]) # no error if conservation module is not used - _ = q.build(components) + _ = q.build([components]) # check that error suppression works with pytest.warns(UserWarning): q.suppress_pssm_errors = True - _ = q.build(conservation) + _ = q.build([conservation]) def test_no_pssm_provided(): @@ -424,12 +424,12 @@ def test_no_pssm_provided(): ) with pytest.raises(ValueError): - _ = q_empty_dict.build(conservation) - _ = q_not_provided.build(conservation) + _ = q_empty_dict.build([conservation]) + _ = q_not_provided.build([conservation]) # no error if conservation module is not used - _ = q_empty_dict.build(components) - _ = q_not_provided.build(components) + _ = q_empty_dict.build([components]) + _ = q_not_provided.build([components]) def test_incorrect_pssm_provided(): @@ -455,12 +455,12 @@ def test_incorrect_pssm_provided(): ) with pytest.raises(FileNotFoundError): - _ = q_non_existing.build(conservation) - _ = q_missing.build(conservation) + _ = q_non_existing.build([conservation]) + _ = q_missing.build([conservation]) # no error if conservation module is not used - _ = q_non_existing.build(components) - _ = q_missing.build(components) + _ = q_non_existing.build([components]) + _ = q_missing.build([components]) def test_variant_query_multiple_chains(): @@ -480,13 +480,13 @@ def test_variant_query_multiple_chains(): # at radius 10, chain B is included in graph # no error without conservation module - graph = q.build(components) + graph = q.build([components]) assert 'B' in graph.get_all_chains() # if we rebuild the graph with conservation module it should fail with pytest.raises(FileNotFoundError): - _ = q.build(conservation) + _ = q.build([conservation]) # at radius 7, chain B is not included in graph q.radius = 7.0 - graph = q.build(conservation) + graph = q.build([conservation]) assert 'B' not in graph.get_all_chains()