Skip to content

vectorize more quaternion operations #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/beignet/_apply_rotation_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def apply_rotation_matrix(
Rotated vectors.
"""
if inverse:
output = torch.einsum("ikj, ik -> ij", rotation, input)
output = torch.einsum("...kj, ...k -> ...j", rotation, input)
else:
output = torch.einsum("ijk, ik -> ij", rotation, input)
output = torch.einsum("...jk, ...k -> ...j", rotation, input)

return output
28 changes: 28 additions & 0 deletions src/beignet/_canonicalize_quaternion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
from torch import Tensor


def canonicalize_quaternion(input: Tensor):
"""Canonicalize the input quaternion

Parameters
----------
input: Tensor, shape=(...,4)
Rotation quaternion

Returns
-------
output: Tensor, shape=(...,4)
Canonicalized quaternion

The caonicalized quaternion is chosen from :math:`{q, -q}`
such that the :math:`w` term is positive.
If the :math:`w` term is :math:`0`, then the rotation quaternion is
chosen such that the first non-zero term of the :math:`x`, :math:`y`,
and :math:`z` terms is positive.
"""

a, b, c, d = torch.unbind(input, dim=-1)
mask = (d == 0) & ((a == 0) & ((b == 0) & (c < 0) | (b < 0)) | (a < 0)) | (d < 0)

return torch.where(mask[..., None], -input, input)
45 changes: 13 additions & 32 deletions src/beignet/_compose_quaternion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from torch import Tensor

from ._canonicalize_quaternion import canonicalize_quaternion


def compose_quaternion(
input: Tensor,
Expand Down Expand Up @@ -31,43 +33,22 @@ def compose_quaternion(
output : Tensor, shape=(..., 4)
Composed rotation quaternions.
"""
output = torch.empty(
[max(input.shape[0], other.shape[0]), 4],
dtype=input.dtype,
layout=input.layout,
device=input.device,
)

for j in range(max(input.shape[0], other.shape[0])):
a = input[j, 0]
b = input[j, 1]
c = input[j, 2]
d = input[j, 3]

p = other[j, 0]
q = other[j, 1]
r = other[j, 2]
s = other[j, 3]

t = output[j, 0]
u = output[j, 1]
v = output[j, 2]
w = output[j, 3]
a, b, c, d = torch.unbind(input, dim=-1)
p, q, r, s = torch.unbind(other, dim=-1)

output[j, 0] = d * p + s * a + b * r - c * q
output[j, 1] = d * q + s * b + c * p - a * r
output[j, 2] = d * r + s * c + a * q - b * p
output[j, 3] = d * s - a * p - b * q - c * r
t = d * p + s * a + b * r - c * q
u = d * q + s * b + c * p - a * r
v = d * r + s * c + a * q - b * p
w = d * s - a * p - b * q - c * r

x = torch.sqrt(t**2.0 + u**2.0 + v**2.0 + w**2.0)
output = torch.stack([t, u, v, w], dim=-1)

if x == 0.0:
output[j] = torch.nan
x = torch.sqrt(torch.sum(torch.square(output), dim=-1, keepdim=True))

output[j] = output[j] / x
output = torch.where(x == 0.0, torch.nan, output / x)

if canonical:
if w == 0 and (t == 0 and (u == 0 and v < 0 or u < 0) or t < 0) or w < 0:
output[j] = -output[j]
if canonical:
return canonicalize_quaternion(output)

return output
7 changes: 6 additions & 1 deletion src/beignet/_invert_quaternion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from torch import Tensor

from ._canonicalize_quaternion import canonicalize_quaternion


