diff --git a/README.md b/README.md index 9d8e1ce0f..1657a9c1c 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ DeepRank2 extensive documentation can be found [here](https://deeprank2.rtfd.io/ ## Installation -The package officially supports ubuntu-latest OS only, whose functioning is widely tested through the continuous integration workflows. +The package officially supports ubuntu-latest OS only, whose functioning is widely tested through the continuous integration workflows. ### Dependencies @@ -65,9 +65,9 @@ Before installing deeprank2 you need to install some dependencies. We advise to * [DSSP 4](https://swift.cmbi.umcn.nl/gv/dssp/) * Check if `dssp` is installed: `dssp --version`. If this gives an error or shows a version lower than 4: * on ubuntu 22.04 or newer: `sudo apt-get install dssp`. If the package cannot be located, first run `sudo apt-get update`. - * on older versions of ubuntu or on mac or lacking sudo priviliges: install from [here](https://github.com/pdb-redo/dssp), following the instructions listed. Alternatively, follow [this](https://github.com/PDB-REDO/libcifpp/issues/49) thread. + * on older versions of ubuntu or on mac or lacking sudo priviliges: install from [here](https://github.com/pdb-redo/dssp), following the instructions listed. Alternatively, follow [this](https://github.com/PDB-REDO/libcifpp/issues/49) thread. * [GCC](https://gcc.gnu.org/install/) - * Check if gcc is installed: `gcc --version`. If this gives an error, run `sudo apt-get install gcc`. + * Check if gcc is installed: `gcc --version`. If this gives an error, run `sudo apt-get install gcc`. * For MacOS with M1 chip users only install [the conda version of PyTables](https://www.pytables.org/usersguide/installation.html). ### Deeprank2 Package @@ -105,25 +105,24 @@ For more details, see the [extended documentation](https://deeprank2.rtfd.io/). ### Data generation -For each protein-protein complex (or protein structure containing a SRV), a query can be created and added to the `QueryCollection` object, to be processed later on. Different types of queries exist: -- In a `ProteinProteinInterfaceResidueQuery` and `SingleResidueVariantResidueQuery`, each node represents one amino acid residue. -- In a `ProteinProteinInterfaceAtomicQuery` and `SingleResidueVariantAtomicQuery`, each node represents one atom within the amino acid residues. +For each protein-protein complex (or protein structure containing a missense variant), a `Query` can be created and added to the `QueryCollection` object, to be processed later on. Two subtypes of `Query` exist: `ProteinProteinInterfaceQuery` and `SingleResidueVariantQuery`. -A query takes as inputs: -- a `.pdb` file, representing the protein-protein structure +A `Query` takes as inputs: +- a `.pdb` file, representing the protein-protein structure, +- the resolution (`"residue"` or `"atom"`), i.e. whether each node should represent an amino acid residue or an atom, - the ids of the chains composing the structure, and - optionally, the correspondent position-specific scoring matrices (PSSMs), in the form of `.pssm` files. ```python -from deeprank2.query import QueryCollection, ProteinProteinInterfaceResidueQuery +from deeprank2.query import QueryCollection, ProteinProteinInterfaceQuery queries = QueryCollection() # Append data points -queries.add(ProteinProteinInterfaceResidueQuery( +queries.add(ProteinProteinInterfaceQuery( pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb", - chain_id1 = "A", - chain_id2 = "B", + resolution = "residue", + chain_ids = ["A", "B"], targets = { "binary": 0 }, @@ -132,10 +131,10 @@ queries.add(ProteinProteinInterfaceResidueQuery( "B": "tests/data/pssm/1ATN/1ATN.B.pdb.pssm" } )) -queries.add(ProteinProteinInterfaceResidueQuery( +queries.add(ProteinProteinInterfaceQuery( pdb_path = "tests/data/pdb/1ATN/1ATN_2w.pdb", - chain_id1 = "A", - chain_id2 = "B", + resolution = "residue", + chain_ids = ["A", "B"], targets = { "binary": 1 }, @@ -144,10 +143,10 @@ queries.add(ProteinProteinInterfaceResidueQuery( "B": "tests/data/pssm/1ATN/1ATN.B.pdb.pssm" } )) -queries.add(ProteinProteinInterfaceResidueQuery( +queries.add(ProteinProteinInterfaceQuery( pdb_path = "tests/data/pdb/1ATN/1ATN_3w.pdb", - chain_id1 = "A", - chain_id2 = "B", + resolution = "residue", + chain_ids = ["A", "B"], targets = { "binary": 0 }, diff --git a/deeprank2/dataset.py b/deeprank2/dataset.py index a915606b2..13c9ea804 100644 --- a/deeprank2/dataset.py +++ b/deeprank2/dataset.py @@ -346,14 +346,14 @@ def save_hist( # pylint: disable=too-many-arguments, too-many-branches, useless- for row, feat in enumerate(features_df): if isinstance(self.df[feat].values[0], np.ndarray): - if(log): + if log: log_data = np.log(np.concatenate(self.df[feat].values)) log_data[log_data == -np.inf] = 0 axs[row].hist(log_data, bins=bins) else: axs[row].hist(np.concatenate(self.df[feat].values), bins=bins) else: - if(log): + if log: log_data = np.log(self.df[feat].values) log_data[log_data == -np.inf] = 0 axs[row].hist(log_data, bins=bins) @@ -366,14 +366,14 @@ def save_hist( # pylint: disable=too-many-arguments, too-many-branches, useless- fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) if isinstance(self.df[features_df[0]].values[0], np.ndarray): - if(log): + if log: log_data = np.log(np.concatenate(self.df[features_df[0]].values)) log_data[log_data == -np.inf] = 0 ax.hist(log_data, bins=bins) else: ax.hist(np.concatenate(self.df[features_df[0]].values), bins=bins) else: - if(log): + if log: log_data = np.log(self.df[features_df[0]].values) log_data[log_data == -np.inf] = 0 ax.hist(log_data, bins=bins) diff --git a/deeprank2/domain/aminoacidlist.py b/deeprank2/domain/aminoacidlist.py index 8b7d7feda..b57831290 100644 --- a/deeprank2/domain/aminoacidlist.py +++ b/deeprank2/domain/aminoacidlist.py @@ -353,8 +353,6 @@ ] def convert_aa_nomenclature(aa: str, output_type: Optional[int] = None): - - # pylint: disable = raise-missing-from try: if len(aa) == 1: aa: AminoAcid = [entry for entry in amino_acids if entry.one_letter_code.lower() == aa.lower()][0] @@ -362,8 +360,8 @@ def convert_aa_nomenclature(aa: str, output_type: Optional[int] = None): aa: AminoAcid = [entry for entry in amino_acids if entry.three_letter_code.lower() == aa.lower()][0] else: aa: AminoAcid = [entry for entry in amino_acids if entry.name.lower() == aa.lower()][0] - except IndexError: - raise ValueError(f'{aa} is not a valid amino acid.') + except IndexError as e: + raise ValueError(f'{aa} is not a valid amino acid.') from e if not output_type: return aa.name diff --git a/deeprank2/features/exposure.py b/deeprank2/features/exposure.py index 8d6be16cf..07aed2fc9 100644 --- a/deeprank2/features/exposure.py +++ b/deeprank2/features/exposure.py @@ -52,25 +52,24 @@ def add_features( # pylint: disable=unused-argument signal.alarm(0) except TimeoutError as e: raise TimeoutError('Bio.PDB.ResidueDepth.get_surface timed out.') from e - else: - hse = HSExposureCA(bio_model) - - # These can only be calculated per residue, not per atom. - # So for atomic graphs, every atom gets its residue's value. - for node in graph.nodes: - if isinstance(node.id, Residue): - residue = node.id - elif isinstance(node.id, Atom): - atom = node.id - residue = atom.residue - else: - raise TypeError(f"Unexpected node type: {type(node.id)}") - - bio_residue = bio_model[residue.chain.id][residue.number] - node.features[Nfeat.RESDEPTH] = residue_depth(bio_residue, surface) - hse_key = (residue.chain.id, (" ", residue.number, space_if_none(residue.insertion_code))) - - if hse_key in hse: - node.features[Nfeat.HSE] = np.array(hse[hse_key], dtype=np.float64) - else: - node.features[Nfeat.HSE] = np.array((0, 0, 0), dtype=np.float64) + + # These can only be calculated per residue, not per atom. + # So for atomic graphs, every atom gets its residue's value. + hse = HSExposureCA(bio_model) + for node in graph.nodes: + if isinstance(node.id, Residue): + residue = node.id + elif isinstance(node.id, Atom): + atom = node.id + residue = atom.residue + else: + raise TypeError(f"Unexpected node type: {type(node.id)}") + + bio_residue = bio_model[residue.chain.id][residue.number] + node.features[Nfeat.RESDEPTH] = residue_depth(bio_residue, surface) + hse_key = (residue.chain.id, (" ", residue.number, space_if_none(residue.insertion_code))) + + if hse_key in hse: + node.features[Nfeat.HSE] = np.array(hse[hse_key], dtype=np.float64) + else: + node.features[Nfeat.HSE] = np.array((0, 0, 0), dtype=np.float64) diff --git a/deeprank2/query.py b/deeprank2/query.py index c97edc59c..2c55a7712 100644 --- a/deeprank2/query.py +++ b/deeprank2/query.py @@ -3,14 +3,15 @@ import os import pickle import pkgutil -import tempfile +import re import warnings +from dataclasses import MISSING, dataclass, field, fields from functools import partial from glob import glob from multiprocessing import Pool from random import randrange from types import ModuleType -from typing import Dict, Iterator, List, Optional, Union +from typing import Iterator, Literal import h5py import numpy as np @@ -20,11 +21,9 @@ from deeprank2.domain.aminoacidlist import convert_aa_nomenclature from deeprank2.features import components, conservation, contact from deeprank2.molstruct.aminoacid import AminoAcid -from deeprank2.molstruct.atom import Atom -from deeprank2.molstruct.residue import SingleResidueVariant +from deeprank2.molstruct.residue import Residue, SingleResidueVariant from deeprank2.molstruct.structure import PDBStructure -from deeprank2.utils.buildgraph import (add_hydrogens, get_contact_atoms, - get_structure, +from deeprank2.utils.buildgraph import (get_contact_atoms, get_structure, get_surrounding_residues) from deeprank2.utils.graph import (Graph, build_atomic_graph, build_residue_graph) @@ -33,870 +32,543 @@ _log = logging.getLogger(__name__) +VALID_RESOLUTIONS = ['atom', 'residue'] -def _check_pssm(pdb_path: str, pssm_paths: Dict[str, str], suppress: bool, verbosity: int = 0): - """Checks whether information stored in pssm file matches the corresponding pdb file. - Args: - pdb_path (str): Path to the PDB file. - pssm_paths (Dict[str, str]): The paths to the PSSM files, per chain identifier. - suppress (bool): Suppress errors and throw warnings instead. - verbosity (int): Level of verbosity of error/warning. Defaults to 0. - 0 (low): Only state file name where error occurred; - 1 (medium): Also state number of incorrect and missing residues; - 2 (high): Also list the incorrect residues - - Raises: - ValueError: Raised if info between pdb file and pssm file doesn't match or if no pssms were provided - """ - - if not pssm_paths: - raise ValueError('No pssm paths provided for conservation feature module.') - - pssm_data = {} - for chain in pssm_paths: - with open(pssm_paths[chain], encoding='utf-8') as f: - lines = f.readlines()[1:] - for line in lines: - pssm_data[chain + line.split()[0].zfill(4)] = convert_aa_nomenclature(line.split()[1], 3) - - # load ground truth from pdb file - pdb_truth = pdb2sql.pdb2sql(pdb_path).get_residues() - pdb_truth = {res[0] + str(res[2]).zfill(4): res[1] for res in pdb_truth if res[0] in pssm_paths} - - wrong_list = [] - missing_list = [] +@dataclass(repr=False, kw_only=True) +class Query: + """Represents one entity of interest: a single residue variant (SRV) or a protein-protein interface (PPI). - for residue in pdb_truth: - try: - if pdb_truth[residue] != pssm_data[residue]: - wrong_list.append(residue) - except KeyError: - missing_list.append(residue) - - if len(wrong_list) + len(missing_list) > 0: - error_message = f'Amino acids in PSSM files do not match pdb file for {os.path.split(pdb_path)[1]}.' - if verbosity: - if len(wrong_list) > 0: - error_message = error_message + f'\n\t{len(wrong_list)} entries are incorrect.' - if verbosity == 2: - error_message = error_message[-1] + f':\n\t{missing_list}' - if len(missing_list) > 0: - error_message = error_message + f'\n\t{len(missing_list)} entries are missing.' - if verbosity == 2: - error_message = error_message[-1] + f':\n\t{missing_list}' - - if not suppress: - raise ValueError(error_message) - - warnings.warn(error_message) - _log.warning(error_message) + :class:`Query` objects are used to generate graphs from structures, and they should be created before any model is loaded. + They can have target values associated with them, which will be stored with the resulting graph. + Args: + pdb_path (str): the path to the PDB file to query. + resolution (Literal['residue', 'atom']): sets whether each node is a residue or atom. + chain_ids (list[str] | str): the chain identifier(s) of the variant residue or interacting interfaces. + Note that this does not limit the structure to residues from this/these chain(s). + pssm_paths (dict[str, str]): the name of the chain(s) (key) and path to the pssm file(s) (value). + distance_cutoff (float): the maximum distance between two nodes to generate an edge connecting them. + targets (dict[str, float]) = Name(s) (key) and target value(s) (value) associated with this query. + suppress_pssm_errors (bool): Whether or not to suppress the error raised if the .pssm files do not + match the .pdb files. If True, a warning is returned instead. + """ -class Query: + pdb_path: str + resolution: Literal['residue', 'atom'] + chain_ids: list[str] | str + pssm_paths: dict[str, str] = field(default_factory=dict) + distance_cutoff: float = None + targets: dict[str, float] = field(default_factory=dict) + suppress_pssm_errors: bool = False - def __init__(self, model_id: str, targets: Optional[Dict[str, Union[float, int]]] = None, suppress_pssm_errors: bool = False): - """Represents one entity of interest, like a single-residue variant or a protein-protein interface. + def __post_init__(self): + self._model_id = os.path.splitext(os.path.basename(self.pdb_path))[0] + self.variant = None # not used for PPI, overwritten for SRV - :class:`Query` objects are used to generate graphs from structures, and they should be created before any model is loaded. - They can have target values associated with them, which will be stored with the resulting graph. + if self.resolution not in VALID_RESOLUTIONS: + raise ValueError(f"Invalid resolution given ({self.resolution}). Must be one of {VALID_RESOLUTIONS}") - Args: - model_id (str): The ID of the model to load, usually a .PDB accession code. - targets (Optional[Dict[str, Union[float, int]]], optional): Target values associated with the query. Defaults to None. - suppress_pssm_errors (bool, optional): Suppress error raised if .pssm files do not match .pdb files and throw warning instead. - Defaults to False. - """ + if not isinstance(self.chain_ids, list): + self.chain_ids = [self.chain_ids] - self._model_id = model_id - self._suppress = suppress_pssm_errors + if not self.distance_cutoff: + if self.resolution == 'atom': + self.distance_cutoff = 5.5 + if self.resolution == 'residue': + self.distance_cutoff = 10 - if targets is None: - self._targets = {} - else: - self._targets = targets + # convert None to empty type (e.g. list, dict) for arguments where this is expected + for f in fields(self): + value = getattr(self, f.name) + if value is None and f.default_factory is not MISSING: + setattr(self, f.name, f.default_factory()) def _set_graph_targets(self, graph: Graph): - "Simply copies target data from query to graph." - - for target_name, target_data in self._targets.items(): + """Copy target data from query to graph.""" + for target_name, target_data in self.targets.items(): graph.targets[target_name] = target_data - def _load_structure( - self, pdb_path: str, pssm_paths: Optional[Dict[str, str]], - include_hydrogens: bool, - load_pssms: bool, - ): - "A helper function, to build the structure from .PDB and .PSSM files." - - # make a copy of the pdb, with hydrogens - pdb_name = os.path.basename(pdb_path) - hydrogen_pdb_file, hydrogen_pdb_path = tempfile.mkstemp( - prefix="hydrogenated-", suffix=pdb_name - ) - os.close(hydrogen_pdb_file) - - if include_hydrogens: - add_hydrogens(pdb_path, hydrogen_pdb_path) - - # read the .PDB copy - try: - pdb = pdb2sql.pdb2sql(hydrogen_pdb_path) - finally: - os.remove(hydrogen_pdb_path) - else: - pdb = pdb2sql.pdb2sql(pdb_path) - + def _load_structure(self) -> PDBStructure: + """Build PDBStructure objects from pdb and pssm data.""" + pdb = pdb2sql.pdb2sql(self.pdb_path) try: structure = get_structure(pdb, self.model_id) finally: pdb._close() # pylint: disable=protected-access - # read the pssm - if load_pssms: - _check_pssm(pdb_path, pssm_paths, suppress = self._suppress) - for chain in structure.chains: - if chain.id in pssm_paths: - pssm_path = pssm_paths[chain.id] - - with open(pssm_path, "rt", encoding="utf-8") as f: - chain.pssm = parse_pssm(f, chain) + if self._pssm_required: + self._load_pssm_data(structure) return structure + def _load_pssm_data(self, structure: PDBStructure): + self._check_pssm() + for chain in structure.chains: + if chain.id in self.pssm_paths: + pssm_path = self.pssm_paths[chain.id] + with open(pssm_path, "rt", encoding="utf-8") as f: + chain.pssm = parse_pssm(f, chain) + + def _check_pssm(self, verbosity: Literal[0,1,2] = 0): + """Checks whether information stored in pssm file matches the corresponding pdb file. + + Args: + pdb_path (str): Path to the PDB file. + pssm_paths (dict[str, str]): The paths to the PSSM files, per chain identifier. + suppress (bool): Suppress errors and throw warnings instead. + verbosity (int): Level of verbosity of error/warning. Defaults to 0. + 0 (low): Only state file name where error occurred; + 1 (medium): Also state number of incorrect and missing residues; + 2 (high): Also list the incorrect residues + + Raises: + ValueError: Raised if info between pdb file and pssm file doesn't match or if no pssms were provided + """ + if not self.pssm_paths: + raise ValueError('No pssm paths provided for conservation feature module.') + + # load residues from pssm and pdb files + pssm_file_residues = {} + for chain, pssm_path in self.pssm_paths.items(): + with open(pssm_path, encoding='utf-8') as f: + lines = f.readlines()[1:] + for line in lines: + pssm_file_residues[chain + line.split()[0].zfill(4)] = convert_aa_nomenclature(line.split()[1], 3) + pdb_file_residues = {res[0] + str(res[2]).zfill(4): res[1] + for res in pdb2sql.pdb2sql(self.pdb_path).get_residues() + if res[0] in self.pssm_paths} + + # list errors + mismatches = [] + missing_entries = [] + for residue in pdb_file_residues: + try: + if pdb_file_residues[residue] != pssm_file_residues[residue]: + mismatches.append(residue) + except KeyError: + missing_entries.append(residue) + + # generate error message + if len(mismatches) + len(missing_entries) > 0: + error_message = f'Amino acids in PSSM files do not match pdb file for {os.path.split(self.pdb_path)[1]}.' + if verbosity: + if len(mismatches) > 0: + error_message = error_message + f'\n\t{len(mismatches)} entries are incorrect.' + if verbosity == 2: + error_message = error_message[-1] + f':\n\t{missing_entries}' + if len(missing_entries) > 0: + error_message = error_message + f'\n\t{len(missing_entries)} entries are missing.' + if verbosity == 2: + error_message = error_message[-1] + f':\n\t{missing_entries}' + + # raise exception (or warning) + if not self.suppress_pssm_errors: + raise ValueError(error_message) + warnings.warn(error_message) + _log.warning(error_message) + @property def model_id(self) -> str: - "The ID of the model, usually a .PDB accession code." + """The ID of the model, usually a .PDB accession code.""" return self._model_id - @model_id.setter - def model_id(self, value): + def model_id(self, value: str): self._model_id = value - @property - def targets(self) -> Dict[str, float]: - "The target values associated with the query." - return self._targets - def __repr__(self) -> str: return f"{type(self)}({self.get_query_id()})" - def build(self, feature_modules: List[ModuleType], include_hydrogens: bool = False) -> Graph: - raise NotImplementedError("Must be defined in child classes.") - def get_query_id(self) -> str: - raise NotImplementedError("Must be defined in child classes.") - - -class QueryCollection: - """ - Represents the collection of data queries. - Queries can be saved as a dictionary to easily navigate through their data. - - """ - - def __init__(self): - - self._queries = [] - self.cpu_count = None - self.ids_count = {} - - def add(self, query: Query, verbose: bool = False, warn_duplicate: bool = True): - """ - Adds a new query to the collection. - - Args: - query(:class:`Query`): Must be a :class:`Query` object, either :class:`ProteinProteinInterfaceResidueQuery` or - :class:`SingleResidueVariantAtomicQuery`. - verbose(bool, optional): For logging query IDs added, defaults to False. - warn_duplicate (bool): Log a warning before renaming if a duplicate query is identified. - - """ - query_id = query.get_query_id() - - if verbose: - _log.info(f'Adding query with ID {query_id}.') - - if query_id not in self.ids_count: - self.ids_count[query_id] = 1 - else: - self.ids_count[query_id] += 1 - new_id = query.model_id + "_" + str(self.ids_count[query_id]) - query.model_id = new_id - - if warn_duplicate: - _log.warning(f'Query with ID {query_id} has already been added to the collection. Renaming it as {query.get_query_id()}') - - self._queries.append(query) - - def export_dict(self, dataset_path: str): - """Exports the colection of all queries to a dictionary file. - - Args: - dataset_path (str): The path where to save the list of queries. - """ - with open(dataset_path, "wb") as pkl_file: - pickle.dump(self, pkl_file) - - @property - def queries(self) -> List[Query]: - "The list of queries added to the collection." - return self._queries - - def __contains__(self, query: Query) -> bool: - return query in self._queries - - def __iter__(self) -> Iterator[Query]: - return iter(self._queries) - - def __len__(self) -> int: - return len(self._queries) - - def _process_one_query( # pylint: disable=too-many-arguments + def build( self, - prefix: str, - feature_names: List[str], - grid_settings: Optional[GridSettings], - grid_map_method: Optional[MapMethod], - grid_augmentation_count: int, - query: Query - ): - - try: - # because only one process may access an hdf5 file at a time: - output_path = f"{prefix}-{os.getpid()}.hdf5" - - feature_modules = [ - importlib.import_module('deeprank2.features.' + name) for name in feature_names] - - graph = query.build(feature_modules) - graph.write_to_hdf5(output_path) - - if grid_settings is not None and grid_map_method is not None: - graph.write_as_grid_to_hdf5(output_path, grid_settings, grid_map_method) - - for _ in range(grid_augmentation_count): - # repeat with random augmentation - axis, angle = pdb2sql.transform.get_rot_axis_angle(randrange(100)) - augmentation = Augmentation(axis, angle) - graph.write_as_grid_to_hdf5(output_path, grid_settings, grid_map_method, augmentation) - - return None - - except (ValueError, AttributeError, KeyError, TimeoutError) as e: - _log.warning(f'\nGraph/Query with ID {query.get_query_id()} ran into an Exception ({e.__class__.__name__}: {e}),' - ' and it has not been written to the hdf5 file. More details below:') - _log.exception(e) - return None + feature_modules: list[str | ModuleType], + ) -> Graph: + """Builds the graph from the .PDB structure. - def process( # pylint: disable=too-many-arguments, too-many-locals, dangerous-default-value - self, - prefix: Optional[str] = None, - feature_modules: Union[ModuleType, List[ModuleType], str, List[str]] = [components, contact], - cpu_count: Optional[int] = None, - combine_output: bool = True, - grid_settings: Optional[GridSettings] = None, - grid_map_method: Optional[MapMethod] = None, - grid_augmentation_count: int = 0 - ) -> List[str]: - """ Args: - prefix (Optional[str], optional): Prefix for the output files. Defaults to None, which sets ./processed-queries- prefix. - feature_modules (Union[ModuleType, List[ModuleType], str, List[str]], optional): Features' module or list of features' modules - used to generate features (given as string or as an imported module). Each module must implement the :py:func:`add_features` function, - and features' modules can be found (or should be placed in case of a custom made feature) in `deeprank2.features` folder. - If set to 'all', all available modules in `deeprank2.features` are used to generate the features. - Defaults to only the basic feature modules `deeprank2.features.components` and `deeprank2.features.contact`. - cpu_count (Optional[int], optional): How many processes to be run simultaneously. Defaults to None, which takes all available cpu cores. - combine_output (bool, optional): For combining the HDF5 files generated by the processes. Defaults to True. - grid_settings (Optional[:class:`GridSettings`], optional): If valid together with `grid_map_method`, the grid data will be stored as well. - Defaults to None. - grid_map_method (Optional[:class:`MapMethod`], optional): If valid together with `grid_settings`, the grid data will be stored as well. - Defaults to None. - grid_augmentation_count (int, optional): Number of grid data augmentations. May not be negative be zero or a positive number. - Defaults to 0. + feature_modules (list[str]): the feature modules used to build the graph. + These must be filenames existing inside `deeprank2.features` subpackage. Returns: - List[str]: The list of paths of the generated HDF5 files. + :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. """ - # set defaults - if prefix is None: - prefix = "processed-queries" - elif prefix.endswith('.hdf5'): - prefix = prefix[:-5] - if cpu_count is None: - cpu_count = os.cpu_count() # returns the number of CPUs in the system - else: - cpu_count_system = os.cpu_count() - if cpu_count > cpu_count_system: - _log.warning(f'\nTried to set {cpu_count} CPUs, but only {cpu_count_system} are present in the system.') - cpu_count = cpu_count_system - self.cpu_count = cpu_count - _log.info(f'\nNumber of CPUs for processing the queries set to: {self.cpu_count}.') - - - if feature_modules == 'all': - feature_names = [modname for _, modname, _ in pkgutil.iter_modules(deeprank2.features.__path__)] - elif isinstance(feature_modules, list): - feature_names = [os.path.basename(m.__file__)[:-3] if isinstance(m,ModuleType) - else m.replace('.py','') for m in feature_modules] - elif isinstance(feature_modules, ModuleType): - feature_names = [os.path.basename(feature_modules.__file__)[:-3]] - elif isinstance(feature_modules, str): - feature_names = [feature_modules.replace('.py','')] - else: - raise ValueError(f'Feature_modules has received an invalid input type: {type(feature_modules)}.') - _log.info(f'\nSelected feature modules: {feature_names}.') - - _log.info(f'Creating pool function to process {len(self.queries)} queries...') - pool_function = partial(self._process_one_query, prefix, - feature_names, - grid_settings, grid_map_method, grid_augmentation_count) - - with Pool(self.cpu_count) as pool: - _log.info('Starting pooling...\n') - pool.map(pool_function, self.queries) - - output_paths = glob(f"{prefix}-*.hdf5") - - if combine_output: - for output_path in output_paths: - with h5py.File(f"{prefix}.hdf5",'a') as f_dest, h5py.File(output_path,'r') as f_src: - for key, value in f_src.items(): - _log.debug(f"copy {key} from {output_path} to {prefix}.hdf5") - f_src.copy(value, f_dest) - os.remove(output_path) - return glob(f"{prefix}.hdf5") - - return output_paths - + if not isinstance(feature_modules, list): + feature_modules = [feature_modules] + feature_modules = [importlib.import_module('deeprank2.features.' + module) + if isinstance(module, str) else module + for module in feature_modules] + self._pssm_required = conservation in feature_modules + graph = self._build_helper() -class SingleResidueVariantResidueQuery(Query): + # add target and feature data to the graph + self._set_graph_targets(graph) + for feature_module in feature_modules: + feature_module.add_features(self.pdb_path, graph, self.variant) - def __init__( # pylint: disable=too-many-arguments - self, - pdb_path: str, - chain_id: str, - residue_number: int, - insertion_code: str, - wildtype_amino_acid: AminoAcid, - variant_amino_acid: AminoAcid, - pssm_paths: Optional[Dict[str, str]] = None, - radius: float = 10.0, - distance_cutoff: Optional[float] = 4.5, - targets: Optional[Dict[str, float]] = None, - suppress_pssm_errors: bool = False, - ): - """ - Creates a residue graph from a single-residue variant in a .PDB file. + return graph - Args: - pdb_path (str): The path to the PDB file. - chain_id (str): The .PDB chain identifier of the variant residue. - residue_number (int): The number of the variant residue. - insertion_code (str): The insertion code of the variant residue, set to None if not applicable. - wildtype_amino_acid (:class:`AminoAcid`): The wildtype amino acid. - variant_amino_acid (:class:`AminoAcid`): The variant amino acid. - pssm_paths (Optional[Dict(str,str)], optional): The paths to the PSSM files, per chain identifier. Defaults to None. - radius (float, optional): In Ångström, determines how many residues will be included in the graph. Defaults to 10.0. - distance_cutoff (Optional[float], optional): Max distance in Ångström between a pair of atoms to consider them as an external edge in the graph. - Defaults to 4.5. - targets (Optional[Dict(str,float)], optional): Named target values associated with this query. Defaults to None. - suppress_pssm_errors (bool, optional): Suppress error raised if .pssm files do not match .pdb files and throw warning instead. - Defaults to False. - """ + def _build_helper(self) -> Graph: + raise NotImplementedError("Must be defined in child classes.") + def get_query_id(self) -> str: + raise NotImplementedError("Must be defined in child classes.") - self._pdb_path = pdb_path - self._pssm_paths = pssm_paths - model_id = os.path.splitext(os.path.basename(pdb_path))[0] +@dataclass(kw_only=True) +class SingleResidueVariantQuery(Query): + """A query that builds a single residue variant graph. + + Args (common for `Query`): + pdb_path (str): the path to the PDB file to query. + resolution (Literal['residue', 'atom']): sets whether each node is a residue or atom. + chain_ids (list[str] | str): the chain identifier of the variant residue (generally a single capital letter). + Note that this does not limit the structure to residues from this chain. + pssm_paths (dict[str, str]): the name of the chain(s) (key) and path to the pssm file(s) (value). + distance_cutoff (float): the maximum distance between two nodes to generate an edge connecting them. + targets (dict[str, float]) = Name(s) (key) and target value(s) (value) associated with this query. + suppress_pssm_errors (bool): Whether or not to suppress the error raised if the .pssm files do not + match the .pdb files. If True, a warning is returned instead. + SRV specific Args: + variant_residue_number (int): the residue number of the variant residue. + insertion_code (str | None): the insertion code of the variant residue. + wildtype_amino_acid (AminoAcid): the amino acid at above position in the wildtype protein. + variant_amino_acid (AminoAcid): the amino acid at above position in the variant protein. + radius (float): all Residues within this radius (in Å) from the variant residue will + be included in the graph + """ - Query.__init__(self, model_id, targets, suppress_pssm_errors) + variant_residue_number: int + insertion_code: str | None + wildtype_amino_acid: AminoAcid + variant_amino_acid: AminoAcid + radius: float = 10.0 - self._chain_id = chain_id - self._residue_number = residue_number - self._insertion_code = insertion_code - self._wildtype_amino_acid = wildtype_amino_acid - self._variant_amino_acid = variant_amino_acid + def __post_init__(self): + super().__post_init__() # calls __post_init__ of parents - self._radius = radius - self._distance_cutoff = distance_cutoff + if len(self.chain_ids) != 1: + raise ValueError("`chain_ids` must contain exactly 1 chain for `SingleResidueVariantQuery` objects, " + + f"but {len(self.chain_ids)} were given.") + self.variant_chain_id = self.chain_ids[0] @property def residue_id(self) -> str: - "String representation of the residue number and insertion code." - - if self._insertion_code is not None: - - return f"{self._residue_number}{self._insertion_code}" - - return str(self._residue_number) + """String representation of the residue number and insertion code.""" + if self.insertion_code is not None: + return f"{self.variant_residue_number}{self.insertion_code}" + return str(self.variant_residue_number) def get_query_id(self) -> str: - "Returns the string representing the complete query ID." - return f"residue-graph:{self._chain_id}:{self.residue_id}:{self._wildtype_amino_acid.name}->{self._variant_amino_acid.name}:{self.model_id}" + """Returns the string representing the complete query ID.""" + return (f"{self.resolution}-srv:" + + f"{self.variant_chain_id}:{self.residue_id}:" + + f"{self.wildtype_amino_acid.name}->{self.variant_amino_acid.name}:{self.model_id}" + ) - def build(self, feature_modules: List[ModuleType], include_hydrogens: bool = False) -> Graph: - """Builds the graph from the .PDB structure. - - Args: - feature_modules (List[ModuleType]): Each must implement the :py:func:`add_features` function. - include_hydrogens (bool, optional): Whether to include hydrogens in the :class:`Graph`. Defaults to False. + def _build_helper(self) -> Graph: + """Helper function to build a graph for SRV queries. Returns: :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. """ # load .PDB structure - if isinstance(feature_modules, List): - load_pssms = conservation in feature_modules - else: - load_pssms = conservation == feature_modules - structure = self._load_structure(self._pdb_path, self._pssm_paths, include_hydrogens, load_pssms) + structure = self._load_structure() - # find the variant residue - variant_residue = None - for residue in structure.get_chain(self._chain_id).residues: + # find the variant residue and its surroundings + variant_residue: Residue = None + for residue in structure.get_chain(self.variant_chain_id).residues: if ( - residue.number == self._residue_number - and residue.insertion_code == self._insertion_code + residue.number == self.variant_residue_number + and residue.insertion_code == self.insertion_code ): variant_residue = residue break - if variant_residue is None: raise ValueError( - f"Residue not found in {self._pdb_path}: {self._chain_id} {self.residue_id}" + f"Residue not found in {self.pdb_path}: {self.variant_chain_id} {self.residue_id}" ) - - # define the variant - variant = SingleResidueVariant(variant_residue, self._variant_amino_acid) - - # select which residues will be the graph - residues = list(get_surrounding_residues(structure, residue, self._radius)) # pylint: disable=undefined-loop-variable + self.variant = SingleResidueVariant(variant_residue, self.variant_amino_acid) + residues = get_surrounding_residues(structure, variant_residue, self.radius) # build the graph - graph = build_residue_graph( - residues, self.get_query_id(), self._distance_cutoff - ) - - # add data to the graph - self._set_graph_targets(graph) - - for feature_module in feature_modules: - feature_module.add_features(self._pdb_path, graph, variant) + if self.resolution == 'residue': + graph = build_residue_graph(residues, self.get_query_id(), self.distance_cutoff) + elif self.resolution == 'atom': + residues.append(variant_residue) + atoms = set([]) + for residue in residues: + if residue.amino_acid is not None: + for atom in residue.atoms: + atoms.add(atom) + atoms = list(atoms) + + graph = build_atomic_graph(atoms, self.get_query_id(), self.distance_cutoff) + else: + raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.") graph.center = variant_residue.get_center() - return graph - - -class SingleResidueVariantAtomicQuery(Query): - def __init__( # pylint: disable=too-many-arguments - self, - pdb_path: str, - chain_id: str, - residue_number: int, - insertion_code: str, - wildtype_amino_acid: AminoAcid, - variant_amino_acid: AminoAcid, - pssm_paths: Optional[Dict[str, str]] = None, - radius: float = 10.0, - distance_cutoff: Optional[float] = 4.5, - targets: Optional[Dict[str, float]] = None, - suppress_pssm_errors: bool = False, - ): - """ - Creates an atomic graph for a single-residue variant in a .PDB file. - - Args: - pdb_path (str): The path to the .PDB file. - chain_id (str): The .PDB chain identifier of the variant residue. - residue_number (int): The number of the variant residue. - insertion_code (str): The insertion code of the variant residue, set to None if not applicable. - wildtype_amino_acid (:class:`AminoAcid`): The wildtype amino acid. - variant_amino_acid (:class:`AminoAcid`): The variant amino acid. - pssm_paths (Optional[Dict(str,str)], optional): The paths to the .PSSM files, per chain identifier. Defaults to None. - radius (float, optional): In Ångström, determines how many residues will be included in the graph. Defaults to 10.0. - distance_cutoff (Optional[float], optional): Max distance in Ångström between a pair of atoms to consider them as an external edge in the graph. - Defaults to 4.5. - targets (Optional[Dict(str,float)], optional): Named target values associated with this query. Defaults to None. - suppress_pssm_errors (bool, optional): Suppress error raised if .pssm files do not match .pdb files and throw warning instead. - Defaults to False. - """ - - self._pdb_path = pdb_path - self._pssm_paths = pssm_paths - - model_id = os.path.splitext(os.path.basename(pdb_path))[0] - - Query.__init__(self, model_id, targets, suppress_pssm_errors) - - self._chain_id = chain_id - self._residue_number = residue_number - self._insertion_code = insertion_code - self._wildtype_amino_acid = wildtype_amino_acid - self._variant_amino_acid = variant_amino_acid + return graph - self._radius = radius - self._distance_cutoff = distance_cutoff +@dataclass(kw_only=True) +class ProteinProteinInterfaceQuery(Query): + """A query that builds a protein-protein interface graph. - @property - def residue_id(self) -> str: - "String representation of the residue number and insertion code." + Args: + pdb_path (str): the path to the PDB file to query. + resolution (Literal['residue', 'atom']): sets whether each node is a residue or atom. + chain_ids (list[str] | str): the chain identifiers of the interacting interfaces (generally a single capital letter each). + Note that this does not limit the structure to residues from these chains. + pssm_paths (dict[str, str]): the name of the chain(s) (key) and path to the pssm file(s) (value). + distance_cutoff (float): the maximum distance between two nodes to generate an edge connecting them. + targets (dict[str, float]) = Name(s) (key) and target value(s) (value) associated with this query. + suppress_pssm_errors (bool): Whether or not to suppress the error raised if the .pssm files do not + match the .pdb files. If True, a warning is returned instead. + """ - if self._insertion_code is not None: - return f"{self._residue_number}{self._insertion_code}" + def __post_init__(self): + super().__post_init__() - return str(self._residue_number) + if len(self.chain_ids) != 2: + raise ValueError("`chain_ids` must contain exactly 2 chains for `ProteinProteinInterfaceQuery` objects, " + + f"but {len(self.chain_ids)} was/were given.") def get_query_id(self) -> str: - "Returns the string representing the complete query ID." - return f"atomic-graph:{self._chain_id}:{self.residue_id}:{self._wildtype_amino_acid.name}->{self._variant_amino_acid.name}:{self.model_id}" - - def __eq__(self, other) -> bool: + """Returns the string representing the complete query ID.""" return ( - isinstance(self, type(other)) - and self.model_id == other.model_id - and self._chain_id == other._chain_id - and self.residue_id == other.residue_id - and self._wildtype_amino_acid == other._wildtype_amino_acid - and self._variant_amino_acid == other._variant_amino_acid + f"{self.resolution}-ppi:" # resolution and query type (ppi for protein protein interface) + + f"{self.chain_ids[0]}-{self.chain_ids[1]}:{self.model_id}" ) - def __hash__(self) -> hash: - return hash( - ( - self.model_id, - self._chain_id, - self.residue_id, - self._wildtype_amino_acid, - self._variant_amino_acid, - ) - ) - - @staticmethod - def _get_atom_node_key(atom) -> str: - """ - Since pickle has problems serializing the graph when the nodes are atoms, - this function can be used to generate a unique key for the atom. - """ - - # This should include the model, chain, residue and atom - return str(atom) - - def build(self, feature_modules: Union[ModuleType, List[ModuleType]], include_hydrogens: bool = False) -> Graph: - """Builds the graph from the .PDB structure. - - Args: - feature_modules (Union[ModuleType, List[ModuleType]]): Each must implement the :py:func:`add_features` function. - include_hydrogens (bool, optional): Whether to include hydrogens in the :class:`Graph`. Defaults to False. + def _build_helper(self) -> Graph: + """Helper function to build a graph for PPI queries. Returns: :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. """ - # load .PDB structure - if isinstance(feature_modules, List): - load_pssms = conservation in feature_modules - else: - load_pssms = conservation == feature_modules - feature_modules = [feature_modules] - structure = self._load_structure(self._pdb_path, self._pssm_paths, include_hydrogens, load_pssms) - - # find the variant residue - variant_residue = None - for residue in structure.get_chain(self._chain_id).residues: - if ( - residue.number == self._residue_number - and residue.insertion_code == self._insertion_code - ): - variant_residue = residue - break - - if variant_residue is None: - raise ValueError( - f"Residue not found in {self._pdb_path}: {self._chain_id} {self.residue_id}" - ) - - # define the variant - variant = SingleResidueVariant(variant_residue, self._variant_amino_acid) - - # get the residues and atoms involved - residues = get_surrounding_residues(structure, variant_residue, self._radius) - residues.add(variant_residue) - atoms = set([]) - for residue in residues: - if residue.amino_acid is not None: - for atom in residue.atoms: - atoms.add(atom) - atoms = list(atoms) - - # build the graph - graph = build_atomic_graph( - atoms, self.get_query_id(), self._distance_cutoff + # find the atoms near the contact interface + contact_atoms = get_contact_atoms( + self.pdb_path, + self.chain_ids, + self.distance_cutoff ) + if len(contact_atoms) == 0: + raise ValueError("no contact atoms found") - # add data to the graph - self._set_graph_targets(graph) + # build the graph + if self.resolution == 'atom': + graph = build_atomic_graph(contact_atoms, self.get_query_id(), self.distance_cutoff) + elif self.resolution == 'residue': + residues_selected = list({atom.residue for atom in contact_atoms}) + graph = build_residue_graph(residues_selected, self.get_query_id(), self.distance_cutoff) + else: + raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.") + graph.center = np.mean([atom.position for atom in contact_atoms], axis=0) - for feature_module in feature_modules: - feature_module.add_features(self._pdb_path, graph, variant) + structure = contact_atoms[0].residue.chain.model + if self._pssm_required: + self._load_pssm_data(structure) - graph.center = variant_residue.get_center() return graph -def _load_ppi_atoms(pdb_path: str, - chain_id1: str, chain_id2: str, - distance_cutoff: float, - include_hydrogens: bool) -> List[Atom]: - - # get the contact atoms - if include_hydrogens: - - pdb_name = os.path.basename(pdb_path) - hydrogen_pdb_file, hydrogen_pdb_path = tempfile.mkstemp( - prefix="hydrogenated-", suffix=pdb_name - ) - os.close(hydrogen_pdb_file) - - add_hydrogens(pdb_path, hydrogen_pdb_path) - - try: - contact_atoms = get_contact_atoms(hydrogen_pdb_path, - chain_id1, chain_id2, - distance_cutoff) - finally: - os.remove(hydrogen_pdb_path) - else: - contact_atoms = get_contact_atoms(pdb_path, - chain_id1, chain_id2, - distance_cutoff) - - if len(contact_atoms) == 0: - raise ValueError("no contact atoms found") - - return contact_atoms - - -def _load_ppi_pssms(pssm_paths: Optional[Dict[str, str]], - chains: List[str], - structure: PDBStructure, - pdb_path, - suppress_error): - - _check_pssm(pdb_path, pssm_paths, suppress_error, verbosity = 0) - for chain_id in chains: - if chain_id in pssm_paths: - - chain = structure.get_chain(chain_id) - - pssm_path = pssm_paths[chain_id] +class QueryCollection: + """Represents the collection of data queries that will be processed. - with open(pssm_path, "rt", encoding="utf-8") as f: - chain.pssm = parse_pssm(f, chain) + The class attributes are set either while adding queries to the collection (`_queries` + and `_ids_count`), or when processing the collection (other attributes). + Attributes: + _queries (list[:class:`Query`]): The `Query` objects in the collection. + _ids_count (dict[str, int]): The original `query_id` and the repeat number for this id. + This is used to rename the `query_id` to ensure that there are no duplicate ids. + _prefix, _cpu_count, _grid_settings, etc.: See docstring for `QueryCollection.process`. -class ProteinProteinInterfaceAtomicQuery(Query): + Notes: + Queries can be saved as a dictionary to easily navigate through their data, + using `QueryCollection.export_dict()`. + """ - def __init__( # pylint: disable=too-many-arguments + def __init__(self): + self._queries: list[Query] = [] + self._ids_count: dict[str, int] = {} + self._prefix: str | None = None + self._cpu_count: int | None = None + self._grid_settings: GridSettings | None = None + self._grid_map_method: MapMethod | None = None + self._grid_augmentation_count: int = 0 + + def add( self, - pdb_path: str, - chain_id1: str, - chain_id2: str, - pssm_paths: Optional[Dict[str, str]] = None, - distance_cutoff: Optional[float] = 5.5, - targets: Optional[Dict[str, float]] = None, - suppress_pssm_errors: bool = False, + query: Query, + verbose: bool = False, + warn_duplicate: bool = True, ): - """ - A query that builds atom-based graphs, using the residues at a protein-protein interface. + """Add a new query to the collection. Args: - pdb_path (str): The path to the .PDB file. - chain_id1 (str): The .PDB chain identifier of the first protein of interest. - chain_id2 (str): The .PDB chain identifier of the second protein of interest. - pssm_paths (Optional[Dict(str,str)], optional): The paths to the .PSSM files, per chain identifier. Defaults to None. - distance_cutoff (Optional[float], optional): Max distance in Ångström between two interacting atoms of the two proteins. - Defaults to 5.5. - targets (Optional[Dict(str,float)], optional): Named target values associated with this query. Defaults to None. - suppress_pssm_errors (bool, optional): Suppress error raised if .pssm files do not match .pdb files and throw warning instead. - Defaults to False. + query(:class:`Query`): The `Query` to add to the collection. + verbose(bool): For logging query IDs added. Defaults to `False`. + warn_duplicate (bool): Log a warning before renaming if a duplicate query is identified. Defaults to `True`. """ - model_id = os.path.splitext(os.path.basename(pdb_path))[0] - - Query.__init__(self, model_id, targets, suppress_pssm_errors) - - self._pdb_path = pdb_path - - self._chain_id1 = chain_id1 - self._chain_id2 = chain_id2 - - self._pssm_paths = pssm_paths - - self._distance_cutoff = distance_cutoff - - def get_query_id(self) -> str: - "Returns the string representing the complete query ID." - return f"atom-ppi:{self._chain_id1}-{self._chain_id2}:{self.model_id}" + query_id = query.get_query_id() + if verbose: + _log.info(f'Adding query with ID {query_id}.') - def __eq__(self, other) -> bool: - return ( - isinstance(self, type(other)) - and self.model_id == other.model_id - and {self._chain_id1, self._chain_id2} - == {other._chain_id1, other._chain_id2} - ) + if query_id not in self._ids_count: + self._ids_count[query_id] = 1 + else: + self._ids_count[query_id] += 1 + new_id = query.model_id + "_" + str(self._ids_count[query_id]) + query.model_id = new_id + if warn_duplicate: + _log.warning(f'Query with ID {query_id} has already been added to the collection. Renaming it as {query.get_query_id()}') - def __hash__(self) -> hash: - return hash((self.model_id, tuple(sorted([self._chain_id1, self._chain_id2])))) + self._queries.append(query) - def build(self, feature_modules: List[ModuleType], include_hydrogens: bool = False) -> Graph: - """Builds the graph from the .PDB structure. + def export_dict(self, dataset_path: str): + """Exports the colection of all queries to a dictionary file. Args: - feature_modules (List[ModuleType]): Each must implement the :py:func:`add_features` function. - include_hydrogens (bool, optional): Whether to include hydrogens in the :class:`Graph`. Defaults to False. - - Returns: - :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. + dataset_path (str): The path where to save the list of queries. """ - contact_atoms = _load_ppi_atoms(self._pdb_path, - self._chain_id1, self._chain_id2, - self._distance_cutoff, - include_hydrogens) + with open(dataset_path, "wb") as pkl_file: + pickle.dump(self, pkl_file) - # build the graph - graph = build_atomic_graph( - contact_atoms, self.get_query_id(), self._distance_cutoff - ) + @property + def queries(self) -> list[Query]: + """The list of queries added to the collection.""" + return self._queries - # add data to the graph - self._set_graph_targets(graph) + def __contains__(self, query: Query) -> bool: + return query in self._queries - # read the pssm - structure = contact_atoms[0].residue.chain.model + def __iter__(self) -> Iterator[Query]: + return iter(self._queries) - if not isinstance(feature_modules, List): - feature_modules = [feature_modules] - if conservation in feature_modules: - _load_ppi_pssms(self._pssm_paths, - [self._chain_id1, self._chain_id2], - structure, self._pdb_path, - suppress_error=self._suppress) + def __len__(self) -> int: + return len(self._queries) - # add the features - for feature_module in feature_modules: - feature_module.add_features(self._pdb_path, graph) + def _process_one_query(self, query: Query): + """Only one process may access an hdf5 file at a time""" - graph.center = np.mean([atom.position for atom in contact_atoms], axis=0) - return graph + try: + output_path = f"{self._prefix}-{os.getpid()}.hdf5" + graph = query.build(self._feature_modules) + graph.write_to_hdf5(output_path) + if self._grid_settings is not None and self._grid_map_method is not None: + graph.write_as_grid_to_hdf5(output_path, self._grid_settings, self._grid_map_method) + for _ in range(self._grid_augmentation_count): + # repeat with random augmentation + axis, angle = pdb2sql.transform.get_rot_axis_angle(randrange(100)) + augmentation = Augmentation(axis, angle) + graph.write_as_grid_to_hdf5(output_path, self._grid_settings, self._grid_map_method, augmentation) -class ProteinProteinInterfaceResidueQuery(Query): + except (ValueError, AttributeError, KeyError, TimeoutError) as e: + _log.warning(f'\nGraph/Query with ID {query.get_query_id()} ran into an Exception ({e.__class__.__name__}: {e}),' + ' and it has not been written to the hdf5 file. More details below:') + _log.exception(e) - def __init__( # pylint: disable=too-many-arguments + def process( # pylint: disable=too-many-arguments, too-many-locals, dangerous-default-value self, - pdb_path: str, - chain_id1: str, - chain_id2: str, - pssm_paths: Optional[Dict[str, str]] = None, - distance_cutoff: Optional[float] = 10, - targets: Optional[Dict[str, float]] = None, - suppress_pssm_errors: bool = False, - ): - """ - A query that builds residue-based graphs, using the residues at a protein-protein interface. + prefix: str = "processed-queries", + feature_modules: list[ModuleType, str] | ModuleType | str | Literal['all'] = [components, contact], + cpu_count: int | None = None, + combine_output: bool = True, + grid_settings: GridSettings | None = None, + grid_map_method: MapMethod | None = None, + grid_augmentation_count: int = 0 + ) -> list[str]: + """Render queries into graphs (and optionally grids). Args: - pdb_path (str): The path to the .PDB file. - chain_id1 (str): The .PDB chain identifier of the first protein of interest. - chain_id2 (str): The .PDB chain identifier of the second protein of interest. - pssm_paths (Optional[Dict(str,str)], optional): The paths to the .PSSM files, per chain identifier. Defaults to None. - distance_cutoff (Optional[float], optional): Max distance in Ångström between two interacting residues of the two proteins. - Defaults to 10. - targets (Optional[Dict(str,float)], optional): Named target values associated with this query. Defaults to None. - suppress_pssm_errors (bool, optional): Suppress error raised if .pssm files do not match .pdb files and throw warning instead. - Defaults to False. - """ - - model_id = os.path.splitext(os.path.basename(pdb_path))[0] + prefix (str | None, optional): Prefix for naming the output files. Defaults to "processed-queries". + feature_modules (list[ModuleType] | list[str] | Literal ['all'], optional): Feature module or list of feature modules + used to generate features (given as string or as an imported module). + Each module must implement the :py:func:`add_features` function, and all feature modules must exist inside `deeprank2.features` folder. + If set to 'all', all available modules in `deeprank2.features` are used to generate the features. + Defaults to the two primary feature modules `deeprank2.features.components` and `deeprank2.features.contact`. + cpu_count (int | None, optional): The number of processes to be run in parallel (i.e. number of CPUs used), capped by + the number of CPUs available to the system. + Defaults to None, which takes all available cpu cores. + combine_output (bool, optional): + If `True` (default): all processes are combined into a single HDF5 file. + If `False`: separate HDF5 files are created for each process (i.e. for each CPU used). + grid_settings (:class:`GridSettings` | None, optional): If valid together with `grid_map_method`, the grid data will be stored as well. + Defaults to None. + grid_map_method (:class:`MapMethod` | None, optional): If valid together with `grid_settings`, the grid data will be stored as well. + Defaults to None. + grid_augmentation_count (int, optional): Number of grid data augmentations (must be >= 0). + Defaults to 0. - Query.__init__(self, model_id, targets, suppress_pssm_errors) + Returns: + list[str]: The list of paths of the generated HDF5 files. + """ - self._pdb_path = pdb_path + # set defaults + self._prefix = "processed-queries" if not prefix else re.sub('.hdf5$', '', prefix) # scrape extension if present - self._chain_id1 = chain_id1 - self._chain_id2 = chain_id2 + max_cpus = os.cpu_count() + self._cpu_count = max_cpus if cpu_count is None else min(cpu_count, max_cpus) + if cpu_count and self._cpu_count < cpu_count: + _log.warning(f'\nTried to set {cpu_count} CPUs, but only {max_cpus} are present in the system.') + _log.info(f'\nNumber of CPUs for processing the queries set to: {self._cpu_count}.') - self._pssm_paths = pssm_paths + self._feature_modules = self._set_feature_modules(feature_modules) + _log.info(f'\nSelected feature modules: {self._feature_modules}.') - self._distance_cutoff = distance_cutoff + self._grid_settings = grid_settings + self._grid_map_method = grid_map_method - def get_query_id(self) -> str: - "Returns the string representing the complete query ID." - return f"residue-ppi:{self._chain_id1}-{self._chain_id2}:{self.model_id}" + if grid_augmentation_count < 0: + raise ValueError(f"`grid_augmentation_count` cannot be negative, but was given as {grid_augmentation_count}") + self._grid_augmentation_count = grid_augmentation_count - def __eq__(self, other) -> bool: - return ( - isinstance(self, type(other)) - and self.model_id == other.model_id - and {self._chain_id1, self._chain_id2} - == {other._chain_id1, other._chain_id2} - ) + _log.info(f'Creating pool function to process {len(self)} queries...') + pool_function = partial(self._process_one_query) + with Pool(self._cpu_count) as pool: + _log.info('Starting pooling...\n') + pool.map(pool_function, self.queries) - def __hash__(self) -> hash: - return hash((self.model_id, tuple(sorted([self._chain_id1, self._chain_id2])))) + output_paths = glob(f"{prefix}-*.hdf5") + if combine_output: + for output_path in output_paths: + with h5py.File(f"{prefix}.hdf5",'a') as f_dest, h5py.File(output_path,'r') as f_src: + for key, value in f_src.items(): + _log.debug(f"copy {key} from {output_path} to {prefix}.hdf5") + f_src.copy(value, f_dest) + os.remove(output_path) + return glob(f"{prefix}.hdf5") - def build(self, feature_modules: List[ModuleType], include_hydrogens: bool = False) -> Graph: - """Builds the graph from the .PDB structure. + return output_paths - Args: - feature_modules (List[ModuleType]): Each must implement the :py:func:`add_features` function. - include_hydrogens (bool, optional): Whether to include hydrogens in the :class:`Graph`. Defaults to False. + def _set_feature_modules( + self, + feature_modules: list[ModuleType, str] | ModuleType | str | Literal['all'] + ) -> list[str]: + """Convert `feature_modules` to list[str] irrespective of input type. - Returns: - :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. + Raises: + TypeError: if an invalid input type is passed. """ - contact_atoms = _load_ppi_atoms(self._pdb_path, - self._chain_id1, self._chain_id2, - self._distance_cutoff, - include_hydrogens) - - atom_positions = [] - residues_selected = set([]) - for atom in contact_atoms: - atom_positions.append(atom.position) - residues_selected.add(atom.residue) - residues_selected = list(residues_selected) - - # build the graph - graph = build_residue_graph( - residues_selected, self.get_query_id(), self._distance_cutoff - ) - - # add data to the graph - self._set_graph_targets(graph) - - # read the pssm - structure = contact_atoms[0].residue.chain.model - - if not isinstance(feature_modules, List): - feature_modules = [feature_modules] - if conservation in feature_modules: - _load_ppi_pssms(self._pssm_paths, - [self._chain_id1, self._chain_id2], - structure, self._pdb_path, - suppress_error=self._suppress) - - # add the features - for feature_module in feature_modules: - feature_module.add_features(self._pdb_path, graph) - - graph.center = np.mean(atom_positions, axis=0) - return graph + if feature_modules == 'all': + return [modname for _, modname, _ in pkgutil.iter_modules(deeprank2.features.__path__)] + if isinstance(feature_modules, ModuleType): + return [os.path.basename(feature_modules.__file__)[:-3]] + if isinstance(feature_modules, str): + return [re.sub('.py$', '', feature_modules)] # scrape extension if present + if isinstance(feature_modules, list): + invalid_inputs = [type(el) for el in feature_modules if not isinstance(el, (str, ModuleType))] + if invalid_inputs: + raise TypeError(f'`feature_modules` contains invalid input ({invalid_inputs}). Only `str` and `ModuleType` are accepted.') + return [re.sub('.py$', '', m) if isinstance(m, str) + else os.path.basename(m.__file__)[:-3] # for ModuleTypes + for m in feature_modules] + raise TypeError(f'`feature_modules` has received an invalid input type: {type(feature_modules)}. Only `str` and `ModuleType` are accepted.') diff --git a/deeprank2/utils/buildgraph.py b/deeprank2/utils/buildgraph.py index c67477b9e..eaecf06de 100644 --- a/deeprank2/utils/buildgraph.py +++ b/deeprank2/utils/buildgraph.py @@ -1,31 +1,21 @@ import logging import os -import subprocess -from typing import List, Union +from typing import List import numpy as np +from pdb2sql import interface as get_interface +from scipy.spatial import distance_matrix + from deeprank2.domain.aminoacidlist import amino_acids from deeprank2.molstruct.atom import Atom, AtomicElement from deeprank2.molstruct.pair import Pair from deeprank2.molstruct.residue import Residue from deeprank2.molstruct.structure import Chain, PDBStructure -from pdb2sql import interface as get_interface -from scipy.spatial import distance_matrix _log = logging.getLogger(__name__) -def add_hydrogens(input_pdb_path, output_pdb_path): - """This requires reduce: https://github.com/rlabduke/reduce.""" - - with open(output_pdb_path, "wt", encoding = "utf-8") as f: - p = subprocess.run(["reduce", input_pdb_path], stdout=subprocess.PIPE, check=True) - for line in p.stdout.decode().split("\n"): - f.write(line.replace(" new", "").replace(" std", "") + "\n") - - def _add_atom_to_residue(atom, residue): - for other_atom in residue.atoms: if other_atom.name == atom.name: # Don't allow two atoms with the same name, pick the highest @@ -115,7 +105,7 @@ def _add_atom_data_to_structure(structure: PDBStructure, # pylint: disable=too- _add_atom_to_residue(atom, residue) -def get_structure(pdb, id_: str): +def get_structure(pdb, id_: str) -> PDBStructure: """Builds a structure from rows in a pdb file. Args: @@ -165,22 +155,26 @@ def get_structure(pdb, id_: str): def get_contact_atoms( # pylint: disable=too-many-locals pdb_path: str, - chain_id1: str, - chain_id2: str, + chain_ids: List[str], distance_cutoff: float ) -> List[Atom]: """Gets the contact atoms from pdb2sql and wraps them in python objects.""" interface = get_interface(pdb_path) try: - atom_indexes = interface.get_contact_atoms(cutoff=distance_cutoff, chain1=chain_id1, chain2=chain_id2) - rows = interface.get("x,y,z,name,element,altLoc,occ,chainID,resSeq,resName,iCode", - rowID=atom_indexes[chain_id1] + atom_indexes[chain_id2]) + atom_indexes = interface.get_contact_atoms( + cutoff=distance_cutoff, + chain1=chain_ids[0], + chain2=chain_ids[1], + ) + rows = interface.get( + "x,y,z,name,element,altLoc,occ,chainID,resSeq,resName,iCode", + rowID=atom_indexes[chain_ids[0]] + atom_indexes[chain_ids[1]] + ) finally: interface._close() # pylint: disable=protected-access pdb_name = os.path.splitext(os.path.basename(pdb_path))[0] - structure = PDBStructure(f"contact_atoms_{pdb_name}") for row in rows: @@ -288,7 +282,11 @@ def get_residue_contact_pairs( # pylint: disable=too-many-locals return residue_pairs -def get_surrounding_residues(structure: Union[Chain, PDBStructure], residue: Residue, radius: float): +def get_surrounding_residues( + structure: Chain | PDBStructure, + residue: Residue, + radius: float, +) -> list[Residue]: """Get the residues that lie within a radius around a residue. Args: @@ -318,4 +316,4 @@ def get_surrounding_residues(structure: Union[Chain, PDBStructure], residue: Res close_residues.add(structure_atom.residue) - return close_residues + return list(close_residues) diff --git a/docs/getstarted.md b/docs/getstarted.md index 12fe532b4..e028d8558 100644 --- a/docs/getstarted.md +++ b/docs/getstarted.md @@ -6,25 +6,24 @@ For more details, see the [extended documentation](https://deeprank2.rtfd.io/). ## Data generation -For each protein-protein complex (or protein structure containing a SRV), a query can be created and added to the `QueryCollection` object, to be processed later on. Different types of queries exist: -- In a `ProteinProteinInterfaceResidueQuery` and `SingleResidueVariantResidueQuery`, each node represents one amino acid residue. -- In a `ProteinProteinInterfaceAtomicQuery` and `SingleResidueVariantAtomicQuery`, each node represents one atom within the amino acid residues. +For each protein-protein complex (or protein structure containing a missense variant), a `Query` can be created and added to the `QueryCollection` object, to be processed later on. Two subtypes of `Query` exist: `ProteinProteinInterfaceQuery` and `SingleResidueVariantQuery`. -A query takes as inputs: -- a `.pdb` file, representing the protein-protein structure +A `Query` takes as inputs: +- a `.pdb` file, representing the protein-protein structure, +- the resolution (`"residue"` or `"atom"`), i.e. whether each node should represent an amino acid residue or an atom, - the ids of the chains composing the structure, and - optionally, the correspondent position-specific scoring matrices (PSSMs), in the form of `.pssm` files. ```python -from deeprank2.query import QueryCollection, ProteinProteinInterfaceResidueQuery +from deeprank2.query import QueryCollection, ProteinProteinInterfaceQuery queries = QueryCollection() # Append data points -queries.add(ProteinProteinInterfaceResidueQuery( +queries.add(ProteinProteinInterfaceQuery( pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb", - chain_id1 = "A", - chain_id2 = "B", + resolution = "residue", + chain_ids = ["A", "B"], targets = { "binary": 0 }, @@ -33,10 +32,10 @@ queries.add(ProteinProteinInterfaceResidueQuery( "B": "tests/data/pssm/1ATN/1ATN.B.pdb.pssm" } )) -queries.add(ProteinProteinInterfaceResidueQuery( +queries.add(ProteinProteinInterfaceQuery( pdb_path = "tests/data/pdb/1ATN/1ATN_2w.pdb", - chain_id1 = "A", - chain_id2 = "B", + resolution = "residue", + chain_ids = ["A", "B"], targets = { "binary": 1 }, @@ -45,10 +44,10 @@ queries.add(ProteinProteinInterfaceResidueQuery( "B": "tests/data/pssm/1ATN/1ATN.B.pdb.pssm" } )) -queries.add(ProteinProteinInterfaceResidueQuery( +queries.add(ProteinProteinInterfaceQuery( pdb_path = "tests/data/pdb/1ATN/1ATN_3w.pdb", - chain_id1 = "A", - chain_id2 = "B", + resolution = "residue", + chain_ids = ["A", "B"], targets = { "binary": 0 }, diff --git a/pyproject.toml b/pyproject.toml index ba8c6aeb8..bda2eac21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,8 +56,8 @@ dependencies = [ # development dependency groups test = [ "pytest >= 7.4.0", - "pylint <= 2.15.3", - "prospector[with_pyroma] <= 1.7.7", + "pylint <= 2.17.5", + "prospector[with_pyroma] <= 1.10.2", "bump2version", "coverage", "pycodestyle", diff --git a/tests/data/hdf5/_generate_testdata.ipynb b/tests/data/hdf5/_generate_testdata.ipynb index 9ad003684..65a58720a 100644 --- a/tests/data/hdf5/_generate_testdata.ipynb +++ b/tests/data/hdf5/_generate_testdata.ipynb @@ -13,9 +13,8 @@ "PATH_TEST = ROOT / \"tests\"\n", "from deeprank2.query import (\n", " QueryCollection,\n", - " ProteinProteinInterfaceResidueQuery,\n", - " SingleResidueVariantResidueQuery,\n", - " ProteinProteinInterfaceAtomicQuery)\n", + " ProteinProteinInterfaceQuery,\n", + " SingleResidueVariantQuery)\n", "from deeprank2.tools.target import compute_ppi_scores\n", "from deeprank2.dataset import save_hdf5_keys\n", "from deeprank2.domain.aminoacidlist import alanine, phenylalanine\n", @@ -65,10 +64,10 @@ " for pdb_path in pdb_paths:\n", " # Append data points\n", " targets = compute_ppi_scores(pdb_path, ref_path)\n", - " queries.add(ProteinProteinInterfaceResidueQuery(\n", + " queries.add(ProteinProteinInterfaceQuery(\n", " pdb_path = pdb_path,\n", - " chain_id1 = chain_id1,\n", - " chain_id2 = chain_id2,\n", + " resolution = \"residue\",\n", + " chain_ids = [chain_id1, chain_id2],\n", " targets = targets,\n", " pssm_paths = {\n", " chain_id1: pssm_path1,\n", @@ -98,11 +97,11 @@ "outputs": [], "source": [ "# Local data\n", - "project_folder = '/home/dbodor/git/DeepRank/deeprank2/tests/data/sample_25_07122022/'\n", + "project_folder = '/home/dbodor/git/DeepRank/deeprank-core/tests/data/sample_25_07122022/'\n", "csv_file_name = 'BA_pMHCI_human_quantitative.csv'\n", "models_folder_name = 'exp_nmers_all_HLA_quantitative'\n", "data = 'pMHCI'\n", - "resolution = 'residue' # either 'residue' or 'atomic'\n", + "resolution = 'residue' # either 'residue' or 'atom'\n", "distance_cutoff = 15 # max distance in Å between two interacting residues/atoms of two proteins\n", "\n", "csv_file_path = f'{project_folder}data/external/processed/I/{csv_file_name}'\n", @@ -127,10 +126,10 @@ "print(f'Adding {len(pdb_files)} queries to the query collection ...')\n", "for i in range(len(pdb_files)):\n", " queries.add(\n", - " ProteinProteinInterfaceResidueQuery(\n", - " pdb_path = pdb_files[i], \n", - " chain_id1 = \"M\",\n", - " chain_id2 = \"P\",\n", + " ProteinProteinInterfaceQuery(\n", + " pdb_path = pdb_files[i],\n", + " resolution = \"residue\",\n", + " chain_ids = [\"M\", \"P\"],\n", " distance_cutoff = distance_cutoff,\n", " targets = {\n", " 'binary': int(float(bas[i]) <= 500), # binary target value\n", @@ -223,14 +222,15 @@ "queries = QueryCollection()\n", "\n", "for number in range(1, count_queries + 1):\n", - " query = SingleResidueVariantResidueQuery(\n", - " pdb_path,\n", - " \"A\",\n", - " number,\n", - " None,\n", - " alanine,\n", - " phenylalanine,\n", - " pssm_paths={\n", + " query = SingleResidueVariantQuery(\n", + " pdb_path = pdb_path,\n", + " resolution = \"residue\",\n", + " chain_ids = \"A\",\n", + " variant_residue_number = number,\n", + " insertion_code = None,\n", + " wildtype_amino_acid = alanine,\n", + " variant_amino_acid = phenylalanine,\n", + " pssm_paths = {\n", " \"A\": str(PATH_TEST / \"data/pssm/3C8P/3C8P.A.pdb.pssm\"),\n", " \"B\": str(PATH_TEST / \"data/pssm/3C8P/3C8P.B.pdb.pssm\")},\n", " targets = targets\n", @@ -270,10 +270,10 @@ "for pdb_path in pdb_paths:\n", " # Append data points\n", " targets = compute_ppi_scores(pdb_path, ref_path)\n", - " queries.add(ProteinProteinInterfaceAtomicQuery(\n", + " queries.add(ProteinProteinInterfaceQuery(\n", " pdb_path = pdb_path,\n", - " chain_id1 = chain_id1,\n", - " chain_id2 = chain_id2,\n", + " resolution=\"atom\",\n", + " chain_ids = [chain_id1, chain_id2],\n", " targets = targets,\n", " pssm_paths = {\n", " chain_id1: pssm_path1,\n", @@ -302,7 +302,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.12" }, "orig_nbformat": 4, "vscode": { diff --git a/tests/domain/test_aminoacidlist.py b/tests/domain/test_aminoacidlist.py index c65700495..2ab89c50a 100644 --- a/tests/domain/test_aminoacidlist.py +++ b/tests/domain/test_aminoacidlist.py @@ -1,4 +1,5 @@ import numpy as np + from deeprank2.domain.aminoacidlist import (amino_acids, cysteine, lysine, pyrrolysine, selenocysteine) @@ -16,10 +17,10 @@ def test_all_different_onehot(): if other != amino_acid: try: assert not np.all(amino_acid.onehot == other.onehot) - except AssertionError: + except AssertionError as e: if other in EXCEPTIONS[0] and amino_acid in EXCEPTIONS[0]: assert np.all(amino_acid.onehot == other.onehot) elif other in EXCEPTIONS[1] and amino_acid in EXCEPTIONS[1]: assert np.all(amino_acid.onehot == other.onehot) else: - raise AssertionError(f"one-hot index {amino_acid.index} is occupied by both {amino_acid} and {other}") \ No newline at end of file + raise AssertionError(f"one-hot index {amino_acid.index} is occupied by both {amino_acid} and {other}") from e diff --git a/tests/molstruct/test_structure.py b/tests/molstruct/test_structure.py index 55633b18c..2692cf18e 100644 --- a/tests/molstruct/test_structure.py +++ b/tests/molstruct/test_structure.py @@ -2,9 +2,10 @@ import pickle from multiprocessing.connection import _ForkingPickler +from pdb2sql import pdb2sql + from deeprank2.molstruct.structure import PDBStructure from deeprank2.utils.buildgraph import get_structure -from pdb2sql import pdb2sql def _get_structure(path) -> PDBStructure: @@ -12,7 +13,7 @@ def _get_structure(path) -> PDBStructure: try: structure = get_structure(pdb, "101M") finally: - pdb._close() + pdb._close() # pylint: disable=protected-access assert structure is not None diff --git a/tests/test_dataset.py b/tests/test_dataset.py index fbc6b02bd..98223acd0 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -8,9 +8,9 @@ import h5py import numpy as np import pytest -from deeprank2.dataset import GraphDataset, GridDataset, save_hdf5_keys from torch_geometric.loader import DataLoader +from deeprank2.dataset import GraphDataset, GridDataset, save_hdf5_keys from deeprank2.domain import edgestorage as Efeat from deeprank2.domain import nodestorage as Nfeat from deeprank2.domain import targetstorage as targets @@ -321,7 +321,7 @@ def test_target_transform_graphdataset(self): ) for i in range(len(dataset)): - assert(0 <= dataset.get(i).y <= 1) + assert (0 <= dataset.get(i).y <= 1) def test_invalid_target_transform_graphdataset(self): diff --git a/tests/test_integration.py b/tests/test_integration.py index d8a37a58c..be5140962 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -11,8 +11,7 @@ from deeprank2.domain import targetstorage as targets from deeprank2.neuralnets.cnn.model3d import CnnClassification from deeprank2.neuralnets.gnn.ginet import GINet -from deeprank2.query import (ProteinProteinInterfaceResidueQuery, - QueryCollection) +from deeprank2.query import ProteinProteinInterfaceQuery, QueryCollection from deeprank2.tools.target import compute_ppi_scores from deeprank2.trainer import Trainer from deeprank2.utils.exporters import HDF5OutputExporter @@ -41,17 +40,15 @@ def test_cnn(): # pylint: disable=too-many-locals prefix = os.path.join(hdf5_directory, "test-queries-process") - all_targets = compute_ppi_scores(pdb_path, ref_path) - try: all_targets = compute_ppi_scores(pdb_path, ref_path) queries = QueryCollection() for _ in range(count_queries): - query = ProteinProteinInterfaceResidueQuery( - pdb_path, - chain_id1, - chain_id2, + query = ProteinProteinInterfaceQuery( + pdb_path=pdb_path, + resolution='residue', + chain_ids=[chain_id1,chain_id2], pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2}, targets = all_targets ) @@ -129,10 +126,10 @@ def test_gnn(): # pylint: disable=too-many-locals queries = QueryCollection() for _ in range(count_queries): - query = ProteinProteinInterfaceResidueQuery( - pdb_path, - chain_id1, - chain_id2, + query = ProteinProteinInterfaceQuery( + pdb_path=pdb_path, + resolution='residue', + chain_ids=[chain_id1,chain_id2], pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2}, targets = all_targets ) diff --git a/tests/test_query.py b/tests/test_query.py index 8c99d91ee..bad0950e8 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,27 +1,31 @@ import os import shutil from tempfile import mkdtemp, mkstemp +from typing import List import h5py import numpy as np import pytest -from deeprank2.dataset import GraphDataset, GridDataset -from deeprank2.domain.aminoacidlist import (alanine, arginine, asparagine, - cysteine, glutamate, glycine, - leucine, lysine, phenylalanine) -from deeprank2.query import (ProteinProteinInterfaceAtomicQuery, - ProteinProteinInterfaceResidueQuery, - QueryCollection, SingleResidueVariantAtomicQuery, - SingleResidueVariantResidueQuery) -from deeprank2.utils.grid import GridSettings, MapMethod +from deeprank2.dataset import GraphDataset, GridDataset from deeprank2.domain import edgestorage as Efeat from deeprank2.domain import nodestorage as Nfeat from deeprank2.domain import targetstorage as targets +from deeprank2.domain.aminoacidlist import (alanine, arginine, asparagine, + cysteine, glutamate, glycine, + leucine, lysine, phenylalanine) from deeprank2.features import components, conservation, contact, surfacearea +from deeprank2.query import (ProteinProteinInterfaceQuery, QueryCollection, + SingleResidueVariantQuery) +from deeprank2.utils.graph import Graph +from deeprank2.utils.grid import GridSettings, MapMethod -def _check_graph_makes_sense(g, node_feature_names, edge_feature_names): +def _check_graph_makes_sense( + g: Graph, + node_feature_names: List[str], + edge_feature_names: List[str], +): assert len(g.nodes) > 0, "no nodes" assert Nfeat.POSITION in g.nodes[0].features @@ -87,18 +91,17 @@ def _check_graph_makes_sense(g, node_feature_names, edge_feature_names): def test_interface_graph_residue(): - query = ProteinProteinInterfaceResidueQuery( - "tests/data/pdb/3C8P/3C8P.pdb", - "A", - "B", - { + query = ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="residue", + chain_ids=["A", "B"], + pssm_paths={ "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", }, ) g = query.build([surfacearea, components, conservation, contact]) - _check_graph_makes_sense( g, [ @@ -112,21 +115,18 @@ def test_interface_graph_residue(): def test_interface_graph_atomic(): - query = ProteinProteinInterfaceAtomicQuery( - "tests/data/pdb/3C8P/3C8P.pdb", - "A", - "B", - { + query = ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="atom", + chain_ids=["A", "B"], + pssm_paths={ "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", }, distance_cutoff=4.5, ) - # using a small cutoff here, because atomic graphs are big - g = query.build([surfacearea, components, conservation, contact]) - _check_graph_makes_sense( g, [ @@ -140,23 +140,21 @@ def test_interface_graph_atomic(): def test_variant_graph_101M(): - query = SingleResidueVariantAtomicQuery( - "tests/data/pdb/101M/101M.pdb", - "A", - 27, - None, - asparagine, - phenylalanine, - {"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, + query = SingleResidueVariantQuery( + pdb_path="tests/data/pdb/101M/101M.pdb", + resolution="atom", + chain_ids="A", + variant_residue_number=27, + insertion_code=None, + wildtype_amino_acid=asparagine, + variant_amino_acid=phenylalanine, + pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, radius=5.0, distance_cutoff=5.0, ) - # using a small cutoff here, because atomic graphs are big - g = query.build([surfacearea, components, conservation, contact]) - _check_graph_makes_sense( g, [ @@ -175,14 +173,15 @@ def test_variant_graph_101M(): def test_variant_graph_1A0Z(): - query = SingleResidueVariantAtomicQuery( - "tests/data/pdb/1A0Z/1A0Z.pdb", - "A", - 125, - None, - leucine, - arginine, - { + query = SingleResidueVariantQuery( + pdb_path="tests/data/pdb/1A0Z/1A0Z.pdb", + resolution="atom", + chain_ids="A", + variant_residue_number=125, + insertion_code=None, + wildtype_amino_acid=leucine, + variant_amino_acid=arginine, + pssm_paths={ "A": "tests/data/pssm/1A0Z/1A0Z.A.pdb.pssm", "B": "tests/data/pssm/1A0Z/1A0Z.B.pdb.pssm", "C": "tests/data/pssm/1A0Z/1A0Z.A.pdb.pssm", @@ -193,10 +192,7 @@ def test_variant_graph_1A0Z(): radius=5.0, ) - # using a small cutoff here, because atomic graphs are big - g = query.build([surfacearea, components, conservation, contact]) - _check_graph_makes_sense( g, [ @@ -215,14 +211,15 @@ def test_variant_graph_1A0Z(): def test_variant_graph_9API(): - query = SingleResidueVariantAtomicQuery( - "tests/data/pdb/9api/9api.pdb", - "A", - 310, - None, - lysine, - glutamate, - { + query = SingleResidueVariantQuery( + pdb_path="tests/data/pdb/9api/9api.pdb", + resolution="atom", + chain_ids="A", + variant_residue_number=310, + insertion_code=None, + wildtype_amino_acid=lysine, + variant_amino_acid=glutamate, + pssm_paths={ "A": "tests/data/pssm/9api/9api.A.pdb.pssm", "B": "tests/data/pssm/9api/9api.B.pdb.pssm", }, @@ -231,10 +228,7 @@ def test_variant_graph_9API(): radius=5.0, ) - # using a small cutoff here, because atomic graphs are big - g = query.build([surfacearea, components, conservation, contact]) - _check_graph_makes_sense( g, [ @@ -253,19 +247,19 @@ def test_variant_graph_9API(): def test_variant_residue_graph_101M(): - query = SingleResidueVariantResidueQuery( - "tests/data/pdb/101M/101M.pdb", - "A", - 25, - None, - glycine, - alanine, - {"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, + query = SingleResidueVariantQuery( + pdb_path="tests/data/pdb/101M/101M.pdb", + resolution="residue", + chain_ids="A", + variant_residue_number=25, + insertion_code=None, + wildtype_amino_acid=glycine, + variant_amino_acid=alanine, + pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, ) g = query.build([surfacearea, components, conservation, contact]) - _check_graph_makes_sense( g, [ @@ -281,66 +275,67 @@ def test_variant_residue_graph_101M(): def test_res_ppi(): - - query = ProteinProteinInterfaceResidueQuery("tests/data/pdb/3MRC/3MRC.pdb", - "M", "P") - + query = ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3MRC/3MRC.pdb", + resolution="residue", + chain_ids=["M", "P"], + ) g = query.build([surfacearea, contact]) - _check_graph_makes_sense(g, [Nfeat.SASA], [Efeat.ELEC]) def test_augmentation(): qc = QueryCollection() - qc.add(ProteinProteinInterfaceResidueQuery( - "tests/data/pdb/3C8P/3C8P.pdb", - "A", - "B", - { + qc.add(ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="residue", + chain_ids=["A", "B"], + pssm_paths={ "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", }, targets={targets.BINARY: 0}, )) - qc.add(ProteinProteinInterfaceAtomicQuery( - "tests/data/pdb/3C8P/3C8P.pdb", - "A", - "B", - { + qc.add(ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="atom", + chain_ids=["A", "B"], + pssm_paths={ "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", }, targets={targets.BINARY: 0}, )) - qc.add(SingleResidueVariantResidueQuery( - "tests/data/pdb/101M/101M.pdb", - "A", - 25, - None, - glycine, - alanine, - {"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, + qc.add(SingleResidueVariantQuery( + pdb_path="tests/data/pdb/101M/101M.pdb", + resolution="residue", + chain_ids="A", + variant_residue_number=25, + insertion_code=None, + wildtype_amino_acid=glycine, + variant_amino_acid=alanine, + pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, )) - qc.add(SingleResidueVariantAtomicQuery( - "tests/data/pdb/101M/101M.pdb", - "A", - 27, - None, - asparagine, - phenylalanine, - {"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, + qc.add(SingleResidueVariantQuery( + pdb_path="tests/data/pdb/101M/101M.pdb", + resolution="atom", + chain_ids="A", + variant_residue_number=27, + insertion_code=None, + wildtype_amino_acid=asparagine, + variant_amino_acid=phenylalanine, + pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, radius=3.0, )) augmentation_count = 3 grid_settings = GridSettings([20, 20, 20], [20.0, 20.0, 20.0]) - expected_entry_count = (augmentation_count + 1) * len(qc) tmp_dir = mkdtemp() @@ -366,11 +361,11 @@ def test_augmentation(): def test_incorrect_pssm_order(): - q = ProteinProteinInterfaceResidueQuery( - "tests/data/pdb/3C8P/3C8P.pdb", - "A", - "B", - { + q = ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="residue", + chain_ids=["A", "B"], + pssm_paths={ "A": "tests/data/pssm/3C8P_incorrect/3C8P.A.wrong_order.pdb.pssm", "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", }, @@ -385,16 +380,16 @@ def test_incorrect_pssm_order(): # check that error suppression works with pytest.warns(UserWarning): - q._suppress = True # pylint: disable = protected-access + q.suppress_pssm_errors = True _ = q.build(conservation) def test_incomplete_pssm(): - q = ProteinProteinInterfaceResidueQuery( - "tests/data/pdb/3C8P/3C8P.pdb", - "A", - "B", - { + q = ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="residue", + chain_ids=["A", "B"], + pssm_paths={ "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", "B": "tests/data/pssm/3C8P_incorrect/3C8P.B.missing_res.pdb.pssm", }, @@ -408,71 +403,72 @@ def test_incomplete_pssm(): # check that error suppression works with pytest.warns(UserWarning): - q._suppress = True # pylint: disable = protected-access + q.suppress_pssm_errors = True _ = q.build(conservation) def test_no_pssm_provided(): # pssm_paths is empty dictionary - q_empty_dict = ProteinProteinInterfaceResidueQuery( - "tests/data/pdb/3C8P/3C8P.pdb", - "A", - "B", - {}, + q_empty_dict = ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="residue", + chain_ids=["A", "B"], + pssm_paths={}, ) # pssm_paths not provided - q_not_provided = ProteinProteinInterfaceResidueQuery( - "tests/data/pdb/3C8P/3C8P.pdb", - "A", - "B", + q_not_provided = ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="residue", + chain_ids=["A", "B"], ) with pytest.raises(ValueError): - _ = q_empty_dict.build(conservation) - _ = q_not_provided.build(conservation) + _ = q_empty_dict.build([conservation]) + _ = q_not_provided.build([conservation]) # no error if conservation module is not used - _ = q_empty_dict.build(components) - _ = q_not_provided.build(components) + _ = q_empty_dict.build([components]) + _ = q_not_provided.build([components]) def test_incorrect_pssm_provided(): # non-existing file - q_non_existing = ProteinProteinInterfaceResidueQuery( - "tests/data/pdb/3C8P/3C8P.pdb", - "A", - "B", - { + q_non_existing = ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="residue", + chain_ids=["A", "B"], + pssm_paths={ "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", "B": "tests/data/pssm/3C8P_incorrect/dummy_non_existing_file.pssm", }, ) # missing file - q_missing = ProteinProteinInterfaceResidueQuery( - "tests/data/pdb/3C8P/3C8P.pdb", - "A", - "B", - { + q_missing = ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="residue", + chain_ids=["A", "B"], + pssm_paths={ "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", }, ) with pytest.raises(FileNotFoundError): - _ = q_non_existing.build(conservation) - _ = q_missing.build(conservation) + _ = q_non_existing.build([conservation]) + _ = q_missing.build([conservation]) # no error if conservation module is not used - _ = q_non_existing.build(components) - _ = q_missing.build(components) + _ = q_non_existing.build([components]) + _ = q_missing.build([components]) def test_variant_query_multiple_chains(): - q = SingleResidueVariantAtomicQuery( + q = SingleResidueVariantQuery( pdb_path = "tests/data/pdb/2g98/pdb2g98.pdb", - chain_id = "A", - residue_number = 14, + resolution = "atom", + chain_ids = "A", + variant_residue_number = 14, insertion_code = None, wildtype_amino_acid = arginine, variant_amino_acid = cysteine, @@ -480,7 +476,7 @@ def test_variant_query_multiple_chains(): targets = {targets.BINARY: 1}, radius = 10.0, distance_cutoff = 4.5, - ) + ) # at radius 10, chain B is included in graph # no error without conservation module @@ -491,6 +487,6 @@ def test_variant_query_multiple_chains(): _ = q.build(conservation) # at radius 7, chain B is not included in graph - q._radius = 7.0 # pylint: disable = protected-access + q.radius = 7.0 graph = q.build(conservation) assert 'B' not in graph.get_all_chains() diff --git a/tests/test_querycollection.py b/tests/test_querycollection.py index ed4611137..258357cdc 100644 --- a/tests/test_querycollection.py +++ b/tests/test_querycollection.py @@ -1,3 +1,4 @@ +import warnings from os.path import join from shutil import rmtree from tempfile import mkdtemp @@ -11,12 +12,12 @@ from deeprank2.domain import nodestorage as Nfeat from deeprank2.domain.aminoacidlist import alanine, phenylalanine from deeprank2.features import components, contact, surfacearea -from deeprank2.query import (ProteinProteinInterfaceResidueQuery, Query, - QueryCollection, SingleResidueVariantResidueQuery) +from deeprank2.query import (ProteinProteinInterfaceQuery, Query, + QueryCollection, SingleResidueVariantQuery) from deeprank2.tools.target import compute_ppi_scores -def _querycollection_tester( # pylint: disable = too-many-locals, dangerous-default-value +def _querycollection_tester( # pylint: disable=dangerous-default-value query_type: str, n_queries: int = 3, feature_modules: Union[ModuleType, List[ModuleType]] = [components, contact], @@ -27,7 +28,7 @@ def _querycollection_tester( # pylint: disable = too-many-locals, dangerous-defa Generic function to test QueryCollection class. Args: - query_type (str): query type to be generated. It accepts only 'ppi' (ProteinProteinInterface) or 'var' (SingleResidueVariant). + query_type (str): query type to be generated. It accepts only 'ppi' (ProteinProteinInterface) or 'srv' (SingleResidueVariant). Defaults to 'ppi'. n_queries (int): number of queries to be generated. feature_modules: module or list of feature modules (from deeprank2.features) to be passed to process. @@ -38,32 +39,33 @@ def _querycollection_tester( # pylint: disable = too-many-locals, dangerous-defa """ if query_type == 'ppi': - queries = [ProteinProteinInterfaceResidueQuery( - str("tests/data/pdb/3C8P/3C8P.pdb"), - "A", - "B", - pssm_paths={"A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm"}, - ) for _ in range(n_queries)] - elif query_type == 'var': - queries = [SingleResidueVariantResidueQuery( - str("tests/data/pdb/101M/101M.pdb"), - "A", - None, # placeholder - insertion_code= None, - wildtype_amino_acid= alanine, - variant_amino_acid= phenylalanine, - pssm_paths={"A": str("tests/data/pssm/101M/101M.A.pdb.pssm")}, - ) for _ in range(n_queries)] + queries = [ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="residue", + chain_ids=["A","B"], + pssm_paths={"A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm"}, + )] * n_queries + elif query_type == 'srv': + queries = [SingleResidueVariantQuery( + pdb_path="tests/data/pdb/101M/101M.pdb", + resolution="residue", + chain_ids="A", + variant_residue_number=None, # placeholder + insertion_code=None, + wildtype_amino_acid=alanine, + variant_amino_acid=phenylalanine, + pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, + )] * n_queries else: - raise ValueError("Please insert a valid type (either ppi or var).") + raise ValueError("Please insert a valid type (either ppi or srv).") output_directory = mkdtemp() prefix = join(output_directory, "test-process-queries") collection = QueryCollection() for idx in range(n_queries): - if query_type == 'var': - queries[idx]._residue_number = idx + 1 # pylint: disable=protected-access + if query_type == 'srv': + queries[idx].variant_residue_number = idx + 1 collection.add(queries[idx]) else: collection.add(queries[idx], warn_duplicate=False) @@ -77,13 +79,16 @@ def _querycollection_tester( # pylint: disable = too-many-locals, dangerous-defa graph_names += list(f5.keys()) for query in collection.queries: - query_id = query.get_query_id() - assert query_id in graph_names, f"missing in output: {query_id}" + assert query.get_query_id() in graph_names, f"missing in output: {query.get_query_id()}" return collection, output_directory, output_paths -def _assert_correct_modules(output_paths: str, features: Union[str, List[str]], absent: str): +def _assert_correct_modules( + output_paths: str, + features: str | List[str], + absent: str, +): """Helper function to assert inclusion of correct features Args: @@ -117,7 +122,7 @@ def test_querycollection_process(): Tests processing method of QueryCollection class. """ - for query_type in ['ppi', 'var']: + for query_type in ['ppi', 'srv']: n_queries = 3 n_queries = 3 @@ -136,7 +141,7 @@ def test_querycollection_process_single_feature_module(): Tests processing for generating from a single feature module for following input types: ModuleType, List[ModuleType] str, List[str] """ - for query_type in ['ppi', 'var']: + for query_type in ['ppi', 'srv']: for testcase in [surfacearea, [surfacearea], 'surfacearea', ['surfacearea']]: _, output_directory, output_paths = _querycollection_tester(query_type, feature_modules=testcase) _assert_correct_modules(output_paths, Nfeat.BSA, Nfeat.HSE) @@ -154,8 +159,9 @@ def test_querycollection_process_all_features_modules(): _assert_correct_modules(output_paths, one_feature_from_each_module, 'dummy_feature') rmtree(output_directory) - _, output_directory, output_paths = _querycollection_tester('var', feature_modules='all') + _, output_directory, output_paths = _querycollection_tester('srv', feature_modules='all') _assert_correct_modules(output_paths, one_feature_from_each_module[:-1], Nfeat.IRCTOTAL) + rmtree(output_directory) @@ -164,10 +170,10 @@ def test_querycollection_process_default_features_modules(): Tests processing for generating all features. """ - for query_type in ['ppi', 'var']: - + for query_type in ['ppi', 'srv']: _, output_directory, output_paths = _querycollection_tester(query_type) _assert_correct_modules(output_paths, [Nfeat.RESTYPE, Efeat.DISTANCE], Nfeat.HSE) + rmtree(output_directory) @@ -176,26 +182,21 @@ def test_querycollection_process_combine_output_true(): Tests processing for combining hdf5 files into one. """ - for query_type in ['ppi', 'var']: + for query_type in ['ppi', 'srv']: modules = [surfacearea, components] - _, output_directory_t, output_paths_t = _querycollection_tester(query_type, feature_modules=modules) - _, output_directory_f, output_paths_f = _querycollection_tester(query_type, feature_modules=modules, combine_output = False, cpu_count=2) - assert len(output_paths_t) == 1 keys_t = {} with h5py.File(output_paths_t[0],'r') as file_t: for key, value in file_t.items(): keys_t[key] = value - keys_f = {} for output_path in output_paths_f: with h5py.File(output_path,'r') as file_f: for key, value in file_f.items(): keys_f[key] = value - assert keys_t == keys_f rmtree(output_directory_t) @@ -207,24 +208,24 @@ def test_querycollection_process_combine_output_false(): Tests processing for keeping all generated hdf5 files . """ - for query_type in ['ppi', 'var']: - + for query_type in ['ppi', 'srv']: cpu_count = 2 combine_output = False modules = [surfacearea, components] - - _, output_directory, output_paths = _querycollection_tester(query_type, feature_modules=modules, - cpu_count = cpu_count, combine_output = combine_output) - + _, output_directory, output_paths = _querycollection_tester( + query_type, + feature_modules=modules, + cpu_count = cpu_count, + combine_output = combine_output, + ) assert len(output_paths) == cpu_count rmtree(output_directory) def test_querycollection_duplicates_add(): - """ - Tests add method of QueryCollection class. - """ + """Tests add method of QueryCollection class.""" + ref_path = "tests/data/ref/1ATN/1ATN.pdb" pssm_path1 = "tests/data/pssm/1ATN/1ATN.A.pdb.pssm" pssm_path2 = "tests/data/pssm/1ATN/1ATN.B.pdb.pssm" @@ -240,19 +241,20 @@ def test_querycollection_duplicates_add(): queries = QueryCollection() - for pdb_path in pdb_paths: - # Append data points - targets = compute_ppi_scores(pdb_path, ref_path) - queries.add(ProteinProteinInterfaceResidueQuery( - pdb_path = pdb_path, - chain_id1 = chain_id1, - chain_id2 = chain_id2, - targets = targets, - pssm_paths = { - chain_id1: pssm_path1, - chain_id2: pssm_path2 - } - )) + with warnings.catch_warnings(record=UserWarning): + for pdb_path in pdb_paths: + # Append data points + targets = compute_ppi_scores(pdb_path, ref_path) + queries.add(ProteinProteinInterfaceQuery( + pdb_path = pdb_path, + resolution="residue", + chain_ids = [chain_id1, chain_id2], + targets = targets, + pssm_paths = { + chain_id1: pssm_path1, + chain_id2: pssm_path2 + } + )) #check id naming for all pdb files model_ids = [] @@ -260,7 +262,8 @@ def test_querycollection_duplicates_add(): model_ids.append(query.model_id) model_ids.sort() + # pylint: disable=protected-access assert model_ids == ['1ATN_1w', '1ATN_1w_2', '1ATN_1w_3', '1ATN_2w', '1ATN_2w_2', '1ATN_3w'] - assert queries.ids_count['residue-ppi:A-B:1ATN_1w'] == 3 - assert queries.ids_count['residue-ppi:A-B:1ATN_2w'] == 2 - assert queries.ids_count['residue-ppi:A-B:1ATN_3w'] == 1 + assert queries._ids_count['residue-ppi:A-B:1ATN_1w'] == 3 + assert queries._ids_count['residue-ppi:A-B:1ATN_2w'] == 2 + assert queries._ids_count['residue-ppi:A-B:1ATN_3w'] == 1 diff --git a/tests/utils/test_exporters.py b/tests/utils/test_exporters.py index 90da005e4..184901263 100644 --- a/tests/utils/test_exporters.py +++ b/tests/utils/test_exporters.py @@ -7,6 +7,7 @@ import h5py import pandas as pd + from deeprank2.utils.exporters import (HDF5OutputExporter, OutputExporterCollection, ScatterPlotExporter, @@ -95,7 +96,7 @@ def test_scatter_plot(self): assert os.path.isfile(scatterplot_exporter.get_filename(epoch_number)) - def test_hdf5_output(self): + def test_hdf5_output(self): # pylint: disable=too-many-locals output_exporter = HDF5OutputExporter(self._work_dir) path_output_exporter = os.path.join(self._work_dir, 'output_exporter.hdf5') entry_names = ["entry1", "entry2", "entry3"] diff --git a/tests/utils/test_grid.py b/tests/utils/test_grid.py index 675780495..fa1c82d9c 100644 --- a/tests/utils/test_grid.py +++ b/tests/utils/test_grid.py @@ -1,101 +1,48 @@ import h5py import numpy as np -from deeprank2.query import (ProteinProteinInterfaceAtomicQuery, - ProteinProteinInterfaceResidueQuery) -from deeprank2.utils.grid import Grid, GridSettings, MapMethod - - -def test_residue_grid_orientation(): - - coord_error_margin = 1.0 # Angstrom - - points_counts = [10, 10, 10] - grid_sizes = [30.0, 30.0, 30.0] - - # Extract data from original deeprank's preprocessed file. - with h5py.File("tests/data/hdf5/original-deeprank-1ak4.hdf5", 'r') as data_file: - grid_points_group = data_file["1AK4/grid_points"] - - target_xs = grid_points_group["x"][()] - target_ys = grid_points_group["y"][()] - target_zs = grid_points_group["z"][()] - - target_center = grid_points_group["center"][()] - - # Build the atomic graph, according to this repository's code. - pdb_path = "tests/data/pdb/1ak4/1ak4.pdb" - chain_id1 = "C" - chain_id2 = "D" - distance_cutoff = 8.5 - - query = ProteinProteinInterfaceResidueQuery(pdb_path, chain_id1, chain_id2, - distance_cutoff=distance_cutoff) - - graph = query.build([]) - - # Make a grid from the graph. - map_method = MapMethod.FAST_GAUSSIAN - grid_settings = GridSettings(points_counts, grid_sizes) - grid = Grid("test_grid", graph.center, grid_settings) - graph.map_to_grid(grid, map_method) - - assert np.all(np.abs(target_center - grid.center) < coord_error_margin), f"\n{grid.center} != \n{target_center}" - - # Orientation must be the same as in the original deeprank. - # Check that the grid point coordinates are the same. - assert grid.xs.shape == target_xs.shape - assert np.all(np.abs(grid.xs - target_xs) < coord_error_margin), f"\n{grid.xs} != \n{target_xs}" - - assert grid.ys.shape == target_ys.shape - assert np.all(np.abs(grid.ys - target_ys) < coord_error_margin), f"\n{grid.ys} != \n{target_ys}" - - assert grid.zs.shape == target_zs.shape - assert np.all(np.abs(grid.zs - target_zs) < coord_error_margin), f"\n{grid.zs} != \n{target_zs}" +from deeprank2.query import VALID_RESOLUTIONS, ProteinProteinInterfaceQuery +from deeprank2.utils.grid import Grid, GridSettings, MapMethod -def test_atomic_grid_orientation(): +def test_grid_orientation(): coord_error_margin = 1.0 # Angstrom - points_counts = [10, 10, 10] grid_sizes = [30.0, 30.0, 30.0] # Extract data from original deeprank's preprocessed file. with h5py.File("tests/data/hdf5/original-deeprank-1ak4.hdf5", 'r') as data_file: grid_points_group = data_file["1AK4/grid_points"] - target_xs = grid_points_group["x"][()] target_ys = grid_points_group["y"][()] target_zs = grid_points_group["z"][()] - target_center = grid_points_group["center"][()] - # Build the atomic graph, according to this repository's code. - pdb_path = "tests/data/pdb/1ak4/1ak4.pdb" - chain_id1 = "C" - chain_id2 = "D" - distance_cutoff = 8.5 - - query = ProteinProteinInterfaceAtomicQuery(pdb_path, chain_id1, chain_id2, - distance_cutoff=distance_cutoff) - - graph = query.build([]) - - # Make a grid from the graph. - map_method = MapMethod.FAST_GAUSSIAN - grid_settings = GridSettings(points_counts, grid_sizes) - grid = Grid("test_grid", graph.center, grid_settings) - graph.map_to_grid(grid, map_method) - - assert np.all(np.abs(target_center - grid.center) < coord_error_margin), f"\n{grid.center} != \n{target_center}" - - # Orientation must be the same as in the original deeprank. - # Check that the grid point coordinates are the same. - assert grid.xs.shape == target_xs.shape - assert np.all(np.abs(grid.xs - target_xs) < coord_error_margin), f"\n{grid.xs} != \n{target_xs}" - - assert grid.ys.shape == target_ys.shape - assert np.all(np.abs(grid.ys - target_ys) < coord_error_margin), f"\n{grid.ys} != \n{target_ys}" - - assert grid.zs.shape == target_zs.shape - assert np.all(np.abs(grid.zs - target_zs) < coord_error_margin), f"\n{grid.zs} != \n{target_zs}" + for resolution in VALID_RESOLUTIONS: + print(f"Testing for {resolution} level grids.") # in case pytest fails, this will be printed. + query = ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/1ak4/1ak4.pdb", + resolution=resolution, + chain_ids=['C', 'D'], + distance_cutoff=8.5, + ) + graph = query.build([]) + + # Make a grid from the graph. + map_method = MapMethod.FAST_GAUSSIAN + grid_settings = GridSettings(points_counts, grid_sizes) + grid = Grid("test_grid", graph.center, grid_settings) + graph.map_to_grid(grid, map_method) + + assert np.all(np.abs(target_center - grid.center) < coord_error_margin), f"\n{grid.center} != \n{target_center}" + + # Orientation must be the same as in the original deeprank. + # Check that the grid point coordinates are the same. + assert grid.xs.shape == target_xs.shape + assert np.all(np.abs(grid.xs - target_xs) < coord_error_margin), f"\n{grid.xs} != \n{target_xs}" + + assert grid.ys.shape == target_ys.shape + assert np.all(np.abs(grid.ys - target_ys) < coord_error_margin), f"\n{grid.ys} != \n{target_ys}" + + assert grid.zs.shape == target_zs.shape + assert np.all(np.abs(grid.zs - target_zs) < coord_error_margin), f"\n{grid.zs} != \n{target_zs}" diff --git a/tutorials/data_generation_ppi.ipynb b/tutorials/data_generation_ppi.ipynb index 812d1bfbe..40a765eb9 100644 --- a/tutorials/data_generation_ppi.ipynb +++ b/tutorials/data_generation_ppi.ipynb @@ -71,7 +71,7 @@ "import matplotlib.image as img\n", "import matplotlib.pyplot as plt\n", "from deeprank2.query import QueryCollection\n", - "from deeprank2.query import ProteinProteinInterfaceResidueQuery, ProteinProteinInterfaceAtomicQuery\n", + "from deeprank2.query import ProteinProteinInterfaceQuery, ProteinProteinInterfaceQuery\n", "from deeprank2.features import components, contact\n", "from deeprank2.utils.grid import GridSettings, MapMethod\n", "from deeprank2.dataset import GraphDataset" @@ -155,14 +155,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "For each protein-protein complex, so for each data point, a query can be created and added to the `QueryCollection` object, to be processed later on. Different types of queries exist, based on the molecular resolution needed:\n", - "\n", - "- In a `ProteinProteinInterfaceResidueQuery` each node represents one amino acid residue.\n", - "- In a `ProteinProteinInterfaceAtomicQuery` each node represents one atom within the amino acid residues.\n", + "For each protein-protein complex, so for each data point, a query can be created and added to the `QueryCollection` object, to be processed later on.\n", "\n", "A query takes as inputs:\n", "\n", "- A `.pdb` file, representing the protein-protein structural complex.\n", + "- The resolution (`\"residue\"` or `\"atom\"`), i.e. whether each node should represent an amino acid residue or an atom.\n", "- The ids of the two chains composing the complex. In our use case, \"M\" indicates the MHC protein chain and \"P\" the peptide chain.\n", "- The distance cutoff, which represents the maximum distance in Ångström between two interacting residues/atoms of the two proteins.\n", "- The target values associated with the query. For each query/data point, in the use case demonstrated in this tutorial will add two targets: \"BA\" and \"binary\". The first represents the actual BA value of the complex in nM, while the second represents its binary mapping, being 0 (BA > 500 nM) a not-binding complex and 1 (BA <= 500 nM) a binding one.\n", @@ -174,7 +172,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Residue-level PPI: `ProteinProteinInterfaceResidueQuery`" + "## Residue-level PPIs using `ProteinProteinInterfaceQuery`" ] }, { @@ -191,10 +189,10 @@ "count = 0\n", "for i in range(len(pdb_files)):\n", "\tqueries.add(\n", - "\t\tProteinProteinInterfaceResidueQuery(\n", - "\t\t\tpdb_path = pdb_files[i], \n", - "\t\t\tchain_id1 = \"M\",\n", - "\t\t\tchain_id2 = \"P\",\n", + "\t\tProteinProteinInterfaceQuery(\n", + "\t\t\tpdb_path = pdb_files[i],\n", + "\t\t\tresolution = \"residue\",\n", + "\t\t\tchain_ids = [\"M\", \"P\"],\n", "\t\t\tdistance_cutoff = interface_distance_cutoff,\n", "\t\t\ttargets = {\n", "\t\t\t\t'binary': int(float(bas[i]) <= 500), # binary target value\n", @@ -405,7 +403,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Atomic-level PPI: `ProteinProteinInterfaceAtomicQuery`\n", + "## Atomic-level PPIs using `ProteinProteinInterfaceQuery`\n", "\n", "Graphs can also be generated at an atomic resolution, very similarly to what has just been done for residue-level. " ] @@ -424,10 +422,10 @@ "count = 0\n", "for i in range(len(pdb_files)):\n", "\tqueries.add(\n", - "\t\tProteinProteinInterfaceAtomicQuery(\n", - "\t\t\tpdb_path = pdb_files[i], \n", - "\t\t\tchain_id1 = \"M\",\n", - "\t\t\tchain_id2 = \"P\",\n", + "\t\tProteinProteinInterfaceQuery(\n", + "\t\t\tpdb_path = pdb_files[i],\n", + "\t\t\tresolution = \"atom\",\n", + "\t\t\tchain_ids = [\"M\",\"P\"],\n", "\t\t\tdistance_cutoff = interface_distance_cutoff,\n", "\t\t\ttargets = {\n", "\t\t\t\t'binary': int(float(bas[i]) <= 500), # binary target value\n", @@ -527,7 +525,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/tutorials/data_generation_srv.ipynb b/tutorials/data_generation_srv.ipynb index accf74516..1fa693011 100644 --- a/tutorials/data_generation_srv.ipynb +++ b/tutorials/data_generation_srv.ipynb @@ -71,7 +71,7 @@ "import matplotlib.image as img\n", "import matplotlib.pyplot as plt\n", "from deeprank2.query import QueryCollection\n", - "from deeprank2.query import SingleResidueVariantResidueQuery, SingleResidueVariantAtomicQuery\n", + "from deeprank2.query import SingleResidueVariantQuery, SingleResidueVariantQuery\n", "from deeprank2.domain.aminoacidlist import (alanine, arginine, asparagine,\n", " serine, glycine, leucine, aspartate,\n", " glutamine, glutamate, lysine, phenylalanine, histidine,\n", @@ -176,12 +176,11 @@ "source": [ "For each SRV, so for each data point, a query can be created and added to the `QueryCollection` object, to be processed later on. Different types of queries exist, based on the molecular resolution needed:\n", "\n", - "- In a `SingleResidueVariantResidueQuery` each node represents one amino acid residue.\n", - "- In a `SingleResidueVariantAtomicQuery` each node represents one atom within the amino acid residues.\n", "\n", "A query takes as inputs:\n", "\n", "- A `.pdb` file, representing the protein structure containing the SRV.\n", + "- The resolution (`\"residue\"` or `\"atom\"`), i.e. whether each node should represent an amino acid residue or an atom.\n", "- The chain id of the SRV.\n", "- The residue number of the missense mutation.\n", "- The insertion code, used when two residues have the same numbering. The combination of residue numbering and insertion code defines the unique residue.\n", @@ -198,7 +197,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Residue-level SRV: `SingleResidueVariantResidueQuery`" + "## Residue-level SRV: `SingleResidueVariantQuery`" ] }, { @@ -215,10 +214,11 @@ "print(f'Adding {len(pdb_files)} queries to the query collection ...')\n", "count = 0\n", "for i in range(len(pdb_files)):\n", - "\tqueries.add(SingleResidueVariantResidueQuery(\n", + "\tqueries.add(SingleResidueVariantQuery(\n", "\t\tpdb_path = pdb_files[i],\n", - "\t\tchain_id = \"A\",\n", - "\t\tresidue_number = res_numbers[i],\n", + "\t\tresolution = \"residue\",\n", + "\t\tchain_ids = \"A\",\n", + "\t\tvariant_residue_number = res_numbers[i],\n", "\t\tinsertion_code = None,\n", "\t\twildtype_amino_acid = aa_dict[res_wildtypes[i]],\n", "\t\tvariant_amino_acid = aa_dict[res_variants[i]],\n", @@ -439,7 +439,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Atomic-level SRV: `SingleResidueVariantAtomicQuery`\n", + "## Atomic-level SRV: `SingleResidueVariantQuery`\n", "\n", "Graphs can also be generated at an atomic resolution, very similarly to what has just been done for residue-level. " ] @@ -458,10 +458,11 @@ "print(f'Adding {len(pdb_files)} queries to the query collection ...')\n", "count = 0\n", "for i in range(len(pdb_files)):\n", - "\tqueries.add(SingleResidueVariantAtomicQuery(\n", + "\tqueries.add(SingleResidueVariantQuery(\n", "\t\tpdb_path = pdb_files[i],\n", - "\t\tchain_id = \"A\",\n", - "\t\tresidue_number = res_numbers[i],\n", + " \t\tresolution = \"atom\",\n", + "\t\tchain_ids = \"A\",\n", + "\t\tvariant_residue_number = res_numbers[i],\n", "\t\tinsertion_code = None,\n", "\t\twildtype_amino_acid = aa_dict[res_wildtypes[i]],\n", "\t\tvariant_amino_acid = aa_dict[res_variants[i]],\n", @@ -563,7 +564,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/tutorials/training.ipynb b/tutorials/training.ipynb index f93108a62..0ed71ed1d 100644 --- a/tutorials/training.ipynb +++ b/tutorials/training.ipynb @@ -764,7 +764,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.10.12" }, "orig_nbformat": 4 },