From 7a8838f4c14ba9dcee537e8f0ad8966531c710b6 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 5 Aug 2021 13:02:34 -0700 Subject: [PATCH] fix image attention layernorm --- lightweight_gan/lightweight_gan.py | 12 +++++++++++- lightweight_gan/version.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/lightweight_gan/lightweight_gan.py b/lightweight_gan/lightweight_gan.py index 6b8e403..e5855e7 100644 --- a/lightweight_gan/lightweight_gan.py +++ b/lightweight_gan/lightweight_gan.py @@ -156,7 +156,17 @@ def forward(self, x): fn = self.fn if random() < self.prob else self.fn_else return fn(x) -ChanNorm = partial(nn.InstanceNorm2d, affine = True) +class ChanNorm(nn.Module): + def __init__(self, dim, eps = 1e-5): + super().__init__() + self.eps = eps + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x): + std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt() + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) / (std + self.eps) * self.g + self.b class PreNorm(nn.Module): def __init__(self, dim, fn): diff --git a/lightweight_gan/version.py b/lightweight_gan/version.py index 2ff628f..1bf8d73 100644 --- a/lightweight_gan/version.py +++ b/lightweight_gan/version.py @@ -1 +1 @@ -__version__ = '0.20.3' +__version__ = '0.20.4'