Skip to content

Commit 314928f

Browse files
committed
Automatically reduce batch size for small validation/test splits
1 parent fff1c6b commit 314928f

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

eval.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ local loss = 0
5353
for i = 1, num do
5454
print(string.format('%s batch %d / %d', opt.split, i, num))
5555
local x, y = loader:nextBatch(opt.split)
56+
N = x:size(1)
5657
x = x:type(dtype)
5758
y = y:type(dtype):view(N * T)
5859
local scores = model:forward(x):view(N * T, -1)

train.lua

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,10 @@ for i = start_i + 1, num_iterations do
205205
local val_loss = 0
206206
for j = 1, num_val do
207207
local xv, yv = loader:nextBatch('val')
208+
local N_v = xv:size(1)
208209
xv = xv:type(dtype)
209-
yv = yv:type(dtype):view(N * T)
210-
local scores = model:forward(xv):view(N * T, -1)
210+
yv = yv:type(dtype):view(N_v * T)
211+
local scores = model:forward(xv):view(N_v * T, -1)
211212
val_loss = val_loss + crit:forward(scores, yv)
212213
end
213214
val_loss = val_loss / num_val

util/DataLoader.lua

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,21 @@ function DataLoader:__init(kwargs)
2424
self.split_sizes = {}
2525
for split, v in pairs(splits) do
2626
local num = v:nElement()
27-
local extra = num % (N * T)
27+
local N_cur = N
28+
if (N * T > num - 1) then
29+
N_cur = math.floor((num - 1) / T)
30+
print(string.format("Not enough %s data, reducing batch size to %d", split, N_cur))
31+
end
32+
local extra = num % (N_cur * T)
2833

2934
-- Ensure that `vy` is non-empty
3035
if extra == 0 then
31-
extra = N * T
36+
extra = N_cur * T
3237
end
3338

3439
-- Chop out the extra bits at the end to make it evenly divide
35-
local vx = v[{{1, num - extra}}]:view(N, -1, T):transpose(1, 2):clone()
36-
local vy = v[{{2, num - extra + 1}}]:view(N, -1, T):transpose(1, 2):clone()
40+
local vx = v[{{1, num - extra}}]:view(N_cur, -1, T):transpose(1, 2):clone()
41+
local vy = v[{{2, num - extra + 1}}]:view(N_cur, -1, T):transpose(1, 2):clone()
3742

3843
self.x_splits[split] = vx
3944
self.y_splits[split] = vy

0 commit comments

Comments
 (0)