diff --git a/.bumpversion.cfg b/.bumpversion.cfg
index 650ce1223..26f291b47 100644
--- a/.bumpversion.cfg
+++ b/.bumpversion.cfg
@@ -1,5 +1,5 @@
[bumpversion]
-current_version = 3.0.0
+current_version = 3.0.1
[comment]
comment = The contents of this file cannot be merged with that of setup.cfg until https://github.com/c4urself/bump2version/issues/185 is resolved
diff --git a/CITATION.cff b/CITATION.cff
index ac8ea0c9b..9b36bc2d1 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -68,5 +68,5 @@ keywords:
- DeepRank
license: Apache-2.0
commit: 4e8823758ba03f824b4281f5689cb6a335ab2f6c
-version: "3.0.0"
-date-released: "2024-01-25"
+version: "3.0.1"
+date-released: "2024-02-22"
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index f6a3dfc95..7169f5a1a 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -37,8 +37,8 @@ You want to make some kind of change to the code base
#. if needed, fork the repository to your own Github profile and create your own feature branch off of the latest main commit. While working on your feature branch, make sure to stay up to date with the main branch by pulling in changes, possibly from the 'upstream' repository (follow the instructions `here `__ and `here `__);
#. make sure the existing tests still work by running ``python setup.py test``;
#. add your own tests (if necessary);
-#. ensure the code is correctly linted (`ruff .`) and formatted (`ruff format .`);
-#. see our `developer's readme `` for detailed information on our style conventions, etc.;
+#. ensure the code is correctly linted (``ruff .``) and formatted (``ruff format .``);
+#. see our `developer's readme `_ for detailed information on our style conventions, etc.;
#. update or expand the documentation;
#. `push `_ your feature branch to (your fork of) the DeepRank2 repository on GitHub;
#. create the pull request, e.g. following the instructions `here `__.
diff --git a/README.dev.md b/README.dev.md
index 256c222ec..f4eb2b56f 100644
--- a/README.dev.md
+++ b/README.dev.md
@@ -79,7 +79,7 @@ During the development cycle, three main supporting branches are used:
## Making a release
-1. Branch from `dev` and prepare the branch for the release (e.g., removing the unnecessary dev files such as the current one, fix minor bugs if necessary).
+1. Branch from `dev` and prepare the branch for the release (e.g., removing the unnecessary dev files, fix minor bugs if necessary).
2. [Bump the version](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#versioning).
3. Verify that the information in `CITATION.cff` is correct (update the release date), and that `.zenodo.json` contains equivalent data.
4. Merge the release branch into `main` (and `dev`), and [run the tests](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#running-the-tests).
diff --git a/deeprank2/__init__.py b/deeprank2/__init__.py
index 528787cfc..055276878 100644
--- a/deeprank2/__init__.py
+++ b/deeprank2/__init__.py
@@ -1 +1 @@
-__version__ = "3.0.0"
+__version__ = "3.0.1"
diff --git a/deeprank2/trainer.py b/deeprank2/trainer.py
index 70f650466..56b5637ad 100644
--- a/deeprank2/trainer.py
+++ b/deeprank2/trainer.py
@@ -947,7 +947,7 @@ def _save_model(self) -> dict[str, Any]:
if key["transform"] is None:
continue
str_expr = inspect.getsource(key["transform"])
- match = re.search(r"\'transform\':.*(lambda.*).*,.*\'standardize\'.*", str_expr).group(1)
+ match = re.search(r"[\"|\']transform[\"|\']:.*(lambda.*).*,.*[\"|\']standardize[\"|\'].*", str_expr).group(1)
key["transform"] = match
state = {
diff --git a/pyproject.toml b/pyproject.toml
index 6482eb385..e589828a4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "deeprank2"
-version = "3.0.0"
+version = "3.0.1"
description = "DeepRank2 is an open-source deep learning framework for data mining of protein-protein interfaces or single-residue missense variants."
readme = "README.md"
requires-python = ">=3.10"
diff --git a/tests/features/test_irc.py b/tests/features/test_irc.py
index e5cb26fcc..282bfcfc4 100644
--- a/tests/features/test_irc.py
+++ b/tests/features/test_irc.py
@@ -32,7 +32,7 @@ def test_irc_atom() -> None:
pdb_path = "tests/data/pdb/1A0Z/1A0Z.pdb"
graph, _ = build_testgraph(
pdb_path=pdb_path,
- detail="residue",
+ detail="atom",
influence_radius=4.5,
max_edge_length=4.5,
)
diff --git a/tests/test_trainer.py b/tests/test_trainer.py
index 591cb7cf3..8efe29229 100644
--- a/tests/test_trainer.py
+++ b/tests/test_trainer.py
@@ -86,6 +86,7 @@ def _model_base_test(
dataset_train,
dataset_val,
dataset_test,
+ cuda=use_cuda,
output_exporters=output_exporters,
)
@@ -94,20 +95,6 @@ def _model_base_test(
for parameter in trainer.model.parameters():
assert parameter.is_cuda, f"{parameter} is not cuda"
- data = dataset_train.get(0)
-
- for name, data_tensor in (
- ("x", data.x),
- ("y", data.y),
- (Efeat.INDEX, data.edge_index),
- ("edge_attr", data.edge_attr),
- (Nfeat.POSITION, data.pos),
- ("cluster0", data.cluster0),
- ("cluster1", data.cluster1),
- ):
- if data_tensor is not None:
- assert data_tensor.is_cuda, f"data.{name} is not cuda"
-
with warnings.catch_warnings(record=UserWarning):
trainer.train(
nepoch=3,
@@ -774,6 +761,37 @@ def test_test_method_pretrained_model_on_dataset_without_target(self) -> None:
assert output.target.unique().tolist()[0] is None
assert output.loss.unique().tolist()[0] is None
+ def test_graph_save_and_load_model(self) -> None:
+ test_data_graph = "tests/data/hdf5/test.hdf5"
+ n = 10
+ features_transform = {
+ Nfeat.RESTYPE: {"transform": lambda x: x / 2, "standardize": True},
+ Nfeat.BSA: {"transform": None, "standardize": False},
+ }
+
+ dataset = GraphDataset(
+ hdf5_path=test_data_graph,
+ node_features=[Nfeat.RESTYPE, Nfeat.POLARITY, Nfeat.BSA],
+ target=targets.BINARY,
+ task=targets.CLASSIF,
+ features_transform=features_transform,
+ )
+ trainer = Trainer(NaiveNetwork, dataset)
+ # during the training the model is saved
+ trainer.train(nepoch=2, batch_size=2, filename=self.save_path)
+ assert trainer.features_transform == features_transform
+
+ # load the model into a new GraphDataset instance
+ dataset_test = GraphDataset(
+ hdf5_path="tests/data/hdf5/test.hdf5",
+ train_source=self.save_path,
+ )
+
+ # Check if the features_transform is correctly loaded from the saved model
+ assert dataset_test.features_transform[Nfeat.RESTYPE]["transform"](n) == n / 2 # the only way to test the transform in this case is to apply it
+ assert dataset_test.features_transform[Nfeat.RESTYPE]["standardize"] == features_transform[Nfeat.RESTYPE]["standardize"]
+ assert dataset_test.features_transform[Nfeat.BSA] == features_transform[Nfeat.BSA]
+
if __name__ == "__main__":
unittest.main()