Skip to content

Commit

Permalink
minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor committed Oct 31, 2023
1 parent ff24a4f commit ab408a5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 25 deletions.
24 changes: 3 additions & 21 deletions deeprank2/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,11 @@ def build(
Returns:
:class:`Graph`: The resulting :class:`Graph` object with all the features and targets.
"""
# TODO: Should this be an internal method, as it is only called upon by QueryCollection and not by user?

# TODO: I would prefer if `_set_feature_modules` would return a list of `ModuleType` instead of `str`,
# but this leads to an exception in the pool function in `QueryCollection.process`, that I don't know how to solve.
# Tests are currently passing `ModuleType`s directly, which is clearer than passing strings
# and is the reason I am allowing both input types.

feature_modules = [importlib.import_module('deeprank2.features.' + module)
if isinstance(module, str) else module
for module in feature_modules]
self._pssm_required = conservation in feature_modules

# TODO: my gut feeling is that the building can be unified further, but it is
# not trivial. I will leave it like this for now.
graph = self._build_helper()

# add target and feature data to the graph
Expand Down Expand Up @@ -373,18 +364,16 @@ def _build_helper(self) -> Graph:
if self.resolution == 'atom':
graph = build_atomic_graph(contact_atoms, self.get_query_id(), self.distance_cutoff)
elif self.resolution == 'residue':
residues_selected = {atom.residue for atom in contact_atoms}
graph = build_residue_graph(list(residues_selected), self.get_query_id(), self.distance_cutoff)
residues_selected = list({atom.residue for atom in contact_atoms})
graph = build_residue_graph(residues_selected, self.get_query_id(), self.distance_cutoff)
else:
raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.")
graph.center = np.mean([atom.position for atom in contact_atoms], axis=0)

structure = contact_atoms[0].residue.chain.model

if self._pssm_required:
self._load_pssm_data(structure)

graph.center = np.mean([atom.position for atom in contact_atoms], axis=0)
return graph


Expand Down Expand Up @@ -441,15 +430,8 @@ def export_dict(self, dataset_path: str):
@property
def queries(self) -> list[DeepRankQuery]:
"""The list of queries added to the collection."""

return self._queries

@property
def ids_count(self) -> list[DeepRankQuery]:
"""The list of queries added to the collection."""

return self._ids_count

def __contains__(self, query: DeepRankQuery) -> bool:
return query in self._queries

Expand Down Expand Up @@ -534,7 +516,7 @@ def process( # pylint: disable=too-many-arguments, too-many-locals, dangerous-de
self._grid_map_method = grid_map_method

if grid_augmentation_count < 0:
raise ValueError(f"`grid_augmentation_count` may not be negative, but was given as {grid_augmentation_count}")
raise ValueError(f"`grid_augmentation_count` cannot be negative, but was given as {grid_augmentation_count}")
self._grid_augmentation_count = grid_augmentation_count

_log.info(f'Creating pool function to process {len(self)} queries...')
Expand Down
7 changes: 4 additions & 3 deletions tests/test_querycollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def test_querycollection_duplicates_add():
model_ids.append(query.model_id)
model_ids.sort()

# pylint: disable=protected-access
assert model_ids == ['1ATN_1w', '1ATN_1w_2', '1ATN_1w_3', '1ATN_2w', '1ATN_2w_2', '1ATN_3w']
assert queries.ids_count['residue-ppi:A-B:1ATN_1w'] == 3
assert queries.ids_count['residue-ppi:A-B:1ATN_2w'] == 2
assert queries.ids_count['residue-ppi:A-B:1ATN_3w'] == 1
assert queries._ids_count['residue-ppi:A-B:1ATN_1w'] == 3
assert queries._ids_count['residue-ppi:A-B:1ATN_2w'] == 2
assert queries._ids_count['residue-ppi:A-B:1ATN_3w'] == 1
2 changes: 1 addition & 1 deletion tests/utils/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from deeprank2.utils.grid import Grid, GridSettings, MapMethod


def test_residue_grid_orientation():
def test_grid_orientation():
coord_error_margin = 1.0 # Angstrom
points_counts = [10, 10, 10]
grid_sizes = [30.0, 30.0, 30.0]
Expand Down

0 comments on commit ab408a5

Please sign in to comment.