Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWang1900 authored Jan 16, 2024
0 parents commit de635ed
Show file tree
Hide file tree
Showing 18 changed files with 1,376 additions and 0 deletions.
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2022 Dominic Rampas

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
67 changes: 67 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
## Note:
Code Tutorial + Implementation Tutorial

<a href="https://www.youtube.com/watch?v=wcqLFDXaDO8">
<img alt="Qries" src="https://user-images.githubusercontent.com/61938694/154516539-90e2d4d0-4383-41f4-ad32-4c6d67bd2442.jpg"
width="300">
</a>

<a href="https://www.youtube.com/watch?v=_Br5WRwUz_U">
<img alt="Qries" src="https://user-images.githubusercontent.com/61938694/154628085-eede604f-442d-4bdb-a1ed-5ad3264e5aa0.jpg"
width="300">
</a>

# VQGAN
Vector Quantized Generative Adversarial Networks (VQGAN) is a generative model for image modeling. It was introduced in [Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841). The concept is build upon two stages. The first stage learns in an autoencoder-like fashion by encoding images into a low-dimensional latent space, then applying vector quantization by making use of a codebook. Afterwards, the quantized latent vectors are projected back to the original image space by using a decoder. Encoder and Decoder are fully convolutional. The second stage is learning a transformer for the latent space. Over the course of training it learns which codebook vectors go along together and which not. This can then be used in an autoregressive fashion to generate before unseen images from the data distribution.

## Results for First Stage (Reconstruction):


### 1. Epoch:

<img src="https://user-images.githubusercontent.com/61938694/154057590-3f457a92-42dd-4912-bb1e-9278a6ae99cc.jpg" width="500">


### 50. Epoch:

<img src="https://user-images.githubusercontent.com/61938694/154057511-266fa6ce-5c45-4660-b669-1dca0841823f.jpg" width="500">



## Results for Second Stage (Generating new Images):
Original Left | Reconstruction Middle Left | Completion Middle Right | New Image Right
### 1. Epoch:

<img src="https://user-images.githubusercontent.com/61938694/154058167-9627c71c-d180-449a-ba18-19a85843cee2.jpg" width="500">

### 100. Epoch:

<img src="https://user-images.githubusercontent.com/61938694/154058563-700292b6-8fbb-4ba1-b4d7-5955030e4489.jpg" width="500">

Note: Let the model train for even longer to get better results.

<hr>

## Train VQGAN on your own data:
### Training First Stage
1. (optional) Configure Hyperparameters in ```training_vqgan.py```
2. Set path to dataset in ```training_vqgan.py```
3. ```python training_vqgan.py```

### Training Second Stage
1. (optional) Configure Hyperparameters in ```training_transformer.py```
2. Set path to dataset in ```training_transformer.py```
3. ```python training_transformer.py```


