Skip to content

Commit

Permalink
Merge pull request #108 from decargroup/handle_2d_plot_poses
Browse files Browse the repository at this point in the history
Add basic 2D plotting case to plot_poses function
  • Loading branch information
CharlesCossette authored Jan 7, 2024
2 parents 357133c + f9b0ea8 commit 35b81c6
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 51 deletions.
145 changes: 94 additions & 51 deletions navlie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,17 +858,21 @@ def plot_poses(
step: int = 5,
label: str = None,
linewidth=None,
plot_2d: bool =False,
):
"""
Plots position trajectory in 3D
and poses along the trajectory as triads.
Plots a pose trajectory, representing the attitudes by triads
plotted along the trajectory.
The poses may be either elements of SE(2),
representing planar 2D poses, or elements of SE(3), representing 3D poses.
Parameters
----------
poses : List[SE3State]
A list objects containing a ``position`` property (numpy array of size 3)
and an ``attitude`` (3 x 3 numpy array) property representing the rotation
matrix :math:``\mathbf{C}_{ab}``.
poses : List[Union[SE2State, SE3State]]
A list objects containing a ``position`` property and an attitude
property, representing the rotation matrix :math:``\mathbf{C}_{ab}``.
Can either be 2D or 3D poses.
ax : plt.Axes, optional
Axes to plot on, if none, 3D axes are created.
line_color : str, optional
Expand All @@ -881,23 +885,36 @@ def plot_poses(
Step size in list of poses, by default 5. If None, no triads are plotted.
label : str, optional
Optional label for the triad
plot_2d: bool, optional
Flag to plot a 3D pose trajectory in 2D bird's eye view.
"""
# TODO. handle 2D case
if isinstance(poses, GaussianResultList):
poses = poses.state

if isinstance(poses, StateWithCovariance):
poses = [poses.state]

if isinstance(poses, np.ndarray):
poses = poses.tolist()

if not isinstance(poses, list):
poses = [poses]


# Check if poses are in 2D or 3D
if poses[0].position.size == 2:
plot_2d = True

# Check if provided axes are in 3D
if ax is not None:
if ax.name == "3d":
plot_2d = False

if ax is None:
fig = plt.figure()
ax = plt.axes(projection="3d")
if plot_2d:
ax = plt.axes()
else:
ax = plt.axes(projection="3d")
else:
fig = ax.get_figure()

Expand All @@ -908,51 +925,77 @@ def plot_poses(

# Plot a line for the positions
r = np.array([pose.position for pose in poses])
ax.plot3D(r[:, 0], r[:, 1], r[:, 2], color=line_color, label=label)
if plot_2d:
ax.plot(r[:, 0], r[:, 1], color=line_color, label=label)
else:
ax.plot3D(r[:, 0], r[:, 1], r[:, 2], color=line_color, label=label)

# Plot triads using quiver
if step is not None:
C = np.array([poses[i].attitude.T for i in range(0, len(poses), step)])
r = np.array([poses[i].position for i in range(0, len(poses), step)])
x, y, z = r[:, 0], r[:, 1], r[:, 2]
ax.quiver(
x,
y,
z,
C[:, 0, 0],
C[:, 0, 1],
C[:, 0, 2],
color=colors[0],
length=arrow_length,
arrow_length_ratio=0.1,
linewidths=linewidth,
)
ax.quiver(
x,
y,
z,
C[:, 1, 0],
C[:, 1, 1],
C[:, 1, 2],
color=colors[1],
length=arrow_length,
arrow_length_ratio=0.1,
linewidths=linewidth,
)
ax.quiver(
x,
y,
z,
C[:, 2, 0],
C[:, 2, 1],
C[:, 2, 2],
color=colors[2],
length=arrow_length,
arrow_length_ratio=0.1,
linewidths=linewidth,
)
if plot_2d:
x, y = r[:, 0], r[:, 1]
ax.quiver(
x, y,
C[:, 0, 0],
C[:, 0, 1],
color=colors[0],
scale=20.0,
headwidth=2,
)

ax.quiver(
x, y,
C[:, 1, 0],
C[:, 1, 1],
color=colors[1],
scale=20.0,
headwidth=2,
)
else:
x, y, z = r[:, 0], r[:, 1], r[:, 2]
ax.quiver(
x,
y,
z,
C[:, 0, 0],
C[:, 0, 1],
C[:, 0, 2],
color=colors[0],
length=arrow_length,
arrow_length_ratio=0.1,
linewidths=linewidth,
)
ax.quiver(
x,
y,
z,
C[:, 1, 0],
C[:, 1, 1],
C[:, 1, 2],
color=colors[1],
length=arrow_length,
arrow_length_ratio=0.1,
linewidths=linewidth,
)
ax.quiver(
x,
y,
z,
C[:, 2, 0],
C[:, 2, 1],
C[:, 2, 2],
color=colors[2],
length=arrow_length,
arrow_length_ratio=0.1,
linewidths=linewidth,
)

set_axes_equal(ax)
if plot_2d:
ax.axis("equal")
else:
set_axes_equal(ax)
return fig, ax


Expand Down
51 changes: 51 additions & 0 deletions tests/unit/test_plot_poses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
import navlie as nav
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")

def test_plot_poses_3d():
N = 500
x0 = nav.lib.SE3State([0.3, 3, 4, 0, 0, 0], direction="right")
process_model = nav.lib.BodyFrameVelocity(np.zeros(6))

dt = 0.1
T = 20
stamps = np.arange(0, T, dt)

x_traj = [x0.copy()]
u = nav.lib.VectorInput([0.1, 0.3, 0, 1, 0, 0])
x = x0.copy()
for _ in stamps:
x = process_model.evaluate(x, u, dt)
x_traj.append(x.copy())

# Plot the trajectory in 3D
fig, ax = nav.plot_poses(x_traj)

# Plot SE(3) poses in 2D
fig, ax2 = nav.plot_poses(x_traj, plot_2d=True)

def test_plot_poses_2d():
x0 = nav.lib.SE2State([0.3, 3, 0], direction="right")
process_model = nav.lib.BodyFrameVelocity(np.zeros(3))

dt = 0.1
T = 50
stamps = np.arange(0, T, dt)

x_traj = [x0.copy()]
u = nav.lib.VectorInput([0.1, 0.3, 0])
x = x0.copy()
for _ in stamps:
x = process_model.evaluate(x, u, dt)
x_traj.append(x.copy())

# Test plotting SE(2) poses
fig, ax = nav.plot_poses(x_traj)

if __name__ == "__main__":
test_plot_poses_3d()
test_plot_poses_2d()
plt.show()

0 comments on commit 35b81c6

Please sign in to comment.