Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wip/bundle adjustment #78

Merged
merged 18 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion parallax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os

__version__ = "0.37.20"
__version__ = "0.37.21"

# allow multiple OpenMP instances
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
2 changes: 1 addition & 1 deletion parallax/axis_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Set logger name
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.setLevel(logging.WARNING)

class AxisFilter(QObject):
"""Class representing no filter."""
Expand Down
6 changes: 3 additions & 3 deletions parallax/bundle_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# Set logger name
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.setLevel(logging.WARNING)

class BALProblem:
def __init__(self, model, file_path):
Expand Down Expand Up @@ -184,7 +184,7 @@ def optimize(self, print_result=True):
self.opt_points = opt_params[12 * n_cams:].reshape(n_pts, 3)

if print_result:
print(f"\n************** Optimization completed. **************************")
print(f"\n*********** Optimization completed **************")
# Compute initial residuals
initial_residuals = self.residuals(initial_params)
initial_residuals_sum = np.sum(initial_residuals**2)
Expand All @@ -196,7 +196,7 @@ def optimize(self, print_result=True):
opt_residuals_sum = np.sum(opt_residuals**2)
average_residual = opt_residuals_sum / len(self.bal_problem.observations)
print(f"** After BA, Average residual of reproj: {np.round(average_residual, 2)} **")
print(f"******************************************************************")
print(f"****************************************************")

logger.debug(f"Optimized camera parameters: {self.opt_camera_params}")

Expand Down
14 changes: 11 additions & 3 deletions parallax/coords_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,17 @@ def func(self, x, measured_pts, global_pts, reflect_z=False):
def avg_error(self, x, measured_pts, global_pts, reflect_z=False):
"""Calculates the total error for the optimization."""
error_values = self.func(x, measured_pts, global_pts, reflect_z)
mean_squared_error = np.mean(error_values**2)
average_error = np.sqrt(mean_squared_error)
return average_error

# Calculate the L2 error for each point
l2_errors = np.zeros(len(global_pts))
for i in range(len(global_pts)):
error_vector = error_values[i * 3: (i + 1) * 3]
l2_errors[i] = np.linalg.norm(error_vector)

# Calculate the average L2 error
average_l2_error = np.mean(l2_errors)

return average_l2_error

def fit_params(self, measured_pts, global_pts):
"""Fits parameters to minimize the error defined in func"""
Expand Down
4 changes: 4 additions & 0 deletions parallax/main_window_wip.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,3 +772,7 @@ def save_user_configs(self):
width = self.width()
height = self.height()
self.user_setting.save_user_configs(nColumn, directory, width, height)

def closeEvent(self, event):
self.model.close_all_point_meshes()
event.accept()
15 changes: 14 additions & 1 deletion parallax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

from PyQt5.QtCore import QObject, pyqtSignal

from .camera import MockCamera, PySpinCamera, close_cameras, list_cameras
from .stage_listener import Stage, StageInfo

Expand All @@ -26,6 +25,9 @@ def __init__(self, version="V1", bundle_adjustment=False):
self.nMockCameras = 0
self.focos = []

# point mesh
self.point_mesh_instances = {}

# stage
self.nStages = 0
self.stages = {}
Expand Down Expand Up @@ -199,3 +201,14 @@ def save_all_camera_frames(self):
filename = 'camera%d_%s.png' % (i, camera.get_last_capture_time())
camera.save_last_image(filename)
self.msg_log.post("Saved camera frame: %s" % filename)

def add_point_mesh_instance(self, instance):
sn = instance.sn
if sn in self.point_mesh_instances.keys():
self.point_mesh_instances[sn].close()
self.point_mesh_instances[sn] = instance

def close_all_point_meshes(self):
for instance in self.point_mesh_instances.values():
instance.close()
self.point_mesh_instances.clear()
184 changes: 184 additions & 0 deletions parallax/point_mesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import os
import pandas as pd
import plotly.graph_objs as go
from PyQt5.QtWidgets import QWidget, QPushButton
from PyQt5.uic import loadUi
from PyQt5.QtCore import Qt
from PyQt5.QtWebEngineWidgets import QWebEngineView

package_dir = os.path.dirname(os.path.abspath(__file__))
debug_dir = os.path.join(os.path.dirname(package_dir), "debug")
ui_dir = os.path.join(os.path.dirname(package_dir), "ui")
csv_file = os.path.join(debug_dir, "points.csv")

class PointMesh(QWidget):
def __init__(self, model, file_path, sn, transM, scale, transM_BA=None, scale_BA=None, calib_completed=False):
super().__init__()
self.model = model
self.file_path = file_path
self.sn = sn
self.calib_completed = calib_completed
self.web_view = None

self.R, self.R_BA = {}, {}
self.T, self.T_BA = {}, {}
self.S, self.S_BA = {}, {}
self.points_dict = {}
self.traces = {} # Plotly trace objects
self.colors = {}
self.resizeEvent = self._on_resize

# Register this instance with the model
self.model.add_point_mesh_instance(self)

