Skip to content

Commit

Permalink
cleanup and add viz
Browse files Browse the repository at this point in the history
  • Loading branch information
meshvaD committed Oct 4, 2024
1 parent 39eee83 commit fc23ea3
Show file tree
Hide file tree
Showing 12 changed files with 321 additions and 9 deletions.
17 changes: 15 additions & 2 deletions src/perception/tracking/dets_2d_3d/config/nuscenes_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,22 @@ dets_2d_3d_node:
lidar_frame: LIDAR_TOP

clustering_params:
car:
default:
clustering_distances: [5., 30., 45., 60.]
clustering_thresholds: [0.5, 1.1, 1.6, 2.1, 2.6]
cluster_size_min: 20.
cluster_size_min: 5.
cluster_size_max: 100000.
cluster_merge_threshold: 1.5
person:
clustering_distances: [5., 30., 45., 60.]
clustering_thresholds: [3., 3.5, 4.5, 5.0, 5.6]
cluster_size_min: 2.
cluster_size_max: 100.
cluster_merge_threshold: 1.5
traffic_light:
clustering_distances: [5., 30., 45., 60.]
clustering_thresholds: [3., 3.5, 4.5, 5.0, 5.6]
cluster_size_min: 2.
cluster_size_max: 100.
cluster_merge_threshold: 1.5

Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class TrackingNode : public rclcpp::Node {

template <typename T>
T getDefaultOrValue(std::map<std::string, T> m, std::string key) {
if (m.find(key) == m.end()) return m[key];
if (auto it = m.find(key); it != m.end()) return it->second;
return m["default"];
}
};
15 changes: 9 additions & 6 deletions src/perception/tracking/dets_2d_3d/src/dets_2d_3d_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ std::map<std::string, ClusteringParams> TrackingNode::initializeClusteringParams
std::string field_type = key.substr(split_loc + 1);

ClusteringParams params;
if (auto it = clusteringParams.find(det_type); it != clusteringParams.end()) {
params = it->second;
}

if (field_type == "clustering_distances")
params.clustering_distances = iter->second.as_double_array();
else if (field_type == "clustering_thresholds")
Expand Down Expand Up @@ -94,7 +98,7 @@ void TrackingNode::readCameraInfo(const sensor_msgs::msg::CameraInfo::SharedPtr
void TrackingNode::receiveDetections(const vision_msgs::msg::Detection2DArray::SharedPtr msg) {
if (!transformInited_) {
try {
transform_ = tf_buffer_->lookupTransform(cameraFrame_, lidarFrame_, msg->header.stamp);
transform_ = tf_buffer_->lookupTransform(cameraFrame_, lidarFrame_, tf2::TimePointZero);
transformInited_ = true;
} catch (const tf2::TransformException &ex) {
RCLCPP_INFO(this->get_logger(), "Could not transform %s", ex.what());
Expand Down Expand Up @@ -147,11 +151,10 @@ void TrackingNode::receiveDetections(const vision_msgs::msg::Detection2DArray::S
}

// clustering
auto clusterAndBBoxes = ProjectionUtils::getClusteredBBoxes(
inlierPoints, getDefaultOrValue<ClusteringParams>(
clusteringParams, det.results[0].hypothesis.class_id.c_str()));
std::vector<std::shared_ptr<Cluster>> clusters =
clusterAndBBoxes.first; // needed? for viz purposes only
auto params = getDefaultOrValue<ClusteringParams>(

Check failure on line 154 in src/perception/tracking/dets_2d_3d/src/dets_2d_3d_node.cpp

View workflow job for this annotation

GitHub Actions / clang_format

src/perception/tracking/dets_2d_3d/src/dets_2d_3d_node.cpp#L154

code should be clang-formatted [-Wclang-format-violations]
clusteringParams, det.results[0].hypothesis.class_id.c_str());

Check failure on line 155 in src/perception/tracking/dets_2d_3d/src/dets_2d_3d_node.cpp

View workflow job for this annotation

GitHub Actions / clang_format

src/perception/tracking/dets_2d_3d/src/dets_2d_3d_node.cpp#L155

code should be clang-formatted [-Wclang-format-violations]
auto clusterAndBBoxes = ProjectionUtils::getClusteredBBoxes(inlierPoints, params);
std::vector<std::shared_ptr<Cluster>> clusters = clusterAndBBoxes.first; // needed? for viz purposes only

Check failure on line 157 in src/perception/tracking/dets_2d_3d/src/dets_2d_3d_node.cpp

View workflow job for this annotation

GitHub Actions / clang_format

src/perception/tracking/dets_2d_3d/src/dets_2d_3d_node.cpp#L157

code should be clang-formatted [-Wclang-format-violations]
std::vector<vision_msgs::msg::BoundingBox3D> allBBoxes = clusterAndBBoxes.second;

if (clusters.size() == 0 || allBBoxes.size() == 0) continue;
Expand Down
3 changes: 3 additions & 0 deletions src/perception/tracking/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
setuptools==65.5.1
numpy==1.24.2
scipy
8 changes: 8 additions & 0 deletions src/perception/tracking/tracking_viz/config/basic_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
tracking_viz_node:
ros__parameters:
image_topic: /annotated_img
publish_viz_topic: /annotated_3d_img
det_3d_topic: /detections_3d
camera_info_topic: /CAM_FRONT/camera_info
camera_frame: CAM_FRONT
lidar_frame: LIDAR_TOP
23 changes: 23 additions & 0 deletions src/perception/tracking/tracking_viz/launch/basic_viz.launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from launch import LaunchDescription
from launch_ros.actions import Node
from ament_index_python.packages import get_package_share_directory
import os


def generate_launch_description():
ld = LaunchDescription()
config = os.path.join(
get_package_share_directory('tracking_viz'),
'config',
'basic_config.yaml'
)

tracking_viz_node = Node(
package='tracking_viz',
executable='tracking_viz_node',
name='tracking_viz_node',
parameters=[config],
arguments=['--ros-args', '--log-level', 'info']
)

return LaunchDescription([tracking_viz_node])
24 changes: 24 additions & 0 deletions src/perception/tracking/tracking_viz/package.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>tracking_viz</name>
<version>0.0.0</version>
<description>The tracking viz package</description>
<maintainer email="[email protected]">Steven Gong</maintainer>
<license>TODO: License declaration</license>

<depend>cv_bridge</depend>
<depend>sensor_msgs</depend>
<depend>vision_msgs</depend>
<depend>std_msgs</depend>

<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>

<export>
<build_type>ament_python</build_type>
</export>

</package>
Empty file.
4 changes: 4 additions & 0 deletions src/perception/tracking/tracking_viz/setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/tracking_viz
[install]
install_scripts=$base/lib/tracking_viz
30 changes: 30 additions & 0 deletions src/perception/tracking/tracking_viz/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from setuptools import setup
import os
from glob import glob

package_name = 'tracking_viz'

setup(
name=package_name,
version='0.0.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'), glob('launch/*.launch.py')),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='Meshva',
maintainer_email='[email protected]',
description='TODO: Package description',
license='TODO: License declaration',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'tracking_viz_node = tracking_viz.draw_tracks:main'
],
},
)
Empty file.
204 changes: 204 additions & 0 deletions src/perception/tracking/tracking_viz/tracking_viz/draw_tracks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import rclpy
from rclpy.node import Node
import os

