Skip to content

Commit

Permalink
make sure raw texts never include empty strings
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 9, 2022
1 parent 038f300 commit e10e92e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
3 changes: 3 additions & 0 deletions imagen_pytorch/elucidated_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,8 @@ def sample(
cond_images = maybe(cast_uint8_images_to_float)(cond_images)

if exists(texts) and not exists(text_embeds) and not self.unconditional:
assert all([*map(len, texts)]), 'text cannot be empty'

with autocast(enabled = False):
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)

Expand Down Expand Up @@ -723,6 +725,7 @@ def forward(
assert h >= target_image_size and w >= target_image_size

if exists(texts) and not exists(text_embeds) and not self.unconditional:
assert all([*map(len, texts)]), 'text cannot be empty'
assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'

with autocast(enabled = False):
Expand Down
3 changes: 3 additions & 0 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2180,6 +2180,8 @@ def sample(
cond_images = maybe(cast_uint8_images_to_float)(cond_images)

if exists(texts) and not exists(text_embeds) and not self.unconditional:
assert all([*map(len, texts)]), 'text cannot be empty'

with autocast(enabled = False):
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)

Expand Down Expand Up @@ -2473,6 +2475,7 @@ def forward(
times = noise_scheduler.sample_random_times(b, device = device)

if exists(texts) and not exists(text_embeds) and not self.unconditional:
assert all([*map(len, texts)]), 'text cannot be empty'
assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'

with autocast(enabled = False):
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.11.10'
__version__ = '1.11.11'

1 comment on commit e10e92e

@TheFusion21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is intended to allow strings like this " " or wouldn't it be better of trimming first?

Please sign in to comment.