Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Host nodes: DepthMerger. #163

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions depthai_nodes/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .depth_merger import DepthMerger

__all__ = ["DepthMerger"]
146 changes: 146 additions & 0 deletions depthai_nodes/nodes/depth_merger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from typing import Union

import depthai as dai

from depthai_nodes.ml.messages import ImgDetectionExtended, ImgDetectionsExtended
klemen1999 marked this conversation as resolved.
Show resolved Hide resolved

from .host_spatials_calc import HostSpatialsCalc


class DepthMerger(dai.node.HostNode):
"""DepthMerger is a custom host node for merging 2D detections with depth
information to produce spatial detections.

Attributes
----------
output : dai.Node.Output
The output of the DepthMerger node containing dai.SpatialImgDetections.
shrinking_factor : float
The shrinking factor for the bounding box. 0 means no shrinking. The factor means the percentage of the bounding box to shrink from each side.

Usage
-----
depth_merger = pipeline.create(DepthMerger).build(
output_2d=nn.out,
output_depth=stereo.depth
)
"""

def __init__(self, shrinking_factor: float = 0) -> None:
super().__init__()

self.output = self.createOutput(
possibleDatatypes=[
dai.Node.DatatypeHierarchy(dai.DatatypeEnum.SpatialImgDetections, True)
]
)

self.shrinking_factor = shrinking_factor

def build(
self,
output_2d: dai.Node.Output,
output_depth: dai.Node.Output,
calib_data: dai.CalibrationHandler,
depth_alignment_socket: dai.CameraBoardSocket = dai.CameraBoardSocket.CAM_A,
shrinking_factor: float = 0,
) -> "DepthMerger":
self.link_args(output_2d, output_depth)
self.shrinking_factor = shrinking_factor
self.host_spatials_calc = HostSpatialsCalc(calib_data, depth_alignment_socket)
return self

def process(self, message_2d: dai.Buffer, depth: dai.ImgFrame) -> None:
spatial_dets = self._transform(message_2d, depth)
self.output.send(spatial_dets)

def _transform(
self, message_2d: dai.Buffer, depth: dai.ImgFrame
) -> Union[dai.SpatialImgDetections, dai.SpatialImgDetection]:
"""Transforms 2D detections into spatial detections based on the depth frame."""
if isinstance(message_2d, dai.ImgDetection):
return self._detection_to_spatial(message_2d, depth)
elif isinstance(message_2d, dai.ImgDetections):
return self._detections_to_spatial(message_2d, depth)
elif isinstance(message_2d, ImgDetectionExtended):
return self._detection_to_spatial(message_2d, depth)
elif isinstance(message_2d, ImgDetectionsExtended):
return self._detections_to_spatial(message_2d, depth)
else:
raise ValueError(f"Unknown message type: {type(message_2d)}")

def _detection_to_spatial(
self,
detection: Union[dai.ImgDetection, ImgDetectionExtended],
depth: dai.ImgFrame,
) -> dai.SpatialImgDetection:
"""Converts a single 2D detection into a spatial detection using the depth
frame."""
depth_frame = depth.getCvFrame()
x_len = depth_frame.shape[1]
y_len = depth_frame.shape[0]
xmin = (
detection.rotated_rect.getOuterRect()[0]
if isinstance(detection, ImgDetectionExtended)
else detection.xmin
)
ymin = (
detection.rotated_rect.getOuterRect()[1]
if isinstance(detection, ImgDetectionExtended)
else detection.ymin
)
xmax = (
detection.rotated_rect.getOuterRect()[2]
if isinstance(detection, ImgDetectionExtended)
else detection.xmax
)
ymax = (
detection.rotated_rect.getOuterRect()[3]
if isinstance(detection, ImgDetectionExtended)
else detection.ymax
)
xmin += (xmax - xmin) * self.shrinking_factor
ymin += (ymax - ymin) * self.shrinking_factor
xmax -= (xmax - xmin) * self.shrinking_factor
ymax -= (ymax - ymin) * self.shrinking_factor
roi = [
self._get_index(xmin, x_len),
self._get_index(ymin, y_len),
self._get_index(xmax, x_len),
self._get_index(ymax, y_len),
]
spatials = self.host_spatials_calc.calc_spatials(depth, roi)

spatial_img_detection = dai.SpatialImgDetection()
spatial_img_detection.xmin = xmin
spatial_img_detection.ymin = ymin
spatial_img_detection.xmax = xmax
spatial_img_detection.ymax = ymax
spatial_img_detection.spatialCoordinates = dai.Point3f(
spatials["x"], spatials["y"], spatials["z"]
)

spatial_img_detection.confidence = detection.confidence
spatial_img_detection.label = 0 if detection.label == -1 else detection.label
return spatial_img_detection

def _detections_to_spatial(
self,
detections: Union[dai.ImgDetections, ImgDetectionsExtended],
depth: dai.ImgFrame,
) -> dai.SpatialImgDetections:
"""Converts multiple 2D detections into spatial detections using the depth
frame."""
new_dets = dai.SpatialImgDetections()
new_dets.detections = [
self._detection_to_spatial(d, depth) for d in detections.detections
]
new_dets.setSequenceNum(detections.getSequenceNum())
new_dets.setTimestamp(detections.getTimestamp())
return new_dets