from geometry_msgs.msg import Pose
from sensor_msgs.msg import Image, CameraInfo
from vision_msgs.msg import Detection3DArray

from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSHistoryPolicy

import tf2_geometry_msgs
from tf2_ros import TransformException
from tf2_ros.buffer import Buffer
from tf2_ros.transform_listener import TransformListener

import cv2
from cv_bridge import CvBridge, CvBridgeError
import numpy as np

import time
from collections import deque
from multiprocessing import Lock
from scipy.spatial.transform import Rotation
from random import randint

mutex = Lock()

class DrawBasicDetections(Node):
def __init__(self):
super().__init__("tracking_viz_node")

Check failure on line 30 in src/perception/tracking/tracking_viz/tracking_viz/draw_tracks.py

View workflow job for this annotation

GitHub Actions / Autopep8

src/perception/tracking/tracking_viz/tracking_viz/draw_tracks.py#L24-L30

from random import randint mutex = Lock() + class DrawBasicDetections(Node): def __init__(self):
self.get_logger().info("Creating tracking viz node...")

self.declare_parameter("image_topic", "/image")
self.declare_parameter("publish_viz_topic", "/annotated_3d_det_img")
self.declare_parameter("det_3d_topic", "/det_3d_topic")
self.declare_parameter("camera_info_topic", "/camera_info")
self.declare_parameter("camera_frame", "/camera")
self.declare_parameter("lidar_frame", "/lidar")

self.image_topic = self.get_parameter("image_topic").value
self.publish_viz_topic = self.get_parameter("publish_viz_topic").value
self.det_3d_topic = self.get_parameter("det_3d_topic").value
self.camera_info_topic = self.get_parameter("camera_info_topic").value
self.camera_frame = self.get_parameter("camera_frame").value
self.lidar_frame = self.get_parameter("lidar_frame").value

# subscribes to images & 3D dets
self.unprocessed_images = deque()
self.unprocessed_dets = deque()
self.camera_info = None
self.transform = None

self.tf_buffer = Buffer()
self.tf_listener = TransformListener(self.tf_buffer, self)

