diff --git a/tests/test_datamodules.py b/tests/test_datamodules.py index 2ed465b..17256ef 100644 --- a/tests/test_datamodules.py +++ b/tests/test_datamodules.py @@ -27,20 +27,13 @@ from itertools import combinations from torch.utils.data import RandomSampler, SequentialSampler from aidsorb.utils import pcd_from_dir -from aidsorb.data import prepare_data, Collator +from aidsorb.data import prepare_data, Collator, get_names from aidsorb.transforms import Center, RandomRotation from aidsorb.datamodules import PCDDataModule class TestPCDDataModule(unittest.TestCase): def setUp(self): - self.tempdir = tempfile.TemporaryDirectory(dir='/tmp') - self.outname = os.path.join(self.tempdir.name, 'pcds.npz') - self.split_ratio = [3, 1, 2] - - pcd_from_dir(dirname='tests/structures', outname=self.outname) - prepare_data(source=self.outname, split_ratio=self.split_ratio) - # Arguments for the datamodule. self.train_size = 2 self.train_trans_x = Center() @@ -51,13 +44,13 @@ def setUp(self): self.eval_bs = 2 self.config_dataloaders = { 'pin_memory': True, - 'num_workers': 2, + 'num_workers': 4, 'collate_fn': Collator() } # Instantiate the datamodule. self.dm = PCDDataModule( - path_to_X=self.outname, + path_to_X='tests/dummy/toy_project/pcd_data', path_to_Y='tests/dummy/toy_dataset.csv', index_col='id', labels=['y2', 'y3'], @@ -76,9 +69,12 @@ def setUp(self): def test_datasets(self): # Check that the datasets have the correct size. + val_names = get_names('tests/dummy/toy_project/validation.json') + test_names = get_names('tests/dummy/toy_project/test.json') + self.assertEqual(len(self.dm.train_dataset), self.train_size) - self.assertEqual(len(self.dm.validation_dataset), self.split_ratio[1]) - self.assertEqual(len(self.dm.test_dataset), self.split_ratio[2]) + self.assertEqual(len(self.dm.validation_dataset), len(val_names)) + self.assertEqual(len(self.dm.test_dataset), len(test_names)) # The pairwise intersections must be the empty set. for ds_comb in combinations([ @@ -101,7 +97,6 @@ def test_datasets(self): self.assertIs(ds.transform_y, self.trans_y) - def test_dataloaders(self): dataloaders = [ self.dm.train_dataloader(), @@ -114,8 +109,6 @@ def test_dataloaders(self): self.dm.test_dataset, ] - passed_collate_fn = self.config_dataloaders['collate_fn'] - for i, dl in enumerate(dataloaders): # Check that dataloaders use appropriate settings. if i == 0: @@ -125,7 +118,7 @@ def test_dataloaders(self): self.assertIsInstance(dl.sampler, SequentialSampler) self.assertEqual(dl.batch_size, self.eval_bs) - self.assertEqual(dl.collate_fn, passed_collate_fn) + self.assertEqual(dl.collate_fn, self.config_dataloaders['collate_fn']) # Check that collate function is used properly. for x, y in dl: @@ -137,4 +130,4 @@ def test_dataloaders(self): self.assertIs(dl.dataset, ds) def tearDown(self): - self.tempdir.cleanup() + ...