7
7
import os
8
8
import numpy as np
9
9
import random
10
+ import torch
11
+ from torch .autograd import Variable
10
12
import skimage
11
13
import skimage .io
12
14
import scipy .misc
13
15
16
+ from torchvision import transforms as trn
17
+ preprocess = trn .Compose ([
18
+ #trn.ToTensor(),
19
+ trn .Normalize ([0.485 , 0.456 , 0.406 ], [0.229 , 0.224 , 0.225 ])
20
+ ])
21
+
22
+ from misc .resnet_utils import myResnet
23
+ import misc .resnet as resnet
24
+
25
+ resnet = resnet .resnet101 ()
26
+ resnet .load_state_dict (torch .load ('/home-nfs/rluo/rluo/model/pytorch-resnet/resnet101.pth' ))
27
+ my_resnet = myResnet (resnet )
28
+ my_resnet .cuda ()
29
+ my_resnet .eval ()
30
+
14
31
class DataLoaderRaw ():
15
32
16
33
def __init__ (self , opt ):
@@ -65,7 +82,8 @@ def get_batch(self, split, batch_size=None):
65
82
batch_size = batch_size or self .batch_size
66
83
67
84
# pick an index of the datapoint to load next
68
- img_batch = np .ndarray ([batch_size , 224 ,224 ,3 ], dtype = 'float32' )
85
+ fc_batch = np .ndarray ((batch_size , 2048 ), dtype = 'float32' )
86
+ att_batch = np .ndarray ((batch_size , 14 , 14 , 2048 ), dtype = 'float32' )
69
87
max_index = self .N
70
88
wrapped = False
71
89
infos = []
@@ -85,15 +103,22 @@ def get_batch(self, split, batch_size=None):
85
103
img = img [:,:,np .newaxis ]
86
104
img = img .concatenate ((img , img , img ), axis = 2 )
87
105
88
- img_batch [i ] = img [16 :240 , 16 :240 , :].astype ('float32' )/ 255.0
106
+ img = img .astype ('float32' )/ 255.0
107
+ img = torch .from_numpy (img .transpose ([2 ,0 ,1 ])).cuda ()
108
+ img = Variable (preprocess (img ), volatile = True )
109
+ tmp_fc , tmp_att = my_resnet (img )
110
+
111
+ fc_batch [i ] = tmp_fc .data .cpu ().float ().numpy ()
112
+ att_batch [i ] = tmp_att .data .cpu ().float ().numpy ()
89
113
90
114
info_struct = {}
91
115
info_struct ['id' ] = self .ids [ri ]
92
116
info_struct ['file_path' ] = self .files [ri ]
93
117
infos .append (info_struct )
94
118
95
119
data = {}
96
- data ['images' ] = img_batch
120
+ data ['fc_feats' ] = fc_batch
121
+ data ['att_feats' ] = att_batch
97
122
data ['bounds' ] = {'it_pos_now' : self .iterator , 'it_max' : self .N , 'wrapped' : wrapped }
98
123
data ['infos' ] = infos
99
124
0 commit comments