self.cv_bridge = CvBridge()

self.image_subscription = self.create_subscription(
Image,
self.image_topic,
self.image_callback,
qos_profile=QoSProfile(
reliability=QoSReliabilityPolicy.RELIABLE,
history=QoSHistoryPolicy.KEEP_LAST,
depth=10,
),
)

self.det_3d_subscription = self.create_subscription(
Detection3DArray,
self.det_3d_topic,
self.det_3d_callback,
qos_profile=QoSProfile(
reliability=QoSReliabilityPolicy.RELIABLE,
history=QoSHistoryPolicy.KEEP_LAST,
depth=10,
),
)

self.camera_info_subscription = self.create_subscription(
CameraInfo,
self.camera_info_topic,
self.camera_info_callback,
qos_profile=QoSProfile(
reliability=QoSReliabilityPolicy.RELIABLE,
history=QoSHistoryPolicy.KEEP_LAST,
depth=10,
),
)

self.viz_publisher = self.create_publisher(Image, self.publish_viz_topic, 10)

def image_callback(self, msg):
with mutex:
self.unprocessed_images.append(msg)
self.try_draw()

def det_3d_callback(self, msg):
with mutex:
self.unprocessed_dets.append(msg)

# get transform from lidar -> camera
if self.transform is None:
try:
self.transform = self.tf_buffer.lookup_transform(
self.camera_frame,
self.lidar_frame,
rclpy.time.Time())
except TransformException as ex:
self.get_logger().info(
f'Could not transform from {self.lidar_frame} to {self.camera_frame}: {ex}')
return

self.try_draw()

def camera_info_callback(self, msg):
self.camera_info = np.array(msg.p).reshape(3, 4)

self.get_logger().info(f"GOT CAMERA INFO... {self.camera_info}")

self.destroy_subscription(self.camera_info_subscription)

def try_draw(self):
if not self.unprocessed_images or not self.unprocessed_dets or self.transform is None or self.camera_info is None:
return

with mutex:
image_msg = self.unprocessed_images.popleft()
det_3d_msg = self.unprocessed_dets.popleft()

Check failure on line 130 in src/perception/tracking/tracking_viz/tracking_viz/draw_tracks.py

View workflow job for this annotation

GitHub Actions / Autopep8

src/perception/tracking/tracking_viz/tracking_viz/draw_tracks.py#L123-L130

def try_draw(self): if not self.unprocessed_images or not self.unprocessed_dets or self.transform is None or self.camera_info is None: return - + with mutex: image_msg = self.unprocessed_images.popleft() det_3d_msg = self.unprocessed_dets.popleft()
try:
image = self.cv_bridge.imgmsg_to_cv2(image_msg, desired_encoding="passthrough")
except CvBridgeError as e:
self.get_logger().error(str(e))
return

self.get_logger().info(f"PROCESSING IMAGE + DET3D...")

for det_msg in det_3d_msg.detections:
bbox = det_msg.bbox

center = np.array([bbox.center.position.x, bbox.center.position.y, bbox.center.position.z])
rot = Rotation.from_quat([bbox.center.orientation.x, bbox.center.orientation.y, bbox.center.orientation.z, bbox.center.orientation.w])
size = np.array([bbox.size.x, bbox.size.y, bbox.size.z])

# get all 8 corners
vert = [ center + rot.apply(np.multiply(size , np.array([-1, 1, 1]))),
center + rot.apply(np.multiply(size , np.array([-1, -1, 1]))),
center + rot.apply(np.multiply(size , np.array([-1, -1, -1]))),
center + rot.apply(np.multiply(size , np.array([-1, 1, -1]))),
center + rot.apply(np.multiply(size , np.array([1, 1, 1]))),
center + rot.apply(np.multiply(size , np.array([1, -1, 1]))),
center + rot.apply(np.multiply(size , np.array([1, -1, -1]))),
center + rot.apply(np.multiply(size , np.array([1, 1, -1]))),
]

color = (randint(0, 255), randint(0, 255), randint(0, 255))
verts_2d = []

Check failure on line 159 in src/perception/tracking/tracking_viz/tracking_viz/draw_tracks.py

View workflow job for this annotation

GitHub Actions / Autopep8

src/perception/tracking/tracking_viz/tracking_viz/draw_tracks.py#L139-L159

