diff --git a/MoG.py b/MoG.py index e7a487a..90a2637 100644 --- a/MoG.py +++ b/MoG.py @@ -43,7 +43,7 @@ if i%5000 == 0: if method == 'ConsOpt': - plot_eigens(i, gen_net, disc_net, params, gamma, path, device) + plot_eigens(i, gen_net, disc_net, params, gamma, path, batch_size, z_dim, sigma, criterion, device) plot_kde(i, method, sigma, gen_net, path, device, batch_size, z_dim, real_input=False) gen_path = join(path, 'Models', 'gen_' + method + "_" + str(i) + '.pt') diff --git a/utils/plot.py b/utils/plot.py index 089ed8e..ce759fd 100644 --- a/utils/plot.py +++ b/utils/plot.py @@ -6,9 +6,9 @@ -def plot_eigens(iteration, gen, disc, params, gamma, path, device): - _, _, fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out = batch_net_outputs() - _, _, gen_loss, disc_loss = net_losses(fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out) +def plot_eigens(iteration, gen, disc, params, gamma, path, batch_size, z_dim, sigma, criterion, device): + _, _, fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out = batch_net_outputs(gen, disc, batch_size, z_dim, sigma, device) + _, _, gen_loss, disc_loss = net_losses(criterion, fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out) p_count = torch.cat([x.flatten() for x in params]).shape[0] gen.zero_grad()