-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit de635ed
Showing
18 changed files
with
1,376 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2022 Dominic Rampas | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
## Note: | ||
Code Tutorial + Implementation Tutorial | ||
|
||
<a href="https://www.youtube.com/watch?v=wcqLFDXaDO8"> | ||
<img alt="Qries" src="https://user-images.githubusercontent.com/61938694/154516539-90e2d4d0-4383-41f4-ad32-4c6d67bd2442.jpg" | ||
width="300"> | ||
</a> | ||
|
||
<a href="https://www.youtube.com/watch?v=_Br5WRwUz_U"> | ||
<img alt="Qries" src="https://user-images.githubusercontent.com/61938694/154628085-eede604f-442d-4bdb-a1ed-5ad3264e5aa0.jpg" | ||
width="300"> | ||
</a> | ||
|
||
# VQGAN | ||
Vector Quantized Generative Adversarial Networks (VQGAN) is a generative model for image modeling. It was introduced in [Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841). The concept is build upon two stages. The first stage learns in an autoencoder-like fashion by encoding images into a low-dimensional latent space, then applying vector quantization by making use of a codebook. Afterwards, the quantized latent vectors are projected back to the original image space by using a decoder. Encoder and Decoder are fully convolutional. The second stage is learning a transformer for the latent space. Over the course of training it learns which codebook vectors go along together and which not. This can then be used in an autoregressive fashion to generate before unseen images from the data distribution. | ||
|
||
## Results for First Stage (Reconstruction): | ||
|
||
|
||
### 1. Epoch: | ||
|
||
<img src="https://user-images.githubusercontent.com/61938694/154057590-3f457a92-42dd-4912-bb1e-9278a6ae99cc.jpg" width="500"> | ||
|
||
|
||
### 50. Epoch: | ||
|
||
<img src="https://user-images.githubusercontent.com/61938694/154057511-266fa6ce-5c45-4660-b669-1dca0841823f.jpg" width="500"> | ||
|
||
|
||
|
||
## Results for Second Stage (Generating new Images): | ||
Original Left | Reconstruction Middle Left | Completion Middle Right | New Image Right | ||
### 1. Epoch: | ||
|
||
<img src="https://user-images.githubusercontent.com/61938694/154058167-9627c71c-d180-449a-ba18-19a85843cee2.jpg" width="500"> | ||
|
||
### 100. Epoch: | ||
|
||
<img src="https://user-images.githubusercontent.com/61938694/154058563-700292b6-8fbb-4ba1-b4d7-5955030e4489.jpg" width="500"> | ||
|
||
Note: Let the model train for even longer to get better results. | ||
|
||
<hr> | ||
|
||
## Train VQGAN on your own data: | ||
### Training First Stage | ||
1. (optional) Configure Hyperparameters in ```training_vqgan.py``` | ||
2. Set path to dataset in ```training_vqgan.py``` | ||
3. ```python training_vqgan.py``` | ||
|
||
### Training Second Stage | ||
1. (optional) Configure Hyperparameters in ```training_transformer.py``` | ||
2. Set path to dataset in ```training_transformer.py``` | ||
3. ```python training_transformer.py``` | ||
|
||
|
||
## Citation | ||
```bibtex | ||
@misc{esser2021taming, | ||
title={Taming Transformers for High-Resolution Image Synthesis}, | ||
author={Patrick Esser and Robin Rombach and Björn Ommer}, | ||
year={2021}, | ||
eprint={2012.09841}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CV} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class Codebook(nn.Module): | ||
def __init__(self, args): | ||
super(Codebook, self).__init__() | ||
self.num_codebook_vectors = args.num_codebook_vectors#1024 | ||
self.latent_dim = args.latent_dim#256 | ||
self.beta = args.beta#0.25 | ||
|
||
self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)#1024,256 | ||
self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors) | ||
|
||
def forward(self, z):# z_hat:([8, 256, 8, 8]) | ||
z = z.permute(0, 2, 3, 1).contiguous()#([8, 8, 8, 256]) | ||
z_flattened = z.view(-1, self.latent_dim)#([512, 256]) 64,256 | ||
|
||
d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ | ||
torch.sum(self.embedding.weight**2, dim=1) - \ | ||
2*(torch.matmul(z_flattened, self.embedding.weight.t()))#[512, 1024]) | ||
|
||
min_encoding_indices = torch.argmin(d, dim=1)# 512 pick one in 1024 for each in 512 | ||
z_q = self.embedding(min_encoding_indices).view(z.shape)#([8, 8, 8, 256]) | ||
|
||
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2) | ||
#z_q = z_q.detach() | ||
z_q = z + (z_q - z).detach()#([8, 8, 8, 256]) copy gradients,foward = zq,backward = z | ||
|
||
z_q = z_q.permute(0, 3, 1, 2)#([8, 256, 8, 8]) | ||
|
||
return z_q, min_encoding_indices, loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
put images here | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch.nn as nn | ||
from helper import ResidualBlock, NonLocalBlock, UpSampleBlock, GroupNorm, Swish | ||
|
||
|
||
class Decoder(nn.Module): | ||
def __init__(self, args): | ||
super(Decoder, self).__init__() | ||
channels = [512, 256, 256, 128, 128] | ||
attn_resolutions = [8]#16 | ||
num_res_blocks = 3 | ||
resolution = 16 | ||
|
||
in_channels = channels[0] | ||
layers = [nn.Conv2d(args.latent_dim, in_channels, 3, 1, 1), | ||
ResidualBlock(in_channels, in_channels), | ||
NonLocalBlock(in_channels), | ||
ResidualBlock(in_channels, in_channels)] | ||
|
||
for i in range(len(channels)): | ||
out_channels = channels[i] | ||
for j in range(num_res_blocks): | ||
layers.append(ResidualBlock(in_channels, out_channels)) | ||
in_channels = out_channels | ||
if resolution in attn_resolutions: | ||
layers.append(NonLocalBlock(in_channels)) | ||
if i != 0: | ||
layers.append(UpSampleBlock(in_channels)) | ||
resolution *= 2 | ||
|
||
layers.append(GroupNorm(in_channels)) | ||
layers.append(Swish()) | ||
layers.append(nn.Conv2d(in_channels, args.image_channels, 3, 1, 1)) | ||
self.model = nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
PatchGAN Discriminator (https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538) | ||
""" | ||
|
||
import torch.nn as nn | ||
|
||
|
||
class Discriminator(nn.Module): | ||
def __init__(self, args, num_filters_last=64, n_layers=3): | ||
super(Discriminator, self).__init__() | ||
|
||
layers = [nn.Conv2d(args.image_channels, num_filters_last, 4, 2, 1), nn.LeakyReLU(0.2)] | ||
num_filters_mult = 1 | ||
|
||
for i in range(1, n_layers + 1): | ||
num_filters_mult_last = num_filters_mult | ||
num_filters_mult = min(2 ** i, 8) | ||
layers += [ | ||
nn.Conv2d(num_filters_last * num_filters_mult_last, num_filters_last * num_filters_mult, 4, | ||
2 if i < n_layers else 1, 1, bias=False), | ||
nn.BatchNorm2d(num_filters_last * num_filters_mult), | ||
nn.LeakyReLU(0.2, True) | ||
] | ||
|
||
layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1)) | ||
self.model = nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
return self.model(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch.nn as nn | ||
from helper import ResidualBlock, NonLocalBlock, DownSampleBlock, UpSampleBlock, GroupNorm, Swish | ||
|
||
|
||
class Encoder(nn.Module): | ||
def __init__(self, args): | ||
super(Encoder, self).__init__() | ||
channels = [128, 128, 128, 256, 256, 512] | ||
attn_resolutions = [8]#16 | ||
num_res_blocks = 2 | ||
resolution = 256 | ||
layers = [nn.Conv2d(args.image_channels, channels[0], 3, 1, 1)]# 3,128 | ||
for i in range(len(channels)-1): | ||
in_channels = channels[i] | ||
out_channels = channels[i + 1] | ||
for j in range(num_res_blocks): | ||
layers.append(ResidualBlock(in_channels, out_channels)) | ||
in_channels = out_channels | ||
if resolution in attn_resolutions: | ||
layers.append(NonLocalBlock(in_channels)) | ||
if i != len(channels)-2: | ||
layers.append(DownSampleBlock(channels[i+1])) | ||
resolution //= 2 | ||
layers.append(ResidualBlock(channels[-1], channels[-1]))#UNET first half | ||
layers.append(NonLocalBlock(channels[-1])) | ||
layers.append(ResidualBlock(channels[-1], channels[-1])) | ||
layers.append(GroupNorm(channels[-1])) | ||
layers.append(Swish()) | ||
layers.append(nn.Conv2d(channels[-1], args.latent_dim, 3, 1, 1)) | ||
self.model = nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
return self.model(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class GroupNorm(nn.Module): | ||
def __init__(self, channels): | ||
super(GroupNorm, self).__init__() | ||
self.gn = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True) | ||
|
||
def forward(self, x): | ||
return self.gn(x) | ||
|
||
|
||
class Swish(nn.Module): | ||
def forward(self, x): | ||
return x * torch.sigmoid(x) | ||
|
||
|
||
class ResidualBlock(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(ResidualBlock, self).__init__() | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.block = nn.Sequential( | ||
GroupNorm(in_channels), | ||
Swish(), | ||
nn.Conv2d(in_channels, out_channels, 3, 1, 1), | ||
GroupNorm(out_channels), | ||
Swish(), | ||
nn.Conv2d(out_channels, out_channels, 3, 1, 1) | ||
) | ||
|
||
if in_channels != out_channels: | ||
self.channel_up = nn.Conv2d(in_channels, out_channels, 1, 1, 0) | ||
|
||
def forward(self, x): | ||
if self.in_channels != self.out_channels: | ||
return self.channel_up(x) + self.block(x) | ||
else: | ||
return x + self.block(x) | ||
|
||
|
||
class UpSampleBlock(nn.Module): | ||
def __init__(self, channels): | ||
super(UpSampleBlock, self).__init__() | ||
self.conv = nn.Conv2d(channels, channels, 3, 1, 1) | ||
|
||
def forward(self, x): | ||
x = F.interpolate(x, scale_factor=2.0) | ||
return self.conv(x) | ||
|
||
|
||
class DownSampleBlock(nn.Module): | ||
def __init__(self, channels): | ||
super(DownSampleBlock, self).__init__() | ||
self.conv = nn.Conv2d(channels, channels, 3, 2, 0) | ||
|
||
def forward(self, x): | ||
pad = (0, 1, 0, 1) | ||
x = F.pad(x, pad, mode="constant", value=0) | ||
return self.conv(x) | ||
|
||
|
||
class NonLocalBlock(nn.Module): | ||
def __init__(self, channels): | ||
super(NonLocalBlock, self).__init__() | ||
self.in_channels = channels | ||
|
||
self.gn = GroupNorm(channels) | ||
self.q = nn.Conv2d(channels, channels, 1, 1, 0) | ||
self.k = nn.Conv2d(channels, channels, 1, 1, 0) | ||
self.v = nn.Conv2d(channels, channels, 1, 1, 0) | ||
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0) | ||
|
||
def forward(self, x): | ||
h_ = self.gn(x) | ||
q = self.q(h_) | ||
k = self.k(h_) | ||
v = self.v(h_) | ||
|
||
b, c, h, w = q.shape | ||
|
||
q = q.reshape(b, c, h*w) | ||
q = q.permute(0, 2, 1) | ||
k = k.reshape(b, c, h*w) | ||
v = v.reshape(b, c, h*w) | ||
|
||
attn = torch.bmm(q, k) | ||
attn = attn * (int(c)**(-0.5)) | ||
attn = F.softmax(attn, dim=2) | ||
attn = attn.permute(0, 2, 1) | ||
|
||
A = torch.bmm(v, attn) | ||
A = A.reshape(b, c, h, w) | ||
|
||
return x + A | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
Oops, something went wrong.