for det_msg in det_3d_msg.detections: bbox = det_msg.bbox - center = np.array([bbox.center.position.x, bbox.center.position.y, bbox.center.position.z]) - rot = Rotation.from_quat([bbox.center.orientation.x, bbox.center.orientation.y, bbox.center.orientation.z, bbox.center.orientation.w]) + center = np.array( + [bbox.center.position.x, bbox.center.position.y, bbox.center.position.z]) + rot = Rotation.from_quat([bbox.center.orientation.x, bbox.center.orientation.y, + bbox.center.orientation.z, bbox.center.orientation.w]) size = np.array([bbox.size.x, bbox.size.y, bbox.size.z]) # get all 8 corners - vert = [ center + rot.apply(np.multiply(size , np.array([-1, 1, 1]))), - center + rot.apply(np.multiply(size , np.array([-1, -1, 1]))), - center + rot.apply(np.multiply(size , np.array([-1, -1, -1]))), - center + rot.apply(np.multiply(size , np.array([-1, 1, -1]))), - center + rot.apply(np.multiply(size , np.array([1, 1, 1]))), - center + rot.apply(np.multiply(size , np.array([1, -1, 1]))), - center + rot.apply(np.multiply(size , np.array([1, -1, -1]))), - center + rot.apply(np.multiply(size , np.array([1, 1, -1]))), - ] + vert = [center + rot.apply(np.multiply(size, np.array([-1, 1, 1]))), + center + rot.apply(np.multiply(size, np.array([-1, -1, 1]))), + center + rot.apply(np.multiply(size, np.array([-1, -1, -1]))), + center + rot.apply(np.multiply(size, np.array([-1, 1, -1]))), + center + rot.apply(np.multiply(size, np.array([1, 1, 1]))), + center + rot.apply(np.multiply(size, np.array([1, -1, 1]))), + center + rot.apply(np.multiply(size, np.array([1, -1, -1]))), + center + rot.apply(np.multiply(size, np.array([1, 1, -1]))), + ] color = (randint(0, 255), randint(0, 255), randint(0, 255)) verts_2d = []
# project each 3d vert to 2d
for v in vert:
# convert v into a pos2d message
v_msg = Pose()
v_msg.position.x = v[0]
v_msg.position.y = v[1]
v_msg.position.z = v[2]

# lidar to camera frame
v_trans = tf2_geometry_msgs.do_transform_pose(v_msg, self.transform)
v_trans = np.array([v_trans.position.x, v_trans.position.y, v_trans.position.z, 1])

# project 3d camera frame to 2d camera plane
v_2d = self.camera_info @ v_trans
v_2d = np.array([int(v_2d[0] / v_2d[2]), int(v_2d[1] / v_2d[2])])
verts_2d.append(v_2d)

# draw vertex onto image
# image = cv2.circle(image, v_2d, 5, color, thickness=-1)

# draw edges
for i in range(4):
image = cv2.line(image, verts_2d[i], verts_2d[(i+1) % 4], color, 10) # face 1
image = cv2.line(image, verts_2d[i+4], verts_2d[(i+1) % 4 + 4], color, 10) # face 2
image = cv2.line(image, verts_2d[i], verts_2d[i+4], color, 10) # connect faces

self.publish_viz(image, image_msg)

def publish_viz(self, cv_img, msg):
imgmsg = self.cv_bridge.cv2_to_imgmsg(cv_img, "bgr8")
imgmsg.header.stamp = msg.header.stamp
imgmsg.header.frame_id = msg.header.frame_id
self.viz_publisher.publish(imgmsg)

def main(args=None):
rclpy.init(args=args)

tracking_viz_node = DrawBasicDetections()

Check failure on line 197 in src/perception/tracking/tracking_viz/tracking_viz/draw_tracks.py

View workflow job for this annotation

GitHub Actions / Autopep8

src/perception/tracking/tracking_viz/tracking_viz/draw_tracks.py#L179-L197

# draw edges for i in range(4): - image = cv2.line(image, verts_2d[i], verts_2d[(i+1) % 4], color, 10) # face 1 - image = cv2.line(image, verts_2d[i+4], verts_2d[(i+1) % 4 + 4], color, 10) # face 2 - image = cv2.line(image, verts_2d[i], verts_2d[i+4], color, 10) # connect faces + image = cv2.line(image, verts_2d[i], verts_2d[(i+1) % 4], color, 10) # face 1 + image = cv2.line(image, verts_2d[i+4], verts_2d[(i+1) % 4 + 4], color, 10) # face 2 + image = cv2.line(image, verts_2d[i], verts_2d[i+4], color, 10) # connect faces self.publish_viz(image, image_msg) - + def publish_viz(self, cv_img, msg): imgmsg = self.cv_bridge.cv2_to_imgmsg(cv_img, "bgr8") imgmsg.header.stamp = msg.header.stamp imgmsg.header.frame_id = msg.header.frame_id self.viz_publisher.publish(imgmsg) + def main(args=None): rclpy.init(args=args)
rclpy.spin(tracking_viz_node)
tracking_viz_node.destroy_node()
rclpy.shutdown()


if __name__ == "__main__":
main()

0 comments on commit fc23ea3

Please sign in to comment.