diff --git a/deeprank2/query.py b/deeprank2/query.py index 0ea73d0ab..9b3d54879 100644 --- a/deeprank2/query.py +++ b/deeprank2/query.py @@ -3,6 +3,7 @@ import os import pickle import pkgutil +import re import warnings from dataclasses import MISSING, dataclass, field, fields from functools import partial @@ -354,41 +355,38 @@ def build( class QueryCollection: - """ - Represents the collection of data queries. - Queries can be saved as a dictionary to easily navigate through their data. + """Represents the collection of data queries. + Queries can be saved as a dictionary to easily navigate through their data. """ def __init__(self): - self._queries = [] - self.cpu_count = None - self.ids_count = {} + self._ids_count = {} - def add(self, query: DeepRankQuery, verbose: bool = False, warn_duplicate: bool = True): - """ - Adds a new query to the collection. + def add( + self, + query: DeepRankQuery, + verbose: bool = False, + warn_duplicate: bool = True, + ): + """Add a new query to the collection. Args: - query(:class:`DeepRankQuery`): Must be a :class:`DeepRankQuery` object, either :class:`ProteinProteinInterfaceResidueQuery` or - :class:`SingleResidueVariantAtomicQuery`. - verbose(bool, optional): For logging query IDs added, defaults to False. - warn_duplicate (bool): Log a warning before renaming if a duplicate query is identified. - + query(:class:`DeepRankQuery`): The `DeepRankQuery` to add to the collection. + verbose(bool): For logging query IDs added. Defaults to `False`. + warn_duplicate (bool): Log a warning before renaming if a duplicate query is identified. Defaults to `True`. """ query_id = query.get_query_id() - if verbose: _log.info(f'Adding query with ID {query_id}.') - if query_id not in self.ids_count: - self.ids_count[query_id] = 1 + if query_id not in self._ids_count: + self._ids_count[query_id] = 1 else: - self.ids_count[query_id] += 1 - new_id = query.model_id + "_" + str(self.ids_count[query_id]) + self._ids_count[query_id] += 1 + new_id = query.model_id + "_" + str(self._ids_count[query_id]) query.model_id = new_id - if warn_duplicate: _log.warning(f'DeepRankQuery with ID {query_id} has already been added to the collection. Renaming it as {query.get_query_id()}') @@ -408,6 +406,11 @@ def queries(self) -> list[DeepRankQuery]: """The list of queries added to the collection.""" return self._queries + @property + def ids_count(self) -> list[DeepRankQuery]: + """The list of queries added to the collection.""" + return self._ids_count + def __contains__(self, query: DeepRankQuery) -> bool: return query in self._queries @@ -417,68 +420,59 @@ def __iter__(self) -> Iterator[DeepRankQuery]: def __len__(self) -> int: return len(self._queries) - def _process_one_query( # pylint: disable=too-many-arguments - self, - prefix: str, - feature_names: list[str], - grid_settings: GridSettings | None, - grid_map_method: MapMethod | None, - grid_augmentation_count: int, - query: DeepRankQuery - ): - + def _process_one_query(self, query: DeepRankQuery): + """Only one process may access an hdf5 file at a time""" try: - # because only one process may access an hdf5 file at a time: - output_path = f"{prefix}-{os.getpid()}.hdf5" - - feature_modules = [ - importlib.import_module('deeprank2.features.' + name) for name in feature_names] - + # TODO: Maybe make exception catching optional, because I think it would be good to raise the error by default. + output_path = f"{self._prefix}-{os.getpid()}.hdf5" + #TODO: move the line below into generic build method so we can pass a list of strings here. + feature_modules = [importlib.import_module('deeprank2.features.' + name) for name in self._feature_modules] graph = query.build(feature_modules) graph.write_to_hdf5(output_path) - if grid_settings is not None and grid_map_method is not None: - graph.write_as_grid_to_hdf5(output_path, grid_settings, grid_map_method) - - for _ in range(grid_augmentation_count): + if self._grid_settings is not None and self._grid_map_method is not None: + graph.write_as_grid_to_hdf5(output_path, self._grid_settings, self._grid_map_method) + for _ in range(self._grid_augmentation_count): # repeat with random augmentation axis, angle = pdb2sql.transform.get_rot_axis_angle(randrange(100)) augmentation = Augmentation(axis, angle) - graph.write_as_grid_to_hdf5(output_path, grid_settings, grid_map_method, augmentation) - - return None + graph.write_as_grid_to_hdf5(output_path, self._grid_settings, self._grid_map_method, augmentation) except (ValueError, AttributeError, KeyError, TimeoutError) as e: _log.warning(f'\nGraph/DeepRankQuery with ID {query.get_query_id()} ran into an Exception ({e.__class__.__name__}: {e}),' ' and it has not been written to the hdf5 file. More details below:') _log.exception(e) - return None def process( # pylint: disable=too-many-arguments, too-many-locals, dangerous-default-value self, - prefix: str | None = None, - feature_modules: ModuleType | list[ModuleType] | str | list[str] | Literal['all'] = [components, contact], + prefix: str = "processed-queries", + feature_modules: list[ModuleType, str] | ModuleType | str | Literal['all'] = [components, contact], cpu_count: int | None = None, combine_output: bool = True, grid_settings: GridSettings | None = None, grid_map_method: MapMethod | None = None, grid_augmentation_count: int = 0 ) -> list[str]: - """ + """Render queries into graphs (and optionally grids). + Args: - prefix (str | None, optional): Prefix for the output files. Defaults to None, which sets ./processed-queries- prefix. - feature_modules (ModuleType | list[ModuleType] | str | list[str] | Literal['all'], optional): Features' module or list of features' modules - used to generate features (given as string or as an imported module). Each module must implement the :py:func:`add_features` function, - and features' modules can be found (or should be placed in case of a custom made feature) in `deeprank2.features` folder. + prefix (str None, optional): Prefix for naming the output files. Defaults to "processed-queries". + feature_modules (list[ModuleType] | list[str] | Literal ['all'], optional): feature module or list of feature modules + used to generate features (given as string or as an imported module). + Each module must implement the :py:func:`add_features` function, and all feature modules must exist inside `deeprank2.features` folder. If set to 'all', all available modules in `deeprank2.features` are used to generate the features. - Defaults to only the basic feature modules `deeprank2.features.components` and `deeprank2.features.contact`. - cpu_count (int | None, optional): How many processes to be run simultaneously. Defaults to None, which takes all available cpu cores. - combine_output (bool, optional): For combining the HDF5 files generated by the processes. Defaults to True. + Defaults to the two primary feature modules `deeprank2.features.components` and `deeprank2.features.contact`. + cpu_count (int | None, optional): The number of processes to be run in parallel (i.e. number of CPUs used), capped by + the number of CPUs available to the system. + Defaults to None, which takes all available cpu cores. + combine_output (bool, optional): + if `True` (default): all processes are combined into a single HDF5 file. + If `False`: separate HDF5 files are created for each process (i.e. for each CPU used). grid_settings (:class:`GridSettings` | None, optional): If valid together with `grid_map_method`, the grid data will be stored as well. Defaults to None. grid_map_method (:class:`MapMethod` | None, optional): If valid together with `grid_settings`, the grid data will be stored as well. Defaults to None. - grid_augmentation_count (int, optional): Number of grid data augmentations. May not be negative be zero or a positive number. + grid_augmentation_count (int, optional): Number of grid data augmentations (must be >= 0). Defaults to 0. Returns: @@ -486,45 +480,32 @@ def process( # pylint: disable=too-many-arguments, too-many-locals, dangerous-de """ # set defaults - if prefix is None: - prefix = "processed-queries" - elif prefix.endswith('.hdf5'): - prefix = prefix[:-5] - if cpu_count is None: - cpu_count = os.cpu_count() # returns the number of CPUs in the system - else: - cpu_count_system = os.cpu_count() - if cpu_count > cpu_count_system: - _log.warning(f'\nTried to set {cpu_count} CPUs, but only {cpu_count_system} are present in the system.') - cpu_count = cpu_count_system - self.cpu_count = cpu_count - _log.info(f'\nNumber of CPUs for processing the queries set to: {self.cpu_count}.') + self._prefix = "processed-queries" if not prefix else re.sub('.hdf5$', '', prefix) # scrape extension if present + max_cpus = os.cpu_count() + self._cpu_count = max_cpus if cpu_count is None else min(cpu_count, max_cpus) + if cpu_count and self._cpu_count < cpu_count: + _log.warning(f'\nTried to set {cpu_count} CPUs, but only {max_cpus} are present in the system.') + _log.info(f'\nNumber of CPUs for processing the queries set to: {self._cpu_count}.') - if feature_modules == 'all': - feature_names = [modname for _, modname, _ in pkgutil.iter_modules(deeprank2.features.__path__)] - elif isinstance(feature_modules, list): - feature_names = [os.path.basename(m.__file__)[:-3] if isinstance(m,ModuleType) - else m.replace('.py','') for m in feature_modules] - elif isinstance(feature_modules, ModuleType): - feature_names = [os.path.basename(feature_modules.__file__)[:-3]] - elif isinstance(feature_modules, str): - feature_names = [feature_modules.replace('.py','')] - else: - raise ValueError(f'Feature_modules has received an invalid input type: {type(feature_modules)}.') - _log.info(f'\nSelected feature modules: {feature_names}.') + self._feature_modules = self._set_feature_modules(feature_modules) + _log.info(f'\nSelected feature modules: {self._feature_modules}.') - _log.info(f'Creating pool function to process {len(self.queries)} queries...') - pool_function = partial(self._process_one_query, prefix, - feature_names, - grid_settings, grid_map_method, grid_augmentation_count) + #TODO: it would be nice if all of the below could be part of GridSettings object + self._grid_settings = grid_settings + self._grid_map_method = grid_map_method - with Pool(self.cpu_count) as pool: + if grid_augmentation_count < 0: + raise ValueError(f"`grid_augmentation_count` may not be negative, but was given as {grid_augmentation_count}") + self._grid_augmentation_count = grid_augmentation_count + + _log.info(f'Creating pool function to process {len(self)} queries...') + pool_function = partial(self._process_one_query) + with Pool(self._cpu_count) as pool: _log.info('Starting pooling...\n') pool.map(pool_function, self.queries) output_paths = glob(f"{prefix}-*.hdf5") - if combine_output: for output_path in output_paths: with h5py.File(f"{prefix}.hdf5",'a') as f_dest, h5py.File(output_path,'r') as f_src: @@ -535,3 +516,27 @@ def process( # pylint: disable=too-many-arguments, too-many-locals, dangerous-de return glob(f"{prefix}.hdf5") return output_paths + + def _set_feature_modules( + self, + feature_modules: list[ModuleType, str] | ModuleType | str | Literal['all'] + ) -> list[str]: + """Convert `feature_modules` to list[str] irrespective of input type. + + Raises: + TypeError: if an invalid input type is passed. + """ + if feature_modules == 'all': + return [modname for _, modname, _ in pkgutil.iter_modules(deeprank2.features.__path__)] + if isinstance(feature_modules, ModuleType): + return [os.path.basename(feature_modules.__file__)[:-3]] + if isinstance(feature_modules, str): + return [re.sub('.py$', '', feature_modules)] # scrape extension if present + if isinstance(feature_modules, list): + invalid_inputs = [type(el) for el in feature_modules if not isinstance(el, (str, ModuleType))] + if invalid_inputs: + raise TypeError(f'`feature_modules` contains invalid input ({invalid_inputs}). Only `str` and `ModuleType` are accepted.') + return [re.sub('.py$', '', m) if isinstance(m, str) + else os.path.basename(m.__file__)[:-3] # for ModuleTypes + for m in feature_modules] + raise TypeError(f'`feature_modules` has received an invalid input type: {type(feature_modules)}. Only `str` and `ModuleType` are accepted.')