Skip to content

Commit

Permalink
an extra assert to make sure images are not integer type
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 27, 2022
1 parent 48361f6 commit ceb23d6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 1 deletion.
3 changes: 3 additions & 0 deletions imagen_pytorch/elucidated_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
default,
cast_tuple,
cast_uint8_images_to_float,
is_float_dtype,
eval_decorator,
check_shape,
pad_tuple_to_length,
Expand Down Expand Up @@ -684,6 +685,8 @@ def forward(
images = cast_uint8_images_to_float(images)
cond_images = maybe(cast_uint8_images_to_float)(cond_images)

assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'

unet_index = unet_number - 1

unet = default(unet, lambda: self.get_unet(unet_number))
Expand Down
5 changes: 5 additions & 0 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def cast_tuple(val, length = None):

return output

def is_float_dtype(dtype):
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])

def cast_uint8_images_to_float(images):
if not images.dtype == torch.uint8:
return images
Expand Down Expand Up @@ -2411,6 +2414,8 @@ def forward(
images = cast_uint8_images_to_float(images)
cond_images = maybe(cast_uint8_images_to_float)(cond_images)

assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'

unet_index = unet_number - 1

unet = default(unet, lambda: self.get_unet(unet_number))
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.11.0'
__version__ = '1.11.1'

0 comments on commit ceb23d6

Please sign in to comment.