forked from luizapozzobon/pytorch_mpiifacegaze
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
34 lines (29 loc) · 933 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
import torch
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, num):
self.val = val
self.sum += val * num
self.count += num
self.avg = self.sum / self.count
def convert_to_unit_vector(angles):
x = -torch.cos(angles[:, 0]) * torch.sin(angles[:, 1])
y = -torch.sin(angles[:, 0])
z = -torch.cos(angles[:, 1]) * torch.cos(angles[:, 1])
norm = torch.sqrt(x**2 + y**2 + z**2)
x /= norm
y /= norm
z /= norm
return x, y, z
def compute_angle_error(preds, labels):
pred_x, pred_y, pred_z = convert_to_unit_vector(preds)
label_x, label_y, label_z = convert_to_unit_vector(labels)
angles = pred_x * label_x + pred_y * label_y + pred_z * label_z
return torch.acos(angles) * 180 / np.pi