-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
129 lines (102 loc) · 4.06 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# PREDICT SINGLE ONE IMAGE
# import argparse
# import glob
# import os
# import sys
# import cv2
# import torch
# from PIL import Image
# from vietocr.tool.config import Cfg
# from vietocr.tool.predictor import Predictor
# def main():
# parser = argparse.ArgumentParser(description="Example script with a required command line argument",
# formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# parser.add_argument("-F", "--file", type=str, help="input file or folder path", required=True)
# args = parser.parse_args()
# # Check if the provided path is a file or a folder
# if os.path.isfile(args.file):
# img_paths = [args.file]
# elif os.path.isdir(args.file):
# img_paths = glob.glob(os.path.join(args.file, '*.jpg'))
# else:
# print("Error: The provided path is neither a file nor a folder.")
# sys.exit(1)
# # Open the annotation file for writing
# with open("annotation.txt", "w", encoding="utf-8") as f:
# # Configure VietOCR
# config = Cfg.load_config_from_name('vgg_transformer')
# if torch.cuda.is_available():
# config['device'] = "cuda:0"
# else:
# config['device'] = "cpu"
# config['cnn']['pretrained'] = True
# config['trainer']['batch_size'] = 2048
# print(config)
# recognitor = Predictor(config)
# # Process each image
# for img_path in img_paths:
# img = cv2.imread(img_path)
# img = Image.fromarray(img)
# rec_result = recognitor.predict(img)
# f.write(os.path.basename(img_path) + "\t" + rec_result + "\n")
# # print(img_path + "\t" + rec_result)
# print("Processing completed. Results written to annotation.txt.")
# if __name__ == "__main__":
# main()
# PREDICT BATCH OF IMAGES
import argparse
import glob
import os
import sys
from tqdm import tqdm
import cv2
import torch
from PIL import Image
from vietocr.tool.config import Cfg
from vietocr.tool.predictor import Predictor
def main():
parser = argparse.ArgumentParser(
description="Example script with a required command line argument",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-F", "--file", type=str, help="input file or folder path", required=True
)
args = parser.parse_args()
# Check if the provided path is a file or a folder
if os.path.isfile(args.file):
img_paths = [args.file]
elif os.path.isdir(args.file):
img_paths = glob.glob(os.path.join(args.file, "*.jpg"))
else:
print("Error: The provided path is neither a file nor a folder.")
sys.exit(1)
# Configure VietOCR
config = Cfg.load_config_from_name("vgg_transformer")
if torch.cuda.is_available():
config["device"] = "cuda:0"
else:
config["device"] = "cpu"
config["cnn"]["pretrained"] = True
config["trainer"]["batch_size"] = 256
# print(config)
recognitor = Predictor(config)
# Open the annotation file for writing
with open("khaisinh_kethon.txt", "w", encoding="utf-8") as f:
# Process images in batches with tqdm for progress bar
batch_size = 256 # Adjust the batch size based on your available memory
for i in tqdm(
range(0, len(img_paths), batch_size), desc="Processing", unit="batch"
):
batch_paths = img_paths[i : i + batch_size]
batch_images = [cv2.imread(path) for path in batch_paths]
batch_images = [Image.fromarray(img) for img in batch_images]
# Use the predict_batch method to get predictions for the batch
batch_results = recognitor.predict_batch(batch_images)
# Write results to the annotation file
for path, result in zip(batch_paths, batch_results):
f.write(os.path.basename(path) + "\t" + result + "\n")
# print(path + "\t" + result + "\n")
print("Processing completed. Results written to annotation.txt.")
if __name__ == "__main__":
main()