diff --git a/ezflow/utils/metrics.py b/ezflow/utils/metrics.py index 6cc5ec1b..b8abae73 100644 --- a/ezflow/utils/metrics.py +++ b/ezflow/utils/metrics.py @@ -1,7 +1,7 @@ import torch -def endpointerror(pred, target): +def endpointerror(pred, target, multi_magnitude=False): """ Endpoint error @@ -24,6 +24,17 @@ def endpointerror(pred, target): """Ignore valid mask for EPE calculation.""" target = target[:, :2, :, :] - epe = torch.norm(target - pred, p=2, dim=1).mean() + epe = torch.norm(pred - target, p=2, dim=1) - return epe + if not multi_magnitude: + return epe.mean().item() + + epe = epe.view(-1) + multi_magnitude_epe = { + "epe": epe.mean().item(), + "1px": (epe < 1).float().mean().item(), + "3px": (epe < 3).float().mean().item(), + "5px": (epe < 5).float().mean().item(), + } + + return multi_magnitude_epe diff --git a/tests/test_utils.py b/tests/test_utils.py index ffa871e8..6db711e4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,6 +9,9 @@ def test_endpointerror(): target = torch.rand(4, 2, 256, 256) _ = endpointerror(pred, target) + multi_magnitude_epe = endpointerror(pred, target, multi_magnitude=True) + assert isinstance(multi_magnitude_epe, dict) + target = torch.rand( 4, 3, 256, 256 ) # Ignore valid mask for EPE calculation if target contains it