Skip to content

Commit

Permalink
Merge pull request #78 from AllenNeuralDynamics/wip/bundle-adjustment
Browse files Browse the repository at this point in the history
Wip/bundle adjustment
  • Loading branch information
jsiegle authored Aug 13, 2024
2 parents 2ea0700 + be8d7af commit 0fa4522
Show file tree
Hide file tree
Showing 14 changed files with 526 additions and 135 deletions.
2 changes: 1 addition & 1 deletion parallax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os

__version__ = "0.37.20"
__version__ = "0.37.21"

# allow multiple OpenMP instances
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
2 changes: 1 addition & 1 deletion parallax/axis_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Set logger name
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.setLevel(logging.WARNING)

class AxisFilter(QObject):
"""Class representing no filter."""
Expand Down
6 changes: 3 additions & 3 deletions parallax/bundle_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# Set logger name
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.setLevel(logging.WARNING)

class BALProblem:
def __init__(self, model, file_path):
Expand Down Expand Up @@ -184,7 +184,7 @@ def optimize(self, print_result=True):
self.opt_points = opt_params[12 * n_cams:].reshape(n_pts, 3)

if print_result:
print(f"\n************** Optimization completed. **************************")
print(f"\n*********** Optimization completed **************")
# Compute initial residuals
initial_residuals = self.residuals(initial_params)
initial_residuals_sum = np.sum(initial_residuals**2)
Expand All @@ -196,7 +196,7 @@ def optimize(self, print_result=True):
opt_residuals_sum = np.sum(opt_residuals**2)
average_residual = opt_residuals_sum / len(self.bal_problem.observations)
print(f"** After BA, Average residual of reproj: {np.round(average_residual, 2)} **")
print(f"******************************************************************")
print(f"****************************************************")

logger.debug(f"Optimized camera parameters: {self.opt_camera_params}")

Expand Down
14 changes: 11 additions & 3 deletions parallax/coords_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,17 @@ def func(self, x, measured_pts, global_pts, reflect_z=False):
def avg_error(self, x, measured_pts, global_pts, reflect_z=False):
"""Calculates the total error for the optimization."""
error_values = self.func(x, measured_pts, global_pts, reflect_z)
mean_squared_error = np.mean(error_values**2)
average_error = np.sqrt(mean_squared_error)
return average_error

# Calculate the L2 error for each point
l2_errors = np.zeros(len(global_pts))
for i in range(len(global_pts)):
error_vector = error_values[i * 3: (i + 1) * 3]
l2_errors[i] = np.linalg.norm(error_vector)

# Calculate the average L2 error
average_l2_error = np.mean(l2_errors)

return average_l2_error

def fit_params(self, measured_pts, global_pts):
"""Fits parameters to minimize the error defined in func"""
Expand Down
4 changes: 4 additions & 0 deletions parallax/main_window_wip.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,3 +772,7 @@ def save_user_configs(self):
width = self.width()
height = self.height()
self.user_setting.save_user_configs(nColumn, directory, width, height)

def closeEvent(self, event):
self.model.close_all_point_meshes()
event.accept()
15 changes: 14 additions & 1 deletion parallax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

from PyQt5.QtCore import QObject, pyqtSignal

from .camera import MockCamera, PySpinCamera, close_cameras, list_cameras
from .stage_listener import Stage, StageInfo

Expand All @@ -26,6 +25,9 @@ def __init__(self, version="V1", bundle_adjustment=False):
self.nMockCameras = 0
self.focos = []

# point mesh
self.point_mesh_instances = {}

# stage
self.nStages = 0
self.stages = {}
Expand Down Expand Up @@ -199,3 +201,14 @@ def save_all_camera_frames(self):
filename = 'camera%d_%s.png' % (i, camera.get_last_capture_time())
camera.save_last_image(filename)
self.msg_log.post("Saved camera frame: %s" % filename)

def add_point_mesh_instance(self, instance):
sn = instance.sn
if sn in self.point_mesh_instances.keys():
self.point_mesh_instances[sn].close()
self.point_mesh_instances[sn] = instance

def close_all_point_meshes(self):
for instance in self.point_mesh_instances.values():
instance.close()
self.point_mesh_instances.clear()
184 changes: 184 additions & 0 deletions parallax/point_mesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import os
import pandas as pd
import plotly.graph_objs as go
from PyQt5.QtWidgets import QWidget, QPushButton
from PyQt5.uic import loadUi
from PyQt5.QtCore import Qt
from PyQt5.QtWebEngineWidgets import QWebEngineView

package_dir = os.path.dirname(os.path.abspath(__file__))
debug_dir = os.path.join(os.path.dirname(package_dir), "debug")
ui_dir = os.path.join(os.path.dirname(package_dir), "ui")
csv_file = os.path.join(debug_dir, "points.csv")

class PointMesh(QWidget):
def __init__(self, model, file_path, sn, transM, scale, transM_BA=None, scale_BA=None, calib_completed=False):
super().__init__()
self.model = model
self.file_path = file_path
self.sn = sn
self.calib_completed = calib_completed
self.web_view = None

self.R, self.R_BA = {}, {}
self.T, self.T_BA = {}, {}
self.S, self.S_BA = {}, {}
self.points_dict = {}
self.traces = {} # Plotly trace objects
self.colors = {}
self.resizeEvent = self._on_resize

# Register this instance with the model
self.model.add_point_mesh_instance(self)

