-
Notifications
You must be signed in to change notification settings - Fork 503
/
demo.py
134 lines (102 loc) · 4.67 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from pathlib import Path
import cv2
import dlib
import numpy as np
import argparse
from contextlib import contextmanager
from keras.utils.data_utils import get_file
from model import get_model
pretrained_model = "https://github.com/yu4u/age-gender-estimation/releases/download/v0.5/age_only_resnet50_weights.061-3.300-4.410.hdf5"
modhash = "306e44200d3f632a5dccac153c2966f2"
def get_args():
parser = argparse.ArgumentParser(description="This script detects faces from web cam input, "
"and estimates age for the detected faces.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model_name", type=str, default="ResNet50",
help="model name: 'ResNet50' or 'InceptionResNetV2'")
parser.add_argument("--weight_file", type=str, default=None,
help="path to weight file (e.g. age_only_weights.029-4.027-5.250.hdf5)")
parser.add_argument("--margin", type=float, default=0.4,
help="margin around detected face for age-gender estimation")
parser.add_argument("--image_dir", type=str, default=None,
help="target image directory; if set, images in image_dir are used instead of webcam")
args = parser.parse_args()
return args
def draw_label(image, point, label, font=cv2.FONT_HERSHEY_SIMPLEX,
font_scale=1, thickness=2):
size = cv2.getTextSize(label, font, font_scale, thickness)[0]
x, y = point
cv2.rectangle(image, (x, y - size[1]), (x + size[0], y), (255, 0, 0), cv2.FILLED)
cv2.putText(image, label, point, font, font_scale, (255, 255, 255), thickness)
@contextmanager
def video_capture(*args, **kwargs):
cap = cv2.VideoCapture(*args, **kwargs)
try:
yield cap
finally:
cap.release()
def yield_images():
# capture video
with video_capture(0) as cap:
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
while True:
# get video frame
ret, img = cap.read()
if not ret:
raise RuntimeError("Failed to capture image")
yield img
def yield_images_from_dir(image_dir):
image_dir = Path(image_dir)
for image_path in image_dir.glob("*.*"):
img = cv2.imread(str(image_path), 1)
if img is not None:
h, w, _ = img.shape
r = 640 / max(w, h)
yield cv2.resize(img, (int(w * r), int(h * r)))
def main():
args = get_args()
model_name = args.model_name
weight_file = args.weight_file
margin = args.margin
image_dir = args.image_dir
if not weight_file:
weight_file = get_file("age_only_resnet50_weights.061-3.300-4.410.hdf5", pretrained_model,
cache_subdir="pretrained_models",
file_hash=modhash, cache_dir=Path(__file__).resolve().parent)
# for face detection
detector = dlib.get_frontal_face_detector()
# load model and weights
model = get_model(model_name=model_name)
model.load_weights(weight_file)
img_size = model.input.shape.as_list()[1]
image_generator = yield_images_from_dir(image_dir) if image_dir else yield_images()
for img in image_generator:
input_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_h, img_w, _ = np.shape(input_img)
# detect faces using dlib detector
detected = detector(input_img, 1)
faces = np.empty((len(detected), img_size, img_size, 3))
if len(detected) > 0:
for i, d in enumerate(detected):
x1, y1, x2, y2, w, h = d.left(), d.top(), d.right() + 1, d.bottom() + 1, d.width(), d.height()
xw1 = max(int(x1 - margin * w), 0)
yw1 = max(int(y1 - margin * h), 0)
xw2 = min(int(x2 + margin * w), img_w - 1)
yw2 = min(int(y2 + margin * h), img_h - 1)
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
faces[i, :, :, :] = cv2.resize(img[yw1:yw2 + 1, xw1:xw2 + 1, :], (img_size, img_size))
# predict ages and genders of the detected faces
results = model.predict(faces)
ages = np.arange(0, 101).reshape(101, 1)
predicted_ages = results.dot(ages).flatten()
# draw results
for i, d in enumerate(detected):
label = str(int(predicted_ages[i]))
draw_label(img, (d.left(), d.top()), label)
cv2.imshow("result", img)
key = cv2.waitKey(-1) if image_dir else cv2.waitKey(30)
if key == 27: # ESC
break
if __name__ == '__main__':
main()