Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LIQE训练过程中loss不收敛,请问是什么问题呢 #229

Open
liuchenxin666 opened this issue Dec 2, 2024 · 4 comments
Open

LIQE训练过程中loss不收敛,请问是什么问题呢 #229

liuchenxin666 opened this issue Dec 2, 2024 · 4 comments

Comments

@liuchenxin666
Copy link

`
model = pyiqa.create_metric('liqe', device=device, pretrained=False)
model.to(device)

pred = model(img)
loss = loss_iqa(pred.squeeze(), mos.float().detach())
`

@chaofengc
Copy link
Owner

chaofengc commented Dec 3, 2024

模型默认采用推理模式,因此disable gradient来节省计算量。如果需要训练需要传入as_loss=True参数pyiqa.create_metric('liqe', device=device, as_loss=True, pretrained=False)

@liuchenxin666
Copy link
Author

但是传入as_loss=True参数时,InferenceModel模型的forward()函数并没有传入mos值,weight_reduce_loss()函数如何计算的loss值。
同时,当as_loss=True时,forward()函数不再返回预测分数,那么我再计算测试集或者验证集上的SROCC,PLCC值时,是否无法继续进行呢?
因此,我不太理解as_loss=True参数的作用,我是否可以直接修改InferenceModel模型的forward()函数使其返回output呢?

    def forward(self, target, ref=None, **kwargs):
        device = self.device

        with torch.set_grad_enabled(self.as_loss):

            if self.metric_name == 'fid':
                output = self.net(target, ref, device=device, **kwargs)
            elif self.metric_name == 'inception_score':
                output = self.net(target, device=device, **kwargs)
            else:
                if not torch.is_tensor(target):
                    target = imread2tensor(target, rgb=True)
                    target = target.unsqueeze(0)
                    if self.metric_mode == 'FR':
                        assert ref is not None, 'Please specify reference image for Full Reference metric'
                        ref = imread2tensor(ref, rgb=True)
                        ref = ref.unsqueeze(0)
                        self.is_valid_input(ref)
                
                self.is_valid_input(target)

                if self.metric_mode == 'FR':
                    assert ref is not None, 'Please specify reference image for Full Reference metric'
                    output = self.net(target.to(device), ref.to(device), **kwargs)
                elif self.metric_mode == 'NR':
                    output = self.net(target.to(device), **kwargs)

        # if self.as_loss:
        #     if isinstance(output, tuple):
        #         output = output[0]
        #     return weight_reduce_loss(output, self.loss_weight, self.loss_reduction)
        # else:
        #     return output
        return output

@liuchenxin666
Copy link
Author

此外,我理解的weight_reduce_loss()函数是在将output一个batch的预测分数求了个平均值?如果我理解的哪里有误,希望能够得到前辈指导,非常感谢!

@chaofengc
Copy link
Owner

chaofengc commented Dec 4, 2024

这里的loss是指用训练好的模型作为一个另一个需要训练的模型(例如图像超分辨率模型)的损失函数。如果你想要训练IQA模型本身,请参考本仓库训练部分的代码

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants