forked from edenton/drnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
255 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
logs | ||
*.swp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
require 'torch' | ||
require 'paths' | ||
require 'image' | ||
require 'utils' | ||
debugger = require 'fb.debugger' | ||
|
||
|
||
local KTHDataset = torch.class('KTHLoader') | ||
|
||
if torch.getmetatable('dataLoader') == nil then | ||
torch.class('dataLoader') | ||
end | ||
|
||
|
||
function KTHDataset:__init(opt, data_type) | ||
self.data_type = data_type | ||
self.opt = opt or {} | ||
self.path = self.opt.dataRoot | ||
self.data = torch.load(('%s/%s_meta.t7'):format(self.path, data_type)) | ||
self.classes = {} | ||
for c, _ in pairs(self.data) do | ||
table.insert(self.classes, c) | ||
end | ||
|
||
print(('\n<loaded KTH %s data>'):format(data_type)) | ||
local N = 0 | ||
local shortest = 100 | ||
local longest = 0 | ||
for _, c in pairs(self.classes) do | ||
local n = 0 | ||
local data = self.data[c] | ||
for i = 1,#data do | ||
for d = 1,#data[i].indices do | ||
local len = data[i].indices[d][2] - data[i].indices[d][1] + 1 | ||
if len < 0 then debugger.enter() end | ||
shortest = math.min(shortest, len) | ||
longest = math.max(longest, len) | ||
end | ||
n = n + self.data[c][i].n | ||
N = N + n | ||
end | ||
print(('%s: %d videos (%d total frames)'):format(c, #data, n)) | ||
end | ||
self.N = N | ||
print('total frame: ' .. N) | ||
print(('min seq length = %d frames'):format(shortest)) | ||
print(('max seq length = %d frames'):format(longest)) | ||
end | ||
|
||
function KTHDataset:size() | ||
return self.N | ||
end | ||
|
||
function KTHDataset:getSequence(x, delta) | ||
local delta = math.random(1, delta or self.opt.delta or 1) | ||
local c = self.classes[math.random(#self.classes)] | ||
local vid = self.data[c][math.random(#self.data[c])] | ||
local seq = math.random(#vid.indices) | ||
local seq_length = vid.indices[seq][2] - vid.indices[seq][1] + 1 | ||
local basename = ('%s/%s/%s/'):format(self.path, c, vid.vid) | ||
|
||
local T = x:size(1) | ||
while T*delta > seq_length do | ||
delta = delta-1 | ||
if delta < 1 then return false end | ||
end | ||
|
||
local offset = math.random(seq_length-T*delta) | ||
local start = vid.indices[seq][1] | ||
for t = 1,T do | ||
local tt = start + offset+(t-1)*delta - 1 | ||
local img = image.load(('%s/image-%03d_%dx%d.png'):format(basename, tt, self.opt.imageSize, self.opt.imageSize))[1] | ||
x[t]:copy(img) | ||
end | ||
return true, c_idx | ||
end | ||
|
||
function KTHDataset:getBatch(n, T, delta) | ||
local xx = torch.Tensor(T, unpack(self.opt.geometry)) | ||
local x = {} | ||
for t=1,T do | ||
x[t] = torch.Tensor(n, unpack(self.opt.geometry)) | ||
end | ||
for i = 1,n do | ||
while not self:getSequence(xx, delta) do | ||
end | ||
for t=1,T do | ||
x[t][i]:copy(xx[t]) | ||
end | ||
end | ||
return x | ||
end | ||
|
||
function KTHDataset:plotSeq(fname) | ||
print('plotting sequence: ' .. fname) | ||
local to_plot = {} | ||
local t = 30 | ||
local n = 50 | ||
for i = 1,n do | ||
local x = self:getBatch(1, t) | ||
for j = 1,t do | ||
table.insert(to_plot, x[j][1]) | ||
end | ||
end | ||
image.save(fname, image.toDisplayTensor{input=to_plot, scaleeach=false, nrow=t}) | ||
end | ||
|
||
function KTHDataset:plot() | ||
local savedir = self.opt.save .. '/data/' | ||
os.execute('mkdir -p ' .. savedir) | ||
self:plotSeq(savedir .. '/' .. self.data_type .. '_seq.png') | ||
end | ||
|
||
trainLoader = KTHLoader(opt_t or opt, 'train') | ||
valLoader = KTHLoader(opt_t or opt, 'test') |
File renamed without changes.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
require 'data.threads' | ||
opt.T = opt.T or opt.maxStep or (opt.Tpast+opt.Tfuture) | ||
local opt_tt = opt -- need local var, opt is global | ||
trainLoader = ThreadedDatasource( | ||
function() | ||
require 'pl' | ||
opt_t = tablex.copy(opt_tt) | ||
-- opt_t = opt_tt | ||
require(('data.%s'):format(opt_t.dataset)) | ||
return trainLoader | ||
end, | ||
{ | ||
nThreads = opt_tt.nThreads, | ||
dataPool = math.ceil(opt_tt.dataPool / 10), | ||
dataWarmup = math.ceil(opt_tt.dataWarmup / 10), | ||
}) | ||
valLoader = ThreadedDatasource( | ||
function() | ||
require 'pl' | ||
opt_t = tablex.copy(opt_tt) | ||
require(('data.%s'):format(opt_t.dataset)) | ||
return valLoader | ||
end, | ||
{ | ||
nThreads = opt_tt.nThreads, | ||
dataPool = math.ceil(opt_tt.dataPool / 10), | ||
dataWarmup = math.ceil(opt_tt.dataWarmup / 10), | ||
}) | ||
|
||
trainLoader:warm() | ||
valLoader:warm() | ||
|
||
cutorch.synchronize() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
--[[ | ||
batchSize and T are fixed throughout. | ||
--]] | ||
|
||
local threads = require 'threads' | ||
|
||
local ThreadedDatasource, parent = torch.class('ThreadedDatasource') | ||
|
||
function ThreadedDatasource:__init(getDatasourceFun, params) | ||
self.nThreads = math.min(params.nThreads or 4, 2) -- XXX: fix | ||
local opt = opt | ||
|
||
self.pool_size = params.dataPool or opt.dataPool | ||
self.dataWarmup = params.dataWarmup or opt.dataWarmup | ||
self.pool = {} | ||
--threads.Threads.serialization('threads.sharedserialize') --TODO | ||
self.threads = threads.Threads(self.nThreads, | ||
function(threadid) | ||
require 'torch' | ||
require 'math' | ||
require 'os' | ||
opt_t = opt | ||
-- print(opt_t) | ||
torch.manualSeed(threadid*os.clock()) | ||
math.randomseed(threadid*os.clock()*1.7) | ||
torch.setnumthreads(1) | ||
threadid_t = threadid | ||
datasource_t = getDatasourceFun() | ||
end) | ||
self.threads:synchronize() | ||
self.threads:specific(false) | ||
end | ||
|
||
function ThreadedDatasource:warm() | ||
print("Warming up batch pool...") | ||
for i = 1, self.dataWarmup do | ||
self:fetch_batch() | ||
|
||
-- don't let the job queue get too big | ||
if i % self.nThreads * 2 == 0 then | ||
self.threads:synchronize() | ||
end | ||
xlua.progress(i, self.dataWarmup) | ||
end | ||
|
||
-- get them working in the background | ||
for i = 1, self.nThreads * 2 do | ||
self:fetch_batch() | ||
end | ||
end | ||
|
||
function ThreadedDatasource:fetch_batch() | ||
self.threads:addjob( | ||
function() | ||
collectgarbage() | ||
collectgarbage() | ||
return table.pack(datasource_t:getBatch(opt_t.batchSize, opt_t.T)) | ||
end, | ||
function(batch) | ||
collectgarbage() | ||
collectgarbage() | ||
if #self.pool < self.pool_size then | ||
table.insert(self.pool, batch) | ||
else | ||
local replacement_index = math.random(1, #self.pool) | ||
self.pool[replacement_index] = batch | ||
end | ||
end | ||
) | ||
end | ||
|
||
function ThreadedDatasource:plot() | ||
self.threads:addjob( | ||
function() | ||
collectgarbage() | ||
collectgarbage() | ||
datasource_t:plot() | ||
end, | ||
function() | ||
end | ||
) | ||
self.threads:dojob() | ||
end | ||
|
||
function ThreadedDatasource:getBatch() | ||
if self.threads:haserror() then | ||
print("ThreadedDatasource: There is an error in a thread") | ||
self.threads:terminate() | ||
os.exit(0) | ||
end | ||
|
||
-- queue has something for us | ||
-- dojob to put the newly loaded batch into the pool | ||
if self.threads.mainqueue.isempty == 0 then | ||
self.threads:dojob() | ||
self:fetch_batch() | ||
end | ||
local batch_to_use = math.random(1, #self.pool) | ||
return unpack(self.pool[batch_to_use]) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters