Skip to content

Commit

Permalink
move beartype to methods instead
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 6, 2022
1 parent 33483c1 commit 8243df6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,7 +1748,6 @@ def __init__(self, *args, **kwargs):

# main imagen ddpm class, which is a cascading DDPM from Ho et al.

@beartype
class Imagen(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -2167,6 +2166,7 @@ def p_sample_loop(

@torch.no_grad()
@eval_decorator
@beartype
def sample(
self,
texts: List[str] = None,
Expand Down Expand Up @@ -2326,6 +2326,7 @@ def sample(

return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)

@beartype
def p_losses(
self,
unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel],
Expand Down Expand Up @@ -2453,6 +2454,7 @@ def p_losses(

return losses.mean()

@beartype
def forward(
self,
images,
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.17.0'
__version__ = '1.17.1'

0 comments on commit 8243df6

Please sign in to comment.