From d39e6f09dfb882093aeb6cc2fe380fdfdc0fee4f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 19 Nov 2020 16:53:27 -0800 Subject: [PATCH] mirror channel excitation with adaptive max|avg 1d --- lightweight_gan/lightweight_gan.py | 15 +++++++++++---- lightweight_gan/version.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/lightweight_gan/lightweight_gan.py b/lightweight_gan/lightweight_gan.py index 43cbe8b..5e792c4 100644 --- a/lightweight_gan/lightweight_gan.py +++ b/lightweight_gan/lightweight_gan.py @@ -263,16 +263,23 @@ def forward(self, x): class SpatialSLE(nn.Module): def __init__(self): super().__init__() + self.avg_pool = nn.AdaptiveAvgPool1d(4) + self.max_pool = nn.AdaptiveMaxPool1d(4) + self.net = nn.Sequential( - nn.ConvTranspose2d(2, 4, 4, stride = 2, padding = 1), + nn.ConvTranspose2d(8, 4, 4, stride = 2, padding = 1), nn.LeakyReLU(0.1), nn.Conv2d(4, 1, 3, padding = 1), nn.Sigmoid() ) def forward(self, x): - pooled_max, _ = x.max(dim = 1, keepdim = True) - pooled_avg = x.mean(dim = 1, keepdim = True) - return self.net(torch.cat((pooled_max, pooled_avg), dim = 1)) + b, c, h, w = x.shape + seq = rearrange(x, 'b c h w -> b (h w) c') + pooled_avg = self.avg_pool(seq) + pooled_max = self.max_pool(seq) + pooled_seq = torch.cat((pooled_avg, pooled_max), dim = 2) + x = rearrange(pooled_seq, 'b (h w) c -> b c h w', h = h, w = w) + return self.net(x) class Generator(nn.Module): def __init__( diff --git a/lightweight_gan/version.py b/lightweight_gan/version.py index ef72cc0..4ca39e7 100644 --- a/lightweight_gan/version.py +++ b/lightweight_gan/version.py @@ -1 +1 @@ -__version__ = '0.8.1' +__version__ = '0.8.2'