Skip to content

Commit

Permalink
Merge pull request #1232 from David-rn/feat/keypoints-from-mediapipe
Browse files Browse the repository at this point in the history
Feat/keypoints from mediapipe
  • Loading branch information
SkalskiP authored Jun 13, 2024
2 parents 3d5437b + 07cb0d8 commit f96e776
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 46 deletions.
236 changes: 190 additions & 46 deletions supervision/keypoint/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ class KeyPoints:
The `sv.KeyPoints` class in the Supervision library standardizes results from
various keypoint detection and pose estimation models into a consistent format. This
class simplifies data manipulation and filtering, providing a uniform API for
integration with Supervision annotators.
integration with Supervision [keypoints annotators](/keypoint/annotators).
=== "Ultralytics"
Use [`sv.KeyPoints.from_ultralytics`](/keypoint/core/#supervision.keypoint.core.KeyPoints.from_ultralytics)
method, which accepts model results.
method, which accepts [YOLOv8](https://github.com/ultralytics/ultralytics)
pose result.
```python
import cv2
Expand All @@ -32,10 +33,60 @@ class simplifies data manipulation and filtering, providing a uniform API for
image = cv2.imread(<SOURCE_IMAGE_PATH>)
model = YOLO('yolov8s-pose.pt')
result = model(image)[0]
key_points = sv.KeyPoints.from_ultralytics(result)
```
=== "Inference"
Use [`sv.KeyPoints.from_inference`](/keypoint/core/#supervision.keypoint.core.KeyPoints.from_inference)
method, which accepts [Inference](https://inference.roboflow.com/) pose result.
```python
import cv2
import supervision as sv
from inference import get_model
image = cv2.imread(<SOURCE_IMAGE_PATH>)
model = get_model(model_id=<POSE_MODEL_ID>, api_key=<ROBOFLOW_API_KEY>)
result = model.infer(image)[0]
key_points = sv.KeyPoints.from_inference(result)
```
=== "MediaPipe"
Use [`sv.KeyPoints.from_mediapipe`](/keypoint/core/#supervision.keypoint.core.KeyPoints.from_mediapipe)
method, which accepts [MediaPipe](https://github.com/google-ai-edge/mediapipe)
pose result.
```python
import cv2
import mediapipe as mp
import supervision as sv
image = cv2.imread(<SOURCE_IMAGE_PATH>)
image_height, image_width, _ = image.shape
mediapipe_image = mp.Image(
image_format=mp.ImageFormat.SRGB,
data=cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
options = mp.tasks.vision.PoseLandmarkerOptions(
base_options=mp.tasks.BaseOptions(
model_asset_path="pose_landmarker_heavy.task"
),
running_mode=mp.tasks.vision.RunningMode.IMAGE,
num_poses=2)
PoseLandmarker = mp.tasks.vision.PoseLandmarker
with PoseLandmarker.create_from_options(options) as landmarker:
pose_landmarker_result = landmarker.detect(mediapipe_image)
key_points = sv.KeyPoints.from_mediapipe(
pose_landmarker_result, (image_width, image_height))
```
Attributes:
xy (np.ndarray): An array of shape `(n, 2)` containing
the bounding boxes coordinates in format `[x1, y1]`
Expand Down Expand Up @@ -63,7 +114,7 @@ def __post_init__(self):

def __len__(self) -> int:
"""
Returns the number of keypoints in the keypoints object.
Returns the number of keypoints in the `sv.KeyPoints` object.
"""
return len(self.xy)

Expand Down Expand Up @@ -106,18 +157,17 @@ def from_inference(cls, inference_result: Union[dict, Any]) -> KeyPoints:
"""
Create a `sv.KeyPoints` object from the [Roboflow](https://roboflow.com/)
API inference result or the [Inference](https://inference.roboflow.com/)
package results. When a keypoint detection model is used, this method
extracts the keypoint coordinates, class IDs, confidences, and class names.
package results.
Args:
inference_result (dict, any): The result from the
Roboflow API or Inference package containing predictions with keypoints.
Returns:
(KeyPoints): A KeyPoints object containing the keypoint coordinates,
class IDs, and confidences of each keypoint.
A `sv.KeyPoints` object containing the keypoint coordinates, class IDs,
and class names, and confidences of each keypoint.
Example:
Examples:
```python
import cv2
import supervision as sv
Expand Down Expand Up @@ -151,8 +201,6 @@ class IDs, and confidences of each keypoint.
"You can retrieve it like so: inference_result = model.infer(image)[0]"
)

# Unpack the result if received from inference.get_model,
# rather than inference_sdk.InferenceHTTPClient
with suppress(AttributeError):
inference_result = inference_result.dict(exclude_none=True, by_alias=True)

Expand Down Expand Up @@ -185,29 +233,115 @@ class IDs, and confidences of each keypoint.
data=data,
)

@classmethod
def from_mediapipe(
cls, mediapipe_results, resolution_wh: Tuple[int, int]
) -> KeyPoints:
"""
Creates a `sv.KeyPoints` instance from a
[MediaPipe](https://github.com/google-ai-edge/mediapipe)
pose landmark detection inference result.
Args:
mediapipe_results (Union[PoseLandmarkerResult, SolutionOutputs]):
The output results from Mediapipe. It supports both: the inference
result `PoseLandmarker` and the legacy one from `Pose`.
resolution_wh (Tuple[int, int]): A tuple of the form `(width, height)`
representing the resolution of the frame.
Returns:
A `sv.KeyPoints` object containing the keypoint coordinates and
confidences of each keypoint.
!!! tip
Before you start, download model bundles from the
[MediaPipe website](https://ai.google.dev/edge/mediapipe/solutions/vision/pose_landmarker/index#models).
Examples:
```python
import cv2
import mediapipe as mp
import supervision as sv
image = cv2.imread(<SOURCE_IMAGE_PATH>)
image_height, image_width, _ = image.shape
mediapipe_image = mp.Image(
image_format=mp.ImageFormat.SRGB,
data=cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
options = mp.tasks.vision.PoseLandmarkerOptions(
base_options=mp.tasks.BaseOptions(
model_asset_path="pose_landmarker_heavy.task"
),
running_mode=mp.tasks.vision.RunningMode.IMAGE,
num_poses=2)
PoseLandmarker = mp.tasks.vision.PoseLandmarker
with PoseLandmarker.create_from_options(options) as landmarker:
pose_landmarker_result = landmarker.detect(mediapipe_image)
key_points = sv.KeyPoints.from_mediapipe(
pose_landmarker_result, (image_width, image_height))
```
""" # noqa: E501 // docs
results = mediapipe_results.pose_landmarks
if not isinstance(mediapipe_results.pose_landmarks, list):
if mediapipe_results.pose_landmarks is None:
results = []
else:
results = [
[landmark for landmark in mediapipe_results.pose_landmarks.landmark]
]

if len(results) == 0:
return cls.empty()

xy = []
confidence = []
for pose in results:
prediction_xy = []
prediction_confidence = []
for landmark in pose:
keypoint_xy = [
landmark.x * resolution_wh[0],
landmark.y * resolution_wh[1],
]
prediction_xy.append(keypoint_xy)
prediction_confidence.append(landmark.visibility)

xy.append(prediction_xy)
confidence.append(prediction_confidence)

return cls(
xy=np.array(xy, dtype=np.float32),
confidence=np.array(confidence, dtype=np.float32),
)

@classmethod
def from_ultralytics(cls, ultralytics_results) -> KeyPoints:
"""
Creates a KeyPoints instance from a
[YOLOv8](https://github.com/ultralytics/ultralytics) inference result.
Creates a `sv.KeyPoints` instance from a
[YOLOv8](https://github.com/ultralytics/ultralytics) pose inference result.
Args:
ultralytics_results (ultralytics.engine.results.Keypoints):
The output Results instance from YOLOv8
Returns:
KeyPoints: A new KeyPoints object.
A `sv.KeyPoints` object containing the keypoint coordinates, class IDs,
and class names, and confidences of each keypoint.
Example:
Examples:
```python
import cv2
import supervision as sv
from ultralytics import YOLO
image = cv2.imread(<SOURCE_IMAGE_PATH>)
model = YOLO('yolov8s-pose.pt')
result = model(image)[0]
keypoints = sv.KeyPoints.from_ultralytics(result)
key_points = sv.KeyPoints.from_ultralytics(result)
```
"""
if ultralytics_results.keypoints.xy.numel() == 0:
Expand All @@ -224,16 +358,18 @@ def from_ultralytics(cls, ultralytics_results) -> KeyPoints:
@classmethod
def from_yolo_nas(cls, yolo_nas_results) -> KeyPoints:
"""
Create a KeyPoints instance from a YOLO NAS results.
Create a `sv.KeyPoints` instance from a [YOLO-NAS](https://github.com/Deci-AI/super-gradients/blob/master/YOLONAS-POSE.md)
pose inference results.
Args:
yolo_nas_results (ImagePoseEstimationPrediction):
The output object from YOLO NAS.
yolo_nas_results (ImagePoseEstimationPrediction): The output object from
YOLO NAS.
Returns:
KeyPoints: A new KeyPoints object.
A `sv.KeyPoints` object containing the keypoint coordinates, class IDs,
and class names, and confidences of each keypoint.
Example:
Examples:
```python
import cv2
import torch
Expand All @@ -243,13 +379,13 @@ def from_yolo_nas(cls, yolo_nas_results) -> KeyPoints:
image = cv2.imread(<SOURCE_IMAGE_PATH>)
device = "cuda" if torch.cuda.is_available() else "cpu"
yolo_nas = super_gradients.training.models.get(
model = super_gradients.training.models.get(
"yolo_nas_pose_s", pretrained_weights="coco_pose").to(device)
results = yolo_nas.predict(image, conf=0.1)
keypoints = sv.KeyPoints.from_yolo_nas(results)
results = model.predict(image, conf=0.1)
key_points = sv.KeyPoints.from_yolo_nas(results)
```
"""
""" # noqa: E501 // docs
if len(yolo_nas_results.prediction.poses) == 0:
return cls.empty()

Expand Down Expand Up @@ -282,34 +418,41 @@ def __getitem__(
self, index: Union[int, slice, List[int], np.ndarray, str]
) -> Union[KeyPoints, List, np.ndarray, None]:
"""
Get a subset of the KeyPoints object or access an item from its data field.
Get a subset of the `sv.KeyPoints` object or access an item from its data field.
When provided with an integer, slice, list of integers, or a numpy array, this
method returns a new KeyPoints object that represents a subset of the original
keypoints. When provided with a string, it accesses the corresponding item in
the data dictionary.
method returns a new `sv.KeyPoints` object that represents a subset of the
original `sv.KeyPoints`. When provided with a string, it accesses the
corresponding item in the data dictionary.
Args:
index (Union[int, slice, List[int], np.ndarray, str]): The index, indices,
or key to access a subset of the KeyPoints or an item from the data.
or key to access a subset of the `sv.KeyPoints` or an item from the
data.
Returns:
Union[KeyPoints, Any]: A subset of the KeyPoints object or an item from
the data field.
A subset of the `sv.KeyPoints` object or an item from the data field.
Example:
Examples:
```python
import supervision as sv
keypoints = sv.KeyPoints()
key_points = sv.KeyPoints()
# access the first keypoint using an integer index
key_points[0]
# access the first 10 keypoints using index slice
key_points[0:10]
first_detection = keypoints[0]
first_10_keypoints = keypoints[0:10]
some_keypoints = keypoints[[0, 2, 4]]
class_0_keypoints = keypoints[keypoints.class_id == 0]
high_confidence_keypoints = keypoints[keypoints.confidence > 0.5]
# access selected keypoints using a list of indices
key_points[[0, 2, 4]]
feature_vector = keypoints['feature_vector']
# access keypoints with selected class_id
key_points[key_points.class_id == 0]
# access keypoints with confidence greater than 0.5
key_points[key_points.confidence > 0.5]
```
"""
if isinstance(index, str):
Expand All @@ -325,13 +468,13 @@ def __getitem__(

def __setitem__(self, key: str, value: Union[np.ndarray, List]):
"""
Set a value in the data dictionary of the KeyPoints object.
Set a value in the data dictionary of the `sv.KeyPoints` object.
Args:
key (str): The key in the data dictionary to set.
value (Union[np.ndarray, List]): The value to set for the key.
Example:
Examples:
```python
import cv2
import supervision as sv
Expand All @@ -343,7 +486,7 @@ def __setitem__(self, key: str, value: Union[np.ndarray, List]):
result = model(image)[0]
keypoints = sv.KeyPoints.from_ultralytics(result)
keypoints['names'] = [
keypoints['class_name'] = [
model.model.names[class_id]
for class_id
in keypoints.class_id
Expand All @@ -364,12 +507,13 @@ def empty(cls) -> KeyPoints:
Create an empty Keypoints object with no keypoints.
Returns:
(KeyPoints): An empty Keypoints object.
An empty `sv.KeyPoints` object.
Example:
Examples:
```python
from supervision import Keypoints
empty_keypoints = Keypoints.empty()
import supervision as sv
key_points = sv.KeyPoints.empty()
```
"""
return cls(xy=np.empty((0, 0, 2), dtype=np.float32))
Loading

0 comments on commit f96e776

Please sign in to comment.