## Citation
```bibtex
@misc{esser2021taming,
title={Taming Transformers for High-Resolution Image Synthesis},
author={Patrick Esser and Robin Rombach and Björn Ommer},
year={2021},
eprint={2012.09841},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
32 changes: 32 additions & 0 deletions codebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torch.nn as nn


class Codebook(nn.Module):
def __init__(self, args):
super(Codebook, self).__init__()
self.num_codebook_vectors = args.num_codebook_vectors#1024
self.latent_dim = args.latent_dim#256
self.beta = args.beta#0.25

self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)#1024,256
self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors)

def forward(self, z):# z_hat:([8, 256, 8, 8])
z = z.permute(0, 2, 3, 1).contiguous()#([8, 8, 8, 256])
z_flattened = z.view(-1, self.latent_dim)#([512, 256]) 64,256

d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - \
2*(torch.matmul(z_flattened, self.embedding.weight.t()))#[512, 1024])

min_encoding_indices = torch.argmin(d, dim=1)# 512 pick one in 1024 for each in 512
z_q = self.embedding(min_encoding_indices).view(z.shape)#([8, 8, 8, 256])

loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
#z_q = z_q.detach()
z_q = z + (z_q - z).detach()#([8, 8, 8, 256]) copy gradients,foward = zq,backward = z

z_q = z_q.permute(0, 3, 1, 2)#([8, 256, 8, 8])

return z_q, min_encoding_indices, loss
2 changes: 2 additions & 0 deletions data/readme
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
put images here

37 changes: 37 additions & 0 deletions decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch.nn as nn
from helper import ResidualBlock, NonLocalBlock, UpSampleBlock, GroupNorm, Swish


class Decoder(nn.Module):
def __init__(self, args):
super(Decoder, self).__init__()
channels = [512, 256, 256, 128, 128]
attn_resolutions = [8]#16
num_res_blocks = 3
resolution = 16

in_channels = channels[0]
layers = [nn.Conv2d(args.latent_dim, in_channels, 3, 1, 1),
ResidualBlock(in_channels, in_channels),
NonLocalBlock(in_channels),
ResidualBlock(in_channels, in_channels)]

for i in range(len(channels)):
out_channels = channels[i]
for j in range(num_res_blocks):
layers.append(ResidualBlock(in_channels, out_channels))
in_channels = out_channels
if resolution in attn_resolutions:
layers.append(NonLocalBlock(in_channels))
if i != 0:
layers.append(UpSampleBlock(in_channels))
resolution *= 2

layers.append(GroupNorm(in_channels))
layers.append(Swish())
layers.append(nn.Conv2d(in_channels, args.image_channels, 3, 1, 1))
self.model = nn.Sequential(*layers)

def forward(self, x):
return self.model(x)

29 changes: 29 additions & 0 deletions discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
PatchGAN Discriminator (https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538)
"""

import torch.nn as nn


class Discriminator(nn.Module):
def __init__(self, args, num_filters_last=64, n_layers=3):
super(Discriminator, self).__init__()

layers = [nn.Conv2d(args.image_channels, num_filters_last, 4, 2, 1), nn.LeakyReLU(0.2)]
num_filters_mult = 1

for i in range(1, n_layers + 1):
num_filters_mult_last = num_filters_mult
num_filters_mult = min(2 ** i, 8)
layers += [
nn.Conv2d(num_filters_last * num_filters_mult_last, num_filters_last * num_filters_mult, 4,
2 if i < n_layers else 1, 1, bias=False),
nn.BatchNorm2d(num_filters_last * num_filters_mult),
nn.LeakyReLU(0.2, True)
]

layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
self.model = nn.Sequential(*layers)

def forward(self, x):
return self.model(x)
33 changes: 33 additions & 0 deletions encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch.nn as nn
from helper import ResidualBlock, NonLocalBlock, DownSampleBlock, UpSampleBlock, GroupNorm, Swish


class Encoder(nn.Module):
def __init__(self, args):
super(Encoder, self).__init__()
channels = [128, 128, 128, 256, 256, 512]
attn_resolutions = [8]#16
num_res_blocks = 2
resolution = 256
layers = [nn.Conv2d(args.image_channels, channels[0], 3, 1, 1)]# 3,128
for i in range(len(channels)-1):
in_channels = channels[i]
out_channels = channels[i + 1]
for j in range(num_res_blocks):
layers.append(ResidualBlock(in_channels, out_channels))
in_channels = out_channels
if resolution in attn_resolutions:
layers.append(NonLocalBlock(in_channels))
if i != len(channels)-2:
layers.append(DownSampleBlock(channels[i+1]))
resolution //= 2
layers.append(ResidualBlock(channels[-1], channels[-1]))#UNET first half
layers.append(NonLocalBlock(channels[-1]))
layers.append(ResidualBlock(channels[-1], channels[-1]))
layers.append(GroupNorm(channels[-1]))
layers.append(Swish())
layers.append(nn.Conv2d(channels[-1], args.latent_dim, 3, 1, 1))
self.model = nn.Sequential(*layers)

def forward(self, x):
return self.model(x)
109 changes: 109 additions & 0 deletions helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class GroupNorm(nn.Module):
def __init__(self, channels):
super(GroupNorm, self).__init__()
self.gn = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)

def forward(self, x):
return self.gn(x)


class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)


class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.block = nn.Sequential(
GroupNorm(in_channels),
Swish(),
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
GroupNorm(out_channels),
Swish(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1)
)

if in_channels != out_channels:
self.channel_up = nn.Conv2d(in_channels, out_channels, 1, 1, 0)

def forward(self, x):
if self.in_channels != self.out_channels:
return self.channel_up(x) + self.block(x)
else:
return x + self.block(x)


class UpSampleBlock(nn.Module):
def __init__(self, channels):
super(UpSampleBlock, self).__init__()
self.conv = nn.Conv2d(channels, channels, 3, 1, 1)

def forward(self, x):
x = F.interpolate(x, scale_factor=2.0)
return self.conv(x)


class DownSampleBlock(nn.Module):
def __init__(self, channels):
super(DownSampleBlock, self).__init__()
self.conv = nn.Conv2d(channels, channels, 3, 2, 0)

def forward(self, x):
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
return self.conv(x)


class NonLocalBlock(nn.Module):
def __init__(self, channels):
super(NonLocalBlock, self).__init__()
self.in_channels = channels

self.gn = GroupNorm(channels)
self.q = nn.Conv2d(channels, channels, 1, 1, 0)
self.k = nn.Conv2d(channels, channels, 1, 1, 0)
self.v = nn.Conv2d(channels, channels, 1, 1, 0)
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0)

def forward(self, x):
h_ = self.gn(x)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)

b, c, h, w = q.shape

q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w)
v = v.reshape(b, c, h*w)

attn = torch.bmm(q, k)
attn = attn * (int(c)**(-0.5))
attn = F.softmax(attn, dim=2)
attn = attn.permute(0, 2, 1)

A = torch.bmm(v, attn)
A = A.reshape(b, c, h, w)

return x + A












Loading

0 comments on commit de635ed

Please sign in to comment.