diff --git a/basenet/vgg16_bn.py b/basenet/vgg16_bn.py index f3f21a7..2e35648 100644 --- a/basenet/vgg16_bn.py +++ b/basenet/vgg16_bn.py @@ -4,7 +4,6 @@ import torch.nn as nn import torch.nn.init as init from torchvision import models -from torchvision.models.vgg import model_urls def init_weights(modules): for m in modules: @@ -22,7 +21,6 @@ def init_weights(modules): class vgg16_bn(torch.nn.Module): def __init__(self, pretrained=True, freeze=True): super(vgg16_bn, self).__init__() - model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://') vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() diff --git a/craft.py b/craft.py index 27131df..96bb591 100755 --- a/craft.py +++ b/craft.py @@ -78,8 +78,11 @@ def forward(self, x): y = self.conv_cls(feature) return y.permute(0,2,3,1), feature + + def unload(self): + del self if __name__ == '__main__': model = CRAFT(pretrained=True).cuda() output, _ = model(torch.randn(1, 3, 768, 768).cuda()) - print(output.shape) \ No newline at end of file + #print(output.shape) \ No newline at end of file diff --git a/file_utils.py b/file_utils.py index 94ab040..4ed1a20 100644 --- a/file_utils.py +++ b/file_utils.py @@ -30,7 +30,7 @@ def list_files(in_path): # gt_files.sort() return img_files, mask_files, gt_files -def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts=None): +def saveResult(img_file, img, boxes, dirname='./result/', split=None, draw_bbox=False, verticals=None, texts=None): """ save text detection result one by one Args: img_file (str): image file name @@ -46,8 +46,12 @@ def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts= filename, file_ext = os.path.splitext(os.path.basename(img_file)) # result directory - res_file = dirname + "res_" + filename + '.txt' - res_img_file = dirname + "res_" + filename + '.jpg' + if split is not None: + res_file = f"{dirname}res_{split}_{filename}.txt" + res_img_file = f"{dirname}res_{split}_{filename}.jpg" + else: + res_file = f"{dirname}res_{filename}.txt" + res_img_file = f"{dirname}res_{filename}.jpg" if not os.path.isdir(dirname): os.mkdir(dirname) @@ -58,9 +62,10 @@ def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts= strResult = ','.join([str(p) for p in poly]) + '\r\n' f.write(strResult) - poly = poly.reshape(-1, 2) - cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2) - ptColor = (0, 255, 255) + if draw_bbox: + poly = poly.reshape(-1, 2) + cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2) + ptColor = (0, 255, 255) if verticals is not None: if verticals[i]: ptColor = (255, 0, 0) diff --git a/imgproc.py b/imgproc.py index ab09d6f..ee34833 100644 --- a/imgproc.py +++ b/imgproc.py @@ -8,13 +8,14 @@ from skimage import io import cv2 -def loadImage(img_file): +def loadImage(img_file, crop=None): img = io.imread(img_file) # RGB order if img.shape[0] == 2: img = img[0] if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) if img.shape[2] == 4: img = img[:,:,:3] img = np.array(img) - + if crop is not None: + img = img[crop[0]:crop[1],crop[2]:crop[3],:] return img def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): diff --git a/test.py b/test.py index 482b503..4c0cd04 100755 --- a/test.py +++ b/test.py @@ -53,6 +53,7 @@ def str2bool(v): parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type') parser.add_argument('--show_time', default=False, action='store_true', help='show processing time') parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images') +parser.add_argument('--result_folder', default='./result/', type=str, help='folder path to output images') parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner') parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model') @@ -62,7 +63,7 @@ def str2bool(v): """ For test images in a folder """ image_list, _, _ = file_utils.get_files(args.test_folder) -result_folder = './result/' +result_folder = args.result_folder if not os.path.isdir(result_folder): os.mkdir(result_folder)