Skip to content

Commit

Permalink
kth files and data loader.
Browse files Browse the repository at this point in the history
  • Loading branch information
edenton committed Jun 7, 2017
1 parent 4fbea8d commit f32deed
Show file tree
Hide file tree
Showing 12 changed files with 255 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
logs
*.swp
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ To train the base model run:
```
th train_drnet.lua
```
or
or the model with skip connections between content encoder and decoder:
```
th train_drnet_skip.lua
```
for the model with skip connections between content encoder and decoder.


To train an LSTM on the pose vectors run:
Expand Down
115 changes: 115 additions & 0 deletions data/kth.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
require 'torch'
require 'paths'
require 'image'
require 'utils'
debugger = require 'fb.debugger'


local KTHDataset = torch.class('KTHLoader')

if torch.getmetatable('dataLoader') == nil then
torch.class('dataLoader')
end


function KTHDataset:__init(opt, data_type)
self.data_type = data_type
self.opt = opt or {}
self.path = self.opt.dataRoot
self.data = torch.load(('%s/%s_meta.t7'):format(self.path, data_type))
self.classes = {}
for c, _ in pairs(self.data) do
table.insert(self.classes, c)
end

print(('\n<loaded KTH %s data>'):format(data_type))
local N = 0
local shortest = 100
local longest = 0
for _, c in pairs(self.classes) do
local n = 0
local data = self.data[c]
for i = 1,#data do
for d = 1,#data[i].indices do
local len = data[i].indices[d][2] - data[i].indices[d][1] + 1
if len < 0 then debugger.enter() end
shortest = math.min(shortest, len)
longest = math.max(longest, len)
end
n = n + self.data[c][i].n
N = N + n
end
print(('%s: %d videos (%d total frames)'):format(c, #data, n))
end
self.N = N
print('total frame: ' .. N)
print(('min seq length = %d frames'):format(shortest))
print(('max seq length = %d frames'):format(longest))
end

function KTHDataset:size()
return self.N
end

function KTHDataset:getSequence(x, delta)
local delta = math.random(1, delta or self.opt.delta or 1)
local c = self.classes[math.random(#self.classes)]
local vid = self.data[c][math.random(#self.data[c])]
local seq = math.random(#vid.indices)
local seq_length = vid.indices[seq][2] - vid.indices[seq][1] + 1
local basename = ('%s/%s/%s/'):format(self.path, c, vid.vid)

local T = x:size(1)
while T*delta > seq_length do
delta = delta-1
if delta < 1 then return false end
end

local offset = math.random(seq_length-T*delta)
local start = vid.indices[seq][1]
for t = 1,T do
local tt = start + offset+(t-1)*delta - 1
local img = image.load(('%s/image-%03d_%dx%d.png'):format(basename, tt, self.opt.imageSize, self.opt.imageSize))[1]
x[t]:copy(img)
end
return true, c_idx
end

function KTHDataset:getBatch(n, T, delta)
local xx = torch.Tensor(T, unpack(self.opt.geometry))
local x = {}
for t=1,T do
x[t] = torch.Tensor(n, unpack(self.opt.geometry))
end
for i = 1,n do
while not self:getSequence(xx, delta) do
end
for t=1,T do
x[t][i]:copy(xx[t])
end
end
return x
end

function KTHDataset:plotSeq(fname)
print('plotting sequence: ' .. fname)
local to_plot = {}
local t = 30
local n = 50
for i = 1,n do
local x = self:getBatch(1, t)
for j = 1,t do
table.insert(to_plot, x[j][1])
end
end
image.save(fname, image.toDisplayTensor{input=to_plot, scaleeach=false, nrow=t})
end

function KTHDataset:plot()
local savedir = self.opt.save .. '/data/'
os.execute('mkdir -p ' .. savedir)
self:plotSeq(savedir .. '/' .. self.data_type .. '_seq.png')
end

trainLoader = KTHLoader(opt_t or opt, 'train')
valLoader = KTHLoader(opt_t or opt, 'test')
File renamed without changes.
Binary file added data/kth_files/test_meta.t7
Binary file not shown.
Binary file added data/kth_files/test_meta_clipped.t7
Binary file not shown.
Binary file added data/kth_files/train_meta.t7
Binary file not shown.
Binary file added data/kth_files/train_meta_clipped.t7
Binary file not shown.
33 changes: 33 additions & 0 deletions data/threaded.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
require 'data.threads'
opt.T = opt.T or opt.maxStep or (opt.Tpast+opt.Tfuture)
local opt_tt = opt -- need local var, opt is global
trainLoader = ThreadedDatasource(
function()
require 'pl'
opt_t = tablex.copy(opt_tt)
-- opt_t = opt_tt
require(('data.%s'):format(opt_t.dataset))
return trainLoader
end,
{
nThreads = opt_tt.nThreads,
dataPool = math.ceil(opt_tt.dataPool / 10),
dataWarmup = math.ceil(opt_tt.dataWarmup / 10),
})
valLoader = ThreadedDatasource(
function()
require 'pl'
opt_t = tablex.copy(opt_tt)
require(('data.%s'):format(opt_t.dataset))
return valLoader
end,
{
nThreads = opt_tt.nThreads,
dataPool = math.ceil(opt_tt.dataPool / 10),
dataWarmup = math.ceil(opt_tt.dataWarmup / 10),
})

trainLoader:warm()
valLoader:warm()

cutorch.synchronize()
100 changes: 100 additions & 0 deletions data/threads.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
--[[
batchSize and T are fixed throughout.
--]]

local threads = require 'threads'

local ThreadedDatasource, parent = torch.class('ThreadedDatasource')

function ThreadedDatasource:__init(getDatasourceFun, params)
self.nThreads = math.min(params.nThreads or 4, 2) -- XXX: fix
local opt = opt

self.pool_size = params.dataPool or opt.dataPool
self.dataWarmup = params.dataWarmup or opt.dataWarmup
self.pool = {}
--threads.Threads.serialization('threads.sharedserialize') --TODO
self.threads = threads.Threads(self.nThreads,
function(threadid)
require 'torch'
require 'math'
require 'os'
opt_t = opt
-- print(opt_t)
torch.manualSeed(threadid*os.clock())
math.randomseed(threadid*os.clock()*1.7)
torch.setnumthreads(1)
threadid_t = threadid
datasource_t = getDatasourceFun()
end)
self.threads:synchronize()
self.threads:specific(false)
end

function ThreadedDatasource:warm()
print("Warming up batch pool...")
for i = 1, self.dataWarmup do
self:fetch_batch()

-- don't let the job queue get too big
if i % self.nThreads * 2 == 0 then
self.threads:synchronize()
end
xlua.progress(i, self.dataWarmup)
end

-- get them working in the background
for i = 1, self.nThreads * 2 do
self:fetch_batch()
end
end

function ThreadedDatasource:fetch_batch()
self.threads:addjob(
function()
collectgarbage()
collectgarbage()
return table.pack(datasource_t:getBatch(opt_t.batchSize, opt_t.T))
end,
function(batch)
collectgarbage()
collectgarbage()
if #self.pool < self.pool_size then
table.insert(self.pool, batch)
else
local replacement_index = math.random(1, #self.pool)
self.pool[replacement_index] = batch
end
end
)
end

function ThreadedDatasource:plot()
self.threads:addjob(
function()
collectgarbage()
collectgarbage()
datasource_t:plot()
end,
function()
end
)
self.threads:dojob()
end

function ThreadedDatasource:getBatch()
if self.threads:haserror() then
print("ThreadedDatasource: There is an error in a thread")
self.threads:terminate()
os.exit(0)
end

-- queue has something for us
-- dojob to put the newly loaded batch into the pool
if self.threads.mainqueue.isempty == 0 then
self.threads:dojob()
self:fetch_batch()
end
local batch_to_use = math.random(1, #self.pool)
return unpack(self.pool[batch_to_use])
end
2 changes: 1 addition & 1 deletion train_drnet.lua
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ function train(x_cpu)
end

if opt.nThreads > 0 then
dofile(('data/%s_threaded.lua'):format(opt.dataset))
dofile('data/threaded.lua')
else
dofile(('data/%s.lua'):format(opt.dataset))
end
Expand Down
6 changes: 3 additions & 3 deletions train_drnet_skip.lua
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ function train(x_cpu)
end

if opt.nThreads > 0 then
dofile(('data/%s_threaded.lua'):format(opt.dataset))
dofile('data/threaded.lua')
else
dofile(('data/%s.lua'):format(opt.dataset))
end
Expand Down Expand Up @@ -391,7 +391,7 @@ while true do
if pred_mse/iter < best then
best = pred_mse / iter
print(('Saving best model so far (pred mse = %.4f) %s/model_best.t7'):format(pred_mse/iter, opt.save))
torch.save(('%s/model_best.t7'):format(opt.save), {netEC=sanitize(netEC), netEP=sanitize(netEP), opt=opt, epoch=epoch, best=best, total_iter=total_iter})
--torch.save(('%s/model_best.t7'):format(opt.save), {netEC=sanitize(netEC), netEP=sanitize(netEP), opt=opt, epoch=epoch, best=best, total_iter=total_iter})
end

-- plot
Expand All @@ -405,7 +405,7 @@ while true do

if epoch % 1 == 0 then
print(('Saving model %s/model.t7'):format(opt.save))
torch.save(('%s/model.t7'):format(opt.save), {netC=sanitize(netC), netEC=sanitize(netEC), netEP=sanitize(netEP), opt=opt, epoch=epoch, best=best, total_iter=total_iter})
--torch.save(('%s/model.t7'):format(opt.save), {netC=sanitize(netC), netEC=sanitize(netEC), netEP=sanitize(netEP), opt=opt, epoch=epoch, best=best, total_iter=total_iter})
end
epoch = epoch+1
if epoch > opt.nEpochs then break end
Expand Down

0 comments on commit f32deed

Please sign in to comment.