Skip to content

Commit

Permalink
Add numerical stability logic to arccos
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Nov 30, 2023
1 parent f6e536e commit 79e2924
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion diffpose/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def forward(
r2 = pose_2.get_rotation()
rdiff = r1 @ r2.transpose(-1, -2)
trace = torch.einsum("...ii", rdiff)
return ((trace - 1) / 2).arccos()
arg = (trace - 1) / 2
arg = torch.clip(arg, -1, 1) # Ensure argument is within domain
return arg.arccos()


class GeodesicTranslation(torch.nn.Module):
Expand Down
4 changes: 3 additions & 1 deletion notebooks/api/04_metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@
" r2 = pose_2.get_rotation()\n",
" rdiff = r1 @ r2.transpose(-1, -2)\n",
" trace = torch.einsum(\"...ii\", rdiff)\n",
" return ((trace - 1) / 2).arccos()\n",
" arg = (trace - 1) / 2\n",
" arg = torch.clip(arg, -1, 1) # Ensure argument is within domain\n",
" return arg.arccos()\n",
"\n",
"\n",
"class GeodesicTranslation(torch.nn.Module):\n",
Expand Down

0 comments on commit 79e2924

Please sign in to comment.