Skip to content

Commit

Permalink
define separate parent and child build methods
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor committed Sep 17, 2023
1 parent 8c82e97 commit 6a7adec
Showing 1 changed file with 47 additions and 23 deletions.
70 changes: 47 additions & 23 deletions deeprank2/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def _set_graph_targets(self, graph: Graph):
for target_name, target_data in self.targets.items():
graph.targets[target_name] = target_data

def _load_structure(self, pssm_required: bool) -> PDBStructure:
def _load_structure(self) -> PDBStructure:
"""Build PDBStructure objects from pdb and pssm data."""
pdb = pdb2sql.pdb2sql(self.pdb_path)
try:
structure = get_structure(pdb, self.model_id)
finally:
pdb._close() # pylint: disable=protected-access
# read the pssm
if pssm_required:
if self._pssm_required:
self._load_pssm_data(structure)

return structure
Expand Down Expand Up @@ -171,7 +171,32 @@ def model_id(self, value: str):
def __repr__(self) -> str:
return f"{type(self)}({self.get_query_id()})"

def build(self, feature_modules: list[ModuleType]) -> Graph:
def build(
self,
feature_modules: list[str],
) -> Graph:
"""Builds the graph from the .PDB structure.
Args:
feature_modules (list[str]): the feature modules used to build the graph.
These must be filenames existing inside `deeprank2.features` subpackage.
Returns:
:class:`Graph`: The resulting :class:`Graph` object with all the features and targets.
"""

try:
feature_modules = [importlib.import_module('deeprank2.features.' + name) for name in feature_modules]
except TypeError:
if isinstance(feature_modules, ModuleType):
feature_modules = [feature_modules]
elif isinstance(feature_modules, str):
feature_modules = [importlib.import_module('deeprank2.features.' + feature_modules)]
self._pssm_required = conservation in feature_modules

return self._child_build(feature_modules)

def _child_build(self, feature_modules: list[ModuleType]):
raise NotImplementedError("Must be defined in child classes.")
def get_query_id(self) -> str:
raise NotImplementedError("Must be defined in child classes.")
Expand Down Expand Up @@ -212,30 +237,25 @@ def get_query_id(self) -> str:
+ f"{self.wildtype_amino_acid.name}->{self.variant_amino_acid.name}:{self.model_id}"
)

def build(
def _child_build(
self,
feature_modules: list[ModuleType] | ModuleType,
feature_modules: list[ModuleType],
) -> Graph:
#TODO: check how much of this is common with PPI and move it to parent class
"""Builds the graph from the .PDB structure.
Args:
feature_modules (list[ModuleType]): Each must implement the :py:func:`add_features` function.
feature_modules (list[str]): the feature modules used to build the graph.
These must be filenames existing inside `deeprank2.features` subpackage.
Returns:
:class:`Graph`: The resulting :class:`Graph` object with all the features and targets.
"""

# load .PDB structure
if isinstance(feature_modules, list):
pssm_required = conservation in feature_modules
else:
pssm_required = conservation == feature_modules
feature_modules = [feature_modules]
structure: PDBStructure = self._load_structure(pssm_required)
structure: PDBStructure = self._load_structure()

# find the variant residue and its surroundings
variant_residue = None
variant_residue: Residue = None
for residue in structure.get_chain(self.variant_chain_id).residues:
residue: Residue
if (
Expand Down Expand Up @@ -272,7 +292,7 @@ def build(
raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.")
graph.center = get_residue_center(variant_residue)

# add data to the graph
# add target and feature data to the graph
self._set_graph_targets(graph)
for feature_module in feature_modules:
feature_module.add_features(self.pdb_path, graph, variant)
Expand Down Expand Up @@ -305,9 +325,9 @@ def get_query_id(self) -> str:
+ f"{self.chain_ids[0]}-{self.chain_ids[1]}:{self.model_id}"
)

def build(
def _child_build(
self,
feature_modules: list[ModuleType] | ModuleType,
feature_modules: list[ModuleType],
) -> Graph:
#TODO: check how much of this is common with SRV and move it to parent class
"""Builds the graph from the .PDB structure.
Expand All @@ -319,7 +339,12 @@ def build(
:class:`Graph`: The resulting :class:`Graph` object with all the features and targets.
"""

contact_atoms = get_contact_atoms(self.pdb_path, self.chain_ids, self.distance_cutoff)
# find the contact atoms
contact_atoms = get_contact_atoms(
self.pdb_path,
self.chain_ids,
self.distance_cutoff
)
if len(contact_atoms) == 0:
raise ValueError("no contact atoms found")

Expand All @@ -341,9 +366,7 @@ def build(
#TODO: unify with the way pssms are read for srv queries
structure = contact_atoms[0].residue.chain.model

if not isinstance(feature_modules, list):
feature_modules = [feature_modules]
if conservation in feature_modules:
if self._pssm_required:
self._load_pssm_data(structure)

# add the features
Expand Down Expand Up @@ -426,8 +449,9 @@ def _process_one_query(self, query: DeepRankQuery):
# TODO: Maybe make exception catching optional, because I think it would be good to raise the error by default.
output_path = f"{self._prefix}-{os.getpid()}.hdf5"
#TODO: move the line below into generic build method so we can pass a list of strings here.
feature_modules = [importlib.import_module('deeprank2.features.' + name) for name in self._feature_modules]
graph = query.build(feature_modules)
# feature_modules = [importlib.import_module('deeprank2.features.' + name) for name in self._feature_modules]
# graph = query.build(feature_modules)
graph = query.build(self._feature_modules)
graph.write_to_hdf5(output_path)

if self._grid_settings is not None and self._grid_map_method is not None:
Expand Down

0 comments on commit 6a7adec

Please sign in to comment.