Skip to content

Commit 5ca2bd9

Browse files
committed
Adds code
1 parent 37f7fce commit 5ca2bd9

File tree

6 files changed

+1041
-2
lines changed

6 files changed

+1041
-2
lines changed

README.md

+72-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,72 @@
1-
# dplc
2-
Deep generative models for distribution-preserving lossy compression
1+
# Deep Generative Models for Distribution-Preserving Lossy Compression
2+
3+
<p align='center'>
4+
<img src='figs/visuals.jpeg' width='440'/>
5+
</p>
6+
7+
### [[Paper]](https://arxiv.org/abs/1805.11057) [[Citation]](#citation)
8+
9+
PyTorch implementation of **Deep Generative Models for Distribution-Preserving Lossy Compression** (NIPS 2018) a framework that unifies generative models and lossy compression. The resulting models behave like generative models at zero bitrate, almost perfectly reconstruct the training data at high enough bitrate, and smoothly interpolate between generation and reconstruction at intermediate bitrates (cf. the figure above, the numbers indicate the rate in bits per pixel).
10+
11+
12+
## Prerequisites
13+
14+
- Python 3 (tested with Python 3.6.4)
15+
- PyTorch (version 0.4.1)
16+
- [tensorboardX](https://github.com/lanpa/tensorboardX)
17+
18+
## Training
19+
20+
The training procedure consists of two steps
21+
22+
1. Learn a generative model of the data.
23+
2. Learn a rate-constrained encoder and a stochastic mapping into the latent space of the of the fixed generative model by minimizing distortion.
24+
25+
The `train.py` script allows to do both of these steps.
26+
27+
To learn the generative model we consider [[Wasserstein GAN with gradient penalty (WGAN-GP)](https://arxiv.org/abs/1704.00028), [Wasserstein Autoencoder (WAE)](https://arxiv.org/abs/1711.01558), and a combination of the two termed Wasserstein++. The following examples show how to train these models as in the experiments in the paper using the [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) data set (see `train.py` for a description of the flags).
28+
29+
WGAN-GP:
30+
31+
python train.py --dataset celeba --dataroot /path/to/traindata/ --testroot /path/to/testdata/ --cuda --nz 128 \
32+
--sigmasqz 1.0 --lr_eg 0.0001 --lr_di 0.0001 --beta1 0.5 --beta2 0.9 --niter 165 --check_every 100 \
33+
--workers 6 --outf /path/to/results/ --batchSize 64 --test_every 100 --addsamples 10000 --manualSeed 321 \
34+
--wganloss
35+
36+
WAE:
37+
38+
python train.py --dataset celeba --dataroot /path/to/traindata/ --testroot /path/to/testdata/ --cuda --nz 128 \
39+
--sigmasqz 1.0 --lr_eg 0.001 --niter 55 --decay_steps 30 50 --decay_gamma 0.4 --check_every 100 \
40+
--workers 8 --recloss --mmd --bnz --outf /path/to/results/ --lbd 100 --batchSize 256 --detenc --useenc \
41+
--test_every 20 --addsamples 10000 --manualSeed 321
42+
43+
Wasserstein++:
44+
45+
python train.py --dataset celeba --dataroot /path/to/traindata/ --testroot /path/to/testdata/ --cuda --nz 128 \
46+
--sigmasqz 1.0 --lr_eg 0.0003 --niter 165 --decay_steps 100 140 --decay_gamma 0.4 --check_every 100 \
47+
--workers 6 --recloss --mmd --bnz --outf /path/to/results/ --lbd 100 --batchSize 256 --detenc --useenc \
48+
--test_every 20 --addsamples 10000 --manualSeed 321 --wganloss --useencdist --lbd_di 0.000025 --intencprior
49+
50+
To learn the rate-constrained encoder and the stochastic mapping run the following (parameters again for the experiment on the CelebA data set):
51+
52+
python train.py --dataset celeba --dataroot /path/to/traindata/ --testroot /path/to/testdata/ --cuda --nz 128 \
53+
--sigmasqz 1.0 --lr_eg 0.001 --niter 55 --decay_steps 30 50 --decay_gamma 0.4 --check_every 100 \
54+
--workers 6 --recloss --mmd --bnz --batchSize 256 --useenc --comp --freezedec --test_every 100 \
55+
--addsamples 10000 --manualSeed 321 --outf /path/to/results/ --netG /path/to/trained/generator \
56+
--nresenc 2 --lbd 300 --ncenc 8
57+
58+
Here, `--ncenc` determines the number of channels at the encoder output (and hence the bitrate) and `--lbd` determines the regularization strength of the MMD penalty on the latent space (has to be adapted as a function of the bitrate).
59+
60+
61+
In the paper we also consider the [LSUN bedrooms](https://github.com/fyu/lsun) data set. We provide the flag `--lsun_custom_split` that splits off 10k samples for the LSUM training set (the LSUN testing set is too small to compute the FID score to asses sample quality). Otherwise, training on the LSUN data set is as outlined above (with different parameters).
62+
63+
64+
## Citation
65+
66+
If you use this code for your research, please cite this paper:
67+
68+
@inproceedings{tschannen2018deep,
69+
Author = {Tschannen, Michael and Agustsson, Eirikur and Lucic, Mario},
70+
Title = {Deep Generative Models for Distribution-Preserving Lossy Compression},
71+
Booktitle = {Advances in Neural Information Processing Systems (NIPS)},
72+
Year = {2018}}

figs/visuals.jpeg

480 KB
Loading

models.py

+210
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import torch
2+
from torch import nn
3+
from resblock import BasicBlock
4+
from torch.autograd import Variable
5+
from scalar_quantizer import quantize
6+
import math
7+
8+
9+
10+
# Encoder and stochastic function (B in the paper)
11+
class _netE(nn.Module):
12+
def __init__(self, nc, nz, ngf, kernel=2, padding=1, img_width=64, img_height=64,
13+
quant_levels=None, do_comp=False, ncenc=8, nresenc=0, detenc=False,
14+
noisedelta=0.5, bnz=False, ngpu=1):
15+
super(_netE, self).__init__()
16+
self.ngpu = ngpu
17+
self.detenc = detenc or not do_comp
18+
self.noisedelta = noisedelta
19+
self.nfmodelz = math.ceil(nz / ((img_height//16) * (img_width//16))) + ncenc
20+
self.ncenc = ncenc
21+
22+
model_down_list = [
23+
# input is (nc) x 64 x 64
24+
nn.Conv2d(nc, ngf, kernel, 2, padding, bias=False),
25+
nn.ReLU(True),
26+
# state size. (ndf) x 32 x 32
27+
nn.Conv2d(ngf, ngf * 2, kernel, 2, padding, bias=False),
28+
nn.BatchNorm2d(ngf * 2),
29+
nn.ReLU(True),
30+
# state size. (ndf*2) x 16 x 16
31+
nn.Conv2d(ngf * 2, ngf * 4, kernel, 2, padding, bias=False),
32+
nn.BatchNorm2d(ngf * 4),
33+
nn.ReLU(True),
34+
# state size. (ndf*4) x 8 x 8
35+
nn.Conv2d(ngf * 4, ngf * 8, kernel, 2, padding, bias=False),
36+
nn.BatchNorm2d(ngf * 8),
37+
nn.ReLU(True)
38+
]
39+
# state size. (ndf*8) x 4 x 4
40+
41+
# quantize if in compression mode
42+
if do_comp:
43+
model_down_list += [
44+
nn.Conv2d(ngf * 8, ncenc, 3, 1, 1, bias=True),
45+
quantize(quant_levels)
46+
]
47+
48+
self.model_down = nn.Sequential(*model_down_list)
49+
50+
# stochastic function mapping compressed representation to latent space
51+
# of generator (B in paper)
52+
if do_comp:
53+
model_z_list = [
54+
nn.ConvTranspose2d(ncenc, ngf * 8, 3, 1, 1, bias=True) if detenc \
55+
else nn.ConvTranspose2d(self.nfmodelz, ngf * 8, 3, 1, 1, bias=True)
56+
]
57+
else:
58+
model_z_list = []
59+
60+
if nresenc > 0:
61+
model_z_list += [BasicBlock(ngf * 8, ngf * 8) for _ in range(nresenc)]
62+
63+
model_z_list += [nn.Conv2d(ngf * 8, nz, (img_height//16, img_width//16), 1, 0, bias=False)]
64+
65+
# batchnorm to facilitate prior matching
66+
if bnz:
67+
model_z_list += [nn.BatchNorm2d(nz)]
68+
69+
self.model_z = nn.Sequential(*model_z_list)
70+
71+
72+
def forward(self, input):
73+
use_cuda = isinstance(input.data, torch.cuda.FloatTensor)
74+
if use_cuda and self.ngpu > 1:
75+
out_down = nn.parallel.data_parallel(self.model_down, input, range(self.ngpu))
76+
else:
77+
out_down = self.model_down(input)
78+
79+
if not self.detenc:
80+
# feed noise of appropriate dimension when using stoc. function
81+
out_down_pad_size = list(out_down.size())
82+
out_down_pad_size[1] = self.nfmodelz - self.ncenc
83+
out_down_pad = torch.zeros(out_down_pad_size)
84+
out_down_pad.uniform_(-self.noisedelta, self.noisedelta)
85+
if use_cuda:
86+
out_down_pad = out_down_pad.cuda()
87+
out_down = torch.cat([out_down, Variable(out_down_pad)], 1)
88+
89+
if use_cuda and self.ngpu > 1:
90+
output = nn.parallel.data_parallel(self.model_z, out_down, range(self.ngpu))
91+
else:
92+
output = self.model_z(out_down)
93+
94+
return output
95+
96+
97+
# Standard DCGAN-type generator/decoder
98+
class _netG(nn.Module):
99+
def __init__(self, nc, nz, ngf, kernel=2, padding=1, output_padding=0, img_width=64, img_height=64, nresdec=0, ngpu=1):
100+
super(_netG, self).__init__()
101+
self.ngpu = ngpu
102+
103+
# input is z, going into a convolution
104+
main_list = [nn.ConvTranspose2d(nz, ngf * 8, (img_height//16, img_width//16), 1, 0, bias=False),
105+
nn.BatchNorm2d(ngf * 8),
106+
nn.ReLU(True)]
107+
108+
if nresdec > 0:
109+
main_list += [BasicBlock(ngf * 8, ngf * 8) for _ in range(nresdec)]
110+
111+
main_list += [
112+
# state size. (ngf*8) x 4 x 4
113+
nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel, 2, padding, output_padding, bias=False),
114+
nn.BatchNorm2d(ngf * 4),
115+
nn.ReLU(True),
116+
# state size. (ngf*4) x 8 x 8
117+
nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel, 2, padding, output_padding, bias=False),
118+
nn.BatchNorm2d(ngf * 2),
119+
nn.ReLU(True),
120+
# state size. (ngf*2) x 16 x 16
121+
nn.ConvTranspose2d(ngf * 2, ngf, kernel, 2, padding, output_padding, bias=False),
122+
nn.BatchNorm2d(ngf),
123+
nn.ReLU(True),
124+
# state size. (ngf) x 32 x 32
125+
nn.ConvTranspose2d( ngf, ngf, kernel, 2, padding, output_padding, bias=False),
126+
nn.BatchNorm2d(ngf),
127+
nn.ReLU(True),
128+
nn.Conv2d( ngf, nc, 3, 1, 1, bias=True),
129+
nn.Tanh()
130+
# state size. (nc) x 64 x 64
131+
]
132+
133+
self.main = nn.Sequential(*main_list)
134+
135+
def forward(self, input):
136+
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
137+
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
138+
else:
139+
output = self.main(input)
140+
141+
return output
142+
143+
144+
# MLP discriminator in z-space
145+
class _netDz(nn.Module):
146+
def __init__(self, nz, ndf=512, ndl=5, ngpu=0, avbtrick=False, sigmasq=1):
147+
super(_netDz, self).__init__()
148+
self.ngpu = ngpu
149+
self.avbtrick = avbtrick
150+
self.sigmasqz = sigmasq
151+
self.nz = nz
152+
153+
layers = [[nn.Linear(ndf, ndf), nn.ReLU(True)] for _ in range(ndl-2)]
154+
155+
layers = [nn.Linear(nz, ndf), nn.ReLU(True)] \
156+
+ sum(layers, []) \
157+
+ [nn.Linear(ndf, 1)]
158+
159+
self.main = nn.Sequential(*layers)
160+
161+
def forward(self, input):
162+
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
163+
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
164+
else:
165+
output = self.main(input)
166+
167+
# Nowozin trick from WAE paper, only valid for Gaussian prior
168+
if self.avbtrick:
169+
output = output - torch.norm(input, p=2, dim=1, keepdim=True)**2 / 2 / self.sigmasqz \
170+
- 0.5 * math.log(2 * math.pi) \
171+
- 0.5 * self.nz * math.log(self.sigmasqz)
172+
173+
return output.view(-1, 1).squeeze(1)
174+
175+
176+
# DCGAN-style discriminator in image space
177+
class _netDim(nn.Module):
178+
def __init__(self, nc=3, ndf=64, kernel=2, padding=1, img_width=64, img_height=64, ngpu=1):
179+
super(_netDim, self).__init__()
180+
self.ngpu = ngpu
181+
self.main = nn.Sequential(
182+
# input is (nc) x 64 x 64
183+
nn.Conv2d( nc, ndf, 3, 1, 1, bias=True),
184+
nn.LeakyReLU(0.2, inplace=True),
185+
nn.Conv2d(ndf, ndf, kernel, 2, padding, bias=False),
186+
nn.LayerNorm([ndf, img_height//2, img_width//2]),
187+
nn.LeakyReLU(0.2, inplace=True),
188+
# state size. (ndf) x 32 x 32
189+
nn.Conv2d(ndf, ndf * 2, kernel, 2, padding, bias=False),
190+
nn.LayerNorm([ndf * 2, img_height//4, img_width//4]),
191+
nn.LeakyReLU(0.2, inplace=True),
192+
# state size. (ndf*2) x 16 x 16
193+
nn.Conv2d(ndf * 2, ndf * 4, kernel, 2, padding, bias=False),
194+
nn.LayerNorm([ndf * 4, img_height//8, img_width//8]),
195+
nn.LeakyReLU(0.2, inplace=True),
196+
# state size. (ndf*4) x 8 x 8
197+
nn.Conv2d(ndf * 4, ndf * 8, kernel, 2, padding, bias=False),
198+
nn.LayerNorm([ndf * 8, img_height//16, img_width//16]),
199+
nn.LeakyReLU(0.2, inplace=True),
200+
# state size. (ndf*8) x 4 x 4
201+
nn.Conv2d(ndf * 8, 1, (img_height//16, img_width//16), 1, 0, bias=False),
202+
)
203+
204+
def forward(self, input):
205+
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
206+
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
207+
else:
208+
output = self.main(input)
209+
210+
return output.view(-1, 1)

resblock.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch.nn as nn
2+
3+
def conv3x3(in_planes, out_planes, stride=1):
4+
"""3x3 convolution with padding"""
5+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
6+
padding=1, bias=False)
7+
8+
9+
class BasicBlock(nn.Module):
10+
expansion = 1
11+
12+
def __init__(self, inplanes, planes, stride=1, downsample=None):
13+
super(BasicBlock, self).__init__()
14+
self.conv1 = conv3x3(inplanes, planes, stride)
15+
self.bn1 = nn.BatchNorm2d(planes)
16+
self.relu = nn.ReLU(inplace=True)
17+
self.conv2 = conv3x3(planes, planes)
18+
self.bn2 = nn.BatchNorm2d(planes)
19+
self.downsample = downsample
20+
self.stride = stride
21+
22+
def forward(self, x):
23+
residual = x
24+
25+
out = self.conv1(x)
26+
out = self.bn1(out)
27+
out = self.relu(out)
28+
29+
out = self.conv2(out)
30+
out = self.bn2(out)
31+
32+
if self.downsample is not None:
33+
residual = self.downsample(x)
34+
35+
out += residual
36+
out = self.relu(out)
37+
38+
return out

scalar_quantizer.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.autograd import Variable
4+
5+
class quantize(nn.Module):
6+
def __init__(self, levels=[-1.0, 1.0], sigma=1.0):
7+
super(quantize, self).__init__()
8+
self.levels = levels
9+
self.sigma = sigma
10+
11+
def forward(self, input):
12+
levels = input.data.new(self.levels)
13+
xsize = list(input.size())
14+
15+
# Compute differentiable soft quantized version
16+
input = input.view(*(xsize + [1]))
17+
level_var = Variable(levels, requires_grad=False)
18+
dist = torch.pow(input-level_var, 2)
19+
output = torch.sum(level_var * nn.functional.softmax(-self.sigma*dist, dim=-1), dim=-1)
20+
21+
# Compute hard quantization (invisible to autograd)
22+
_, symbols = torch.min(dist.data, dim=-1, keepdim=True)
23+
for _ in range(len(xsize)): levels.unsqueeze_(0)
24+
levels = levels.expand(*(xsize + [len(self.levels)]))
25+
26+
quant = levels.gather(-1, symbols.long()).squeeze_(dim=-1)
27+
28+
# Replace activations in soft variable with hard quantized version
29+
output.data = quant
30+
31+
return output

0 commit comments

Comments
 (0)