Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请问一下,下面两个损失都输针对判别器的损失,输入都是pred_fake,但是target确不一样,这样每一次优化的时候会不会造成混淆? #42

Open
cjt222 opened this issue Jun 10, 2021 · 1 comment

Comments

@cjt222
Copy link

cjt222 commented Jun 10, 2021

No description provided.

@cjt222
Copy link
Author

cjt222 commented Jun 10, 2021

def compute_generator_loss(self, input_semantics, real_image):
    G_losses = {}

    fake_image = self.generate_fake(input_semantics)

    pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)

    G_losses['GAN'] = self.criterionGAN(pred_fake, True,
                                        for_discriminator=False)

    if not self.opt.no_ganFeat_loss:
        num_D = len(pred_fake)
        GAN_Feat_loss = self.FloatTensor(1).fill_(0)
        for i in range(num_D):  # for each discriminator
            # last output is the final prediction, so we exclude it
            num_intermediate_outputs = len(pred_fake[i]) - 1
            for j in range(num_intermediate_outputs):  # for each layer output
                unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
                GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D
        G_losses['GAN_Feat'] = GAN_Feat_loss
    
    h ,w = fake_image.shape[-2:]
    if not self.opt.no_vgg_loss and min(w,h)>=64:
        G_losses['VGG'] = self.criterionVGG(fake_image, real_image) \
                          * self.opt.lambda_vgg

    return G_losses, fake_image

def compute_discriminator_loss(self, input_semantics, real_image):
    D_losses = {}
    with torch.no_grad():
        fake_image = self.generate_fake(input_semantics)
        fake_image = fake_image.detach()
        fake_image.requires_grad_()

    pred_fake, pred_real = self.discriminate(
        input_semantics, fake_image, real_image)

    D_losses['D_Fake'] = self.criterionGAN(pred_fake, False,
                                           for_discriminator=True)
    D_losses['D_real'] = self.criterionGAN(pred_real, True,
                                           for_discriminator=True)

    return D_losses

gan_mode是ls,所以采用的是mse loss,这里面的 G_losses['GAN']和D_losses['D_Fake']貌似是相反的操作

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant