Skip to content

Latest commit

 

History

History
57 lines (47 loc) · 1.92 KB

README.md

File metadata and controls

57 lines (47 loc) · 1.92 KB

Pytorch Implementation of RealNVP

Basic pytorch implementation of RealNVP for higher dimensional images based on the paper Density estimation using Real NVP.

The pseudocode for the realNVP class ,

Preprocess() # dequantization, logit trick from RealNVP (Dinh et al) Section 4.1 (mentioned in utils.py)

for _ in range(4):
  AffineCouplingWithCheckerboard() # Figure 3 in Dinh et al - (left)
  ActNorm() # described in Glow (Kingma & Dhariwal) Section 3.1
Squeeze(), # [b, c, h, w] --> [b, c*4, h//2, w//2]

for _ in range(3):
  AffineCouplingWithChannel()
  ActNorm()
Unsqueeze(), # [b, c*4, h//2, w//2] --> [b, c, h, w]

for _ in range(3):
  AffineCouplingWithCheckerboard()
  ActNorm()

The pseudocode for the coupling layers is,

ResnetBlock: n_filters
  h = x
  h = conv2d(n_filters, n_filters, (1,1), stride=1, padding=0)(h)
  h = relu(h)
  h = conv2d(n_filters, n_filters, (3,3), stride=1, padding=1)(h)
  h = relu(h)
  h = conv2d(n_filters, n_filters, (1,1), stride=1, padding=0)(h)
  return h + x

SimpleResnet: n_filters = 256, n_blocks = 8, n_out
  conv2d(in_channels, n_filters=128, (3,3), stride=1, padding=1)
  apply 8 ResnetBlocks with n_filters=128
  relu()
  conv2d(n_filters, n_filters=n_out, (3,3), stride=1, padding=1)

AffineCoupling(x, mask):
  x_ = x * mask
  log_s, t = torch.chunk(SimpleResnet(x_), 2, dim=1)
  t = t * (1.0 - mask)
  log_scale = log_scale * (1.0 - mask)
  z = x * torch.exp(log_scale) + t
  log_det_jacobian = log_scale
  return z, log_det_jacobian

Results

The dataset used is CIFAR10 and the samples were generated after 5, 10, 15 and 20 epochs. The images generated for CIFAR10 are as follows,

References

  1. https://github.com/fmu2/realNVP
  2. Implementation of Real_NVP in pytorchhttps://github.com/chrischute/real-nvp
  3. CS294 Deep Unsupervised Learning Course - Assignment 2 https://sites.google.com/view/berkeley-cs294-158-sp20/home