diff --git a/rcdt_detection/nodes/object_detection.py b/rcdt_detection/nodes/object_detection.py index 65385d0..6c9c1f4 100755 --- a/rcdt_detection/nodes/object_detection.py +++ b/rcdt_detection/nodes/object_detection.py @@ -19,7 +19,7 @@ from sensor_msgs.msg import Image, CameraInfo from rcdt_detection.image_manipulation import ( - ros_image_to_cv2_image, + ros_image_to_cv2_image_sliced, segmentation_mask_to_binary_mask, single_to_three_channel, ) @@ -146,10 +146,8 @@ def process_rgb_image( If it is, the image coordinates of its centroid are calculated. """ - rgb_image = ros_image_to_cv2_image(message) - + rgb_image = ros_image_to_cv2_image_sliced(message) segmentation_result = segment_image(model=segmentation_model, image=rgb_image) - ros_logger.info(f"Segmented {len(segmentation_result.masks)} objects.") centroid_image_coordinates = [] for mask in segmentation_result.masks: @@ -171,7 +169,7 @@ def process_depth_image( intrinsics: rs2.intrinsics, ) -> list[Point]: """Calculate world coordinate relative to the camera of image coordinates.""" - depth_image = ros_image_to_cv2_image(message) + depth_image = ros_image_to_cv2_image_sliced(message) world_coordinates = [] for image_coordinate in image_coordinates: diff --git a/rcdt_detection/rcdt_detection/image_manipulation.py b/rcdt_detection/rcdt_detection/image_manipulation.py index 1c7a8a8..ebc5d75 100644 --- a/rcdt_detection/rcdt_detection/image_manipulation.py +++ b/rcdt_detection/rcdt_detection/image_manipulation.py @@ -3,29 +3,51 @@ # SPDX-License-Identifier: Apache-2.0 import cv2 -import numpy as np +from numpy import array, uint8 import torch from sensor_msgs.msg import Image from cv_bridge import CvBridge +from math import floor, ceil cv_bridge = CvBridge() def ros_image_to_cv2_image( image_message: Image, desired_encoding: str = "passthrough" -) -> np.array: +) -> array: """Convert ROS image message to cv2 image.""" return cv_bridge.imgmsg_to_cv2(image_message, desired_encoding=desired_encoding) -def segmentation_mask_to_binary_mask(mask: torch.Tensor) -> np.array: +def slice_image_to_stride(image: array, stride: int = 32) -> array: + """Slice image to confirm to a given stride length""" + rows = image.shape[0] + cols = image.shape[1] + row_dist = rows % stride + col_dist = cols % stride + return image[ + slice(floor(row_dist / 2), rows - ceil(row_dist / 2)), + slice(floor(col_dist / 2), cols - ceil(col_dist / 2)), + ] + + +def ros_image_to_cv2_image_sliced( + image_message: Image, desired_encoding: str = "passthrough", stride: int = 32 +) -> array: + """slice an image so that its dimensions are a multiple of stride""" + return slice_image_to_stride( + ros_image_to_cv2_image(image_message, desired_encoding), stride + ) + + +def segmentation_mask_to_binary_mask(mask: torch.Tensor) -> array: """Convert given mask to np.array with range [0, 255], dtype=uint8, and dimensions [height, width, channels].""" - binary_mask = mask.data.cpu().numpy().astype(np.uint8) + binary_mask = mask.data.cpu().numpy().astype(uint8) binary_mask = binary_mask * 255 binary_mask = binary_mask.transpose(1, 2, 0) return binary_mask -def single_to_three_channel(image: np.array) -> np.array: +def single_to_three_channel(image: array) -> array: """Convert given single-channel image to three-channel image.""" return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)