Skip to content

Commit

Permalink
add frequency channel attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 16, 2021
1 parent eba115b commit e275632
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 10 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,16 @@ If you want the current state of the art GAN, you can find it at https://github.
}
```

```bibtex
@misc{qin2020fcanet,
title={FcaNet: Frequency Channel Attention Networks},
author={Zequn Qin and Pengyi Zhang and Fei Wu and Xi Li},
year={2020},
eprint={2012.11879},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

```bibtex
@misc{sinha2020topk,
Expand Down
2 changes: 2 additions & 0 deletions lightweight_gan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def train_from_folder(
generate_interpolation = False,
aug_test = False,
attn_res_layers = [32],
freq_chan_attn = False,
disc_output_size = 1,
antialias = False,
interpolation_num_steps = 100,
Expand All @@ -120,6 +121,7 @@ def train_from_folder(
batch_size = batch_size,
gradient_accumulate_every = gradient_accumulate_every,
attn_res_layers = cast_list(attn_res_layers),
freq_chan_attn = freq_chan_attn,
disc_output_size = disc_output_size,
antialias = antialias,
image_size = image_size,
Expand Down
82 changes: 73 additions & 9 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from lightweight_gan.version import __version__

from tqdm import tqdm
from einops import rearrange
from einops import rearrange, reduce

from adabelief_pytorch import AdaBelief
from gsa_pytorch import GSA
Expand Down Expand Up @@ -289,10 +289,12 @@ def forward(self, images, prob = 0., types = [], detach = False, **kwargs):
def upsample(scale_factor = 2):
return nn.Upsample(scale_factor = scale_factor)

# classes
# squeeze excitation classes

# global context network
# https://arxiv.org/abs/2012.13375
# similar to squeeze-excite, but with a simplified attention pooling and a subsequent layer norm

class GlobalContext(nn.Module):
def __init__(
self,
Expand All @@ -317,6 +319,52 @@ def forward(self, x):
out = out.unsqueeze(-1)
return self.net(out)

# frequency channel attention
# https://arxiv.org/abs/2012.11879

def get_1d_dct(i, freq, L):
result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
return result * (1 if freq == 0 else math.sqrt(2))

def get_dct_weights(width, channel, fidx_u, fidx_v):
dct_weights = torch.zeros(1, channel, width, width)
c_part = channel // len(fidx_u)

for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
for x in range(width):
for y in range(width):
coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
dct_weights[:, i * c_part: (i + 1) * c_part, x, y] = coor_value

return dct_weights

class FCANet(nn.Module):
def __init__(
self,
*,
chan_in,
chan_out,
reduction = 4,
width
):
super().__init__()
freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal
dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w])
self.register_buffer('dct_weights', dct_weights)

self.net = nn.Sequential(
nn.Conv2d(chan_in, chan_out // reduction, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_out // reduction, chan_out, 1),
nn.Sigmoid()
)

def forward(self, x):
x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1 = 1, w1 = 1)
return self.net(x)

# generative adversarial network

class Generator(nn.Module):
def __init__(
self,
Expand All @@ -327,7 +375,8 @@ def __init__(
fmap_inverse_coef = 12,
transparent = False,
greyscale = False,
attn_res_layers = []
attn_res_layers = [],
freq_chan_attn = False
):
super().__init__()
resolution = log2(image_size)
Expand Down Expand Up @@ -378,10 +427,17 @@ def __init__(
residual_layer = self.sle_map[res]
sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]

sle = GlobalContext(
chan_in = chan_out,
chan_out = sle_chan_out
)
if freq_chan_attn:
sle = FCANet(
chan_in = chan_out,
chan_out = sle_chan_out,
width = 2 ** (res + 1)
)
else:
sle = GlobalContext(
chan_in = chan_out,
chan_out = sle_chan_out
)

layer = nn.ModuleList([
nn.Sequential(
Expand Down Expand Up @@ -636,6 +692,7 @@ def __init__(
greyscale = False,
disc_output_size = 5,
attn_res_layers = [],
freq_chan_attn = False,
ttur_mult = 1.,
lr = 2e-4,
rank = 0,
Expand All @@ -652,7 +709,8 @@ def __init__(
fmap_inverse_coef = fmap_inverse_coef,
transparent = transparent,
greyscale = greyscale,
attn_res_layers = attn_res_layers
attn_res_layers = attn_res_layers,
freq_chan_attn = freq_chan_attn
)

self.G = Generator(**G_kwargs)
Expand Down Expand Up @@ -729,6 +787,7 @@ def __init__(
gp_weight = 10,
gradient_accumulate_every = 1,
attn_res_layers = [],
freq_chan_attn = False,
disc_output_size = 5,
antialias = False,
lr = 2e-4,
Expand Down Expand Up @@ -796,6 +855,8 @@ def __init__(
self.generator_top_k_frac = 0.5

self.attn_res_layers = attn_res_layers
self.freq_chan_attn = freq_chan_attn

self.disc_output_size = disc_output_size
self.antialias = antialias

Expand Down Expand Up @@ -860,6 +921,7 @@ def init_GAN(self):
lr = self.lr,
latent_dim = self.latent_dim,
attn_res_layers = self.attn_res_layers,
freq_chan_attn = self.freq_chan_attn,
image_size = self.image_size,
ttur_mult = self.ttur_mult,
fmap_max = self.fmap_max,
Expand Down Expand Up @@ -889,6 +951,7 @@ def load_config(self):
self.disc_output_size = config['disc_output_size']
self.greyscale = config.pop('greyscale', False)
self.attn_res_layers = config.pop('attn_res_layers', [])
self.freq_chan_attn = config.pop('freq_chan_attn', False)
self.optimizer = config.pop('optimizer', 'adam')
self.fmap_max = config.pop('fmap_max', 512)
del self.GAN
Expand All @@ -902,7 +965,8 @@ def config(self):
'syncbatchnorm': self.syncbatchnorm,
'disc_output_size': self.disc_output_size,
'optimizer': self.optimizer,
'attn_res_layers': self.attn_res_layers
'attn_res_layers': self.attn_res_layers,
'freq_chan_attn': self.freq_chan_attn
}

def set_data_src(self, folder):
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.16.4'
__version__ = '0.17.0'

0 comments on commit e275632

Please sign in to comment.