Skip to content

Commit

Permalink
PX4-OpticalFlow impl
Browse files Browse the repository at this point in the history
  • Loading branch information
dakejahl committed Jan 7, 2025
1 parent 5f8c483 commit e5b6194
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 171 deletions.
2 changes: 1 addition & 1 deletion Tools/simulation/gz
20 changes: 20 additions & 0 deletions src/modules/simulation/gz_plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@ gz_find_package(gz-sensors8 REQUIRED)
gz_find_package(gz-transport12 REQUIRED)
find_package(OpenCV REQUIRED)

include(ExternalProject)

ExternalProject_Add(OpticalFlow
GIT_REPOSITORY https://github.com/PX4/PX4-OpticalFlow.git
GIT_TAG master
PREFIX ${CMAKE_BINARY_DIR}/OpticalFlow
INSTALL_DIR ${CMAKE_BINARY_DIR}/OpticalFlow/install
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
BUILD_BYPRODUCTS ${CMAKE_BINARY_DIR}/OpticalFlow/install/lib/libOpticalFlow.so
)

ExternalProject_Get_Property(OpticalFlow install_dir)

set(OpticalFlow_INCLUDE_DIRS ${install_dir}/include)
set(OpticalFlow_LIBS ${install_dir}/lib/libOpticalFlow.so)

add_library(${PROJECT_NAME} SHARED
OpticalFlow.cpp
OpticalFlowSystem.cpp
Expand All @@ -22,11 +38,15 @@ target_link_libraries(${PROJECT_NAME}
PUBLIC gz-sim8::gz-sim8
PUBLIC gz-transport12::gz-transport12
PUBLIC ${OpenCV_LIBS}
PUBLIC ${OpticalFlow_LIBS}
# PUBLIC ${PROTOBUF_LIBRARIES}
)

target_include_directories(${PROJECT_NAME}
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}
PUBLIC ${CMAKE_CURRENT_BINARY_DIR}
PUBLIC ${OpenCV_INCLUDE_DIRS}
PUBLIC ${OpticalFlow_INCLUDE_DIRS}
)

add_dependencies(${PROJECT_NAME} OpticalFlow)
202 changes: 62 additions & 140 deletions src/modules/simulation/gz_plugin/OpticalFlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

using namespace custom;

