diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index 5e2a701..dc2ce62 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -1,5 +1,6 @@ import pytest import torchxrayvision as xrv +from skimage.io import imread, imsave dataset_classes = [xrv.datasets.NIH_Dataset, xrv.datasets.PC_Dataset, @@ -45,3 +46,21 @@ def test_dataloader_merging_incorrect_alignment(): assert "incorrect pathology alignment" in str(excinfo.value) + +def test_resize(): + + for filename in ["16747_3_1.jpg", "covid-19-pneumonia-58-prior.jpg"] + img = imread(filename) + img = xrv.datasets.normalize(img, 255) + + # Check that images are 2D arrays + if len(img.shape) > 2: + img = img[:, :, 0] + + # Add color channel + img = img[None, :, :] + + resize_ski = xrv.datasets.XRayResizer(100, engine="skimage") + resize_cv2 = xrv.datasets.XRayResizer(100, engine="cv2") + + assert(np.allclose(resize_ski(img),resize_cv2(img)))