def invert_quaternion(
input: Tensor,
Expand All @@ -26,6 +28,9 @@ def invert_quaternion(
inverted_quaternions : Tensor, shape (..., 4)
Inverted rotation quaternions.
"""
input[:, :3] = -input[:, :3]
input[..., :3] = -input[..., :3]

if canonical:
return canonicalize_quaternion(input)

return input
9 changes: 0 additions & 9 deletions src/beignet/_quaternion_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

def quaternion_identity(
size: int,
canonical: bool | None = False,
*,
out: Tensor | None = None,
dtype: torch.dtype | None = None,
Expand All @@ -20,14 +19,6 @@ def quaternion_identity(
size : int
Output size.

canonical : bool, optional
Whether to map the redundant double cover of rotation space to a unique
canonical single cover. If `True`, then the rotation quaternion is
chosen from :math:`{q, -q}` such that the :math:`w` term is positive.
If the :math:`w` term is :math:`0`, then the rotation quaternion is
chosen such that the first non-zero term of the :math:`x`, :math:`y`,
and :math:`z` terms is positive.

out : Tensor, optional
Output tensor. Default, `None`.

Expand Down
20 changes: 2 additions & 18 deletions src/beignet/_quaternion_magnitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,6 @@ def quaternion_magnitude(input: Tensor, canonical=False) -> Tensor:
output : Tensor, shape=(...)
Angles in radians. Magnitudes will be in the range :math:`[0, \pi]`.
"""
output = torch.empty(
input.shape[0],
dtype=input.dtype,
layout=input.layout,
device=input.device,
requires_grad=input.requires_grad,
)

for j in range(input.shape[0]):
a = input[j, 0]
b = input[j, 1]
c = input[j, 2]
d = input[j, 3]

x = torch.atan2(torch.sqrt(a**2 + b**2 + c**2), torch.abs(d))

output[j] = x * 2.0

return output
a, b, c, d = torch.unbind(input, dim=-1)
return 2 * torch.atan2(torch.sqrt(a**2 + b**2 + c**2), torch.abs(d))
40 changes: 13 additions & 27 deletions src/beignet/_quaternion_to_rotation_vector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from torch import Tensor

from ._canonicalize_quaternion import canonicalize_quaternion


def quaternion_to_rotation_vector(
input: Tensor,
Expand All @@ -21,37 +23,21 @@ def quaternion_to_rotation_vector(
output : Tensor, shape=(..., 3)
Rotation vectors.
"""
output = torch.empty(
[input.shape[0], 3],
dtype=input.dtype,
layout=input.layout,
device=input.device,
)

for j in range(input.shape[0]):
a = input[j, 0]
b = input[j, 1]
c = input[j, 2]
d = input[j, 3]

if d == 0 and (a == 0 and (b == 0 and c < 0 or b < 0) or a < 0) or d < 0:
input[j] = -input[j]
input = canonicalize_quaternion(input)

t = input[j, 0] ** 2.0
u = input[j, 1] ** 2.0
v = input[j, 2] ** 2.0
w = input[j, 3] ** 1.0
a, b, c, d = torch.unbind(input, dim=-1)

y = 2.0 * torch.atan2(torch.sqrt(t + u + v), w)

if y < 0.001:
y = 2.0 + y**2.0 / 12 + 7 * y**2.0 * y**2.0 / 2880
else:
y = y / torch.sin(y / 2.0)
y = 2 * torch.atan2(
torch.sqrt(torch.square(a) + torch.square(b) + torch.square(c)), d
)
y2 = torch.square(y)

output[j] = input[j, :-1] * y
scale = torch.where(
y < 0.001, 2.0 + y2 / 12 + 7 * y2 * y2 / 2880, y / torch.sin(y / 2.0)
)

if degrees:
output = torch.rad2deg(output)
scale = torch.rad2deg(scale)

return output
return scale[..., None] * input[..., :-1]
4 changes: 3 additions & 1 deletion tests/beignet/structure/test__dockq.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def test_dockq_batched(dockq_test_data_path):
lambda x: torch.unbind(x, dim=0), results
)

assert results1 == pytest.approx(results2, abs=1e-6)
assert optree.tree_map(lambda x: x.item(), results1) == pytest.approx(
optree.tree_map(lambda x: x.item(), results2), abs=1e-6, rel=1e-6
)
assert results1["model_contacts"].item() == ref["best_result"]["BC"]["model_total"]
assert results1["native_contacts"].item() == ref["best_result"]["BC"]["nat_total"]
assert results1["shared_contacts"].item() == ref["best_result"]["BC"]["nat_correct"]
Expand Down
5 changes: 1 addition & 4 deletions tests/beignet/test__quaternion_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@ def _strategy(function):

rotation = Rotation.identity(size)

canonical = function(hypothesis.strategies.booleans())

return (
{
"size": size,
"canonical": canonical,
"dtype": torch.float64,
},
torch.from_numpy(rotation.as_quat(canonical)),
torch.from_numpy(rotation.as_quat()),
)


Expand Down
Loading