Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically assign the maximum number of saved models to start_epoch when continue training. #104

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 62 additions & 9 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ opt = {
display_freq = 100, -- display the current results every display_freq iterations
save_display_freq = 5000, -- save the current display of results every save_display_freq_iterations
continue_train=0, -- if continue training, load the latest model: 1: true, 0: false
precise_model_num = false, -- if true, load the maximum number of saved models, else load the latest.
serial_batches = 0, -- if 1, takes images in order to make batches, otherwise takes them randomly
serial_batch_iter = 1, -- iter into serial image list
checkpoints_dir = './checkpoints', -- models are saved here
Expand Down Expand Up @@ -138,16 +139,68 @@ end


-- load saved models and finetune
local start_epoch = 0
if opt.continue_train == 1 then
print('loading previously trained netG...')
netG = util.load(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_G.t7'), opt)
print('loading previously trained netD...')
netD = util.load(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_D.t7'), opt)
local current_dir = paths.concat(opt.checkpoints_dir, opt.name)

local continue_txt = 'continue.txt'
if io.open(continue_txt,'r') ~= nil then
os.execute('rm -r '..continue_txt)
end

os.execute('cd '..current_dir..';'..'ls -d *.t7 | tee '..continue_txt)
local file_continue = io.open(current_dir..'/'..continue_txt,'r')
local file_content = io.open(current_dir..'/'..continue_txt,'r')
local latest_saved_num = 0
local file_content_all = file_content:read('*a')
if file_content_all ~= '' then
for line in file_continue:lines() do
local st, _ = string.find(line, '%d_net_G.t7')
if st then -- avoid latest.t7
local tmp = tonumber(string.sub(line,1, st))
if tmp > latest_saved_num then
latest_saved_num = tmp
end
end
end
local load_model_prefix = nil
if latest_saved_num == 0 then
load_model_prefix = 'latest'
print('Warning: it seems that no models whose names contains numbers pretrained, so just train with index 1')
end
if load_model_prefix == nil then
load_model_prefix = tostring(latest_saved_num)
end
print('Epoch starting at '..latest_saved_num + 1)
start_epoch = latest_saved_num

local exist_latest = io.open(current_dir..'/'..'latest_net_G.t7')
if opt.precise_model_num == true then
if latest_saved_num == 0 then
error('no models whose names contains numbers pretrained')
end
else
if exist_latest == nil then
error('No \'latest\' models saved!')
else
load_model_prefix = 'latest'
exist_latest:close()
end
end
print('loading previously trained netG...')
netG = util.load(paths.concat(opt.checkpoints_dir, opt.name, load_model_prefix..'_net_G.t7'), opt)
print('loading previously trained netD...')
netD = util.load(paths.concat(opt.checkpoints_dir, opt.name, load_model_prefix..'_net_D.t7'), opt)
else
error('no pretrained model, you\'d better train from scratch')
end
file_continue:close()
file_content:close()
else
print('define model netG...')
netG = defineG(input_nc, output_nc, ngf)
print('define model netD...')
netD = defineD(input_nc, output_nc, ndf)
print('define model netG...')
netG = defineG(input_nc, output_nc, ngf)
print('define model netD...')
netD = defineD(input_nc, output_nc, ndf)
end

print(netG)
Expand Down Expand Up @@ -333,7 +386,7 @@ local plot_data = {}
local plot_win

local counter = 0
for epoch = 1, opt.niter do
for epoch = start_epoch+1, opt.niter do
epoch_tm:reset()
for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do
tm:reset()
Expand Down