From 3048b8751b96c57866ea2fe5ff75df083b824ffc Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 3 Dec 2015 18:21:41 -0500 Subject: [PATCH 1/4] [extra supervision] initial commit --- python/new_train.py | 190 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 190 insertions(+) create mode 100644 python/new_train.py diff --git a/python/new_train.py b/python/new_train.py new file mode 100644 index 00000000..cc0cf900 --- /dev/null +++ b/python/new_train.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python +__doc__ = """ + +Jingpeng Wu , 2015 +""" +import time +from front_end import * +import cost_fn +import test +import utils +import zstatistics +import os + +def main( conf_file='config.cfg', logfile=None ): + #%% parameters + print "reading config parameters..." + config, pars = zconfig.parser( conf_file ) + + if pars.has_key('logging') and pars['logging']: + print "recording configuration file..." + zconfig.record_config_file( pars ) + + logfile = zlog.make_logfile_name( pars ) + + #%% create and initialize the network + if pars['train_load_net'] and os.path.exists(pars['train_load_net']): + print "loading network..." + net = znetio.load_network( pars ) + # load existing learning curve + lc = zstatistics.CLearnCurve( pars['train_load_net'] ) + # the last iteration we want to continue training + iter_last = lc.get_last_it() + else: + if pars['train_seed_net'] and os.path.exists(pars['train_seed_net']): + print "seeding network..." + net = znetio.load_network( pars, is_seed=True ) + else: + print "initializing network..." + net = znetio.init_network( pars ) + # initalize a learning curve + lc = zstatistics.CLearnCurve() + iter_last = lc.get_last_it() + + # show field of view + print "field of view: ", net.get_fov() + + # total voxel number of output volumes + vn = utils.get_total_num(net.get_outputs_setsz()) + + # set some parameters + print 'setting up the network...' + eta = pars['eta'] + net.set_eta( pars['eta'] ) + net.set_momentum( pars['momentum'] ) + net.set_weight_decay( pars['weight_decay'] ) + + # initialize samples + outsz = pars['train_outsz'] + print "\n\ncreate train samples..." + smp_trn = zsample.CSamples(config, pars, pars['train_range'], net, outsz, logfile) + print "\n\ncreate test samples..." + smp_tst = zsample.CSamples(config, pars, pars['test_range'], net, outsz, logfile) + + # initialization + elapsed = 0 + err = 0.0 # cost energy + cls = 0.0 # pixel classification error + re = 0.0 # rand error + # number of voxels which accumulate error + # (if a mask exists) + num_mask_voxels = 0 + + if pars['is_malis']: + malis_cls = 0.0 + + print "start training..." + start = time.time() + total_time = 0.0 + print "start from ", iter_last+1 + + #Saving initialized network + if iter_last+1 == 1: + znetio.save_network(net, pars['train_save_net'], num_iters=0) + + for i in xrange(iter_last+1, pars['Max_iter']+1): + # get random sub volume from sample + vol_ins, lbl_outs, msks, wmsks = smp_trn.get_random_sample() + + # forward pass + # apply the transformations in memory rather than array view + vol_ins = utils.make_continuous(vol_ins, dtype=pars['dtype']) + props = net.forward( vol_ins ) + + # cost function and accumulate errors + props, cerr, grdts = pars['cost_fn']( props, lbl_outs, msks ) + err += cerr + cls += cost_fn.get_cls(props, lbl_outs) + num_mask_voxels += utils.sum_over_dict(msks) + + # gradient reweighting + grdts = utils.dict_mul( grdts, msks ) + grdts = utils.dict_mul( grdts, wmsks ) + + if pars['is_malis'] : + malis_weights, rand_errors = cost_fn.malis_weight(pars, props, lbl_outs) + grdts = utils.dict_mul(grdts, malis_weights) + # accumulate the rand error + re += rand_errors.values()[0] + malis_cls_dict = utils.get_malis_cls( props, lbl_outs, malis_weights ) + malis_cls += malis_cls_dict.values()[0] + + + total_time += time.time() - start + start = time.time() + + # test the net + if i%pars['Num_iter_per_test']==0: + lc = test.znn_test(net, pars, smp_tst, vn, i, lc) + + if i%pars['Num_iter_per_show']==0: + # normalize + if utils.dict_mask_empty(msks): + err = err / vn / pars['Num_iter_per_show'] + cls = cls / vn / pars['Num_iter_per_show'] + else: + err = err / num_mask_voxels / pars['Num_iter_per_show'] + cls = cls / num_mask_voxels / pars['Num_iter_per_show'] + + lc.append_train(i, err, cls) + + # time + elapsed = total_time / pars['Num_iter_per_show'] + + if pars['is_malis']: + re = re / pars['Num_iter_per_show'] + lc.append_train_rand_error( re ) + malis_cls = malis_cls / pars['Num_iter_per_show'] + lc.append_train_malis_cls( malis_cls ) + + show_string = "iteration %d, err: %.3f, cls: %.3f, re: %.6f, mc: %.3f, elapsed: %.1f s/iter, learning rate: %.6f"\ + %(i, err, cls, re, malis_cls, elapsed, eta ) + else: + show_string = "iteration %d, err: %.3f, cls: %.3f, elapsed: %.1f s/iter, learning rate: %.6f"\ + %(i, err, cls, elapsed, eta ) + + if pars.has_key('logging') and pars['logging']: + utils.write_to_log(logfile, show_string) + print show_string + + # reset err and cls + err = 0 + cls = 0 + re = 0 + num_mask_voxels = 0 + + if pars['is_malis']: + malis_cls = 0 + + # reset time + total_time = 0 + start = time.time() + + if i%pars['Num_iter_per_annealing']==0: + # anneal factor + eta = eta * pars['anneal_factor'] + net.set_eta(eta) + + if i%pars['Num_iter_per_save']==0: + # save network + znetio.save_network(net, pars['train_save_net'], num_iters=i) + lc.save( pars, elapsed ) + if pars['is_malis']: + utils.save_malis(malis_weights, pars['train_save_net'], num_iters=i) + + # run backward pass + grdts = utils.make_continuous(grdts, dtype=pars['dtype']) + net.backward( grdts ) + + +if __name__ == '__main__': + """ + usage + ------ + python train.py path/to/config.cfg + """ + import sys + if len(sys.argv)>1: + main( sys.argv[1] ) + else: + main() From 928e84e6afdba0e0b41402cd630c86602605708d Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 15 Jan 2016 15:28:37 -0500 Subject: [PATCH 2/4] fix issue#39 --- src/include/network/parallel/network.hpp | 57 +++++++++++++----------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/src/include/network/parallel/network.hpp b/src/include/network/parallel/network.hpp index f9852c45..60a8f7a8 100644 --- a/src/include/network/parallel/network.hpp +++ b/src/include/network/parallel/network.hpp @@ -123,40 +123,47 @@ class network #endif - void fov_pass(nnodes* n, const vec3i& fov, const vec3i& fsize ) + void fov_pass(nnodes* n, vec3i fov, const vec3i& fsize ) { if ( n->fov != vec3i::zero ) { ZI_ASSERT(n->fsize==fsize); - ZI_ASSERT(n->fov==fov); + if ( n->fov == fov ) + { + return; + } + else + { + fov[0] = std::max(n->fov[0],fov[0]); + fov[1] = std::max(n->fov[1],fov[1]); + fov[2] = std::max(n->fov[2],fov[2]); + } } - else + + for ( auto& e: n->out ) { - for ( auto& e: n->out ) + e->in_fsize = fsize; + } + n->fov = fov; + n->fsize = fsize; + for ( auto& e: n->in ) + { + if ( e->pool ) { - e->in_fsize = fsize; + vec3i new_fov = fov * e->width; + vec3i new_fsize = e->width * fsize; + fov_pass(e->in, new_fov, new_fsize); } - n->fov = fov; - n->fsize = fsize; - for ( auto& e: n->in ) + else if ( e->crop ) { - if ( e->pool ) - { - vec3i new_fov = fov * e->width; - vec3i new_fsize = e->width * fsize; - fov_pass(e->in, new_fov, new_fsize); - } - else if ( e->crop ) - { - // FoV doesn't change - fov_pass(e->in, fov, fsize + e->width - vec3i::one); - } - else - { - vec3i new_fov = (fov - vec3i::one) * e->stride + e->width; - vec3i new_fsize = (e->width-vec3i::one) * e->in_stride + fsize; - fov_pass(e->in, new_fov, new_fsize); - } + // FoV doesn't change + fov_pass(e->in, fov, fsize + e->width - vec3i::one); + } + else + { + vec3i new_fov = (fov - vec3i::one) * e->stride + e->width; + vec3i new_fsize = (e->width-vec3i::one) * e->in_stride + fsize; + fov_pass(e->in, new_fov, new_fsize); } } } From a3530f164d671fa250cda3844c589075ba190645 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 15 Jan 2016 15:31:01 -0500 Subject: [PATCH 3/4] fix issue#42 --- python/front_end/zsample.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/python/front_end/zsample.py b/python/front_end/zsample.py index ebea7584..4b65df93 100644 --- a/python/front_end/zsample.py +++ b/python/front_end/zsample.py @@ -56,21 +56,22 @@ def __init__(self, config, pars, sample_id, net, \ outsz, setsz_in, fov, is_forward=is_forward ) self.imgs[name] = self.ins[name].data - print "\ncreate label image class..." self.lbls = dict() self.msks = dict() self.outs = dict() - for name,setsz_out in self.setsz_outs.iteritems(): - #Allowing for users to abstain from specifying labels - if not config.has_option(self.sec_name, name): - continue - #Finding the section of the config file - imid = config.getint(self.sec_name, name) - imsec_name = "label%d" % (imid,) - self.outs[name] = ConfigOutputLabel( config, pars, imsec_name, \ - outsz, setsz_out, fov) - self.lbls[name] = self.outs[name].data - self.msks[name] = self.outs[name].msk + if not is_forward: + print "\ncreate label image class..." + for name,setsz_out in self.setsz_outs.iteritems(): + #Allowing for users to abstain from specifying labels + if not config.has_option(self.sec_name, name): + continue + #Finding the section of the config file + imid = config.getint(self.sec_name, name) + imsec_name = "label%d" % (imid,) + self.outs[name] = ConfigOutputLabel( config, pars, imsec_name, \ + outsz, setsz_out, fov) + self.lbls[name] = self.outs[name].data + self.msks[name] = self.outs[name].msk if not is_forward: self._prepare_training() From e6cd2be6192b56cee6992b89ae0d7205f412ec22 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 15 Jan 2016 15:33:41 -0500 Subject: [PATCH 4/4] Delete new_train.py --- python/new_train.py | 190 -------------------------------------------- 1 file changed, 190 deletions(-) delete mode 100644 python/new_train.py diff --git a/python/new_train.py b/python/new_train.py deleted file mode 100644 index cc0cf900..00000000 --- a/python/new_train.py +++ /dev/null @@ -1,190 +0,0 @@ -#!/usr/bin/env python -__doc__ = """ - -Jingpeng Wu , 2015 -""" -import time -from front_end import * -import cost_fn -import test -import utils -import zstatistics -import os - -def main( conf_file='config.cfg', logfile=None ): - #%% parameters - print "reading config parameters..." - config, pars = zconfig.parser( conf_file ) - - if pars.has_key('logging') and pars['logging']: - print "recording configuration file..." - zconfig.record_config_file( pars ) - - logfile = zlog.make_logfile_name( pars ) - - #%% create and initialize the network - if pars['train_load_net'] and os.path.exists(pars['train_load_net']): - print "loading network..." - net = znetio.load_network( pars ) - # load existing learning curve - lc = zstatistics.CLearnCurve( pars['train_load_net'] ) - # the last iteration we want to continue training - iter_last = lc.get_last_it() - else: - if pars['train_seed_net'] and os.path.exists(pars['train_seed_net']): - print "seeding network..." - net = znetio.load_network( pars, is_seed=True ) - else: - print "initializing network..." - net = znetio.init_network( pars ) - # initalize a learning curve - lc = zstatistics.CLearnCurve() - iter_last = lc.get_last_it() - - # show field of view - print "field of view: ", net.get_fov() - - # total voxel number of output volumes - vn = utils.get_total_num(net.get_outputs_setsz()) - - # set some parameters - print 'setting up the network...' - eta = pars['eta'] - net.set_eta( pars['eta'] ) - net.set_momentum( pars['momentum'] ) - net.set_weight_decay( pars['weight_decay'] ) - - # initialize samples - outsz = pars['train_outsz'] - print "\n\ncreate train samples..." - smp_trn = zsample.CSamples(config, pars, pars['train_range'], net, outsz, logfile) - print "\n\ncreate test samples..." - smp_tst = zsample.CSamples(config, pars, pars['test_range'], net, outsz, logfile) - - # initialization - elapsed = 0 - err = 0.0 # cost energy - cls = 0.0 # pixel classification error - re = 0.0 # rand error - # number of voxels which accumulate error - # (if a mask exists) - num_mask_voxels = 0 - - if pars['is_malis']: - malis_cls = 0.0 - - print "start training..." - start = time.time() - total_time = 0.0 - print "start from ", iter_last+1 - - #Saving initialized network - if iter_last+1 == 1: - znetio.save_network(net, pars['train_save_net'], num_iters=0) - - for i in xrange(iter_last+1, pars['Max_iter']+1): - # get random sub volume from sample - vol_ins, lbl_outs, msks, wmsks = smp_trn.get_random_sample() - - # forward pass - # apply the transformations in memory rather than array view - vol_ins = utils.make_continuous(vol_ins, dtype=pars['dtype']) - props = net.forward( vol_ins ) - - # cost function and accumulate errors - props, cerr, grdts = pars['cost_fn']( props, lbl_outs, msks ) - err += cerr - cls += cost_fn.get_cls(props, lbl_outs) - num_mask_voxels += utils.sum_over_dict(msks) - - # gradient reweighting - grdts = utils.dict_mul( grdts, msks ) - grdts = utils.dict_mul( grdts, wmsks ) - - if pars['is_malis'] : - malis_weights, rand_errors = cost_fn.malis_weight(pars, props, lbl_outs) - grdts = utils.dict_mul(grdts, malis_weights) - # accumulate the rand error - re += rand_errors.values()[0] - malis_cls_dict = utils.get_malis_cls( props, lbl_outs, malis_weights ) - malis_cls += malis_cls_dict.values()[0] - - - total_time += time.time() - start - start = time.time() - - # test the net - if i%pars['Num_iter_per_test']==0: - lc = test.znn_test(net, pars, smp_tst, vn, i, lc) - - if i%pars['Num_iter_per_show']==0: - # normalize - if utils.dict_mask_empty(msks): - err = err / vn / pars['Num_iter_per_show'] - cls = cls / vn / pars['Num_iter_per_show'] - else: - err = err / num_mask_voxels / pars['Num_iter_per_show'] - cls = cls / num_mask_voxels / pars['Num_iter_per_show'] - - lc.append_train(i, err, cls) - - # time - elapsed = total_time / pars['Num_iter_per_show'] - - if pars['is_malis']: - re = re / pars['Num_iter_per_show'] - lc.append_train_rand_error( re ) - malis_cls = malis_cls / pars['Num_iter_per_show'] - lc.append_train_malis_cls( malis_cls ) - - show_string = "iteration %d, err: %.3f, cls: %.3f, re: %.6f, mc: %.3f, elapsed: %.1f s/iter, learning rate: %.6f"\ - %(i, err, cls, re, malis_cls, elapsed, eta ) - else: - show_string = "iteration %d, err: %.3f, cls: %.3f, elapsed: %.1f s/iter, learning rate: %.6f"\ - %(i, err, cls, elapsed, eta ) - - if pars.has_key('logging') and pars['logging']: - utils.write_to_log(logfile, show_string) - print show_string - - # reset err and cls - err = 0 - cls = 0 - re = 0 - num_mask_voxels = 0 - - if pars['is_malis']: - malis_cls = 0 - - # reset time - total_time = 0 - start = time.time() - - if i%pars['Num_iter_per_annealing']==0: - # anneal factor - eta = eta * pars['anneal_factor'] - net.set_eta(eta) - - if i%pars['Num_iter_per_save']==0: - # save network - znetio.save_network(net, pars['train_save_net'], num_iters=i) - lc.save( pars, elapsed ) - if pars['is_malis']: - utils.save_malis(malis_weights, pars['train_save_net'], num_iters=i) - - # run backward pass - grdts = utils.make_continuous(grdts, dtype=pars['dtype']) - net.backward( grdts ) - - -if __name__ == '__main__': - """ - usage - ------ - python train.py path/to/config.cfg - """ - import sys - if len(sys.argv)>1: - main( sys.argv[1] ) - else: - main()