Skip to content

Commit e8a8cdf

Browse files
committed
fix for torchvision >= 0.17.2
see also clovaai#215
1 parent c096b69 commit e8a8cdf

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

basenet/vgg16_bn.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn as nn
55
import torch.nn.init as init
66
from torchvision import models
7-
from torchvision.models.vgg import model_urls
7+
from torchvision.models.vgg import VGG16_BN_Weights
88

99
def init_weights(modules):
1010
for m in modules:
@@ -22,8 +22,10 @@ def init_weights(modules):
2222
class vgg16_bn(torch.nn.Module):
2323
def __init__(self, pretrained=True, freeze=True):
2424
super(vgg16_bn, self).__init__()
25-
model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')
26-
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
25+
# Use the weights parameter based on the pretrained flag
26+
weights = VGG16_BN_Weights.IMAGENET1K_V1 if pretrained else None
27+
vgg_pretrained_features = models.vgg16_bn(weights=weights).features
28+
2729
self.slice1 = torch.nn.Sequential()
2830
self.slice2 = torch.nn.Sequential()
2931
self.slice3 = torch.nn.Sequential()

0 commit comments

Comments
 (0)