Skip to content

Commit

Permalink
Add option of calculating EPE split by motion magnitude (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
NeelayS authored Feb 10, 2022
1 parent be1f34a commit 10d19f6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
17 changes: 14 additions & 3 deletions ezflow/utils/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch


def endpointerror(pred, target):
def endpointerror(pred, target, multi_magnitude=False):
"""
Endpoint error
Expand All @@ -24,6 +24,17 @@ def endpointerror(pred, target):
"""Ignore valid mask for EPE calculation."""
target = target[:, :2, :, :]

epe = torch.norm(target - pred, p=2, dim=1).mean()
epe = torch.norm(pred - target, p=2, dim=1)

return epe
if not multi_magnitude:
return epe.mean().item()

epe = epe.view(-1)
multi_magnitude_epe = {
"epe": epe.mean().item(),
"1px": (epe < 1).float().mean().item(),
"3px": (epe < 3).float().mean().item(),
"5px": (epe < 5).float().mean().item(),
}

return multi_magnitude_epe
3 changes: 3 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ def test_endpointerror():
target = torch.rand(4, 2, 256, 256)
_ = endpointerror(pred, target)

multi_magnitude_epe = endpointerror(pred, target, multi_magnitude=True)
assert isinstance(multi_magnitude_epe, dict)

target = torch.rand(
4, 3, 256, 256
) # Ignore valid mask for EPE calculation if target contains it
Expand Down

0 comments on commit 10d19f6

Please sign in to comment.