-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
61 lines (46 loc) · 2.31 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def temporal_loss(slots):
bs, nf, ns, ds = slots.shape
slot0 = slots[:,:-1,:,:].reshape(bs*(nf-1),ns,ds)
slot1 = slots[:,1:,:,:].reshape(bs*(nf-1),ns,ds)
slot1 = slot1.permute(0,2,1)
scores = torch.einsum('bij,bjk->bik', slot0, slot1)
scores = scores.softmax(dim=1) + 0.0001
gt = torch.eye(ns).unsqueeze(0).repeat(bs*(nf-1),1,1).to(device)
return torch.mean((gt-scores)**2)
def adjusted_rand_index(true_mask, pred_mask):
_, n_points, n_true_groups = true_mask.shape
n_pred_groups = pred_mask.shape[-1]
assert not (n_points <= n_true_groups and n_points <= n_pred_groups), ("adjusted_rand_index requires n_groups < n_points. We don't handle the special cases that can occur when you have one cluster per datapoint.")
true_group_ids = torch.argmax(true_mask, -1)
pred_group_ids = torch.argmax(pred_mask, -1)
true_mask_oh = true_mask.to(torch.float32)
pred_mask_oh = F.one_hot(pred_group_ids, n_pred_groups).to(torch.float32)
n_points = torch.sum(true_mask_oh, dim=[1, 2]).to(torch.float32)
nij = torch.einsum('bji,bjk->bki', pred_mask_oh, true_mask_oh)
a = torch.sum(nij, dim=1)
b = torch.sum(nij, dim=2)
rindex = torch.sum(nij * (nij - 1), dim=[1, 2])
aindex = torch.sum(a * (a - 1), dim=1)
bindex = torch.sum(b * (b - 1), dim=1)
expected_rindex = aindex * bindex / (n_points*(n_points-1))
max_rindex = (aindex + bindex) / 2
ari = (rindex - expected_rindex) / (max_rindex - expected_rindex+0.000000000001)
_all_equal = lambda values: torch.all(torch.eq(values, values[..., :1]), dim=-1)
both_single_cluster = torch.logical_and(_all_equal(true_group_ids), _all_equal(pred_group_ids))
return torch.where(both_single_cluster, torch.ones_like(ari), ari)
def build_grid(resolution):
ranges = [torch.linspace(0.0, 1.0, steps=res) for res in resolution]
grid = torch.meshgrid(*ranges)
grid = torch.stack(grid, dim=-1)
grid = torch.reshape(grid, [resolution[0], resolution[1], -1])
grid = grid.unsqueeze(0)
return torch.cat([grid, 1.0 - grid], dim=-1)
def first(x):
return next(iter(x))
def only(x):
materialized_x = list(x)
assert len(materialized_x) == 1
return materialized_x[0]