def _get_index(self, relative_coord: float, dimension_len: int) -> int:
"""Converts a relative coordinate to an absolute index within the given
dimension length."""
bounded_coord = min(1, relative_coord)
return max(0, int(bounded_coord * dimension_len) - 1)
164 changes: 164 additions & 0 deletions depthai_nodes/nodes/host_spatials_calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import Dict, List

import depthai as dai
import numpy as np


class HostSpatialsCalc:
"""HostSpatialsCalc is a helper class for calculating spatial coordinates from depth
data.

Attributes
----------
calibData : dai.CalibrationHandler
Calibration data handler for the device.
depth_alignment_socket : dai.CameraBoardSocket
The camera socket used for depth alignment.
DELTA : int
The delta value for ROI calculation. Default is 5 - means 10x10 depth pixels around point for depth averaging.
THRESH_LOW : int
The lower threshold for depth values. Default is 200 - means 20cm.
THRESH_HIGH : int
The upper threshold for depth values. Default is 30000 - means 30m.
"""

# We need device object to get calibration data
def __init__(
self,
calib_data: dai.CalibrationHandler,
depth_alignment_socket: dai.CameraBoardSocket = dai.CameraBoardSocket.CAM_A,
delta: int = 5,
thresh_low: int = 200,
thresh_high: int = 30000,
):
self.calibData = calib_data
self.depth_alignment_socket = depth_alignment_socket

self.delta = delta
self.thresh_low = thresh_low
self.thresh_high = thresh_high

def setLowerThreshold(self, threshold_low: int) -> None:
"""Sets the lower threshold for depth values.

@param threshold_low: The lower threshold for depth values.
@type threshold_low: int
"""
if not isinstance(threshold_low, int):
if isinstance(threshold_low, float):
threshold_low = int(threshold_low)
else:
raise TypeError(
"Threshold has to be an integer or float! Got {}".format(
type(threshold_low)
)
)
self.thresh_low = threshold_low

def setUpperThreshold(self, threshold_high: int) -> None:
"""Sets the upper threshold for depth values.

@param threshold_high: The upper threshold for depth values.
@type threshold_high: int
"""
if not isinstance(threshold_high, int):
if isinstance(threshold_high, float):
threshold_high = int(threshold_high)
else:
raise TypeError(
"Threshold has to be an integer or float! Got {}".format(
type(threshold_high)
)
)
self.thresh_high = threshold_high

def setDeltaRoi(self, delta: int) -> None:
"""Sets the delta value for ROI calculation.

@param delta: The delta value for ROI calculation.
@type delta: int
"""
if not isinstance(delta, int):
if isinstance(delta, float):
delta = int(delta)
else:
raise TypeError(
"Delta has to be an integer or float! Got {}".format(type(delta))
)
self.delta = delta

def _check_input(self, roi: List[int], frame: np.ndarray) -> List[int]:
"""Checks if the input is ROI or point and converts point to ROI if necessary.

@param roi: The region of interest (ROI) or point.
@type roi: List[int]
@param frame: The depth frame.
@type frame: np.ndarray
@return: The region of interest (ROI).
@rtype: List[int]
"""
if len(roi) == 4:
return roi
if len(roi) != 2:
raise ValueError(
"You have to pass either ROI (4 values) or point (2 values)!"
)
# Limit the point so ROI won't be outside the frame
x = min(max(roi[0], self.delta), frame.shape[1] - self.delta)
y = min(max(roi[1], self.delta), frame.shape[0] - self.delta)
return (x - self.delta, y - self.delta, x + self.delta, y + self.delta)

# roi has to be list of ints
def calc_spatials(
self,
depthData: dai.ImgFrame,
roi: List[int],
averaging_method: callable = np.mean,
) -> Dict[str, float]:
"""Calculates spatial coordinates from depth data within the specified ROI.

@param depthData: The depth data.
@type depthData: dai.ImgFrame
@param roi: The region of interest (ROI) or point.
@type roi: List[int]
@param averaging_method: The method for averaging the depth values.
@type averaging_method: callable
@return: The spatial coordinates.
@rtype: Dict[str, float]
"""
depthFrame = depthData.getFrame()

roi = self._check_input(
roi, depthFrame
) # If point was passed, convert it to ROI
xmin, ymin, xmax, ymax = roi

# Calculate the average depth in the ROI.
depthROI = depthFrame[ymin:ymax, xmin:xmax]
inRange = (self.thresh_low <= depthROI) & (depthROI <= self.thresh_high)

averageDepth = averaging_method(depthROI[inRange])

centroid = np.array( # Get centroid of the ROI
[
int((xmax + xmin) / 2),
int((ymax + ymin) / 2),
]
)

K = self.calibData.getCameraIntrinsics(
cameraId=self.depth_alignment_socket,
resizeWidth=depthFrame.shape[1],
resizeHeight=depthFrame.shape[0],
)
K = np.array(K)
K_inv = np.linalg.inv(K)
homogenous_coords = np.array([centroid[0], centroid[1], 1])
spatial_coords = averageDepth * K_inv.dot(homogenous_coords)

spatials = {
"x": spatial_coords[0],
"y": spatial_coords[1],
"z": spatial_coords[2],
}
return spatials
Loading