diff --git a/dreambooth/utils/image_utils.py b/dreambooth/utils/image_utils.py index 197088bd..6525240f 100644 --- a/dreambooth/utils/image_utils.py +++ b/dreambooth/utils/image_utils.py @@ -255,16 +255,15 @@ def process_tags(caption: str, shuffle_tags: bool, drop_p: float, skip_first: bo tags = [t.strip() for t in caption.split(',')] if skip_first: first_tag = tags.pop(0) - if shuffle_tags: - random.shuffle(tags) if drop_p > 0: # Randomly drop more tags if there are a lot of tags in the image tag_cap = random.randint(10, 30) - if len(tags) > tag_cap: - drop_p = 1 - (1 - drop_p) * tag_cap / len(tags) + drop_p = max(drop_p, 1 - tag_cap / len(tags)) tags = [t for t in tags if random.random() >= drop_p] - if skip_first: + if shuffle_tags: + random.shuffle(tags) + if skip_first and first_tag != "*" and not first_tag in tags: tags.insert(0, first_tag) output = ', '.join(tags) return output