Skip to content

Commit

Permalink
Tracking converter for ROS2 Humble (#505)
Browse files Browse the repository at this point in the history
* Added code for tracking with OAK: ros msgs, tracking converters and tracking examples

---------

Co-authored-by: Daniel Silva <[email protected]>
  • Loading branch information
daniqsilva25 and daniqsilva25 authored Mar 11, 2024
1 parent 21efd39 commit 98d495f
Show file tree
Hide file tree
Showing 13 changed files with 1,124 additions and 0 deletions.
2 changes: 2 additions & 0 deletions depthai_bridge/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ file(GLOB LIB_SRC
"src/ImuConverter.cpp"
"src/TFPublisher.cpp"
"src/TrackedFeaturesConverter.cpp"
"src/TrackDetectionConverter.cpp"
"src/TrackSpatialDetectionConverter.cpp"
)

add_library(${PROJECT_NAME} SHARED ${LIB_SRC})
Expand Down
60 changes: 60 additions & 0 deletions depthai_bridge/include/depthai_bridge/TrackDetectionConverter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include <deque>
#include <memory>
#include <string>

#include "depthai/pipeline/datatype/Tracklets.hpp"
#include "depthai_ros_msgs/msg/track_detection2_d_array.hpp"
#include "rclcpp/time.hpp"
#include "vision_msgs/msg/detection2_d_array.hpp"

namespace dai {

namespace ros {

class TrackDetectionConverter {
public:
TrackDetectionConverter(std::string frameName, int width, int height, bool normalized = false, float thresh = 0.0, bool getBaseDeviceTimestamp = false);
~TrackDetectionConverter();

/**
* @brief Handles cases in which the ROS time shifts forward or backward
* Should be called at regular intervals or on-change of ROS time, depending
* on monitoring.
*
*/
void updateRosBaseTime();

/**
* @brief Commands the converter to automatically update the ROS base time on message conversion based on variable
*
* @param update: bool whether to automatically update the ROS base time on message conversion
*/
void setUpdateRosBaseTimeOnToRosMsg(bool update = true) {
_updateRosBaseTimeOnToRosMsg = update;
}

void toRosMsg(std::shared_ptr<dai::Tracklets> trackData, std::deque<depthai_ros_msgs::msg::TrackDetection2DArray>& opDetectionMsgs);

depthai_ros_msgs::msg::TrackDetection2DArray::SharedPtr toRosMsgPtr(std::shared_ptr<dai::Tracklets> trackData);

private:
int _width, _height;
const std::string _frameName;
bool _normalized;
float _thresh;
std::chrono::time_point<std::chrono::steady_clock> _steadyBaseTime;
rclcpp::Time _rosBaseTime;
bool _getBaseDeviceTimestamp;
// For handling ROS time shifts and debugging
int64_t _totalNsChange{0};
// Whether to update the ROS base time on each message conversion
bool _updateRosBaseTimeOnToRosMsg{false};
};

} // namespace ros

namespace rosBridge = ros;

} // namespace dai
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once

#include <deque>
#include <memory>
#include <string>

#include "depthai/pipeline/datatype/Tracklets.hpp"
#include "depthai_ros_msgs/msg/track_detection2_d_array.hpp"
#include "rclcpp/time.hpp"
#include "vision_msgs/msg/detection2_d_array.hpp"

namespace dai {

namespace ros {

class TrackSpatialDetectionConverter {
public:
TrackSpatialDetectionConverter(
std::string frameName, int width, int height, bool normalized = false, float thresh = 0.0, bool getBaseDeviceTimestamp = false);
~TrackSpatialDetectionConverter();

/**
* @brief Handles cases in which the ROS time shifts forward or backward
* Should be called at regular intervals or on-change of ROS time, depending
* on monitoring.
*
*/
void updateRosBaseTime();

/**
* @brief Commands the converter to automatically update the ROS base time on message conversion based on variable
*
* @param update: bool whether to automatically update the ROS base time on message conversion
*/
void setUpdateRosBaseTimeOnToRosMsg(bool update = true) {
_updateRosBaseTimeOnToRosMsg = update;
}

void toRosMsg(std::shared_ptr<dai::Tracklets> trackData, std::deque<depthai_ros_msgs::msg::TrackDetection2DArray>& opDetectionMsgs);

depthai_ros_msgs::msg::TrackDetection2DArray::SharedPtr toRosMsgPtr(std::shared_ptr<dai::Tracklets> trackData);

private:
int _width, _height;
const std::string _frameName;
bool _normalized;
float _thresh;
std::chrono::time_point<std::chrono::steady_clock> _steadyBaseTime;
rclcpp::Time _rosBaseTime;
bool _getBaseDeviceTimestamp;
// For handling ROS time shifts and debugging
int64_t _totalNsChange{0};
// Whether to update the ROS base time on each message conversion
bool _updateRosBaseTimeOnToRosMsg{false};
};

} // namespace ros

namespace rosBridge = ros;

} // namespace dai
94 changes: 94 additions & 0 deletions depthai_bridge/src/TrackDetectionConverter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include "depthai_bridge/TrackDetectionConverter.hpp"

#include "depthai/depthai.hpp"
#include "depthai_bridge/depthaiUtility.hpp"

namespace dai {

namespace ros {

TrackDetectionConverter::TrackDetectionConverter(std::string frameName, int width, int height, bool normalized, float thresh, bool getBaseDeviceTimestamp)
: _frameName(frameName),
_width(width),
_height(height),
_normalized(normalized),
_thresh(thresh),
_steadyBaseTime(std::chrono::steady_clock::now()),
_getBaseDeviceTimestamp(getBaseDeviceTimestamp) {
_rosBaseTime = rclcpp::Clock().now();
}

TrackDetectionConverter::~TrackDetectionConverter() = default;

void TrackDetectionConverter::updateRosBaseTime() {
updateBaseTime(_steadyBaseTime, _rosBaseTime, _totalNsChange);
}

void TrackDetectionConverter::toRosMsg(std::shared_ptr<dai::Tracklets> trackData, std::deque<depthai_ros_msgs::msg::TrackDetection2DArray>& opDetectionMsgs) {
// setting the header
std::chrono::_V2::steady_clock::time_point tstamp;
if(_getBaseDeviceTimestamp)
tstamp = trackData->getTimestampDevice();
else
tstamp = trackData->getTimestamp();

depthai_ros_msgs::msg::TrackDetection2DArray opDetectionMsg;
opDetectionMsg.header.stamp = getFrameTime(_rosBaseTime, _steadyBaseTime, tstamp);
opDetectionMsg.header.frame_id = _frameName;
opDetectionMsg.detections.resize(trackData->tracklets.size());

// publishing
for(int i = 0; i < trackData->tracklets.size(); ++i) {
dai::Tracklet t = trackData->tracklets[i];
dai::Rect roi;
float xMin, yMin, xMax, yMax;

if(_normalized)
roi = t.roi;
else
roi = t.roi.denormalize(_width, _height);

xMin = roi.topLeft().x;
yMin = roi.topLeft().y;
xMax = roi.bottomRight().x;
yMax = roi.bottomRight().y;

float xSize = xMax - xMin;
float ySize = yMax - yMin;
float xCenter = xMin + xSize / 2.;
float yCenter = yMin + ySize / 2.;

opDetectionMsg.detections[i].results.resize(1);

opDetectionMsg.detections[i].results[0].hypothesis.class_id = std::to_string(t.label);
opDetectionMsg.detections[i].results[0].hypothesis.score = _thresh;

opDetectionMsg.detections[i].bbox.center.position.x = xCenter;
opDetectionMsg.detections[i].bbox.center.position.y = yCenter;
opDetectionMsg.detections[i].bbox.size_x = xSize;
opDetectionMsg.detections[i].bbox.size_y = ySize;

opDetectionMsg.detections[i].is_tracking = true;
std::stringstream track_id_str;
track_id_str << "" << t.id;
opDetectionMsg.detections[i].tracking_id = track_id_str.str();
opDetectionMsg.detections[i].tracking_age = t.age;
opDetectionMsg.detections[i].tracking_status = static_cast<int32_t>(t.status);
}

opDetectionMsgs.push_back(opDetectionMsg);
}

depthai_ros_msgs::msg::TrackDetection2DArray::SharedPtr TrackDetectionConverter::toRosMsgPtr(std::shared_ptr<dai::Tracklets> trackData) {
std::deque<depthai_ros_msgs::msg::TrackDetection2DArray> msgQueue;
toRosMsg(trackData, msgQueue);
auto msg = msgQueue.front();

depthai_ros_msgs::msg::TrackDetection2DArray::SharedPtr ptr = std::make_shared<depthai_ros_msgs::msg::TrackDetection2DArray>(msg);

return ptr;
}

} // namespace ros

} // namespace dai
101 changes: 101 additions & 0 deletions depthai_bridge/src/TrackSpatialDetectionConverter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#include "depthai_bridge/TrackSpatialDetectionConverter.hpp"

#include "depthai/depthai.hpp"
#include "depthai_bridge/depthaiUtility.hpp"

namespace dai {

namespace ros {

TrackSpatialDetectionConverter::TrackSpatialDetectionConverter(
std::string frameName, int width, int height, bool normalized, float thresh, bool getBaseDeviceTimestamp)
: _frameName(frameName),
_width(width),
_height(height),
_normalized(normalized),
_thresh(thresh),
_steadyBaseTime(std::chrono::steady_clock::now()),
_getBaseDeviceTimestamp(getBaseDeviceTimestamp) {
_rosBaseTime = rclcpp::Clock().now();
}

TrackSpatialDetectionConverter::~TrackSpatialDetectionConverter() = default;

void TrackSpatialDetectionConverter::updateRosBaseTime() {
updateBaseTime(_steadyBaseTime, _rosBaseTime, _totalNsChange);
}

void TrackSpatialDetectionConverter::toRosMsg(std::shared_ptr<dai::Tracklets> trackData,
std::deque<depthai_ros_msgs::msg::TrackDetection2DArray>& opDetectionMsgs) {
// setting the header
std::chrono::_V2::steady_clock::time_point tstamp;
if(_getBaseDeviceTimestamp)
tstamp = trackData->getTimestampDevice();
else
tstamp = trackData->getTimestamp();

depthai_ros_msgs::msg::TrackDetection2DArray opDetectionMsg;
opDetectionMsg.header.stamp = getFrameTime(_rosBaseTime, _steadyBaseTime, tstamp);
opDetectionMsg.header.frame_id = _frameName;
opDetectionMsg.detections.resize(trackData->tracklets.size());

// publishing
for(int i = 0; i < trackData->tracklets.size(); ++i) {
dai::Tracklet t = trackData->tracklets[i];
dai::Rect roi;
float xMin, yMin, xMax, yMax;

if(_normalized)
roi = t.roi;
else
roi = t.roi.denormalize(_width, _height);

xMin = roi.topLeft().x;
yMin = roi.topLeft().y;
xMax = roi.bottomRight().x;
yMax = roi.bottomRight().y;

float xSize = xMax - xMin;
float ySize = yMax - yMin;
float xCenter = xMin + xSize / 2.;
float yCenter = yMin + ySize / 2.;

opDetectionMsg.detections[i].results.resize(1);

opDetectionMsg.detections[i].results[0].hypothesis.class_id = std::to_string(t.label);
opDetectionMsg.detections[i].results[0].hypothesis.score = _thresh;

opDetectionMsg.detections[i].bbox.center.position.x = xCenter;
opDetectionMsg.detections[i].bbox.center.position.y = yCenter;
opDetectionMsg.detections[i].bbox.size_x = xSize;
opDetectionMsg.detections[i].bbox.size_y = ySize;

opDetectionMsg.detections[i].is_tracking = true;
std::stringstream track_id_str;
track_id_str << "" << t.id;
opDetectionMsg.detections[i].tracking_id = track_id_str.str();
opDetectionMsg.detections[i].tracking_age = t.age;
opDetectionMsg.detections[i].tracking_status = static_cast<int32_t>(t.status);

// converting mm to meters since per ros rep-103 lenght should always be in meters
opDetectionMsg.detections[i].results[0].pose.pose.position.x = t.spatialCoordinates.x / 1000.0;
opDetectionMsg.detections[i].results[0].pose.pose.position.y = t.spatialCoordinates.y / 1000.0;
opDetectionMsg.detections[i].results[0].pose.pose.position.z = t.spatialCoordinates.z / 1000.0;
}

opDetectionMsgs.push_back(opDetectionMsg);
}

depthai_ros_msgs::msg::TrackDetection2DArray::SharedPtr TrackSpatialDetectionConverter::toRosMsgPtr(std::shared_ptr<dai::Tracklets> trackData) {
std::deque<depthai_ros_msgs::msg::TrackDetection2DArray> msgQueue;
toRosMsg(trackData, msgQueue);
auto msg = msgQueue.front();

depthai_ros_msgs::msg::TrackDetection2DArray::SharedPtr ptr = std::make_shared<depthai_ros_msgs::msg::TrackDetection2DArray>(msg);

return ptr;
}

} // namespace ros

} // namespace dai
6 changes: 6 additions & 0 deletions depthai_examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,15 @@ dai_add_node_ros2(feature_tracker src/feature_tracker_publisher.cpp)
dai_add_node_ros2(stereo_node src/stereo_publisher.cpp)
dai_add_node_ros2(yolov4_spatial_node src/yolov4_spatial_publisher.cpp)
dai_add_node_ros2(yolov4_node src/yolov4_publisher.cpp)
dai_add_node_ros2(tracker_yolov4_node src/tracker_yolov4_publisher.cpp)
dai_add_node_ros2(tracker_yolov4_spatial_node src/tracker_yolov4_spatial_publisher.cpp)

target_compile_definitions(mobilenet_node PRIVATE BLOB_NAME="${mobilenet_blob_name}")
target_compile_definitions(yolov4_spatial_node PRIVATE BLOB_NAME="${tiny_yolo_v4_blob_name}")
target_compile_definitions(yolov4_node PRIVATE BLOB_NAME="${tiny_yolo_v4_blob_name}")
target_compile_definitions(stereo_inertial_node PRIVATE BLOB_NAME="${tiny_yolo_v4_blob_name}")
target_compile_definitions(tracker_yolov4_node PRIVATE BLOB_NAME="${tiny_yolo_v4_blob_name}")
target_compile_definitions(tracker_yolov4_spatial_node PRIVATE BLOB_NAME="${tiny_yolo_v4_blob_name}")

if($ENV{ROS_DISTRO} STREQUAL "galactic")
target_compile_definitions(rgb_stereo_node PRIVATE IS_GALACTIC)
Expand All @@ -128,6 +132,8 @@ install(TARGETS
yolov4_spatial_node
yolov4_node
feature_tracker
tracker_yolov4_node
tracker_yolov4_spatial_node
DESTINATION lib/${PROJECT_NAME})

ament_package()
Expand Down
Loading

0 comments on commit 98d495f

Please sign in to comment.