Skip to content

Commit

Permalink
in paper, they simply used adam, no weight decay needed, with a warni…
Browse files Browse the repository at this point in the history
…ng not to use adafactor
  • Loading branch information
lucidrains committed May 25, 2022
1 parent 0e29784 commit fd6763e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
60 changes: 59 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,74 @@ images = imagen.sample(texts = [
images.shape # (3, 3, 256, 256)
```

With the `ImagenTrainer` wrapper class, the exponential moving averages for all of the U-nets in the cascading DDPM will be automatically taken care of when calling `update`

```python
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer

# unet for imagen

unet1 = Unet(
dim = 32,
cond_dim = 512,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()

unet2 = Unet(
dim = 32,
cond_dim = 512,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
unets = (unet1, unet2),
text_encoder_name = 't5-large',
image_sizes = (64, 256),
timesteps = 100,
cond_drop_prob = 0.5
).cuda()

# wrap imagen with the trainer class

trainer = ImagenTrainer(imagen)

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 1024).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

for i in (1, 2):
loss = trainer(images, text_embeds = text_embeds, unet_number = i)
trainer.update(unet_number = i)

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = trainer.sample(texts = [
'a puppy looking anxiously at a giant donut on the table',
'the milky way galaxy in the style of monet'
], cond_scale = 2.)

images.shape # (3, 3, 256, 256)
```

## Todo

- [x] use huggingface transformers for T5-small text embeddings
- [x] add dynamic thresholding
- [x] add dynamic thresholding DALLE2 and video-diffusion repository as well
- [x] allow for one to set T5-large (and perhaps small factory method to take in any huggingface transformer)
- [x] add the lowres noise level with the pseudocode in appendix, and figure out what is this sweep they do at inference time
- [x] port over some training code from DALLE2
- [ ] separate unet into base unet and SR3 unet
- [ ] build whatever efficient unet they came up with
- [ ] port over some training code from DALLE2
- [ ] figure out if learned variance was used at all, and remove it if it was inconsequential
- [ ] switch to continuous timesteps instead of discretized, as it seems that is what they used for all stages

Expand Down
1 change: 1 addition & 0 deletions imagen_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from imagen_pytorch.imagen_pytorch import Imagen, Unet
from imagen_pytorch.trainer import ImagenTrainer
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'imagen-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.10',
version = '0.0.11',
license='MIT',
description = 'Imagen - unprecedented photorealism × deep level of language understanding',
author = 'Phil Wang',
Expand All @@ -21,6 +21,7 @@
'einops>=0.4',
'einops-exts',
'kornia',
'numpy',
'resize-right',
'torch>=1.6',
'torchvision',
Expand Down

0 comments on commit fd6763e

Please sign in to comment.