Skip to content

Commit

Permalink
replace hamburger with linear attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 21, 2020
1 parent 6efb240 commit 5756417
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 37 deletions.
4 changes: 2 additions & 2 deletions lightweight_gan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def train_from_folder(
evaluate_every = 1000,
generate = False,
generate_interpolation = False,
hamburger_res_layers = [32],
attn_res_layers = [32],
sle_spatial = False,
disc_output_size = 1,
interpolation_num_steps = 100,
Expand All @@ -104,7 +104,7 @@ def train_from_folder(
models_dir = models_dir,
batch_size = batch_size,
gradient_accumulate_every = gradient_accumulate_every,
hamburger_res_layers = cast_list(hamburger_res_layers),
attn_res_layers = cast_list(attn_res_layers),
sle_spatial = sle_spatial,
disc_output_size = disc_output_size,
image_size = image_size,
Expand Down
68 changes: 36 additions & 32 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from einops import rearrange
from pytorch_fid import fid_score

from hamburger_pytorch import Hamburger
from adabelief_pytorch import AdaBelief
from gsa_pytorch import GSA

# asserts

Expand Down Expand Up @@ -136,6 +136,15 @@ def forward(self, x):
fn = self.fn if random() < self.prob else self.fn_else
return fn(x)

class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.tensor(1e-3))

def forward(self, x):
return self.g * self.fn(x)

# dataset

def convert_image_to(img_type, image):
Expand Down Expand Up @@ -288,7 +297,7 @@ def __init__(
fmap_max = 512,
fmap_inverse_coef = 12,
transparent = False,
hamburger_res_layers = [],
attn_res_layers = [],
sle_spatial = False
):
super().__init__()
Expand Down Expand Up @@ -319,17 +328,14 @@ def __init__(
self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map))
self.sle_map = dict(self.sle_map)

self.num_layers_spatial_res = 4
self.num_layers_spatial_res = 2

for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features):
image_width = 2 ** res

hamburger = None
if image_width in hamburger_res_layers:
hamburger = Hamburger(
dim = chan_in,
n = image_width ** 2
)
attn = None
if image_width in attn_res_layers:
attn = Rezero(GSA(dim = chan_in, rel_pos_length = image_width))

sle = None
if res in self.sle_map:
Expand All @@ -354,7 +360,7 @@ def __init__(
),
sle,
sle_spatial,
hamburger
attn
])
self.layers.append(layer)

Expand All @@ -368,13 +374,13 @@ def forward(self, x):
residuals = dict()
spatial_residuals = dict()

for (res, (up, sle, sle_spatial, hamburger)) in zip(self.res_layers, self.layers):
for (res, (up, sle, sle_spatial, attn)) in zip(self.res_layers, self.layers):
if exists(sle_spatial):
spatial_res = sle_spatial(x)
spatial_residuals[res + self.num_layers_spatial_res] = spatial_res

if exists(hamburger):
x = hamburger(x) + x
if exists(attn):
x = attn(x) + x

x = up(x)

Expand Down Expand Up @@ -431,7 +437,7 @@ def __init__(
fmap_inverse_coef = 12,
transparent = False,
disc_output_size = 5,
hamburger_res_layers = []
attn_res_layers = []
):
super().__init__()
resolution = log2(image_size)
Expand Down Expand Up @@ -466,15 +472,13 @@ def __init__(
))

self.residual_layers = nn.ModuleList([])

for (res, ((_, chan_in), (_, chan_out))) in zip(non_residual_resolutions, chan_in_out):
image_width = 2 ** resolution

hamburger = None
if image_width in hamburger_res_layers:
hamburger = Hamburger(
dim = chan_in,
n = image_width ** 2
)
attn = None
if image_width in attn_res_layers:
attn = Rezero(GSA(dim = chan_in, batch_norm = False, rel_pos_length = image_width))

