-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Semantic Labeling #203
base: ros2-devel
Are you sure you want to change the base?
Semantic Labeling #203
Changes from 1 commit
8ed6683
dcb74b7
c611e78
4fb0258
5d81216
bfcaeaa
19e9275
c29520d
4e7391f
9c43fc4
31551e4
3ec7b50
9dc9a40
929e570
e1ebf8b
704caa1
4f9305d
024c71c
c78cd4a
e503800
648a46e
3032f65
85f9577
0049598
e9fd4d5
9d52d98
30bc036
4d3b27c
94af48e
29ed345
23577ae
3688541
195b123
b8a4ccb
5363732
d73c983
2326742
b95ac8e
4bf52ea
9382b67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# The interface for a service that takes in a list of input labels | ||
# describing the food items on a plate and returns a sentence caption compiling | ||
# these labels used as a query for GroundingDINO detection. | ||
|
||
# A list of semantic labels corresponding to each of the masks of detected | ||
# items in the image | ||
string[] input_labels | ||
--- | ||
# A sentence caption compiling the semantic labels used as a query for | ||
# GroundingDINO to perform bounding box detections. | ||
string caption | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
""" | ||
This file defines the SegmentAllItems class, which launches an action server that | ||
segments all food items in the latest image and defines each segmentation with a semantic | ||
label using GPT-4V, GroundingDINO, and Segment Anything. | ||
takes a list of labels describing food items on a plate and returns segmentation masks | ||
of all food items in the latest image and defines each segmentation with a semantic label | ||
using a pipeline of foundation models including GPT-4o, GroundingDINO, and SegmentAnything. | ||
""" | ||
|
||
# Standard imports | ||
|
@@ -41,6 +42,7 @@ | |
# Local imports | ||
from ada_feeding_msgs.action import SegmentAllItems | ||
from ada_feeding_msgs.msg import Mask | ||
from ada_feeding_msgs.srv import GenerateCaption | ||
from ada_feeding_perception.helpers import ( | ||
BoundingBox, | ||
bbox_from_mask, | ||
|
@@ -56,7 +58,7 @@ class SegmentAllItemsNode(Node): | |
""" | ||
The SegmentAllItemsNode launches an action server that segments all food | ||
items in the latest image and defines each segmentation with a semantic | ||
label using GPT-4V, GroundingDINO, and Segment Anything. | ||
label using GPT-4o, GroundingDINO, and SegmentAnything. | ||
""" | ||
|
||
def __init__(self): | ||
|
@@ -75,6 +77,7 @@ def __init__(self): | |
seg_model_base_url, | ||
groundingdino_config_name, | ||
groundingdino_model_name, | ||
groundingdino_model_base_url, | ||
model_dir, | ||
self.use_efficient_sam, | ||
self.rate_hz, | ||
|
@@ -96,7 +99,7 @@ def __init__(self): | |
groundingdino_model_path = os.path.join(model_dir, groundingdino_model_name) | ||
if not os.path.isfile(groundingdino_model_path): | ||
self.get_logger().info("Model checkpoint does not exist. Downloading...") | ||
download_checkpoint(groundingdino_model_name, model_dir) | ||
download_checkpoint(groundingdino_model_name, model_dir, groundingdino_model_base_url) | ||
self.get_logger().info(f"Model checkpoint downloaded {groundingdino_model_path}.") | ||
|
||
# Set the path to the GroundingDINO configurations file in the model directory | ||
|
@@ -169,6 +172,14 @@ def __init__(self): | |
self.active_goal_request_lock = threading.Lock() | ||
self.active_goal_request = None | ||
|
||
# Create the service that invokes GPT-4o | ||
self.srv = self.create_service( | ||
GenerateCaption, | ||
"~/invoke_gpt4o", | ||
self.invoke_gpt4o_callback, | ||
callback_group=MutuallyExclusiveCallbackGroup(), | ||
) | ||
|
||
# Create the Action Server. | ||
# Note: remapping action names does not work: https://github.com/ros2/ros2/issues/1312 | ||
self._action_server = ActionServer( | ||
|
@@ -208,6 +219,7 @@ def read_params( | |
efficient_sam_model_base_url, | ||
groundingdino_config_name, | ||
groundingdino_model_name, | ||
groundingdino_model_base_url, | ||
model_dir, | ||
use_efficient_sam, | ||
rate_hz, | ||
|
@@ -399,6 +411,7 @@ def read_params( | |
seg_model_base_url, | ||
groundingdino_config_name.value, | ||
groundingdino_model_name.value, | ||
groundingdino_model_base_url.value, | ||
model_dir.value, | ||
use_efficient_sam.value, | ||
rate_hz.value, | ||
|
@@ -593,6 +606,45 @@ def cancel_callback(self, _: ServerGoalHandle) -> CancelResponse: | |
""" | ||
self.get_logger().info("Cancelling the goal request...") | ||
return CancelResponse.ACCEPT | ||
|
||
def invoke_gpt4o_callback(self, request: GenerateCaption.Request, response: GenerateCaption.Response): | ||
""" | ||
Callback function for the GPT-4o service. This function takes in a list | ||
of string labels describing the foods on an image as input and returns a | ||
caption generated by GPT-4o that compiles these labels into a descriptive | ||
sentence used as a query for GroundingDINO. | ||
|
||
Parameters | ||
---------- | ||
request: The given request message. | ||
response: The created response message. | ||
|
||
Returns | ||
---------- | ||
response: The updated response message based on the request. | ||
""" | ||
# Get the latest image and camera info | ||
with self.latest_img_msg_lock: | ||
latest_img_msg = self.latest_img_msg | ||
with self.camera_info_lock: | ||
camera_info = self.camera_info | ||
|
||
# Check if the image and camera info are available | ||
if latest_img_msg is None or camera_info is None: | ||
self.get_logger().error("Image or camera info not available.") | ||
return response | ||
|
||
# Convert the image message to a CV2 image | ||
image = ros_msg_to_cv2_image(latest_img_msg) | ||
|
||
# Run GPT-4o to generate a caption for the image | ||
vlm_query = self.run_gpt4o(image, request.input_labels) | ||
self.get_logger().info(f"GPT-4o Query: {vlm_query}") | ||
|
||
# Set the response message to the caption generated by GPT-4o | ||
response.caption = vlm_query | ||
|
||
return response | ||
|
||
def run_gpt4o(self, image: npt.NDArray, labels_list: list[str]): | ||
""" | ||
|
@@ -639,14 +691,14 @@ def run_gpt4o(self, image: npt.NDArray, labels_list: list[str]): | |
|
||
Here are some sample responses that convey how you should format your responses: | ||
|
||
"Food items including grapes, strawberries, blueberries, melon chunks, and | ||
carrots on a small, blue plate." | ||
Food items including grapes, strawberries, blueberries, melon chunks, and | ||
carrots on a small, blue plate. | ||
|
||
"Food items including strips of grilled meat and seasoned cucumber | ||
spears arranged on a light gray plate." | ||
Food items including strips of grilled meat and seasoned cucumber | ||
spears arranged on a light gray plate. | ||
|
||
"Food items including baked chicken pieces, black olives, bell pepper slices, | ||
and artichoke on a plate." | ||
Food items including baked chicken pieces, black olives, bell pepper slices, | ||
and artichoke on a plate. | ||
""" | ||
|
||
# Run GPT-4o to generate a sentence caption given the user query, system query, | ||
|
@@ -666,6 +718,7 @@ def run_gpt4o(self, image: npt.NDArray, labels_list: list[str]): | |
]}], | ||
) | ||
|
||
# Get the caption generated by GPT-4o | ||
vlm_query = response.choices[0].message.content | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related to error handling, is it possible that |
||
|
||
return vlm_query | ||
|
@@ -1111,11 +1164,6 @@ async def run_vision_pipeline(self, image_msg: Image, caption: str): | |
# Run Open-GroundingDINO on the image | ||
bbox_predictions = self.run_grounding_dino(image, caption, self.box_threshold, self.text_threshold) | ||
|
||
# Run GPT-4o on the image and caption to generate a sentence prompt | ||
# for the food items detected in the image | ||
gpt4o_query = self.run_gpt4o(image, caption) | ||
self.get_logger().info(f"GPT-4o Query: {gpt4o_query}") | ||
|
||
# Publish a visualization of the GroundingDINO predictions, if the visualization | ||
# flag is set to true | ||
if self.viz_groundingdino: | ||
|
@@ -1200,16 +1248,15 @@ async def execute_callback( | |
feedback = SegmentAllItems.Feedback() | ||
while ( | ||
rclpy.ok() | ||
and not goal_handle.is_cancel_requested() | ||
and not goal_handle.is_cancel_requested | ||
and not vision_pipeline_task.done() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there cases where the vision pipeline could hang? If so, I'd recommend adding a timeout to the action (maybe in the action message itself) to be robust to that |
||
): | ||
feedback.elapsed_time = ((self.get_clock().now() - starting_time).nanoseconds / | ||
1e9).to_msg() | ||
feedback.elapsed_time = (self.get_clock().now() - starting_time).to_msg() | ||
goal_handle.publish_feedback(feedback) | ||
rate.sleep() | ||
|
||
# If there is a cancel request, cancel the vision pipeline task | ||
if goal_handle.is_cancel_requested(): | ||
if goal_handle.is_cancel_requested: | ||
self.get_logger().info("Goal cancelled.") | ||
goal_handle.canceled() | ||
result = SegmentAllItems.Result() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: add newlines at the end of files. (I know not all files have it, but in general it is a best practice so we should enforce it on new/modified files)