From 4c3f880738bcf394fe0994c7324331c8c61acdb5 Mon Sep 17 00:00:00 2001 From: Satire_Y <973919719@qq.com> Date: Mon, 16 Oct 2023 15:37:14 +0800 Subject: [PATCH] add reset+ema for vq-vae --- .vscode/launch.json | 16 ++ configs/vq_vae.yaml | 13 +- experiment.py | 7 +- models/__init__.py | 1 + models/quantizer.py | 400 ++++++++++++++++++++++++++++++++++++++++++++ models/vq_vae.py | 115 +++++++------ run.py | 8 +- 7 files changed, 497 insertions(+), 63 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 models/quantizer.py diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..8127bc7e --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // 使用 IntelliSense 了解相关属性。 + // 悬停以查看现有属性的描述。 + // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: 当前文件", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": false + } + ] +} \ No newline at end of file diff --git a/configs/vq_vae.yaml b/configs/vq_vae.yaml index 505425f7..b8e04ca3 100644 --- a/configs/vq_vae.yaml +++ b/configs/vq_vae.yaml @@ -1,21 +1,22 @@ model_params: - name: 'VQVAE' + name: "VQVAE" in_channels: 3 + quantizer: "ema_reset" embedding_dim: 64 num_embeddings: 512 img_size: 64 - beta: 0.25 + beta: 1.0 + mu: 0.99 data_params: data_path: "Data/" train_batch_size: 64 - val_batch_size: 64 + val_batch_size: 64 patch_size: 64 num_workers: 4 - exp_params: - LR: 0.005 + LR: 0.0003 weight_decay: 0.0 scheduler_gamma: 0.0 kld_weight: 0.00025 @@ -27,4 +28,4 @@ trainer_params: logging_params: save_dir: "logs/" - name: 'VQVAE' + name: "VQVAE" diff --git a/experiment.py b/experiment.py index 8763a006..2d5130f9 100644 --- a/experiment.py +++ b/experiment.py @@ -66,7 +66,12 @@ def sample_images(self): test_input, test_label = next(iter(self.trainer.datamodule.test_dataloader())) test_input = test_input.to(self.curr_device) test_label = test_label.to(self.curr_device) - + vutils.save_image(test_input.data, + os.path.join(self.logger.log_dir , + "Origin", + f"origin_{self.logger.name}_Epoch_{self.current_epoch}.png"), + normalize=True, + nrow=12) # test_input, test_label = batch recons = self.model.generate(test_input, labels = test_label) vutils.save_image(recons.data, diff --git a/models/__init__.py b/models/__init__.py index 3f310f29..3c3d1d61 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -21,6 +21,7 @@ from .vq_vae import * from .betatc_vae import * from .dip_vae import * +from .quantizer import * # Aliases diff --git a/models/quantizer.py b/models/quantizer.py new file mode 100644 index 00000000..06034026 --- /dev/null +++ b/models/quantizer.py @@ -0,0 +1,400 @@ +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from .types_ import * + +class Quantizer(nn.Module): + """ + Reference: + [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py + """ + def __init__(self, + num_embeddings: int, + embedding_dim: int, + beta: float = 0.25): + super(Quantizer, self).__init__() + self.K = num_embeddings + self.D = embedding_dim + self.beta = beta + + self.embedding = nn.Embedding(self.K, self.D) + self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K) + + def forward(self, latents: Tensor) -> Tensor: + latents = latents.permute(0, 2, 3, 1).contiguous() # [B x D x H x W] -> [B x H x W x D] + latents_shape = latents.shape + flat_latents = latents.view(-1, self.D) # [BHW x D] + + # Compute L2 distance between latents and embedding weights + dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - \ + 2 * torch.matmul(flat_latents, self.embedding.weight.t()) # [BHW x K] + + # Get the encoding that has the min distance + encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1] + + # Convert to one-hot encodings + device = latents.device + encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) + encoding_one_hot.scatter_(1, encoding_inds, 1) # [BHW x K] + + # Quantize the latents + quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D] + quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D] + + # Compute the VQ Losses + commitment_loss = F.mse_loss(quantized_latents.detach(), latents) + embedding_loss = F.mse_loss(quantized_latents, latents.detach()) + + vq_loss = commitment_loss * self.beta + embedding_loss + + # Add the residue back to the latents + quantized_latents = latents + (quantized_latents - latents).detach() + + return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W] + + +class QuantizeReset(nn.Module): + def __init__(self, nb_code, code_dim): + super().__init__() + self.nb_code = nb_code + self.code_dim = code_dim + self.reset_codebook() + self.codebook = nn.Parameter(torch.randn(nb_code, code_dim)) + + def reset_codebook(self): + self.init = False + self.code_count = None + + def _tile(self, x): + nb_code_x, code_dim = x.shape + if nb_code_x < self.nb_code: + n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x + std = 0.01 / np.sqrt(code_dim) + out = x.repeat(n_repeats, 1) + out = out + torch.randn_like(out) * std + else : + out = x + return out + + def init_codebook(self, x): + out = self._tile(x) + self.codebook = nn.Parameter(out[:self.nb_code]) + self.code_count = torch.ones(self.nb_code, device=self.codebook.device) + self.init = True + + @torch.no_grad() + def compute_perplexity(self, code_idx) : + # Calculate new centres + code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + return perplexity + + def update_codebook(self, x, code_idx): + # [K, BHW] + code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) + code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) + # [K, ] + code_count = code_onehot.sum(dim=-1) # nb_code + + out = self._tile(x) + code_rand = out[:self.nb_code] + + # Update centres + self.code_count = code_count # nb_code + usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() + + self.codebook.data = usage * self.codebook.data + (1 - usage) * code_rand + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + + + return perplexity + + def preprocess(self, x): + # [B x D x H x W] -> [B x H x W x D] -> [BHW x D] + x = x.permute(0, 2, 3, 1).contiguous() + shape = x.shape + x = x.view(-1, x.shape[-1]) + return x, shape + + def quantize(self, x): + # Calculate latent code x_l + k_w = self.codebook.t() + + distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, + keepdim=True) + _, code_idx = torch.min(distance, dim=-1) + return code_idx + + def dequantize(self, code_idx): + # [BHW, D] + x = F.embedding(code_idx, self.codebook) + return x + + + def forward(self, x): + # Preprocess + x, x_shape = self.preprocess(x) # [BHW, D] + # Init codebook if not inited + if self.training and not self.init: + self.init_codebook(x) + # quantize and dequantize through bottleneck + # [BHW] + code_idx = self.quantize(x) + # [BHW, D] + x_d = self.dequantize(code_idx) + # Update embeddings + if self.training: + perplexity = self.update_codebook(x, code_idx) + else : + perplexity = self.compute_perplexity(code_idx) + + # Loss + commit_loss = F.mse_loss(x, x_d.detach()) + + # Passthrough + x_d = x + (x_d - x).detach() + + # Postprocess + x_d = x_d.view(x_shape).permute(0, 3, 1, 2).contiguous() #(N, DIM, T) + + return x_d, commit_loss, perplexity + +class QuantizeEMA(nn.Module): + def __init__(self, nb_code, code_dim, mu): + super().__init__() + self.nb_code = nb_code + self.code_dim = code_dim + self.mu = mu + self.reset_codebook() + + def reset_codebook(self): + self.init = False + self.code_sum = None + self.code_count = None + self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda()) + + def _tile(self, x): + nb_code_x, code_dim = x.shape + if nb_code_x < self.nb_code: + n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x + std = 0.01 / np.sqrt(code_dim) + out = x.repeat(n_repeats, 1) + out = out + torch.randn_like(out) * std + else : + out = x + return out + + def init_codebook(self, x): + out = self._tile(x) + self.codebook = out[:self.nb_code] + self.code_sum = self.codebook.clone() + self.code_count = torch.ones(self.nb_code, device=self.codebook.device) + self.init = True + + @torch.no_grad() + def compute_perplexity(self, code_idx) : + # Calculate new centres + code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + return perplexity + + @torch.no_grad() + def update_codebook(self, x, code_idx): + # [K, BHW] + code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) + # [K, D] + code_sum = torch.matmul(code_onehot, x) # nb_code, w + # [K] + code_count = code_onehot.sum(dim=-1) # nb_code + + # Update centres + self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code + self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code + #[K, D] + code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) + + self.codebook = code_update + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + + return perplexity + + def preprocess(self, x): + # [B x D x H x W] -> [B x H x W x D] -> [BHW x D] + x = x.permute(0, 2, 3, 1).contiguous() + shape = x.shape + x = x.view(-1, x.shape[-1]) + return x, shape + + def quantize(self, x): + # Calculate latent code x_l + k_w = self.codebook.t() + distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, + keepdim=True) # (N * L, b) + _, code_idx = torch.min(distance, dim=-1) + return code_idx + + def dequantize(self, code_idx): + x = F.embedding(code_idx, self.codebook) + return x + + + def forward(self, x): + # Preprocess + x, x_shape = self.preprocess(x) # [BHW, D] + + # Init codebook if not inited + if self.training and not self.init: + self.init_codebook(x) + + # quantize and dequantize through bottleneck + code_idx = self.quantize(x) + x_d = self.dequantize(code_idx) + + # Update embeddings + if self.training: + perplexity = self.update_codebook(x, code_idx) + else : + perplexity = self.compute_perplexity(code_idx) + + # Loss + commit_loss = F.mse_loss(x, x_d.detach()) + + # Passthrough + x_d = x + (x_d - x).detach() + + # Postprocess + x_d = x_d.view(x_shape).permute(0, 3, 1, 2).contiguous() #(N, DIM, T) + + return x_d, commit_loss, perplexity + +class QuantizeEMAReset(nn.Module): + def __init__(self, nb_code, code_dim, mu): + super().__init__() + self.nb_code = nb_code + self.code_dim = code_dim + self.mu = mu + self.reset_codebook() + + def reset_codebook(self): + self.init = False + self.code_sum = None + self.code_count = None + device = "cuda" if torch.cuda.is_available() else "cpu" + self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).to(device)) + + def _tile(self, x): + nb_code_x, code_dim = x.shape + if nb_code_x < self.nb_code: + n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x + std = 0.01 / np.sqrt(code_dim) + out = x.repeat(n_repeats, 1) + out = out + torch.randn_like(out) * std + else : + out = x + return out + + def init_codebook(self, x): + out = self._tile(x) + self.codebook = out[:self.nb_code] + self.code_sum = self.codebook.clone() + self.code_count = torch.ones(self.nb_code, device=self.codebook.device) + self.init = True + + @torch.no_grad() + def compute_perplexity(self, code_idx) : + # Calculate new centres + code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + return perplexity + + @torch.no_grad() + def update_codebook(self, x, code_idx): + + code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) + code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) + + code_sum = torch.matmul(code_onehot, x) + code_count = code_onehot.sum(dim=-1) + + out = self._tile(x) + code_rand = out[:self.nb_code] + + # Update centres + self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum + self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count + + usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() + code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) + + self.codebook = usage * code_update + (1 - usage) * code_rand + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + + + return perplexity + + def preprocess(self, x): + # [B x D x H x W] -> [B x H x W x D] -> [BHW x D] + x = x.permute(0, 2, 3, 1).contiguous() + shape = x.shape + x = x.view(-1, x.shape[-1]) + return x, shape + + def quantize(self, x): + # Calculate latent code x_l + k_w = self.codebook.t() + distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, + keepdim=True) # (N * L, b) + _, code_idx = torch.min(distance, dim=-1) + return code_idx + + def dequantize(self, code_idx): + x = F.embedding(code_idx, self.codebook) + return x + + + def forward(self, x): + # Preprocess + x, x_shape = self.preprocess(x) # [BHW, D] + + # Init codebook if not inited + if self.training and not self.init: + self.init_codebook(x) + + # quantize and dequantize through bottleneck + code_idx = self.quantize(x) + x_d = self.dequantize(code_idx) + + # Update embeddings + if self.training: + perplexity = self.update_codebook(x, code_idx) + else : + perplexity = self.compute_perplexity(code_idx) + + # Loss + commit_loss = F.mse_loss(x, x_d.detach()) + + # Passthrough + x_d = x + (x_d - x).detach() + + # Postprocess + x_d = x_d.view(x_shape).permute(0, 3, 1, 2).contiguous() #(N, DIM, T) + + return x_d, commit_loss, perplexity \ No newline at end of file diff --git a/models/vq_vae.py b/models/vq_vae.py index bd2249c6..c98c4e1f 100644 --- a/models/vq_vae.py +++ b/models/vq_vae.py @@ -2,57 +2,59 @@ from models import BaseVAE from torch import nn from torch.nn import functional as F +from .quantizer import * from .types_ import * -class VectorQuantizer(nn.Module): - """ - Reference: - [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py - """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - beta: float = 0.25): - super(VectorQuantizer, self).__init__() - self.K = num_embeddings - self.D = embedding_dim - self.beta = beta - - self.embedding = nn.Embedding(self.K, self.D) - self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K) - - def forward(self, latents: Tensor) -> Tensor: - latents = latents.permute(0, 2, 3, 1).contiguous() # [B x D x H x W] -> [B x H x W x D] - latents_shape = latents.shape - flat_latents = latents.view(-1, self.D) # [BHW x D] - - # Compute L2 distance between latents and embedding weights - dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \ - torch.sum(self.embedding.weight ** 2, dim=1) - \ - 2 * torch.matmul(flat_latents, self.embedding.weight.t()) # [BHW x K] - - # Get the encoding that has the min distance - encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1] - - # Convert to one-hot encodings - device = latents.device - encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) - encoding_one_hot.scatter_(1, encoding_inds, 1) # [BHW x K] - - # Quantize the latents - quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D] - quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D] - - # Compute the VQ Losses - commitment_loss = F.mse_loss(quantized_latents.detach(), latents) - embedding_loss = F.mse_loss(quantized_latents, latents.detach()) - - vq_loss = commitment_loss * self.beta + embedding_loss - - # Add the residue back to the latents - quantized_latents = latents + (quantized_latents - latents).detach() - - return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W] +# VQ-VAE Codebook +# class VectorQuantizer(nn.Module): +# """ +# Reference: +# [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py +# """ +# def __init__(self, +# num_embeddings: int, +# embedding_dim: int, +# beta: float = 0.25): +# super(VectorQuantizer, self).__init__() +# self.K = num_embeddings +# self.D = embedding_dim +# self.beta = beta + +# self.embedding = nn.Embedding(self.K, self.D) +# self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K) + +# def forward(self, latents: Tensor) -> Tensor: +# latents = latents.permute(0, 2, 3, 1).contiguous() # [B x D x H x W] -> [B x H x W x D] +# latents_shape = latents.shape +# flat_latents = latents.view(-1, self.D) # [BHW x D] + +# # Compute L2 distance between latents and embedding weights +# dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \ +# torch.sum(self.embedding.weight ** 2, dim=1) - \ +# 2 * torch.matmul(flat_latents, self.embedding.weight.t()) # [BHW x K] + +# # Get the encoding that has the min distance +# encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1] + +# # Convert to one-hot encodings +# device = latents.device +# encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) +# encoding_one_hot.scatter_(1, encoding_inds, 1) # [BHW x K] + +# # Quantize the latents +# quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D] +# quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D] + +# # Compute the VQ Losses +# commitment_loss = F.mse_loss(quantized_latents.detach(), latents) +# embedding_loss = F.mse_loss(quantized_latents, latents.detach()) + +# vq_loss = commitment_loss * self.beta + embedding_loss + +# # Add the residue back to the latents +# quantized_latents = latents + (quantized_latents - latents).detach() + +# return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W] class ResidualLayer(nn.Module): @@ -74,10 +76,12 @@ class VQVAE(BaseVAE): def __init__(self, in_channels: int, + quantizer: str, embedding_dim: int, num_embeddings: int, hidden_dims: List = None, beta: float = 0.25, + mu: float = 0.99, img_size: int = 64, **kwargs) -> None: super(VQVAE, self).__init__() @@ -121,9 +125,14 @@ def __init__(self, self.encoder = nn.Sequential(*modules) - self.vq_layer = VectorQuantizer(num_embeddings, - embedding_dim, - self.beta) + if quantizer == "ema_reset": + self.vq_layer = QuantizeEMAReset(num_embeddings, embedding_dim, mu=mu) + if quantizer == "orig": + self.vq_layer = Quantizer(num_embeddings, embedding_dim, beta=beta) + elif quantizer == "ema": + self.vq_layer = QuantizeEMA(num_embeddings, embedding_dim, mu=mu) + elif quantizer == "reset": + self.vq_layer = QuantizeReset(num_embeddings, embedding_dim) # Build Decoder modules = [] @@ -188,7 +197,7 @@ def decode(self, z: Tensor) -> Tensor: def forward(self, input: Tensor, **kwargs) -> List[Tensor]: encoding = self.encode(input)[0] - quantized_inputs, vq_loss = self.vq_layer(encoding) + quantized_inputs, vq_loss, perplexity = self.vq_layer(encoding) return [self.decode(quantized_inputs), input, vq_loss] def loss_function(self, diff --git a/run.py b/run.py index 160ed762..82a0db50 100644 --- a/run.py +++ b/run.py @@ -12,6 +12,8 @@ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from dataset import VAEDataset from pytorch_lightning.plugins import DDPPlugin +import setproctitle +setproctitle.setproctitle('hychen') parser = argparse.ArgumentParser(description='Generic runner for VAE models') @@ -19,7 +21,7 @@ dest="filename", metavar='FILE', help = 'path to the config file', - default='configs/vae.yaml') + default='configs/vq_vae.yaml') args = parser.parse_args() with open(args.filename, 'r') as file: @@ -31,7 +33,7 @@ tb_logger = TensorBoardLogger(save_dir=config['logging_params']['save_dir'], name=config['model_params']['name'],) - +tb_logger.log_hyperparams(config) # For reproducibility seed_everything(config['exp_params']['manual_seed'], True) @@ -56,7 +58,7 @@ Path(f"{tb_logger.log_dir}/Samples").mkdir(exist_ok=True, parents=True) Path(f"{tb_logger.log_dir}/Reconstructions").mkdir(exist_ok=True, parents=True) - +Path(f"{tb_logger.log_dir}/Origin" ).mkdir(exist_ok=True, parents=True) print(f"======= Training {config['model_params']['name']} =======") runner.fit(experiment, datamodule=data) \ No newline at end of file