Skip to content

Commit

Permalink
test(datamodules): Reduce boilerplate with toy_project
Browse files Browse the repository at this point in the history
  • Loading branch information
adosar committed Dec 13, 2024
1 parent ed988a7 commit 3a00ebd
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions tests/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,13 @@
from itertools import combinations
from torch.utils.data import RandomSampler, SequentialSampler
from aidsorb.utils import pcd_from_dir
from aidsorb.data import prepare_data, Collator
from aidsorb.data import prepare_data, Collator, get_names
from aidsorb.transforms import Center, RandomRotation
from aidsorb.datamodules import PCDDataModule


class TestPCDDataModule(unittest.TestCase):
def setUp(self):
self.tempdir = tempfile.TemporaryDirectory(dir='/tmp')
self.outname = os.path.join(self.tempdir.name, 'pcds.npz')
self.split_ratio = [3, 1, 2]

pcd_from_dir(dirname='tests/structures', outname=self.outname)
prepare_data(source=self.outname, split_ratio=self.split_ratio)

# Arguments for the datamodule.
self.train_size = 2
self.train_trans_x = Center()
Expand All @@ -51,13 +44,13 @@ def setUp(self):
self.eval_bs = 2
self.config_dataloaders = {
'pin_memory': True,
'num_workers': 2,
'num_workers': 4,
'collate_fn': Collator()
}

# Instantiate the datamodule.
self.dm = PCDDataModule(
path_to_X=self.outname,
path_to_X='tests/dummy/toy_project/pcd_data',
path_to_Y='tests/dummy/toy_dataset.csv',
index_col='id',
labels=['y2', 'y3'],
Expand All @@ -76,9 +69,12 @@ def setUp(self):

def test_datasets(self):
# Check that the datasets have the correct size.
val_names = get_names('tests/dummy/toy_project/validation.json')
test_names = get_names('tests/dummy/toy_project/test.json')

self.assertEqual(len(self.dm.train_dataset), self.train_size)
self.assertEqual(len(self.dm.validation_dataset), self.split_ratio[1])
self.assertEqual(len(self.dm.test_dataset), self.split_ratio[2])
self.assertEqual(len(self.dm.validation_dataset), len(val_names))
self.assertEqual(len(self.dm.test_dataset), len(test_names))

# The pairwise intersections must be the empty set.
for ds_comb in combinations([
Expand All @@ -101,7 +97,6 @@ def test_datasets(self):

self.assertIs(ds.transform_y, self.trans_y)


def test_dataloaders(self):
dataloaders = [
self.dm.train_dataloader(),
Expand All @@ -114,8 +109,6 @@ def test_dataloaders(self):
self.dm.test_dataset,
]

passed_collate_fn = self.config_dataloaders['collate_fn']

for i, dl in enumerate(dataloaders):
# Check that dataloaders use appropriate settings.
if i == 0:
Expand All @@ -125,7 +118,7 @@ def test_dataloaders(self):
self.assertIsInstance(dl.sampler, SequentialSampler)
self.assertEqual(dl.batch_size, self.eval_bs)

self.assertEqual(dl.collate_fn, passed_collate_fn)
self.assertEqual(dl.collate_fn, self.config_dataloaders['collate_fn'])

# Check that collate function is used properly.
for x, y in dl:
Expand All @@ -137,4 +130,4 @@ def test_dataloaders(self):
self.assertIs(dl.dataset, ds)

def tearDown(self):
self.tempdir.cleanup()
...

0 comments on commit 3a00ebd

Please sign in to comment.