From 29cff1189cfa3b9ba3f2226e5798f40c5464ffb6 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 1 Oct 2024 17:14:49 +0200 Subject: [PATCH 1/2] close temporary file for tox21molnet download before retrieving url --- chebai/preprocessing/datasets/tox21.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/chebai/preprocessing/datasets/tox21.py b/chebai/preprocessing/datasets/tox21.py index 98d78009..79cd35e9 100644 --- a/chebai/preprocessing/datasets/tox21.py +++ b/chebai/preprocessing/datasets/tox21.py @@ -56,14 +56,16 @@ def processed_file_names(self) -> List[str]: def download(self) -> None: """Downloads and extracts the dataset.""" - with NamedTemporaryFile("rb") as gout: - request.urlretrieve( - "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz", - gout.name, - ) - with gzip.open(gout.name) as gfile: - with open(os.path.join(self.raw_dir, "tox21.csv"), "wt") as fout: - fout.write(gfile.read().decode()) + gout = NamedTemporaryFile("rb") + gout.close() + + request.urlretrieve( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz", + gout.name, + ) + with gzip.open(gout.name) as gfile: + with open(os.path.join(self.raw_dir, "tox21.csv"), "wt") as fout: + fout.write(gfile.read().decode()) def setup_processed(self) -> None: """Processes and splits the dataset.""" From e0a794ef76fd793b60469ebd88b539ea2e0bc410 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 7 Nov 2024 11:29:44 +0100 Subject: [PATCH 2/2] add test for tox21molnet --- tests/unit/dataset_classes/testTox21MolNet.py | 181 ++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 tests/unit/dataset_classes/testTox21MolNet.py diff --git a/tests/unit/dataset_classes/testTox21MolNet.py b/tests/unit/dataset_classes/testTox21MolNet.py new file mode 100644 index 00000000..86cbb752 --- /dev/null +++ b/tests/unit/dataset_classes/testTox21MolNet.py @@ -0,0 +1,181 @@ +import unittest +from typing import List +from unittest.mock import MagicMock, mock_open, patch + +import torch + +from chebai.preprocessing.datasets.tox21 import Tox21MolNet +from chebai.preprocessing.reader import ChemDataReader +from tests.unit.mock_data.tox_mock_data import Tox21MolNetMockData + + +class TestTox21MolNet(unittest.TestCase): + @classmethod + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs: MagicMock) -> None: + """ + Initialize a Tox21MolNet instance for testing. + + Args: + mock_makedirs (MagicMock): Mocked `os.makedirs` function. + """ + Tox21MolNet.READER = ChemDataReader + cls.data_module = Tox21MolNet() + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=Tox21MolNetMockData.get_raw_data(), + ) + def test_load_data_from_file(self, mock_open_file: mock_open) -> None: + """ + Test the `_load_data_from_file` method for correct output. + + Args: + mock_open_file (mock_open): Mocked open function to simulate file reading. + """ + actual_data = self.data_module._load_data_from_file("fake/file/path.csv") + + first_instance = next(actual_data) + + # Check for required keys + required_keys = ["features", "labels", "ident"] + for key in required_keys: + self.assertIn( + key, first_instance, f"'{key}' key is missing in the output data." + ) + + self.assertTrue( + all(isinstance(feature, int) for feature in first_instance["features"]), + "Not all elements in 'features' are integers.", + ) + + # Check that 'features' can be converted to a tensor + features = first_instance["features"] + try: + tensor_features = torch.tensor(features) + self.assertTrue( + tensor_features.ndim > 0, + "'features' should be convertible to a non-empty tensor.", + ) + except Exception as e: + self.fail(f"'features' cannot be converted to a tensor: {str(e)}") + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=Tox21MolNetMockData.get_raw_data(), + ) + @patch("torch.save") + def test_setup_processed_simple_split( + self, + mock_torch_save: MagicMock, + mock_open_file: mock_open, + ) -> None: + """ + Test the `setup_processed` method for basic data splitting and saving. + + Args: + mock_torch_save (MagicMock): Mocked `torch.save` function to avoid actual file writes. + mock_open_file (mock_open): Mocked `open` function to simulate file reading. + """ + self.data_module.setup_processed() + + # Verify if torch.save was called for each split (train, test, validation) + self.assertEqual( + mock_torch_save.call_count, 3, "Expected torch.save to be called 3 times." + ) + call_args_list = mock_torch_save.call_args_list + self.assertIn("test", call_args_list[0][0][1], "Missing 'test' split.") + self.assertIn("train", call_args_list[1][0][1], "Missing 'train' split.") + self.assertIn( + "validation", call_args_list[2][0][1], "Missing 'validation' split." + ) + + # Check for non-overlap between train, test, and validation splits + test_split: List[str] = [d["ident"] for d in call_args_list[0][0][0]] + train_split: List[str] = [d["ident"] for d in call_args_list[1][0][0]] + validation_split: List[str] = [d["ident"] for d in call_args_list[2][0][0]] + + self.assertTrue( + set(train_split).isdisjoint(test_split), + "Overlap detected between the train and test splits.", + ) + self.assertTrue( + set(train_split).isdisjoint(validation_split), + "Overlap detected between the train and validation splits.", + ) + self.assertTrue( + set(test_split).isdisjoint(validation_split), + "Overlap detected between the test and validation splits.", + ) + + @patch.object( + Tox21MolNet, + "_load_data_from_file", + return_value=Tox21MolNetMockData.get_processed_grouped_data(), + ) + @patch("torch.save") + def test_setup_processed_with_group_split( + self, mock_torch_save: MagicMock, mock_load_file: MagicMock + ) -> None: + """ + Test the `setup_processed` method for group-based splitting and saving. + + Args: + mock_torch_save (MagicMock): Mocked `torch.save` function to avoid actual file writes. + mock_load_file (MagicMock): Mocked `_load_data_from_file` to provide custom data. + """ + self.data_module.train_split = 0.5 + self.data_module.setup_processed() + + # Verify if torch.save was called for each split + self.assertEqual( + mock_torch_save.call_count, 3, "Expected torch.save to be called 3 times." + ) + call_args_list = mock_torch_save.call_args_list + self.assertIn("test", call_args_list[0][0][1], "Missing 'test' split.") + self.assertIn("train", call_args_list[1][0][1], "Missing 'train' split.") + self.assertIn( + "validation", call_args_list[2][0][1], "Missing 'validation' split." + ) + + # Check for non-overlap between train, test, and validation splits (based on 'ident') + test_split: List[str] = [d["ident"] for d in call_args_list[0][0][0]] + train_split: List[str] = [d["ident"] for d in call_args_list[1][0][0]] + validation_split: List[str] = [d["ident"] for d in call_args_list[2][0][0]] + + self.assertTrue( + set(train_split).isdisjoint(test_split), + "Overlap detected between the train and test splits (based on 'ident').", + ) + self.assertTrue( + set(train_split).isdisjoint(validation_split), + "Overlap detected between the train and validation splits (based on 'ident').", + ) + self.assertTrue( + set(test_split).isdisjoint(validation_split), + "Overlap detected between the test and validation splits (based on 'ident').", + ) + + # Check for non-overlap between train, test, and validation splits (based on 'group') + test_split_grp: List[str] = [d["group"] for d in call_args_list[0][0][0]] + train_split_grp: List[str] = [d["group"] for d in call_args_list[1][0][0]] + validation_split_grp: List[str] = [d["group"] for d in call_args_list[2][0][0]] + + self.assertTrue( + set(train_split_grp).isdisjoint(test_split_grp), + "Overlap detected between the train and test splits (based on 'group').", + ) + self.assertTrue( + set(train_split_grp).isdisjoint(validation_split_grp), + "Overlap detected between the train and validation splits (based on 'group').", + ) + self.assertTrue( + set(test_split_grp).isdisjoint(validation_split_grp), + "Overlap detected between the test and validation splits (based on 'group').", + ) + + +if __name__ == "__main__": + unittest.main()