You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I try to extract your BetaVAE_H model and loss function, then, I train the model on cifar10. But after 10000 epochs training, the quality of recon_img is still very terrible. Is there anything else I didn't consider in? Please help me. The code I use is listed as follows:
`from torch import nn
from torch.nn import init
from torch.autograd import Variable
def kaiming_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.kaiming_normal(m.weight)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.fill_(0)
class View(nn.Module):
def init(self, size):
super(View, self).init()
self.size = size
Hello, I try to extract your BetaVAE_H model and loss function, then, I train the model on cifar10. But after 10000 epochs training, the quality of recon_img is still very terrible. Is there anything else I didn't consider in? Please help me. The code I use is listed as follows:
`from torch import nn
from torch.nn import init
from torch.autograd import Variable
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
def kaiming_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.kaiming_normal(m.weight)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.fill_(0)
class View(nn.Module):
def init(self, size):
super(View, self).init()
self.size = size
class BetaVAE_H(nn.Module):
"""Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""
from torch import optim
from torch.utils.data import DataLoader
from beta_vae import BetaVAE_H
import torch.nn.functional as F
from torchvision import datasets, transforms
def recon_loss(x, x_recon):
x_recon = F.sigmoid(x_recon)
rec_loss = F.mse_loss(x_recon, x)
return rec_loss
def kld_loss(mu, logvar):
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
def train(epochs=1000, batch_size=128, z_dim=32, device='cuda:2', lr=1e-4, beta=10):
dataset = datasets.CIFAR10(root='../dataset/cifar10', train=True, transform=transforms.Compose([
transforms.Resize(64),
transforms.ToTensor()
]), download=True)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=2)
`
The text was updated successfully, but these errors were encountered: