-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathrecognize_att.py
109 lines (88 loc) · 4.74 KB
/
recognize_att.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import argparse
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from torchvision import transforms as T
from typing import Union
from easyface.attributes.models import *
from easyface.utils.visualize import show_image
from easyface.utils.io import WebcamStream, VideoReader, VideoWriter, FPS
from detect_align import FaceDetectAlign
class Inference:
def __init__(self, model: str, checkpoint: str, det_model: str, det_checkpoint: str) -> None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.gender_labels = ['Male', 'Female']
self.race_labels = ['White', 'Black', 'Latino Hispanic', 'East Asian', 'Southeast Asian', 'Indian', 'Middle Eastern']
self.age_labels = ['0-2', '3-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60-69', '70+']
self.model = eval(model)(len(self.gender_labels) + len(self.race_labels) + len(self.age_labels))
self.model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
self.model = self.model.to(self.device)
self.model.eval()
self.align = FaceDetectAlign(det_model, det_checkpoint)
self.preprocess = T.Compose([
T.Resize((224, 224)),
T.Lambda(lambda x: x / 255),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def visualize(self, image, dets, races, genders, ages):
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
boxes = dets[:, :4].astype(int)
for box, race, gender, age in zip(boxes, races, genders, ages):
cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2)
cv2.rectangle(image, (box[0], box[3] + 5), (box[2] + 20, box[3] + 50), (255, 255, 255), -1)
cv2.putText(image, gender, (box[0], box[3] + 15), cv2.FONT_HERSHEY_DUPLEX, 0.4, (0, 0, 0), lineType=cv2.LINE_AA)
cv2.putText(image, race, (box[0], box[3] + 30), cv2.FONT_HERSHEY_DUPLEX, 0.4, (0, 0, 0), lineType=cv2.LINE_AA)
cv2.putText(image, age, (box[0], box[3] + 45), cv2.FONT_HERSHEY_DUPLEX, 0.4, (0, 0, 0), lineType=cv2.LINE_AA)
return image
def postprocess(self, preds: torch.Tensor):
race_logits, gender_logits, age_logits = preds[:, :7].softmax(dim=1), preds[:, 7:9].softmax(dim=1), preds[:, 9:18].softmax(dim=1)
race_preds = torch.argmax(race_logits, dim=1)
gender_preds = torch.argmax(gender_logits, dim=1)
age_preds = torch.argmax(age_logits, dim=1)
return [self.race_labels[idx] for idx in race_preds], [self.gender_labels[idx] for idx in gender_preds], [self.age_labels[idx] for idx in age_preds]
def __call__(self, img_path: Union[str, np.ndarray]):
faces, dets, image = self.align.detect_and_align_faces(img_path, (112, 112))
if faces is None:
return cv2.cvtColor(image[0], cv2.COLOR_RGB2BGR)
pfaces = self.preprocess(faces.permute(0, 3, 1, 2)).to(self.device)
with torch.inference_mode():
preds = self.model(pfaces).detach().cpu()
races, genders, ages = self.postprocess(preds)
image = self.visualize(image[0], dets[0], races, genders, ages)
return image
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--source', type=str, default='assets/asian_american.jpg')
parser.add_argument('--model', type=str, default='FairFace')
parser.add_argument('--checkpoint', type=str, default='/home/sithu/checkpoints/facialattributes/fairface/res34_fairface.pth')
parser.add_argument('--det_model', type=str, default='RetinaFace')
parser.add_argument('--det_checkpoint', type=str, default='/home/sithu/checkpoints/FR/retinaface/mobilenet0.25_Final.pth')
args = vars(parser.parse_args())
source = args.pop('source')
file_path = Path(source)
inference = Inference(**args)
if file_path.is_file():
if file_path.suffix in ['.mp4', '.avi', '.m4v']:
reader = VideoReader(str(file_path))
writer = VideoWriter(f"{str(file_path).split('.', maxsplit=1)[0]}_out.mp4", reader.fps)
for frame in tqdm(reader):
image = inference(frame)
writer.update(image[:, :, ::-1])
writer.write()
else:
image = inference(str(file_path))
image = Image.fromarray(image[:, :, ::-1]).convert('RGB')
image.show()
elif str(file_path) == 'webcam':
stream = WebcamStream(0)
fps = FPS()
for frame in stream:
fps.start()
frame = inference(frame)
fps.stop()
cv2.imshow('frame', frame)
else:
raise FileNotFoundError(f"The following file does not exist: {str(file_path)}")