diff --git a/src/beignet/nn/__init__.py b/src/beignet/nn/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/beignet/nn/functional/__init__.py b/src/beignet/nn/functional/__init__.py new file mode 100644 index 0000000000..4d7a256e54 --- /dev/null +++ b/src/beignet/nn/functional/__init__.py @@ -0,0 +1,5 @@ +from ._torsion_angle_loss import torsion_angle_loss + +__all__ = [ + "torsion_angle_loss", +] diff --git a/src/beignet/nn/functional/_torsion_angle_loss.py b/src/beignet/nn/functional/_torsion_angle_loss.py new file mode 100644 index 0000000000..b9c4d333e6 --- /dev/null +++ b/src/beignet/nn/functional/_torsion_angle_loss.py @@ -0,0 +1,36 @@ +from typing import Tuple + +import torch +from torch import Tensor + + +def torsion_angle_loss(input, target: Tuple[Tensor, Tensor]) -> Tensor: + """ + + Parameters + ---------- + input + target + + Returns + ------- + + """ + a = input / torch.norm(input, dim=-1, keepdim=True) + + b, c = target + + x = torch.mean( + torch.minimum( + torch.square(torch.norm(a - b, dim=-1)), + torch.square(torch.norm(a - c, dim=-1)), + ), + dim=[-1, -2], + ) + + y = torch.mean( + torch.abs(torch.norm(input, dim=-1) - 1), + dim=[-1, -2], + ) + + return x + 0.02 * y diff --git a/tests/beignet/nn/functional/test__torsion_angle_loss.py b/tests/beignet/nn/functional/test__torsion_angle_loss.py new file mode 100644 index 0000000000..e513675676 --- /dev/null +++ b/tests/beignet/nn/functional/test__torsion_angle_loss.py @@ -0,0 +1,17 @@ +import beignet.nn.functional +import torch + + +def test_torsion_angle_loss(): + input = torch.ones([1, 1, 7, 2]) + + target = torch.zeros([1, 1, 7, 2]), torch.zeros([1, 1, 7, 2]) + + output = beignet.nn.functional.torsion_angle_loss(input, target) + + torch.testing.assert_close( + output, + torch.tensor([1.0]), + rtol=0.01, + atol=0.01, + )