From ca377b40223702f5ba3d3f8dc6ffb8d771c6917f Mon Sep 17 00:00:00 2001 From: LambdaWill <574819595@qq.com> Date: Mon, 14 Aug 2017 17:48:50 +0800 Subject: [PATCH 1/5] add start_epoch --- train.lua | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/train.lua b/train.lua index 1d658475..d5c66a18 100644 --- a/train.lua +++ b/train.lua @@ -138,11 +138,40 @@ end -- load saved models and finetune +local start_epoch = 1 if opt.continue_train == 1 then print('loading previously trained netG...') netG = util.load(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_G.t7'), opt) print('loading previously trained netD...') netD = util.load(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_D.t7'), opt) + local current_dir = paths.concat(opt.checkpoints_dir, opt.name) + + local current_dir = paths.concat(opt.checkpoints_dir, opt.name) + + local continue_txt = 'continue.txt' + if io.open(continue_txt,'r') ~= nil then + os.execute('rm -r '..continue_txt) + end + + os.execute('cd '..current_dir..';'..'ls -d *.t7 | tee '..continue_txt) + local file_continue = io.open(current_dir..'/'..continue_txt,'r') + + local latest_saved_num = 0 + for line in file_continue:lines() do + local st, _ = string.find(line, '%d_net_G.t7') + if st then -- avoid latest.t7 + local tmp = tonumber(string.sub(line,1, st)) + if tmp > latest_saved_num then + latest_saved_num = tmp + end + end + end + if latest_saved_num == 0 then + print('Warning there seems no number saved model, so just train with starting epoch 1') + end + print('using the number of latest saved model approximate lastest model, it seems no better way...') + print('Epoch starting at '..latest_saved_num) + start_epoch = latest_saved_num else print('define model netG...') netG = defineG(input_nc, output_nc, ngf) @@ -333,7 +362,7 @@ local plot_data = {} local plot_win local counter = 0 -for epoch = 1, opt.niter do +for epoch = start_epoch, opt.niter do epoch_tm:reset() for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do tm:reset() From 8ed195970656f459ae9e89067ec190a4f0144d09 Mon Sep 17 00:00:00 2001 From: LambdaWill <574819595@qq.com> Date: Mon, 14 Aug 2017 18:34:53 +0800 Subject: [PATCH 2/5] add option for precise_model_num --- train.lua | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/train.lua b/train.lua index d5c66a18..babb80b1 100644 --- a/train.lua +++ b/train.lua @@ -40,6 +40,7 @@ opt = { display_freq = 100, -- display the current results every display_freq iterations save_display_freq = 5000, -- save the current display of results every save_display_freq_iterations continue_train=0, -- if continue training, load the latest model: 1: true, 0: false + precise_model_num = false, -- if true, load the maximum number of saved models, else load the latest. serial_batches = 0, -- if 1, takes images in order to make batches, otherwise takes them randomly serial_batch_iter = 1, -- iter into serial image list checkpoints_dir = './checkpoints', -- models are saved here @@ -140,12 +141,6 @@ end -- load saved models and finetune local start_epoch = 1 if opt.continue_train == 1 then - print('loading previously trained netG...') - netG = util.load(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_G.t7'), opt) - print('loading previously trained netD...') - netD = util.load(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_D.t7'), opt) - local current_dir = paths.concat(opt.checkpoints_dir, opt.name) - local current_dir = paths.concat(opt.checkpoints_dir, opt.name) local continue_txt = 'continue.txt' @@ -155,7 +150,6 @@ if opt.continue_train == 1 then os.execute('cd '..current_dir..';'..'ls -d *.t7 | tee '..continue_txt) local file_continue = io.open(current_dir..'/'..continue_txt,'r') - local latest_saved_num = 0 for line in file_continue:lines() do local st, _ = string.find(line, '%d_net_G.t7') @@ -172,6 +166,15 @@ if opt.continue_train == 1 then print('using the number of latest saved model approximate lastest model, it seems no better way...') print('Epoch starting at '..latest_saved_num) start_epoch = latest_saved_num + + local load_model_prefix = tostring(latest_saved_num) + if opt.precise_model_num == false then + load_model_prefix = 'latest' + end + print('loading previously trained netG...') + netG = util.load(paths.concat(opt.checkpoints_dir, opt.name, load_model_prefix..'_net_G.t7'), opt) + print('loading previously trained netD...') + netD = util.load(paths.concat(opt.checkpoints_dir, opt.name, load_model_prefix..'_net_D.t7'), opt) else print('define model netG...') netG = defineG(input_nc, output_nc, ngf) From 28d6c50d5c249748692f08c88ec080600658c9c6 Mon Sep 17 00:00:00 2001 From: LambdaWill <574819595@qq.com> Date: Tue, 15 Aug 2017 10:22:35 +0800 Subject: [PATCH 3/5] fix bugs --- train.lua | 67 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/train.lua b/train.lua index babb80b1..71859627 100644 --- a/train.lua +++ b/train.lua @@ -139,47 +139,54 @@ end -- load saved models and finetune -local start_epoch = 1 +local start_epoch = 0 if opt.continue_train == 1 then local current_dir = paths.concat(opt.checkpoints_dir, opt.name) local continue_txt = 'continue.txt' if io.open(continue_txt,'r') ~= nil then - os.execute('rm -r '..continue_txt) + os.execute('rm -r '..continue_txt) end os.execute('cd '..current_dir..';'..'ls -d *.t7 | tee '..continue_txt) local file_continue = io.open(current_dir..'/'..continue_txt,'r') + local file_content = io.open(current_dir..'/'..continue_txt,'r') local latest_saved_num = 0 - for line in file_continue:lines() do - local st, _ = string.find(line, '%d_net_G.t7') - if st then -- avoid latest.t7 - local tmp = tonumber(string.sub(line,1, st)) - if tmp > latest_saved_num then - latest_saved_num = tmp - end - end - end - if latest_saved_num == 0 then - print('Warning there seems no number saved model, so just train with starting epoch 1') - end - print('using the number of latest saved model approximate lastest model, it seems no better way...') - print('Epoch starting at '..latest_saved_num) - start_epoch = latest_saved_num - - local load_model_prefix = tostring(latest_saved_num) - if opt.precise_model_num == false then - load_model_prefix = 'latest' + if file_content_all ~= '' then + for line in file_continue:lines() do + local st, _ = string.find(line, '%d_net_G.t7') + if st then -- avoid latest.t7 + local tmp = tonumber(string.sub(line,1, st)) + if tmp > latest_saved_num then + latest_saved_num = tmp + end + end + end + if latest_saved_num == 0 then + print('Warning: it seems that no models with numbers pretrained, so just train with index 1') + end + print('Using the number of latest saved model approximate lastest model, it seems no better way...') + print('Epoch starting at '..latest_saved_num + 1) + start_epoch = latest_saved_num + + local load_model_prefix = tostring(latest_saved_num) + if opt.precise_model_num == false then + load_model_prefix = 'latest' + end + print('loading previously trained netG...') + netG = util.load(paths.concat(opt.checkpoints_dir, opt.name, load_model_prefix..'_net_G.t7'), opt) + print('loading previously trained netD...') + netD = util.load(paths.concat(opt.checkpoints_dir, opt.name, load_model_prefix..'_net_D.t7'), opt) + else + error('no pretrained model, you\'d better train from scratch') end - print('loading previously trained netG...') - netG = util.load(paths.concat(opt.checkpoints_dir, opt.name, load_model_prefix..'_net_G.t7'), opt) - print('loading previously trained netD...') - netD = util.load(paths.concat(opt.checkpoints_dir, opt.name, load_model_prefix..'_net_D.t7'), opt) + file_continue:close() + file_content:close() else - print('define model netG...') - netG = defineG(input_nc, output_nc, ngf) - print('define model netD...') - netD = defineD(input_nc, output_nc, ndf) + print('define model netG...') + netG = defineG(input_nc, output_nc, ngf) + print('define model netD...') + netD = defineD(input_nc, output_nc, ndf) end print(netG) @@ -365,7 +372,7 @@ local plot_data = {} local plot_win local counter = 0 -for epoch = start_epoch, opt.niter do +for epoch = start_epoch+1, opt.niter do epoch_tm:reset() for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do tm:reset() From dc4df00b5822bdd104d5d1fb9781b9818ea23363 Mon Sep 17 00:00:00 2001 From: LambdaWill <574819595@qq.com> Date: Sun, 20 Aug 2017 12:23:56 +0800 Subject: [PATCH 4/5] fix missing code --- train.lua | 1 + 1 file changed, 1 insertion(+) diff --git a/train.lua b/train.lua index 71859627..121daffa 100644 --- a/train.lua +++ b/train.lua @@ -152,6 +152,7 @@ if opt.continue_train == 1 then local file_continue = io.open(current_dir..'/'..continue_txt,'r') local file_content = io.open(current_dir..'/'..continue_txt,'r') local latest_saved_num = 0 + local file_content_all = file_content:read('*a') if file_content_all ~= '' then for line in file_continue:lines() do local st, _ = string.find(line, '%d_net_G.t7') From 1cde2af830a640381d5b9e23fd6084d6aaa55c05 Mon Sep 17 00:00:00 2001 From: LambdaWill <574819595@qq.com> Date: Wed, 11 Oct 2017 21:22:35 +0800 Subject: [PATCH 5/5] Very rubost automatically loading model. --- train.lua | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/train.lua b/train.lua index 121daffa..edf89f4a 100644 --- a/train.lua +++ b/train.lua @@ -163,16 +163,29 @@ if opt.continue_train == 1 then end end end + local load_model_prefix = nil if latest_saved_num == 0 then - print('Warning: it seems that no models with numbers pretrained, so just train with index 1') + load_model_prefix = 'latest' + print('Warning: it seems that no models whose names contains numbers pretrained, so just train with index 1') + end + if load_model_prefix == nil then + load_model_prefix = tostring(latest_saved_num) end - print('Using the number of latest saved model approximate lastest model, it seems no better way...') print('Epoch starting at '..latest_saved_num + 1) start_epoch = latest_saved_num - local load_model_prefix = tostring(latest_saved_num) - if opt.precise_model_num == false then - load_model_prefix = 'latest' + local exist_latest = io.open(current_dir..'/'..'latest_net_G.t7') + if opt.precise_model_num == true then + if latest_saved_num == 0 then + error('no models whose names contains numbers pretrained') + end + else + if exist_latest == nil then + error('No \'latest\' models saved!') + else + load_model_prefix = 'latest' + exist_latest:close() + end end print('loading previously trained netG...') netG = util.load(paths.concat(opt.checkpoints_dir, opt.name, load_model_prefix..'_net_G.t7'), opt)