Skip to content

Commit

Permalink
add new GAN stability measure (zero centered gp on fake images as wel…
Browse files Browse the repository at this point in the history
…l) out of Cornell and Brown university
  • Loading branch information
lucidrains committed Jan 12, 2025
1 parent 5ff7b57 commit 72ba807
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,12 @@ Thank you to Matthew Mann for his inspiring [simple port](https://github.com/man
primaryClass = {cs.CV}
}
```

```bibtex
@inproceedings{Huang2025TheGI,
title = {The GAN is dead; long live the GAN! A Modern GAN Baseline},
author = {Yiwen Huang and Aaron Gokaslan and Volodymyr Kuleshov and James Tompkin},
year = {2025},
url = {https://api.semanticscholar.org/CorpusID:275405495}
}
```
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
],
install_requires=[
'aim',
'einops>=0.7.0',
'einops>=0.8.0',
'contrastive_learner>=0.1.0',
'fire',
'kornia>=0.5.4',
'numpy',
'retry',
'tqdm',
'torch',
'torch>=2.2',
'torchvision',
'pillow',
'vector-quantize-pytorch==0.1.0'
Expand Down
27 changes: 19 additions & 8 deletions stylegan2_pytorch/stylegan2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,14 @@ def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
else:
loss.backward(**kwargs)

def gradient_penalty(images, output, weight = 10):
def gradient_penalty(images, output, weight = 10, center = 0.):
batch_size = images.shape[0]
gradients = torch_grad(outputs=output, inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True, retain_graph=True, only_inputs=True)[0]

gradients = gradients.reshape(batch_size, -1)
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return weight * ((gradients.norm(2, dim=1) - center) ** 2).mean()

def calc_pl_lengths(styles, images):
device = images.device
Expand Down Expand Up @@ -396,15 +396,23 @@ def __init__(self, D, image_size):
super().__init__()
self.D = D

def forward(self, images, prob = 0., types = [], detach = False):
def forward(self, images, prob = 0., types = [], detach = False, return_aug_images = False, input_requires_grad = False):
if random() < prob:
images = random_hflip(images, prob=0.5)
images = DiffAugment(images, types=types)

if detach:
images = images.detach()

return self.D(images)
if input_requires_grad:
images.requires_grad_()

logits = self.D(images)

if not return_aug_images:
return logits

return images, logits

# stylegan2 classes

Expand Down Expand Up @@ -1030,10 +1038,13 @@ def train(self):
w_styles = styles_def_to_tensor(w_space)

generated_images = G(w_styles, noise)
fake_output, fake_q_loss = D_aug(generated_images.clone().detach(), detach = True, **aug_kwargs)
generated_images, (fake_output, fake_q_loss) = D_aug(generated_images.clone().detach(), return_aug_images = True, input_requires_grad = apply_gradient_penalty, detach = True, **aug_kwargs)

image_batch = next(self.loader).cuda(self.rank)
image_batch.requires_grad_()

if apply_gradient_penalty:
image_batch.requires_grad_()

real_output, real_q_loss = D_aug(image_batch, **aug_kwargs)

real_output_loss = real_output
Expand All @@ -1053,7 +1064,7 @@ def train(self):
disc_loss = disc_loss + quantize_loss

if apply_gradient_penalty:
gp = gradient_penalty(image_batch, real_output)
gp = gradient_penalty(image_batch, real_output) + gradient_penalty(generated_images, fake_output)
self.last_gp_loss = gp.clone().detach().item()
self.track(self.last_gp_loss, 'GP')
disc_loss = disc_loss + gp
Expand Down Expand Up @@ -1382,7 +1393,7 @@ def load(self, num = -1):

self.steps = name * self.save_every

load_data = torch.load(self.model_name(name))
load_data = torch.load(self.model_name(name), weights_only = True)

if 'version' in load_data:
print(f"loading from version {load_data['version']}")
Expand Down

0 comments on commit 72ba807

Please sign in to comment.