diff --git a/.bumpversion.cfg b/.bumpversion.cfg deleted file mode 100644 index 51e0082cb..000000000 --- a/.bumpversion.cfg +++ /dev/null @@ -1,17 +0,0 @@ -[bumpversion] -current_version = 3.0.5 - -[comment] -comment = The contents of this file cannot be merged with that of setup.cfg until https://github.com/c4urself/bump2version/issues/185 is resolved - -[bumpversion:file:deeprank2/__init__.py] -search = __version__ = "{current_version}" -replace = __version__ = "{new_version}" - -[bumpversion:file:pyproject.toml] -search = version = "{current_version}" -replace = version = "{new_version}" - -[bumpversion:file:CITATION.cff] -search = version: "{current_version}" -replace = version: "{new_version}" diff --git a/.bumpversion.toml b/.bumpversion.toml new file mode 100644 index 000000000..b527b8dd1 --- /dev/null +++ b/.bumpversion.toml @@ -0,0 +1,17 @@ +[tool.bumpversion] +current_version = "3.0.5" + +[[tool.bumpversion.files]] +filename = "pyproject.toml" +search = 'version = "{current_version}"' +replace = 'version = "{new_version}"' + +[[tool.bumpversion.files]] +filename = "CITATION.cff" +search = 'version: "{current_version}"' +replace = 'version: "{new_version}"' + +[[tool.bumpversion.files]] +filename = "deeprank2/__init__.py" +search = '__version__ = "{current_version}"' +replace = '__version__ = "{new_version}"' diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 030e4cbca..a9369bfdc 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -57,6 +57,6 @@ jobs: python3 --version - name: Check linting and formatting using ruff run: | - python3 -m pip install ruff + python3 -m pip install ruff==0.5.1 ruff check || (echo "Please ensure you have the latest version of ruff (`ruff -V`) installed locally." && (exit 1)) ruff format --check || (echo "Please ensure you have the latest version of ruff (`ruff -V`) installed locally." && (exit 1)) diff --git a/.github/workflows/notebooks.yml b/.github/workflows/notebooks.yml index 9f3f29020..758b30302 100644 --- a/.github/workflows/notebooks.yml +++ b/.github/workflows/notebooks.yml @@ -57,6 +57,8 @@ jobs: wget https://zenodo.org/records/13709906/files/data_raw.zip unzip data_raw.zip -d data_raw mv data_raw tutorials + echo listing files in data_raw: + ls tutorials/data_raw - name: Run tutorial notebooks run: pytest --nbmake tutorials diff --git a/.github/workflows/release_github.yml b/.github/workflows/release_github.yml new file mode 100644 index 000000000..001dfc4f4 --- /dev/null +++ b/.github/workflows/release_github.yml @@ -0,0 +1,139 @@ +name: Draft GitHub Release + +on: + workflow_dispatch: + inputs: + version_level: + description: "Semantic version level increase." + required: true + type: choice + options: + - patch + - minor + - major + +permissions: + contents: write + pull-requests: write + +jobs: + draft_release: + runs-on: "ubuntu-latest" + defaults: + run: + shell: bash -l {0} + strategy: + fail-fast: true + + steps: + - name: Display selection + run: | + echo "Branch selected: '${{ github.ref_name }}'" + echo "Release level selected: '${{ github.event.inputs.version_level }}'" + + - name: Ensure that permitted release branch was selected + if: ${{ github.ref_name == 'main' || github.ref_name == 'dev' }} + run: | + echo "Branch selected: '${{ github.ref_name }}'" + echo "Releasing from main or dev branch is not permitted, please select a different release branch." + exit 1 + + - name: Check GitHub Token Validity + run: | + echo "-- Validating GitHub Token" + status_code=$(curl -o /dev/null -s -w "%{http_code}" -H "Authorization: token ${{ secrets.GH_RELEASE }}" https://api.github.com/user) + if [ "$status_code" -ne 200 ]; then + echo "Error: GitHub token is invalid or expired. Please update your token in secrets." + echo "Instructions can be found at: https://github.com/DeepRank/deeprank2/blob/main/README.dev.md#updating-the-token" + exit 1 + else + echo "GitHub token is valid." + fi + + - name: Checkout repository + uses: actions/checkout@v4 + with: + # token with admin priviliges to override brach protection on main and dev + token: ${{ secrets.GH_RELEASE }} + ref: main + fetch-depth: 0 + + - name: Configure git + run: | + git config user.email "actions@github.com" + git config user.name "GitHub Actions" + git pull + + - name: Merge changes into main + run: | + git switch main + git merge origin/${{ github.ref_name }} --no-ff --no-commit + git commit --no-edit + + - name: Bump version + id: bump + run: | + echo "-- install bump-my-version" + python3 -m pip install bump-my-version + echo "-- bump the version" + bump-my-version bump ${{ github.event.inputs.version_level }} --commit --tag + echo "-- push bumped version" + echo "RELEASE_TAG=$(git describe --tags --abbrev=0)" >> $GITHUB_OUTPUT + git push --tags -f + git push + + - name: Create GitHub Release + id: create_release + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + gh release create ${{ steps.bump.outputs.RELEASE_TAG }} \ + --title="Release ${{ steps.bump.outputs.RELEASE_TAG }}" \ + --generate-notes \ + --draft + + tidy_workspace: + # only run if action above succeeds + needs: draft_release + runs-on: "ubuntu-latest" + defaults: + run: + shell: bash -l {0} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + # token with admin priviliges to override brach protection on main and dev + token: ${{ secrets.GH_RELEASE }} + fetch-depth: 0 + + - name: Configure git + run: | + git config user.email "actions@github.com" + git config user.name "GitHub Actions" + git pull + + - name: Close PR + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + echo "-- searching for associated PR" + pr_number=$(gh pr list --head ${{ github.ref_name }} --json number --jq '.[0].number') + if [ -n "$pr_number" ]; then + echo "-- closing PR #$pr_number" + gh pr close $pr_number + else + echo "-- no open pull request found for branch $branch_name" + fi + + - name: Merge updates into dev + run: | + git switch dev + git merge origin/main + git push + + - name: Delete release branch other than main or dev + run: | + echo "-- deleting branch '${{ github.ref_name }}'" + git push origin -d ${{ github.ref_name }} diff --git a/.github/workflows/release.yml b/.github/workflows/release_pypi.yml similarity index 100% rename from .github/workflows/release.yml rename to .github/workflows/release_pypi.yml diff --git a/.ruff.toml b/.ruff.toml new file mode 100644 index 000000000..80320a9cf --- /dev/null +++ b/.ruff.toml @@ -0,0 +1,59 @@ +target-version = "py310" +output-format = "concise" +line-length = 159 + +[lint] +select = ["ALL"] +pydocstyle.convention = "google" # docstring settings +ignore = [ + # Unrealistic for this code base + "PTH", # flake8-use-pathlib + "N", # naming conventions + "PLR0912", # Too many branches, + "PLR0913", # Too many arguments in function definition + "D102", # Missing docstring in public method + # Unwanted + "FBT", # Using boolean arguments + "ANN101", # Missing type annotation for `self` in method + "ANN102", # Missing type annotation for `cls` in classmethod + "ANN204", # Missing return type annotation for special (dunder) method + "B028", # No explicit `stacklevel` keyword argument found in warning + "S105", # Possible hardcoded password + "S311", # insecure random generators + "PT011", # pytest-raises-too-broad + "SIM108", # Use ternary operator + # Unwanted docstrings + "D100", # Missing module docstring + "D104", # Missing public package docstring + "D105", # Missing docstring in magic method + "D107", # Missing docstring in `__init__` +] + +# Autofix settings +fixable = ["ALL"] +unfixable = ["F401"] # unused imports (should not disappear while editing) +extend-safe-fixes = [ + "D415", # First line should end with a period, question mark, or exclamation point + "D300", # Use triple double quotes `"""` + "D200", # One-line docstring should fit on one line + "TCH", # Format type checking only imports + "ISC001", # Implicitly concatenated strings on a single line + "EM", # Exception message variables + "RUF013", # Implicit Optional + "B006", # Mutable default argument +] + +isort.known-first-party = ["deeprank2"] + +[lint.per-file-ignores] +"tests/*" = [ + "S101", # Use of `assert` detected + "PLR2004", # Magic value used in comparison + "D101", # Missing class docstring + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "SLF001", # private member access +] +"docs/*" = ["ALL"] +"tests/perf/*" = ["T201"] # Use of print statements +"*.ipynb" = ["T201", "E402", "D103"] diff --git a/README.dev.md b/README.dev.md index d15731b30..4d3b19f54 100644 --- a/README.dev.md +++ b/README.dev.md @@ -79,10 +79,66 @@ During the development cycle, three main supporting branches are used: ## Making a release -1. Branch from `dev` and prepare the branch for the release (e.g., removing the unnecessary dev files, fix minor bugs if necessary). -2. [Bump the version](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#versioning). -3. Merge the release branch into `main` (and `dev`), and [run the tests](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#running-the-tests). -4. Go to https://github.com/DeepRank/deeprank2/releases and draft a new release; create a new tag for the release, generate release notes automatically and adjust them, and finally publish the release as latest. This will trigger [a GitHub action](https://github.com/DeepRank/deeprank2/actions/workflows/release.yml) that will take care of publishing the package on PyPi. +### Automated release workflow: + +1. **IMP0RTANT:** Create a PR for the release branch, targeting the `main` branch. Ensure there are no conflicts and that all checks pass successfully. Release branches are typically: traditional [release branches](https://nvie.com/posts/a-successful-git-branching-model/#release-branches) (these are created from the `dev` branch), or [hotfix branches](https://nvie.com/posts/a-successful-git-branching-model/#hotfix-branches) (these are created directly from the `main` branch). + - if everything goes well, this PR will automatically be closed after the draft release is created. +2. Navigate to [Draft Github Release](https://github.com/DeepRank/deeprank2/actions/workflows/release_github.yml) + on the [Actions](https://github.com/DeepRank/deeprank2/actions) tab. +3. On the right hand side, you can select the level increase ("patch", "minor", or "major") and which branch to release from. + - [Follow semantic versioning conventions](https://semver.org/) to chose the level increase: + - `patch`: when backward compatible bug fixes were made + - `minor`: when functionality was added in a backward compatible manner + - `major`: when API-incompatible changes have been made + - Note that you cannot release from `main` (the default shown) using the automated workflow. To release from `main` + directly, you must [create the release manually](#manually-create-a-release). +4. Visit [Actions](https://github.com/DeepRank/deeprank2/actions) tab to check whether everything went as expected. + - NOTE: there are two separate jobs in the workflow: "draft_release" and "tidy_workspace". The first creates the draft release on github, while the second merges changes into `dev` and closes the PR. + - If "draft_release" fails, then there are likely merge conflicts with `main` that need to be resolved first. No release draft is created and the "tidy_workspace" job does not run. Coversely, if this action is succesfull, then the release branch (including a version bump) have been merged into the remote `main` branch. + - If "draft_release" is succesfull but "tidy_workspace" fails, then there are likely merge conflicts with `dev` that are not conflicts with `main`. In this case, the draft release is created (and changes were merged into the remote `main`). Conflicts with `dev` need to be resolved with `dev` by the user. + - If both jobs succeed, then the draft release is created and the changes are merged into both remote `main` and `dev` without any problems and the associated PR is closed. Also, the release branch is deleted from the remote repository. +5. Navigate to the [Releases](https://github.com/DeepRank/deeprank2/releases) tab and click on the newest draft + release that was just generated. +6. Click on the edit (pencil) icon on the right side of the draft release. +7. Check/adapt the release notes and make sure that everything is as expected. +8. Check that "Set as the latest release is checked". +9. Click green "Publish Release" button to convert the draft to a published release on GitHub. + - This will automatically trigger [another GitHub workflow](https://github.com/DeepRank/deeprank2/actions/workflows/release.yml) that will take care of publishing the package on PyPi. + +#### Updating the token: + +In order for the workflow above to be able to bypass the branch protection on `main` and `dev`, a token with admin priviliges for the current repo is required. Below are instructions on how to create such a token. +NOTE: the current token (associated to @DaniBodor) allowing to bypass branch protection will expire on 9 July 2025. To update the token do the following: + +1. [Create a personal access token](https://github.com/settings/tokens/new) from a GitHub user account with admin + priviliges for this repo. +2. Check all the "repo" boxes and the "workflow" box, set an expiration date, and give the token a note. +3. Click green "Generate token" button on the bottom +4. Copy the token immediately, as it will not be visible again later. +5. Navigate to the [secrets settings](https://github.com/DeepRank/deeprank2/settings/secrets/actions). +6. Edit the `GH_RELEASE` key giving your access token as the new value. + +### Manually create a release: + +0. Make sure you have all required developers tools installed `pip install -e .'[test]'`. +1. Create a `release-` branch from `main` (if there has been an hotfix) or `dev` (regular new production release). +2. Prepare the branch for the release (e.g., removing the unnecessary dev files, fix minor bugs if necessary). Do this by ensuring all tests pass `pytest -v` and that linting (`ruff check`) and formatting (`ruff format --check`) conventions are adhered to. +3. Bump the version using [bump-my-version](https://github.com/callowayproject/bump-my-version): `bump-my-version bump ` + where level must be one of the following ([following semantic versioning conventions](https://semver.org/)): + - `major`: when API-incompatible changes have been made + - `minor`: when functionality was added in a backward compatible manner + - `patch`: when backward compatible bug fixes were made +4. Merge the release branch into `main` and `dev`. +5. On the [Releases page](https://github.com/DeepRank/deeprank2/releases): + 1. Click "Draft a new release" + 2. By convention, use `v` as both the release title and as a tag for the release. + 3. Click "Generate release notes" to automatically load release notes from merged PRs since the last release. + 4. Adjust the notes as required. + 5. Ensure that "Set as latest release" is checked and that both other boxes are unchecked. + 6. Hit "Publish release". + - This will automatically trigger a [GitHub + workflow](https://github.com/DeepRank/deeprank2/actions/workflows/release.yml) that will take care of publishing + the package on PyPi. ## UML diff --git a/deeprank2/dataset.py b/deeprank2/dataset.py index 18796ac26..144e8d0af 100644 --- a/deeprank2/dataset.py +++ b/deeprank2/dataset.py @@ -112,7 +112,7 @@ def _check_and_inherit_train( # noqa: C901 for key in data["features_transform"].values(): if key["transform"] is None: continue - key["transform"] = eval(key["transform"]) # noqa: S307, PGH001 + key["transform"] = eval(key["transform"]) # noqa: S307 except pickle.UnpicklingError as e: msg = "The path provided to `train_source` is not a valid DeepRank2 pre-trained model." raise ValueError(msg) from e @@ -277,7 +277,7 @@ def _filter_targets(self, grp: h5py.Group) -> bool: for operator_string in [">", "<", "==", "<=", ">=", "!="]: operation = operation.replace(operator_string, f"{target_value}" + operator_string) - if not eval(operation): # noqa: S307, PGH001 + if not eval(operation): # noqa: S307 return False elif target_condition is not None: diff --git a/deeprank2/query.py b/deeprank2/query.py index 89f171a4d..b85222d14 100644 --- a/deeprank2/query.py +++ b/deeprank2/query.py @@ -22,7 +22,7 @@ import deeprank2.features from deeprank2.domain.aminoacidlist import convert_aa_nomenclature from deeprank2.features import components, conservation, contact -from deeprank2.molstruct.residue import Residue, SingleResidueVariant +from deeprank2.molstruct.residue import SingleResidueVariant from deeprank2.utils.buildgraph import get_contact_atoms, get_structure, get_surrounding_residues from deeprank2.utils.graph import Graph from deeprank2.utils.grid import Augmentation, GridSettings, MapMethod @@ -265,12 +265,11 @@ def _build_helper(self) -> Graph: structure = self._load_structure() # find the variant residue and its surroundings - variant_residue: Residue = None for residue in structure.get_chain(self.variant_chain_id).residues: if residue.number == self.variant_residue_number and residue.insertion_code == self.insertion_code: variant_residue = residue break - if variant_residue is None: + else: # if break is not reached msg = f"Residue not found in {self.pdb_path}: {self.variant_chain_id} {self.residue_id}" raise ValueError(msg) self.variant = SingleResidueVariant(variant_residue, self.variant_amino_acid) @@ -354,19 +353,12 @@ def _build_helper(self) -> Graph: raise ValueError(msg) # build the graph - if self.resolution == "atom": - graph = Graph.build_graph( - contact_atoms, - self.get_query_id(), - self.max_edge_length, - ) - elif self.resolution == "residue": - residues_selected = list({atom.residue for atom in contact_atoms}) - graph = Graph.build_graph( - residues_selected, - self.get_query_id(), - self.max_edge_length, - ) + nodes = contact_atoms if self.resolution == "atom" else list({atom.residue for atom in contact_atoms}) + graph = Graph.build_graph( + nodes=nodes, + graph_id=self.get_query_id(), + max_edge_length=self.max_edge_length, + ) graph.center = np.mean([atom.position for atom in contact_atoms], axis=0) structure = contact_atoms[0].residue.chain.model @@ -453,7 +445,7 @@ def __iter__(self) -> Iterator[Query]: def __len__(self) -> int: return len(self._queries) - def _process_one_query(self, query: Query) -> None: + def _process_one_query(self, query: Query, log_error_traceback: bool = False) -> None: """Only one process may access an hdf5 file at a time.""" try: output_path = f"{self._prefix}-{os.getpid()}.hdf5" @@ -479,10 +471,12 @@ def _process_one_query(self, query: Query) -> None: except (ValueError, AttributeError, KeyError, TimeoutError) as e: _log.warning( - f"\nGraph/Query with ID {query.get_query_id()} ran into an Exception ({e.__class__.__name__}: {e})," - " and it has not been written to the hdf5 file. More details below:", + f"Graph/Query with ID {query.get_query_id()} ran into an Exception and was not written to the hdf5 file.\n" + f"Exception found: {e.__class__.__name__}: {e}.\n" + "You may proceed with your analysis, but this query will be ignored.\n", ) - _log.exception(e) + if log_error_traceback: + _log.exception(f"----Full error traceback:----\n{e}") def process( self, @@ -493,6 +487,7 @@ def process( grid_settings: GridSettings | None = None, grid_map_method: MapMethod | None = None, grid_augmentation_count: int = 0, + log_error_traceback: bool = False, ) -> list[str]: """Render queries into graphs (and optionally grids). @@ -510,6 +505,8 @@ def process( grid_settings: If valid together with `grid_map_method`, the grid data will be stored as well. Defaults to None. grid_map_method: If valid together with `grid_settings`, the grid data will be stored as well. Defaults to None. grid_augmentation_count: Number of grid data augmentations (must be >= 0). Defaults to 0. + log_error_traceback: if True, logs full error message in case query fails. Otherwise only the error message is logged. + Defaults to false. Returns: The list of paths of the generated HDF5 files. @@ -536,7 +533,7 @@ def process( self._grid_augmentation_count = grid_augmentation_count _log.info(f"Creating pool function to process {len(self)} queries...") - pool_function = partial(self._process_one_query) + pool_function = partial(self._process_one_query, log_error_traceback=log_error_traceback) with Pool(self._cpu_count) as pool: _log.info("Starting pooling...\n") pool.map(pool_function, self.queries) @@ -551,6 +548,24 @@ def process( os.remove(output_path) return glob(f"{prefix}.hdf5") + n_processed = 0 + for hdf5file in output_paths: + with h5py.File(hdf5file, "r") as hdf5: + # List of all graphs in hdf5, each graph representing + # a SRV and its sourrouding environment + n_processed += len(list(hdf5.keys())) + + if not n_processed: + msg = "No queries have been processed." + raise ValueError(msg) + if n_processed != len(self.queries): + _log.warning( + f"Not all queries have been processed. You can proceed with the analysis of {n_processed}/{len(self.queries)} queries.\n" + "Set `log_error_traceback` to True for advanced troubleshooting.", + ) + else: + _log.info(f"{n_processed} queries have been processed.") + return output_paths def _set_feature_modules(self, feature_modules: list[ModuleType, str] | ModuleType | str) -> list[str]: diff --git a/pyproject.toml b/pyproject.toml index 4172c9c1a..df3c3b431 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,21 +53,20 @@ dependencies = [ "python-louvain >= 0.16, < 1.0", "tqdm >= 4.66.4, < 5.0", "freesasa >= 2.1.1, < 3.0", - "biopython >= 1.83, < 2.0" - ] + "biopython >= 1.83, < 2.0", +] [project.optional-dependencies] -# development dependency groups test = [ "pytest >= 7.4.0, < 8.0", - "bump2version >= 1.0.1, < 2.0", + "bump-my-version >= 0.24.2, < 1.0", "coverage >= 6.5.0, < 7.0", "pycodestyle >= 2.8.0, < 3.0", "pytest-cov >= 4.1.0, < 5.0", "pytest-runner >= 6.0.0, < 7.0", "coveralls >= 3.3.1, < 4.0", - "ruff == 0.5.1" -] + "ruff == 0.6.3", +] # development dependency groups publishing = ["build", "twine", "wheel"] notebooks = ["nbmake"] @@ -88,63 +87,4 @@ include = ["deeprank2*"] [tool.pytest.ini_options] # pytest options: -ra: show summary info for all test outcomes -addopts = "-ra" - -[tool.ruff] -output-format = "concise" -line-length = 159 - -[tool.ruff.lint] -select = ["ALL"] -pydocstyle.convention = "google" # docstring settings -ignore = [ - # Unrealistic for this code base - "PTH", # flake8-use-pathlib - "N", # naming conventions - "PLR0912", # Too many branches, - "PLR0913", # Too many arguments in function definition - "D102", # Missing docstring in public method - # Unwanted - "FBT", # Using boolean arguments - "ANN101", # Missing type annotation for `self` in method - "ANN102", # Missing type annotation for `cls` in classmethod - "ANN204", # Missing return type annotation for special (dunder) method - "B028", # No explicit `stacklevel` keyword argument found in warning - "S105", # Possible hardcoded password - "S311", # insecure random generators - "PT011", # pytest-raises-too-broad - "SIM108", # Use ternary operator - # Unwanted docstrings - "D100", # Missing module docstring - "D104", # Missing public package docstring - "D105", # Missing docstring in magic method - "D107", # Missing docstring in `__init__` -] - -# Autofix settings -fixable = ["ALL"] -unfixable = ["F401"] # unused imports (should not disappear while editing) -extend-safe-fixes = [ - "D415", # First line should end with a period, question mark, or exclamation point - "D300", # Use triple double quotes `"""` - "D200", # One-line docstring should fit on one line - "TCH", # Format type checking only imports - "ISC001", # Implicitly concatenated strings on a single line - "EM", # Exception message variables - "RUF013", # Implicit Optional - "B006", # Mutable default argument -] - -isort.known-first-party = ["deeprank2"] - -[tool.ruff.lint.per-file-ignores] -"tests/*" = [ - "S101", # Use of `assert` detected - "PLR2004", # Magic value used in comparison - "D101", # Missing class docstring - "D102", # Missing docstring in public method - "D103", # Missing docstring in public function - "SLF001", # private member access -] -"docs/*" = ["ALL"] -"tests/perf/*" = ["T201"] # Use of print statements +addopts = "-ra" diff --git a/tests/data/hdf5/_generate_testdata.ipynb b/tests/data/hdf5/_generate_testdata.ipynb index b2fc2677f..762897834 100644 --- a/tests/data/hdf5/_generate_testdata.ipynb +++ b/tests/data/hdf5/_generate_testdata.ipynb @@ -15,11 +15,8 @@ "PATH_TEST = ROOT / \"tests\"\n", "import glob\n", "import os\n", - "import re\n", - "import sys\n", "\n", "import h5py\n", - "import numpy as np\n", "import pandas as pd\n", "\n", "from deeprank2.dataset import save_hdf5_keys\n", @@ -79,7 +76,7 @@ " chain_ids=[chain_id1, chain_id2],\n", " targets=targets,\n", " pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2},\n", - " )\n", + " ),\n", " )\n", "\n", " # Generate graphs and save them in hdf5 files\n", @@ -128,8 +125,8 @@ "csv_data = pd.read_csv(csv_file_path)\n", "csv_data.cluster = csv_data.cluster.fillna(-1)\n", "pdb_ids_csv = [pdb_file.split(\"/\")[-1].split(\".\")[0].replace(\"-\", \"_\") for pdb_file in pdb_files]\n", - "clusters = [csv_data[pdb_id == csv_data.ID].cluster.values[0] for pdb_id in pdb_ids_csv]\n", - "bas = [csv_data[pdb_id == csv_data.ID].measurement_value.values[0] for pdb_id in pdb_ids_csv]\n", + "clusters = [csv_data[pdb_id == csv_data.ID].cluster.to_numpy()[0] for pdb_id in pdb_ids_csv]\n", + "bas = [csv_data[pdb_id == csv_data.ID].measurement_value.to_numpy()[0] for pdb_id in pdb_ids_csv]\n", "\n", "queries = QueryCollection()\n", "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", @@ -147,7 +144,7 @@ " \"cluster\": clusters[i],\n", " },\n", " pssm_paths={\"M\": pssm_m[i], \"P\": pssm_p[i]},\n", - " )\n", + " ),\n", " )\n", "print(\"Queries created and ready to be processed.\\n\")\n", "\n", @@ -183,7 +180,7 @@ "test_ids = []\n", "\n", "with h5py.File(hdf5_path, \"r\") as hdf5:\n", - " for key in hdf5.keys():\n", + " for key in hdf5:\n", " feature_value = float(hdf5[key][target][feature][()])\n", " if feature_value in train_clusters:\n", " train_ids.append(key)\n", @@ -192,7 +189,7 @@ " elif feature_value in test_clusters:\n", " test_ids.append(key)\n", "\n", - " if feature_value in clusters.keys():\n", + " if feature_value in clusters:\n", " clusters[int(feature_value)] += 1\n", " else:\n", " clusters[int(feature_value)] = 1\n", @@ -278,8 +275,12 @@ " targets = compute_ppi_scores(pdb_path, ref_path)\n", " queries.add(\n", " ProteinProteinInterfaceQuery(\n", - " pdb_path=pdb_path, resolution=\"atom\", chain_ids=[chain_id1, chain_id2], targets=targets, pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2}\n", - " )\n", + " pdb_path=pdb_path,\n", + " resolution=\"atom\",\n", + " chain_ids=[chain_id1, chain_id2],\n", + " targets=targets,\n", + " pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2},\n", + " ),\n", " )\n", "\n", "# Generate graphs and save them in hdf5 files\n", @@ -303,7 +304,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.14" }, "orig_nbformat": 4, "vscode": { diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 932e7d3c9..3d2424258 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1201,7 +1201,7 @@ def test_inherit_info_pretrained_model_graphdataset(self) -> None: for key in data["features_transform"].values(): if key["transform"] is None: continue - key["transform"] = eval(key["transform"]) # noqa: S307, PGH001 + key["transform"] = eval(key["transform"]) # noqa: S307 dataset_test_vars = vars(dataset_test) for param in dataset_test.inherited_params: diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 39bd2c9c8..5ca7d4273 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -27,7 +27,7 @@ target_value = 1.0 -@pytest.fixture() +@pytest.fixture def graph() -> Graph: """Build a simple graph of two nodes and one edge in between them.""" # load the structure diff --git a/tutorials/data_generation_ppi.ipynb b/tutorials/data_generation_ppi.ipynb index 8553d4a73..8106e3637 100644 --- a/tutorials/data_generation_ppi.ipynb +++ b/tutorials/data_generation_ppi.ipynb @@ -1,549 +1,559 @@ { - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Data preparation for protein-protein interfaces\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Introduction\n", - "\n", - "\n", - "\n", - "This tutorial will demonstrate the use of DeepRank2 for generating protein-protein interface (PPI) graphs and saving them as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) files, using [PBD files]() of protein-protein complexes as input.\n", - "\n", - "In this data processing phase, for each protein-protein complex an interface is selected according to a distance threshold that the user can customize, and it is mapped to a graph. Nodes either represent residues or atoms, and edges are the interactions between them. Each node and edge can have several different features, which are generated and added during the processing phase as well. Optionally, the graphs can be mapped to volumetric grids (i.e., 3D image-like representations), together with their features. The mapped data are finally saved into HDF5 files, and can be used for later models' training (for details go to [training_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/training_ppi.ipynb) tutorial). In particular, graphs can be used for the training of Graph Neural Networks (GNNs), and grids can be used for the training of Convolutional Neural Networks (CNNs).\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Input Data\n", - "\n", - "The example data used in this tutorial are available on Zenodo at [this record address](https://zenodo.org/record/13709906). To download the raw data used in this tutorial, please visit the link and download `data_raw.zip`. Unzip it, and save the `data_raw/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", - "\n", - "Note that the dataset contains only 100 data points, which is not enough to develop an impactful predictive model, and the scope of its use is indeed only demonstrative and informative for the users.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Utilities\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Libraries\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The libraries needed for this tutorial:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import pandas as pd\n", - "import glob\n", - "import h5py\n", - "import matplotlib.image as img\n", - "import matplotlib.pyplot as plt\n", - "from deeprank2.query import QueryCollection\n", - "from deeprank2.query import ProteinProteinInterfaceQuery, ProteinProteinInterfaceQuery\n", - "from deeprank2.features import components, contact\n", - "from deeprank2.utils.grid import GridSettings, MapMethod\n", - "from deeprank2.dataset import GraphDataset" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Raw files and paths\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The paths for reading raw data and saving the processed ones:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data_path = os.path.join(\"data_raw\", \"ppi\")\n", - "processed_data_path = os.path.join(\"data_processed\", \"ppi\")\n", - "os.makedirs(os.path.join(processed_data_path, \"residue\"))\n", - "os.makedirs(os.path.join(processed_data_path, \"atomic\"))\n", - "# Flag limit_data as True if you are running on a machine with limited memory (e.g., Docker container)\n", - "limit_data = True" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "- Raw data are PDB files in `data_raw/ppi/pdb/`, which contains atomic coordinates of the protein-protein complexes of interest, so in our case of pMHC complexes.\n", - "- Target data, so in our case the BA values for the pMHC complex, are in `data_raw/ppi/BA_values.csv`.\n", - "- The final PPI processed data will be saved in `data_processed/ppi/` folder, which in turns contains a folder for residue-level data and another one for atomic-level data. More details about such different levels will come a few cells below.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`get_pdb_files_and_target_data` is an helper function used to retrieve the raw pdb files names in a list and the BA target values from a CSV containing the IDs of the PDB models as well:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def get_pdb_files_and_target_data(data_path):\n", - " csv_data = pd.read_csv(os.path.join(data_path, \"BA_values.csv\"))\n", - " pdb_files = glob.glob(os.path.join(data_path, \"pdb\", \"*.pdb\"))\n", - " pdb_files.sort()\n", - " pdb_ids_csv = [pdb_file.split(\"/\")[-1].split(\".\")[0] for pdb_file in pdb_files]\n", - " csv_data_indexed = csv_data.set_index(\"ID\")\n", - " csv_data_indexed = csv_data_indexed.loc[pdb_ids_csv]\n", - " bas = csv_data_indexed.measurement_value.values.tolist()\n", - " return pdb_files, bas\n", - "\n", - "\n", - "pdb_files, bas = get_pdb_files_and_target_data(data_path)\n", - "\n", - "if limit_data:\n", - " pdb_files = pdb_files[:15]" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## `QueryCollection` and `Query` objects\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For each protein-protein complex, so for each data point, a query can be created and added to the `QueryCollection` object, to be processed later on.\n", - "\n", - "A query takes as inputs:\n", - "\n", - "- A `.pdb` file, representing the protein-protein structural complex.\n", - "- The resolution (`\"residue\"` or `\"atom\"`), i.e. whether each node should represent an amino acid residue or an atom.\n", - "- The ids of the two chains composing the complex. In our use case, \"M\" indicates the MHC protein chain and \"P\" the peptide chain.\n", - "- The interaction radius, which determines the threshold distance (in Ångström) for residues/atoms surrounding the interface that will be included in the graph.\n", - "- The target values associated with the query. For each query/data point, in the use case demonstrated in this tutorial will add two targets: \"BA\" and \"binary\". The first represents the actual BA value of the complex in nM, while the second represents its binary mapping, being 0 (BA > 500 nM) a not-binding complex and 1 (BA <= 500 nM) a binding one.\n", - "- The max edge distance, which is the maximum distance between two nodes to generate an edge between them.\n", - "- Optional: The correspondent [Position-Specific Scoring Matrices (PSSMs)](https://en.wikipedia.org/wiki/Position_weight_matrix), in the form of .pssm files. PSSMs are optional and will not be used in this tutorial.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Residue-level PPIs using `ProteinProteinInterfaceQuery`\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "queries = QueryCollection()\n", - "\n", - "influence_radius = 8 # max distance in Å between two interacting residues/atoms of two proteins\n", - "max_edge_length = 8\n", - "\n", - "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", - "count = 0\n", - "for i in range(len(pdb_files)):\n", - " queries.add(\n", - " ProteinProteinInterfaceQuery(\n", - " pdb_path=pdb_files[i],\n", - " resolution=\"residue\",\n", - " chain_ids=[\"M\", \"P\"],\n", - " influence_radius=influence_radius,\n", - " max_edge_length=max_edge_length,\n", - " targets={\n", - " \"binary\": int(float(bas[i]) <= 500), # binary target value\n", - " \"BA\": bas[i], # continuous target value\n", - " },\n", - " )\n", - " )\n", - " count += 1\n", - " if count % 20 == 0:\n", - " print(f\"{count} queries added to the collection.\")\n", - "\n", - "print(\"Queries ready to be processed.\\n\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Notes on `process()` method\n", - "\n", - "Once all queries have been added to the `QueryCollection` instance, they can be processed. Main parameters of the `process()` method, include:\n", - "\n", - "- `prefix` sets the output file location.\n", - "- `feature_modules` allows you to choose which feature generating modules you want to use. By default, the basic features contained in `deeprank2.features.components` and `deeprank2.features.contact` are generated. Users can add custom features by creating a new module and placing it in the `deeprank2.feature` subpackage. A complete and detailed list of the pre-implemented features per module and more information about how to add custom features can be found [here](https://deeprank2.readthedocs.io/en/latest/features.html).\n", - " - Note that all features generated by a module will be added if that module was selected, and there is no way to only generate specific features from that module. However, during the training phase shown in `training_ppi.ipynb`, it is possible to select only a subset of available features.\n", - "- `cpu_count` can be used to specify how many processes to be run simultaneously, and will coincide with the number of HDF5 files generated. By default it takes all available CPU cores and HDF5 files are squashed into a single file using the `combine_output` setting.\n", - "- Optional: If you want to include grids in the HDF5 files, which represent the mapping of the graphs to a volumetric box, you need to define `grid_settings` and `grid_map_method`, as shown in the example below. If they are `None` (default), only graphs are saved.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "grid_settings = GridSettings( # None if you don't want grids\n", - " # the number of points on the x, y, z edges of the cube\n", - " points_counts=[35, 30, 30],\n", - " # x, y, z sizes of the box in Å\n", - " sizes=[1.0, 1.0, 1.0],\n", - ")\n", - "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", - "\n", - "queries.process(\n", - " prefix=os.path.join(processed_data_path, \"residue\", \"proc\"),\n", - " feature_modules=[components, contact],\n", - " cpu_count=8,\n", - " combine_output=False,\n", - " grid_settings=grid_settings,\n", - " grid_map_method=grid_map_method,\n", - ")\n", - "\n", - "print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, \"residue\")}.')" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Exploring data\n", - "\n", - "As representative example, the following is the HDF5 structure generated by the previous code for `BA-100600.pdb`, so for one single graph, which represents one PPI, for the graph + grid case:\n", - "\n", - "```bash\n", - "└── residue-ppi:M-P:BA-100600\n", - " |\n", - " ├── edge_features\n", - " │ ├── _index\n", - " │ ├── _name\n", - " │ ├── covalent\n", - " │ ├── distance\n", - " │ ├── electrostatic\n", - " │ ├── same_chain\n", - " │ └── vanderwaals\n", - " |\n", - " ├── node_features\n", - " │ ├── _chain_id\n", - " │ ├── _name\n", - " │ ├── _position\n", - " │ ├── hb_acceptors\n", - " │ ├── hb_donors\n", - " │ ├── polarity\n", - " │ ├── res_charge\n", - " │ ├── res_mass\n", - " | ├── res_pI\n", - " | ├── res_size\n", - " | └── res_type\n", - " |\n", - " ├── grid_points\n", - " │ ├── center\n", - " │ ├── x\n", - " │ ├── y\n", - " │ └── z\n", - " |\n", - " ├── mapped_features\n", - " │ ├── _position_000\n", - " │ ├── _position_001\n", - " │ ├── _position_002\n", - " │ ├── covalent\n", - " │ ├── distance\n", - " │ ├── electrostatic\n", - " │ ├── polarity_000\n", - " │ ├── polarity_001\n", - " │ ├── polarity_002\n", - " │ ├── polarity_003\n", - " | ├── ...\n", - " | └── vanderwaals\n", - " |\n", - " └── target_values\n", - " │ ├── BA\n", - " └── binary\n", - "```\n", - "\n", - "`edge_features`, `node_features`, `mapped_features` are [HDF5 Groups](https://docs.h5py.org/en/stable/high/group.html) which contain [HDF5 Datasets](https://docs.h5py.org/en/stable/high/dataset.html) (e.g., `_index`, `electrostatic`, etc.), which in turn contains features values in the form of arrays. `edge_features` and `node_features` refer specificly to the graph representation, while `grid_points` and `mapped_features` refer to the grid mapped from the graph. Each data point generated by deeprank2 has the above structure, with the features and the target changing according to the user's settings. Features starting with `_` are present for human inspection of the data, but they are not used for training models.\n", - "\n", - "It is always a good practice to first explore the data, and then make decision about splitting them in training, test and validation sets. There are different possible ways for doing it.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Pandas dataframe\n", - "\n", - "The edge and node features just generated can be explored by instantiating the `GraphDataset` object, and then using `hdf5_to_pandas` method which converts node and edge features into a [Pandas](https://pandas.pydata.org/) dataframe. Each row represents a ppi in the form of a graph.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "processed_data = glob.glob(os.path.join(processed_data_path, \"residue\", \"*.hdf5\"))\n", - "dataset = GraphDataset(processed_data, target=\"binary\")\n", - "df = dataset.hdf5_to_pandas()\n", - "df.head()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can also generate histograms for looking at the features distributions. An example:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fname = os.path.join(processed_data_path, \"residue\", \"_\".join([\"res_mass\", \"distance\", \"electrostatic\"]))\n", - "dataset.save_hist(features=[\"res_mass\", \"distance\", \"electrostatic\"], fname=fname)\n", - "\n", - "im = img.imread(fname + \".png\")\n", - "plt.figure(figsize=(15, 10))\n", - "fig = plt.imshow(im)\n", - "fig.axes.get_xaxis().set_visible(False)\n", - "fig.axes.get_yaxis().set_visible(False)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Other tools\n", - "\n", - "- [HDFView](https://www.hdfgroup.org/downloads/hdfview/), a visual tool written in Java for browsing and editing HDF5 files.\n", - " As representative example, the following is the structure for `BA-100600.pdb` seen from HDF5View:\n", - "\n", - " \n", - "\n", - " Using this tool you can inspect the values of the features visually, for each data point.\n", - "\n", - "- Python packages such as [h5py](https://docs.h5py.org/en/stable/index.html). Examples:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with h5py.File(processed_data[0], \"r\") as hdf5:\n", - " # List of all graphs in hdf5, each graph representing a ppi\n", - " ids = list(hdf5.keys())\n", - " print(f\"IDs of PPIs in {processed_data[0]}: {ids}\")\n", - " node_features = list(hdf5[ids[0]][\"node_features\"])\n", - " print(f\"Node features: {node_features}\")\n", - " edge_features = list(hdf5[ids[0]][\"edge_features\"])\n", - " print(f\"Edge features: {edge_features}\")\n", - " target_features = list(hdf5[ids[0]][\"target_values\"])\n", - " print(f\"Targets features: {target_features}\")\n", - " # Polarity feature for ids[0], numpy.ndarray\n", - " node_feat_polarity = hdf5[ids[0]][\"node_features\"][\"polarity\"][:]\n", - " print(f\"Polarity feature shape: {node_feat_polarity.shape}\")\n", - " # Electrostatic feature for ids[0], numpy.ndarray\n", - " edge_feat_electrostatic = hdf5[ids[0]][\"edge_features\"][\"electrostatic\"][:]\n", - " print(f\"Electrostatic feature shape: {edge_feat_electrostatic.shape}\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Atomic-level PPIs using `ProteinProteinInterfaceQuery`\n", - "\n", - "Graphs can also be generated at an atomic resolution, very similarly to what has just been done for residue-level.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "queries = QueryCollection()\n", - "\n", - "influence_radius = 5 # max distance in Å between two interacting residues/atoms of two proteins\n", - "max_edge_length = 5\n", - "\n", - "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", - "count = 0\n", - "for i in range(len(pdb_files)):\n", - " queries.add(\n", - " ProteinProteinInterfaceQuery(\n", - " pdb_path=pdb_files[i],\n", - " resolution=\"atom\",\n", - " chain_ids=[\"M\", \"P\"],\n", - " influence_radius=influence_radius,\n", - " max_edge_length=max_edge_length,\n", - " targets={\n", - " \"binary\": int(float(bas[i]) <= 500), # binary target value\n", - " \"BA\": bas[i], # continuous target value\n", - " },\n", - " )\n", - " )\n", - " count += 1\n", - " if count % 20 == 0:\n", - " print(f\"{count} queries added to the collection.\")\n", - "\n", - "print(\"Queries ready to be processed.\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "grid_settings = GridSettings( # None if you don't want grids\n", - " # the number of points on the x, y, z edges of the cube\n", - " points_counts=[35, 30, 30],\n", - " # x, y, z sizes of the box in Å\n", - " sizes=[1.0, 1.0, 1.0],\n", - ")\n", - "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", - "\n", - "queries.process(\n", - " prefix=os.path.join(processed_data_path, \"atomic\", \"proc\"),\n", - " feature_modules=[components, contact],\n", - " cpu_count=8,\n", - " combine_output=False,\n", - " grid_settings=grid_settings,\n", - " grid_map_method=grid_map_method,\n", - ")\n", - "\n", - "print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, \"atomic\")}.')" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Again, the data can be inspected using `hdf5_to_pandas` function.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "processed_data = glob.glob(os.path.join(processed_data_path, \"atomic\", \"*.hdf5\"))\n", - "dataset = GraphDataset(processed_data, target=\"binary\")\n", - "df = dataset.hdf5_to_pandas()\n", - "df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fname = os.path.join(processed_data_path, \"atomic\", \"atom_charge\")\n", - "dataset.save_hist(features=\"atom_charge\", fname=fname)\n", - "\n", - "im = img.imread(fname + \".png\")\n", - "plt.figure(figsize=(8, 8))\n", - "fig = plt.imshow(im)\n", - "fig.axes.get_xaxis().set_visible(False)\n", - "fig.axes.get_yaxis().set_visible(False)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note that some of the features are different from the ones generated with the residue-level queries. There are indeed features in `deeprank2.features.components` module which are generated only in atomic graphs, i.e. `atom_type`, `atom_charge`, and `pdb_occupancy`, because they don't make sense only in the atomic graphs' representation.\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "deeprank2", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data preparation for protein-protein interfaces\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "\n", + "\n", + "This tutorial will demonstrate the use of DeepRank2 for generating protein-protein interface (PPI) graphs and saving them as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) files, using [PBD files]() of protein-protein complexes as input.\n", + "\n", + "In this data processing phase, for each protein-protein complex an interface is selected according to a distance threshold that the user can customize, and it is mapped to a graph. Nodes either represent residues or atoms, and edges are the interactions between them. Each node and edge can have several different features, which are generated and added during the processing phase as well. Optionally, the graphs can be mapped to volumetric grids (i.e., 3D image-like representations), together with their features. The mapped data are finally saved into HDF5 files, and can be used for later models' training (for details go to [training_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/training_ppi.ipynb) tutorial). In particular, graphs can be used for the training of Graph Neural Networks (GNNs), and grids can be used for the training of Convolutional Neural Networks (CNNs).\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Input Data\n", + "\n", + "The example data used in this tutorial are available on Zenodo at [this record address](https://zenodo.org/record/7997585). To download the raw data used in this tutorial, please visit the link and download `data_raw.zip`. Unzip it, and save the `data_raw/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", + "\n", + "Note that the dataset contains only 100 data points, which is not enough to develop an impactful predictive model, and the scope of its use is indeed only demonstrative and informative for the users.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Utilities\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Libraries\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The libraries needed for this tutorial:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import contextlib\n", + "import glob\n", + "import os\n", + "from pathlib import Path\n", + "\n", + "import h5py\n", + "import matplotlib.image as img\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "\n", + "from deeprank2.dataset import GraphDataset\n", + "from deeprank2.features import components, contact\n", + "from deeprank2.query import ProteinProteinInterfaceQuery, QueryCollection\n", + "from deeprank2.utils.grid import GridSettings, MapMethod" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Raw files and paths\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The paths for reading raw data and saving the processed ones:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_path = os.path.join(\"data_raw\", \"ppi\")\n", + "processed_data_path = os.path.join(\"data_processed\", \"ppi\")\n", + "residue_data_path = os.path.join(processed_data_path, \"residue\")\n", + "atomic_data_path = os.path.join(processed_data_path, \"atomic\")\n", + "\n", + "for output_path in [residue_data_path, atomic_data_path]:\n", + " os.makedirs(output_path, exist_ok=True)\n", + " if any(Path(output_path).iterdir()):\n", + " msg = f\"Please store any required data from `./{output_path}` and delete the folder.\\nThen re-run this cell to continue.\"\n", + " raise FileExistsError(msg)\n", + "\n", + "# Flag limit_data as True if you are running on a machine with limited memory (e.g., Docker container)\n", + "limit_data = True" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Raw data are PDB files in `data_raw/ppi/pdb/`, which contains atomic coordinates of the protein-protein complexes of interest, so in our case of pMHC complexes.\n", + "- Target data, so in our case the BA values for the pMHC complex, are in `data_raw/ppi/BA_values.csv`.\n", + "- The final PPI processed data will be saved in `data_processed/ppi/` folder, which in turns contains a folder for residue-level data and another one for atomic-level data. More details about such different levels will come a few cells below.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`get_pdb_files_and_target_data` is an helper function used to retrieve the raw pdb files names in a list and the BA target values from a CSV containing the IDs of the PDB models as well:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_pdb_files_and_target_data(data_path: str) -> tuple[list[str], list[float]]:\n", + " csv_data = pd.read_csv(os.path.join(data_path, \"BA_values.csv\"))\n", + " pdb_files = glob.glob(os.path.join(data_path, \"pdb\", \"*.pdb\"))\n", + " pdb_files.sort()\n", + " pdb_ids_csv = [pdb_file.split(\"/\")[-1].split(\".\")[0] for pdb_file in pdb_files]\n", + " with contextlib.suppress(KeyError):\n", + " csv_data_indexed = csv_data.set_index(\"ID\")\n", + " csv_data_indexed = csv_data_indexed.loc[pdb_ids_csv]\n", + " bas = csv_data_indexed.measurement_value.tolist()\n", + "\n", + " return pdb_files, bas\n", + "\n", + "\n", + "pdb_files, bas = get_pdb_files_and_target_data(data_path)\n", + "\n", + "if limit_data:\n", + " pdb_files = pdb_files[:15]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## `QueryCollection` and `Query` objects\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For each protein-protein complex, so for each data point, a query can be created and added to the `QueryCollection` object, to be processed later on.\n", + "\n", + "A query takes as inputs:\n", + "\n", + "- A `.pdb` file, representing the protein-protein structural complex.\n", + "- The resolution (`\"residue\"` or `\"atom\"`), i.e. whether each node should represent an amino acid residue or an atom.\n", + "- The ids of the two chains composing the complex. In our use case, \"M\" indicates the MHC protein chain and \"P\" the peptide chain.\n", + "- The interaction radius, which determines the threshold distance (in Ångström) for residues/atoms surrounding the interface that will be included in the graph.\n", + "- The target values associated with the query. For each query/data point, in the use case demonstrated in this tutorial will add two targets: \"BA\" and \"binary\". The first represents the actual BA value of the complex in nM, while the second represents its binary mapping, being 0 (BA > 500 nM) a not-binding complex and 1 (BA <= 500 nM) a binding one.\n", + "- The max edge distance, which is the maximum distance between two nodes to generate an edge between them.\n", + "- Optional: The correspondent [Position-Specific Scoring Matrices (PSSMs)](https://en.wikipedia.org/wiki/Position_weight_matrix), in the form of .pssm files. PSSMs are optional and will not be used in this tutorial.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Residue-level PPIs using `ProteinProteinInterfaceQuery`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "queries = QueryCollection()\n", + "\n", + "influence_radius = 8 # max distance in Å between two interacting residues/atoms of two proteins\n", + "max_edge_length = 8\n", + "binary_target_value = 500\n", + "\n", + "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", + "for i in range(len(pdb_files)):\n", + " queries.add(\n", + " ProteinProteinInterfaceQuery(\n", + " pdb_path=pdb_files[i],\n", + " resolution=\"residue\",\n", + " chain_ids=[\"M\", \"P\"],\n", + " influence_radius=influence_radius,\n", + " max_edge_length=max_edge_length,\n", + " targets={\n", + " \"binary\": int(float(bas[i]) <= binary_target_value),\n", + " \"BA\": bas[i], # continuous target value\n", + " },\n", + " ),\n", + " )\n", + " if i + 1 % 20 == 0:\n", + " print(f\"{i+1} queries added to the collection.\")\n", + "\n", + "print(f\"{i+1} queries ready to be processed.\\n\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Notes on `process()` method\n", + "\n", + "Once all queries have been added to the `QueryCollection` instance, they can be processed. Main parameters of the `process()` method, include:\n", + "\n", + "- `prefix` sets the output file location.\n", + "- `feature_modules` allows you to choose which feature generating modules you want to use. By default, the basic features contained in `deeprank2.features.components` and `deeprank2.features.contact` are generated. Users can add custom features by creating a new module and placing it in the `deeprank2.feature` subpackage. A complete and detailed list of the pre-implemented features per module and more information about how to add custom features can be found [here](https://deeprank2.readthedocs.io/en/latest/features.html).\n", + " - Note that all features generated by a module will be added if that module was selected, and there is no way to only generate specific features from that module. However, during the training phase shown in `training_ppi.ipynb`, it is possible to select only a subset of available features.\n", + "- `cpu_count` can be used to specify how many processes to be run simultaneously, and will coincide with the number of HDF5 files generated. By default it takes all available CPU cores and HDF5 files are squashed into a single file using the `combine_output` setting.\n", + "- Optional: If you want to include grids in the HDF5 files, which represent the mapping of the graphs to a volumetric box, you need to define `grid_settings` and `grid_map_method`, as shown in the example below. If they are `None` (default), only graphs are saved.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "grid_settings = GridSettings( # None if you don't want grids\n", + " # the number of points on the x, y, z edges of the cube\n", + " points_counts=[35, 30, 30],\n", + " # x, y, z sizes of the box in Å\n", + " sizes=[1.0, 1.0, 1.0],\n", + ")\n", + "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", + "\n", + "queries.process(\n", + " prefix=os.path.join(processed_data_path, \"residue\", \"proc\"),\n", + " feature_modules=[components, contact],\n", + " cpu_count=8,\n", + " combine_output=False,\n", + " grid_settings=grid_settings,\n", + " grid_map_method=grid_map_method,\n", + ")\n", + "\n", + "print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, \"residue\")}.')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Exploring data\n", + "\n", + "As representative example, the following is the HDF5 structure generated by the previous code for `BA-100600.pdb`, so for one single graph, which represents one PPI, for the graph + grid case:\n", + "\n", + "```bash\n", + "└── residue-ppi:M-P:BA-100600\n", + " |\n", + " ├── edge_features\n", + " │ ├── _index\n", + " │ ├── _name\n", + " │ ├── covalent\n", + " │ ├── distance\n", + " │ ├── electrostatic\n", + " │ ├── same_chain\n", + " │ └── vanderwaals\n", + " |\n", + " ├── node_features\n", + " │ ├── _chain_id\n", + " │ ├── _name\n", + " │ ├── _position\n", + " │ ├── hb_acceptors\n", + " │ ├── hb_donors\n", + " │ ├── polarity\n", + " │ ├── res_charge\n", + " │ ├── res_mass\n", + " | ├── res_pI\n", + " | ├── res_size\n", + " | └── res_type\n", + " |\n", + " ├── grid_points\n", + " │ ├── center\n", + " │ ├── x\n", + " │ ├── y\n", + " │ └── z\n", + " |\n", + " ├── mapped_features\n", + " │ ├── _position_000\n", + " │ ├── _position_001\n", + " │ ├── _position_002\n", + " │ ├── covalent\n", + " │ ├── distance\n", + " │ ├── electrostatic\n", + " │ ├── polarity_000\n", + " │ ├── polarity_001\n", + " │ ├── polarity_002\n", + " │ ├── polarity_003\n", + " | ├── ...\n", + " | └── vanderwaals\n", + " |\n", + " └── target_values\n", + " │ ├── BA\n", + " └── binary\n", + "```\n", + "\n", + "`edge_features`, `node_features`, `mapped_features` are [HDF5 Groups](https://docs.h5py.org/en/stable/high/group.html) which contain [HDF5 Datasets](https://docs.h5py.org/en/stable/high/dataset.html) (e.g., `_index`, `electrostatic`, etc.), which in turn contains features values in the form of arrays. `edge_features` and `node_features` refer specificly to the graph representation, while `grid_points` and `mapped_features` refer to the grid mapped from the graph. Each data point generated by deeprank2 has the above structure, with the features and the target changing according to the user's settings. Features starting with `_` are present for human inspection of the data, but they are not used for training models.\n", + "\n", + "It is always a good practice to first explore the data, and then make decision about splitting them in training, test and validation sets. There are different possible ways for doing it.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Pandas dataframe\n", + "\n", + "The edge and node features just generated can be explored by instantiating the `GraphDataset` object, and then using `hdf5_to_pandas` method which converts node and edge features into a [Pandas](https://pandas.pydata.org/) dataframe. Each row represents a ppi in the form of a graph.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "processed_data = glob.glob(os.path.join(processed_data_path, \"residue\", \"*.hdf5\"))\n", + "dataset = GraphDataset(processed_data, target=\"binary\")\n", + "dataset_df = dataset.hdf5_to_pandas()\n", + "dataset_df.head()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also generate histograms for looking at the features distributions. An example:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fname = os.path.join(processed_data_path, \"residue\", \"res_mass_distance_electrostatic\")\n", + "dataset.save_hist(features=[\"res_mass\", \"distance\", \"electrostatic\"], fname=fname)\n", + "\n", + "im = img.imread(fname + \".png\")\n", + "plt.figure(figsize=(15, 10))\n", + "fig = plt.imshow(im)\n", + "fig.axes.get_xaxis().set_visible(False)\n", + "fig.axes.get_yaxis().set_visible(False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Other tools\n", + "\n", + "- [HDFView](https://www.hdfgroup.org/downloads/hdfview/), a visual tool written in Java for browsing and editing HDF5 files.\n", + " As representative example, the following is the structure for `BA-100600.pdb` seen from HDF5View:\n", + "\n", + " \n", + "\n", + " Using this tool you can inspect the values of the features visually, for each data point.\n", + "\n", + "- Python packages such as [h5py](https://docs.h5py.org/en/stable/index.html). Examples:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with h5py.File(processed_data[0], \"r\") as hdf5:\n", + " # List of all graphs in hdf5, each graph representing a ppi\n", + " ids = list(hdf5.keys())\n", + " print(f\"IDs of PPIs in {processed_data[0]}: {ids}\")\n", + " node_features = list(hdf5[ids[0]][\"node_features\"])\n", + " print(f\"Node features: {node_features}\")\n", + " edge_features = list(hdf5[ids[0]][\"edge_features\"])\n", + " print(f\"Edge features: {edge_features}\")\n", + " target_features = list(hdf5[ids[0]][\"target_values\"])\n", + " print(f\"Targets features: {target_features}\")\n", + " # Polarity feature for ids[0], numpy.ndarray\n", + " node_feat_polarity = hdf5[ids[0]][\"node_features\"][\"polarity\"][:]\n", + " print(f\"Polarity feature shape: {node_feat_polarity.shape}\")\n", + " # Electrostatic feature for ids[0], numpy.ndarray\n", + " edge_feat_electrostatic = hdf5[ids[0]][\"edge_features\"][\"electrostatic\"][:]\n", + " print(f\"Electrostatic feature shape: {edge_feat_electrostatic.shape}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Atomic-level PPIs using `ProteinProteinInterfaceQuery`\n", + "\n", + "Graphs can also be generated at an atomic resolution, very similarly to what has just been done for residue-level.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "queries = QueryCollection()\n", + "\n", + "influence_radius = 5 # max distance in Å between two interacting residues/atoms of two proteins\n", + "max_edge_length = 5\n", + "binary_target_value = 500\n", + "\n", + "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", + "for i in range(len(pdb_files)):\n", + " queries.add(\n", + " ProteinProteinInterfaceQuery(\n", + " pdb_path=pdb_files[i],\n", + " resolution=\"atom\",\n", + " chain_ids=[\"M\", \"P\"],\n", + " influence_radius=influence_radius,\n", + " max_edge_length=max_edge_length,\n", + " targets={\n", + " \"binary\": int(float(bas[i]) <= binary_target_value),\n", + " \"BA\": bas[i], # continuous target value\n", + " },\n", + " ),\n", + " )\n", + " if i + 1 % 20 == 0:\n", + " print(f\"{i+1} queries added to the collection.\")\n", + "\n", + "print(f\"{i+1} queries ready to be processed.\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "grid_settings = GridSettings( # None if you don't want grids\n", + " # the number of points on the x, y, z edges of the cube\n", + " points_counts=[35, 30, 30],\n", + " # x, y, z sizes of the box in Å\n", + " sizes=[1.0, 1.0, 1.0],\n", + ")\n", + "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", + "\n", + "queries.process(\n", + " prefix=os.path.join(processed_data_path, \"atomic\", \"proc\"),\n", + " feature_modules=[components, contact],\n", + " cpu_count=8,\n", + " combine_output=False,\n", + " grid_settings=grid_settings,\n", + " grid_map_method=grid_map_method,\n", + ")\n", + "\n", + "print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, \"atomic\")}.')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Again, the data can be inspected using `hdf5_to_pandas` function.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "processed_data = glob.glob(os.path.join(processed_data_path, \"atomic\", \"*.hdf5\"))\n", + "dataset = GraphDataset(processed_data, target=\"binary\")\n", + "dataset_df = dataset.hdf5_to_pandas()\n", + "dataset_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fname = os.path.join(processed_data_path, \"atomic\", \"atom_charge\")\n", + "dataset.save_hist(features=\"atom_charge\", fname=fname)\n", + "\n", + "im = img.imread(fname + \".png\")\n", + "plt.figure(figsize=(8, 8))\n", + "fig = plt.imshow(im)\n", + "fig.axes.get_xaxis().set_visible(False)\n", + "fig.axes.get_yaxis().set_visible(False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that some of the features are different from the ones generated with the residue-level queries. There are indeed features in `deeprank2.features.components` module which are generated only in atomic graphs, i.e. `atom_type`, `atom_charge`, and `pdb_occupancy`, because they don't make sense only in the atomic graphs' representation.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "deeprank2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/tutorials/data_generation_srv.ipynb b/tutorials/data_generation_srv.ipynb index f8606c7e6..832543958 100644 --- a/tutorials/data_generation_srv.ipynb +++ b/tutorials/data_generation_srv.ipynb @@ -1,574 +1,583 @@ { - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Data preparation for single-residue variants\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Introduction\n", - "\n", - "\n", - "\n", - "This tutorial will demonstrate the use of DeepRank2 for generating single-residue variants (SRVs) graphs and saving them as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) files, using [PBD files]() of protein structures as input.\n", - "\n", - "In this data processing phase, a local neighborhood around the mutated residue is selected for each SRV according to a radius threshold that the user can customize. All atoms or residues within the threshold are mapped as the nodes to a graph and the interactions between them are the edges of the graph. Each node and edge can have several distinct (structural or physico-chemical) features, which are generated and added during the processing phase as well. Optionally, the graphs can be mapped to volumetric grids (i.e., 3D image-like representations), together with their features. Finally, the mapped data are saved as HDF5 files, which can be used for training predictive models (for details see [training_ppi.ipynb](https://github.com/DeepRank/deeprank-core/blob/main/tutorials/training_ppi.ipynb) tutorial). In particular, graphs can be used for the training of Graph Neural Networks (GNNs), and grids can be used for the training of Convolutional Neural Networks (CNNs).\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Input Data\n", - "\n", - "The example data used in this tutorial are available on Zenodo at [this record address](https://zenodo.org/record/13709906). To download the raw data used in this tutorial, please visit the link and download `data_raw.zip`. Unzip it, and save the `data_raw/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", - "\n", - "Note that the dataset contains only 96 data points, which is not enough to develop an impactful predictive model, and the scope of its use is indeed only demonstrative and informative for the users.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Utilities\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Libraries\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The libraries needed for this tutorial:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import pandas as pd\n", - "import glob\n", - "import h5py\n", - "import matplotlib.image as img\n", - "import matplotlib.pyplot as plt\n", - "from deeprank2.query import QueryCollection\n", - "from deeprank2.query import SingleResidueVariantQuery, SingleResidueVariantQuery\n", - "from deeprank2.domain.aminoacidlist import amino_acids_by_code\n", - "from deeprank2.features import components, contact\n", - "from deeprank2.utils.grid import GridSettings, MapMethod\n", - "from deeprank2.dataset import GraphDataset" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Raw files and paths\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The paths for reading raw data and saving the processed ones:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data_path = os.path.join(\"data_raw\", \"srv\")\n", - "processed_data_path = os.path.join(\"data_processed\", \"srv\")\n", - "os.makedirs(os.path.join(processed_data_path, \"residue\"))\n", - "os.makedirs(os.path.join(processed_data_path, \"atomic\"))\n", - "# Flag limit_data as True if you are running on a machine with limited memory (e.g., Docker container)\n", - "limit_data = True" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "- Raw data are PDB files in `data_raw/srv/pdb/`, which contains atomic coordinates of the protein structure containing the variant.\n", - "- Target data, so in our case pathogenic versus benign labels, are in `data_raw/srv/srv_target_values_curated.csv`.\n", - "- The final SRV processed data will be saved in `data_processed/srv/` folder, which in turns contains a folder for residue-level data and another one for atomic-level data. More details about such different levels will come a few cells below.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`get_pdb_files_and_target_data` is an helper function used to retrieve the raw pdb files names, SRVs information and target values in a list from the CSV:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def get_pdb_files_and_target_data(data_path):\n", - " csv_data = pd.read_csv(os.path.join(data_path, \"srv_target_values_curated.csv\"))\n", - " pdb_files = glob.glob(os.path.join(data_path, \"pdb\", \"*.ent\"))\n", - " pdb_files.sort()\n", - " pdb_file_names = [os.path.basename(pdb_file) for pdb_file in pdb_files]\n", - " csv_data_indexed = csv_data.set_index(\"pdb_file\")\n", - " csv_data_indexed = csv_data_indexed.loc[pdb_file_names]\n", - " res_numbers = csv_data_indexed.res_number.values.tolist()\n", - " res_wildtypes = csv_data_indexed.res_wildtype.values.tolist()\n", - " res_variants = csv_data_indexed.res_variant.values.tolist()\n", - " targets = csv_data_indexed.target.values.tolist()\n", - " pdb_names = csv_data_indexed.index.values.tolist()\n", - " pdb_files = [data_path + \"/pdb/\" + pdb_name for pdb_name in pdb_names]\n", - " return pdb_files, res_numbers, res_wildtypes, res_variants, targets\n", - "\n", - "\n", - "pdb_files, res_numbers, res_wildtypes, res_variants, targets = get_pdb_files_and_target_data(data_path)\n", - "\n", - "if limit_data:\n", - " pdb_files = pdb_files[:15]" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## `QueryCollection` and `Query` objects\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For each SRV, so for each data point, a query can be created and added to the `QueryCollection` object, to be processed later on. Different types of queries exist, based on the molecular resolution needed:\n", - "\n", - "A query takes as inputs:\n", - "\n", - "- A `.pdb` file, representing the protein structure containing the SRV.\n", - "- The resolution (`\"residue\"` or `\"atom\"`), i.e. whether each node should represent an amino acid residue or an atom.\n", - "- The chain id of the SRV.\n", - "- The residue number of the missense mutation.\n", - "- The insertion code, used when two residues have the same numbering. The combination of residue numbering and insertion code defines the unique residue.\n", - "- The wildtype amino acid.\n", - "- The variant amino acid.\n", - "- The interaction radius, which determines the threshold distance (in Ångström) for residues/atoms surrounding the mutation that will be included in the graph.\n", - "- The target values associated with the query. For each query/data point, in the use case demonstrated in this tutorial will add a 0 if the SRV belongs to the benign class, and 1 if it belongs to the pathogenic one.\n", - "- The max edge distance, which is the maximum distance between two nodes to generate an edge between them.\n", - "- Optional: The correspondent [Position-Specific Scoring Matrices (PSSMs)](https://en.wikipedia.org/wiki/Position_weight_matrix), per chain identifier, in the form of .pssm files. PSSMs are optional and will not be used in this tutorial.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Residue-level SRV: `SingleResidueVariantQuery`\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "queries = QueryCollection()\n", - "\n", - "influence_radius = 10.0 # radius to select the local neighborhood around the SRV\n", - "max_edge_length = 4.5 # ??\n", - "\n", - "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", - "count = 0\n", - "for i in range(len(pdb_files)):\n", - " queries.add(\n", - " SingleResidueVariantQuery(\n", - " pdb_path=pdb_files[i],\n", - " resolution=\"residue\",\n", - " chain_ids=\"A\",\n", - " variant_residue_number=res_numbers[i],\n", - " insertion_code=None,\n", - " wildtype_amino_acid=amino_acids_by_code[res_wildtypes[i]],\n", - " variant_amino_acid=amino_acids_by_code[res_variants[i]],\n", - " targets={\"binary\": targets[i]},\n", - " influence_radius=influence_radius,\n", - " max_edge_length=max_edge_length,\n", - " )\n", - " )\n", - " count += 1\n", - " if count % 20 == 0:\n", - " print(f\"{count} queries added to the collection.\")\n", - "\n", - "print(f\"Queries ready to be processed.\\n\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Notes on `process()` method\n", - "\n", - "Once all queries have been added to the `QueryCollection` instance, they can be processed. Main parameters of the `process()` method, include:\n", - "\n", - "- `prefix` sets the output file location.\n", - "- `feature_modules` allows you to choose which feature generating modules you want to use. By default, the basic features contained in `deeprank2.features.components` and `deeprank2.features.contact` are generated. Users can add custom features by creating a new module and placing it in the `deeprank2.feature` subpackage. A complete and detailed list of the pre-implemented features per module and more information about how to add custom features can be found [here](https://deeprank2.readthedocs.io/en/latest/features.html).\n", - " - Note that all features generated by a module will be added if that module was selected, and there is no way to only generate specific features from that module. However, during the training phase shown in `training_ppi.ipynb`, it is possible to select only a subset of available features.\n", - "- `cpu_count` can be used to specify how many processes to be run simultaneously, and will coincide with the number of HDF5 files generated. By default it takes all available CPU cores and HDF5 files are squashed into a single file using the `combine_output` setting.\n", - "- Optional: If you want to include grids in the HDF5 files, which represent the mapping of the graphs to a volumetric box, you need to define `grid_settings` and `grid_map_method`, as shown in the example below. If they are `None` (default), only graphs are saved.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "grid_settings = GridSettings( # None if you don't want grids\n", - " # the number of points on the x, y, z edges of the cube\n", - " points_counts=[35, 30, 30],\n", - " # x, y, z sizes of the box in Å\n", - " sizes=[1.0, 1.0, 1.0],\n", - ")\n", - "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", - "\n", - "queries.process(\n", - " prefix=os.path.join(processed_data_path, \"residue\", \"proc\"),\n", - " feature_modules=[components, contact],\n", - " cpu_count=8,\n", - " combine_output=False,\n", - " grid_settings=grid_settings,\n", - " grid_map_method=grid_map_method,\n", - ")\n", - "\n", - "print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, \"residue\")}.')" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Exploring data\n", - "\n", - "As representative example, the following is the HDF5 structure generated by the previous code for `pdb2ooh.ent`, so for one single graph, which represents one protein structure containing a SRV in position 112, for the graph + grid case:\n", - "\n", - "```bash\n", - "└── residue-graph:A:112:Threonine->Isoleucine:pdb2ooh\n", - " |\n", - " ├── edge_features\n", - " │ ├── _index\n", - " │ ├── _name\n", - " │ ├── covalent\n", - " │ ├── distance\n", - " │ ├── electrostatic\n", - " │ ├── same_chain\n", - " │ └── vanderwaals\n", - " |\n", - " ├── node_features\n", - " │ ├── _chain_id\n", - " │ ├── _name\n", - " │ ├── _position\n", - " │ ├── diff_charge\n", - " │ ├── diff_hb_donors\n", - " │ ├── diff_hb_acceptors\n", - " │ ├── diff_mass\n", - " │ ├── diff_pI\n", - " │ ├── diff_polarity\n", - " │ ├── diff_size\n", - " │ ├── hb_acceptors\n", - " │ ├── hb_donors\n", - " │ ├── polarity\n", - " │ ├── res_charge\n", - " │ ├── res_mass\n", - " | ├── res_pI\n", - " | ├── res_size\n", - " | ├── res_type\n", - " | └── variant_res\n", - " |\n", - " ├── grid_points\n", - " │ ├── center\n", - " │ ├── x\n", - " │ ├── y\n", - " │ └── z\n", - " |\n", - " ├── mapped_features\n", - " │ ├── _position_000\n", - " │ ├── _position_001\n", - " │ ├── _position_002\n", - " │ ├── covalent\n", - " │ ├── distance\n", - " │ ├── electrostatic\n", - " │ ├── diff_polarity_000\n", - " │ ├── diff_polarity_001\n", - " │ ├── diff_polarity_002\n", - " │ ├── diff_polarity_003\n", - " | ├── ...\n", - " | └── vanderwaals\n", - " |\n", - " └── target_values\n", - " └── binary\n", - "```\n", - "\n", - "`edge_features`, `node_features`, `mapped_features` are [HDF5 Groups](https://docs.h5py.org/en/stable/high/group.html) which contain [HDF5 Datasets](https://docs.h5py.org/en/stable/high/dataset.html) (e.g., `_index`, `electrostatic`, etc.), which in turn contains features values in the form of arrays. `edge_features` and `node_features` refer specificly to the graph representation, while `grid_points` and `mapped_features` refer to the grid mapped from the graph. Each data point generated by deeprank2 has the above structure, with the features and the target changing according to the user's settings. Features starting with `_` are present for human inspection of the data, but they are not used for training models.\n", - "\n", - "It is always a good practice to first explore the data, and then make decision about splitting them in training, test and validation sets. There are different possible ways for doing it.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Pandas dataframe\n", - "\n", - "The edge and node features just generated can be explored by instantiating the `GraphDataset` object, and then using `hdf5_to_pandas` method which converts node and edge features into a [Pandas](https://pandas.pydata.org/) dataframe. Each row represents a ppi in the form of a graph.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "processed_data = glob.glob(os.path.join(processed_data_path, \"residue\", \"*.hdf5\"))\n", - "dataset = GraphDataset(processed_data, target=\"binary\")\n", - "df = dataset.hdf5_to_pandas()\n", - "df.head()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can also generate histograms for looking at the features distributions. An example:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fname = os.path.join(processed_data_path, \"residue\", \"_\".join([\"res_mass\", \"distance\", \"electrostatic\"]))\n", - "dataset.save_hist(features=[\"res_mass\", \"distance\", \"electrostatic\"], fname=fname)\n", - "\n", - "im = img.imread(fname + \".png\")\n", - "plt.figure(figsize=(15, 10))\n", - "fig = plt.imshow(im)\n", - "fig.axes.get_xaxis().set_visible(False)\n", - "fig.axes.get_yaxis().set_visible(False)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Other tools\n", - "\n", - "- [HDFView](https://www.hdfgroup.org/downloads/hdfview/), a visual tool written in Java for browsing and editing HDF5 files.\n", - " As representative example, the following is the structure for `pdb2ooh.ent` seen from HDF5View:\n", - "\n", - " \n", - "\n", - " Using this tool you can inspect the values of the features visually, for each data point.\n", - "\n", - "- Python packages such as [h5py](https://docs.h5py.org/en/stable/index.html). Examples:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with h5py.File(processed_data[0], \"r\") as hdf5:\n", - " # List of all graphs in hdf5, each graph representing\n", - " # a SRV and its sourrouding environment\n", - " ids = list(hdf5.keys())\n", - " print(f\"IDs of SRVs in {processed_data[0]}: {ids}\")\n", - " node_features = list(hdf5[ids[0]][\"node_features\"])\n", - " print(f\"Node features: {node_features}\")\n", - " edge_features = list(hdf5[ids[0]][\"edge_features\"])\n", - " print(f\"Edge features: {edge_features}\")\n", - " target_features = list(hdf5[ids[0]][\"target_values\"])\n", - " print(f\"Targets features: {target_features}\")\n", - " # Polarity feature for ids[0], numpy.ndarray\n", - " node_feat_polarity = hdf5[ids[0]][\"node_features\"][\"polarity\"][:]\n", - " print(f\"Polarity feature shape: {node_feat_polarity.shape}\")\n", - " # Electrostatic feature for ids[0], numpy.ndarray\n", - " edge_feat_electrostatic = hdf5[ids[0]][\"edge_features\"][\"electrostatic\"][:]\n", - " print(f\"Electrostatic feature shape: {edge_feat_electrostatic.shape}\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Atomic-level SRV: `SingleResidueVariantQuery`\n", - "\n", - "Graphs can also be generated at an atomic resolution, very similarly to what has just been done for residue-level.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "queries = QueryCollection()\n", - "\n", - "influence_radius = 10.0 # radius to select the local neighborhood around the SRV\n", - "max_edge_length = 4.5 # ??\n", - "\n", - "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", - "count = 0\n", - "for i in range(len(pdb_files)):\n", - " queries.add(\n", - " SingleResidueVariantQuery(\n", - " pdb_path=pdb_files[i],\n", - " resolution=\"atom\",\n", - " chain_ids=\"A\",\n", - " variant_residue_number=res_numbers[i],\n", - " insertion_code=None,\n", - " wildtype_amino_acid=amino_acids_by_code[res_wildtypes[i]],\n", - " variant_amino_acid=amino_acids_by_code[res_variants[i]],\n", - " targets={\"binary\": targets[i]},\n", - " influence_radius=influence_radius,\n", - " max_edge_length=max_edge_length,\n", - " )\n", - " )\n", - " count += 1\n", - " if count % 20 == 0:\n", - " print(f\"{count} queries added to the collection.\")\n", - "\n", - "print(\"Queries ready to be processed.\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "grid_settings = GridSettings( # None if you don't want grids\n", - " # the number of points on the x, y, z edges of the cube\n", - " points_counts=[35, 30, 30],\n", - " # x, y, z sizes of the box in Å\n", - " sizes=[1.0, 1.0, 1.0],\n", - ")\n", - "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", - "\n", - "queries.process(\n", - " prefix=os.path.join(processed_data_path, \"atomic\", \"proc\"),\n", - " feature_modules=[components, contact],\n", - " cpu_count=8,\n", - " combine_output=False,\n", - " grid_settings=grid_settings,\n", - " grid_map_method=grid_map_method,\n", - ")\n", - "\n", - "print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, \"atomic\")}.')" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Again, the data can be inspected using `hdf5_to_pandas` function.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "processed_data = glob.glob(os.path.join(processed_data_path, \"atomic\", \"*.hdf5\"))\n", - "dataset = GraphDataset(processed_data, target=\"binary\")\n", - "df = dataset.hdf5_to_pandas()\n", - "df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fname = os.path.join(processed_data_path, \"atomic\", \"atom_charge\")\n", - "dataset.save_hist(features=\"atom_charge\", fname=fname)\n", - "\n", - "im = img.imread(fname + \".png\")\n", - "plt.figure(figsize=(8, 8))\n", - "fig = plt.imshow(im)\n", - "fig.axes.get_xaxis().set_visible(False)\n", - "fig.axes.get_yaxis().set_visible(False)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note that some of the features are different from the ones generated with the residue-level queries. There are indeed features in `deeprank2.features.components` module which are generated only in atomic graphs, i.e. `atom_type`, `atom_charge`, and `pdb_occupancy`, because they don't make sense only in the atomic graphs' representation.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "deeprankcore", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data preparation for single-residue variants\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "\n", + "\n", + "This tutorial will demonstrate the use of DeepRank2 for generating single-residue variants (SRVs) graphs and saving them as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) files, using [PBD files]() of protein structures as input.\n", + "\n", + "In this data processing phase, a local neighborhood around the mutated residue is selected for each SRV according to a radius threshold that the user can customize. All atoms or residues within the threshold are mapped as the nodes to a graph and the interactions between them are the edges of the graph. Each node and edge can have several distinct (structural or physico-chemical) features, which are generated and added during the processing phase as well. Optionally, the graphs can be mapped to volumetric grids (i.e., 3D image-like representations), together with their features. Finally, the mapped data are saved as HDF5 files, which can be used for training predictive models (for details see [training_ppi.ipynb](https://github.com/DeepRank/deeprank-core/blob/main/tutorials/training_ppi.ipynb) tutorial). In particular, graphs can be used for the training of Graph Neural Networks (GNNs), and grids can be used for the training of Convolutional Neural Networks (CNNs).\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Input Data\n", + "\n", + "The example data used in this tutorial are available on Zenodo at [this record address](https://zenodo.org/record/7997585). To download the raw data used in this tutorial, please visit the link and download `data_raw.zip`. Unzip it, and save the `data_raw/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", + "\n", + "Note that the dataset contains only 96 data points, which is not enough to develop an impactful predictive model, and the scope of its use is indeed only demonstrative and informative for the users.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Utilities\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Libraries\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The libraries needed for this tutorial:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import contextlib\n", + "import glob\n", + "import os\n", + "from pathlib import Path\n", + "\n", + "import h5py\n", + "import matplotlib.image as img\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "\n", + "from deeprank2.dataset import GraphDataset\n", + "from deeprank2.domain.aminoacidlist import amino_acids_by_code\n", + "from deeprank2.features import components, contact\n", + "from deeprank2.query import QueryCollection, SingleResidueVariantQuery\n", + "from deeprank2.utils.grid import GridSettings, MapMethod" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Raw files and paths\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The paths for reading raw data and saving the processed ones:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_path = os.path.join(\"data_raw\", \"srv\")\n", + "processed_data_path = os.path.join(\"data_processed\", \"srv\")\n", + "residue_data_path = os.path.join(processed_data_path, \"residue\")\n", + "atomic_data_path = os.path.join(processed_data_path, \"atomic\")\n", + "\n", + "for output_path in [residue_data_path, atomic_data_path]:\n", + " os.makedirs(output_path, exist_ok=True)\n", + " if any(Path(output_path).iterdir()):\n", + " msg = f\"Please store any required data from `./{output_path}` and delete the folder.\\nThen re-run this cell to continue.\"\n", + " raise FileExistsError(msg)\n", + "\n", + "# Flag limit_data as True if you are running on a machine with limited memory (e.g., Docker container)\n", + "limit_data = False" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Raw data are PDB files in `data_raw/srv/pdb/`, which contains atomic coordinates of the protein structure containing the variant.\n", + "- Target data, so in our case pathogenic versus benign labels, are in `data_raw/srv/srv_target_values_curated.csv`.\n", + "- The final SRV processed data will be saved in `data_processed/srv/` folder, which in turns contains a folder for residue-level data and another one for atomic-level data. More details about such different levels will come a few cells below.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`get_pdb_files_and_target_data` is an helper function used to retrieve the raw pdb files names, SRVs information and target values in a list from the CSV:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_pdb_files_and_target_data(data_path: str) -> tuple[list[str], list[int], list[str], list[str], list[float]]:\n", + " csv_data = pd.read_csv(os.path.join(data_path, \"srv_target_values_curated.csv\"))\n", + " pdb_files = glob.glob(os.path.join(data_path, \"pdb\", \"*.ent\"))\n", + " pdb_files.sort()\n", + " pdb_file_names = [os.path.basename(pdb_file) for pdb_file in pdb_files]\n", + " csv_data_indexed = csv_data.set_index(\"pdb_file\")\n", + " with contextlib.suppress(KeyError):\n", + " csv_data_indexed = csv_data_indexed.loc[pdb_file_names]\n", + " res_numbers = csv_data_indexed.res_number.tolist()\n", + " res_wildtypes = csv_data_indexed.res_wildtype.tolist()\n", + " res_variants = csv_data_indexed.res_variant.tolist()\n", + " targets = csv_data_indexed.target.tolist()\n", + " pdb_names = csv_data_indexed.index.tolist()\n", + " pdb_files = [data_path + \"/pdb/\" + pdb_name for pdb_name in pdb_names]\n", + "\n", + " return pdb_files, res_numbers, res_wildtypes, res_variants, targets\n", + "\n", + "\n", + "pdb_files, res_numbers, res_wildtypes, res_variants, targets = get_pdb_files_and_target_data(data_path)\n", + "\n", + "if limit_data:\n", + " pdb_files = pdb_files[:15]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## `QueryCollection` and `Query` objects\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For each SRV, so for each data point, a query can be created and added to the `QueryCollection` object, to be processed later on. Different types of queries exist, based on the molecular resolution needed:\n", + "\n", + "A query takes as inputs:\n", + "\n", + "- A `.pdb` file, representing the protein structure containing the SRV.\n", + "- The resolution (`\"residue\"` or `\"atom\"`), i.e. whether each node should represent an amino acid residue or an atom.\n", + "- The chain id of the SRV.\n", + "- The residue number of the missense mutation.\n", + "- The insertion code, used when two residues have the same numbering. The combination of residue numbering and insertion code defines the unique residue.\n", + "- The wildtype amino acid.\n", + "- The variant amino acid.\n", + "- The interaction radius, which determines the threshold distance (in Ångström) for residues/atoms surrounding the mutation that will be included in the graph.\n", + "- The target values associated with the query. For each query/data point, in the use case demonstrated in this tutorial will add a 0 if the SRV belongs to the benign class, and 1 if it belongs to the pathogenic one.\n", + "- The max edge distance, which is the maximum distance between two nodes to generate an edge between them.\n", + "- Optional: The correspondent [Position-Specific Scoring Matrices (PSSMs)](https://en.wikipedia.org/wiki/Position_weight_matrix), per chain identifier, in the form of .pssm files. PSSMs are optional and will not be used in this tutorial.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Residue-level SRV: `SingleResidueVariantQuery`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "queries = QueryCollection()\n", + "\n", + "influence_radius = 10.0 # radius to select the local neighborhood around the SRV\n", + "max_edge_length = 4.5 # ??\n", + "\n", + "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", + "for i in range(len(pdb_files)):\n", + " queries.add(\n", + " SingleResidueVariantQuery(\n", + " pdb_path=pdb_files[i],\n", + " resolution=\"residue\",\n", + " chain_ids=\"A\",\n", + " variant_residue_number=res_numbers[i],\n", + " insertion_code=None,\n", + " wildtype_amino_acid=amino_acids_by_code[res_wildtypes[i]],\n", + " variant_amino_acid=amino_acids_by_code[res_variants[i]],\n", + " targets={\"binary\": targets[i]},\n", + " influence_radius=influence_radius,\n", + " max_edge_length=max_edge_length,\n", + " ),\n", + " )\n", + " if i + 1 % 20 == 0:\n", + " print(f\"{i+1} queries added to the collection.\")\n", + "\n", + "print(f\"{i+1} queries ready to be processed.\\n\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Notes on `process()` method\n", + "\n", + "Once all queries have been added to the `QueryCollection` instance, they can be processed. Main parameters of the `process()` method, include:\n", + "\n", + "- `prefix` sets the output file location.\n", + "- `feature_modules` allows you to choose which feature generating modules you want to use. By default, the basic features contained in `deeprank2.features.components` and `deeprank2.features.contact` are generated. Users can add custom features by creating a new module and placing it in the `deeprank2.feature` subpackage. A complete and detailed list of the pre-implemented features per module and more information about how to add custom features can be found [here](https://deeprank2.readthedocs.io/en/latest/features.html).\n", + " - Note that all features generated by a module will be added if that module was selected, and there is no way to only generate specific features from that module. However, during the training phase shown in `training_ppi.ipynb`, it is possible to select only a subset of available features.\n", + "- `cpu_count` can be used to specify how many processes to be run simultaneously, and will coincide with the number of HDF5 files generated. By default it takes all available CPU cores and HDF5 files are squashed into a single file using the `combine_output` setting.\n", + "- Optional: If you want to include grids in the HDF5 files, which represent the mapping of the graphs to a volumetric box, you need to define `grid_settings` and `grid_map_method`, as shown in the example below. If they are `None` (default), only graphs are saved.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "grid_settings = GridSettings( # None if you don't want grids\n", + " # the number of points on the x, y, z edges of the cube\n", + " points_counts=[35, 30, 30],\n", + " # x, y, z sizes of the box in Å\n", + " sizes=[1.0, 1.0, 1.0],\n", + ")\n", + "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", + "\n", + "queries.process(\n", + " prefix=os.path.join(processed_data_path, \"residue\", \"proc\"),\n", + " feature_modules=[components, contact],\n", + " cpu_count=8,\n", + " combine_output=False,\n", + " grid_settings=grid_settings,\n", + " grid_map_method=grid_map_method,\n", + ")\n", + "\n", + "print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, \"residue\")}.')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Exploring data\n", + "\n", + "As representative example, the following is the HDF5 structure generated by the previous code for `pdb2ooh.ent`, so for one single graph, which represents one protein structure containing a SRV in position 112, for the graph + grid case:\n", + "\n", + "```bash\n", + "└── residue-graph:A:112:Threonine->Isoleucine:pdb2ooh\n", + " |\n", + " ├── edge_features\n", + " │ ├── _index\n", + " │ ├── _name\n", + " │ ├── covalent\n", + " │ ├── distance\n", + " │ ├── electrostatic\n", + " │ ├── same_chain\n", + " │ └── vanderwaals\n", + " |\n", + " ├── node_features\n", + " │ ├── _chain_id\n", + " │ ├── _name\n", + " │ ├── _position\n", + " │ ├── diff_charge\n", + " │ ├── diff_hb_donors\n", + " │ ├── diff_hb_acceptors\n", + " │ ├── diff_mass\n", + " │ ├── diff_pI\n", + " │ ├── diff_polarity\n", + " │ ├── diff_size\n", + " │ ├── hb_acceptors\n", + " │ ├── hb_donors\n", + " │ ├── polarity\n", + " │ ├── res_charge\n", + " │ ├── res_mass\n", + " | ├── res_pI\n", + " | ├── res_size\n", + " | ├── res_type\n", + " | └── variant_res\n", + " |\n", + " ├── grid_points\n", + " │ ├── center\n", + " │ ├── x\n", + " │ ├── y\n", + " │ └── z\n", + " |\n", + " ├── mapped_features\n", + " │ ├── _position_000\n", + " │ ├── _position_001\n", + " │ ├── _position_002\n", + " │ ├── covalent\n", + " │ ├── distance\n", + " │ ├── electrostatic\n", + " │ ├── diff_polarity_000\n", + " │ ├── diff_polarity_001\n", + " │ ├── diff_polarity_002\n", + " │ ├── diff_polarity_003\n", + " | ├── ...\n", + " | └── vanderwaals\n", + " |\n", + " └── target_values\n", + " └── binary\n", + "```\n", + "\n", + "`edge_features`, `node_features`, `mapped_features` are [HDF5 Groups](https://docs.h5py.org/en/stable/high/group.html) which contain [HDF5 Datasets](https://docs.h5py.org/en/stable/high/dataset.html) (e.g., `_index`, `electrostatic`, etc.), which in turn contains features values in the form of arrays. `edge_features` and `node_features` refer specificly to the graph representation, while `grid_points` and `mapped_features` refer to the grid mapped from the graph. Each data point generated by deeprank2 has the above structure, with the features and the target changing according to the user's settings. Features starting with `_` are present for human inspection of the data, but they are not used for training models.\n", + "\n", + "It is always a good practice to first explore the data, and then make decision about splitting them in training, test and validation sets. There are different possible ways for doing it.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Pandas dataframe\n", + "\n", + "The edge and node features just generated can be explored by instantiating the `GraphDataset` object, and then using `hdf5_to_pandas` method which converts node and edge features into a [Pandas](https://pandas.pydata.org/) dataframe. Each row represents a ppi in the form of a graph.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "processed_data = glob.glob(os.path.join(processed_data_path, \"residue\", \"*.hdf5\"))\n", + "dataset = GraphDataset(processed_data, target=\"binary\")\n", + "dataset_df = dataset.hdf5_to_pandas()\n", + "dataset_df.head()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also generate histograms for looking at the features distributions. An example:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fname = os.path.join(processed_data_path, \"residue\", \"res_mass_distance_electrostatic\")\n", + "\n", + "dataset.save_hist(features=[\"res_mass\", \"distance\", \"electrostatic\"], fname=fname)\n", + "\n", + "im = img.imread(fname + \".png\")\n", + "plt.figure(figsize=(15, 10))\n", + "fig = plt.imshow(im)\n", + "fig.axes.get_xaxis().set_visible(False)\n", + "fig.axes.get_yaxis().set_visible(False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Other tools\n", + "\n", + "- [HDFView](https://www.hdfgroup.org/downloads/hdfview/), a visual tool written in Java for browsing and editing HDF5 files.\n", + " As representative example, the following is the structure for `pdb2ooh.ent` seen from HDF5View:\n", + "\n", + " \n", + "\n", + " Using this tool you can inspect the values of the features visually, for each data point.\n", + "\n", + "- Python packages such as [h5py](https://docs.h5py.org/en/stable/index.html). Examples:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with h5py.File(processed_data[0], \"r\") as hdf5:\n", + " # List of all graphs in hdf5, each graph representing\n", + " # a SRV and its sourrouding environment\n", + " ids = list(hdf5.keys())\n", + " print(f\"IDs of SRVs in {processed_data[0]}: {ids}\")\n", + " node_features = list(hdf5[ids[0]][\"node_features\"])\n", + " print(f\"Node features: {node_features}\")\n", + " edge_features = list(hdf5[ids[0]][\"edge_features\"])\n", + " print(f\"Edge features: {edge_features}\")\n", + " target_features = list(hdf5[ids[0]][\"target_values\"])\n", + " print(f\"Targets features: {target_features}\")\n", + " # Polarity feature for ids[0], numpy.ndarray\n", + " node_feat_polarity = hdf5[ids[0]][\"node_features\"][\"polarity\"][:]\n", + " print(f\"Polarity feature shape: {node_feat_polarity.shape}\")\n", + " # Electrostatic feature for ids[0], numpy.ndarray\n", + " edge_feat_electrostatic = hdf5[ids[0]][\"edge_features\"][\"electrostatic\"][:]\n", + " print(f\"Electrostatic feature shape: {edge_feat_electrostatic.shape}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Atomic-level SRV: `SingleResidueVariantQuery`\n", + "\n", + "Graphs can also be generated at an atomic resolution, very similarly to what has just been done for residue-level.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "queries = QueryCollection()\n", + "\n", + "influence_radius = 10.0 # radius to select the local neighborhood around the SRV\n", + "max_edge_length = 4.5 # ??\n", + "\n", + "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", + "for i in range(len(pdb_files)):\n", + " queries.add(\n", + " SingleResidueVariantQuery(\n", + " pdb_path=pdb_files[i],\n", + " resolution=\"atom\",\n", + " chain_ids=\"A\",\n", + " variant_residue_number=res_numbers[i],\n", + " insertion_code=None,\n", + " wildtype_amino_acid=amino_acids_by_code[res_wildtypes[i]],\n", + " variant_amino_acid=amino_acids_by_code[res_variants[i]],\n", + " targets={\"binary\": targets[i]},\n", + " influence_radius=influence_radius,\n", + " max_edge_length=max_edge_length,\n", + " ),\n", + " )\n", + " if i + 1 % 20 == 0:\n", + " print(f\"{i+1} queries added to the collection.\")\n", + "\n", + "print(f\"{i+1} queries ready to be processed.\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "grid_settings = GridSettings( # None if you don't want grids\n", + " # the number of points on the x, y, z edges of the cube\n", + " points_counts=[35, 30, 30],\n", + " # x, y, z sizes of the box in Å\n", + " sizes=[1.0, 1.0, 1.0],\n", + ")\n", + "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", + "\n", + "queries.process(\n", + " prefix=os.path.join(processed_data_path, \"atomic\", \"proc\"),\n", + " feature_modules=[components, contact],\n", + " cpu_count=8,\n", + " combine_output=False,\n", + " grid_settings=grid_settings,\n", + " grid_map_method=grid_map_method,\n", + ")\n", + "\n", + "print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, \"atomic\")}.')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Again, the data can be inspected using `hdf5_to_pandas` function.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "processed_data = glob.glob(os.path.join(processed_data_path, \"atomic\", \"*.hdf5\"))\n", + "dataset = GraphDataset(processed_data, target=\"binary\")\n", + "dataset_df = dataset.hdf5_to_pandas()\n", + "dataset_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fname = os.path.join(processed_data_path, \"atomic\", \"atom_charge\")\n", + "dataset.save_hist(features=\"atom_charge\", fname=fname)\n", + "\n", + "im = img.imread(fname + \".png\")\n", + "plt.figure(figsize=(8, 8))\n", + "fig = plt.imshow(im)\n", + "fig.axes.get_xaxis().set_visible(False)\n", + "fig.axes.get_yaxis().set_visible(False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that some of the features are different from the ones generated with the residue-level queries. There are indeed features in `deeprank2.features.components` module which are generated only in atomic graphs, i.e. `atom_type`, `atom_charge`, and `pdb_occupancy`, because they don't make sense only in the atomic graphs' representation.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "deeprankcore", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/tutorials/training.ipynb b/tutorials/training.ipynb index 2c662b8d2..db64ae534 100644 --- a/tutorials/training.ipynb +++ b/tutorials/training.ipynb @@ -19,7 +19,7 @@ "\n", "This tutorial will demonstrate the use of DeepRank2 for training graph neural networks (GNNs) and convolutional neural networks (CNNs) using protein-protein interface (PPI) or single-residue variant (SRV) data for classification and regression predictive tasks.\n", "\n", - "This tutorial assumes that the PPI data of interest have already been generated and saved as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format), with the data structure that DeepRank2 expects. This data can be generated using the [data_generation_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/data_generation_ppi.ipynb) tutorial or downloaded from Zenodo at [this record address](https://zenodo.org/record/8349335). For more details on the data structure, please refer to the other tutorial, which also contains a detailed description of how the data is generated from PDB files.\n", + "This tutorial assumes that the PPI data of interest have already been generated and saved as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format), with the data structure that DeepRank2 expects. This data can be generated using the [data_generation_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/data_generation_ppi.ipynb) tutorial or downloaded from Zenodo at [this record address](https://zenodo.org/record/7997585). For more details on the data structure, please refer to the other tutorial, which also contains a detailed description of how the data is generated from PDB files.\n", "\n", "This tutorial assumes also a basic knowledge of the [PyTorch](https://pytorch.org/) framework, on top of which the machine learning pipeline of DeepRank2 has been developed, for which many online tutorials exist.\n" ] @@ -33,7 +33,7 @@ "\n", "If you have previously run `data_generation_ppi.ipynb` or `data_generation_srv.ipynb` notebook, then their output can be directly used as input for this tutorial.\n", "\n", - "Alternatively, preprocessed HDF5 files can be downloaded directly from Zenodo at [this record address](https://zenodo.org/record/13709906). To download the data used in this tutorial, please visit the link and download `data_processed.zip`. Unzip it, and save the `data_processed/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", + "Alternatively, preprocessed HDF5 files can be downloaded directly from Zenodo at [this record address](https://zenodo.org/record/7997585). To download the data used in this tutorial, please visit the link and download `data_processed.zip`. Unzip it, and save the `data_processed/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", "\n", "Note that the datasets contain only ~100 data points each, which is not enough to develop an impactful predictive model, and the scope of their use is indeed only demonstrative and informative for the users.\n" ] @@ -68,30 +68,33 @@ "metadata": {}, "outputs": [], "source": [ - "import logging\n", "import glob\n", + "import logging\n", "import os\n", + "import warnings\n", + "\n", "import h5py\n", + "import numpy as np\n", "import pandas as pd\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import roc_curve, auc, precision_score, recall_score, accuracy_score, f1_score\n", "import plotly.express as px\n", "import torch\n", - "import numpy as np\n", + "from sklearn.metrics import accuracy_score, auc, f1_score, precision_score, recall_score, roc_curve\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "from deeprank2.dataset import GraphDataset, GridDataset\n", + "from deeprank2.neuralnets.cnn.model3d import CnnClassification\n", + "from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork\n", + "from deeprank2.trainer import Trainer\n", + "from deeprank2.utils.exporters import HDF5OutputExporter\n", "\n", "np.seterr(divide=\"ignore\")\n", "np.seterr(invalid=\"ignore\")\n", - "import pandas as pd\n", "\n", "logging.basicConfig(level=logging.INFO)\n", - "from deeprank2.dataset import GraphDataset, GridDataset\n", - "from deeprank2.trainer import Trainer\n", - "from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork\n", - "from deeprank2.neuralnets.cnn.model3d import CnnClassification\n", - "from deeprank2.utils.exporters import HDF5OutputExporter\n", - "import warnings\n", "\n", - "warnings.filterwarnings(\"ignore\")" + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "# ruff: noqa: PD901" ] }, { @@ -146,7 +149,7 @@ "df_dict[\"target\"] = []\n", "for fname in input_data_path:\n", " with h5py.File(fname, \"r\") as hdf5:\n", - " for mol in hdf5.keys():\n", + " for mol in hdf5:\n", " target_value = float(hdf5[mol][\"target_values\"][\"binary\"][()])\n", " df_dict[\"entry\"].append(mol)\n", " df_dict[\"target\"].append(target_value)\n", @@ -176,7 +179,7 @@ "df_train, df_test = train_test_split(df, test_size=0.1, stratify=df.target, random_state=42)\n", "df_train, df_valid = train_test_split(df_train, test_size=0.2, stratify=df_train.target, random_state=42)\n", "\n", - "print(f\"Data statistics:\\n\")\n", + "print(\"Data statistics:\\n\")\n", "print(f\"Total samples: {len(df)}\\n\")\n", "print(f\"Training set: {len(df_train)} samples, {round(100*len(df_train)/len(df))}%\")\n", "print(f\"\\t- Class 0: {len(df_train[df_train.target == 0])} samples, {round(100*len(df_train[df_train.target == 0])/len(df_train))}%\")\n", @@ -478,12 +481,12 @@ "df = pd.concat([output_train, output_test])\n", "df_plot = df[(df.epoch == trainer.epoch_saved_model) | ((df.epoch == trainer.epoch_saved_model) & (df.phase == \"testing\"))]\n", "\n", - "for idx, set in enumerate([\"training\", \"validation\", \"testing\"]):\n", - " df_plot_phase = df_plot[(df_plot.phase == set)]\n", + "for dataset in [\"training\", \"validation\", \"testing\"]:\n", + " df_plot_phase = df_plot[(df_plot.phase == dataset)]\n", " y_true = df_plot_phase.target\n", - " y_score = np.array(df_plot_phase.output.values.tolist())[:, 1]\n", + " y_score = np.array(df_plot_phase.output.tolist())[:, 1]\n", "\n", - " print(f\"\\nMetrics for {set}:\")\n", + " print(f\"\\nMetrics for {dataset}:\")\n", " fpr_roc, tpr_roc, thr_roc = roc_curve(y_true, y_score)\n", " auc_score = auc(fpr_roc, tpr_roc)\n", " print(f\"AUC: {round(auc_score, 1)}\")\n", @@ -719,12 +722,12 @@ "df = pd.concat([output_train, output_test])\n", "df_plot = df[(df.epoch == trainer.epoch_saved_model) | ((df.epoch == trainer.epoch_saved_model) & (df.phase == \"testing\"))]\n", "\n", - "for idx, set in enumerate([\"training\", \"validation\", \"testing\"]):\n", - " df_plot_phase = df_plot[(df_plot.phase == set)]\n", + "for dataset in [\"training\", \"validation\", \"testing\"]:\n", + " df_plot_phase = df_plot[(df_plot.phase == dataset)]\n", " y_true = df_plot_phase.target\n", - " y_score = np.array(df_plot_phase.output.values.tolist())[:, 1]\n", + " y_score = np.array(df_plot_phase.output.tolist())[:, 1]\n", "\n", - " print(f\"\\nMetrics for {set}:\")\n", + " print(f\"\\nMetrics for {dataset}:\")\n", " fpr_roc, tpr_roc, thr_roc = roc_curve(y_true, y_score)\n", " auc_score = auc(fpr_roc, tpr_roc)\n", " print(f\"AUC: {round(auc_score, 1)}\")\n",