-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
26 lines (24 loc) · 967 Bytes
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from torchvision import datasets
from torchvision import transforms
import torch.utils.data
class perchDataloader(datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
def __init__(self,root,transform):
super(perchDataloader,self).__init__(root,transform)
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(perchDataloader, self).__getitem__(index)
#hard code the gt to be the first one
if index == 0:
return (original_tuple[0],-1)
return (original_tuple[0],[index-1])
# EXAMPLE USAGE:
# instantiate the dataset and dataloader
class outputData:
def loadedData(data_dir):
dataset = perchDataloader(data_dir, transform=transforms.ToTensor()) # our custom dataset
dataloader = torch.utils.data.DataLoader(dataset)
return dataloader