diff --git a/.github/workflows/notebooks.yml b/.github/workflows/notebooks.yml index b1e59ea9d..f4e9d1f1a 100644 --- a/.github/workflows/notebooks.yml +++ b/.github/workflows/notebooks.yml @@ -54,9 +54,11 @@ jobs: - name: Download the data for the tutorials shell: bash -l {0} run: | - wget https://zenodo.org/records/8349335/files/data_raw.zip + wget https://zenodo.org/records/13709906/files/data_raw.zip unzip data_raw.zip -d data_raw mv data_raw tutorials + echo listing files in data_raw: + ls tutorials/data_raw - name: Run tutorial notebooks run: pytest --nbmake tutorials diff --git a/deeprank2/dataset.py b/deeprank2/dataset.py index 18796ac26..144e8d0af 100644 --- a/deeprank2/dataset.py +++ b/deeprank2/dataset.py @@ -112,7 +112,7 @@ def _check_and_inherit_train( # noqa: C901 for key in data["features_transform"].values(): if key["transform"] is None: continue - key["transform"] = eval(key["transform"]) # noqa: S307, PGH001 + key["transform"] = eval(key["transform"]) # noqa: S307 except pickle.UnpicklingError as e: msg = "The path provided to `train_source` is not a valid DeepRank2 pre-trained model." raise ValueError(msg) from e @@ -277,7 +277,7 @@ def _filter_targets(self, grp: h5py.Group) -> bool: for operator_string in [">", "<", "==", "<=", ">=", "!="]: operation = operation.replace(operator_string, f"{target_value}" + operator_string) - if not eval(operation): # noqa: S307, PGH001 + if not eval(operation): # noqa: S307 return False elif target_condition is not None: diff --git a/deeprank2/query.py b/deeprank2/query.py index 89f171a4d..b85222d14 100644 --- a/deeprank2/query.py +++ b/deeprank2/query.py @@ -22,7 +22,7 @@ import deeprank2.features from deeprank2.domain.aminoacidlist import convert_aa_nomenclature from deeprank2.features import components, conservation, contact -from deeprank2.molstruct.residue import Residue, SingleResidueVariant +from deeprank2.molstruct.residue import SingleResidueVariant from deeprank2.utils.buildgraph import get_contact_atoms, get_structure, get_surrounding_residues from deeprank2.utils.graph import Graph from deeprank2.utils.grid import Augmentation, GridSettings, MapMethod @@ -265,12 +265,11 @@ def _build_helper(self) -> Graph: structure = self._load_structure() # find the variant residue and its surroundings - variant_residue: Residue = None for residue in structure.get_chain(self.variant_chain_id).residues: if residue.number == self.variant_residue_number and residue.insertion_code == self.insertion_code: variant_residue = residue break - if variant_residue is None: + else: # if break is not reached msg = f"Residue not found in {self.pdb_path}: {self.variant_chain_id} {self.residue_id}" raise ValueError(msg) self.variant = SingleResidueVariant(variant_residue, self.variant_amino_acid) @@ -354,19 +353,12 @@ def _build_helper(self) -> Graph: raise ValueError(msg) # build the graph - if self.resolution == "atom": - graph = Graph.build_graph( - contact_atoms, - self.get_query_id(), - self.max_edge_length, - ) - elif self.resolution == "residue": - residues_selected = list({atom.residue for atom in contact_atoms}) - graph = Graph.build_graph( - residues_selected, - self.get_query_id(), - self.max_edge_length, - ) + nodes = contact_atoms if self.resolution == "atom" else list({atom.residue for atom in contact_atoms}) + graph = Graph.build_graph( + nodes=nodes, + graph_id=self.get_query_id(), + max_edge_length=self.max_edge_length, + ) graph.center = np.mean([atom.position for atom in contact_atoms], axis=0) structure = contact_atoms[0].residue.chain.model @@ -453,7 +445,7 @@ def __iter__(self) -> Iterator[Query]: def __len__(self) -> int: return len(self._queries) - def _process_one_query(self, query: Query) -> None: + def _process_one_query(self, query: Query, log_error_traceback: bool = False) -> None: """Only one process may access an hdf5 file at a time.""" try: output_path = f"{self._prefix}-{os.getpid()}.hdf5" @@ -479,10 +471,12 @@ def _process_one_query(self, query: Query) -> None: except (ValueError, AttributeError, KeyError, TimeoutError) as e: _log.warning( - f"\nGraph/Query with ID {query.get_query_id()} ran into an Exception ({e.__class__.__name__}: {e})," - " and it has not been written to the hdf5 file. More details below:", + f"Graph/Query with ID {query.get_query_id()} ran into an Exception and was not written to the hdf5 file.\n" + f"Exception found: {e.__class__.__name__}: {e}.\n" + "You may proceed with your analysis, but this query will be ignored.\n", ) - _log.exception(e) + if log_error_traceback: + _log.exception(f"----Full error traceback:----\n{e}") def process( self, @@ -493,6 +487,7 @@ def process( grid_settings: GridSettings | None = None, grid_map_method: MapMethod | None = None, grid_augmentation_count: int = 0, + log_error_traceback: bool = False, ) -> list[str]: """Render queries into graphs (and optionally grids). @@ -510,6 +505,8 @@ def process( grid_settings: If valid together with `grid_map_method`, the grid data will be stored as well. Defaults to None. grid_map_method: If valid together with `grid_settings`, the grid data will be stored as well. Defaults to None. grid_augmentation_count: Number of grid data augmentations (must be >= 0). Defaults to 0. + log_error_traceback: if True, logs full error message in case query fails. Otherwise only the error message is logged. + Defaults to false. Returns: The list of paths of the generated HDF5 files. @@ -536,7 +533,7 @@ def process( self._grid_augmentation_count = grid_augmentation_count _log.info(f"Creating pool function to process {len(self)} queries...") - pool_function = partial(self._process_one_query) + pool_function = partial(self._process_one_query, log_error_traceback=log_error_traceback) with Pool(self._cpu_count) as pool: _log.info("Starting pooling...\n") pool.map(pool_function, self.queries) @@ -551,6 +548,24 @@ def process( os.remove(output_path) return glob(f"{prefix}.hdf5") + n_processed = 0 + for hdf5file in output_paths: + with h5py.File(hdf5file, "r") as hdf5: + # List of all graphs in hdf5, each graph representing + # a SRV and its sourrouding environment + n_processed += len(list(hdf5.keys())) + + if not n_processed: + msg = "No queries have been processed." + raise ValueError(msg) + if n_processed != len(self.queries): + _log.warning( + f"Not all queries have been processed. You can proceed with the analysis of {n_processed}/{len(self.queries)} queries.\n" + "Set `log_error_traceback` to True for advanced troubleshooting.", + ) + else: + _log.info(f"{n_processed} queries have been processed.") + return output_paths def _set_feature_modules(self, feature_modules: list[ModuleType, str] | ModuleType | str) -> list[str]: diff --git a/pyproject.toml b/pyproject.toml index 4172c9c1a..4172d7cf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,8 +53,8 @@ dependencies = [ "python-louvain >= 0.16, < 1.0", "tqdm >= 4.66.4, < 5.0", "freesasa >= 2.1.1, < 3.0", - "biopython >= 1.83, < 2.0" - ] + "biopython >= 1.83, < 2.0", +] [project.optional-dependencies] # development dependency groups @@ -66,7 +66,7 @@ test = [ "pytest-cov >= 4.1.0, < 5.0", "pytest-runner >= 6.0.0, < 7.0", "coveralls >= 3.3.1, < 4.0", - "ruff == 0.5.1" + "ruff == 0.6.3", ] publishing = ["build", "twine", "wheel"] notebooks = ["nbmake"] @@ -88,7 +88,7 @@ include = ["deeprank2*"] [tool.pytest.ini_options] # pytest options: -ra: show summary info for all test outcomes -addopts = "-ra" +addopts = "-ra" [tool.ruff] output-format = "concise" @@ -148,3 +148,4 @@ isort.known-first-party = ["deeprank2"] ] "docs/*" = ["ALL"] "tests/perf/*" = ["T201"] # Use of print statements +"*.ipynb" = ["T201", "E402", "D103"] diff --git a/tests/data/hdf5/_generate_testdata.ipynb b/tests/data/hdf5/_generate_testdata.ipynb index b2fc2677f..762897834 100644 --- a/tests/data/hdf5/_generate_testdata.ipynb +++ b/tests/data/hdf5/_generate_testdata.ipynb @@ -15,11 +15,8 @@ "PATH_TEST = ROOT / \"tests\"\n", "import glob\n", "import os\n", - "import re\n", - "import sys\n", "\n", "import h5py\n", - "import numpy as np\n", "import pandas as pd\n", "\n", "from deeprank2.dataset import save_hdf5_keys\n", @@ -79,7 +76,7 @@ " chain_ids=[chain_id1, chain_id2],\n", " targets=targets,\n", " pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2},\n", - " )\n", + " ),\n", " )\n", "\n", " # Generate graphs and save them in hdf5 files\n", @@ -128,8 +125,8 @@ "csv_data = pd.read_csv(csv_file_path)\n", "csv_data.cluster = csv_data.cluster.fillna(-1)\n", "pdb_ids_csv = [pdb_file.split(\"/\")[-1].split(\".\")[0].replace(\"-\", \"_\") for pdb_file in pdb_files]\n", - "clusters = [csv_data[pdb_id == csv_data.ID].cluster.values[0] for pdb_id in pdb_ids_csv]\n", - "bas = [csv_data[pdb_id == csv_data.ID].measurement_value.values[0] for pdb_id in pdb_ids_csv]\n", + "clusters = [csv_data[pdb_id == csv_data.ID].cluster.to_numpy()[0] for pdb_id in pdb_ids_csv]\n", + "bas = [csv_data[pdb_id == csv_data.ID].measurement_value.to_numpy()[0] for pdb_id in pdb_ids_csv]\n", "\n", "queries = QueryCollection()\n", "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", @@ -147,7 +144,7 @@ " \"cluster\": clusters[i],\n", " },\n", " pssm_paths={\"M\": pssm_m[i], \"P\": pssm_p[i]},\n", - " )\n", + " ),\n", " )\n", "print(\"Queries created and ready to be processed.\\n\")\n", "\n", @@ -183,7 +180,7 @@ "test_ids = []\n", "\n", "with h5py.File(hdf5_path, \"r\") as hdf5:\n", - " for key in hdf5.keys():\n", + " for key in hdf5:\n", " feature_value = float(hdf5[key][target][feature][()])\n", " if feature_value in train_clusters:\n", " train_ids.append(key)\n", @@ -192,7 +189,7 @@ " elif feature_value in test_clusters:\n", " test_ids.append(key)\n", "\n", - " if feature_value in clusters.keys():\n", + " if feature_value in clusters:\n", " clusters[int(feature_value)] += 1\n", " else:\n", " clusters[int(feature_value)] = 1\n", @@ -278,8 +275,12 @@ " targets = compute_ppi_scores(pdb_path, ref_path)\n", " queries.add(\n", " ProteinProteinInterfaceQuery(\n", - " pdb_path=pdb_path, resolution=\"atom\", chain_ids=[chain_id1, chain_id2], targets=targets, pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2}\n", - " )\n", + " pdb_path=pdb_path,\n", + " resolution=\"atom\",\n", + " chain_ids=[chain_id1, chain_id2],\n", + " targets=targets,\n", + " pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2},\n", + " ),\n", " )\n", "\n", "# Generate graphs and save them in hdf5 files\n", @@ -303,7 +304,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.14" }, "orig_nbformat": 4, "vscode": { diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 932e7d3c9..3d2424258 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1201,7 +1201,7 @@ def test_inherit_info_pretrained_model_graphdataset(self) -> None: for key in data["features_transform"].values(): if key["transform"] is None: continue - key["transform"] = eval(key["transform"]) # noqa: S307, PGH001 + key["transform"] = eval(key["transform"]) # noqa: S307 dataset_test_vars = vars(dataset_test) for param in dataset_test.inherited_params: diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 39bd2c9c8..5ca7d4273 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -27,7 +27,7 @@ target_value = 1.0 -@pytest.fixture() +@pytest.fixture def graph() -> Graph: """Build a simple graph of two nodes and one edge in between them.""" # load the structure diff --git a/tutorials/data_generation_ppi.ipynb b/tutorials/data_generation_ppi.ipynb index 2d1d9650a..98cd77fa3 100644 --- a/tutorials/data_generation_ppi.ipynb +++ b/tutorials/data_generation_ppi.ipynb @@ -29,7 +29,7 @@ "source": [ "### Input Data\n", "\n", - "The example data used in this tutorial are available on Zenodo at [this record address](https://zenodo.org/record/8349335). To download the raw data used in this tutorial, please visit the link and download `data_raw.zip`. Unzip it, and save the `data_raw/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", + "The example data used in this tutorial are available on Zenodo at [this record address](https://zenodo.org/record/7997585). To download the raw data used in this tutorial, please visit the link and download `data_raw.zip`. Unzip it, and save the `data_raw/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", "\n", "Note that the dataset contains only 100 data points, which is not enough to develop an impactful predictive model, and the scope of its use is indeed only demonstrative and informative for the users.\n" ] @@ -64,17 +64,20 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "import pandas as pd\n", + "import contextlib\n", "import glob\n", + "import os\n", + "from pathlib import Path\n", + "\n", "import h5py\n", "import matplotlib.image as img\n", "import matplotlib.pyplot as plt\n", - "from deeprank2.query import QueryCollection\n", - "from deeprank2.query import ProteinProteinInterfaceQuery, ProteinProteinInterfaceQuery\n", + "import pandas as pd\n", + "\n", + "from deeprank2.dataset import GraphDataset\n", "from deeprank2.features import components, contact\n", - "from deeprank2.utils.grid import GridSettings, MapMethod\n", - "from deeprank2.dataset import GraphDataset" + "from deeprank2.query import ProteinProteinInterfaceQuery, QueryCollection\n", + "from deeprank2.utils.grid import GridSettings, MapMethod" ] }, { @@ -101,8 +104,15 @@ "source": [ "data_path = os.path.join(\"data_raw\", \"ppi\")\n", "processed_data_path = os.path.join(\"data_processed\", \"ppi\")\n", - "os.makedirs(os.path.join(processed_data_path, \"residue\"))\n", - "os.makedirs(os.path.join(processed_data_path, \"atomic\"))\n", + "residue_data_path = os.path.join(processed_data_path, \"residue\")\n", + "atomic_data_path = os.path.join(processed_data_path, \"atomic\")\n", + "\n", + "for output_path in [residue_data_path, atomic_data_path]:\n", + " os.makedirs(output_path, exist_ok=True)\n", + " if any(Path(output_path).iterdir()):\n", + " msg = f\"Please store any required data from `./{output_path}` and delete the folder.\\nThen re-run this cell to continue.\"\n", + " raise FileExistsError(msg)\n", + "\n", "# Flag limit_data as True if you are running on a machine with limited memory (e.g., Docker container)\n", "limit_data = True" ] @@ -131,14 +141,16 @@ "metadata": {}, "outputs": [], "source": [ - "def get_pdb_files_and_target_data(data_path):\n", + "def get_pdb_files_and_target_data(data_path: str) -> tuple[list[str], list[float]]:\n", " csv_data = pd.read_csv(os.path.join(data_path, \"BA_values.csv\"))\n", " pdb_files = glob.glob(os.path.join(data_path, \"pdb\", \"*.pdb\"))\n", " pdb_files.sort()\n", " pdb_ids_csv = [pdb_file.split(\"/\")[-1].split(\".\")[0] for pdb_file in pdb_files]\n", - " csv_data_indexed = csv_data.set_index(\"ID\")\n", + " with contextlib.suppress(KeyError):\n", + " csv_data_indexed = csv_data.set_index(\"ID\")\n", " csv_data_indexed = csv_data_indexed.loc[pdb_ids_csv]\n", - " bas = csv_data_indexed.measurement_value.values.tolist()\n", + " bas = csv_data_indexed.measurement_value.tolist()\n", + "\n", " return pdb_files, bas\n", "\n", "\n", @@ -192,9 +204,9 @@ "\n", "influence_radius = 8 # max distance in Å between two interacting residues/atoms of two proteins\n", "max_edge_length = 8\n", + "binary_target_value = 500\n", "\n", "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", - "count = 0\n", "for i in range(len(pdb_files)):\n", " queries.add(\n", " ProteinProteinInterfaceQuery(\n", @@ -204,16 +216,15 @@ " influence_radius=influence_radius,\n", " max_edge_length=max_edge_length,\n", " targets={\n", - " \"binary\": int(float(bas[i]) <= 500), # binary target value\n", + " \"binary\": int(float(bas[i]) <= binary_target_value),\n", " \"BA\": bas[i], # continuous target value\n", " },\n", - " )\n", + " ),\n", " )\n", - " count += 1\n", - " if count % 20 == 0:\n", - " print(f\"{count} queries added to the collection.\")\n", + " if i + 1 % 20 == 0:\n", + " print(f\"{i+1} queries added to the collection.\")\n", "\n", - "print(\"Queries ready to be processed.\\n\")" + "print(f\"{i+1} queries ready to be processed.\\n\")" ] }, { @@ -340,8 +351,8 @@ "source": [ "processed_data = glob.glob(os.path.join(processed_data_path, \"residue\", \"*.hdf5\"))\n", "dataset = GraphDataset(processed_data, target=\"binary\")\n", - "df = dataset.hdf5_to_pandas()\n", - "df.head()" + "dataset_df = dataset.hdf5_to_pandas()\n", + "dataset_df.head()" ] }, { @@ -358,7 +369,7 @@ "metadata": {}, "outputs": [], "source": [ - "fname = os.path.join(processed_data_path, \"residue\", \"_\".join([\"res_mass\", \"distance\", \"electrostatic\"]))\n", + "fname = os.path.join(processed_data_path, \"residue\", \"res_mass_distance_electrostatic\")\n", "dataset.save_hist(features=[\"res_mass\", \"distance\", \"electrostatic\"], fname=fname)\n", "\n", "im = img.imread(fname + \".png\")\n", @@ -429,9 +440,9 @@ "\n", "influence_radius = 5 # max distance in Å between two interacting residues/atoms of two proteins\n", "max_edge_length = 5\n", + "binary_target_value = 500\n", "\n", "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", - "count = 0\n", "for i in range(len(pdb_files)):\n", " queries.add(\n", " ProteinProteinInterfaceQuery(\n", @@ -441,16 +452,15 @@ " influence_radius=influence_radius,\n", " max_edge_length=max_edge_length,\n", " targets={\n", - " \"binary\": int(float(bas[i]) <= 500), # binary target value\n", + " \"binary\": int(float(bas[i]) <= binary_target_value),\n", " \"BA\": bas[i], # continuous target value\n", " },\n", - " )\n", + " ),\n", " )\n", - " count += 1\n", - " if count % 20 == 0:\n", - " print(f\"{count} queries added to the collection.\")\n", + " if i + 1 % 20 == 0:\n", + " print(f\"{i+1} queries added to the collection.\")\n", "\n", - "print(\"Queries ready to be processed.\\n\")" + "print(f\"{i+1} queries ready to be processed.\\n\")" ] }, { @@ -495,8 +505,8 @@ "source": [ "processed_data = glob.glob(os.path.join(processed_data_path, \"atomic\", \"*.hdf5\"))\n", "dataset = GraphDataset(processed_data, target=\"binary\")\n", - "df = dataset.hdf5_to_pandas()\n", - "df.head()" + "dataset_df = dataset.hdf5_to_pandas()\n", + "dataset_df.head()" ] }, { @@ -540,7 +550,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.14" }, "orig_nbformat": 4 }, diff --git a/tutorials/data_generation_srv.ipynb b/tutorials/data_generation_srv.ipynb index 1a68f31ad..7dcde58ab 100644 --- a/tutorials/data_generation_srv.ipynb +++ b/tutorials/data_generation_srv.ipynb @@ -29,7 +29,7 @@ "source": [ "### Input Data\n", "\n", - "The example data used in this tutorial are available on Zenodo at [this record address](https://zenodo.org/record/8349335). To download the raw data used in this tutorial, please visit the link and download `data_raw.zip`. Unzip it, and save the `data_raw/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", + "The example data used in this tutorial are available on Zenodo at [this record address](https://zenodo.org/record/7997585). To download the raw data used in this tutorial, please visit the link and download `data_raw.zip`. Unzip it, and save the `data_raw/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", "\n", "Note that the dataset contains only 96 data points, which is not enough to develop an impactful predictive model, and the scope of its use is indeed only demonstrative and informative for the users.\n" ] @@ -64,18 +64,21 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "import pandas as pd\n", + "import contextlib\n", "import glob\n", + "import os\n", + "from pathlib import Path\n", + "\n", "import h5py\n", "import matplotlib.image as img\n", "import matplotlib.pyplot as plt\n", - "from deeprank2.query import QueryCollection\n", - "from deeprank2.query import SingleResidueVariantQuery, SingleResidueVariantQuery\n", + "import pandas as pd\n", + "\n", + "from deeprank2.dataset import GraphDataset\n", "from deeprank2.domain.aminoacidlist import amino_acids_by_code\n", "from deeprank2.features import components, contact\n", - "from deeprank2.utils.grid import GridSettings, MapMethod\n", - "from deeprank2.dataset import GraphDataset" + "from deeprank2.query import QueryCollection, SingleResidueVariantQuery\n", + "from deeprank2.utils.grid import GridSettings, MapMethod" ] }, { @@ -102,10 +105,17 @@ "source": [ "data_path = os.path.join(\"data_raw\", \"srv\")\n", "processed_data_path = os.path.join(\"data_processed\", \"srv\")\n", - "os.makedirs(os.path.join(processed_data_path, \"residue\"))\n", - "os.makedirs(os.path.join(processed_data_path, \"atomic\"))\n", + "residue_data_path = os.path.join(processed_data_path, \"residue\")\n", + "atomic_data_path = os.path.join(processed_data_path, \"atomic\")\n", + "\n", + "for output_path in [residue_data_path, atomic_data_path]:\n", + " os.makedirs(output_path, exist_ok=True)\n", + " if any(Path(output_path).iterdir()):\n", + " msg = f\"Please store any required data from `./{output_path}` and delete the folder.\\nThen re-run this cell to continue.\"\n", + " raise FileExistsError(msg)\n", + "\n", "# Flag limit_data as True if you are running on a machine with limited memory (e.g., Docker container)\n", - "limit_data = True" + "limit_data = False" ] }, { @@ -132,19 +142,21 @@ "metadata": {}, "outputs": [], "source": [ - "def get_pdb_files_and_target_data(data_path):\n", - " csv_data = pd.read_csv(os.path.join(data_path, \"srv_target_values.csv\"))\n", + "def get_pdb_files_and_target_data(data_path: str) -> tuple[list[str], list[int], list[str], list[str], list[float]]:\n", + " csv_data = pd.read_csv(os.path.join(data_path, \"srv_target_values_curated.csv\"))\n", " pdb_files = glob.glob(os.path.join(data_path, \"pdb\", \"*.ent\"))\n", " pdb_files.sort()\n", " pdb_file_names = [os.path.basename(pdb_file) for pdb_file in pdb_files]\n", " csv_data_indexed = csv_data.set_index(\"pdb_file\")\n", - " csv_data_indexed = csv_data_indexed.loc[pdb_file_names]\n", - " res_numbers = csv_data_indexed.res_number.values.tolist()\n", - " res_wildtypes = csv_data_indexed.res_wildtype.values.tolist()\n", - " res_variants = csv_data_indexed.res_variant.values.tolist()\n", - " targets = csv_data_indexed.target.values.tolist()\n", - " pdb_names = csv_data_indexed.index.values.tolist()\n", + " with contextlib.suppress(KeyError):\n", + " csv_data_indexed = csv_data_indexed.loc[pdb_file_names]\n", + " res_numbers = csv_data_indexed.res_number.tolist()\n", + " res_wildtypes = csv_data_indexed.res_wildtype.tolist()\n", + " res_variants = csv_data_indexed.res_variant.tolist()\n", + " targets = csv_data_indexed.target.tolist()\n", + " pdb_names = csv_data_indexed.index.tolist()\n", " pdb_files = [data_path + \"/pdb/\" + pdb_name for pdb_name in pdb_names]\n", + "\n", " return pdb_files, res_numbers, res_wildtypes, res_variants, targets\n", "\n", "\n", @@ -204,7 +216,6 @@ "max_edge_length = 4.5 # ??\n", "\n", "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", - "count = 0\n", "for i in range(len(pdb_files)):\n", " queries.add(\n", " SingleResidueVariantQuery(\n", @@ -218,13 +229,12 @@ " targets={\"binary\": targets[i]},\n", " influence_radius=influence_radius,\n", " max_edge_length=max_edge_length,\n", - " )\n", + " ),\n", " )\n", - " count += 1\n", - " if count % 20 == 0:\n", - " print(f\"{count} queries added to the collection.\")\n", + " if i + 1 % 20 == 0:\n", + " print(f\"{i+1} queries added to the collection.\")\n", "\n", - "print(f\"Queries ready to be processed.\\n\")" + "print(f\"{i+1} queries ready to be processed.\\n\")" ] }, { @@ -358,8 +368,8 @@ "source": [ "processed_data = glob.glob(os.path.join(processed_data_path, \"residue\", \"*.hdf5\"))\n", "dataset = GraphDataset(processed_data, target=\"binary\")\n", - "df = dataset.hdf5_to_pandas()\n", - "df.head()" + "dataset_df = dataset.hdf5_to_pandas()\n", + "dataset_df.head()" ] }, { @@ -376,7 +386,8 @@ "metadata": {}, "outputs": [], "source": [ - "fname = os.path.join(processed_data_path, \"residue\", \"_\".join([\"res_mass\", \"distance\", \"electrostatic\"]))\n", + "fname = os.path.join(processed_data_path, \"residue\", \"res_mass_distance_electrostatic\")\n", + "\n", "dataset.save_hist(features=[\"res_mass\", \"distance\", \"electrostatic\"], fname=fname)\n", "\n", "im = img.imread(fname + \".png\")\n", @@ -450,7 +461,6 @@ "max_edge_length = 4.5 # ??\n", "\n", "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", - "count = 0\n", "for i in range(len(pdb_files)):\n", " queries.add(\n", " SingleResidueVariantQuery(\n", @@ -464,13 +474,12 @@ " targets={\"binary\": targets[i]},\n", " influence_radius=influence_radius,\n", " max_edge_length=max_edge_length,\n", - " )\n", + " ),\n", " )\n", - " count += 1\n", - " if count % 20 == 0:\n", - " print(f\"{count} queries added to the collection.\")\n", + " if i + 1 % 20 == 0:\n", + " print(f\"{i+1} queries added to the collection.\")\n", "\n", - "print(\"Queries ready to be processed.\\n\")" + "print(f\"{i+1} queries ready to be processed.\\n\")" ] }, { @@ -515,8 +524,8 @@ "source": [ "processed_data = glob.glob(os.path.join(processed_data_path, \"atomic\", \"*.hdf5\"))\n", "dataset = GraphDataset(processed_data, target=\"binary\")\n", - "df = dataset.hdf5_to_pandas()\n", - "df.head()" + "dataset_df = dataset.hdf5_to_pandas()\n", + "dataset_df.head()" ] }, { @@ -565,7 +574,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.14" }, "orig_nbformat": 4 }, diff --git a/tutorials/training.ipynb b/tutorials/training.ipynb index 499c3f8e6..b3445cc41 100644 --- a/tutorials/training.ipynb +++ b/tutorials/training.ipynb @@ -1,770 +1,773 @@ { - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Training Neural Networks\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Introduction\n", - "\n", - "\n", - "\n", - "This tutorial will demonstrate the use of DeepRank2 for training graph neural networks (GNNs) and convolutional neural networks (CNNs) using protein-protein interface (PPI) or single-residue variant (SRV) data for classification and regression predictive tasks.\n", - "\n", - "This tutorial assumes that the PPI data of interest have already been generated and saved as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format), with the data structure that DeepRank2 expects. This data can be generated using the [data_generation_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/data_generation_ppi.ipynb) tutorial or downloaded from Zenodo at [this record address](https://zenodo.org/record/8349335). For more details on the data structure, please refer to the other tutorial, which also contains a detailed description of how the data is generated from PDB files.\n", - "\n", - "This tutorial assumes also a basic knowledge of the [PyTorch](https://pytorch.org/) framework, on top of which the machine learning pipeline of DeepRank2 has been developed, for which many online tutorials exist.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Input data\n", - "\n", - "If you have previously run `data_generation_ppi.ipynb` or `data_generation_srv.ipynb` notebook, then their output can be directly used as input for this tutorial.\n", - "\n", - "Alternatively, preprocessed HDF5 files can be downloaded directly from Zenodo at [this record address](https://zenodo.org/record/8349335). To download the data used in this tutorial, please visit the link and download `data_processed.zip`. Unzip it, and save the `data_processed/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", - "\n", - "Note that the datasets contain only ~100 data points each, which is not enough to develop an impactful predictive model, and the scope of their use is indeed only demonstrative and informative for the users.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Utilities\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Libraries\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The libraries needed for this tutorial:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import logging\n", - "import glob\n", - "import os\n", - "import h5py\n", - "import pandas as pd\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import roc_curve, auc, precision_score, recall_score, accuracy_score, f1_score\n", - "import plotly.express as px\n", - "import torch\n", - "import numpy as np\n", - "\n", - "np.seterr(divide=\"ignore\")\n", - "np.seterr(invalid=\"ignore\")\n", - "import pandas as pd\n", - "\n", - "logging.basicConfig(level=logging.INFO)\n", - "from deeprank2.dataset import GraphDataset, GridDataset\n", - "from deeprank2.trainer import Trainer\n", - "from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork\n", - "from deeprank2.neuralnets.cnn.model3d import CnnClassification\n", - "from deeprank2.utils.exporters import HDF5OutputExporter\n", - "import warnings\n", - "\n", - "warnings.filterwarnings(\"ignore\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Paths and sets\n", - "\n", - "The paths for reading the processed data:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data_type = \"ppi\"\n", - "level = \"residue\"\n", - "processed_data_path = os.path.join(\"data_processed\", data_type, level)\n", - "input_data_path = glob.glob(os.path.join(processed_data_path, \"*.hdf5\"))\n", - "output_path = os.path.join(\"data_processed\", data_type, level) # for saving predictions results" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `data_type` can be either \"ppi\" or \"srv\", depending on which application the user is most interested in. The `level` can be either \"residue\" or \"atomic\", and refers to the structural resolution, where each node either represents a single residue or a single atom from the molecular structure.\n", - "\n", - "In this tutorial, we will use PPI residue-level data by default, but the same code can be applied to SRV or/and atomic-level data with no changes, apart from setting `data_type` and `level` parameters in the cell above.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "A Pandas DataFrame containing data points' IDs and the binary target values can be defined:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "df_dict = {}\n", - "df_dict[\"entry\"] = []\n", - "df_dict[\"target\"] = []\n", - "for fname in input_data_path:\n", - " with h5py.File(fname, \"r\") as hdf5:\n", - " for mol in hdf5.keys():\n", - " target_value = float(hdf5[mol][\"target_values\"][\"binary\"][()])\n", - " df_dict[\"entry\"].append(mol)\n", - " df_dict[\"target\"].append(target_value)\n", - "\n", - "df = pd.DataFrame(data=df_dict)\n", - "df.head()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As explained in [data_generation_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/data_generation_ppi.ipynb), for each data point there are two targets: \"BA\" and \"binary\". The first represents the strength of the interaction between two molecules that bind reversibly (interact) in nM, while the second represents its binary mapping, being 0 (BA > 500 nM) a not-binding complex and 1 (BA <= 500 nM) binding one.\n", - "\n", - "For SRVs, each data point has a single target, \"binary\", which is 0 if the SRV is considered benign, and 1 if it is pathogenic, as explained in [data_generation_srv.ipynb](https://github.com/DeepRank/deeprank-core/blob/main/tutorials/data_generation_srv.ipynb).\n", - "\n", - "The pandas DataFrame `df` is used only to split data points into training, validation and test sets according to the \"binary\" target - using target stratification to keep the proportion of 0s and 1s constant among the different sets. Training and validation sets will be used during the training for updating the network weights, while the test set will be held out as an independent test and will be used later for the model evaluation.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "df_train, df_test = train_test_split(df, test_size=0.1, stratify=df.target, random_state=42)\n", - "df_train, df_valid = train_test_split(df_train, test_size=0.2, stratify=df_train.target, random_state=42)\n", - "\n", - "print(f\"Data statistics:\\n\")\n", - "print(f\"Total samples: {len(df)}\\n\")\n", - "print(f\"Training set: {len(df_train)} samples, {round(100*len(df_train)/len(df))}%\")\n", - "print(f\"\\t- Class 0: {len(df_train[df_train.target == 0])} samples, {round(100*len(df_train[df_train.target == 0])/len(df_train))}%\")\n", - "print(f\"\\t- Class 1: {len(df_train[df_train.target == 1])} samples, {round(100*len(df_train[df_train.target == 1])/len(df_train))}%\")\n", - "print(f\"Validation set: {len(df_valid)} samples, {round(100*len(df_valid)/len(df))}%\")\n", - "print(f\"\\t- Class 0: {len(df_valid[df_valid.target == 0])} samples, {round(100*len(df_valid[df_valid.target == 0])/len(df_valid))}%\")\n", - "print(f\"\\t- Class 1: {len(df_valid[df_valid.target == 1])} samples, {round(100*len(df_valid[df_valid.target == 1])/len(df_valid))}%\")\n", - "print(f\"Testing set: {len(df_test)} samples, {round(100*len(df_test)/len(df))}%\")\n", - "print(f\"\\t- Class 0: {len(df_test[df_test.target == 0])} samples, {round(100*len(df_test[df_test.target == 0])/len(df_test))}%\")\n", - "print(f\"\\t- Class 1: {len(df_test[df_test.target == 1])} samples, {round(100*len(df_test[df_test.target == 1])/len(df_test))}%\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Classification example\n", - "\n", - "A GNN and a CNN can be trained for a classification predictive task, which consists in predicting the \"binary\" target values.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### GNN\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### GraphDataset\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For training GNNs the user can create `GraphDataset` instances. This class inherits from `DeeprankDataset` class, which in turns inherits from `Dataset` [PyTorch geometric class](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/data/dataset.html), a base class for creating graph datasets.\n", - "\n", - "A few notes about `GraphDataset` parameters:\n", - "\n", - "- By default, all features contained in the HDF5 files are used, but the user can specify `node_features` and `edge_features` in `GraphDataset` if not all of them are needed. See the [docs](https://deeprank2.readthedocs.io/en/latest/features.html) for more details about all the possible pre-implemented features.\n", - "- For regression, `task` should be set to `regress` and the `target` to `BA`, which is a continuous variable and therefore suitable for regression tasks.\n", - "- For the `GraphDataset` class it is possible to define a dictionary to indicate which transformations to apply to the features, being the transformations lambda functions and/or standardization.\n", - " - If the `standardize` key is `True`, standardization is applied after transformation. Standardization consists in applying the following formula on each feature's value: ${x' = \\frac{x - \\mu}{\\sigma}}$, being ${\\mu}$ the mean and ${\\sigma}$ the standard deviation. Standardization is a scaling method where the values are centered around mean with a unit standard deviation.\n", - " - The transformation to apply can be speficied as a lambda function as a value of the key `transform`, which defaults to `None`.\n", - " - Since in the provided example standardization is applied, the training features' means and standard deviations need to be used for scaling validation and test sets. For doing so, `train_source` parameter is used. When `train_source` parameter is set, it will be used to scale the validation/testing sets. You need to pass `features_transform` to the training dataset only, since in other cases it will be ignored and only the one of `train_source` will be considered.\n", - " - Note that transformations have not currently been implemented for the `GridDataset` class.\n", - " - In the example below a logarithmic transformation and then the standardization are applied to all the features. It is also possible to use specific features as keys for indicating that transformation and/or standardization need to be apply to few features only.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target = \"binary\"\n", - "task = \"classif\"\n", - "node_features = [\"res_type\"]\n", - "edge_features = [\"distance\"]\n", - "features_transform = {\"all\": {\"transform\": lambda x: np.cbrt(x), \"standardize\": True}}\n", - "\n", - "print(\"Loading training data...\")\n", - "dataset_train = GraphDataset(\n", - " hdf5_path=input_data_path,\n", - " subset=list(df_train.entry), # selects only data points with ids in df_train.entry\n", - " node_features=node_features,\n", - " edge_features=edge_features,\n", - " features_transform=features_transform,\n", - " target=target,\n", - " task=task,\n", - ")\n", - "print(\"\\nLoading validation data...\")\n", - "dataset_val = GraphDataset(\n", - " hdf5_path=input_data_path,\n", - " subset=list(df_valid.entry), # selects only data points with ids in df_valid.entry\n", - " train_source=dataset_train,\n", - ")\n", - "print(\"\\nLoading test data...\")\n", - "dataset_test = GraphDataset(\n", - " hdf5_path=input_data_path,\n", - " subset=list(df_test.entry), # selects only data points with ids in df_test.entry\n", - " train_source=dataset_train,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Trainer\n", - "\n", - "The class `Trainer` implements training, validation and testing of PyTorch-based neural networks.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "A few notes about `Trainer` parameters:\n", - "\n", - "- `neuralnet` can be any neural network class that inherits from `torch.nn.Module`, and it shouldn't be specific to regression or classification in terms of output shape. The `Trainer` class takes care of formatting the output shape according to the task. This tutorial uses a simple network, `VanillaNetwork` (implemented in `deeprank2.neuralnets.gnn.vanilla_gnn`). All GNN architectures already implemented in the pakcage can be found [here](https://github.com/DeepRank/deeprank-core/tree/main/deeprank2/neuralnets/gnn) and can be used for training or as a basis for implementing new ones.\n", - "- `class_weights` is used for classification tasks only and assigns class weights based on the training dataset content to account for any potential inbalance between the classes. In this case the dataset is balanced (50% 0 and 50% 1), so it is not necessary to use it. It defaults to False.\n", - "- `cuda` and `ngpu` are used for indicating whether to use CUDA and how many GPUs. By default, CUDA is not used and `ngpu` is 0.\n", - "- The user can specify a deeprank2 exporter or a custom one in `output_exporters` parameter, together with the path where to save the results. Exporters are used for storing predictions information collected later on during training and testing. Later the results saved by `HDF5OutputExporter` will be read and evaluated.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Training\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer = Trainer(\n", - " neuralnet=VanillaNetwork,\n", - " dataset_train=dataset_train,\n", - " dataset_val=dataset_val,\n", - " dataset_test=dataset_test,\n", - " output_exporters=[HDF5OutputExporter(os.path.join(output_path, f\"gnn_{task}\"))],\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The default optimizer is `torch.optim.Adam`. It is possible to specify optimizer's parameters or to use another PyTorch optimizer object:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "optimizer = torch.optim.SGD\n", - "lr = 1e-3\n", - "weight_decay = 0.001\n", - "\n", - "trainer.configure_optimizers(optimizer, lr, weight_decay)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The default loss function for classification is `torch.nn.CrossEntropyLoss` and for regression it is `torch.nn.MSELoss`. It is also possible to set some other PyTorch loss functions by using `Trainer.set_lossfunction` method, although not all are currently implemented.\n", - "\n", - "Then the model can be trained using the `train()` method of the `Trainer` class.\n", - "\n", - "A few notes about `train()` method parameters:\n", - "\n", - "- `earlystop_patience`, `earlystop_maxgap` and `min_epoch` are used for controlling early stopping logic. `earlystop_patience` indicates the number of epochs after which the training ends if the validation loss does not improve. `earlystop_maxgap` indicated the maximum difference allowed between validation and training loss, and `min_epoch` is the minimum number of epochs to be reached before evaluating `maxgap`.\n", - "- If `validate` is set to `True`, validation is performed on an independent dataset, which has been called `dataset_val` few cells above. If set to `False`, validation is performed on the training dataset itself (not recommended).\n", - "- `num_workers` can be set for indicating how many subprocesses to use for data loading. The default is 0 and it means that the data will be loaded in the main process.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "epochs = 20\n", - "batch_size = 8\n", - "earlystop_patience = 5\n", - "earlystop_maxgap = 0.1\n", - "min_epoch = 10\n", - "\n", - "trainer.train(\n", - " nepoch=epochs,\n", - " batch_size=batch_size,\n", - " earlystop_patience=earlystop_patience,\n", - " earlystop_maxgap=earlystop_maxgap,\n", - " min_epoch=min_epoch,\n", - " validate=True,\n", - " filename=os.path.join(output_path, f\"gnn_{task}\", \"model.pth.tar\"),\n", - ")\n", - "\n", - "epoch = trainer.epoch_saved_model\n", - "print(f\"Model saved at epoch {epoch}\")\n", - "pytorch_total_params = sum(p.numel() for p in trainer.model.parameters())\n", - "print(f\"Total # of parameters: {pytorch_total_params}\")\n", - "pytorch_trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)\n", - "print(f\"Total # of trainable parameters: {pytorch_trainable_params}\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Testing\n", - "\n", - "And the trained model can be tested on `dataset_test`:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.test()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Results visualization\n", - "\n", - "Finally, the results saved by `HDF5OutputExporter` can be inspected, which can be found in the `data/ppi/gnn_classif/` folder in the form of an HDF5 file, `output_exporter.hdf5`. Note that the folder contains the saved pre-trained model as well.\n", - "\n", - "`output_exporter.hdf5` contains [HDF5 Groups](https://docs.h5py.org/en/stable/high/group.html) which refer to each phase, e.g. training and testing if both are run, only one of them otherwise. Training phase includes validation results as well. This HDF5 file can be read as a Pandas Dataframe:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "output_train = pd.read_hdf(os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"training\")\n", - "output_test = pd.read_hdf(os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\")\n", - "output_train.head()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The dataframes contain `phase`, `epoch`, `entry`, `output`, `target`, and `loss` columns, and can be easily used to visualize the results.\n", - "\n", - "For classification tasks, the `output` column contains a list of probabilities that each class occurs, and each list sums to 1 (for more details, please see documentation on the [softmax function](https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html)). Note that the order of the classes in the list depends on the `classes` attribute of the DeeprankDataset instances. For classification tasks, if `classes` is not specified (as in this example case), it is defaulted to [0, 1].\n", - "\n", - "The loss across the epochs can be plotted for the training and the validation sets:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig = px.line(output_train, x=\"epoch\", y=\"loss\", color=\"phase\", markers=True)\n", - "\n", - "fig.add_vline(x=trainer.epoch_saved_model, line_width=3, line_dash=\"dash\", line_color=\"green\")\n", - "\n", - "fig.update_layout(\n", - " xaxis_title=\"Epoch #\",\n", - " yaxis_title=\"Loss\",\n", - " title=\"Loss vs epochs - GNN training\",\n", - " width=700,\n", - " height=400,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And now a few metrics of interest for classification tasks can be printed out: the [area under the ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve) (AUC), and for a threshold of 0.5 the [precision, recall, accuracy and f1 score](https://en.wikipedia.org/wiki/Precision_and_recall#Definition).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [], - "source": [ - "threshold = 0.5\n", - "df = pd.concat([output_train, output_test])\n", - "df_plot = df[(df.epoch == trainer.epoch_saved_model) | ((df.epoch == trainer.epoch_saved_model) & (df.phase == \"testing\"))]\n", - "\n", - "for idx, set in enumerate([\"training\", \"validation\", \"testing\"]):\n", - " df_plot_phase = df_plot[(df_plot.phase == set)]\n", - " y_true = df_plot_phase.target\n", - " y_score = np.array(df_plot_phase.output.values.tolist())[:, 1]\n", - "\n", - " print(f\"\\nMetrics for {set}:\")\n", - " fpr_roc, tpr_roc, thr_roc = roc_curve(y_true, y_score)\n", - " auc_score = auc(fpr_roc, tpr_roc)\n", - " print(f\"AUC: {round(auc_score, 1)}\")\n", - " print(f\"Considering a threshold of {threshold}\")\n", - " y_pred = (y_score > threshold) * 1\n", - " print(f\"- Precision: {round(precision_score(y_true, y_pred), 1)}\")\n", - " print(f\"- Recall: {round(recall_score(y_true, y_pred), 1)}\")\n", - " print(f\"- Accuracy: {round(accuracy_score(y_true, y_pred), 1)}\")\n", - " print(f\"- F1: {round(f1_score(y_true, y_pred), 1)}\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note that the poor performance of this network is due to the small number of datapoints used in this tutorial. For a more reliable network we suggest using a number of data points on the order of at least tens of thousands.\n", - "\n", - "The same exercise can be repeated but using grids instead of graphs and CNNs instead of GNNs.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### CNN\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### GridDataset\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For training CNNs the user can create `GridDataset` instances.\n", - "\n", - "A few notes about `GridDataset` parameters:\n", - "\n", - "- By default, all features contained in the HDF5 files are used, but the user can specify `features` in `GridDataset` if not all of them are needed. Since grids features are derived from node and edge features mapped from graphs to grid, the easiest way to see which features are available is to look at the HDF5 file, as explained in detail in `data_generation_ppi.ipynb` and `data_generation_srv.ipynb`, section \"Other tools\".\n", - "- As is the case for a `GraphDataset`, `task` can be assigned to `regress` and `target` to `BA` to perform a regression task. As mentioned previously, we do not provide sample data to perform a regression task for SRVs.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target = \"binary\"\n", - "task = \"classif\"\n", - "\n", - "print(\"Loading training data...\")\n", - "dataset_train = GridDataset(\n", - " hdf5_path=input_data_path,\n", - " subset=list(df_train.entry), # selects only data points with ids in df_train.entry\n", - " target=target,\n", - " task=task,\n", - ")\n", - "print(\"\\nLoading validation data...\")\n", - "dataset_val = GridDataset(\n", - " hdf5_path=input_data_path,\n", - " subset=list(df_valid.entry), # selects only data points with ids in df_valid.entry\n", - " train_source=dataset_train,\n", - ")\n", - "print(\"\\nLoading test data...\")\n", - "dataset_test = GridDataset(\n", - " hdf5_path=input_data_path,\n", - " subset=list(df_test.entry), # selects only data points with ids in df_test.entry\n", - " train_source=dataset_train,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Trainer\n", - "\n", - "As for graphs, the class `Trainer` is used for training, validation and testing of the PyTorch-based CNN.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "- Also in this case, `neuralnet` can be any neural network class that inherits from `torch.nn.Module`, and it shouldn't be specific to regression or classification in terms of output shape. This tutorial uses `CnnClassification` (implemented in `deeprank2.neuralnets.cnn.model3d`). All CNN architectures already implemented in the pakcage can be found [here](https://github.com/DeepRank/deeprank2/tree/main/deeprank2/neuralnets/cnn) and can be used for training or as a basis for implementing new ones.\n", - "- The rest of the `Trainer` parameters can be used as explained already for graphs.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Training\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "optimizer = torch.optim.SGD\n", - "lr = 1e-3\n", - "weight_decay = 0.001\n", - "epochs = 20\n", - "batch_size = 8\n", - "earlystop_patience = 5\n", - "earlystop_maxgap = 0.1\n", - "min_epoch = 10\n", - "\n", - "trainer = Trainer(\n", - " neuralnet=CnnClassification,\n", - " dataset_train=dataset_train,\n", - " dataset_val=dataset_val,\n", - " dataset_test=dataset_test,\n", - " output_exporters=[HDF5OutputExporter(os.path.join(output_path, f\"cnn_{task}\"))],\n", - ")\n", - "\n", - "trainer.configure_optimizers(optimizer, lr, weight_decay)\n", - "\n", - "trainer.train(\n", - " nepoch=epochs,\n", - " batch_size=batch_size,\n", - " earlystop_patience=earlystop_patience,\n", - " earlystop_maxgap=earlystop_maxgap,\n", - " min_epoch=min_epoch,\n", - " validate=True,\n", - " filename=os.path.join(output_path, f\"cnn_{task}\", \"model.pth.tar\"),\n", - ")\n", - "\n", - "epoch = trainer.epoch_saved_model\n", - "print(f\"Model saved at epoch {epoch}\")\n", - "pytorch_total_params = sum(p.numel() for p in trainer.model.parameters())\n", - "print(f\"Total # of parameters: {pytorch_total_params}\")\n", - "pytorch_trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)\n", - "print(f\"Total # of trainable parameters: {pytorch_trainable_params}\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Testing\n", - "\n", - "And the trained model can be tested on `dataset_test`:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.test()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Results visualization\n", - "\n", - "As for GNNs, the results saved by `HDF5OutputExporter` can be inspected, and are saved in the `data/ppi/cnn_classif/` or `data/srv/cnn_classif/` folder in the form of an HDF5 file, `output_exporter.hdf5`, together with the saved pre-trained model.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "output_train = pd.read_hdf(os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"training\")\n", - "output_test = pd.read_hdf(os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\")\n", - "output_train.head()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Also in this case, the loss across the epochs can be plotted for the training and the validation sets:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig = px.line(output_train, x=\"epoch\", y=\"loss\", color=\"phase\", markers=True)\n", - "\n", - "fig.add_vline(x=trainer.epoch_saved_model, line_width=3, line_dash=\"dash\", line_color=\"green\")\n", - "\n", - "fig.update_layout(\n", - " xaxis_title=\"Epoch #\",\n", - " yaxis_title=\"Loss\",\n", - " title=\"Loss vs epochs - CNN training\",\n", - " width=700,\n", - " height=400,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And some metrics of interest for classification tasks:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "threshold = 0.5\n", - "df = pd.concat([output_train, output_test])\n", - "df_plot = df[(df.epoch == trainer.epoch_saved_model) | ((df.epoch == trainer.epoch_saved_model) & (df.phase == \"testing\"))]\n", - "\n", - "for idx, set in enumerate([\"training\", \"validation\", \"testing\"]):\n", - " df_plot_phase = df_plot[(df_plot.phase == set)]\n", - " y_true = df_plot_phase.target\n", - " y_score = np.array(df_plot_phase.output.values.tolist())[:, 1]\n", - "\n", - " print(f\"\\nMetrics for {set}:\")\n", - " fpr_roc, tpr_roc, thr_roc = roc_curve(y_true, y_score)\n", - " auc_score = auc(fpr_roc, tpr_roc)\n", - " print(f\"AUC: {round(auc_score, 1)}\")\n", - " print(f\"Considering a threshold of {threshold}\")\n", - " y_pred = (y_score > threshold) * 1\n", - " print(f\"- Precision: {round(precision_score(y_true, y_pred), 1)}\")\n", - " print(f\"- Recall: {round(recall_score(y_true, y_pred), 1)}\")\n", - " print(f\"- Accuracy: {round(accuracy_score(y_true, y_pred), 1)}\")\n", - " print(f\"- F1: {round(f1_score(y_true, y_pred), 1)}\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "It's important to note that the dataset used in this analysis is not sufficiently large to provide conclusive and reliable insights. Depending on your specific application, you might find regression, classification, GNNs, and/or CNNs to be valuable options. Feel free to choose the approach that best aligns with your particular problem!\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "deeprank2", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training Neural Networks\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "\n", + "\n", + "This tutorial will demonstrate the use of DeepRank2 for training graph neural networks (GNNs) and convolutional neural networks (CNNs) using protein-protein interface (PPI) or single-residue variant (SRV) data for classification and regression predictive tasks.\n", + "\n", + "This tutorial assumes that the PPI data of interest have already been generated and saved as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format), with the data structure that DeepRank2 expects. This data can be generated using the [data_generation_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/data_generation_ppi.ipynb) tutorial or downloaded from Zenodo at [this record address](https://zenodo.org/record/8349335). For more details on the data structure, please refer to the other tutorial, which also contains a detailed description of how the data is generated from PDB files.\n", + "\n", + "This tutorial assumes also a basic knowledge of the [PyTorch](https://pytorch.org/) framework, on top of which the machine learning pipeline of DeepRank2 has been developed, for which many online tutorials exist.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Input data\n", + "\n", + "If you have previously run `data_generation_ppi.ipynb` or `data_generation_srv.ipynb` notebook, then their output can be directly used as input for this tutorial.\n", + "\n", + "Alternatively, preprocessed HDF5 files can be downloaded directly from Zenodo at [this record address](https://zenodo.org/record/7997585). To download the data used in this tutorial, please visit the link and download `data_processed.zip`. Unzip it, and save the `data_processed/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", + "\n", + "Note that the datasets contain only ~100 data points each, which is not enough to develop an impactful predictive model, and the scope of their use is indeed only demonstrative and informative for the users.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Utilities\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Libraries\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The libraries needed for this tutorial:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "import logging\n", + "import os\n", + "import warnings\n", + "\n", + "import h5py\n", + "import numpy as np\n", + "import pandas as pd\n", + "import plotly.express as px\n", + "import torch\n", + "from sklearn.metrics import accuracy_score, auc, f1_score, precision_score, recall_score, roc_curve\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "from deeprank2.dataset import GraphDataset, GridDataset\n", + "from deeprank2.neuralnets.cnn.model3d import CnnClassification\n", + "from deeprank2.neuralnets.gnn.vanilla_gnn import VanillaNetwork\n", + "from deeprank2.trainer import Trainer\n", + "from deeprank2.utils.exporters import HDF5OutputExporter\n", + "\n", + "np.seterr(divide=\"ignore\")\n", + "np.seterr(invalid=\"ignore\")\n", + "\n", + "logging.basicConfig(level=logging.INFO)\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "# ruff: noqa: PD901" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Paths and sets\n", + "\n", + "The paths for reading the processed data:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_type = \"ppi\"\n", + "level = \"residue\"\n", + "processed_data_path = os.path.join(\"data_processed\", data_type, level)\n", + "input_data_path = glob.glob(os.path.join(processed_data_path, \"*.hdf5\"))\n", + "output_path = os.path.join(\"data_processed\", data_type, level) # for saving predictions results" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `data_type` can be either \"ppi\" or \"srv\", depending on which application the user is most interested in. The `level` can be either \"residue\" or \"atomic\", and refers to the structural resolution, where each node either represents a single residue or a single atom from the molecular structure.\n", + "\n", + "In this tutorial, we will use PPI residue-level data by default, but the same code can be applied to SRV or/and atomic-level data with no changes, apart from setting `data_type` and `level` parameters in the cell above.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A Pandas DataFrame containing data points' IDs and the binary target values can be defined:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_dict = {}\n", + "df_dict[\"entry\"] = []\n", + "df_dict[\"target\"] = []\n", + "for fname in input_data_path:\n", + " with h5py.File(fname, \"r\") as hdf5:\n", + " for mol in hdf5:\n", + " target_value = float(hdf5[mol][\"target_values\"][\"binary\"][()])\n", + " df_dict[\"entry\"].append(mol)\n", + " df_dict[\"target\"].append(target_value)\n", + "\n", + "df = pd.DataFrame(data=df_dict)\n", + "df.head()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As explained in [data_generation_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/data_generation_ppi.ipynb), for each data point there are two targets: \"BA\" and \"binary\". The first represents the strength of the interaction between two molecules that bind reversibly (interact) in nM, while the second represents its binary mapping, being 0 (BA > 500 nM) a not-binding complex and 1 (BA <= 500 nM) binding one.\n", + "\n", + "For SRVs, each data point has a single target, \"binary\", which is 0 if the SRV is considered benign, and 1 if it is pathogenic, as explained in [data_generation_srv.ipynb](https://github.com/DeepRank/deeprank-core/blob/main/tutorials/data_generation_srv.ipynb).\n", + "\n", + "The pandas DataFrame `df` is used only to split data points into training, validation and test sets according to the \"binary\" target - using target stratification to keep the proportion of 0s and 1s constant among the different sets. Training and validation sets will be used during the training for updating the network weights, while the test set will be held out as an independent test and will be used later for the model evaluation.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_train, df_test = train_test_split(df, test_size=0.1, stratify=df.target, random_state=42)\n", + "df_train, df_valid = train_test_split(df_train, test_size=0.2, stratify=df_train.target, random_state=42)\n", + "\n", + "print(\"Data statistics:\\n\")\n", + "print(f\"Total samples: {len(df)}\\n\")\n", + "print(f\"Training set: {len(df_train)} samples, {round(100*len(df_train)/len(df))}%\")\n", + "print(f\"\\t- Class 0: {len(df_train[df_train.target == 0])} samples, {round(100*len(df_train[df_train.target == 0])/len(df_train))}%\")\n", + "print(f\"\\t- Class 1: {len(df_train[df_train.target == 1])} samples, {round(100*len(df_train[df_train.target == 1])/len(df_train))}%\")\n", + "print(f\"Validation set: {len(df_valid)} samples, {round(100*len(df_valid)/len(df))}%\")\n", + "print(f\"\\t- Class 0: {len(df_valid[df_valid.target == 0])} samples, {round(100*len(df_valid[df_valid.target == 0])/len(df_valid))}%\")\n", + "print(f\"\\t- Class 1: {len(df_valid[df_valid.target == 1])} samples, {round(100*len(df_valid[df_valid.target == 1])/len(df_valid))}%\")\n", + "print(f\"Testing set: {len(df_test)} samples, {round(100*len(df_test)/len(df))}%\")\n", + "print(f\"\\t- Class 0: {len(df_test[df_test.target == 0])} samples, {round(100*len(df_test[df_test.target == 0])/len(df_test))}%\")\n", + "print(f\"\\t- Class 1: {len(df_test[df_test.target == 1])} samples, {round(100*len(df_test[df_test.target == 1])/len(df_test))}%\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Classification example\n", + "\n", + "A GNN and a CNN can be trained for a classification predictive task, which consists in predicting the \"binary\" target values.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### GNN\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### GraphDataset\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For training GNNs the user can create `GraphDataset` instances. This class inherits from `DeeprankDataset` class, which in turns inherits from `Dataset` [PyTorch geometric class](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/data/dataset.html), a base class for creating graph datasets.\n", + "\n", + "A few notes about `GraphDataset` parameters:\n", + "\n", + "- By default, all features contained in the HDF5 files are used, but the user can specify `node_features` and `edge_features` in `GraphDataset` if not all of them are needed. See the [docs](https://deeprank2.readthedocs.io/en/latest/features.html) for more details about all the possible pre-implemented features.\n", + "- For regression, `task` should be set to `regress` and the `target` to `BA`, which is a continuous variable and therefore suitable for regression tasks.\n", + "- For the `GraphDataset` class it is possible to define a dictionary to indicate which transformations to apply to the features, being the transformations lambda functions and/or standardization.\n", + " - If the `standardize` key is `True`, standardization is applied after transformation. Standardization consists in applying the following formula on each feature's value: ${x' = \\frac{x - \\mu}{\\sigma}}$, being ${\\mu}$ the mean and ${\\sigma}$ the standard deviation. Standardization is a scaling method where the values are centered around mean with a unit standard deviation.\n", + " - The transformation to apply can be speficied as a lambda function as a value of the key `transform`, which defaults to `None`.\n", + " - Since in the provided example standardization is applied, the training features' means and standard deviations need to be used for scaling validation and test sets. For doing so, `train_source` parameter is used. When `train_source` parameter is set, it will be used to scale the validation/testing sets. You need to pass `features_transform` to the training dataset only, since in other cases it will be ignored and only the one of `train_source` will be considered.\n", + " - Note that transformations have not currently been implemented for the `GridDataset` class.\n", + " - In the example below a logarithmic transformation and then the standardization are applied to all the features. It is also possible to use specific features as keys for indicating that transformation and/or standardization need to be apply to few features only.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target = \"binary\"\n", + "task = \"classif\"\n", + "node_features = [\"res_type\"]\n", + "edge_features = [\"distance\"]\n", + "features_transform = {\"all\": {\"transform\": lambda x: np.cbrt(x), \"standardize\": True}}\n", + "\n", + "print(\"Loading training data...\")\n", + "dataset_train = GraphDataset(\n", + " hdf5_path=input_data_path,\n", + " subset=list(df_train.entry), # selects only data points with ids in df_train.entry\n", + " node_features=node_features,\n", + " edge_features=edge_features,\n", + " features_transform=features_transform,\n", + " target=target,\n", + " task=task,\n", + ")\n", + "print(\"\\nLoading validation data...\")\n", + "dataset_val = GraphDataset(\n", + " hdf5_path=input_data_path,\n", + " subset=list(df_valid.entry), # selects only data points with ids in df_valid.entry\n", + " train_source=dataset_train,\n", + ")\n", + "print(\"\\nLoading test data...\")\n", + "dataset_test = GraphDataset(\n", + " hdf5_path=input_data_path,\n", + " subset=list(df_test.entry), # selects only data points with ids in df_test.entry\n", + " train_source=dataset_train,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Trainer\n", + "\n", + "The class `Trainer` implements training, validation and testing of PyTorch-based neural networks.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A few notes about `Trainer` parameters:\n", + "\n", + "- `neuralnet` can be any neural network class that inherits from `torch.nn.Module`, and it shouldn't be specific to regression or classification in terms of output shape. The `Trainer` class takes care of formatting the output shape according to the task. This tutorial uses a simple network, `VanillaNetwork` (implemented in `deeprank2.neuralnets.gnn.vanilla_gnn`). All GNN architectures already implemented in the pakcage can be found [here](https://github.com/DeepRank/deeprank-core/tree/main/deeprank2/neuralnets/gnn) and can be used for training or as a basis for implementing new ones.\n", + "- `class_weights` is used for classification tasks only and assigns class weights based on the training dataset content to account for any potential inbalance between the classes. In this case the dataset is balanced (50% 0 and 50% 1), so it is not necessary to use it. It defaults to False.\n", + "- `cuda` and `ngpu` are used for indicating whether to use CUDA and how many GPUs. By default, CUDA is not used and `ngpu` is 0.\n", + "- The user can specify a deeprank2 exporter or a custom one in `output_exporters` parameter, together with the path where to save the results. Exporters are used for storing predictions information collected later on during training and testing. Later the results saved by `HDF5OutputExporter` will be read and evaluated.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Training\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " neuralnet=VanillaNetwork,\n", + " dataset_train=dataset_train,\n", + " dataset_val=dataset_val,\n", + " dataset_test=dataset_test,\n", + " output_exporters=[HDF5OutputExporter(os.path.join(output_path, f\"gnn_{task}\"))],\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The default optimizer is `torch.optim.Adam`. It is possible to specify optimizer's parameters or to use another PyTorch optimizer object:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.SGD\n", + "lr = 1e-3\n", + "weight_decay = 0.001\n", + "\n", + "trainer.configure_optimizers(optimizer, lr, weight_decay)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The default loss function for classification is `torch.nn.CrossEntropyLoss` and for regression it is `torch.nn.MSELoss`. It is also possible to set some other PyTorch loss functions by using `Trainer.set_lossfunction` method, although not all are currently implemented.\n", + "\n", + "Then the model can be trained using the `train()` method of the `Trainer` class.\n", + "\n", + "A few notes about `train()` method parameters:\n", + "\n", + "- `earlystop_patience`, `earlystop_maxgap` and `min_epoch` are used for controlling early stopping logic. `earlystop_patience` indicates the number of epochs after which the training ends if the validation loss does not improve. `earlystop_maxgap` indicated the maximum difference allowed between validation and training loss, and `min_epoch` is the minimum number of epochs to be reached before evaluating `maxgap`.\n", + "- If `validate` is set to `True`, validation is performed on an independent dataset, which has been called `dataset_val` few cells above. If set to `False`, validation is performed on the training dataset itself (not recommended).\n", + "- `num_workers` can be set for indicating how many subprocesses to use for data loading. The default is 0 and it means that the data will be loaded in the main process.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 20\n", + "batch_size = 8\n", + "earlystop_patience = 5\n", + "earlystop_maxgap = 0.1\n", + "min_epoch = 10\n", + "\n", + "trainer.train(\n", + " nepoch=epochs,\n", + " batch_size=batch_size,\n", + " earlystop_patience=earlystop_patience,\n", + " earlystop_maxgap=earlystop_maxgap,\n", + " min_epoch=min_epoch,\n", + " validate=True,\n", + " filename=os.path.join(output_path, f\"gnn_{task}\", \"model.pth.tar\"),\n", + ")\n", + "\n", + "epoch = trainer.epoch_saved_model\n", + "print(f\"Model saved at epoch {epoch}\")\n", + "pytorch_total_params = sum(p.numel() for p in trainer.model.parameters())\n", + "print(f\"Total # of parameters: {pytorch_total_params}\")\n", + "pytorch_trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)\n", + "print(f\"Total # of trainable parameters: {pytorch_trainable_params}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Testing\n", + "\n", + "And the trained model can be tested on `dataset_test`:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.test()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Results visualization\n", + "\n", + "Finally, the results saved by `HDF5OutputExporter` can be inspected, which can be found in the `data/ppi/gnn_classif/` folder in the form of an HDF5 file, `output_exporter.hdf5`. Note that the folder contains the saved pre-trained model as well.\n", + "\n", + "`output_exporter.hdf5` contains [HDF5 Groups](https://docs.h5py.org/en/stable/high/group.html) which refer to each phase, e.g. training and testing if both are run, only one of them otherwise. Training phase includes validation results as well. This HDF5 file can be read as a Pandas Dataframe:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_train = pd.read_hdf(os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"training\")\n", + "output_test = pd.read_hdf(os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\")\n", + "output_train.head()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The dataframes contain `phase`, `epoch`, `entry`, `output`, `target`, and `loss` columns, and can be easily used to visualize the results.\n", + "\n", + "For classification tasks, the `output` column contains a list of probabilities that each class occurs, and each list sums to 1 (for more details, please see documentation on the [softmax function](https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html)). Note that the order of the classes in the list depends on the `classes` attribute of the DeeprankDataset instances. For classification tasks, if `classes` is not specified (as in this example case), it is defaulted to [0, 1].\n", + "\n", + "The loss across the epochs can be plotted for the training and the validation sets:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.line(output_train, x=\"epoch\", y=\"loss\", color=\"phase\", markers=True)\n", + "\n", + "fig.add_vline(x=trainer.epoch_saved_model, line_width=3, line_dash=\"dash\", line_color=\"green\")\n", + "\n", + "fig.update_layout(\n", + " xaxis_title=\"Epoch #\",\n", + " yaxis_title=\"Loss\",\n", + " title=\"Loss vs epochs - GNN training\",\n", + " width=700,\n", + " height=400,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now a few metrics of interest for classification tasks can be printed out: the [area under the ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve) (AUC), and for a threshold of 0.5 the [precision, recall, accuracy and f1 score](https://en.wikipedia.org/wiki/Precision_and_recall#Definition).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "skip-execution" + ] + }, + "outputs": [], + "source": [ + "threshold = 0.5\n", + "df = pd.concat([output_train, output_test])\n", + "df_plot = df[(df.epoch == trainer.epoch_saved_model) | ((df.epoch == trainer.epoch_saved_model) & (df.phase == \"testing\"))]\n", + "\n", + "for dataset in [\"training\", \"validation\", \"testing\"]:\n", + " df_plot_phase = df_plot[(df_plot.phase == dataset)]\n", + " y_true = df_plot_phase.target\n", + " y_score = np.array(df_plot_phase.output.tolist())[:, 1]\n", + "\n", + " print(f\"\\nMetrics for {dataset}:\")\n", + " fpr_roc, tpr_roc, thr_roc = roc_curve(y_true, y_score)\n", + " auc_score = auc(fpr_roc, tpr_roc)\n", + " print(f\"AUC: {round(auc_score, 1)}\")\n", + " print(f\"Considering a threshold of {threshold}\")\n", + " y_pred = (y_score > threshold) * 1\n", + " print(f\"- Precision: {round(precision_score(y_true, y_pred), 1)}\")\n", + " print(f\"- Recall: {round(recall_score(y_true, y_pred), 1)}\")\n", + " print(f\"- Accuracy: {round(accuracy_score(y_true, y_pred), 1)}\")\n", + " print(f\"- F1: {round(f1_score(y_true, y_pred), 1)}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the poor performance of this network is due to the small number of datapoints used in this tutorial. For a more reliable network we suggest using a number of data points on the order of at least tens of thousands.\n", + "\n", + "The same exercise can be repeated but using grids instead of graphs and CNNs instead of GNNs.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CNN\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### GridDataset\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For training CNNs the user can create `GridDataset` instances.\n", + "\n", + "A few notes about `GridDataset` parameters:\n", + "\n", + "- By default, all features contained in the HDF5 files are used, but the user can specify `features` in `GridDataset` if not all of them are needed. Since grids features are derived from node and edge features mapped from graphs to grid, the easiest way to see which features are available is to look at the HDF5 file, as explained in detail in `data_generation_ppi.ipynb` and `data_generation_srv.ipynb`, section \"Other tools\".\n", + "- As is the case for a `GraphDataset`, `task` can be assigned to `regress` and `target` to `BA` to perform a regression task. As mentioned previously, we do not provide sample data to perform a regression task for SRVs.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target = \"binary\"\n", + "task = \"classif\"\n", + "\n", + "print(\"Loading training data...\")\n", + "dataset_train = GridDataset(\n", + " hdf5_path=input_data_path,\n", + " subset=list(df_train.entry), # selects only data points with ids in df_train.entry\n", + " target=target,\n", + " task=task,\n", + ")\n", + "print(\"\\nLoading validation data...\")\n", + "dataset_val = GridDataset(\n", + " hdf5_path=input_data_path,\n", + " subset=list(df_valid.entry), # selects only data points with ids in df_valid.entry\n", + " train_source=dataset_train,\n", + ")\n", + "print(\"\\nLoading test data...\")\n", + "dataset_test = GridDataset(\n", + " hdf5_path=input_data_path,\n", + " subset=list(df_test.entry), # selects only data points with ids in df_test.entry\n", + " train_source=dataset_train,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Trainer\n", + "\n", + "As for graphs, the class `Trainer` is used for training, validation and testing of the PyTorch-based CNN.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Also in this case, `neuralnet` can be any neural network class that inherits from `torch.nn.Module`, and it shouldn't be specific to regression or classification in terms of output shape. This tutorial uses `CnnClassification` (implemented in `deeprank2.neuralnets.cnn.model3d`). All CNN architectures already implemented in the pakcage can be found [here](https://github.com/DeepRank/deeprank2/tree/main/deeprank2/neuralnets/cnn) and can be used for training or as a basis for implementing new ones.\n", + "- The rest of the `Trainer` parameters can be used as explained already for graphs.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Training\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.SGD\n", + "lr = 1e-3\n", + "weight_decay = 0.001\n", + "epochs = 20\n", + "batch_size = 8\n", + "earlystop_patience = 5\n", + "earlystop_maxgap = 0.1\n", + "min_epoch = 10\n", + "\n", + "trainer = Trainer(\n", + " neuralnet=CnnClassification,\n", + " dataset_train=dataset_train,\n", + " dataset_val=dataset_val,\n", + " dataset_test=dataset_test,\n", + " output_exporters=[HDF5OutputExporter(os.path.join(output_path, f\"cnn_{task}\"))],\n", + ")\n", + "\n", + "trainer.configure_optimizers(optimizer, lr, weight_decay)\n", + "\n", + "trainer.train(\n", + " nepoch=epochs,\n", + " batch_size=batch_size,\n", + " earlystop_patience=earlystop_patience,\n", + " earlystop_maxgap=earlystop_maxgap,\n", + " min_epoch=min_epoch,\n", + " validate=True,\n", + " filename=os.path.join(output_path, f\"cnn_{task}\", \"model.pth.tar\"),\n", + ")\n", + "\n", + "epoch = trainer.epoch_saved_model\n", + "print(f\"Model saved at epoch {epoch}\")\n", + "pytorch_total_params = sum(p.numel() for p in trainer.model.parameters())\n", + "print(f\"Total # of parameters: {pytorch_total_params}\")\n", + "pytorch_trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)\n", + "print(f\"Total # of trainable parameters: {pytorch_trainable_params}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Testing\n", + "\n", + "And the trained model can be tested on `dataset_test`:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.test()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Results visualization\n", + "\n", + "As for GNNs, the results saved by `HDF5OutputExporter` can be inspected, and are saved in the `data/ppi/cnn_classif/` or `data/srv/cnn_classif/` folder in the form of an HDF5 file, `output_exporter.hdf5`, together with the saved pre-trained model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_train = pd.read_hdf(os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"training\")\n", + "output_test = pd.read_hdf(os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\")\n", + "output_train.head()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Also in this case, the loss across the epochs can be plotted for the training and the validation sets:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.line(output_train, x=\"epoch\", y=\"loss\", color=\"phase\", markers=True)\n", + "\n", + "fig.add_vline(x=trainer.epoch_saved_model, line_width=3, line_dash=\"dash\", line_color=\"green\")\n", + "\n", + "fig.update_layout(\n", + " xaxis_title=\"Epoch #\",\n", + " yaxis_title=\"Loss\",\n", + " title=\"Loss vs epochs - CNN training\",\n", + " width=700,\n", + " height=400,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And some metrics of interest for classification tasks:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "threshold = 0.5\n", + "df = pd.concat([output_train, output_test])\n", + "df_plot = df[(df.epoch == trainer.epoch_saved_model) | ((df.epoch == trainer.epoch_saved_model) & (df.phase == \"testing\"))]\n", + "\n", + "for dataset in [\"training\", \"validation\", \"testing\"]:\n", + " df_plot_phase = df_plot[(df_plot.phase == dataset)]\n", + " y_true = df_plot_phase.target\n", + " y_score = np.array(df_plot_phase.output.tolist())[:, 1]\n", + "\n", + " print(f\"\\nMetrics for {dataset}:\")\n", + " fpr_roc, tpr_roc, thr_roc = roc_curve(y_true, y_score)\n", + " auc_score = auc(fpr_roc, tpr_roc)\n", + " print(f\"AUC: {round(auc_score, 1)}\")\n", + " print(f\"Considering a threshold of {threshold}\")\n", + " y_pred = (y_score > threshold) * 1\n", + " print(f\"- Precision: {round(precision_score(y_true, y_pred), 1)}\")\n", + " print(f\"- Recall: {round(recall_score(y_true, y_pred), 1)}\")\n", + " print(f\"- Accuracy: {round(accuracy_score(y_true, y_pred), 1)}\")\n", + " print(f\"- F1: {round(f1_score(y_true, y_pred), 1)}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It's important to note that the dataset used in this analysis is not sufficiently large to provide conclusive and reliable insights. Depending on your specific application, you might find regression, classification, GNNs, and/or CNNs to be valuable options. Feel free to choose the approach that best aligns with your particular problem!\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "deeprank2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 }