diff --git a/eye_tracking/constants.py b/eye_tracking/constants.py deleted file mode 100644 index 29d2f1d4..00000000 --- a/eye_tracking/constants.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -Constants file -Author: Hugo Burton -Last Updated: 21/08/2024 -""" - -MAPPINGS_FOLDER = "mappings" - - -# Landmark Mapping - -EXPECTED_EYE_POINT_KEYS = ["centre", "right", "top", "left", "bottom"] - -# Window - -EYE_TRACKING_WINDOW_NAME = "Eye Tracking" - -# Keys - -ENTER_KEY = "\r" -BACKSPACE_KEY = "\x08" -ESCAPE_KEY = "\x1b" - - -# Buttons - -BUTTON_WIDTH = 300 -BUTTON_HEIGHT = 40 -BUTTON_PADDING_X = 10 -BUTTON_PADDING_Y = 5 diff --git a/eye_tracking/gaze/__main__.py b/eye_tracking/gaze/__main__.py deleted file mode 100644 index a952ff50..00000000 --- a/eye_tracking/gaze/__main__.py +++ /dev/null @@ -1,3 +0,0 @@ -import main - -main.main() diff --git a/eye_tracking/gaze/common/face_model.py b/eye_tracking/gaze/common/face_model.py index 9a7484ef..ff041e9b 100644 --- a/eye_tracking/gaze/common/face_model.py +++ b/eye_tracking/gaze/common/face_model.py @@ -84,20 +84,14 @@ def compute_3d_pose(self, face: Face) -> None: rot = face.head_pose_rot.as_matrix() # Has units of radians face.model3d = self.LANDMARKS @ rot.T + face.head_position # This is the 3D model of the face in world coordinates - def compute_face_eye_centers(self, face: Face, mode: str) -> None: + def compute_face_eye_centers(self, face: Face) -> None: """Compute the centers of the face and eyes. - In the case of MPIIFaceGaze, the face center is defined as the + The face center is defined as the average coordinates of the six points at the corners of both - eyes and the mouth. In the case of ETH-XGaze, it's defined as - the average coordinates of the six points at the corners of both - eyes and the nose. The eye centers are defined as the average - coordinates of the corners of each eye. + eyes and the mouth. """ - if mode == "ETH-XGaze": - face.center = face.model3d[np.concatenate([self.REYE_INDICES, self.LEYE_INDICES, self.NOSE_INDICES])].mean(axis=0) - else: - face.center = face.model3d[np.concatenate([self.REYE_INDICES, self.LEYE_INDICES, self.MOUTH_INDICES])].mean(axis=0) + face.center = face.model3d[np.concatenate([self.REYE_INDICES, self.LEYE_INDICES, self.MOUTH_INDICES])].mean(axis=0) # Face centre is world coordinates in 3D with units metres relative to the camera face.reye.center = face.model3d[self.REYE_INDICES].mean(axis=0) diff --git a/eye_tracking/gaze/gaze_detector.py b/eye_tracking/gaze/gaze_detector.py index b6f4ca17..9cad6ea7 100644 --- a/eye_tracking/gaze/gaze_detector.py +++ b/eye_tracking/gaze/gaze_detector.py @@ -271,18 +271,17 @@ def _draw_face_template_model(self, face: Face) -> None: def _display_normalized_image(self, face: Face) -> None: if not self.config.demo.display_on_screen: return + if not self.show_normalized_image: return - if self.config.mode == "MPIIGaze": - reye = face.reye.normalized_image - leye = face.leye.normalized_image - normalized = np.hstack([reye, leye]) - elif self.config.mode in ["MPIIFaceGaze", "ETH-XGaze"]: - normalized = face.normalized_image - else: - raise ValueError + + reye = face.reye.normalized_image + leye = face.leye.normalized_image + normalized = np.hstack([reye, leye]) + if self.config.demo.use_camera: normalized = utils.flip_image(normalized) + cv2.imshow("normalized", normalized) def _draw_gaze_vector(self, face: Face) -> None: @@ -291,37 +290,30 @@ def _draw_gaze_vector(self, face: Face) -> None: length = self.config.demo.gaze_visualization_length - if self.config.mode == "MPIIGaze": - for key in [FacePartsName.REYE, FacePartsName.LEYE]: - eye = getattr(face, key.name.lower()) - end_point = eye.center + length * eye.gaze_vector # eye.gaze_vector.z is always -1. We scale by length - self.visualizer.draw_3d_line(eye.center, end_point) + for key in [FacePartsName.REYE, FacePartsName.LEYE]: + eye = getattr(face, key.name.lower()) + end_point = eye.center + length * eye.gaze_vector # eye.gaze_vector.z is always -1. We scale by length + self.visualizer.draw_3d_line(eye.center, end_point) - pitch, yaw = np.rad2deg(eye.vector_to_angle(eye.gaze_vector)) - logger.info(f"[{key.name.lower()}] pitch: {pitch:.2f}, yaw: {yaw:.2f}") + pitch, yaw = np.rad2deg(eye.vector_to_angle(eye.gaze_vector)) + logger.info(f"[{key.name.lower()}] pitch: {pitch:.2f}, yaw: {yaw:.2f}") - average_eye_distance = (face.reye.distance + face.leye.distance) / 2 - average_eye_center = (face.reye.center + face.leye.center) / 2 - average_gaze_vector = (face.reye.gaze_vector + face.leye.gaze_vector) / 2 + average_eye_distance = (face.reye.distance + face.leye.distance) / 2 + average_eye_center = (face.reye.center + face.leye.center) / 2 + average_gaze_vector = (face.reye.gaze_vector + face.leye.gaze_vector) / 2 - end_point = average_eye_center + length * average_gaze_vector - self.visualizer.draw_3d_line(average_eye_center, end_point) + end_point = average_eye_center + length * average_gaze_vector + self.visualizer.draw_3d_line(average_eye_center, end_point) - # Draw the point on the screen the user is looking at - point_on_screen = average_eye_center + (average_eye_distance * 0.9) * average_gaze_vector - point_on_screen[1] *= 0.15 # Scale y-coordinate + # Draw the point on the screen the user is looking at + point_on_screen = average_eye_center + (average_eye_distance * 0.9) * average_gaze_vector + point_on_screen[1] *= 0.15 # Scale y-coordinate - # Update buffer and calculate smoothed point - self.point_buffer.append(point_on_screen) - if len(self.point_buffer) > self.point_on_screen_smoothing_factor: - self.point_buffer.pop(0) # Remove oldest point + # Update buffer and calculate smoothed point + self.point_buffer.append(point_on_screen) + if len(self.point_buffer) > self.point_on_screen_smoothing_factor: + self.point_buffer.pop(0) # Remove oldest point - smoothed_point = np.mean(self.point_buffer, axis=0) + smoothed_point = np.mean(self.point_buffer, axis=0) - self.visualizer.draw_3d_points(np.array([smoothed_point]), color=(0, 255, 0), size=10, clamp_to_screen=True) - elif self.config.mode in ["MPIIFaceGaze", "ETH-XGaze"]: - self.visualizer.draw_3d_line(face.center, face.center + length * face.gaze_vector) - pitch, yaw = np.rad2deg(face.vector_to_angle(face.gaze_vector)) - logger.info(f"[face] pitch: {pitch:.2f}, yaw: {yaw:.2f}") - else: - raise ValueError + self.visualizer.draw_3d_points(np.array([smoothed_point]), color=(0, 255, 0), size=10, clamp_to_screen=True) diff --git a/eye_tracking/gaze/gaze_estimator.py b/eye_tracking/gaze/gaze_estimator.py index 722bfe41..61c76126 100644 --- a/eye_tracking/gaze/gaze_estimator.py +++ b/eye_tracking/gaze/gaze_estimator.py @@ -36,7 +36,7 @@ def __init__(self, config: DictConfig): self.camera, self._normalized_camera, self._config.gaze_estimator.normalized_camera_distance ) self._gaze_estimation_model = self._load_model() - self._transform = create_transform(config) + self._transform = create_transform() def _load_model(self) -> torch.nn.Module: model = create_model(self._config) @@ -55,36 +55,32 @@ def detect_faces_raw(self, image: np.ndarray) -> List[np.ndarray]: def estimate_gaze(self, image: np.ndarray, face: Face) -> None: self._face_model3d.estimate_head_pose(face, self.camera) self._face_model3d.compute_3d_pose(face) - self._face_model3d.compute_face_eye_centers(face, self._config.mode) - - if self._config.mode == "MPIIGaze": - for key in self.EYE_KEYS: - eye = getattr(face, key.name.lower()) - self._head_pose_normalizer.normalize(image, eye) - self._run_mpiigaze_model(face) - elif self._config.mode == "MPIIFaceGaze": - self._head_pose_normalizer.normalize(image, face) - self._run_mpiifacegaze_model(face) - elif self._config.mode == "ETH-XGaze": - self._head_pose_normalizer.normalize(image, face) - self._run_ethxgaze_model(face) - else: - raise ValueError + self._face_model3d.compute_face_eye_centers(face) + + for key in self.EYE_KEYS: + eye = getattr(face, key.name.lower()) + self._head_pose_normalizer.normalize(image, eye) + + self._run_mpiigaze_model(face) @torch.no_grad() def _run_mpiigaze_model(self, face: Face) -> None: images = [] head_poses = [] + for key in self.EYE_KEYS: eye = getattr(face, key.name.lower()) image = eye.normalized_image normalized_head_pose = eye.normalized_head_rot2d + if key == FacePartsName.REYE: image = utils.flip_image(image).copy() normalized_head_pose *= np.array([1, -1]) + image = self._transform(image) images.append(image) head_poses.append(normalized_head_pose) + images = torch.stack(images) head_poses = np.array(head_poses).astype(np.float32) head_poses = torch.from_numpy(head_poses) @@ -98,33 +94,9 @@ def _run_mpiigaze_model(self, face: Face) -> None: for i, key in enumerate(self.EYE_KEYS): eye = getattr(face, key.name.lower()) eye.normalized_gaze_angles = predictions[i] + if key == FacePartsName.REYE: eye.normalized_gaze_angles *= np.array([1, -1]) + eye.angle_to_vector() eye.denormalize_gaze_vector() - - @torch.no_grad() - def _run_mpiifacegaze_model(self, face: Face) -> None: - image = self._transform(face.normalized_image).unsqueeze(0) - - device = torch.device(self._config.device) - image = image.to(device) - prediction = self._gaze_estimation_model(image) - prediction = prediction.cpu().numpy() - - face.normalized_gaze_angles = prediction[0] - face.angle_to_vector() - face.denormalize_gaze_vector() - - @torch.no_grad() - def _run_ethxgaze_model(self, face: Face) -> None: - image = self._transform(face.normalized_image).unsqueeze(0) - - device = torch.device(self._config.device) - image = image.to(device) - prediction = self._gaze_estimation_model(image) - prediction = prediction.cpu().numpy() - - face.normalized_gaze_angles = prediction[0] - face.angle_to_vector() - face.denormalize_gaze_vector() diff --git a/eye_tracking/gaze/head_pose_estimation/face_landmark_estimator.py b/eye_tracking/gaze/head_pose_estimation/face_landmark_estimator.py index 5df85fea..a6a56608 100644 --- a/eye_tracking/gaze/head_pose_estimation/face_landmark_estimator.py +++ b/eye_tracking/gaze/head_pose_estimation/face_landmark_estimator.py @@ -1,7 +1,5 @@ from typing import List -import face_alignment -import face_alignment.detection.sfd import mediapipe import numpy as np from omegaconf import DictConfig @@ -12,81 +10,21 @@ class LandmarkEstimator: def __init__(self, config: DictConfig): self.mode = config.face_detector.mode - if self.mode == "dlib": - raise NotImplementedError("Dlib is not supported for landmark estimation") - elif self.mode == "face_alignment_dlib": - raise NotImplementedError("Dlib is not supported for landmark estimation") - elif self.mode == "face_alignment_sfd": - self.detector = face_alignment.detection.sfd.sfd_detector.SFDDetector(device=config.device) - self.predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=config.device) - elif self.mode == "mediapipe": - self.detector = mediapipe.solutions.face_mesh.FaceMesh( - max_num_faces=config.face_detector.mediapipe_max_num_faces, - static_image_mode=config.face_detector.mediapipe_static_image_mode, - refine_landmarks=True, # Adds eye pupil landmarks (468-477) - ) - else: - raise ValueError + self.detector = mediapipe.solutions.face_mesh.FaceMesh( + max_num_faces=config.face_detector.mediapipe_max_num_faces, + static_image_mode=config.face_detector.mediapipe_static_image_mode, + refine_landmarks=True, # Adds eye pupil landmarks (468-477) + ) def detect_faces(self, image: np.ndarray) -> List[Face]: - if self.mode == "dlib": - return self._detect_faces_dlib(image) - elif self.mode == "face_alignment_dlib": - return self._detect_faces_face_alignment_dlib(image) - elif self.mode == "face_alignment_sfd": - return self._detect_faces_face_alignment_sfd(image) - elif self.mode == "mediapipe": - return self._detect_faces_mediapipe(image) - else: - raise ValueError - - def detect_faces_raw(self, image: np.ndarray) -> List[np.ndarray]: - if self.mode == "mediapipe": - return self._detect_faces_mediapipe_raw(image) - else: - raise ValueError - - def _detect_faces_dlib(self, image: np.ndarray) -> List[Face]: - bboxes = self.detector(self._get_bgr_frame(image), 0) - detected = [] - for bbox in bboxes: - predictions = self.predictor(self._get_bgr_frame(image), bbox) - landmarks = np.array([(pt.x, pt.y) for pt in predictions.parts()], dtype=np.float32) - bbox = np.array([[bbox.left(), bbox.top()], [bbox.right(), bbox.bottom()]], dtype=np.float32) - detected.append(Face(bbox, landmarks)) - return detected - - def _detect_faces_face_alignment_dlib(self, image: np.ndarray) -> List[Face]: - bboxes = self.detector(self._get_bgr_frame(image), 0) - bboxes = [[bbox.left(), bbox.top(), bbox.right(), bbox.bottom()] for bbox in bboxes] - predictions = self.predictor.get_landmarks(self._get_bgr_frame(image), detected_faces=bboxes) - if predictions is None: - predictions = [] - detected = [] - for bbox, landmarks in zip(bboxes, predictions): - bbox = np.array(bbox, dtype=np.float32).reshape(2, 2) - detected.append(Face(bbox, landmarks)) - return detected - - def _detect_faces_face_alignment_sfd(self, image: np.ndarray) -> List[Face]: - bboxes = self.detector.detect_from_image(self._get_bgr_frame(image).copy()) - bboxes = [bbox[:4] for bbox in bboxes] - predictions = self.predictor.get_landmarks(self._get_bgr_frame(image), detected_faces=bboxes) - if predictions is None: - predictions = [] - detected = [] - for bbox, landmarks in zip(bboxes, predictions): - bbox = np.array(bbox, dtype=np.float32).reshape(2, 2) - detected.append(Face(bbox, landmarks)) - return detected - - def _detect_faces_mediapipe(self, image: np.ndarray) -> List[Face]: """ Calculated landmarks scaled to the image size with a bounding box + :param image: RGB image + :return: List of faces """ h, w = image.shape[:2] - faces_landmarks = self._detect_faces_mediapipe_raw(image) + faces_landmarks = self._detect_faces_raw(image) detected = [] if faces_landmarks: for face in faces_landmarks: @@ -96,9 +34,17 @@ def _detect_faces_mediapipe(self, image: np.ndarray) -> List[Face]: detected.append(Face(bbox, pts)) return detected - def _detect_faces_mediapipe_raw(self, image: np.ndarray) -> List[np.ndarray]: + def detect_faces_raw(self, image: np.ndarray) -> List[np.ndarray]: + if self.mode == "mediapipe": + return self._detect_faces_raw(image) + else: + raise ValueError + + def _detect_faces_raw(self, image: np.ndarray) -> List[np.ndarray]: """ - Returns landmarks as they come from the mediapipe model + Returns landmarks as they come from the mediapipe model (not scaled to the image size) + :param image: RGB image + :return: List of faces landmarks """ predictions = self.detector.process(self._get_bgr_frame(image)) faces_landmarks = [] diff --git a/eye_tracking/gaze/main.py b/eye_tracking/gaze/main.py index a4be48d6..5c7d7ffa 100644 --- a/eye_tracking/gaze/main.py +++ b/eye_tracking/gaze/main.py @@ -6,17 +6,6 @@ import torch from omegaconf import DictConfig, OmegaConf -from gaze.gaze_detector import GazeDetector -from gaze.utils import ( - check_path_all, - download_dlib_pretrained_model, - download_ethxgaze_model, - download_mpiifacegaze_model, - download_mpiigaze_model, - expanduser_all, - generate_dummy_camera_params, -) - logger = logging.getLogger(__name__) @@ -27,22 +16,7 @@ def parse_args() -> argparse.Namespace: type=str, help="Config file. When using a config file, all the other " "commandline arguments are ignored. " - "See https://github.com/hysts/pytorch_mpiigaze_demo/ptgaze/data/configs/eth-xgaze.yaml", - ) - parser.add_argument( - "--mode", - type=str, - choices=["mpiigaze", "mpiifacegaze", "eth-xgaze"], - help="With 'mpiigaze', MPIIGaze model will be used. " - "With 'mpiifacegaze', MPIIFaceGaze model will be used. " - "With 'eth-xgaze', ETH-XGaze model will be used.", - ) - parser.add_argument( - "--face-detector", - type=str, - default="mediapipe", - choices=["dlib", "face_alignment_dlib", "face_alignment_sfd", "mediapipe"], - help="The method used to detect faces and find face landmarks " "(default: 'mediapipe')", + "See https://github.com/hysts/pytorch_mpiigaze_demo/ptgaze/data/configs/mpiigaze.yaml", ) parser.add_argument("--device", type=str, choices=["cpu", "cuda"], help="Device used for model inference.") parser.add_argument("--image", type=str, help="Path to an input image file.") @@ -63,14 +37,7 @@ def parse_args() -> argparse.Namespace: def load_mode_config(args: argparse.Namespace) -> DictConfig: package_root = pathlib.Path(__file__).parent.resolve() - if args.mode == "mpiigaze": - path = package_root / "data/configs/mpiigaze.yaml" - elif args.mode == "mpiifacegaze": - path = package_root / "data/configs/mpiifacegaze.yaml" - elif args.mode == "eth-xgaze": - path = package_root / "data/configs/eth-xgaze.yaml" - else: - raise ValueError + path = package_root / "data/configs/mpiigaze.yaml" config = OmegaConf.load(path) config.PACKAGE_ROOT = package_root.as_posix() @@ -103,37 +70,3 @@ def load_mode_config(args: argparse.Namespace) -> DictConfig: config.demo.output_dir = "outputs" return config - - -def main(): - args = parse_args() - if args.debug: - logging.getLogger("ptgaze").setLevel(logging.DEBUG) - - if args.config: - config = OmegaConf.load(args.config) - elif args.mode: - config = load_mode_config(args) - else: - raise ValueError("You need to specify one of '--mode' or '--config'.") - expanduser_all(config) - if config.gaze_estimator.use_dummy_camera_params: - generate_dummy_camera_params(config) - - OmegaConf.set_readonly(config, True) - logger.info(OmegaConf.to_yaml(config)) - - if config.face_detector.mode == "dlib": - download_dlib_pretrained_model() - if args.mode: - if config.mode == "MPIIGaze": - download_mpiigaze_model() - elif config.mode == "MPIIFaceGaze": - download_mpiifacegaze_model() - elif config.mode == "ETH-XGaze": - download_ethxgaze_model() - - check_path_all(config) - - gaze_detector = GazeDetector(config) - gaze_detector.run() diff --git a/eye_tracking/gaze/models/__init__.py b/eye_tracking/gaze/models/__init__.py index 5e233734..3a7a0d38 100644 --- a/eye_tracking/gaze/models/__init__.py +++ b/eye_tracking/gaze/models/__init__.py @@ -1,20 +1,14 @@ import importlib -import timm import torch from omegaconf import DictConfig def create_model(config: DictConfig) -> torch.nn.Module: mode = config.mode - if mode in ['MPIIGaze', 'MPIIFaceGaze']: - module = importlib.import_module( - f'ptgaze.models.{mode.lower()}.{config.model.name}') - model = module.Model(config) - elif mode == 'ETH-XGaze': - model = timm.create_model(config.model.name, num_classes=2) - else: - raise ValueError + module = importlib.import_module(f"ptgaze.models.{mode.lower()}.{config.model.name}") + model = module.Model(config) + device = torch.device(config.device) model.to(device) return model diff --git a/eye_tracking/gaze/models/mpiifacegaze/__init__.py b/eye_tracking/gaze/models/mpiifacegaze/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/eye_tracking/gaze/models/mpiifacegaze/backbones/__init__.py b/eye_tracking/gaze/models/mpiifacegaze/backbones/__init__.py deleted file mode 100644 index d4845df0..00000000 --- a/eye_tracking/gaze/models/mpiifacegaze/backbones/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -import importlib - -import torch.nn as nn -from omegaconf import DictConfig - - -def create_backbone(config: DictConfig) -> nn.Module: - backbone_name = config.model.backbone.name - module = importlib.import_module( - f'ptgaze.models.mpiifacegaze.backbones.{backbone_name}') - return module.Model(config) diff --git a/eye_tracking/gaze/models/mpiifacegaze/backbones/resnet_simple.py b/eye_tracking/gaze/models/mpiifacegaze/backbones/resnet_simple.py deleted file mode 100644 index bc732b00..00000000 --- a/eye_tracking/gaze/models/mpiifacegaze/backbones/resnet_simple.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import torchvision -from omegaconf import DictConfig - - -class Model(torchvision.models.ResNet): - def __init__(self, config: DictConfig): - block_name = config.model.backbone.resnet_block - if block_name == 'basic': - block = torchvision.models.resnet.BasicBlock - elif block_name == 'bottleneck': - block = torchvision.models.resnet.Bottleneck - else: - raise ValueError - layers = list(config.model.backbone.resnet_layers) + [1] - super().__init__(block, layers) - del self.layer4 - del self.avgpool - del self.fc - - pretrained_name = config.model.backbone.pretrained - if pretrained_name: - state_dict = torch.hub.load_state_dict_from_url( - torchvision.models.resnet.model_urls[pretrained_name]) - self.load_state_dict(state_dict, strict=False) - # While the pretrained models of torchvision are trained - # using images with RGB channel order, in this repository - # images are treated as BGR channel order. - # Therefore, reverse the channel order of the first - # convolutional layer. - module = self.conv1 - module.weight.data = module.weight.data[:, [2, 1, 0]] - - with torch.no_grad(): - data = torch.zeros((1, 3, 224, 224), dtype=torch.float32) - features = self.forward(data) - self.n_features = features.shape[1] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - return x diff --git a/eye_tracking/gaze/models/mpiifacegaze/resnet_simple.py b/eye_tracking/gaze/models/mpiifacegaze/resnet_simple.py deleted file mode 100644 index 5302008e..00000000 --- a/eye_tracking/gaze/models/mpiifacegaze/resnet_simple.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from omegaconf import DictConfig - -from .backbones import create_backbone - - -class Model(nn.Module): - def __init__(self, config: DictConfig): - super().__init__() - self.feature_extractor = create_backbone(config) - n_channels = self.feature_extractor.n_features - - self.conv = nn.Conv2d(n_channels, - 1, - kernel_size=1, - stride=1, - padding=0) - # This model assumes the input image size is 224x224. - self.fc = nn.Linear(n_channels * 14**2, 2) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.feature_extractor(x) - y = F.relu(self.conv(x)) - x = x * y - x = x.view(x.size(0), -1) - x = self.fc(x) - return x diff --git a/eye_tracking/gaze/transforms.py b/eye_tracking/gaze/transforms.py index d7bc8f9e..98728eb4 100644 --- a/eye_tracking/gaze/transforms.py +++ b/eye_tracking/gaze/transforms.py @@ -1,39 +1,7 @@ from typing import Any -import cv2 import torchvision.transforms as T -from omegaconf import DictConfig -def create_transform(config: DictConfig) -> Any: - if config.mode == 'MPIIGaze': - return T.ToTensor() - elif config.mode == 'MPIIFaceGaze': - return _create_mpiifacegaze_transform(config) - elif config.mode == 'ETH-XGaze': - return _create_ethxgaze_transform(config) - else: - raise ValueError - - -def _create_mpiifacegaze_transform(config: DictConfig) -> Any: - size = tuple(config.gaze_estimator.image_size) - transform = T.Compose([ - T.Lambda(lambda x: cv2.resize(x, size)), - T.ToTensor(), - T.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, - 0.229]), # BGR - ]) - return transform - - -def _create_ethxgaze_transform(config: DictConfig) -> Any: - size = tuple(config.gaze_estimator.image_size) - transform = T.Compose([ - T.Lambda(lambda x: cv2.resize(x, size)), - T.Lambda(lambda x: x[:, :, ::-1].copy()), # BGR -> RGB - T.ToTensor(), - T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, - 0.225]), # RGB - ]) - return transform +def create_transform() -> Any: + return T.ToTensor() diff --git a/eye_tracking/gaze/utils.py b/eye_tracking/gaze/utils.py index 427f88b3..6d7ebbed 100644 --- a/eye_tracking/gaze/utils.py +++ b/eye_tracking/gaze/utils.py @@ -38,27 +38,6 @@ def get_ptgaze_model_dir() -> pathlib.Path: return model_dir -def download_dlib_pretrained_model() -> None: - logger.debug("Called download_dlib_pretrained_model()") - - model_dir = get_ptgaze_model_dir() - dlib_model_dir = model_dir / "dlib" - dlib_model_dir.mkdir(exist_ok=True, parents=True) - dlib_model_path = dlib_model_dir / "shape_predictor_68_face_landmarks.dat" - logger.debug(f"Update config.face_detector.dlib_model_path to {dlib_model_path.as_posix()}") - - if dlib_model_path.exists(): - logger.debug(f"dlib pretrained model {dlib_model_path.as_posix()} already exists.") - return - - logger.debug("Download the dlib pretrained model") - bz2_path = dlib_model_path.as_posix() + ".bz2" - torch.hub.download_url_to_file("http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2", bz2_path) - with bz2.BZ2File(bz2_path, "rb") as f_in, open(dlib_model_path, "wb") as f_out: - data = f_in.read() - f_out.write(data) - - def download_mpiigaze_model() -> pathlib.Path: logger.debug("Called _download_mpiigaze_model()") model_dir = get_ptgaze_model_dir() @@ -75,38 +54,6 @@ def download_mpiigaze_model() -> pathlib.Path: return output_path -def download_mpiifacegaze_model() -> pathlib.Path: - logger.debug("Called _download_mpiifacegaze_model()") - model_dir = get_ptgaze_model_dir() - output_dir = model_dir / "models/" - output_dir.mkdir(exist_ok=True, parents=True) - output_path = output_dir / "mpiifacegaze_resnet_simple.pth" - if not output_path.exists(): - logger.debug("Download the pretrained model") - torch.hub.download_url_to_file( - "https://github.com/hysts/pytorch_mpiigaze_demo/releases/download/v0.1.0/mpiifacegaze_resnet_simple.pth", output_path.as_posix() - ) - else: - logger.debug(f"The pretrained model {output_path} already exists.") - return output_path - - -def download_ethxgaze_model() -> pathlib.Path: - logger.debug("Called _download_ethxgaze_model()") - model_dir = get_ptgaze_model_dir() - output_dir = model_dir / "models/" - output_dir.mkdir(exist_ok=True, parents=True) - output_path = output_dir / "eth-xgaze_resnet18.pth" - if not output_path.exists(): - logger.debug("Download the pretrained model") - torch.hub.download_url_to_file( - "https://github.com/hysts/pytorch_mpiigaze_demo/releases/download/v0.2.2/eth-xgaze_resnet18.pth", output_path.as_posix() - ) - else: - logger.debug(f"The pretrained model {output_path} already exists.") - return output_path - - def generate_dummy_camera_params(config: DictConfig) -> None: logger.debug("Called _generate_dummy_camera_params()") if config.demo.image_path: @@ -147,8 +94,6 @@ def _expanduser(path: str) -> str: def expanduser_all(config: DictConfig) -> None: - if hasattr(config.face_detector, "dlib_model_path"): - config.face_detector.dlib_model_path = _expanduser(config.face_detector.dlib_model_path) config.gaze_estimator.checkpoint = _expanduser(config.gaze_estimator.checkpoint) config.gaze_estimator.camera_params = _expanduser(config.gaze_estimator.camera_params) config.gaze_estimator.normalized_camera_params = _expanduser(config.gaze_estimator.normalized_camera_params) @@ -170,8 +115,6 @@ def _check_path(config: DictConfig, key: str) -> None: def check_path_all(config: DictConfig) -> None: - if config.face_detector.mode == "dlib": - _check_path(config, "face_detector.dlib_model_path") _check_path(config, "gaze_estimator.checkpoint") _check_path(config, "gaze_estimator.camera_params") _check_path(config, "gaze_estimator.normalized_camera_params") diff --git a/eye_tracking/init.py b/eye_tracking/init.py index fc8ccb10..fc028ba4 100644 --- a/eye_tracking/init.py +++ b/eye_tracking/init.py @@ -9,9 +9,6 @@ import logging from gaze.utils import ( check_path_all, - download_dlib_pretrained_model, - download_ethxgaze_model, - download_mpiifacegaze_model, download_mpiigaze_model, expanduser_all, generate_dummy_camera_params, @@ -22,21 +19,14 @@ logger = logging.getLogger(__name__) -def init_ptgaze_config(mode: str) -> DictConfig: +def init_ptgaze_config() -> DictConfig: """ Custom config initialiser for ptgaze """ package_root = pathlib.Path(__file__).parent.resolve() ptgaze_package_root = package_root / "gaze" - if mode == "mpiigaze": - path = ptgaze_package_root / "data/configs/mpiigaze.yaml" - elif mode == "mpiifacegaze": - path = ptgaze_package_root / "data/configs/mpiifacegaze.yaml" - elif mode == "eth-xgaze": - path = ptgaze_package_root / "data/configs/eth-xgaze.yaml" - else: - raise ValueError(f"Incorrect mode selected: {mode}") + path = ptgaze_package_root / "data/configs/mpiigaze.yaml" logger.info(f"Loading config from {path}") config = OmegaConf.load(path) @@ -52,7 +42,7 @@ def init_ptgaze() -> DictConfig: :return DictConfig: The ptgaze config """ - config = init_ptgaze_config("mpiigaze") + config = init_ptgaze_config() expanduser_all(config) if config.gaze_estimator.use_dummy_camera_params: @@ -61,15 +51,7 @@ def init_ptgaze() -> DictConfig: OmegaConf.set_readonly(config, True) logger.info(OmegaConf.to_yaml(config)) - if config.face_detector.mode == "dlib": - download_dlib_pretrained_model() - - if config.mode == "MPIIGaze": - download_mpiigaze_model() - elif config.mode == "MPIIFaceGaze": - download_mpiifacegaze_model() - elif config.mode == "ETH-XGaze": - download_ethxgaze_model() + download_mpiigaze_model() check_path_all(config)