Skip to content

Commit

Permalink
updated function signature
Browse files Browse the repository at this point in the history
  • Loading branch information
rsk97 committed Mar 9, 2023
1 parent 6db0581 commit ecfef5a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion MoG.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

if i%5000 == 0:
if method == 'ConsOpt':
plot_eigens(i, gen_net, disc_net, params, gamma, path, device)
plot_eigens(i, gen_net, disc_net, params, gamma, path, batch_size, z_dim, sigma, criterion, device)
plot_kde(i, method, sigma, gen_net, path, device, batch_size, z_dim, real_input=False)

gen_path = join(path, 'Models', 'gen_' + method + "_" + str(i) + '.pt')
Expand Down
6 changes: 3 additions & 3 deletions utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@



def plot_eigens(iteration, gen, disc, params, gamma, path, device):
_, _, fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out = batch_net_outputs()
_, _, gen_loss, disc_loss = net_losses(fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out)
def plot_eigens(iteration, gen, disc, params, gamma, path, batch_size, z_dim, sigma, criterion, device):
_, _, fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out = batch_net_outputs(gen, disc, batch_size, z_dim, sigma, device)
_, _, gen_loss, disc_loss = net_losses(criterion, fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out)
p_count = torch.cat([x.flatten() for x in params]).shape[0]

gen.zero_grad()
Expand Down

0 comments on commit ecfef5a

Please sign in to comment.