-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer_utils.py
48 lines (41 loc) · 1.48 KB
/
trainer_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Common
import torch
import numpy as np
def prob2Ent(prob):
n, c = prob.size()
ent = (-torch.mul(prob, torch.log2(prob + 1e-30)) / np.log2(c))
return ent
def get_current_consistency_weight(weight, epoch, rampup):
"""Consistency ramp-up from https://arxiv.org/abs/1610.02242"""
return weight * sigmoid_rampup(epoch, rampup)
def sigmoid_rampup(current, rampup_length):
if rampup_length == 0:
return 1.0
else:
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current / rampup_length
return float(np.exp(-5.0 * phase * phase))
def weightmap(pred1, pred2):
output = 1.0 - torch.sum((pred1 * pred2), 1).view(pred1.size(0), 1) / \
(torch.norm(pred1, 2, 1) * torch.norm(pred2, 2, 1)).view(pred1.size(0), 1)
return output
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [N, C]
dst: target points, [M, C]
Output:
dist: per-point square distance, [N, M]
"""
N, _ = src.shape
M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(1, 0))
dist += torch.sum(src ** 2, -1).view(N, 1)
dist += torch.sum(dst ** 2, -1).view(1, M)
return dist