Skip to content

Commit

Permalink
Add BatchNorm.
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischute committed Dec 9, 2018
1 parent a5a13c2 commit 73922d0
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 68 deletions.
117 changes: 83 additions & 34 deletions models/real_nvp/coupling_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from enum import IntEnum
from models.resnet import ResNet
from util import checkerboard_mask, channel_wise_mask
from util import BatchNormStats2d, checkerboard_mask


class MaskType(IntEnum):
Expand All @@ -28,49 +28,98 @@ def __init__(self, in_channels, mid_channels, num_blocks, mask_type, reverse_mas
self.mask_type = mask_type
self.reverse_mask = reverse_mask

# Build neural network for scale and translate
# Build scale and translate network
if self.mask_type == MaskType.CHANNEL_WISE:
in_channels //= 2
self.st_net = ResNet(in_channels, mid_channels, 2 * in_channels,
num_blocks=num_blocks, kernel_size=3, padding=1)

# Learnable scale for s
self.scale = nn.utils.weight_norm(Scalar())

def forward(self, x, sldj=None, reverse=True):
# Get scale and translate factors
b = self._get_mask(x)
x_b = x * b
st = self.st_net(x_b, b)
s, t = st.chunk(2, dim=1)
s = self.scale(torch.tanh(s))
s = s * (1 - b)
t = t * (1 - b)

# Scale and translate
if reverse:
inv_exp_s = s.mul(-1).exp()
if torch.isnan(inv_exp_s).any():
raise RuntimeError('Scale factor has NaN entries')
x = x_b + inv_exp_s * ((1 - b) * x - t)
else:
exp_s = s.exp()
if torch.isnan(exp_s).any():
raise RuntimeError('Scale factor has NaN entries')
x = x_b + (1 - b) * (x * exp_s + t)
# Out batch norm
self.norm = BatchNormStats2d(in_channels)
self.use_norm = True

# Add log-determinant of the Jacobian
sldj += s.view(s.size(0), -1).sum(-1)

return x, sldj

def _get_mask(self, x):
def forward(self, x, sldj=None, reverse=True):
if self.mask_type == MaskType.CHECKERBOARD:
mask = checkerboard_mask(x.size(2), x.size(3), self.reverse_mask, device=x.device)
elif self.mask_type == MaskType.CHANNEL_WISE:
mask = channel_wise_mask(x.size(1), self.reverse_mask, device=x.device)
# Checkerboard mask
b = checkerboard_mask(x.size(2), x.size(3), self.reverse_mask, device=x.device)
x_b = x * b
st = self.st_net(x_b)
s, t = st.chunk(2, dim=1)
s = self.scale(torch.tanh(s))
s = s * (1 - b)
t = t * (1 - b)

# Scale and translate
if reverse:
if self.use_norm:
m, v = self.norm(x * (1 - b), training=False)
log_v = v.log()
x = x * (.5 * log_v * (1 - b)).exp() + m * (1 - b)

inv_exp_s = s.mul(-1).exp()
if torch.isnan(inv_exp_s).any():
raise RuntimeError('Scale factor has NaN entries')
x = x * inv_exp_s - t
else:
exp_s = s.exp()
if torch.isnan(exp_s).any():
raise RuntimeError('Scale factor has NaN entries')
x = (x + t) * exp_s

# Add log-determinant of the Jacobian
sldj += s.view(s.size(0), -1).sum(-1)

if self.use_norm:
m, v = self.norm(x * (1 - b), self.training)
log_v = v.log()
x = (x - m * (1 - b)) * (-.5 * log_v * (1 - b)).exp()
sldj -= (.5 * log_v * (1 - b)).view(log_v.size(0), -1).sum(-1)
else:
raise ValueError('Mask type must be MaskType.CHECKERBOARD or MaskType.CHANNEL_WISE')
# Channel-wise mask
if self.reverse_mask:
x_id, x_change = x.chunk(2, dim=1)
else:
x_change, x_id = x.chunk(2, dim=1)

st = self.st_net(x_id)
s, t = st.chunk(2, dim=1)
s = self.scale(torch.tanh(s))

# Scale and translate
if reverse:
if self.use_norm:
m, v = self.norm(x_change, training=False)
log_v = v.log()
x_change = x_change * (.5 * log_v).exp() + m

inv_exp_s = s.mul(-1).exp()
if torch.isnan(inv_exp_s).any():
raise RuntimeError('Scale factor has NaN entries')
x_change = x_change * inv_exp_s - t
else:
exp_s = s.exp()
if torch.isnan(exp_s).any():
raise RuntimeError('Scale factor has NaN entries')
x_change = (x_change + t) * exp_s

# Add log-determinant of the Jacobian
sldj += s.view(s.size(0), -1).sum(-1)

if self.use_norm:
m, v = self.norm(x_change, self.training)
log_v = v.log()
x_change = (x_change - m) * (-.5 * log_v).exp()
sldj -= (.5 * log_v).view(log_v.size(0), -1).sum(-1)

if self.reverse_mask:
x = torch.cat((x_id, x_change), dim=1)
else:
x = torch.cat((x_change, x_id), dim=1)

return mask
return x, sldj


class Scalar(nn.Module):
Expand Down
7 changes: 5 additions & 2 deletions models/resnet/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class ResNet(nn.Module):
"""
def __init__(self, in_channels, mid_channels, out_channels, num_blocks, kernel_size, padding):
super(ResNet, self).__init__()
self.in_conv = WNConv2d(in_channels, mid_channels, kernel_size, padding, bias=True)
self.in_norm = nn.BatchNorm2d(in_channels, affine=False)
self.in_conv = WNConv2d(2 * in_channels, mid_channels, kernel_size, padding, bias=True)
self.in_skip = WNConv2d(mid_channels, mid_channels, kernel_size=1, padding=0, bias=True)

self.blocks = nn.ModuleList([ResidualBlock(mid_channels, mid_channels)
Expand All @@ -29,7 +30,9 @@ def __init__(self, in_channels, mid_channels, out_channels, num_blocks, kernel_s
self.out_norm = nn.BatchNorm2d(mid_channels)
self.out_conv = WNConv2d(mid_channels, out_channels, kernel_size=1, padding=0, bias=True)

def forward(self, x, b):
def forward(self, x):
x = self.in_norm(x)
x = F.relu(torch.cat((x, -x), dim=1))
x = self.in_conv(x)
x_skip = self.in_skip(x)

Expand Down
4 changes: 2 additions & 2 deletions util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from util.array_util import squeeze_2x2, checkerboard_mask, channel_wise_mask
from util.norm_util import get_norm_layer, get_param_groups, WNConv2d
from util.array_util import squeeze_2x2, checkerboard_mask
from util.norm_util import BatchNormStats2d, get_norm_layer, get_param_groups, WNConv2d
from util.optim_util import bits_per_dim, clip_grad_norm
from util.shell_util import AverageMeter
29 changes: 0 additions & 29 deletions util/array_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,32 +103,3 @@ def checkerboard_mask(height, width, reverse=False, dtype=torch.float32,
mask = mask.view(1, 1, height, width)

return mask


def channel_wise_mask(num_channels, reverse=False, dtype=torch.float32,
device=None, requires_grad=False):
"""Get a channel-wise mask. In non-reversed mask, first N/2 channels are 0,
and last N/2 channels are 1.
Args:
num_channels (int): Number of channels in the mask.
reverse (bool): If True, reverse the mask (i.e., make first N/2 channels 1).
Useful for alternating masks in RealNVP.
dtype (torch.dtype): Data type of the tensor.
device (torch.device): Device on which to construct the tensor.
requires_grad (bool): Whether the tensor requires gradient.
Returns:
mask (torch.tensor): channel-wise mask of shape (1, num_channels, 1, 1).
"""
half_channels = num_channels // 2
channel_wise = [int(i < half_channels) for i in range(num_channels)]
mask = torch.tensor(channel_wise, dtype=dtype, device=device, requires_grad=requires_grad)

if reverse:
mask = 1 - mask

# Reshape to (1, num_channels, 1, 1) for broadcasting with tensors of shape (B, C, H, W)
mask = mask.view(1, num_channels, 1, 1)

return mask
42 changes: 41 additions & 1 deletion util/norm_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F


def get_norm_layer(norm_type='instance'):
Expand Down Expand Up @@ -61,3 +61,43 @@ def forward(self, x):
x = self.conv(x)

return x


class BatchNormStats2d(nn.Module):
"""Compute BatchNorm2d normalization statistics: `mean` and `var`.
Useful for keeping track of sum of log-determinant of Jacobians in flow models.
Args:
num_features (int): Number of features in the input (i.e., `C` in `(N, C, H, W)`).
eps (float): Added to the denominator for numerical stability.
decay (float): The value used for the running_mean and running_var computation.
Different from conventional momentum, see `nn.BatchNorm2d` for more.
"""
def __init__(self, num_features, eps=1e-5, decay=0.1):
super(BatchNormStats2d, self).__init__()
self.eps = eps

self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.decay = decay

def forward(self, x, training):
# Get mean and variance per channel
if training:
channels = x.transpose(0, 1).contiguous().view(x.size(1), -1)
used_mean, used_var = channels.mean(-1), channels.var(-1)
curr_mean, curr_var = used_mean, used_var

# Update variables
self.running_mean = self.running_mean - self.decay * (self.running_mean - curr_mean)
self.running_var = self.running_var - self.decay * (self.running_var - curr_var)
else:
used_mean = self.running_mean
used_var = self.running_var

used_var += self.eps

# Reshape to (N, C, H, W)
used_mean = used_mean.view(1, x.size(1), 1, 1).expand_as(x)
used_var = used_var.view(1, x.size(1), 1, 1).expand_as(x)

return used_mean, used_var

0 comments on commit 73922d0

Please sign in to comment.