This repository has been archived by the owner on Nov 29, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDataset.py
43 lines (38 loc) · 1.81 KB
/
Dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from PIL import Image
from torchvision.datasets.vision import *
from torchvision.transforms import ToTensor
import glob
import os
class ImgDataset(VisionDataset):
_filetype = ('jpg', 'jpeg', 'png', 'ppm', 'bmp', 'pgm', 'tif', 'tiff', 'webp')
def __init__(self, root: str, transforms: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, scale=2, LR_dir: str = 'LR', HR_dir: str = 'HR',
prefix: str = '', subfix: str = '') -> None:
super().__init__(root, transforms, transform, target_transform)
self.scale = scale
self.root = root
self.lr_dir = os.path.join(self.root, LR_dir)
self.hr_dir = os.path.join(self.root, HR_dir)
# LR_files = glob.glob(self.lr_dir + '/**', recursive=True)
HR_files = glob.glob(self.hr_dir + '/**', recursive=True)
self.imgs = []
for file in HR_files:
if os.path.isdir(file):
continue
filename = os.path.basename(file).rsplit('.', maxsplit=1)
self.imgs.append((os.path.join(self.lr_dir, prefix + filename[0] + subfix + '.' + filename[-1]), file))
self.length = len(self.imgs)
self.toTensor = ToTensor()
def __getitem__(self, index: int) -> Tuple[Any, Any]:
imgFile = self.imgs[index]
img = Image.open(imgFile[0]).convert('RGB')
target = Image.open(imgFile[1]).convert('RGB')
img = img.resize((img.size[0] * self.scale, img.size[1] * self.scale), Image.Resampling.BICUBIC)
if self.transform is not None:
img, target = self.transforms(img, target)
else:
img = self.toTensor(img)
target = self.toTensor(target)
return img, target
def __len__(self) -> int:
return self.length