From 314928fd64aa2e337f2b52c5b7f96df6cb0407d4 Mon Sep 17 00:00:00 2001 From: Michal Povinsky Date: Mon, 29 May 2017 10:30:15 +0200 Subject: [PATCH] Automatically reduce batch size for small validation/test splits --- eval.lua | 1 + train.lua | 5 +++-- util/DataLoader.lua | 13 +++++++++---- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/eval.lua b/eval.lua index 591fa523..1d5e9b6b 100644 --- a/eval.lua +++ b/eval.lua @@ -53,6 +53,7 @@ local loss = 0 for i = 1, num do print(string.format('%s batch %d / %d', opt.split, i, num)) local x, y = loader:nextBatch(opt.split) + N = x:size(1) x = x:type(dtype) y = y:type(dtype):view(N * T) local scores = model:forward(x):view(N * T, -1) diff --git a/train.lua b/train.lua index fa00af16..52210ec8 100644 --- a/train.lua +++ b/train.lua @@ -205,9 +205,10 @@ for i = start_i + 1, num_iterations do local val_loss = 0 for j = 1, num_val do local xv, yv = loader:nextBatch('val') + local N_v = xv:size(1) xv = xv:type(dtype) - yv = yv:type(dtype):view(N * T) - local scores = model:forward(xv):view(N * T, -1) + yv = yv:type(dtype):view(N_v * T) + local scores = model:forward(xv):view(N_v * T, -1) val_loss = val_loss + crit:forward(scores, yv) end val_loss = val_loss / num_val diff --git a/util/DataLoader.lua b/util/DataLoader.lua index 3722601c..531f526e 100644 --- a/util/DataLoader.lua +++ b/util/DataLoader.lua @@ -24,16 +24,21 @@ function DataLoader:__init(kwargs) self.split_sizes = {} for split, v in pairs(splits) do local num = v:nElement() - local extra = num % (N * T) + local N_cur = N + if (N * T > num - 1) then + N_cur = math.floor((num - 1) / T) + print(string.format("Not enough %s data, reducing batch size to %d", split, N_cur)) + end + local extra = num % (N_cur * T) -- Ensure that `vy` is non-empty if extra == 0 then - extra = N * T + extra = N_cur * T end -- Chop out the extra bits at the end to make it evenly divide - local vx = v[{{1, num - extra}}]:view(N, -1, T):transpose(1, 2):clone() - local vy = v[{{2, num - extra + 1}}]:view(N, -1, T):transpose(1, 2):clone() + local vx = v[{{1, num - extra}}]:view(N_cur, -1, T):transpose(1, 2):clone() + local vy = v[{{2, num - extra + 1}}]:view(N_cur, -1, T):transpose(1, 2):clone() self.x_splits[split] = vx self.y_splits[split] = vy