Skip to content

Commit

Permalink
added slicing of image to confirm to stride length
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriVanWarmerdam committed Nov 19, 2024
1 parent 878ef96 commit 987f26b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
8 changes: 3 additions & 5 deletions rcdt_detection/nodes/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sensor_msgs.msg import Image, CameraInfo

from rcdt_detection.image_manipulation import (
ros_image_to_cv2_image,
ros_image_to_cv2_image_sliced,
segmentation_mask_to_binary_mask,
single_to_three_channel,
)
Expand Down Expand Up @@ -146,10 +146,8 @@ def process_rgb_image(
If it is, the image coordinates of its centroid are calculated.
"""
rgb_image = ros_image_to_cv2_image(message)

rgb_image = ros_image_to_cv2_image_sliced(message)
segmentation_result = segment_image(model=segmentation_model, image=rgb_image)
ros_logger.info(f"Segmented {len(segmentation_result.masks)} objects.")

centroid_image_coordinates = []
for mask in segmentation_result.masks:
Expand All @@ -171,7 +169,7 @@ def process_depth_image(
intrinsics: rs2.intrinsics,
) -> list[Point]:
"""Calculate world coordinate relative to the camera of image coordinates."""
depth_image = ros_image_to_cv2_image(message)
depth_image = ros_image_to_cv2_image_sliced(message)

world_coordinates = []
for image_coordinate in image_coordinates:
Expand Down
32 changes: 27 additions & 5 deletions rcdt_detection/rcdt_detection/image_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,51 @@
# SPDX-License-Identifier: Apache-2.0

import cv2
import numpy as np
from numpy import array, uint8
import torch
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
from math import floor, ceil

cv_bridge = CvBridge()


def ros_image_to_cv2_image(
image_message: Image, desired_encoding: str = "passthrough"
) -> np.array:
) -> array:
"""Convert ROS image message to cv2 image."""
return cv_bridge.imgmsg_to_cv2(image_message, desired_encoding=desired_encoding)


def segmentation_mask_to_binary_mask(mask: torch.Tensor) -> np.array:
def slice_image_to_stride(image: array, stride: int = 32) -> array:
"""Slice image to confirm to a given stride length"""
rows = image.shape[0]
cols = image.shape[1]
row_dist = rows % stride
col_dist = cols % stride
return image[
slice(floor(row_dist / 2), rows - ceil(row_dist / 2)),
slice(floor(col_dist / 2), cols - ceil(col_dist / 2)),
]


def ros_image_to_cv2_image_sliced(
image_message: Image, desired_encoding: str = "passthrough", stride: int = 32
) -> array:
"""slice an image so that its dimensions are a multiple of stride"""
return slice_image_to_stride(
ros_image_to_cv2_image(image_message, desired_encoding), stride
)


def segmentation_mask_to_binary_mask(mask: torch.Tensor) -> array:
"""Convert given mask to np.array with range [0, 255], dtype=uint8, and dimensions [height, width, channels]."""
binary_mask = mask.data.cpu().numpy().astype(np.uint8)
binary_mask = mask.data.cpu().numpy().astype(uint8)
binary_mask = binary_mask * 255
binary_mask = binary_mask.transpose(1, 2, 0)
return binary_mask


def single_to_three_channel(image: np.array) -> np.array:
def single_to_three_channel(image: array) -> array:
"""Convert given single-channel image to three-channel image."""
return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

0 comments on commit 987f26b

Please sign in to comment.