From db19653c78c38b298b83b953260a1afe1afcdb95 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Thu, 7 Sep 2023 04:09:49 +0200 Subject: [PATCH] stuff --- deeprank2/query.py | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/deeprank2/query.py b/deeprank2/query.py index 70c7e64b7..59b3425f9 100644 --- a/deeprank2/query.py +++ b/deeprank2/query.py @@ -150,14 +150,7 @@ def _load_structure(self, load_pssms: bool) -> PDBStructure: # read the pssm if load_pssms: - _check_pssm(self.pdb_path, self.pssm_paths, suppress = self.suppress_pssm_errors) - for chain in structure.chains: - chain: Chain - if chain.id in self.pssm_paths: - pssm_path = self.pssm_paths[chain.id] - - with open(pssm_path, "rt", encoding="utf-8") as f: - chain.pssm = parse_pssm(f, chain) + _load_ppi_pssms(self.pssm_paths, structure, self.pdb_path, self.suppress_pssm_errors) return structure @@ -191,6 +184,7 @@ def __repr__(self) -> str: return f"{type(self)}({self.get_query_id()})" def build(self, feature_modules: List[ModuleType]) -> Graph: + #TODO: convert feature_modules to list of not already the case raise NotImplementedError("Must be defined in child classes.") def get_query_id(self) -> str: raise NotImplementedError("Must be defined in child classes.") @@ -436,7 +430,7 @@ def build( feature_modules = [feature_modules] structure: PDBStructure = self._load_structure(load_pssms) - # find the variant residue + # find the variant residue and its surroundings variant_residue = None for residue in structure.get_chain(self.variant_chain_id).residues: residue: Residue @@ -456,7 +450,6 @@ def build( # build the graph if self.resolution == 'residue': graph = build_residue_graph(residues, self.get_query_id(), self.distance_cutoff) - elif self.resolution == 'atomic': residues.append(variant_residue) atoms = set([]) @@ -473,12 +466,10 @@ def build( else: raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.") - graph.center = get_residue_center(variant_residue) # add data to the graph self._set_graph_targets(graph) - for feature_module in feature_modules: feature_module.add_features(self.pdb_path, graph, variant) @@ -487,19 +478,21 @@ def build( def _load_ppi_pssms( pssm_paths: Optional[Dict[str, str]], - chain_ids: List[str], structure: PDBStructure, pdb_path: str, suppress_error: bool, ): _check_pssm(pdb_path, pssm_paths, suppress_error, verbosity = 0) - for chain_id in chain_ids: - if chain_id in pssm_paths: - chain = structure.get_chain(chain_id) - pssm_path = pssm_paths[chain_id] + for chain in structure.chains: + chain: Chain + if chain.id in pssm_paths: + pssm_path = pssm_paths[chain.id] + with open(pssm_path, "rt", encoding="utf-8") as f: chain.pssm = parse_pssm(f, chain) + + @dataclass(kw_only=True) class ProteinProteinInterfaceQuery(DeepRankQuery): """A query that builds a protein-protein interface graph.""" @@ -525,7 +518,6 @@ def get_query_id(self) -> str: + f"{self.chain_ids[0]}-{self.chain_ids[1]}:{self.model_id}" ) - def build( self, feature_modules: List[ModuleType] | ModuleType, @@ -544,13 +536,16 @@ def build( if len(contact_atoms) == 0: raise ValueError("no contact atoms found") + # build the graph if self.resolution == 'atomic': graph = build_atomic_graph(contact_atoms, self.get_query_id(), self.distance_cutoff) - elif self.resolution == 'residue': residues_selected = {atom.residue for atom in contact_atoms} graph = build_residue_graph(list(residues_selected), self.get_query_id(), self.distance_cutoff) #TODO: check whether this works with a set instead of a list + else: + raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.") + graph.center = np.mean([atom.position for atom in contact_atoms], axis=0) # add data to the graph self._set_graph_targets(graph) @@ -563,7 +558,6 @@ def build( feature_modules = [feature_modules] if conservation in feature_modules: _load_ppi_pssms(self.pssm_paths, - [self.chain_ids[0], self.chain_ids[1]], structure, self.pdb_path, suppress_error=self.suppress_pssm_errors)