forked from zalandoresearch/SWARM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpooling.py
112 lines (71 loc) · 2.87 KB
/
pooling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import torch
import torch.nn as nn
class Pooling(nn.Module):
def __init__(self, n_in, n_out, n_dim):
super().__init__()
self.n_in = n_in
self.n_out = n_out
self.n_dim = n_dim
assert self.n_dim==1 or self.n_dim==2
def forward(self, x, mask):
# x is (N, n_in, E) or (N, n_in, E1, E2)
# mask is (N, E) or (N, E1, E2)
raise NotImplementedError("Pooling is only an abstract bas class")
class Mean( Pooling):
def __init__(self, n_in, n_out, n_dim):
super().__init__(n_in, n_out, n_dim)
assert n_in==n_out
def forward(self, x, mask=None):
x_sz = x.size()
if self.n_dim==1:
pooling_dim = 2
else:
pooling_dim = (2,3)
if mask is None:
# 2. compute mean over spatial dimensions
pool = x.mean(dim=pooling_dim, keepdim=True).expand(x_sz)
else:
# 2. compute masked mean over spatial dimensions
mask = mask.view((x_sz[0], 1, *x_sz[2:])).float()
pool = (x * mask).sum(dim=pooling_dim, keepdim=True).expand(x_sz)
pool = pool / mask.sum(dim=pooling_dim, keepdim=True).expand(x_sz)
pool = pool.view(x_sz)
return pool
class Causal( Pooling):
def __init__(self, n_in, n_out, n_dim):
super().__init__(n_in, n_out, n_dim)
assert n_in == n_out
def forward(self, x, mask=None):
if mask is not None:
raise NotImplementedError("Causal pooling is not yet implemented for masked input!")
x_sz = x.size()
# 1. flatten all spatial dimensions
pool = x.view((x_sz[0], self.n_in, -1))
# 2. compute cumulative means of non-successort entities
pool = torch.cumsum(pool, dim=2) / (torch.arange(np.prod(x_sz[2:]), device=pool.device).float() + 1.0).view(1, 1, -1)
# 3. reshape to the original spatial layout
pool = pool.view(x_sz)
return pool
class PoolingMaps(Pooling):
def __init__(self, n_in, n_slices, n_dim):
n_out = n_in-2*n_slices
super().__init__(n_in, n_out, n_dim)
self.n_slices = n_slices
def forward(self, x, mask = None):
# x is (N, n_in+2*n_slices, E)
assert x.size(1) == self.n_in
a = x[:, :self.n_in-2*self.n_slices]
b = x[:, self.n_in-2*self.n_slices:-self.n_slices]
c = x[:, -self.n_slices:]
if mask is not None:
b = b+torch.log(mask.unsqueeze(1).float())
b = torch.softmax(b.view(b.size(0),b.size(1),-1), dim=2).view(b.size())
tmp = a.unsqueeze(1) * b.unsqueeze(2) #(N, n_slices, n_in, E)
#print(tmp.size())
tmp = tmp.sum(dim=3, keepdim=True) #(N, n_slices, n_in, 1)
#print(tmp.size())
tmp = tmp * c.unsqueeze(2)
#print(tmp.size())
out = torch.sum(tmp, dim=1)
return out