diff --git a/models/real_nvp/coupling_layer.py b/models/real_nvp/coupling_layer.py index 3641fe7..fc4957d 100644 --- a/models/real_nvp/coupling_layer.py +++ b/models/real_nvp/coupling_layer.py @@ -3,7 +3,7 @@ from enum import IntEnum from models.resnet import ResNet -from util import BatchNormStats2d, checkerboard_mask +from util import checkerboard_mask class MaskType(IntEnum): @@ -37,10 +37,6 @@ def __init__(self, in_channels, mid_channels, num_blocks, mask_type, reverse_mas # Learnable scale for s self.scale = nn.utils.weight_norm(Scalar()) - # Out batch norm - self.norm = BatchNormStats2d(in_channels) - self.use_norm = True - def forward(self, x, sldj=None, reverse=True): if self.mask_type == MaskType.CHECKERBOARD: # Checkerboard mask @@ -54,11 +50,6 @@ def forward(self, x, sldj=None, reverse=True): # Scale and translate if reverse: - if self.use_norm: - m, v = self.norm(x * (1 - b), training=False) - log_v = v.log() - x = x * (.5 * log_v * (1 - b)).exp() + m * (1 - b) - inv_exp_s = s.mul(-1).exp() if torch.isnan(inv_exp_s).any(): raise RuntimeError('Scale factor has NaN entries') @@ -71,12 +62,6 @@ def forward(self, x, sldj=None, reverse=True): # Add log-determinant of the Jacobian sldj += s.view(s.size(0), -1).sum(-1) - - if self.use_norm: - m, v = self.norm(x * (1 - b), self.training) - log_v = v.log() - x = (x - m * (1 - b)) * (-.5 * log_v * (1 - b)).exp() - sldj -= (.5 * log_v * (1 - b)).view(log_v.size(0), -1).sum(-1) else: # Channel-wise mask if self.reverse_mask: @@ -90,11 +75,6 @@ def forward(self, x, sldj=None, reverse=True): # Scale and translate if reverse: - if self.use_norm: - m, v = self.norm(x_change, training=False) - log_v = v.log() - x_change = x_change * (.5 * log_v).exp() + m - inv_exp_s = s.mul(-1).exp() if torch.isnan(inv_exp_s).any(): raise RuntimeError('Scale factor has NaN entries') @@ -108,12 +88,6 @@ def forward(self, x, sldj=None, reverse=True): # Add log-determinant of the Jacobian sldj += s.view(s.size(0), -1).sum(-1) - if self.use_norm: - m, v = self.norm(x_change, self.training) - log_v = v.log() - x_change = (x_change - m) * (-.5 * log_v).exp() - sldj -= (.5 * log_v).view(log_v.size(0), -1).sum(-1) - if self.reverse_mask: x = torch.cat((x_id, x_change), dim=1) else: diff --git a/models/real_nvp/real_nvp.py b/models/real_nvp/real_nvp.py index 5c6cbd7..acb99a9 100644 --- a/models/real_nvp/real_nvp.py +++ b/models/real_nvp/real_nvp.py @@ -106,7 +106,6 @@ def __init__(self, scale_idx, num_scales, in_channels, mid_channels, num_blocks) self.next_block = _RealNVP(scale_idx + 1, num_scales, 2 * in_channels, 2 * mid_channels, num_blocks) def forward(self, x, sldj, reverse=False): - if reverse: if not self.is_last_block: # Re-squeeze -> split -> next block diff --git a/util/__init__.py b/util/__init__.py index ff75187..580292c 100755 --- a/util/__init__.py +++ b/util/__init__.py @@ -1,4 +1,4 @@ from util.array_util import squeeze_2x2, checkerboard_mask -from util.norm_util import BatchNormStats2d, get_norm_layer, get_param_groups, WNConv2d +from util.norm_util import get_norm_layer, get_param_groups, WNConv2d from util.optim_util import bits_per_dim, clip_grad_norm from util.shell_util import AverageMeter