Skip to content
This repository has been archived by the owner on Feb 6, 2020. It is now read-only.

Commit

Permalink
Merge pull request #46 from seung-lab/extra_supervision
Browse files Browse the repository at this point in the history
fix issue#39, issue#42

Fixes the forward pass and make the code concise.
  • Loading branch information
xiuliren committed Feb 2, 2016
2 parents 428f9ef + e6cd2be commit 3d98ca6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 37 deletions.
25 changes: 13 additions & 12 deletions python/front_end/zsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,22 @@ def __init__(self, config, pars, sample_id, net, \
outsz, setsz_in, fov, is_forward=is_forward )
self.imgs[name] = self.ins[name].data

print "\ncreate label image class..."
self.lbls = dict()
self.msks = dict()
self.outs = dict()
for name,setsz_out in self.setsz_outs.iteritems():
#Allowing for users to abstain from specifying labels
if not config.has_option(self.sec_name, name):
continue
#Finding the section of the config file
imid = config.getint(self.sec_name, name)
imsec_name = "label%d" % (imid,)
self.outs[name] = ConfigOutputLabel( config, pars, imsec_name, \
outsz, setsz_out, fov)
self.lbls[name] = self.outs[name].data
self.msks[name] = self.outs[name].msk
if not is_forward:
print "\ncreate label image class..."
for name,setsz_out in self.setsz_outs.iteritems():
#Allowing for users to abstain from specifying labels
if not config.has_option(self.sec_name, name):
continue
#Finding the section of the config file
imid = config.getint(self.sec_name, name)
imsec_name = "label%d" % (imid,)
self.outs[name] = ConfigOutputLabel( config, pars, imsec_name, \
outsz, setsz_out, fov)
self.lbls[name] = self.outs[name].data
self.msks[name] = self.outs[name].msk

if not is_forward:
self._prepare_training()
Expand Down
57 changes: 32 additions & 25 deletions src/include/network/parallel/network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,40 +123,47 @@ class network
#endif


void fov_pass(nnodes* n, const vec3i& fov, const vec3i& fsize )
void fov_pass(nnodes* n, vec3i fov, const vec3i& fsize )
{
if ( n->fov != vec3i::zero )
{
ZI_ASSERT(n->fsize==fsize);
ZI_ASSERT(n->fov==fov);
if ( n->fov == fov )
{
return;
}
else
{
fov[0] = std::max(n->fov[0],fov[0]);
fov[1] = std::max(n->fov[1],fov[1]);
fov[2] = std::max(n->fov[2],fov[2]);
}
}
else

for ( auto& e: n->out )
{
for ( auto& e: n->out )
e->in_fsize = fsize;
}
n->fov = fov;
n->fsize = fsize;
for ( auto& e: n->in )
{
if ( e->pool )
{
e->in_fsize = fsize;
vec3i new_fov = fov * e->width;
vec3i new_fsize = e->width * fsize;
fov_pass(e->in, new_fov, new_fsize);
}
n->fov = fov;
n->fsize = fsize;
for ( auto& e: n->in )
else if ( e->crop )
{
if ( e->pool )
{
vec3i new_fov = fov * e->width;
vec3i new_fsize = e->width * fsize;
fov_pass(e->in, new_fov, new_fsize);
}
else if ( e->crop )
{
// FoV doesn't change
fov_pass(e->in, fov, fsize + e->width - vec3i::one);
}
else
{
vec3i new_fov = (fov - vec3i::one) * e->stride + e->width;
vec3i new_fsize = (e->width-vec3i::one) * e->in_stride + fsize;
fov_pass(e->in, new_fov, new_fsize);
}
// FoV doesn't change
fov_pass(e->in, fov, fsize + e->width - vec3i::one);
}
else
{
vec3i new_fov = (fov - vec3i::one) * e->stride + e->width;
vec3i new_fsize = (e->width-vec3i::one) * e->in_stride + fsize;
fov_pass(e->in, new_fov, new_fsize);
}
}
}
Expand Down

0 comments on commit 3d98ca6

Please sign in to comment.