Skip to content

Commit

Permalink
Remove bn.
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischute committed Dec 9, 2018
1 parent 3374f19 commit fc0c170
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 29 deletions.
28 changes: 1 addition & 27 deletions models/real_nvp/coupling_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from enum import IntEnum
from models.resnet import ResNet
from util import BatchNormStats2d, checkerboard_mask
from util import checkerboard_mask


class MaskType(IntEnum):
Expand Down Expand Up @@ -37,10 +37,6 @@ def __init__(self, in_channels, mid_channels, num_blocks, mask_type, reverse_mas
# Learnable scale for s
self.scale = nn.utils.weight_norm(Scalar())

# Out batch norm
self.norm = BatchNormStats2d(in_channels)
self.use_norm = True

def forward(self, x, sldj=None, reverse=True):
if self.mask_type == MaskType.CHECKERBOARD:
# Checkerboard mask
Expand All @@ -54,11 +50,6 @@ def forward(self, x, sldj=None, reverse=True):

# Scale and translate
if reverse:
if self.use_norm:
m, v = self.norm(x * (1 - b), training=False)
log_v = v.log()
x = x * (.5 * log_v * (1 - b)).exp() + m * (1 - b)

inv_exp_s = s.mul(-1).exp()
if torch.isnan(inv_exp_s).any():
raise RuntimeError('Scale factor has NaN entries')
Expand All @@ -71,12 +62,6 @@ def forward(self, x, sldj=None, reverse=True):

# Add log-determinant of the Jacobian
sldj += s.view(s.size(0), -1).sum(-1)

if self.use_norm:
m, v = self.norm(x * (1 - b), self.training)
log_v = v.log()
x = (x - m * (1 - b)) * (-.5 * log_v * (1 - b)).exp()
sldj -= (.5 * log_v * (1 - b)).view(log_v.size(0), -1).sum(-1)
else:
# Channel-wise mask
if self.reverse_mask:
Expand All @@ -90,11 +75,6 @@ def forward(self, x, sldj=None, reverse=True):

# Scale and translate
if reverse:
if self.use_norm:
m, v = self.norm(x_change, training=False)
log_v = v.log()
x_change = x_change * (.5 * log_v).exp() + m

inv_exp_s = s.mul(-1).exp()
if torch.isnan(inv_exp_s).any():
raise RuntimeError('Scale factor has NaN entries')
Expand All @@ -108,12 +88,6 @@ def forward(self, x, sldj=None, reverse=True):
# Add log-determinant of the Jacobian
sldj += s.view(s.size(0), -1).sum(-1)

if self.use_norm:
m, v = self.norm(x_change, self.training)
log_v = v.log()
x_change = (x_change - m) * (-.5 * log_v).exp()
sldj -= (.5 * log_v).view(log_v.size(0), -1).sum(-1)

if self.reverse_mask:
x = torch.cat((x_id, x_change), dim=1)
else:
Expand Down
1 change: 0 additions & 1 deletion models/real_nvp/real_nvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def __init__(self, scale_idx, num_scales, in_channels, mid_channels, num_blocks)
self.next_block = _RealNVP(scale_idx + 1, num_scales, 2 * in_channels, 2 * mid_channels, num_blocks)

def forward(self, x, sldj, reverse=False):

if reverse:
if not self.is_last_block:
# Re-squeeze -> split -> next block
Expand Down
2 changes: 1 addition & 1 deletion util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from util.array_util import squeeze_2x2, checkerboard_mask
from util.norm_util import BatchNormStats2d, get_norm_layer, get_param_groups, WNConv2d
from util.norm_util import get_norm_layer, get_param_groups, WNConv2d
from util.optim_util import bits_per_dim, clip_grad_norm
from util.shell_util import AverageMeter

0 comments on commit fc0c170

Please sign in to comment.