Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

您好,加载提供的weight.pkl报错? #13

Open
douhaoexia opened this issue Nov 30, 2019 · 5 comments
Open

您好,加载提供的weight.pkl报错? #13

douhaoexia opened this issue Nov 30, 2019 · 5 comments

Comments

@douhaoexia
Copy link

你好,我想用您提供的weights.pkl文件进行测试,但是 torch.load()时总是报错.
如图:
2019-11-30 21-01-57屏幕截图

我已经试过torch.load()的map_location等设置,但还是加载不成功,报错仍然一样.
请问这个问题是怎么回事?

@little-one
Copy link

先把代码中的model文件夹名字改成models,然后加载时用下面代码
comb = torch.load(path)
fsrnet = comb['model']
这样写就没问题了
你试试能不能跑起来

@douhaoexia
Copy link
Author

先把代码中的model文件夹名字改成models,然后加载时用下面代码
comb = torch.load(path)
fsrnet = comb['model']
这样写就没问题了
你试试能不能跑起来

改名models后,那个错误就过去了,但是出现了新的错.
我是参照你的train.py中加载预训练模型的部分,这样写的:

weights = torch.load(path)
print(weights['model'])
pretrained_dict = weights['model'].state_dict()
model_dict = fsrnet.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
fsrnet.load_state_dict(model_dict)

报错为:
Traceback (most recent call last):
File "test.py", line 100, in
fsrnet.load_state_dict(model_dict)
File "/home/dou/miniconda3/envs/pytorch_101/lib/python3.6/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
size mismatch for fine_SR_Decoder.0.weight: copying a param with shape torch.Size([64, 75, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 192, 3, 3]).

请问您看从哪方面改一改

@little-one
Copy link

首先我不是作者哈,我也是遇到问题了进来看看。
我昨天已经调过了,你可以用这个思路试一下(没理解错的话你代码中pretrained_dict是加载作者提供的预训练模型,model_dict是你自己声明的新网络的模型,那我就使用这两个变量进行举例):
preList = []
for i in pretrained_dict:
tmpList.append(i)
这里的pretrained_dict是你加载作者提供的预训练模型得到的,同理可以得到model_dict对应的modList,你可以输出一下len(preList)和len(modList),我没记错的话preList长度应该是718,modList是716,也就是说从预训练模型中加载的网络和用代码声明的新网络在结构上已经不一致了(实际上是预训练模型的网络多了一个1x1的卷积层,所以正好多了卷积层的weights,和bais,所以权重的数量比716多出了2),至于想找到它在哪也比较好办,自己写一个代码计算 preList与modList的差集并输出,得到的就是卷积层的名字,根据命名规则就可以找回去,你也可以print(fsrnet)和print(pretrainedModel)对比二者网络结构来验证这个卷积层多在哪里。接下来就是手动往fsrnet中加入那个卷积层,然后再load参数,就可以了。

@douhaoexia
Copy link
Author

非常感谢, 我可以改一下网络的定义来加载预训练网络

@yangyingni
Copy link

非常感谢, 我可以改一下网络的定义来加载预训练网络

你好,方便提供一下测试的代码嘛,非常感谢

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants