-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdataloader_mtl.py
73 lines (54 loc) · 2.31 KB
/
dataloader_mtl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch.utils.data as data
import pandas as pd
import os
from PIL import Image
class ArtDatasetMTL(data.Dataset):
def __init__(self, args_dict, set, att2i, transform = None):
"""
Args:
args_dict: parameters dictionary
set: 'train', 'val', 'test'
att2i: list of attribute vocabularies as [type2idx, school2idx, time2idx, author2idx]
transform: data transform
"""
self.args_dict = args_dict
self.set = set
# Load data
if self.set == 'train':
textfile = os.path.join(args_dict.dir_dataset, args_dict.csvtrain)
elif self.set == 'val':
textfile = os.path.join(args_dict.dir_dataset, args_dict.csvval)
elif self.set == 'test':
textfile = os.path.join(args_dict.dir_dataset, args_dict.csvtest)
df = pd.read_csv(textfile, delimiter='\t', encoding='Cp1252')
self.imagefolder = os.path.join(args_dict.dir_dataset, args_dict.dir_images)
self.transform = transform
self.type_vocab = att2i[0]
self.school_vocab = att2i[1]
self.time_vocab = att2i[2]
self.author_vocab = att2i[3]
self.imageurls = list(df['IMAGE_FILE'])
self.type = list(df['TYPE'])
self.school = list(df['SCHOOL'])
self.time = list(df['TIMEFRAME'])
self.author = list(df['AUTHOR'])
def __len__(self):
return len(self.imageurls)
def class_from_name(self, vocab, name):
if vocab.has_key(name):
idclass= vocab[name]
else:
idclass = vocab['UNK']
return idclass
def __getitem__(self, index):
# Load image & apply transformation
imagepath = self.imagefolder + self.imageurls[index]
image = Image.open(imagepath).convert('RGB')
if self.transform is not None:
image = self.transform(image)
# Attribute class
type_idclass = self.class_from_name(self.type_vocab, self.type[index])
school_idclass = self.class_from_name(self.school_vocab, self.school[index])
time_idclass = self.class_from_name(self.time_vocab, self.time[index])
author_idclass = self.class_from_name(self.author_vocab, self.author[index])
return [image], [type_idclass, school_idclass, time_idclass, author_idclass]