From de635ed531aed1855dc17981f41c5b5698667f60 Mon Sep 17 00:00:00 2001 From: AlexWang1900 <60679873+AlexWang1900@users.noreply.github.com> Date: Wed, 17 Jan 2024 00:46:24 +0800 Subject: [PATCH] Add files via upload --- LICENSE | 21 +++++ README.md | 67 ++++++++++++++++ codebook.py | 32 ++++++++ data/readme | 2 + decoder.py | 37 +++++++++ discriminator.py | 29 +++++++ encoder.py | 33 ++++++++ helper.py | 109 +++++++++++++++++++++++++ lpips.py | 162 +++++++++++++++++++++++++++++++++++++ lpips_old.py | 149 ++++++++++++++++++++++++++++++++++ mingpt.py | 171 ++++++++++++++++++++++++++++++++++++++++ sample_transformer.py | 49 ++++++++++++ training_transformer.py | 99 +++++++++++++++++++++++ training_vqgan.py | 120 ++++++++++++++++++++++++++++ transformer.py | 132 +++++++++++++++++++++++++++++++ utils.py | 102 ++++++++++++++++++++++++ vgg_lpips/vgg.pth | Bin 0 -> 7289 bytes vqgan.py | 62 +++++++++++++++ 18 files changed, 1376 insertions(+) create mode 100644 LICENSE create mode 100644 README.md create mode 100644 codebook.py create mode 100644 data/readme create mode 100644 decoder.py create mode 100644 discriminator.py create mode 100644 encoder.py create mode 100644 helper.py create mode 100644 lpips.py create mode 100644 lpips_old.py create mode 100644 mingpt.py create mode 100644 sample_transformer.py create mode 100644 training_transformer.py create mode 100644 training_vqgan.py create mode 100644 transformer.py create mode 100644 utils.py create mode 100644 vgg_lpips/vgg.pth create mode 100644 vqgan.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e9fd3ec --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Dominic Rampas + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..f940351 --- /dev/null +++ b/README.md @@ -0,0 +1,67 @@ +## Note: +Code Tutorial + Implementation Tutorial + + + Qries + + + + Qries + + +# VQGAN +Vector Quantized Generative Adversarial Networks (VQGAN) is a generative model for image modeling. It was introduced in [Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841). The concept is build upon two stages. The first stage learns in an autoencoder-like fashion by encoding images into a low-dimensional latent space, then applying vector quantization by making use of a codebook. Afterwards, the quantized latent vectors are projected back to the original image space by using a decoder. Encoder and Decoder are fully convolutional. The second stage is learning a transformer for the latent space. Over the course of training it learns which codebook vectors go along together and which not. This can then be used in an autoregressive fashion to generate before unseen images from the data distribution. + +## Results for First Stage (Reconstruction): + + +### 1. Epoch: + + + + +### 50. Epoch: + + + + + +## Results for Second Stage (Generating new Images): +Original Left | Reconstruction Middle Left | Completion Middle Right | New Image Right +### 1. Epoch: + + + +### 100. Epoch: + + + +Note: Let the model train for even longer to get better results. + +
+ +## Train VQGAN on your own data: +### Training First Stage +1. (optional) Configure Hyperparameters in ```training_vqgan.py``` +2. Set path to dataset in ```training_vqgan.py``` +3. ```python training_vqgan.py``` + +### Training Second Stage +1. (optional) Configure Hyperparameters in ```training_transformer.py``` +2. Set path to dataset in ```training_transformer.py``` +3. ```python training_transformer.py``` + + +## Citation +```bibtex +@misc{esser2021taming, + title={Taming Transformers for High-Resolution Image Synthesis}, + author={Patrick Esser and Robin Rombach and Björn Ommer}, + year={2021}, + eprint={2012.09841}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` diff --git a/codebook.py b/codebook.py new file mode 100644 index 0000000..c03b427 --- /dev/null +++ b/codebook.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + + +class Codebook(nn.Module): + def __init__(self, args): + super(Codebook, self).__init__() + self.num_codebook_vectors = args.num_codebook_vectors#1024 + self.latent_dim = args.latent_dim#256 + self.beta = args.beta#0.25 + + self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)#1024,256 + self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors) + + def forward(self, z):# z_hat:([8, 256, 8, 8]) + z = z.permute(0, 2, 3, 1).contiguous()#([8, 8, 8, 256]) + z_flattened = z.view(-1, self.latent_dim)#([512, 256]) 64,256 + + d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - \ + 2*(torch.matmul(z_flattened, self.embedding.weight.t()))#[512, 1024]) + + min_encoding_indices = torch.argmin(d, dim=1)# 512 pick one in 1024 for each in 512 + z_q = self.embedding(min_encoding_indices).view(z.shape)#([8, 8, 8, 256]) + + loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2) + #z_q = z_q.detach() + z_q = z + (z_q - z).detach()#([8, 8, 8, 256]) copy gradients,foward = zq,backward = z + + z_q = z_q.permute(0, 3, 1, 2)#([8, 256, 8, 8]) + + return z_q, min_encoding_indices, loss diff --git a/data/readme b/data/readme new file mode 100644 index 0000000..4443e3c --- /dev/null +++ b/data/readme @@ -0,0 +1,2 @@ +put images here + diff --git a/decoder.py b/decoder.py new file mode 100644 index 0000000..ca65821 --- /dev/null +++ b/decoder.py @@ -0,0 +1,37 @@ +import torch.nn as nn +from helper import ResidualBlock, NonLocalBlock, UpSampleBlock, GroupNorm, Swish + + +class Decoder(nn.Module): + def __init__(self, args): + super(Decoder, self).__init__() + channels = [512, 256, 256, 128, 128] + attn_resolutions = [8]#16 + num_res_blocks = 3 + resolution = 16 + + in_channels = channels[0] + layers = [nn.Conv2d(args.latent_dim, in_channels, 3, 1, 1), + ResidualBlock(in_channels, in_channels), + NonLocalBlock(in_channels), + ResidualBlock(in_channels, in_channels)] + + for i in range(len(channels)): + out_channels = channels[i] + for j in range(num_res_blocks): + layers.append(ResidualBlock(in_channels, out_channels)) + in_channels = out_channels + if resolution in attn_resolutions: + layers.append(NonLocalBlock(in_channels)) + if i != 0: + layers.append(UpSampleBlock(in_channels)) + resolution *= 2 + + layers.append(GroupNorm(in_channels)) + layers.append(Swish()) + layers.append(nn.Conv2d(in_channels, args.image_channels, 3, 1, 1)) + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + diff --git a/discriminator.py b/discriminator.py new file mode 100644 index 0000000..addc605 --- /dev/null +++ b/discriminator.py @@ -0,0 +1,29 @@ +""" +PatchGAN Discriminator (https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538) +""" + +import torch.nn as nn + + +class Discriminator(nn.Module): + def __init__(self, args, num_filters_last=64, n_layers=3): + super(Discriminator, self).__init__() + + layers = [nn.Conv2d(args.image_channels, num_filters_last, 4, 2, 1), nn.LeakyReLU(0.2)] + num_filters_mult = 1 + + for i in range(1, n_layers + 1): + num_filters_mult_last = num_filters_mult + num_filters_mult = min(2 ** i, 8) + layers += [ + nn.Conv2d(num_filters_last * num_filters_mult_last, num_filters_last * num_filters_mult, 4, + 2 if i < n_layers else 1, 1, bias=False), + nn.BatchNorm2d(num_filters_last * num_filters_mult), + nn.LeakyReLU(0.2, True) + ] + + layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1)) + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) diff --git a/encoder.py b/encoder.py new file mode 100644 index 0000000..fb25be9 --- /dev/null +++ b/encoder.py @@ -0,0 +1,33 @@ +import torch.nn as nn +from helper import ResidualBlock, NonLocalBlock, DownSampleBlock, UpSampleBlock, GroupNorm, Swish + + +class Encoder(nn.Module): + def __init__(self, args): + super(Encoder, self).__init__() + channels = [128, 128, 128, 256, 256, 512] + attn_resolutions = [8]#16 + num_res_blocks = 2 + resolution = 256 + layers = [nn.Conv2d(args.image_channels, channels[0], 3, 1, 1)]# 3,128 + for i in range(len(channels)-1): + in_channels = channels[i] + out_channels = channels[i + 1] + for j in range(num_res_blocks): + layers.append(ResidualBlock(in_channels, out_channels)) + in_channels = out_channels + if resolution in attn_resolutions: + layers.append(NonLocalBlock(in_channels)) + if i != len(channels)-2: + layers.append(DownSampleBlock(channels[i+1])) + resolution //= 2 + layers.append(ResidualBlock(channels[-1], channels[-1]))#UNET first half + layers.append(NonLocalBlock(channels[-1])) + layers.append(ResidualBlock(channels[-1], channels[-1])) + layers.append(GroupNorm(channels[-1])) + layers.append(Swish()) + layers.append(nn.Conv2d(channels[-1], args.latent_dim, 3, 1, 1)) + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) \ No newline at end of file diff --git a/helper.py b/helper.py new file mode 100644 index 0000000..bea49b5 --- /dev/null +++ b/helper.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GroupNorm(nn.Module): + def __init__(self, channels): + super(GroupNorm, self).__init__() + self.gn = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True) + + def forward(self, x): + return self.gn(x) + + +class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(ResidualBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.block = nn.Sequential( + GroupNorm(in_channels), + Swish(), + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + GroupNorm(out_channels), + Swish(), + nn.Conv2d(out_channels, out_channels, 3, 1, 1) + ) + + if in_channels != out_channels: + self.channel_up = nn.Conv2d(in_channels, out_channels, 1, 1, 0) + + def forward(self, x): + if self.in_channels != self.out_channels: + return self.channel_up(x) + self.block(x) + else: + return x + self.block(x) + + +class UpSampleBlock(nn.Module): + def __init__(self, channels): + super(UpSampleBlock, self).__init__() + self.conv = nn.Conv2d(channels, channels, 3, 1, 1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0) + return self.conv(x) + + +class DownSampleBlock(nn.Module): + def __init__(self, channels): + super(DownSampleBlock, self).__init__() + self.conv = nn.Conv2d(channels, channels, 3, 2, 0) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + return self.conv(x) + + +class NonLocalBlock(nn.Module): + def __init__(self, channels): + super(NonLocalBlock, self).__init__() + self.in_channels = channels + + self.gn = GroupNorm(channels) + self.q = nn.Conv2d(channels, channels, 1, 1, 0) + self.k = nn.Conv2d(channels, channels, 1, 1, 0) + self.v = nn.Conv2d(channels, channels, 1, 1, 0) + self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0) + + def forward(self, x): + h_ = self.gn(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h*w) + v = v.reshape(b, c, h*w) + + attn = torch.bmm(q, k) + attn = attn * (int(c)**(-0.5)) + attn = F.softmax(attn, dim=2) + attn = attn.permute(0, 2, 1) + + A = torch.bmm(v, attn) + A = A.reshape(b, c, h, w) + + return x + A + + + + + + + + + + + + diff --git a/lpips.py b/lpips.py new file mode 100644 index 0000000..953807d --- /dev/null +++ b/lpips.py @@ -0,0 +1,162 @@ +import os +import torch +import torch.nn as nn +#from torchvision.models import vgg16 +from torchvision import models +from collections import namedtuple +import requests +from tqdm import tqdm +import hashlib + + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + +def get_ckpt_path(name, root): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path): + print(f"Downloading {name} model from {URL_MAP[name]} to {path}") + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False)#weights=VGG16_Weights.DEFAULT + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "vgg_lpips") + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_FEATURES).features#weights=models.VGG16_Weights.IMAGENET1K_FEATURES #pretrained=pretrained + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2,3],keepdim=keepdim) diff --git a/lpips_old.py b/lpips_old.py new file mode 100644 index 0000000..3105e0f --- /dev/null +++ b/lpips_old.py @@ -0,0 +1,149 @@ +import os +import torch +import torch.nn as nn +from torchvision.models import vgg16 +from collections import namedtuple +import requests +from tqdm import tqdm +import hashlib + + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + +def get_ckpt_path(name, root): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path): + print(f"Downloading {name} model from {URL_MAP[name]} to {path}") + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + def __init__(self): + super(LPIPS, self).__init__() + self.scaling_layer = ScalingLayer() + self.channels = [64, 128, 256, 512, 512] + self.vgg = VGG16() + self.lins = nn.ModuleList([ + NetLinLayer(self.channels[0]), + NetLinLayer(self.channels[1]), + NetLinLayer(self.channels[2]), + NetLinLayer(self.channels[3]), + NetLinLayer(self.channels[4]) + ]) + + self.load_from_pretrained() + + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "vgg_lpips") # add by wang for test + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=True) + pass + + def forward(self, real_x, fake_x): + features_real = self.vgg(self.scaling_layer(real_x)) + features_fake = self.vgg(self.scaling_layer(fake_x)) + diffs = {} + + for i in range(len(self.channels)): + diffs[i] = (norm_tensor(features_real[i]) - norm_tensor(features_fake[i])) ** 2 + + return sum([spatial_average(self.lins[i].model(diffs[i])) for i in range(len(self.channels))]) + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer("shift", torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer("scale", torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, x): + return (x - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + def __init__(self, in_channels, out_channels=1): + super(NetLinLayer, self).__init__() + self.model = nn.Sequential( + nn.Dropout(), + nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False) + ) + + +class VGG16(nn.Module): + def __init__(self): + super(VGG16, self).__init__() + vgg_pretrained_features = vgg16(pretrained=True).features + slices = [vgg_pretrained_features[i] for i in range(30)] + self.slice1 = nn.Sequential(*slices[0:4]) + self.slice2 = nn.Sequential(*slices[4:9]) + self.slice3 = nn.Sequential(*slices[9:16]) + self.slice4 = nn.Sequential(*slices[16:23]) + self.slice5 = nn.Sequential(*slices[23:30]) + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + h = self.slice1(x) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + vgg_outputs = namedtuple("VGGOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + return vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + +def norm_tensor(x): + """ + Normalize images by their length to make them unit vector? + :param x: batch of images + :return: normalized batch of images + """ + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + 1e-10) + + +def spatial_average(x): + """ + imgs have: batch_size x channels x width x height --> average over width and height channel + :param x: batch of images + :return: averaged images along width and height + """ + return x.mean([2, 3], keepdim=True) diff --git a/mingpt.py b/mingpt.py new file mode 100644 index 0000000..af88d5c --- /dev/null +++ b/mingpt.py @@ -0,0 +1,171 @@ +""" +taken from: https://github.com/karpathy/minGPT/ +GPT model: +- the initial stem consists of a combination of token encoding and a positional encoding +- the meat of it is a uniform sequence of Transformer blocks + - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block + - all blocks feed into a central residual pathway similar to resnets +- the final decoder is a linear projection into a vanilla Softmax classifier +""" + +import math +import torch +import torch.nn as nn +from torch.nn import functional as F + + +class GPTConfig: + """ base GPT config, params common to all GPT versions """ + embd_pdrop = 0.1 + resid_pdrop = 0.1 + attn_pdrop = 0.1 + + def __init__(self, vocab_size, block_size, **kwargs): + self.vocab_size = vocab_size + self.block_size = block_size + for k, v in kwargs.items(): + setattr(self, k, v) + + +class CausalSelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(config.n_embd, config.n_embd) + self.query = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd) + # causal mask to ensure that attention is only applied to the left in the input sequence + mask = torch.tril(torch.ones(config.block_size, + config.block_size)) + if hasattr(config, "n_unmasked"): + mask[:config.n_unmasked, :config.n_unmasked] = 1 + self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) + self.n_head = config.n_head + + def forward(self, x, layer_past=None): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + present = torch.stack((k, v)) + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + if layer_past is None: + att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) + + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y, present # TODO: check that this does not break anything + + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, config): + super().__init__() + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.mlp = nn.Sequential( + nn.Linear(config.n_embd, 4 * config.n_embd), + nn.GELU(), # nice + nn.Linear(4 * config.n_embd, config.n_embd), + nn.Dropout(config.resid_pdrop), + ) + + def forward(self, x, layer_past=None, return_present=False): + # TODO: check that training still works + if return_present: + assert not self.training + # layer past: tuple of length two with B, nh, T, hs + attn, present = self.attn(self.ln1(x), layer_past=layer_past) + + x = x + attn + x = x + self.mlp(self.ln2(x)) + if layer_past is not None or return_present: + return x, present + return x + + +class GPT(nn.Module): + """ the full GPT language model, with a context size of block_size """ + + def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256, + embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): + super().__init__() + config = GPTConfig(vocab_size=vocab_size, block_size=block_size, + embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, + n_layer=n_layer, n_head=n_head, n_embd=n_embd, + n_unmasked=n_unmasked) + # input embedding stem + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) # 512 x 1024 + self.drop = nn.Dropout(config.embd_pdrop) + # transformer + self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + # decoder head + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + self.block_size = config.block_size + self.apply(self._init_weights) + self.config = config + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, idx, embeddings=None): + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + + if embeddings is not None: # prepend explicit embeddings + token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) + + t = token_embeddings.shape[1] + assert t <= self.block_size, "Cannot forward, model block size is exhausted." + position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector + x = self.drop(token_embeddings + position_embeddings) + x = self.blocks(x) + x = self.ln_f(x) + logits = self.head(x) + + return logits, None + + + + + + diff --git a/sample_transformer.py b/sample_transformer.py new file mode 100644 index 0000000..39a688f --- /dev/null +++ b/sample_transformer.py @@ -0,0 +1,49 @@ +import os +import argparse +import torch +from torchvision import utils as vutils +from transformer import VQGANTransformer +from tqdm import tqdm + + +parser = argparse.ArgumentParser(description="VQGAN") +parser.add_argument('--latent-dim', type=int, default=256, help='Latent dimension n_z.') +parser.add_argument('--image-size', type=int, default=128, help='Image height and width.)') +parser.add_argument('--num-codebook-vectors', type=int, default=1024, help='Number of codebook vectors.') +parser.add_argument('--beta', type=float, default=0.25, help='Commitment loss scalar.') +parser.add_argument('--image-channels', type=int, default=3, help='Number of channels of images.') +parser.add_argument('--dataset-path', type=str, default='./data', help='Path to data.') +parser.add_argument('--checkpoint-path', type=str, default='./checkpoints/last_ckpt.pt', + help='Path to checkpoint.') +parser.add_argument('--device', type=str, default="cuda", help='Which device the training is on') +parser.add_argument('--batch-size', type=int, default=32, help='Input batch size for training.') +parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train.') +parser.add_argument('--learning-rate', type=float, default=2.25e-05, help='Learning rate.') +parser.add_argument('--beta1', type=float, default=0.5, help='Adam beta param.') +parser.add_argument('--beta2', type=float, default=0.9, help='Adam beta param.') +parser.add_argument('--disc-start', type=int, default=10000, help='When to start the discriminator.') +parser.add_argument('--disc-factor', type=float, default=1., help='Weighting factor for the Discriminator.') +parser.add_argument('--l2-loss-factor', type=float, default=1., + help='Weighting factor for reconstruction loss.') +parser.add_argument('--perceptual-loss-factor', type=float, default=1., + help='Weighting factor for perceptual loss.') + +parser.add_argument('--pkeep', type=float, default=0.5, help='Percentage for how much latent codes to keep.') +parser.add_argument('--sos-token', type=int, default=0, help='Start of Sentence token.') + +args = parser.parse_args() +args.dataset_path = './data/FFHQ_128' +args.checkpoint_path = "./checkpoints/vqgan_epoch_99.pt" + +n = 100 +transformer = VQGANTransformer(args).to("cuda") +transformer.load_state_dict(torch.load(os.path.join("checkpoints", "transformer_99.pt"))) +print("Loaded state dict of Transformer") + +for i in tqdm(range(n)): + start_indices = torch.zeros((4, 0)).long().to("cuda")#([1, 0]) + sos_tokens = torch.ones(start_indices.shape[0], 1) * 0#[1,1] + sos_tokens = sos_tokens.long().to("cuda") + sample_indices = transformer.sample(start_indices, sos_tokens, steps=64)#256 + sampled_imgs = transformer.z_to_image(sample_indices) + vutils.save_image(sampled_imgs.mul(0.5).add(0.5), os.path.join("results", "transformer", f"transformer_{i}.png"), nrow=4)#,format="png" diff --git a/training_transformer.py b/training_transformer.py new file mode 100644 index 0000000..fb72f6f --- /dev/null +++ b/training_transformer.py @@ -0,0 +1,99 @@ +import os +import numpy as np +from tqdm import tqdm +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import utils as vutils +from transformer import VQGANTransformer +from utils import load_data, plot_images + + +class TrainTransformer: + def __init__(self, args): + self.model = VQGANTransformer(args).to(device=args.device) + self.optim = self.configure_optimizers() + + self.train(args) + + def configure_optimizers(self): + decay, no_decay = set(), set() + whitelist_weight_modules = (nn.Linear, ) + blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) + + for mn, m in self.model.transformer.named_modules(): + for pn, p in m.named_parameters(): + fpn = f"{mn}.{pn}" if mn else pn + + if pn.endswith("bias"): + no_decay.add(fpn) + + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): + decay.add(fpn) + + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): + no_decay.add(fpn) + + no_decay.add("pos_emb") + + param_dict = {pn: p for pn, p in self.model.transformer.named_parameters()} + + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + + optimizer = torch.optim.AdamW(optim_groups, lr=4.5e-06, betas=(0.9, 0.95)) + return optimizer + + def train(self, args): + train_dataset = load_data(args) + for epoch in range(args.epochs): + with tqdm(range(len(train_dataset))) as pbar: + for i, imgs in zip(pbar, train_dataset):#([20, 3, 128, 128]) + self.optim.zero_grad() + imgs = imgs.to(device=args.device)#([20, 3, 128, 128]) + logits, targets = self.model(imgs)#([20, 64, 1024]),([20, 64]) + loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) + loss.backward() + self.optim.step() + pbar.set_postfix(Transformer_Loss=np.round(loss.cpu().detach().numpy().item(), 4)) + pbar.update(0) + log, sampled_imgs = self.model.log_images(imgs[0][None]) + vutils.save_image(sampled_imgs, os.path.join("results", f"transformer_{epoch}.jpg"), nrow=4) + #plot_images(log) + if epoch >80: + torch.save(self.model.state_dict(), os.path.join("checkpoints", f"transformer_{epoch}.pt")) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="VQGAN") + parser.add_argument('--latent-dim', type=int, default=256, help='Latent dimension n_z.') + parser.add_argument('--image-size', type=int, default=128, help='Image height and width.)') + parser.add_argument('--num-codebook-vectors', type=int, default=1024, help='Number of codebook vectors.') + parser.add_argument('--beta', type=float, default=0.25, help='Commitment loss scalar.') + parser.add_argument('--image-channels', type=int, default=3, help='Number of channels of images.') + parser.add_argument('--dataset-path', type=str, default='./data', help='Path to data.') + parser.add_argument('--checkpoint-path', type=str, default='./checkpoints/last_ckpt.pt', help='Path to checkpoint.') + parser.add_argument('--device', type=str, default="cuda", help='Which device the training is on') + parser.add_argument('--batch-size', type=int, default=32, help='Input batch size for training.')#20 + parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train.') + parser.add_argument('--learning-rate', type=float, default=2.25e-05, help='Learning rate.') + parser.add_argument('--beta1', type=float, default=0.5, help='Adam beta param.') + parser.add_argument('--beta2', type=float, default=0.9, help='Adam beta param.') + parser.add_argument('--disc-start', type=int, default=10000, help='When to start the discriminator.') + parser.add_argument('--disc-factor', type=float, default=1., help='Weighting factor for the Discriminator.') + parser.add_argument('--l2-loss-factor', type=float, default=1., help='Weighting factor for reconstruction loss.') + parser.add_argument('--perceptual-loss-factor', type=float, default=1., help='Weighting factor for perceptual loss.') + + parser.add_argument('--pkeep', type=float, default=0.5, help='Percentage for how much latent codes to keep.') + parser.add_argument('--sos-token', type=int, default=0, help='Start of Sentence token.') + + args = parser.parse_args() + args.dataset_path = './data/FFHQ_128'#r"C:\Users\dome\datasets\flowers" + args.checkpoint_path = './checkpoints/vqgan_epoch_99.pt'#r".\checkpoints\vqgan_last_ckpt.pt" + + train_transformer = TrainTransformer(args) + + diff --git a/training_vqgan.py b/training_vqgan.py new file mode 100644 index 0000000..2c70799 --- /dev/null +++ b/training_vqgan.py @@ -0,0 +1,120 @@ +import os +import argparse +from tqdm import tqdm +import numpy as np +import torch +import torch.nn.functional as F +from torchvision import utils as vutils +from discriminator import Discriminator +from lpips import LPIPS +from vqgan import VQGAN +from utils import load_data, weights_init + + +class TrainVQGAN: + def __init__(self, args): + self.vqgan = VQGAN(args).to(device=args.device) + self.discriminator = Discriminator(args).to(device=args.device) + self.discriminator.apply(weights_init) + self.perceptual_loss = LPIPS().eval().to(device=args.device) + self.opt_vq, self.opt_disc = self.configure_optimizers(args) + + self.prepare_training() + + self.train(args) + + def configure_optimizers(self, args): + lr = args.learning_rate + opt_vq = torch.optim.Adam( + list(self.vqgan.encoder.parameters()) + + list(self.vqgan.decoder.parameters()) + + list(self.vqgan.codebook.parameters()) + + list(self.vqgan.quant_conv.parameters()) + + list(self.vqgan.post_quant_conv.parameters()), + lr=lr, eps=1e-08, betas=(args.beta1, args.beta2) + ) + opt_disc = torch.optim.Adam(self.discriminator.parameters(), + lr=lr, eps=1e-08, betas=(args.beta1, args.beta2)) + + return opt_vq, opt_disc + + @staticmethod + def prepare_training(): + os.makedirs("results", exist_ok=True) + os.makedirs("checkpoints", exist_ok=True) + + def train(self, args): + train_dataset = load_data(args) + steps_per_epoch = len(train_dataset)#11667 + for epoch in range(args.epochs):#100 + with tqdm(range(len(train_dataset))) as pbar: + for i, imgs in zip(pbar, train_dataset): + imgs = imgs.to(device=args.device)#[8, 3, 128, 128]) + decoded_images, _, q_loss = self.vqgan(imgs) + + disc_real = self.discriminator(imgs)#([8, 1, 14, 14]) + disc_fake = self.discriminator(decoded_images)#([8, 1, 14, 14]) + + disc_factor = self.vqgan.adopt_weight(args.disc_factor, epoch*steps_per_epoch+i, threshold=args.disc_start)#1000 + + perceptual_loss = self.perceptual_loss(imgs, decoded_images) + rec_loss = torch.abs(imgs - decoded_images) + perceptual_rec_loss = args.perceptual_loss_factor * perceptual_loss + args.rec_loss_factor * rec_loss + perceptual_rec_loss = perceptual_rec_loss.mean() + g_loss = -torch.mean(disc_fake) + + λ = self.vqgan.calculate_lambda(perceptual_rec_loss, g_loss) + vq_loss = perceptual_rec_loss + q_loss + disc_factor * λ * g_loss + + d_loss_real = torch.mean(F.relu(1. - disc_real)) + d_loss_fake = torch.mean(F.relu(1. + disc_fake)) + gan_loss = disc_factor * 0.5*(d_loss_real + d_loss_fake) + + self.opt_vq.zero_grad() + vq_loss.backward(retain_graph=True) #retain_graph=True + + self.opt_disc.zero_grad() + gan_loss.backward() + + self.opt_vq.step() + self.opt_disc.step() + + if i % 1000 == 0: + with torch.no_grad(): + real_fake_images = torch.cat((imgs[:4].mul(0.5).add(0.5), decoded_images.mul(0.5).add(0.5)[:4])) + vutils.save_image(real_fake_images, os.path.join("results", f"{epoch}_{i}.jpg"), nrow=4) + + pbar.set_postfix( + VQ_Loss=np.round(vq_loss.cpu().detach().numpy().item(), 5), + GAN_Loss=np.round(gan_loss.cpu().detach().numpy().item(), 3) + ) + pbar.update(0) + torch.save(self.vqgan.state_dict(), os.path.join("checkpoints", f"vqgan_epoch_{epoch}.pt")) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="VQGAN") + parser.add_argument('--latent-dim', type=int, default=256, help='Latent dimension n_z (default: 256)') + parser.add_argument('--image-size', type=int, default=128, help='Image height and width (default: 256)') + parser.add_argument('--num-codebook-vectors', type=int, default=1024, help='Number of codebook vectors (default: 256)') + parser.add_argument('--beta', type=float, default=0.25, help='Commitment loss scalar (default: 0.25)') + parser.add_argument('--image-channels', type=int, default=3, help='Number of channels of images (default: 3)') + parser.add_argument('--dataset-path', type=str, default='/data', help='Path to data (default: /data)') + parser.add_argument('--device', type=str, default="cuda", help='Which device the training is on') + parser.add_argument('--batch-size', type=int, default=16, help='Input batch size for training (default: 6)') + parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train (default: 50)') + parser.add_argument('--learning-rate', type=float, default=2.25e-05, help='Learning rate (default: 0.0002)') + parser.add_argument('--beta1', type=float, default=0.5, help='Adam beta param (default: 0.0)') + parser.add_argument('--beta2', type=float, default=0.9, help='Adam beta param (default: 0.999)') + parser.add_argument('--disc-start', type=int, default=10000, help='When to start the discriminator (default: 0)') + parser.add_argument('--disc-factor', type=float, default=1., help='') + parser.add_argument('--rec-loss-factor', type=float, default=1., help='Weighting factor for reconstruction loss.') + parser.add_argument('--perceptual-loss-factor', type=float, default=1., help='Weighting factor for perceptual loss.') + + args = parser.parse_args() + args.dataset_path = './data/FFHQ_128'#r"C:\Users\dome\datasets\flowers" + + train_vqgan = TrainVQGAN(args) + + + diff --git a/transformer.py b/transformer.py new file mode 100644 index 0000000..e558623 --- /dev/null +++ b/transformer.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mingpt import GPT +from vqgan import VQGAN + + +class VQGANTransformer(nn.Module): + def __init__(self, args): + super(VQGANTransformer, self).__init__() + + self.sos_token = args.sos_token #0 + + self.vqgan = self.load_vqgan(args) + + transformer_config = { + "vocab_size": args.num_codebook_vectors, + "block_size": 512, + "n_layer": 24, + "n_head": 16, + "n_embd": 1024 + } + self.transformer = GPT(**transformer_config) + + self.pkeep = args.pkeep + + @staticmethod + def load_vqgan(args): + model = VQGAN(args) + model.load_checkpoint(args.checkpoint_path) #add by wang for test + model = model.eval() + return model + + @torch.no_grad() + def encode_to_z(self, x): + quant_z, indices, _ = self.vqgan.encode(x) + indices = indices.view(quant_z.shape[0], -1) + return quant_z, indices + + @torch.no_grad() + def z_to_image(self, indices, p1=8, p2=8):#16,16 + ix_to_vectors = self.vqgan.codebook.embedding(indices).reshape(indices.shape[0], p1, p2, 256) + ix_to_vectors = ix_to_vectors.permute(0, 3, 1, 2) + image = self.vqgan.decode(ix_to_vectors) + return image + + def forward(self, x):#([20, 3, 128, 128]) + _, indices = self.encode_to_z(x)#([20, 64]) + + sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token#([20, 1]) value = 0 + sos_tokens = sos_tokens.long().to("cuda") + + mask = torch.bernoulli(self.pkeep * torch.ones(indices.shape, device=indices.device))#([20, 64]) 0.5 + mask = mask.round().to(dtype=torch.int64) + random_indices = torch.randint_like(indices, self.transformer.config.vocab_size)#([20, 64]) + new_indices = mask * indices + (1 - mask) * random_indices# mixing + + new_indices = torch.cat((sos_tokens, new_indices), dim=1)#([20, 65]) start of sentence. + + target = indices#([20, 64]) + + logits, _ = self.transformer(new_indices[:, :-1])#([20, 64, 1024]) + + return logits, target + + def top_k_logits(self, logits, k): + v, ix = torch.topk(logits, k) + out = logits.clone() + out[out < v[..., [-1]]] = -float("inf") + return out + + @torch.no_grad() + def sample(self, x, c, steps, temperature=1.0, top_k=100):#[1,0],([1, 1]),256,100 + self.transformer.eval() + x = torch.cat((c, x), dim=1)#[1,1] + for k in range(steps): + logits, _ = self.transformer(x)#[1,1,1024] + logits = logits[:, -1, :] / temperature#[1,1024] + + if top_k is not None: + logits = self.top_k_logits(logits, top_k)#[1,1024] + + probs = F.softmax(logits, dim=-1)#([1, 1024]) + + ix = torch.multinomial(probs, num_samples=1)#[1,1] + + x = torch.cat((x, ix), dim=1) + #([1, 257]) + x = x[:, c.shape[1]:]#[1,256] + self.transformer.train() + return x + + @torch.no_grad() + def log_images(self, x): + log = dict() + + _, indices = self.encode_to_z(x) + sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token + sos_tokens = sos_tokens.long().to("cuda") + + start_indices = indices[:, :indices.shape[1] // 2] + sample_indices = self.sample(start_indices, sos_tokens, steps=indices.shape[1] - start_indices.shape[1]) + half_sample = self.z_to_image(sample_indices) + + start_indices = indices[:, :0] + sample_indices = self.sample(start_indices, sos_tokens, steps=indices.shape[1]) + full_sample = self.z_to_image(sample_indices) + + x_rec = self.z_to_image(indices) + + log["input"] = x.mul(0.5).add(0.5) + log["rec"] = x_rec.mul(0.5).add(0.5) + log["half_sample"] = half_sample.mul(0.5).add(0.5) + log["full_sample"] = full_sample.mul(0.5).add(0.5) + + return log, torch.concat((x.mul(0.5).add(0.5), x_rec.mul(0.5).add(0.5), half_sample.mul(0.5).add(0.5), full_sample.mul(0.5).add(0.5))) + + + + + + + + + + + + + + + + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..719f6d0 --- /dev/null +++ b/utils.py @@ -0,0 +1,102 @@ +import os +import albumentations +import numpy as np +import torch.nn as nn +from PIL import Image +from torch.utils.data import Dataset, DataLoader +import matplotlib.pyplot as plt +import torchvision +import torch +# --------------------------------------------- # +# Data Utils +# --------------------------------------------- # +class ImageDataset(Dataset): + def __init__(self, image_paths): + super().__init__() + image_paths = os.listdir(image_paths) + #print( image_paths[0]) + image_paths = ["./data/FFHQ_128"+f"/{el}" for el in image_paths] + self.image_paths = image_paths + """self.transform = albumentations.Compose([ + albumentations.RandomCrop(height=128, width=128), + albumentations.augmentations.transforms.HorizontalFlip(p=0.5) + ])""" + self.transform = torchvision.transforms.RandomHorizontalFlip() + self.normalise = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, index): + + image = Image.open(self.image_paths[index]) + image = np.array(image)/255.0 + image = torch.from_numpy(image).float().permute(2, 0, 1) + image = self.transform(image) + image = self.normalise(image) + return image + + +class ImagePaths(Dataset): + def __init__(self, path, size=None): + self.size = size + + self.images = [os.path.join(path, file) for file in os.listdir(path)] + self._length = len(self.images) + + self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) + self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) + self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) + + def __len__(self): + return self._length + + def preprocess_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image / 127.5 - 1.0).astype(np.float32) + image = image.transpose(2, 0, 1) + return image + + def __getitem__(self, i): + example = self.preprocess_image(self.images[i]) + return example + + +def load_data(args): + #train_data = ImagePaths(args.dataset_path, size=128) + train_data = ImageDataset(args.dataset_path) + train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True,num_workers = 8, + pin_memory = True) + return train_loader + + +# --------------------------------------------- # +# Module Utils +# for Encoder, Decoder etc. +# --------------------------------------------- # + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +def plot_images(images): + x = images["input"] + reconstruction = images["rec"] + half_sample = images["half_sample"] + full_sample = images["full_sample"] + + fig, axarr = plt.subplots(1, 4) + axarr[0].imshow(x.cpu().detach().numpy()[0].transpose(1, 2, 0)) + axarr[1].imshow(reconstruction.cpu().detach().numpy()[0].transpose(1, 2, 0)) + axarr[2].imshow(half_sample.cpu().detach().numpy()[0].transpose(1, 2, 0)) + axarr[3].imshow(full_sample.cpu().detach().numpy()[0].transpose(1, 2, 0)) + plt.show() diff --git a/vgg_lpips/vgg.pth b/vgg_lpips/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..47e943cfacabf7040b4af8cf4084ab91177f1b88 GIT binary patch literal 7289 zcma)Ad00)`|8AZLg;1#^Iu)8n=lv+^ra`GxNTpG8?K-9D)M-A=GeV@u6f#8S5*L{w z(>2^{7NtxXuIYE~P51k~-@kt6dG^_RuXjK1yPmz?^?ue~JdpyKh_CB^`K%VUZQ-O@ zJdp)2L42lg3vn-v)y7!)x}f7Z&N&=nyZMl666 z9UBlb%Qv1A8WAVs8yn;w9~u$p%L$5#i;ne8Fpy!ySuA(SIKg9nD}oq_q330OiHP>& zxD7SRFp};v`er7krluyQ#s-X(yF@^Ipr3^vBV8bD>BMny5Mnt9J6H-1VGe8PEV#;d z;`iGwPlm74XF#Tu& z{g@$oga4(M`)_&^Gb1B&BU1xL{vZ0W1tLPi!}JS;1fBn@`#&#+1;z;k#t#V?{;z<- ze+w8J8=DyDndvbT{t=ic5D@u~fXJ}G-xtFIlLP{aLjp$sE1>jm0X-8lQ$rI*`5%GF z{~rOBpr9OAMwQE`xl8*-266lX{WyNiln{Y&)CI*m?ARH&?({z@= zlmbIBnp}nz^0QHkjkkU`jFeY5abQm%lA{(}o9%IgBEQZ0^{|5CLHkZLLNJB~n%p5LbIoxP9 zr1aB1%v>%rZy2&3A{%y=9%I90ei;UB{|)Lh^SO-OFlaxdWXL#h8OPzq1w%?deZ)9% z8RubW;Skxd#q<~#F5@~3y8Ro}XBKf8_hHauNXdxtv4-hZW?APtd4`#dl^ap}&ojSryZy@1pF^Xip0gf-ms;Z&jn>rRady;747)pK6gk#5Mp>CH9 zrR;S?_4QR~>hi;nMi;zje}r(!MRZH(0km@YFm&09@y=V}XnzpNZ|5T6YXbrUrsCQ7 zQq;6>Maqmegr};~HyuA@9ZZM+3I)2c{WH#=ybPUvx^Osk0f8wg*wSCiz7*y}DuESv zyt@+~Y6>*_p*cmn*Fj%s0y#|l|TmJ@0P6sxfeuI~n{=m&gme|x=fTt2ZSkkl)DKhfpYrGMsU+}@i zzGuJ9zm6LRMCfVgH7t|Tz=X(=FdnSKyge=8y$Pp$gZ0?lF`N3S7>dDaSl5(-`RDv- zaOy))*h3hXsbRb5M`Tw|L+qdGkn{1OPfxF4(;WtD>wmyBT>^z^M%Z&P3y-%2pniWT zx<%LHrSmIzck!`QXBjSdZpS006Th^+Bq!My>=u_^d~i@C3mtWe?A3(pp#+G%6(gD{ zM)TwDKymABq;<22vsRsYoMs|v#(WsNykYm*9zx!dwfOW%n(AdXLt%>~=@`Z#UGRSH zjjO<`#)-7;%VYNb?>QLN$3srJ1cKS;uv*Ln3KuS7=R8U3{QM3zUE8tG-~xgts-thW zCH8e`Q&9IU$c#M$>)gwbQ6GTzx@L@@?SzZsi=be76%#fsrs+TCl5|r6)}}wf6NfFZ z_gIS^2QGqF+lbQ>_>jM=NOylPftJQYG+a%`w;*x4v09nVg@JBPuER`?EPQnM4d*B9 zV;_w^gpnF4=-6@;;(tV8+T~S9zC4D~a;i|Mx&bY%a`bSKKPFhp)5+L!#An}Uzt5;c zOA8@$#9?@fSW-r|JC)wj#k1Ag=>8stGc9E(eKnm*BDgs9WH0zy&mr|)8WInxpe?lq z1EG1)eCLC64&&fgIEu_-+i>E*To{AS+cJjw>#jrf{b~5$jX|l}D_nY<2-VyXR3`qN?Ofr8 z*vY3b@lg?sq!y5zUKtu?jp$Qc4YvGZKu;5->9M~iJ?RNXg{2hjd^4WP8Xh6cKLGi& z4zZOJ8t~O>3lS}`AmpE&PggaYp(CzB4Z_j*mS919n|C6+ z=Lu}@j~amOfl9bHHx@4kb@DJ{B~5QxcF+R%7J9pOLjvf00}(5S?sM;F@g zy@HFf4I(7f^@Oebax~7jN}+VcS@bmYLAC7yP83SP!B7wD-ygxJd^L!V&BWVN=5U(c z1SgX#Y(w+4IJAp!L|KK#rrN_`%3~<~`U4XxpW*6QZ)}p(qi#7R`tsD8BDGs!mYjey zmkvU>N1i_2*2R`9F4Uc7kDHbX5a}+$jcF4g{D%VCPiWFWmoRlaZNcWRilk{2iZvQh z5bYR4>%{m-DXzxsjth7>*Bi4$eud4BDKHGbhK+Nj5F8pqsGdsVkABBGMuK>z8hG6z zNzeHsspM`0csir$L*ZfkxbY4;_pNa)=Lgu#e%$HNLD5nr($1cR6(hbu(*7A%&9lHH zX<0HFH=RxwmO_#W;Zj+P_Q#X)h5Zc2W5!drkS(1*IS8lOY^0|4v&|H^$Uk)!_2=&( zSNs(W=ZVvAZ@N(zsYwA-wD4{go36Qg;q!G}yb6AeD1&9VY^@886Y8+*??>;Kn<((S zf{Q*UV5J)ex0rk+9W#M!lN!cydALhn*jp|K&E@eh=$wl^)^#}SFAdHiY4Yzpi|4yC z@Xk>Q7CK(YTn~sA?uN&QQ|R=rfTfH+l50M|Z`1?aLOOD`Mc~_*lc@W14|-Ltv5MCV z5s|m})^AFr2fC9wnkwH*(elohxaJjv?*lwUJlCbI*OjQTQwYB`_OKn^7o)Wq(3H4< z+Go?~jeiXK|FB2gjH{T-Hz#??5%j>>4u^j)!yH9Nhy>3>d6yocH57B}d!e*pFFLd| zpwQNXUYn`t)Dxo*PHJeX_o6^Q20u1A!yxA>LcU&rj(ai&?{37x*M`*AJ{|X>K0~F_ z1W#^$VP8Jzgwger=}eb83EwcIwIW>*ZaIWUtZ!^{J8#5=e8z@>r|cCI`*7;!6ZU$Q ztr%AE{Q5AWB5euxWrI1Ulf%W@rl+N5lzNdxKak=6$$w4z#uv~4etCDv}SuCX5MMayjVozrzGRN?Jqd4V}rrw z_xMvj4h<)S=+8ZcNORSqoxUZ=vyMl{;!x^|D1@PZ95fQfl0lvd{dw>m2Fz*^%c{cu z>r3!$PdWUL1yk5h(I63v?bl#O+dQh67z4Kp3FuFDLacfn ziK;~7_qBK7c5E`$%6p^pSOMY{IT+tE7q?rSY24BXI6R(A@A?!d`f)a67hULR2m;HHi|Y3gI6PnvHjkl0Dm4&) z7mYuUP}j$B}qZDMHO(@-b>qn0n6G!*Z7inUAQ$XGFqw zc?`B6sGykfk8$*SI(V;sKtpUZENvo5t}PDg12K5*AA%m?4E)*R4fi*KvAtJ-Nxo^= z9>YL(<0-7km`d*FRmk2d2J_Mf(LDYPb%&os!>tT>jJSmFktvvBUkZ(sJm?Ge;6gwN za^ll*-ysO~$7j%-RzdzhO9yvTCFa?$f~RvBB1Q{Q-$psQBbAJi15q#!%tZNv3<%%J z$KIGGeEYH#LOoG%p1U3+q)ITcJqp)dlHe|$33G8il#<#Jo05&)1=g@r<%54a4g0J- zaBf-&s?PF&kqpdgwBfbffE39V??2DPJSkNg8#|E{losGs$ui7-I0p--W#O~x3G8#{ zL(=jbp3mQf(N}`8`DhCc@ojLhb1WV1VjwD>1Fe-jjGM8FI!rfWlIH@Pij70*h#(xC z9e}SIGPL(!xXz>e2UG4*kXk_GJ~~b z``nCbeTtzk$PeX#G`KmK(Dv;?NMku5SNJKWEzP1H!!+>6)?&0Y7groZG40@c*gBNs zO#CHi9$taqZN1o+{Tgbwl2OxpjU;RGG5T~Ws(zhIx`KBw@k=NsmAm7ahhTkVV!A#3D|M$7xEm-!&BWmLa{&9KeS!Ip8Yh2+pn1D?> z;V3(hhz`>X%ubV~CuuM7VQ(QUKJ)Nk+5?mvEWneQk`&l+9Y#75)To|^ytr(^e4mG+ zw@Gkz(xm`{E4bupgqq3{te5qHLQyfg3_J0))r$;PAHh1-Hu8KE4gLWhuD;|W-A9wo zX(!^-f{~rhI2XU>@NrwN1WVYX zNa=V6#(gozy}lH3xL1QHL zuGkz9_@J;D#bIu!dYKIQ!b5O57=|_v!Rew6|^wNolTCL zQn9;Hn%vsk>FwGoOp2^S*bW~OaZbS?FAZ~*vM@5G9hWaA<5p%h9o-y(K9&$!ww6Ko zo*+jf0&(Pd6Nzm&fh?N|G;vE992YUj%jO}rI2PeOQlz?X4U%UR()!#|w0@1mVp>4= z;?yt9hE|_BgBg5R^cTlL zCbIzf*Ba5X{~o^AmmpxLE7i6M>U7#Ps`D2m9qwqFm5@)a`bJdBFpyri9g7+YsCN|) zum329ommlF!ffecy$->s5}gmSV4YTh;*H~I*7;CuJC%y*nk#YEEf1}(+1NO?0LwRf zAhtdl*B=QHt6z|deOwqN^N|*)NUxUmfbUie*-&Ge9r zXjT%MBlTe`K20qKD=8gw8;a2v8US6M3F#(QP`9B!c3I58NBCZrf)lukfD#x)^ zX{dkQg88Ro$@)Y!!m{h&B3Fo#m)c}waSmL7eDfbn5f)p|JqI9V0L>c_TI*(Zi&+}$?T??8up5WDq zqQ)_U9M8?@7Z~8K}gNF)MpH#+Vy+o`` z&V^}WG;|*DDN?wRGDAjS*?VnpLv!GF#tf_S>Y(s2w7i*w`MWBREf_byHLIa=PZJXPMVQzZ zg@|S?_)m35cC$Hr+l#TUwG_`DC&SKJ6&*LNFfyVD?w4e7rON@$C+s1q&PLr=A56Xa zoqajl4G{)w5Vk)TJvaF{<&+PNViP<|4u{th!Fu_p4<1=9z@YdP>>5Q_$|=UJ7y8(> z+yG&#ig8BZgW5YP=zU>KE&i_foFWIc%gRW3F9T;r7Yo<7la$E|dUyav#G8pvTfG?qPFrC@Nj$+9| zNO=~YFA=KQ4&_RXG3XT z7TN`w_VFtZ#Z5_gUdxAmN-i!WC?a897C!jKLVDXuv_39G@UwVWZRcZ}HXkMR-new2 z9Bm`xfK!TEu+`YxR|1}EI(D~b;Zfj9oNUcS(bilXD9eX(M=5rm%!gW79wPK| zvBWzYJ+JHt_Tu^gi==E2l;1xAF0U~Xvz z0)qKqf2zhJwK%+C@^N5mH99U9;`SeF&=in|2iKm{#>EBL^)d)@zvLlRDGfzu7K0@o zD%c|_hSVz%mENt|g=O6%SIHRXCLBi`)EM^fgzbHL6PRxl@jQ z!{w0PpA2JjE<^&?plkUCK^7)rurM0j2nI8v6R=U>)3#An;9o7myO|t>TrNT9m2CV# z0mk&_A^vJPx(?=pua^bgEFMIoim+@=1vU!yKN5Fl;m;kp@K20DiC}#d+!T)c9>Fkw zU<21N(U{K8N2zu^+DwYEM!pEPp(S8?+;FrYA7vIrc)f2ub+Do^uEZZ!t(B+~U5yKR zJY;;2MCF$p%$och;ty+~C0m2FWfHWw=^b8vEcHkuzKK;EJP za-UbgtttrLxWO=azY49pbMQ=B;NiUyX#BVa><@9c=+DPO`BF69%7m)O3Y5QKu$j+; zvwtB1Ywd9BbvR@OGH_hO9nJS+A<<`m%lAv5VVnbrqk=J+l!ddRm9UBR!l-o|s9A8( PyEGODb&D{1dnW!5&PYuU literal 0 HcmV?d00001 diff --git a/vqgan.py b/vqgan.py new file mode 100644 index 0000000..221ef51 --- /dev/null +++ b/vqgan.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +from encoder import Encoder +from decoder import Decoder +from codebook import Codebook + + +class VQGAN(nn.Module): + def __init__(self, args): + super(VQGAN, self).__init__() + self.encoder = Encoder(args).to(device=args.device)# UNET FIRST HALF WITH ATTENTION + self.decoder = Decoder(args).to(device=args.device)# UNET SEC HALF WITH ATTETNION + self.codebook = Codebook(args).to(device=args.device) + self.quant_conv = nn.Conv2d(args.latent_dim, args.latent_dim, 1).to(device=args.device)#256,256 + self.post_quant_conv = nn.Conv2d(args.latent_dim, args.latent_dim, 1).to(device=args.device)#256,256 + + def forward(self, imgs): + encoded_images = self.encoder(imgs)#[8, 256, 8, 8]) + quant_conv_encoded_images = self.quant_conv(encoded_images)#([8, 256, 8, 8]) + codebook_mapping, codebook_indices, q_loss = self.codebook(quant_conv_encoded_images) + post_quant_conv_mapping = self.post_quant_conv(codebook_mapping)#([8, 256, 8, 8]) + decoded_images = self.decoder(post_quant_conv_mapping)#([8, 3, 128, 128]) + + return decoded_images, codebook_indices, q_loss#([8, 3, 128, 128]),[512],0.05 + + def encode(self, imgs): + encoded_images = self.encoder(imgs) + quant_conv_encoded_images = self.quant_conv(encoded_images) + codebook_mapping, codebook_indices, q_loss = self.codebook(quant_conv_encoded_images) + return codebook_mapping, codebook_indices, q_loss + + def decode(self, z): + post_quant_conv_mapping = self.post_quant_conv(z) + decoded_images = self.decoder(post_quant_conv_mapping) + return decoded_images + + def calculate_lambda(self, perceptual_loss, gan_loss): + last_layer = self.decoder.model[-1] + last_layer_weight = last_layer.weight + perceptual_loss_grads = torch.autograd.grad(perceptual_loss, last_layer_weight, retain_graph=True)[0] + gan_loss_grads = torch.autograd.grad(gan_loss, last_layer_weight, retain_graph=True)[0] + + λ = torch.norm(perceptual_loss_grads) / (torch.norm(gan_loss_grads) + 1e-4) + λ = torch.clamp(λ, 0, 1e4).detach() + return 0.8 * λ + + @staticmethod + def adopt_weight(disc_factor, i, threshold, value=0.): + if i < threshold: + disc_factor = value + return disc_factor + + def load_checkpoint(self, path): + self.load_state_dict(torch.load(path)) + + + + + + + +