-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcapsule_layer.py
56 lines (43 loc) · 2.29 KB
/
capsule_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
def softmax(input, dim=1):
transposed_input = input.transpose(dim, len(input.size())-1)
softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)))
return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size())-1)
class CapsuleLayer(nn.Module):
def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, kernel_size=None, stride=None, num_iterations=3):
super(CapsuleLayer, self).__init__()
self.num_capsules = num_capsules
self.num_route_nodes = num_route_nodes
self.num_iterations = num_iterations
if num_route_nodes != -1:
# Routing between capsules
self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels))
else:
# Lower level is a conv net
self.capsules = nn.ModuleList([nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) for _ in range(num_capsules)])
def squash(self, tensor, dim=-1):
squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
scale = squared_norm / (1 + squared_norm)
return scale * tensor / torch.sqrt(squared_norm)
def forward(self, x):
if self.num_route_nodes != -1:
# Inputs * Weights
# print(x[None, :, :, None, :].size())
# print(self.route_weights[:, None, :, :,:].size())
priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :] # GG? Not this time!
# Routing algorithm
logits = Variable(torch.zeros(*priors.size())).cuda()
for i in range(self.num_iterations):
probs = softmax(logits, dim=2)
outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))
if i != self.num_iterations - 1:
delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
logits = logits + delta_logits
else:
outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
outputs = torch.cat(outputs, dim=-1)
outputs = self.squash(outputs)
return outputs