Skip to content

Commit

Permalink
mirror channel excitation with adaptive max|avg 1d
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 20, 2020
1 parent 5eb9eb5 commit d39e6f0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
15 changes: 11 additions & 4 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,23 @@ def forward(self, x):
class SpatialSLE(nn.Module):
def __init__(self):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool1d(4)
self.max_pool = nn.AdaptiveMaxPool1d(4)

self.net = nn.Sequential(
nn.ConvTranspose2d(2, 4, 4, stride = 2, padding = 1),
nn.ConvTranspose2d(8, 4, 4, stride = 2, padding = 1),
nn.LeakyReLU(0.1),
nn.Conv2d(4, 1, 3, padding = 1),
nn.Sigmoid()
)
def forward(self, x):
pooled_max, _ = x.max(dim = 1, keepdim = True)
pooled_avg = x.mean(dim = 1, keepdim = True)
return self.net(torch.cat((pooled_max, pooled_avg), dim = 1))
b, c, h, w = x.shape
seq = rearrange(x, 'b c h w -> b (h w) c')
pooled_avg = self.avg_pool(seq)
pooled_max = self.max_pool(seq)
pooled_seq = torch.cat((pooled_avg, pooled_max), dim = 2)
x = rearrange(pooled_seq, 'b (h w) c -> b c h w', h = h, w = w)
return self.net(x)

class Generator(nn.Module):
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.8.1'
__version__ = '0.8.2'

0 comments on commit d39e6f0

Please sign in to comment.