diff --git a/deepface/models/demography/Age.py b/deepface/models/demography/Age.py index 67ab3ae6..9f2052e3 100644 --- a/deepface/models/demography/Age.py +++ b/deepface/models/demography/Age.py @@ -1,3 +1,6 @@ +# stdlib dependencies +from typing import List + # 3rd party dependencies import numpy as np @@ -43,6 +46,29 @@ def predict(self, img: np.ndarray) -> np.float64: age_predictions = self.model(img, training=False).numpy()[0, :] return find_apparent_age(age_predictions) + def predicts(self, imgs: List[np.ndarray]) -> np.ndarray: + """ + Predict apparent ages of multiple faces + Args: + imgs (List[np.ndarray]): (n, 224, 224, 3) + Returns: + apparent_ages (np.ndarray): (n,) + """ + # Convert list to numpy array + imgs_:np.ndarray = np.array(imgs) + # Remove batch dimension if exists + imgs_ = imgs_.squeeze() + # Check if the input is a single image + if len(imgs_.shape) == 3: + # Add batch dimension if not exists + imgs_ = np.expand_dims(imgs_, axis=0) + # Batch prediction + age_predictions = self.model.predict_on_batch(imgs_) + apparent_ages = np.array( + [find_apparent_age(age_prediction) for age_prediction in age_predictions] + ) + return apparent_ages + def load_model( url=WEIGHTS_URL, @@ -65,7 +91,7 @@ def load_model( # -------------------------- - age_model = Model(inputs=model.input, outputs=base_model_output) + age_model = Model(inputs=model.inputs, outputs=base_model_output) # -------------------------- diff --git a/deepface/models/demography/Gender.py b/deepface/models/demography/Gender.py index ad1c15e3..f55c5719 100644 --- a/deepface/models/demography/Gender.py +++ b/deepface/models/demography/Gender.py @@ -1,3 +1,6 @@ +# stdlib dependencies +from typing import List + # 3rd party dependencies import numpy as np @@ -42,6 +45,24 @@ def predict(self, img: np.ndarray) -> np.ndarray: # return self.model.predict(img, verbose=0)[0, :] return self.model(img, training=False).numpy()[0, :] + def predicts(self, imgs: List[np.ndarray]) -> np.ndarray: + """ + Predict apparent ages of multiple faces + Args: + imgs (List[np.ndarray]): (n, 224, 224, 3) + Returns: + apparent_ages (np.ndarray): (n,) + """ + # Convert list to numpy array + imgs_:np.ndarray = np.array(imgs) + # Remove redundant dimensions + imgs_ = imgs_.squeeze() + # Check if the input is a single image + if len(imgs_.shape) == 3: + # Add batch dimension + imgs_ = np.expand_dims(imgs_, axis=0) + return self.model.predict_on_batch(imgs_) + def load_model( url=WEIGHTS_URL, @@ -64,7 +85,7 @@ def load_model( # -------------------------- - gender_model = Model(inputs=model.input, outputs=base_model_output) + gender_model = Model(inputs=model.inputs, outputs=base_model_output) # --------------------------