self.residual_layers.append(nn.ModuleList([
nn.Sequential(
Expand All @@ -488,7 +492,7 @@ def __init__(
nn.Conv2d(chan_in, chan_out, 1),
nn.LeakyReLU(0.1),
),
hamburger
attn
]))

last_chan = features[-1][-1]
Expand Down Expand Up @@ -516,9 +520,9 @@ def forward(self, x, calc_aux_loss = False):

layer_outputs = []

for (layer, residual_layer, hamburger) in self.residual_layers:
if exists(hamburger):
x = hamburger(x) + x
for (layer, residual_layer, attn) in self.residual_layers:
if exists(attn):
x = attn(x) + x

x = layer(x) + residual_layer(x)
layer_outputs.append(x)
Expand Down Expand Up @@ -567,7 +571,7 @@ def __init__(
fmap_inverse_coef = 12,
transparent = False,
disc_output_size = 5,
hamburger_res_layers = [],
attn_res_layers = [],
sle_spatial = False,
ttur_mult = 1.,
lr = 2e-4,
Expand All @@ -584,7 +588,7 @@ def __init__(
fmap_max = fmap_max,
fmap_inverse_coef = fmap_inverse_coef,
transparent = transparent,
hamburger_res_layers = hamburger_res_layers,
attn_res_layers = attn_res_layers,
sle_spatial = sle_spatial
)

Expand All @@ -595,7 +599,7 @@ def __init__(
fmap_max = fmap_max,
fmap_inverse_coef = fmap_inverse_coef,
transparent = transparent,
hamburger_res_layers = hamburger_res_layers,
attn_res_layers = attn_res_layers,
disc_output_size = disc_output_size
)

Expand Down Expand Up @@ -658,7 +662,7 @@ def __init__(
batch_size = 4,
mixed_prob = 0.9,
gradient_accumulate_every = 1,
hamburger_res_layers = [],
attn_res_layers = [],
sle_spatial = False,
disc_output_size = 5,
lr = 2e-4,
Expand Down Expand Up @@ -711,7 +715,7 @@ def __init__(
self.generator_top_k_gamma = 0.99
self.generator_top_k_frac = 0.5

self.hamburger_res_layers = hamburger_res_layers
self.attn_res_layers = attn_res_layers
self.sle_spatial = sle_spatial
self.disc_output_size = disc_output_size

Expand Down Expand Up @@ -766,7 +770,7 @@ def init_GAN(self):
optimizer=self.optimizer,
lr = self.lr,
latent_dim = self.latent_dim,
hamburger_res_layers = self.hamburger_res_layers,
attn_res_layers = self.attn_res_layers,
sle_spatial = self.sle_spatial,
image_size = self.image_size,
ttur_mult = self.ttur_mult,
Expand All @@ -792,7 +796,7 @@ def load_config(self):
config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
self.image_size = config['image_size']
self.transparent = config['transparent']
self.hamburger_res_layers = config['hamburger_res_layers']
self.attn_res_layers = config['attn_res_layers']
self.syncbatchnorm = config['syncbatchnorm']
self.disc_output_size = config['disc_output_size']
self.sle_spatial = config.pop('sle_spatial', False)
Expand All @@ -808,7 +812,7 @@ def config(self):
'syncbatchnorm': self.syncbatchnorm,
'disc_output_size': self.disc_output_size,
'optimizer': self.optimizer,
'hamburger_res_layers': self.hamburger_res_layers,
'attn_res_layers': self.attn_res_layers,
'sle_spatial': self.sle_spatial
}

Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.8.5'
__version__ = '0.9.0'
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
'generative adversarial networks'
],
install_requires=[
'adabelief-pytorch',
'einops>=0.3',
'fire',
'hamburger-pytorch',
'adabelief-pytorch',
'gsa_pytorch',
'numpy',
'pillow',
'pytorch-fid',
Expand Down

0 comments on commit 5756417

Please sign in to comment.