From 79e29244be090748ea3eb92070796acd659ac7a4 Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Thu, 30 Nov 2023 16:02:35 -0500 Subject: [PATCH] Add numerical stability logic to arccos --- diffpose/metrics.py | 4 +++- notebooks/api/04_metrics.ipynb | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/diffpose/metrics.py b/diffpose/metrics.py index 122a546..25a1ef2 100644 --- a/diffpose/metrics.py +++ b/diffpose/metrics.py @@ -87,7 +87,9 @@ def forward( r2 = pose_2.get_rotation() rdiff = r1 @ r2.transpose(-1, -2) trace = torch.einsum("...ii", rdiff) - return ((trace - 1) / 2).arccos() + arg = (trace - 1) / 2 + arg = torch.clip(arg, -1, 1) # Ensure argument is within domain + return arg.arccos() class GeodesicTranslation(torch.nn.Module): diff --git a/notebooks/api/04_metrics.ipynb b/notebooks/api/04_metrics.ipynb index c728e11..1036b15 100644 --- a/notebooks/api/04_metrics.ipynb +++ b/notebooks/api/04_metrics.ipynb @@ -201,7 +201,9 @@ " r2 = pose_2.get_rotation()\n", " rdiff = r1 @ r2.transpose(-1, -2)\n", " trace = torch.einsum(\"...ii\", rdiff)\n", - " return ((trace - 1) / 2).arccos()\n", + " arg = (trace - 1) / 2\n", + " arg = torch.clip(arg, -1, 1) # Ensure argument is within domain\n", + " return arg.arccos()\n", "\n", "\n", "class GeodesicTranslation(torch.nn.Module):\n",