Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the quality of recon_img #22

Open
zhuyingce opened this issue Oct 6, 2023 · 0 comments
Open

About the quality of recon_img #22

zhuyingce opened this issue Oct 6, 2023 · 0 comments

Comments

@zhuyingce
Copy link

Hello, I try to extract your BetaVAE_H model and loss function, then, I train the model on cifar10. But after 10000 epochs training, the quality of recon_img is still very terrible. Is there anything else I didn't consider in? Please help me. The code I use is listed as follows:
`from torch import nn
from torch.nn import init
from torch.autograd import Variable

def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps

def kaiming_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.kaiming_normal(m.weight)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.fill_(0)

class View(nn.Module):
def init(self, size):
super(View, self).init()
self.size = size

def forward(self, tensor):
    return tensor.view(self.size)

class BetaVAE_H(nn.Module):
"""Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""

def __init__(self, z_dim=10, nc=3):
    super(BetaVAE_H, self).__init__()
    self.z_dim = z_dim
    self.nc = nc
    self.encoder = nn.Sequential(
        nn.Conv2d(nc, 32, 4, 2, 1),  # B,  32, 32, 32
        nn.ReLU(True),
        nn.Conv2d(32, 32, 4, 2, 1),  # B,  32, 16, 16
        nn.ReLU(True),
        nn.Conv2d(32, 64, 4, 2, 1),  # B,  64,  8,  8
        nn.ReLU(True),
        nn.Conv2d(64, 64, 4, 2, 1),  # B,  64,  4,  4
        nn.ReLU(True),
        nn.Conv2d(64, 256, 4, 1),  # B, 256,  1,  1
        nn.ReLU(True),
        View((-1, 256 * 1 * 1)),  # B, 256
        nn.Linear(256, z_dim * 2),  # B, z_dim*2
    )
    self.decoder = nn.Sequential(
        nn.Linear(z_dim, 256),  # B, 256
        View((-1, 256, 1, 1)),  # B, 256,  1,  1
        nn.ReLU(True),
        nn.ConvTranspose2d(256, 64, 4),  # B,  64,  4,  4
        nn.ReLU(True),
        nn.ConvTranspose2d(64, 64, 4, 2, 1),  # B,  64,  8,  8
        nn.ReLU(True),
        nn.ConvTranspose2d(64, 32, 4, 2, 1),  # B,  32, 16, 16
        nn.ReLU(True),
        nn.ConvTranspose2d(32, 32, 4, 2, 1),  # B,  32, 32, 32
        nn.ReLU(True),
        nn.ConvTranspose2d(32, nc, 4, 2, 1),  # B, nc, 64, 64
    )

    self.weight_init()

def weight_init(self):
    for block in self._modules:
        for m in self._modules[block]:
            kaiming_init(m)

def forward(self, x):
    distributions = self._encode(x)
    mu = distributions[:, :self.z_dim]
    logvar = distributions[:, self.z_dim:]
    z = reparametrize(mu, logvar)
    x_recon = self._decode(z)

    return x_recon, mu, logvar

def _encode(self, x):
    return self.encoder(x)

def _decode(self, z):
    return self.decoder(z)

import torch
from torch import optim
from torch.utils.data import DataLoader

from beta_vae import BetaVAE_H
import torch.nn.functional as F
from torchvision import datasets, transforms

def recon_loss(x, x_recon):
x_recon = F.sigmoid(x_recon)
rec_loss = F.mse_loss(x_recon, x)
return rec_loss

def kld_loss(mu, logvar):
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))

klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
return total_kld

def train(epochs=1000, batch_size=128, z_dim=32, device='cuda:2', lr=1e-4, beta=10):
dataset = datasets.CIFAR10(root='../dataset/cifar10', train=True, transform=transforms.Compose([
transforms.Resize(64),
transforms.ToTensor()
]), download=True)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=2)

model = BetaVAE_H(z_dim=z_dim, nc=3).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    model.train()
    loss, r_loss, k_loss = (0, 0, 0)
    for idx, (images, _) in enumerate(dataloader):
        images = images.to(device)
        x_recon, mu, logvar = model(images)
        rec_loss = recon_loss(images, x_recon)
        kl_loss = kld_loss(mu, logvar)

        beta_vae_loss = rec_loss + beta * kl_loss
        optimizer.zero_grad()
        beta_vae_loss.backward()
        loss += beta_vae_loss.item()
        r_loss += rec_loss.item()
        k_loss += kl_loss.item()
        optimizer.step()

`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant