Skip to content

Commit

Permalink
offer convenient way to return list of pillow images for saving
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 25, 2022
1 parent c468b4f commit 42d87b8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 7 additions & 2 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,8 @@ def sample(
batch_size = 1,
cond_scale = 1.,
lowres_sample_noise_level = None,
stop_at_unet_number = None
stop_at_unet_number = None,
return_pil_images = False
):
device = next(self.parameters()).device

Expand Down Expand Up @@ -1196,7 +1197,11 @@ def sample(
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
break

return img
if not return_pil_images:
return img

pil_images = list(map(T.ToPILImage(), img.unbind(dim = 0)))
return pil_images # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)

def p_losses(self, unet, x_start, times, *, noise_scheduler, lowres_cond_img = None, lowres_aug_times = None, text_embeds = None, text_mask = None, noise = None, learned_variance = False, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'imagen-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.24',
version = '0.0.25',
license='MIT',
description = 'Imagen - unprecedented photorealism × deep level of language understanding',
author = 'Phil Wang',
Expand Down

0 comments on commit 42d87b8

Please sign in to comment.