diff --git a/README.md b/README.md index 1e9f47e..1ba43c0 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,64 @@ images = imagen.sample(texts = [ images.shape # (3, 3, 256, 256) ``` +With the `ImagenTrainer` wrapper class, the exponential moving averages for all of the U-nets in the cascading DDPM will be automatically taken care of when calling `update` + +```python +import torch +from imagen_pytorch import Unet, Imagen, ImagenTrainer + +# unet for imagen + +unet1 = Unet( + dim = 32, + cond_dim = 512, + channels = 3, + dim_mults=(1, 2, 4, 8) +).cuda() + +unet2 = Unet( + dim = 32, + cond_dim = 512, + channels = 3, + dim_mults=(1, 2, 4, 8) +).cuda() + +# imagen, which contains the unets above (base unet and super resoluting ones) + +imagen = Imagen( + unets = (unet1, unet2), + text_encoder_name = 't5-large', + image_sizes = (64, 256), + timesteps = 100, + cond_drop_prob = 0.5 +).cuda() + +# wrap imagen with the trainer class + +trainer = ImagenTrainer(imagen) + +# mock images (get a lot of this) and text encodings from large T5 + +text_embeds = torch.randn(4, 256, 1024).cuda() +images = torch.randn(4, 3, 256, 256).cuda() + +# feed images into imagen, training each unet in the cascade + +for i in (1, 2): + loss = trainer(images, text_embeds = text_embeds, unet_number = i) + trainer.update(unet_number = i) + +# do the above for many many many many steps +# now you can sample an image based on the text embeddings from the cascading ddpm + +images = trainer.sample(texts = [ + 'a puppy looking anxiously at a giant donut on the table', + 'the milky way galaxy in the style of monet' +], cond_scale = 2.) + +images.shape # (3, 3, 256, 256) +``` + ## Todo - [x] use huggingface transformers for T5-small text embeddings @@ -75,9 +133,9 @@ images.shape # (3, 3, 256, 256) - [x] add dynamic thresholding DALLE2 and video-diffusion repository as well - [x] allow for one to set T5-large (and perhaps small factory method to take in any huggingface transformer) - [x] add the lowres noise level with the pseudocode in appendix, and figure out what is this sweep they do at inference time +- [x] port over some training code from DALLE2 - [ ] separate unet into base unet and SR3 unet - [ ] build whatever efficient unet they came up with -- [ ] port over some training code from DALLE2 - [ ] figure out if learned variance was used at all, and remove it if it was inconsequential - [ ] switch to continuous timesteps instead of discretized, as it seems that is what they used for all stages diff --git a/imagen_pytorch/__init__.py b/imagen_pytorch/__init__.py index 9d9707b..fc3f048 100644 --- a/imagen_pytorch/__init__.py +++ b/imagen_pytorch/__init__.py @@ -1 +1,2 @@ from imagen_pytorch.imagen_pytorch import Imagen, Unet +from imagen_pytorch.trainer import ImagenTrainer diff --git a/setup.py b/setup.py index 79d7f4c..af547ff 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'imagen-pytorch', packages = find_packages(exclude=[]), - version = '0.0.10', + version = '0.0.11', license='MIT', description = 'Imagen - unprecedented photorealism × deep level of language understanding', author = 'Phil Wang', @@ -21,6 +21,7 @@ 'einops>=0.4', 'einops-exts', 'kornia', + 'numpy', 'resize-right', 'torch>=1.6', 'torchvision',