Skip to content

Commit

Permalink
refactor child specific helper functions of build
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor committed Sep 17, 2023
1 parent 6a7adec commit a1f99f3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 66 deletions.
78 changes: 29 additions & 49 deletions deeprank2/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class DeepRankQuery:

def __post_init__(self):
self._model_id = os.path.splitext(os.path.basename(self.pdb_path))[0]
self.variant = None # not used for PPI, overwritten for SRV

if self.resolution not in VALID_RESOLUTIONS:
raise ValueError(f"Invalid resolution given ({self.resolution}). Must be one of {VALID_RESOLUTIONS}")
Expand Down Expand Up @@ -173,7 +174,7 @@ def __repr__(self) -> str:

def build(
self,
feature_modules: list[str],
feature_modules: list[str | ModuleType],
) -> Graph:
"""Builds the graph from the .PDB structure.
Expand All @@ -184,19 +185,30 @@ def build(
Returns:
:class:`Graph`: The resulting :class:`Graph` object with all the features and targets.
"""
# TODO: Should this be an internal method, as it is only called upon by QueryCollection and not by user?

try:
feature_modules = [importlib.import_module('deeprank2.features.' + name) for name in feature_modules]
except TypeError:
if isinstance(feature_modules, ModuleType):
feature_modules = [feature_modules]
elif isinstance(feature_modules, str):
feature_modules = [importlib.import_module('deeprank2.features.' + feature_modules)]
# TODO: I would prefer if `_set_feature_modules` would return a list of `ModuleType` instead of `str`,
# but this leads to an exception in the pool function in `QueryCollection.process`, that I don't know how to solve.
# Tests are currently passing `ModuleType`s directly, which is clearer than passing strings
# and is the reason I am allowing both input types.

feature_modules = [importlib.import_module('deeprank2.features.' + module)
if isinstance(module, str) else module
for module in feature_modules]
self._pssm_required = conservation in feature_modules

return self._child_build(feature_modules)
# TODO: my gut feeling is that the building can be unified further, but it is
# not trivial. I will leave it like this for now.
graph = self._build_helper()

# add target and feature data to the graph
self._set_graph_targets(graph)
for feature_module in feature_modules:
feature_module.add_features(self.pdb_path, graph, self.variant)

return graph

def _child_build(self, feature_modules: list[ModuleType]):
def _build_helper(self) -> Graph:
raise NotImplementedError("Must be defined in child classes.")
def get_query_id(self) -> str:
raise NotImplementedError("Must be defined in child classes.")
Expand Down Expand Up @@ -237,27 +249,19 @@ def get_query_id(self) -> str:
+ f"{self.wildtype_amino_acid.name}->{self.variant_amino_acid.name}:{self.model_id}"
)

def _child_build(
self,
feature_modules: list[ModuleType],
) -> Graph:
"""Builds the graph from the .PDB structure.
Args:
feature_modules (list[str]): the feature modules used to build the graph.
These must be filenames existing inside `deeprank2.features` subpackage.
def _build_helper(self) -> Graph:
"""Helper function to build a graph for SRV queries.
Returns:
:class:`Graph`: The resulting :class:`Graph` object with all the features and targets.
"""

# load .PDB structure
structure: PDBStructure = self._load_structure()
structure = self._load_structure()

# find the variant residue and its surroundings
variant_residue: Residue = None
for residue in structure.get_chain(self.variant_chain_id).residues:
residue: Residue
if (
residue.number == self.variant_residue_number
and residue.insertion_code == self.insertion_code
Expand All @@ -268,7 +272,7 @@ def _child_build(
raise ValueError(
f"Residue not found in {self.pdb_path}: {self.variant_chain_id} {self.residue_id}"
)
variant = SingleResidueVariant(variant_residue, self.variant_amino_acid)
self.variant = SingleResidueVariant(variant_residue, self.variant_amino_acid)
residues = get_surrounding_residues(structure, variant_residue, self.radius)

# build the graph
Expand All @@ -287,16 +291,10 @@ def _child_build(
graph = build_atomic_graph(atoms, self.get_query_id(), self.distance_cutoff)
#TODO: check if this works with a set instead of a list


else:
raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.")
graph.center = get_residue_center(variant_residue)

# add target and feature data to the graph
self._set_graph_targets(graph)
for feature_module in feature_modules:
feature_module.add_features(self.pdb_path, graph, variant)

return graph


Expand Down Expand Up @@ -325,21 +323,15 @@ def get_query_id(self) -> str:
+ f"{self.chain_ids[0]}-{self.chain_ids[1]}:{self.model_id}"
)

def _child_build(
self,
feature_modules: list[ModuleType],
) -> Graph:
def _build_helper(self) -> Graph:
#TODO: check how much of this is common with SRV and move it to parent class
"""Builds the graph from the .PDB structure.
Args:
feature_modules (list[ModuleType]): Each must implement the :py:func:`add_features` function.
"""Helper function to build a graph for PPI queries.
Returns:
:class:`Graph`: The resulting :class:`Graph` object with all the features and targets.
"""

