Skip to content

Commit

Permalink
Merge pull request #9 from lucidrains/pw/add-spatial-excitation-residual
Browse files Browse the repository at this point in the history
Pw/add spatial excitation residual
  • Loading branch information
lucidrains authored Nov 20, 2020
2 parents fa3bb6b + d39e6f0 commit 9fcfb7e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
28 changes: 20 additions & 8 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,29 +245,41 @@ def __init__(
chan_out
):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((4, 4))
self.max_pool = nn.AdaptiveMaxPool2d((4, 4))

chan_intermediate = chan_in // 2
self.net = nn.Sequential(
nn.AdaptiveAvgPool2d((4, 4)),
nn.Conv2d(chan_in, chan_out, 4),
nn.Conv2d(chan_in * 2, chan_intermediate, 4),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_out, chan_out, 1),
nn.Conv2d(chan_intermediate, chan_out, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x)
pooled_avg = self.avg_pool(x)
pooled_max = self.max_pool(x)
return self.net(torch.cat((pooled_max, pooled_avg), dim = 1))

class SpatialSLE(nn.Module):
def __init__(self):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool1d(4)
self.max_pool = nn.AdaptiveMaxPool1d(4)

self.net = nn.Sequential(
nn.ConvTranspose2d(2, 4, 4, stride = 2, padding = 1),
nn.ConvTranspose2d(8, 4, 4, stride = 2, padding = 1),
nn.LeakyReLU(0.1),
nn.Conv2d(4, 1, 3, padding = 1),
nn.Sigmoid()
)
def forward(self, x):
max_pool, _ = x.max(dim = 1, keepdim = True)
avg_pool = x.mean(dim = 1, keepdim = True)
return self.net(torch.cat((max_pool, avg_pool), dim = 1))
b, c, h, w = x.shape
seq = rearrange(x, 'b c h w -> b (h w) c')
pooled_avg = self.avg_pool(seq)
pooled_max = self.max_pool(seq)
pooled_seq = torch.cat((pooled_avg, pooled_max), dim = 2)
x = rearrange(pooled_seq, 'b (h w) c -> b c h w', h = h, w = w)
return self.net(x)

class Generator(nn.Module):
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.8.0'
__version__ = '0.8.2'

0 comments on commit 9fcfb7e

Please sign in to comment.