Skip to content

Commit

Permalink
updated DCGAN arch to match the paper
Browse files Browse the repository at this point in the history
  • Loading branch information
rsk97 committed Mar 9, 2023
1 parent 265c082 commit 8d11a29
Show file tree
Hide file tree
Showing 2 changed files with 420 additions and 358 deletions.
92 changes: 55 additions & 37 deletions MoG.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import tqdm
import sys, os
from os.path import join

import torch
import torch.optim as optim
from torch import autograd
from models.gan import Gen, Dis

from utils.plot import plot_eigens, complex_scatter_plot, plot_kde
from utils.plot import plot_eigens, plot_kde
from utils.utils import batch_net_outputs, net_losses


Expand All @@ -16,54 +20,68 @@
batch_size = 512
method = 'ConsOpt' #'SimGA' #'ConsOpt'

for i in tqdm(range(steps+1)):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

gen_out, real_in, fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out = batch_net_outputs()
gen_loss_detached, disc_loss_detached, gen_loss, disc_loss = net_losses(fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out)
path = '/home/mila/r/rohan.sukumaran/repos/TheNumericsofGANs_pytorch/results/'

if i%5000 == 0:
if method == 'ConsOpt':
plot_eigens(i)
plot_kde(i)

if "name" == "__main__":
gen_net = Gen(16, 2).to(device)
disc_net = Dis(2, 1).to(device)

params = list(gen_net.parameters()) + list(disc_net.parameters())

gen_opt = optim.RMSprop(gen_net.parameters(), lr=lr)
disc_opt = optim.RMSprop(disc_net.parameters(), lr=lr)

for i in tqdm(range(steps+1)):

gen_out, real_in, fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out = batch_net_outputs()
gen_loss_detached, disc_loss_detached, gen_loss, disc_loss = net_losses(fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out)

if i%5000 == 0:
if method == 'ConsOpt':
plot_eigens(i)
plot_kde(i)

gen_path = join(path, 'Models', 'gen_' + method + "_" + str(i) + '.pt')
disc_path = join(path, 'Models', 'disc_'+ method + "_" + str(i) + '.pt')
torch.save(gen_net.state_dict(), gen_path)
torch.save(disc_net.state_dict(), disc_path)
gen_path = join(path, 'Models', 'gen_' + method + "_" + str(i) + '.pt')
disc_path = join(path, 'Models', 'disc_'+ method + "_" + str(i) + '.pt')
torch.save(gen_net.state_dict(), gen_path)
torch.save(disc_net.state_dict(), disc_path)

if method == 'ConsOpt':
if method == 'ConsOpt':

gen_net.zero_grad()
gen_grad = autograd.grad(gen_loss, gen_net.parameters(), retain_graph=True, create_graph=True)
disc_net.zero_grad()
disc_grad = autograd.grad(disc_loss, disc_net.parameters(), retain_graph=True, create_graph=True)
gen_net.zero_grad()
gen_grad = autograd.grad(gen_loss, gen_net.parameters(), retain_graph=True, create_graph=True)
disc_net.zero_grad()
disc_grad = autograd.grad(disc_loss, disc_net.parameters(), retain_graph=True, create_graph=True)

v = list(gen_grad) + list(disc_grad)
v = torch.cat([t.flatten() for t in v])
v = list(gen_grad) + list(disc_grad)
v = torch.cat([t.flatten() for t in v])

L = 1/2 * torch.dot(v, v)
jgrads = autograd.grad(L, params, retain_graph=True)
L = 1/2 * torch.dot(v, v)
jgrads = autograd.grad(L, params, retain_graph=True)

gen_opt.zero_grad()
gen_opt.zero_grad()

for i in range(len(params)):
params[i].grad = jgrads[i] * gamma
gen_loss_detached.backward(retain_graph=True, create_graph=True)
gen_opt.step()
for i in range(len(params)):
params[i].grad = jgrads[i] * gamma
gen_loss_detached.backward(retain_graph=True, create_graph=True)
gen_opt.step()

disc_opt.zero_grad()
disc_opt.zero_grad()

for i in range(len(params)):
for i in range(len(params)):
params[i].grad = jgrads[i] * gamma
disc_loss_detached.backward(retain_graph=True, create_graph=True)
disc_opt.step()
disc_loss_detached.backward(retain_graph=True, create_graph=True)
disc_opt.step()

else:
gen_opt.zero_grad()
gen_loss_detached.backward(retain_graph=True, create_graph=True)
gen_opt.step()
else:
gen_opt.zero_grad()
gen_loss_detached.backward(retain_graph=True, create_graph=True)
gen_opt.step()

disc_opt.zero_grad()
disc_loss_detached.backward(retain_graph=True, create_graph=True)
disc_opt.step()
disc_opt.zero_grad()
disc_loss_detached.backward(retain_graph=True, create_graph=True)
disc_opt.step()

Loading

0 comments on commit 8d11a29

Please sign in to comment.