diff --git a/basenet/vgg16_bn.py b/basenet/vgg16_bn.py index f3f21a7..57aa375 100644 --- a/basenet/vgg16_bn.py +++ b/basenet/vgg16_bn.py @@ -1,10 +1,9 @@ from collections import namedtuple - import torch import torch.nn as nn import torch.nn.init as init from torchvision import models -from torchvision.models.vgg import model_urls +from torchvision.models.vgg import VGG16_BN_Weights def init_weights(modules): for m in modules: @@ -22,27 +21,30 @@ def init_weights(modules): class vgg16_bn(torch.nn.Module): def __init__(self, pretrained=True, freeze=True): super(vgg16_bn, self).__init__() - model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://') - vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features + # Use the weights parameter based on the pretrained flag + weights = VGG16_BN_Weights.IMAGENET1K_V1 if pretrained else None + vgg_pretrained_features = models.vgg16_bn(weights=weights).features + self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() - for x in range(12): # conv2_2 + + for x in range(12): # conv2_2 self.slice1.add_module(str(x), vgg_pretrained_features[x]) - for x in range(12, 19): # conv3_3 + for x in range(12, 19): # conv3_3 self.slice2.add_module(str(x), vgg_pretrained_features[x]) - for x in range(19, 29): # conv4_3 + for x in range(19, 29): # conv4_3 self.slice3.add_module(str(x), vgg_pretrained_features[x]) - for x in range(29, 39): # conv5_3 + for x in range(29, 39): # conv5_3 self.slice4.add_module(str(x), vgg_pretrained_features[x]) # fc6, fc7 without atrous conv self.slice5 = torch.nn.Sequential( - nn.MaxPool2d(kernel_size=3, stride=1, padding=1), - nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), - nn.Conv2d(1024, 1024, kernel_size=1) + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), + nn.Conv2d(1024, 1024, kernel_size=1) ) if not pretrained: @@ -51,11 +53,11 @@ def __init__(self, pretrained=True, freeze=True): init_weights(self.slice3.modules()) init_weights(self.slice4.modules()) - init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 + init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 if freeze: - for param in self.slice1.parameters(): # only first conv - param.requires_grad= False + for param in self.slice1.parameters(): # only first conv + param.requires_grad = False def forward(self, X): h = self.slice1(X) diff --git a/requirements.txt b/requirements.txt index f4b2412..ab474cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -torch==0.4.1.post2 -torchvision==0.2.1 -opencv-python==3.4.2.17 -scikit-image==0.14.2 -scipy==1.1.0 \ No newline at end of file +torch==2.2.2 +torchvision==0.17.2 +opencv-python==4.10.0 +scikit-image==1.5.1 +scipy==1.14.0 diff --git a/test.py b/test.py index 482b503..0f09c4f 100755 --- a/test.py +++ b/test.py @@ -47,7 +47,7 @@ def str2bool(v): parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold') parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score') parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold') -parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference') +parser.add_argument('--cuda', default=torch.cuda.is_available(), type=str2bool, help='Use cuda for inference') parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference') parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio') parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')