From 149afd7ac08fd92889b8a63a65ffc743d15ccae5 Mon Sep 17 00:00:00 2001 From: xuyou314 <714098455@qq.com> Date: Tue, 29 Jan 2019 23:26:29 +0800 Subject: [PATCH] change the deprecated torch.Varialbe api into torch.tensor with requires_grad ==true --- VAE/vanilla_vae/vae_pytorch.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/VAE/vanilla_vae/vae_pytorch.py b/VAE/vanilla_vae/vae_pytorch.py index fb09577..ff24d07 100644 --- a/VAE/vanilla_vae/vae_pytorch.py +++ b/VAE/vanilla_vae/vae_pytorch.py @@ -6,7 +6,6 @@ import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import os -from torch.autograd import Variable from tensorflow.examples.tutorials.mnist import input_data @@ -23,19 +22,20 @@ def xavier_init(size): in_dim = size[0] xavier_stddev = 1. / np.sqrt(in_dim / 2.) - return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True) + n=np.random.randn(*size)*xavier_stddev + return torch.tensor(n, requires_grad=True,dtype=torch.float32) # =============================== Q(z|X) ====================================== Wxh = xavier_init(size=[X_dim, h_dim]) -bxh = Variable(torch.zeros(h_dim), requires_grad=True) +bxh = torch.zeros(h_dim, requires_grad=True) Whz_mu = xavier_init(size=[h_dim, Z_dim]) -bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True) +bhz_mu = torch.zeros(Z_dim, requires_grad=True) Whz_var = xavier_init(size=[h_dim, Z_dim]) -bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True) +bhz_var = torch.zeros(Z_dim, requires_grad=True) def Q(X): @@ -46,17 +46,17 @@ def Q(X): def sample_z(mu, log_var): - eps = Variable(torch.randn(mb_size, Z_dim)) + eps = torch.randn(mb_size, Z_dim) return mu + torch.exp(log_var / 2) * eps # =============================== P(X|z) ====================================== Wzh = xavier_init(size=[Z_dim, h_dim]) -bzh = Variable(torch.zeros(h_dim), requires_grad=True) +bzh = torch.zeros(h_dim, requires_grad=True) Whx = xavier_init(size=[h_dim, X_dim]) -bhx = Variable(torch.zeros(X_dim), requires_grad=True) +bhx = torch.zeros(X_dim, requires_grad=True) def P(z): @@ -74,7 +74,7 @@ def P(z): for it in range(100000): X, _ = mnist.train.next_batch(mb_size) - X = Variable(torch.from_numpy(X)) + X = torch.from_numpy(X) # Forward z_mu, z_var = Q(X) @@ -85,23 +85,21 @@ def P(z): recon_loss = nn.binary_cross_entropy(X_sample, X, size_average=False) / mb_size kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1)) loss = recon_loss + kl_loss - # Backward loss.backward() # Update solver.step() - + solver.zero_grad() # Housekeeping - for p in params: - if p.grad is not None: - data = p.grad.data - p.grad = Variable(data.new().resize_as_(data).zero_()) + # for p in params: + # if p.grad is not None: + # data = p.grad.data + # p.grad = data.new().resize_as_(data).zero_() # Print and plot every now and then if it % 1000 == 0: - print('Iter-{}; Loss: {:.4}'.format(it, loss.data[0])) - + print('Iter-{}; Loss: {:.4}'.format(it, loss.data.item())) samples = P(z).data.numpy()[:16] fig = plt.figure(figsize=(4, 4))