-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimageLoader.lua
145 lines (124 loc) · 5.77 KB
/
imageLoader.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
local threadPool = require('threadPool')
local util = require('./util')
local torchUtil = require('torchUtil')
local M = {}
function M.filterAllFileLists(opt)
math.randomseed( os.time() )
local hillariousMultithreadFactor = 10
for categoryIndex = 1, opt.sceneCategoryCount * hillariousMultithreadFactor do
local category = math.random(opt.sceneCategoryCount)
local inFile = opt.imageListBase .. util.zeroPad(category, 3) .. '.txt'
local outFile = opt.imageListBase .. util.zeroPad(category, 3) .. '_filtered.txt'
torchUtil.filterFileList(inFile, outFile)
end
end
function M.makeImageLoader(opt)
print('Initializing images from: ' .. opt.imageListBase)
local result = {}
result.donkeys = threadPool.makeThreadPool(opt)
result.opt = opt
result.imageLists = {}
for category = 1, opt.sceneCategoryCount do
local list = util.readAllLines(opt.imageListBase .. util.zeroPad(category, 3) .. '_filtered.txt')
table.insert(result.imageLists, list)
-- print('category ' .. category .. ' has ' .. #list .. ' images')
end
return result
end
local function loadAndResizeImage(path, opt)
local loadSize = {3, opt.imageSize, opt.imageSize}
local input = image.load(path, 3, 'float')
if input:size(2) == loadSize[2] and input:size(3) == loadSize[3] then
return input
end
-- find the smaller dimension, and resize it to loadSize (while keeping aspect ratio)
if input:size(3) < input:size(2) then
input = image.scale(input, loadSize[2], loadSize[3] * input:size(2) / input:size(3))
else
input = image.scale(input, loadSize[2] * input:size(3) / input:size(2), loadSize[3])
end
return input
end
-- function to load the image, jitter it appropriately (random crops etc.)
local function loadAndCropImage(path, opt)
local sampleSize = {3, opt.cropSize, opt.cropSize}
collectgarbage()
local input = loadAndResizeImage(path, opt)
local iW = input:size(3)
local iH = input:size(2)
-- do random crop
local oW = sampleSize[3]
local oH = sampleSize[2]
local h1 = math.ceil(torch.uniform(1e-2, iH-oH))
local w1 = math.ceil(torch.uniform(1e-2, iW-oW))
if iH == oH then h1 = 0 end
if iW == oW then w1 = 0 end
local out = image.crop(input, w1, h1, w1 + oW, h1 + oH)
assert(out:size(3) == oW)
assert(out:size(2) == oH)
-- do hflip with probability 0.5
if torch.uniform() > 0.5 then out = image.hflip(out) end
return out
end
function M.sampleBatch(imageLoader)
local opt = imageLoader.opt
local imageLists = imageLoader.imageLists
local donkeys = imageLoader.donkeys
-- pick an index of the datapoint to load next
local sourceImages = torch.FloatTensor(opt.batchSize, 3, opt.cropSize, opt.cropSize)
local grayscaleInputs = torch.FloatTensor(opt.batchSize, 1, opt.cropSize, opt.cropSize)
local RGBTargets = torch.FloatTensor(opt.batchSize, 3, opt.halfCropSize, opt.halfCropSize)
-- local LABImages = torch.FloatTensor(opt.batchSize, 3, opt.cropSize, opt.cropSize)
local thumbnails = torch.FloatTensor(opt.batchSize, 3, opt.thumbnailSize, opt.thumbnailSize)
local normalizedThumbnails = torch.FloatTensor(opt.batchSize, 3, opt.thumbnailSize, opt.thumbnailSize)
local classLabels = torch.IntTensor(opt.batchSize)
for b = 1, opt.batchSize do
local imageCategory = math.random( #imageLists )
classLabels[b] = imageCategory
local list = imageLists[imageCategory]
local imageFilename = list[ math.random( #list ) ]
donkeys:addjob(
function()
local sourceImg = loadAndCropImage(imageFilename, opt)
-- Grayscale image
local grayscale = image.rgb2y(sourceImg)
-- y is in the range 0 - 1
--[[local imgGray = torch.FloatTensor(1, opt.cropSize, opt.cropSize):zero()
grayscale:add(0.299, sourceImg:select(1, 1))
grayscale:add(0.587, sourceImg:select(1, 2))
grayscale:add(0.114, sourceImg:select(1, 3))]]
grayscale:add(-0.5)
local downscaleImg = image.scale(sourceImg, opt.halfCropSize, opt.halfCropSize)
local RGBColor = torchUtil.caffePreprocess(downscaleImg:clone())
local thumbnailImg = image.scale(sourceImg, opt.thumbnailSize, opt.thumbnailSize)
local thumbnail = image.rgb2lab(thumbnailImg)
local thumbnailNorm = torchUtil.normalizeLab(thumbnail)
-- local LABColor = torchUtil.normalizeLab(image.rgb2lab(sourceImg))
--[[local ABColor = image.rgb2lab(downscaleImg)
ABColor = ABColor[{{2,3},{},{}}]:clone()
ABColor:mul(1.0 / 100.0)]]
-- return sourceImg, grayscale, RGBColor, LABColor, thumbnail, thumbnailNorm
return sourceImg, grayscale, RGBColor, thumbnail, thumbnailNorm
end,
-- function(sourceImg, grayscale, RGBColor, LABColor, thumbnail, thumbnailNorm)
function(sourceImg, grayscale, RGBColor, thumbnail, thumbnailNorm)
sourceImages[b] = sourceImg
grayscaleInputs[b] = grayscale
RGBTargets[b] = RGBColor
-- LABImages[b] = LABColor
thumbnails[b] = thumbnail
normalizedThumbnails[b] = thumbnailNorm
end)
end
donkeys:synchronize()
local batch = {}
batch.images = sourceImages
batch.grayscaleInputs = grayscaleInputs
batch.RGBTargets = RGBTargets
-- batch.LABImages = LABImages
batch.thumbnails = thumbnails
batch.normalizedThumbnails = normalizedThumbnails
batch.classLabels = classLabels
return batch
end
return M