self.ui = loadUi(os.path.join(ui_dir, "point_mesh.ui"), self)
self.setWindowTitle(f"{self.sn} - Trajectory 3D View ")
self.setWindowFlags(Qt.Window | Qt.WindowMinimizeButtonHint | \
Qt.WindowMaximizeButtonHint | Qt.WindowCloseButtonHint)

self._set_transM(transM, scale)
if transM_BA is not None and scale_BA is not None and \
self.model.bundle_adjustment and self.calib_completed:
self.set_transM_BA(transM_BA, scale_BA)
self._parse_csv()
self._init_buttons()

def show(self):
self._init_ui()
self._update_canvas()
super().show() # Show the widget

def _init_ui(self):
if self.web_view is not None:
self.web_view.close()
self.web_view = QWebEngineView(self)
self.ui.verticalLayout1.addWidget(self.web_view)

def _set_transM(self, transM, scale):
self.R[self.sn] = transM[:3, :3]
self.T[self.sn] = transM[:3, 3]
self.S[self.sn] = scale[:3]

def set_transM_BA(self, transM, scale):
self.R_BA[self.sn] = transM[:3, :3]
self.T_BA[self.sn] = transM[:3, 3]
self.S_BA[self.sn] = scale[:3]

def _parse_csv(self):
self.df = pd.read_csv(self.file_path)
self.df = self.df[self.df["sn"] == self.sn] # filter by sn

self.local_pts_org = self.df[['local_x', 'local_y', 'local_z']].values
self.local_pts = self._local_to_global(self.local_pts_org, self.R[self.sn], self.T[self.sn], self.S[self.sn])
self.points_dict['local_pts'] = self.local_pts

self.global_pts = self.df[['global_x', 'global_y', 'global_z']].values
self.points_dict['global_pts'] = self.global_pts

if self.model.bundle_adjustment and self.calib_completed:
self.m_global_pts = self.df[['m_global_x', 'm_global_y', 'm_global_z']].values
self.points_dict['m_global_pts'] = self.m_global_pts

self.opt_global_pts = self.df[['opt_x', 'opt_y', 'opt_z']].values
self.points_dict['opt_global_pts'] = self.opt_global_pts

self.local_pts_BA = self._local_to_global(self.local_pts_org, self.R_BA[self.sn], self.T_BA[self.sn], self.S_BA[self.sn])
self.points_dict['local_pts_BA'] = self.local_pts_BA

# Assign unique colors to each key
color_list = ['red', 'blue', 'green', 'cyan', 'magenta']
for i, key in enumerate(self.points_dict.keys()):
self.colors[key] = color_list[i % len(color_list)]

def _local_to_global(self, local_pts, R, t, scale=None):
if scale is not None:
local_pts = local_pts * scale
global_coords_exp = R @ local_pts.T + t.reshape(-1, 1)
return global_coords_exp.T

def _init_buttons(self):
self.buttons = {}

for key in self.points_dict.keys():
button_name = self._get_button_name(key)
button = QPushButton(f'{button_name}')
button.setCheckable(True)
button.setMaximumWidth(200)
button.clicked.connect(lambda checked, key=key: self._update_plot(key, checked))
self.ui.verticalLayout2.addWidget(button)
self.buttons[key] = button

if self.model.bundle_adjustment and self.calib_completed:
keys_to_check = ['local_pts_BA', 'opt_global_pts']
else:
keys_to_check = ['local_pts', 'global_pts']

for key in keys_to_check:
self.buttons[key].setChecked(True)
self._draw_specific_points(key)

def _get_button_name(self, key):
if key == 'local_pts':
return 'stage'
elif key == 'local_pts_BA':
return 'stage (BA)'
elif key == 'global_pts':
return 'global'
elif key == 'm_global_pts':
return 'global (mean)'
elif key == 'opt_global_pts':
return 'global (BA)'
else:
return key # Default to the key if no match

def _update_plot(self, key, checked):
if checked:
self._draw_specific_points(key)
else:
self._remove_points_from_plot(key)
self._update_canvas()

def _remove_points_from_plot(self, key):
if key in self.points_dict:
del self.traces[key] # Remove from self.traces
self._update_canvas()

def _draw_specific_points(self, key):
pts = self.points_dict[key]
x_rounded = [round(x, 0) for x in pts[:, 0]]
y_rounded = [round(y, 0) for y in pts[:, 1]]
z_rounded = [round(z, 0) for z in pts[:, 2]]

scatter = go.Scatter3d(
x=x_rounded, y=y_rounded, z=z_rounded,
mode='markers+lines',
marker=dict(size=2, color=self.colors[key]),
name=self._get_button_name(key),
hoverinfo='x+y+z'
)
self.traces[key] = scatter # Store the trace in self.traces

def _update_canvas(self):
data = list(self.traces.values())
layout = go.Layout(
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='Z'
),
margin=dict(l=0, r=0, b=0, t=0)
)
fig = go.Figure(data=data, layout=layout)
html_content = fig.to_html(include_plotlyjs='cdn')
self.web_view.setHtml(html_content)

def _on_resize(self, event):
new_size = event.size()
self.web_view.resize(new_size.width(), new_size.height())
self._update_canvas()

# Resize horizontal layout
self.ui.horizontalLayoutWidget.resize(new_size.width(), new_size.height())



Loading
Loading