From 8243df6c8c49a1ba605bd8d8b3d65fa939dc898b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 6 Dec 2022 07:35:56 -0800 Subject: [PATCH] move beartype to methods instead --- imagen_pytorch/imagen_pytorch.py | 4 +++- imagen_pytorch/version.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py index 166bfab..aacb440 100644 --- a/imagen_pytorch/imagen_pytorch.py +++ b/imagen_pytorch/imagen_pytorch.py @@ -1748,7 +1748,6 @@ def __init__(self, *args, **kwargs): # main imagen ddpm class, which is a cascading DDPM from Ho et al. -@beartype class Imagen(nn.Module): def __init__( self, @@ -2167,6 +2166,7 @@ def p_sample_loop( @torch.no_grad() @eval_decorator + @beartype def sample( self, texts: List[str] = None, @@ -2326,6 +2326,7 @@ def sample( return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png) + @beartype def p_losses( self, unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel], @@ -2453,6 +2454,7 @@ def p_losses( return losses.mean() + @beartype def forward( self, images, diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index bac6c20..4da6044 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.17.0' +__version__ = '1.17.1'