Skip to content

Commit

Permalink
Simplify detect faces code
Browse files Browse the repository at this point in the history
  • Loading branch information
hugs7 committed Aug 23, 2024
1 parent 683f4be commit 544df18
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions eye_tracking/gaze/head_pose_estimation/face_landmark_estimator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import List

import face_alignment
import face_alignment.detection.sfd
import mediapipe
import numpy as np
from omegaconf import DictConfig
Expand All @@ -19,21 +17,14 @@ def __init__(self, config: DictConfig):
)

def detect_faces(self, image: np.ndarray) -> List[Face]:
return self._detect_faces_mediapipe(image)

def detect_faces_raw(self, image: np.ndarray) -> List[np.ndarray]:
if self.mode == "mediapipe":
return self._detect_faces_mediapipe_raw(image)
else:
raise ValueError

def _detect_faces_mediapipe(self, image: np.ndarray) -> List[Face]:
"""
Calculated landmarks scaled to the image size with a bounding box
:param image: RGB image
:return: List of faces
"""

h, w = image.shape[:2]
faces_landmarks = self._detect_faces_mediapipe_raw(image)
faces_landmarks = self._detect_faces_raw(image)
detected = []
if faces_landmarks:
for face in faces_landmarks:
Expand All @@ -43,9 +34,17 @@ def _detect_faces_mediapipe(self, image: np.ndarray) -> List[Face]:
detected.append(Face(bbox, pts))
return detected

def _detect_faces_mediapipe_raw(self, image: np.ndarray) -> List[np.ndarray]:
def detect_faces_raw(self, image: np.ndarray) -> List[np.ndarray]:
if self.mode == "mediapipe":
return self._detect_faces_raw(image)
else:
raise ValueError

def _detect_faces_raw(self, image: np.ndarray) -> List[np.ndarray]:
"""
Returns landmarks as they come from the mediapipe model
Returns landmarks as they come from the mediapipe model (not scaled to the image size)
:param image: RGB image
:return: List of faces landmarks
"""
predictions = self.detector.process(self._get_bgr_frame(image))
faces_landmarks = []
Expand Down

0 comments on commit 544df18

Please sign in to comment.