-
Notifications
You must be signed in to change notification settings - Fork 191
/
Copy pathPrediction.py
106 lines (88 loc) · 3.45 KB
/
Prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import cv2
import itertools, os, time
import numpy as np
from Model import get_Model
from parameter import letters
import argparse
from keras import backend as K
K.set_learning_phase(0)
Region = {"A": "서울 ", "B": "경기 ", "C": "인천 ", "D": "강원 ", "E": "충남 ", "F": "대전 ",
"G": "충북 ", "H": "부산 ", "I": "울산 ", "J": "대구 ", "K": "경북 ", "L": "경남 ",
"M": "전남 ", "N": "광주 ", "O": "전북 ", "P": "제주 "}
Hangul = {"dk": "아", "dj": "어", "dh": "오", "dn": "우", "qk": "바", "qj": "버", "qh": "보", "qn": "부",
"ek": "다", "ej": "더", "eh": "도", "en": "두", "rk": "가", "rj": "거", "rh": "고", "rn": "구",
"wk": "자", "wj": "저", "wh": "조", "wn": "주", "ak": "마", "aj": "머", "ah": "모", "an": "무",
"sk": "나", "sj": "너", "sh": "노", "sn": "누", "fk": "라", "fj": "러", "fh": "로", "fn": "루",
"tk": "사", "tj": "서", "th": "소", "tn": "수", "gj": "허"}
def decode_label(out):
# out : (1, 32, 42)
out_best = list(np.argmax(out[0, 2:], axis=1)) # get max index -> len = 32
out_best = [k for k, g in itertools.groupby(out_best)] # remove overlap value
outstr = ''
for i in out_best:
if i < len(letters):
outstr += letters[i]
return outstr
def label_to_hangul(label): # eng -> hangul
region = label[0]
two_num = label[1:3]
hangul = label[3:5]
four_num = label[5:]
try:
region = Region[region] if region != 'Z' else ''
except:
pass
try:
hangul = Hangul[hangul]
except:
pass
return region + two_num + hangul + four_num
parser = argparse.ArgumentParser()
parser.add_argument("-w", "--weight", help="weight file directory",
type=str, default="Final_weight.hdf5")
parser.add_argument("-t", "--test_img", help="Test image directory",
type=str, default="./DB/test/")
args = parser.parse_args()
# Get CRNN model
model = get_Model(training=False)
try:
model.load_weights(args.weight)
print("...Previous weight data...")
except:
raise Exception("No weight file!")
test_dir =args.test_img
test_imgs = os.listdir(args.test_img)
total = 0
acc = 0
letter_total = 0
letter_acc = 0
start = time.time()
for test_img in test_imgs:
img = cv2.imread(test_dir + test_img, cv2.IMREAD_GRAYSCALE)
img_pred = img.astype(np.float32)
img_pred = cv2.resize(img_pred, (128, 64))
img_pred = (img_pred / 255.0) * 2.0 - 1.0
img_pred = img_pred.T
img_pred = np.expand_dims(img_pred, axis=-1)
img_pred = np.expand_dims(img_pred, axis=0)
net_out_value = model.predict(img_pred)
pred_texts = decode_label(net_out_value)
for i in range(min(len(pred_texts), len(test_img[0:-4]))):
if pred_texts[i] == test_img[i]:
letter_acc += 1
letter_total += max(len(pred_texts), len(test_img[0:-4]))
if pred_texts == test_img[0:-4]:
acc += 1
total += 1
print('Predicted: %s / True: %s' % (label_to_hangul(pred_texts), label_to_hangul(test_img[0:-4])))
# cv2.rectangle(img, (0,0), (150, 30), (0,0,0), -1)
# cv2.putText(img, pred_texts, (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255),2)
#cv2.imshow("q", img)
#if cv2.waitKey(0) == 27:
# break
#cv2.destroyAllWindows()
end = time.time()
total_time = (end - start)
print("Time : ",total_time / total)
print("ACC : ", acc / total)
print("letter ACC : ", letter_acc / letter_total)