Skip to content

Commit c67bbab

Browse files
authored
[TorchFix] Update deprecated TorchVision pretrained parameters (#1193)
Update deprecated TorchVision pretrained parameters
1 parent cead596 commit c67bbab

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

cpp/transfer-learning/convert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchvision import models
66

77
# Download and load the pre-trained model
8-
model = models.resnet18(pretrained=True)
8+
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
99

1010
# Set upgrading the gradients to False
1111
for param in model.parameters():

fast_neural_style/neural_style/vgg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
class Vgg16(torch.nn.Module):
88
def __init__(self, requires_grad=False):
99
super(Vgg16, self).__init__()
10-
vgg_pretrained_features = models.vgg16(pretrained=True).features
10+
vgg_pretrained_features = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
1111
self.slice1 = torch.nn.Sequential()
1212
self.slice2 = torch.nn.Sequential()
1313
self.slice3 = torch.nn.Sequential()

0 commit comments

Comments
 (0)