bool OpticalFlow::Load(const sdf::Sensor &_sdf)
bool OpticalFlowSensor::Load(const sdf::Sensor &_sdf)
{
auto type = gz::sensors::customType(_sdf);
if ("optical_flow" != type) {
Expand All @@ -17,188 +17,110 @@ bool OpticalFlow::Load(const sdf::Sensor &_sdf)

gz::sensors::Sensor::Load(_sdf);

this->pub = this->node.Advertise<sensor_msgs::msgs::OpticalFlow>(this->Topic());
_publisher = _node.Advertise<sensor_msgs::msgs::OpticalFlow>(this->Topic());
gzdbg << "Advertising optical flow data on: " << this->Topic() << std::endl;

// Get camera topic from our sensor config
auto elem = _sdf.Element();
auto opticalFlowElem = elem->GetElement("gz:optical_flow");
auto cameraTopic = opticalFlowElem->Get<std::string>("camera_topic");
auto camera_topic = opticalFlowElem->Get<std::string>("camera_topic");

std::string topic;
int image_width = 0;
int image_height = 0;
int update_rate = 0;
float hfov = 0;

// Get FOV from the actual camera sensor's config
auto cameraElem = elem->GetParent()->GetElement("sensor");
while (cameraElem) {
if (cameraElem->Get<std::string>("name") == "flow_camera") {
auto camera = cameraElem->GetElement("camera");
this->horizontal_fov = camera->GetElement("horizontal_fov")->Get<double>();
this->vertical_fov = this->horizontal_fov * 0.75; // Assume 4:3 aspect ratio
auto sensorElem = elem->GetParent()->GetElement("sensor");
while (sensorElem) {
if (sensorElem->Get<std::string>("name") == "flow_camera") {

auto cameraElem = sensorElem->GetElement("camera");
update_rate = sensorElem->GetElement("update_rate")->Get<int>();
hfov = cameraElem->GetElement("horizontal_fov")->Get<double>();

auto imageElem = cameraElem->GetElement("image");
image_width = imageElem->GetElement("width")->Get<int>();
image_height = imageElem->GetElement("height")->Get<int>();
break;
}
cameraElem = cameraElem->GetNextElement("sensor");
sensorElem = sensorElem->GetNextElement("sensor");
}

gzdbg << "Using camera FOV - horizontal: " << this->horizontal_fov
<< " vertical: " << this->vertical_fov << std::endl;
gzdbg << "image_width: " << image_width << std::endl;
gzdbg << "image_height: " << image_height << std::endl;
gzdbg << "update_rate: " << update_rate << std::endl;
gzdbg << "hfov: " << hfov << std::endl;

// Subscribe to camera
gzdbg << "Subscribing to camera topic: " << cameraTopic << std::endl;
if (!this->node.Subscribe(cameraTopic, &OpticalFlow::OnImage, this)) {
gzerr << "Failed to subscribe to camera topic: " << cameraTopic << std::endl;
gzdbg << "Subscribing to camera topic: " << camera_topic << std::endl;
if (!_node.Subscribe(camera_topic, &OpticalFlowSensor::OnImage, this)) {
gzerr << "Failed to subscribe to camera topic: " << camera_topic << std::endl;
return false;
}

this->lastUpdateTime = std::chrono::steady_clock::now();
// TODO: get from sdf
float focal_length = (image_width / 2.0f) / tan(hfov / 2.0f);

// Create OpticalFlow
_optical_flow = std::make_shared<OpticalFlowOpenCV>(focal_length, focal_length, update_rate, image_width, image_height);

return true;
}

void OpticalFlow::OnImage(const gz::msgs::Image &_msg)
void OpticalFlowSensor::OnImage(const gz::msgs::Image &image_msg)
{
if (_msg.width() == 0 || _msg.height() == 0) {
if (image_msg.width() == 0 || image_msg.height() == 0) {
gzerr << "Invalid image dimensions" << std::endl;
return;
}

cv::Mat frame;
if (image_msg.pixel_format_type() == gz::msgs::PixelFormatType::RGB_INT8) {
cv::Mat temp(image_msg.height(), image_msg.width(), CV_8UC3);
std::memcpy(temp.data, image_msg.data().c_str(), image_msg.data().size());
cv::cvtColor(temp, _last_image_gray, cv::COLOR_RGB2GRAY);

} else if (image_msg.pixel_format_type() == gz::msgs::PixelFormatType::L_INT8) {
std::memcpy(_last_image_gray.data, image_msg.data().c_str(), image_msg.data().size());

// Convert image to grayscale
if (_msg.pixel_format_type() == gz::msgs::PixelFormatType::RGB_INT8) {
frame = cv::Mat(_msg.height(), _msg.width(), CV_8UC3, (void *)_msg.data().c_str());
cv::cvtColor(frame, frame, cv::COLOR_RGB2GRAY);
} else if (_msg.pixel_format_type() == gz::msgs::PixelFormatType::L_INT8) {
frame = cv::Mat(_msg.height(), _msg.width(), CV_8UC1, (void *)_msg.data().c_str());
} else {
gzerr << "Unsupported image format" << std::endl;
return;
}

// Preprocess image
// cv::GaussianBlur(frame, frame, this->blur_size, this->blur_sigma);

// // Scale down for performance
// cv::Mat scaled_frame;
// cv::resize(frame, scaled_frame, cv::Size(), this->scale_factor, this->scale_factor);

// ProcessFlow(scaled_frame);
ProcessFlow(frame);
}

void OpticalFlow::ProcessFlow(const cv::Mat &current_frame)
{
if (!flow_initialized) {
current_frame.copyTo(prevFrame);
flow_initialized = true;
return;
}

// Detect features in previous frame
std::vector<cv::Point2f> current_points;
std::vector<uchar> status;
std::vector<float> err;

if (prev_points.empty()) {
cv::goodFeaturesToTrack(prevFrame, prev_points, max_corners, quality_level, min_distance);
}

if (prev_points.empty()) {
gzwarn << "No features detected in previous frame" << std::endl;
current_frame.copyTo(prevFrame);
return;
}

// Calculate optical flow
cv::calcOpticalFlowPyrLK(prevFrame, current_frame, prev_points, current_points, status, err);

// Filter valid points and calculate flow
std::vector<cv::Point2f> good_old, good_new;
for (size_t i = 0; i < status.size(); i++) {
if (status[i]) {
good_old.push_back(prev_points[i]);
good_new.push_back(current_points[i]);
}
}

if (good_new.empty() || good_old.empty()) {
gzwarn << "No valid flow vectors" << std::endl;
quality = 0;
current_frame.copyTo(prevFrame);
prev_points.clear();
return;
}

// Calculate average flow
cv::Point2f mean_flow(0, 0);
for (size_t i = 0; i < good_new.size(); i++) {
mean_flow += good_new[i] - good_old[i];
}
mean_flow = mean_flow * (1.0f / good_new.size());

// Convert to radians using FOV and resolution
// double rad_per_pixel_x = horizontal_fov / (current_frame.cols / double(scale_factor));
// double rad_per_pixel_y = vertical_fov / (current_frame.rows / double(scale_factor));
double rad_per_pixel_x = horizontal_fov / current_frame.cols;
double rad_per_pixel_y = vertical_fov / current_frame.rows;

integrated_x = (double)mean_flow.x * rad_per_pixel_x;
integrated_y = (double)mean_flow.y * rad_per_pixel_y;

// Calculate quality metric
std::vector<float> flow_magnitudes;
for (size_t i = 0; i < good_new.size(); i++) {
cv::Point2f flow = good_new[i] - good_old[i];
flow_magnitudes.push_back(cv::norm(flow));
}
// Store current timestamp for integration time calculation
uint32_t current_timestamp = (image_msg.header().stamp().sec() * 1000000ULL +
image_msg.header().stamp().nsec() / 1000ULL) & 0xFFFFFFFF;

float avg_magnitude = 0;
if (!flow_magnitudes.empty()) {
avg_magnitude = std::accumulate(flow_magnitudes.begin(), flow_magnitudes.end(), 0.0f)
/ flow_magnitudes.size();
if (_last_image_timestamp != 0) {
_integration_time_us = (current_timestamp - _last_image_timestamp) & 0xFFFFFFFF;
}

// Compute quality based on flow consistency and magnitude
float std_dev = 0;
for (float mag : flow_magnitudes) {
std_dev += (mag - avg_magnitude) * (mag - avg_magnitude);
}
std_dev = std_dev > 0 ? sqrt(std_dev / flow_magnitudes.size()) : 0;

// Higher quality when flow is consistent (low std_dev) and has reasonable magnitude
quality = std::min(255.0f, (avg_magnitude * 100.0f) / (std_dev + 1.0f));

// Check for excessive motion
if (std::abs(integrated_x) > M_PI_2 || std::abs(integrated_y) > M_PI_2) {
gzwarn << "Excessive motion detected" << std::endl;
quality = 0;
}

// Update state for next iteration
current_frame.copyTo(prevFrame);
prev_points = good_new; // Use current good points for next iteration

flow_updated = true;
_last_image_timestamp = current_timestamp;
_new_image_available = true;
}

bool OpticalFlow::Update(const std::chrono::steady_clock::duration &_now)
bool OpticalFlowSensor::Update(const std::chrono::steady_clock::duration &_now)
{
if (!flow_updated) {
if (!_new_image_available) {
return true;
}

auto currentTime = std::chrono::steady_clock::now();
auto deltaTime = std::chrono::duration_cast<std::chrono::microseconds>(
currentTime - this->lastUpdateTime);

sensor_msgs::msgs::OpticalFlow msg;
msg.set_time_usec(std::chrono::duration_cast<std::chrono::microseconds>(_now).count());
msg.set_integration_time_us(deltaTime.count());
msg.set_integrated_x(this->integrated_x);
msg.set_integrated_y(this->integrated_y);
msg.set_quality(this->quality);
msg.set_time_usec(_last_image_timestamp);

int quality = _optical_flow->calcFlow(_last_image_gray.data, _last_image_timestamp, _integration_time_us, _flow_x, _flow_y);

msg.set_integrated_x(_flow_x);
msg.set_integrated_y(_flow_y);
msg.set_integration_time_us(_integration_time_us);
msg.set_quality(quality);

if (!this->pub.Publish(msg)) {
if (!_publisher.Publish(msg)) {
gzwarn << "Failed to publish optical flow message" << std::endl;
}

flow_updated = false;
this->lastUpdateTime = currentTime;
_new_image_available = false;
return true;
}
46 changes: 18 additions & 28 deletions src/modules/simulation/gz_plugin/OpticalFlow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,37 @@
#include <opencv2/opencv.hpp>
#include <numeric>
#include <vector>
#include <memory>

#include "flow_opencv.hpp"

namespace custom
{
class OpticalFlow : public gz::sensors::Sensor
class OpticalFlowSensor : public gz::sensors::Sensor
{
public:
virtual bool Load(const sdf::Sensor &_sdf) override;
virtual bool Update(const std::chrono::steady_clock::duration &_now) override;

private:
void OnImage(const gz::msgs::Image &_msg);
void ProcessFlow(const cv::Mat &current_frame);

cv::Mat prevFrame;
gz::transport::Node node;
gz::transport::Node::Publisher pub;

// Camera parameters
double horizontal_fov{0.79}; // Default FOV in radians
double vertical_fov{0.6}; // Default FOV in radians

// Flow computation parameters
const int max_corners{100};
const double quality_level{0.3};
const double min_distance{7.0};
gz::transport::Node _node;
gz::transport::Node::Publisher _publisher;

// Flow state
double integrated_x{0.0};
double integrated_y{0.0};
double quality{0.0};
std::chrono::steady_clock::time_point lastUpdateTime;
// Flow
std::shared_ptr<OpticalFlowOpenCV> _optical_flow {nullptr};
float _flow_x {0.0f};
float _flow_y {0.0f};
int _integration_time_us;

// Image processing parameters
const cv::Size blur_size{5, 5};
const double blur_sigma{1.5};
const float scale_factor{0.5}; // Scale image down for performance
// Camera
double _horizontal_fov {0.0};
double _vertical_fov {0.0};

bool flow_updated{false};

// Previous points for optical flow
std::vector<cv::Point2f> prev_points;
bool flow_initialized{false};
cv::Mat _last_image_gray;
uint32_t _last_image_timestamp {0};
bool _new_image_available {false};
};

} // end namespace custom
2 changes: 1 addition & 1 deletion src/modules/simulation/gz_plugin/OpticalFlowSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void OpticalFlowSystem::PreUpdate(const gz::sim::UpdateInfo &, gz::sim::EntityCo
}

gz::sensors::SensorFactory sensorFactory;
auto sensor = sensorFactory.CreateSensor<custom::OpticalFlow>(data);
auto sensor = sensorFactory.CreateSensor<custom::OpticalFlowSensor>(data);

if (sensor == nullptr) {
gzerr << "Failed to create optical flow sensor [" << sensorScopedName << "]" << std::endl;
Expand Down
2 changes: 1 addition & 1 deletion src/modules/simulation/gz_plugin/OpticalFlowSystem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ class OpticalFlowSystem:
private:
void RemoveSensorEntities(const gz::sim::EntityComponentManager &_ecm);

std::unordered_map<gz::sim::Entity, std::shared_ptr<OpticalFlow>> entitySensorMap;
std::unordered_map<gz::sim::Entity, std::shared_ptr<OpticalFlowSensor>> entitySensorMap;
};
} // end namespace custom

0 comments on commit e5b6194

Please sign in to comment.