Skip to content

Commit

Permalink
Merge pull request #636 from DeepRank/635_tutorials_dbodor
Browse files Browse the repository at this point in the history
tutorials: avoid error messages in tutorial
  • Loading branch information
DaniBodor authored Sep 6, 2024
2 parents 128d5f9 + e0b209a commit ea47488
Show file tree
Hide file tree
Showing 10 changed files with 919 additions and 878 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ jobs:
- name: Download the data for the tutorials
shell: bash -l {0}
run: |
wget https://zenodo.org/records/8349335/files/data_raw.zip
wget https://zenodo.org/records/13709906/files/data_raw.zip
unzip data_raw.zip -d data_raw
mv data_raw tutorials
echo listing files in data_raw:
ls tutorials/data_raw
- name: Run tutorial notebooks
run: pytest --nbmake tutorials
4 changes: 2 additions & 2 deletions deeprank2/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _check_and_inherit_train( # noqa: C901
for key in data["features_transform"].values():
if key["transform"] is None:
continue
key["transform"] = eval(key["transform"]) # noqa: S307, PGH001
key["transform"] = eval(key["transform"]) # noqa: S307
except pickle.UnpicklingError as e:
msg = "The path provided to `train_source` is not a valid DeepRank2 pre-trained model."
raise ValueError(msg) from e
Expand Down Expand Up @@ -277,7 +277,7 @@ def _filter_targets(self, grp: h5py.Group) -> bool:
for operator_string in [">", "<", "==", "<=", ">=", "!="]:
operation = operation.replace(operator_string, f"{target_value}" + operator_string)

if not eval(operation): # noqa: S307, PGH001
if not eval(operation): # noqa: S307
return False

elif target_condition is not None:
Expand Down
57 changes: 36 additions & 21 deletions deeprank2/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import deeprank2.features
from deeprank2.domain.aminoacidlist import convert_aa_nomenclature
from deeprank2.features import components, conservation, contact
from deeprank2.molstruct.residue import Residue, SingleResidueVariant
from deeprank2.molstruct.residue import SingleResidueVariant
from deeprank2.utils.buildgraph import get_contact_atoms, get_structure, get_surrounding_residues
from deeprank2.utils.graph import Graph
from deeprank2.utils.grid import Augmentation, GridSettings, MapMethod
Expand Down Expand Up @@ -265,12 +265,11 @@ def _build_helper(self) -> Graph:
structure = self._load_structure()

# find the variant residue and its surroundings
variant_residue: Residue = None
for residue in structure.get_chain(self.variant_chain_id).residues:
if residue.number == self.variant_residue_number and residue.insertion_code == self.insertion_code:
variant_residue = residue
break
if variant_residue is None:
else: # if break is not reached
msg = f"Residue not found in {self.pdb_path}: {self.variant_chain_id} {self.residue_id}"
raise ValueError(msg)
self.variant = SingleResidueVariant(variant_residue, self.variant_amino_acid)
Expand Down Expand Up @@ -354,19 +353,12 @@ def _build_helper(self) -> Graph:
raise ValueError(msg)

# build the graph
if self.resolution == "atom":
graph = Graph.build_graph(
contact_atoms,
self.get_query_id(),
self.max_edge_length,
)
elif self.resolution == "residue":
residues_selected = list({atom.residue for atom in contact_atoms})
graph = Graph.build_graph(
residues_selected,
self.get_query_id(),
self.max_edge_length,
)
nodes = contact_atoms if self.resolution == "atom" else list({atom.residue for atom in contact_atoms})
graph = Graph.build_graph(
nodes=nodes,
graph_id=self.get_query_id(),
max_edge_length=self.max_edge_length,
)

graph.center = np.mean([atom.position for atom in contact_atoms], axis=0)
structure = contact_atoms[0].residue.chain.model
Expand Down Expand Up @@ -453,7 +445,7 @@ def __iter__(self) -> Iterator[Query]:
def __len__(self) -> int:
return len(self._queries)

def _process_one_query(self, query: Query) -> None:
def _process_one_query(self, query: Query, log_error_traceback: bool = False) -> None:
"""Only one process may access an hdf5 file at a time."""
try:
output_path = f"{self._prefix}-{os.getpid()}.hdf5"
Expand All @@ -479,10 +471,12 @@ def _process_one_query(self, query: Query) -> None:

except (ValueError, AttributeError, KeyError, TimeoutError) as e:
_log.warning(
f"\nGraph/Query with ID {query.get_query_id()} ran into an Exception ({e.__class__.__name__}: {e}),"
" and it has not been written to the hdf5 file. More details below:",
f"Graph/Query with ID {query.get_query_id()} ran into an Exception and was not written to the hdf5 file.\n"
f"Exception found: {e.__class__.__name__}: {e}.\n"
"You may proceed with your analysis, but this query will be ignored.\n",
)
_log.exception(e)
if log_error_traceback:
_log.exception(f"----Full error traceback:----\n{e}")

def process(
self,
Expand All @@ -493,6 +487,7 @@ def process(
grid_settings: GridSettings | None = None,
grid_map_method: MapMethod | None = None,
grid_augmentation_count: int = 0,
log_error_traceback: bool = False,
) -> list[str]:
"""Render queries into graphs (and optionally grids).
Expand All @@ -510,6 +505,8 @@ def process(
grid_settings: If valid together with `grid_map_method`, the grid data will be stored as well. Defaults to None.
grid_map_method: If valid together with `grid_settings`, the grid data will be stored as well. Defaults to None.
grid_augmentation_count: Number of grid data augmentations (must be >= 0). Defaults to 0.
log_error_traceback: if True, logs full error message in case query fails. Otherwise only the error message is logged.
Defaults to false.
Returns:
The list of paths of the generated HDF5 files.
Expand All @@ -536,7 +533,7 @@ def process(
self._grid_augmentation_count = grid_augmentation_count

_log.info(f"Creating pool function to process {len(self)} queries...")
pool_function = partial(self._process_one_query)
pool_function = partial(self._process_one_query, log_error_traceback=log_error_traceback)
with Pool(self._cpu_count) as pool:
_log.info("Starting pooling...\n")
pool.map(pool_function, self.queries)
Expand All @@ -551,6 +548,24 @@ def process(
os.remove(output_path)
return glob(f"{prefix}.hdf5")

n_processed = 0
for hdf5file in output_paths:
with h5py.File(hdf5file, "r") as hdf5:
# List of all graphs in hdf5, each graph representing
# a SRV and its sourrouding environment
n_processed += len(list(hdf5.keys()))

if not n_processed:
msg = "No queries have been processed."
raise ValueError(msg)
if n_processed != len(self.queries):
_log.warning(
f"Not all queries have been processed. You can proceed with the analysis of {n_processed}/{len(self.queries)} queries.\n"
"Set `log_error_traceback` to True for advanced troubleshooting.",
)
else:
_log.info(f"{n_processed} queries have been processed.")

return output_paths

def _set_feature_modules(self, feature_modules: list[ModuleType, str] | ModuleType | str) -> list[str]:
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ dependencies = [
"python-louvain >= 0.16, < 1.0",
"tqdm >= 4.66.4, < 5.0",
"freesasa >= 2.1.1, < 3.0",
"biopython >= 1.83, < 2.0"
]
"biopython >= 1.83, < 2.0",
]

[project.optional-dependencies]
# development dependency groups
Expand All @@ -66,7 +66,7 @@ test = [
"pytest-cov >= 4.1.0, < 5.0",
"pytest-runner >= 6.0.0, < 7.0",
"coveralls >= 3.3.1, < 4.0",
"ruff == 0.5.1"
"ruff == 0.6.3",
]
publishing = ["build", "twine", "wheel"]
notebooks = ["nbmake"]
Expand All @@ -88,7 +88,7 @@ include = ["deeprank2*"]

[tool.pytest.ini_options]
# pytest options: -ra: show summary info for all test outcomes
addopts = "-ra"
addopts = "-ra"

[tool.ruff]
output-format = "concise"
Expand Down Expand Up @@ -148,3 +148,4 @@ isort.known-first-party = ["deeprank2"]
]
"docs/*" = ["ALL"]
"tests/perf/*" = ["T201"] # Use of print statements
"*.ipynb" = ["T201", "E402", "D103"]
25 changes: 13 additions & 12 deletions tests/data/hdf5/_generate_testdata.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
"PATH_TEST = ROOT / \"tests\"\n",
"import glob\n",
"import os\n",
"import re\n",
"import sys\n",
"\n",
"import h5py\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from deeprank2.dataset import save_hdf5_keys\n",
Expand Down Expand Up @@ -79,7 +76,7 @@
" chain_ids=[chain_id1, chain_id2],\n",
" targets=targets,\n",
" pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2},\n",
" )\n",
" ),\n",
" )\n",
"\n",
" # Generate graphs and save them in hdf5 files\n",
Expand Down Expand Up @@ -128,8 +125,8 @@
"csv_data = pd.read_csv(csv_file_path)\n",
"csv_data.cluster = csv_data.cluster.fillna(-1)\n",
"pdb_ids_csv = [pdb_file.split(\"/\")[-1].split(\".\")[0].replace(\"-\", \"_\") for pdb_file in pdb_files]\n",
"clusters = [csv_data[pdb_id == csv_data.ID].cluster.values[0] for pdb_id in pdb_ids_csv]\n",
"bas = [csv_data[pdb_id == csv_data.ID].measurement_value.values[0] for pdb_id in pdb_ids_csv]\n",
"clusters = [csv_data[pdb_id == csv_data.ID].cluster.to_numpy()[0] for pdb_id in pdb_ids_csv]\n",
"bas = [csv_data[pdb_id == csv_data.ID].measurement_value.to_numpy()[0] for pdb_id in pdb_ids_csv]\n",
"\n",
"queries = QueryCollection()\n",
"print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n",
Expand All @@ -147,7 +144,7 @@
" \"cluster\": clusters[i],\n",
" },\n",
" pssm_paths={\"M\": pssm_m[i], \"P\": pssm_p[i]},\n",
" )\n",
" ),\n",
" )\n",
"print(\"Queries created and ready to be processed.\\n\")\n",
"\n",
Expand Down Expand Up @@ -183,7 +180,7 @@
"test_ids = []\n",
"\n",
"with h5py.File(hdf5_path, \"r\") as hdf5:\n",
" for key in hdf5.keys():\n",
" for key in hdf5:\n",
" feature_value = float(hdf5[key][target][feature][()])\n",
" if feature_value in train_clusters:\n",
" train_ids.append(key)\n",
Expand All @@ -192,7 +189,7 @@
" elif feature_value in test_clusters:\n",
" test_ids.append(key)\n",
"\n",
" if feature_value in clusters.keys():\n",
" if feature_value in clusters:\n",
" clusters[int(feature_value)] += 1\n",
" else:\n",
" clusters[int(feature_value)] = 1\n",
Expand Down Expand Up @@ -278,8 +275,12 @@
" targets = compute_ppi_scores(pdb_path, ref_path)\n",
" queries.add(\n",
" ProteinProteinInterfaceQuery(\n",
" pdb_path=pdb_path, resolution=\"atom\", chain_ids=[chain_id1, chain_id2], targets=targets, pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2}\n",
" )\n",
" pdb_path=pdb_path,\n",
" resolution=\"atom\",\n",
" chain_ids=[chain_id1, chain_id2],\n",
" targets=targets,\n",
" pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2},\n",
" ),\n",
" )\n",
"\n",
"# Generate graphs and save them in hdf5 files\n",
Expand All @@ -303,7 +304,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,7 @@ def test_inherit_info_pretrained_model_graphdataset(self) -> None:
for key in data["features_transform"].values():
if key["transform"] is None:
continue
key["transform"] = eval(key["transform"]) # noqa: S307, PGH001
key["transform"] = eval(key["transform"]) # noqa: S307

dataset_test_vars = vars(dataset_test)
for param in dataset_test.inherited_params:
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
target_value = 1.0


@pytest.fixture()
@pytest.fixture
def graph() -> Graph:
"""Build a simple graph of two nodes and one edge in between them."""
# load the structure
Expand Down
Loading

0 comments on commit ea47488

Please sign in to comment.