File tree 1 file changed +5
-3
lines changed
1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change 4
4
import torch .nn as nn
5
5
import torch .nn .init as init
6
6
from torchvision import models
7
- from torchvision .models .vgg import model_urls
7
+ from torchvision .models .vgg import VGG16_BN_Weights
8
8
9
9
def init_weights (modules ):
10
10
for m in modules :
@@ -22,8 +22,10 @@ def init_weights(modules):
22
22
class vgg16_bn (torch .nn .Module ):
23
23
def __init__ (self , pretrained = True , freeze = True ):
24
24
super (vgg16_bn , self ).__init__ ()
25
- model_urls ['vgg16_bn' ] = model_urls ['vgg16_bn' ].replace ('https://' , 'http://' )
26
- vgg_pretrained_features = models .vgg16_bn (pretrained = pretrained ).features
25
+ # Use the weights parameter based on the pretrained flag
26
+ weights = VGG16_BN_Weights .IMAGENET1K_V1 if pretrained else None
27
+ vgg_pretrained_features = models .vgg16_bn (weights = weights ).features
28
+
27
29
self .slice1 = torch .nn .Sequential ()
28
30
self .slice2 = torch .nn .Sequential ()
29
31
self .slice3 = torch .nn .Sequential ()
You can’t perform that action at this time.
0 commit comments