self.ui = loadUi(os.path.join(ui_dir, "point_mesh.ui"), self)
self.setWindowTitle(f"{self.sn} - Trajectory 3D View ")
self.setWindowFlags(Qt.Window | Qt.WindowMinimizeButtonHint | \
Qt.WindowMaximizeButtonHint | Qt.WindowCloseButtonHint)

self._set_transM(transM, scale)
if transM_BA is not None and scale_BA is not None and \
self.model.bundle_adjustment and self.calib_completed:
self.set_transM_BA(transM_BA, scale_BA)
self._parse_csv()
self._init_buttons()

def show(self):
self._init_ui()
self._update_canvas()
super().show() # Show the widget

def _init_ui(self):
if self.web_view is not None:
self.web_view.close()
self.web_view = QWebEngineView(self)
self.ui.verticalLayout1.addWidget(self.web_view)

def _set_transM(self, transM, scale):
self.R[self.sn] = transM[:3, :3]
self.T[self.sn] = transM[:3, 3]
self.S[self.sn] = scale[:3]

def set_transM_BA(self, transM, scale):
self.R_BA[self.sn] = transM[:3, :3]
self.T_BA[self.sn] = transM[:3, 3]
self.S_BA[self.sn] = scale[:3]

def _parse_csv(self):
self.df = pd.read_csv(self.file_path)
self.df = self.df[self.df["sn"] == self.sn] # filter by sn

self.local_pts_org = self.df[['local_x', 'local_y', 'local_z']].values
self.local_pts = self._local_to_global(self.local_pts_org, self.R[self.sn], self.T[self.sn], self.S[self.sn])
self.points_dict['local_pts'] = self.local_pts

self.global_pts = self.df[['global_x', 'global_y', 'global_z']].values
self.points_dict['global_pts'] = self.global_pts

if self.model.bundle_adjustment and self.calib_completed:
self.m_global_pts = self.df[['m_global_x', 'm_global_y', 'm_global_z']].values
self.points_dict['m_global_pts'] = self.m_global_pts

self.opt_global_pts = self.df[['opt_x', 'opt_y', 'opt_z']].values
self.points_dict['opt_global_pts'] = self.opt_global_pts

self.local_pts_BA = self._local_to_global(self.local_pts_org, self.R_BA[self.sn], self.T_BA[self.sn], self.S_BA[self.sn])
self.points_dict['local_pts_BA'] = self.local_pts_BA

# Assign unique colors to each key
color_list = ['red', 'blue', 'green', 'cyan', 'magenta']
for i, key in enumerate(self.points_dict.keys()):
self.colors[key] = color_list[i % len(color_list)]

def _local_to_global(self, local_pts, R, t, scale=None):
if scale is not None:
local_pts = local_pts * scale
global_coords_exp = R @ local_pts.T + t.reshape(-1, 1)
return global_coords_exp.T

def _init_buttons(self):
self.buttons = {}

for key in self.points_dict.keys():
button_name = self._get_button_name(key)
button = QPushButton(f'{button_name}')
button.setCheckable(True)
button.setMaximumWidth(200)
button.clicked.connect(lambda checked, key=key: self._update_plot(key, checked))
self.ui.verticalLayout2.addWidget(button)
self.buttons[key] = button

if self.model.bundle_adjustment and self.calib_completed:
keys_to_check = ['local_pts_BA', 'opt_global_pts']
else:
keys_to_check = ['local_pts', 'global_pts']

for key in keys_to_check:
self.buttons[key].setChecked(True)
self._draw_specific_points(key)

def _get_button_name(self, key):
if key == 'local_pts':
return 'stage'
elif key == 'local_pts_BA':
return 'stage (BA)'
elif key == 'global_pts':
return 'global'
elif key == 'm_global_pts':
return 'global (mean)'
elif key == 'opt_global_pts':
return 'global (BA)'
else:
return key # Default to the key if no match

def _update_plot(self, key, checked):
if checked:
self._draw_specific_points(key)
else:
self._remove_points_from_plot(key)
self._update_canvas()

def _remove_points_from_plot(self, key):
if key in self.points_dict:
del self.traces[key] # Remove from self.traces
self._update_canvas()

def _draw_specific_points(self, key):
pts = self.points_dict[key]
x_rounded = [round(x, 0) for x in pts[:, 0]]
y_rounded = [round(y, 0) for y in pts[:, 1]]
z_rounded = [round(z, 0) for z in pts[:, 2]]

scatter = go.Scatter3d(
x=x_rounded, y=y_rounded, z=z_rounded,
mode='markers+lines',
marker=dict(size=2, color=self.colors[key]),
name=self._get_button_name(key),
hoverinfo='x+y+z'
)
self.traces[key] = scatter # Store the trace in self.traces

def _update_canvas(self):
data = list(self.traces.values())
layout = go.Layout(
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='Z'
),
margin=dict(l=0, r=0, b=0, t=0)
)
fig = go.Figure(data=data, layout=layout)
html_content = fig.to_html(include_plotlyjs='cdn')
self.web_view.setHtml(html_content)

def _on_resize(self, event):
new_size = event.size()
self.web_view.resize(new_size.width(), new_size.height())
self._update_canvas()

# Resize horizontal layout
self.ui.horizontalLayoutWidget.resize(new_size.width(), new_size.height())



Loading

0 comments on commit 0fa4522

Please sign in to comment.