Skip to content

Commit

Permalink
Automatically reduce batch size for small validation/test splits
Browse files Browse the repository at this point in the history
  • Loading branch information
antihutka committed May 29, 2017
1 parent fff1c6b commit 314928f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
1 change: 1 addition & 0 deletions eval.lua
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ local loss = 0
for i = 1, num do
print(string.format('%s batch %d / %d', opt.split, i, num))
local x, y = loader:nextBatch(opt.split)
N = x:size(1)
x = x:type(dtype)
y = y:type(dtype):view(N * T)
local scores = model:forward(x):view(N * T, -1)
Expand Down
5 changes: 3 additions & 2 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,10 @@ for i = start_i + 1, num_iterations do
local val_loss = 0
for j = 1, num_val do
local xv, yv = loader:nextBatch('val')
local N_v = xv:size(1)
xv = xv:type(dtype)
yv = yv:type(dtype):view(N * T)
local scores = model:forward(xv):view(N * T, -1)
yv = yv:type(dtype):view(N_v * T)
local scores = model:forward(xv):view(N_v * T, -1)
val_loss = val_loss + crit:forward(scores, yv)
end
val_loss = val_loss / num_val
Expand Down
13 changes: 9 additions & 4 deletions util/DataLoader.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,21 @@ function DataLoader:__init(kwargs)
self.split_sizes = {}
for split, v in pairs(splits) do
local num = v:nElement()
local extra = num % (N * T)
local N_cur = N
if (N * T > num - 1) then
N_cur = math.floor((num - 1) / T)
print(string.format("Not enough %s data, reducing batch size to %d", split, N_cur))
end
local extra = num % (N_cur * T)

-- Ensure that `vy` is non-empty
if extra == 0 then
extra = N * T
extra = N_cur * T
end

-- Chop out the extra bits at the end to make it evenly divide
local vx = v[{{1, num - extra}}]:view(N, -1, T):transpose(1, 2):clone()
local vy = v[{{2, num - extra + 1}}]:view(N, -1, T):transpose(1, 2):clone()
local vx = v[{{1, num - extra}}]:view(N_cur, -1, T):transpose(1, 2):clone()
local vy = v[{{2, num - extra + 1}}]:view(N_cur, -1, T):transpose(1, 2):clone()

self.x_splits[split] = vx
self.y_splits[split] = vy
Expand Down

0 comments on commit 314928f

Please sign in to comment.