# find the contact atoms
# find the atoms near the contact interface
contact_atoms = get_contact_atoms(
self.pdb_path,
self.chain_ids,
Expand All @@ -359,20 +351,11 @@ def _child_build(
raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.")
graph.center = np.mean([atom.position for atom in contact_atoms], axis=0)

# add data to the graph
self._set_graph_targets(graph)

# read the pssm
#TODO: unify with the way pssms are read for srv queries
structure = contact_atoms[0].residue.chain.model

if self._pssm_required:
self._load_pssm_data(structure)

# add the features
for feature_module in feature_modules:
feature_module.add_features(self.pdb_path, graph)

graph.center = np.mean([atom.position for atom in contact_atoms], axis=0)
return graph

Expand Down Expand Up @@ -448,9 +431,6 @@ def _process_one_query(self, query: DeepRankQuery):
try:
# TODO: Maybe make exception catching optional, because I think it would be good to raise the error by default.
output_path = f"{self._prefix}-{os.getpid()}.hdf5"
#TODO: move the line below into generic build method so we can pass a list of strings here.
# feature_modules = [importlib.import_module('deeprank2.features.' + name) for name in self._feature_modules]
# graph = query.build(feature_modules)
graph = query.build(self._feature_modules)
graph.write_to_hdf5(output_path)

Expand Down
34 changes: 17 additions & 17 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,15 @@ def test_incorrect_pssm_order():

# check that error is thrown for incorrect pssm
with pytest.raises(ValueError):
_ = q.build(conservation)
_ = q.build([conservation])

# no error if conservation module is not used
_ = q.build(components)
_ = q.build([components])

# check that error suppression works
with pytest.warns(UserWarning):
q.suppress_pssm_errors = True
_ = q.build(conservation)
_ = q.build([conservation])


def test_incomplete_pssm():
Expand All @@ -396,15 +396,15 @@ def test_incomplete_pssm():
)

with pytest.raises(ValueError):
_ = q.build(conservation)
_ = q.build([conservation])

# no error if conservation module is not used
_ = q.build(components)
_ = q.build([components])

# check that error suppression works
with pytest.warns(UserWarning):
q.suppress_pssm_errors = True
_ = q.build(conservation)
_ = q.build([conservation])


def test_no_pssm_provided():
Expand All @@ -424,12 +424,12 @@ def test_no_pssm_provided():
)

with pytest.raises(ValueError):
_ = q_empty_dict.build(conservation)
_ = q_not_provided.build(conservation)
_ = q_empty_dict.build([conservation])
_ = q_not_provided.build([conservation])

# no error if conservation module is not used
_ = q_empty_dict.build(components)
_ = q_not_provided.build(components)
_ = q_empty_dict.build([components])
_ = q_not_provided.build([components])


def test_incorrect_pssm_provided():
Expand All @@ -455,12 +455,12 @@ def test_incorrect_pssm_provided():
)

with pytest.raises(FileNotFoundError):
_ = q_non_existing.build(conservation)
_ = q_missing.build(conservation)
_ = q_non_existing.build([conservation])
_ = q_missing.build([conservation])

# no error if conservation module is not used
_ = q_non_existing.build(components)
_ = q_missing.build(components)
_ = q_non_existing.build([components])
_ = q_missing.build([components])


def test_variant_query_multiple_chains():
Expand All @@ -480,13 +480,13 @@ def test_variant_query_multiple_chains():

# at radius 10, chain B is included in graph
# no error without conservation module
graph = q.build(components)
graph = q.build([components])
assert 'B' in graph.get_all_chains()
# if we rebuild the graph with conservation module it should fail
with pytest.raises(FileNotFoundError):
_ = q.build(conservation)
_ = q.build([conservation])

# at radius 7, chain B is not included in graph
q.radius = 7.0
graph = q.build(conservation)
graph = q.build([conservation])
assert 'B' not in graph.get_all_chains()

0 comments on commit a1f99f3

Please sign in to comment.