diff --git a/python/front_end/zsample.py b/python/front_end/zsample.py index ebea7584..4b65df93 100644 --- a/python/front_end/zsample.py +++ b/python/front_end/zsample.py @@ -56,21 +56,22 @@ def __init__(self, config, pars, sample_id, net, \ outsz, setsz_in, fov, is_forward=is_forward ) self.imgs[name] = self.ins[name].data - print "\ncreate label image class..." self.lbls = dict() self.msks = dict() self.outs = dict() - for name,setsz_out in self.setsz_outs.iteritems(): - #Allowing for users to abstain from specifying labels - if not config.has_option(self.sec_name, name): - continue - #Finding the section of the config file - imid = config.getint(self.sec_name, name) - imsec_name = "label%d" % (imid,) - self.outs[name] = ConfigOutputLabel( config, pars, imsec_name, \ - outsz, setsz_out, fov) - self.lbls[name] = self.outs[name].data - self.msks[name] = self.outs[name].msk + if not is_forward: + print "\ncreate label image class..." + for name,setsz_out in self.setsz_outs.iteritems(): + #Allowing for users to abstain from specifying labels + if not config.has_option(self.sec_name, name): + continue + #Finding the section of the config file + imid = config.getint(self.sec_name, name) + imsec_name = "label%d" % (imid,) + self.outs[name] = ConfigOutputLabel( config, pars, imsec_name, \ + outsz, setsz_out, fov) + self.lbls[name] = self.outs[name].data + self.msks[name] = self.outs[name].msk if not is_forward: self._prepare_training() diff --git a/src/include/network/parallel/network.hpp b/src/include/network/parallel/network.hpp index f9852c45..60a8f7a8 100644 --- a/src/include/network/parallel/network.hpp +++ b/src/include/network/parallel/network.hpp @@ -123,40 +123,47 @@ class network #endif - void fov_pass(nnodes* n, const vec3i& fov, const vec3i& fsize ) + void fov_pass(nnodes* n, vec3i fov, const vec3i& fsize ) { if ( n->fov != vec3i::zero ) { ZI_ASSERT(n->fsize==fsize); - ZI_ASSERT(n->fov==fov); + if ( n->fov == fov ) + { + return; + } + else + { + fov[0] = std::max(n->fov[0],fov[0]); + fov[1] = std::max(n->fov[1],fov[1]); + fov[2] = std::max(n->fov[2],fov[2]); + } } - else + + for ( auto& e: n->out ) { - for ( auto& e: n->out ) + e->in_fsize = fsize; + } + n->fov = fov; + n->fsize = fsize; + for ( auto& e: n->in ) + { + if ( e->pool ) { - e->in_fsize = fsize; + vec3i new_fov = fov * e->width; + vec3i new_fsize = e->width * fsize; + fov_pass(e->in, new_fov, new_fsize); } - n->fov = fov; - n->fsize = fsize; - for ( auto& e: n->in ) + else if ( e->crop ) { - if ( e->pool ) - { - vec3i new_fov = fov * e->width; - vec3i new_fsize = e->width * fsize; - fov_pass(e->in, new_fov, new_fsize); - } - else if ( e->crop ) - { - // FoV doesn't change - fov_pass(e->in, fov, fsize + e->width - vec3i::one); - } - else - { - vec3i new_fov = (fov - vec3i::one) * e->stride + e->width; - vec3i new_fsize = (e->width-vec3i::one) * e->in_stride + fsize; - fov_pass(e->in, new_fov, new_fsize); - } + // FoV doesn't change + fov_pass(e->in, fov, fsize + e->width - vec3i::one); + } + else + { + vec3i new_fov = (fov - vec3i::one) * e->stride + e->width; + vec3i new_fsize = (e->width-vec3i::one) * e->in_stride + fsize; + fov_pass(e->in, new_fov, new_fsize); } } }