From 9d88b7cdcbe19cc52ee383e319e2b91a01522bd9 Mon Sep 17 00:00:00 2001 From: Ross Morgan-Linial Date: Fri, 3 May 2024 10:31:07 -0700 Subject: [PATCH] Additional tag processing options Includes randomly dropping tags, and whether to treat the first tag specially. --- dreambooth/dataclasses/db_config.py | 2 ++ dreambooth/dataset/db_dataset.py | 20 +++++++++------ dreambooth/utils/gen_utils.py | 2 ++ dreambooth/utils/image_utils.py | 20 +++++++++------ index.html | 25 +++++++++++++++---- javascript/dreambooth.js | 2 ++ scripts/main.py | 11 ++++++++ .../defaults/dreambooth_model_config.json | 2 ++ templates/locales/titles_en.json | 10 ++++++++ 9 files changed, 73 insertions(+), 21 deletions(-) diff --git a/dreambooth/dataclasses/db_config.py b/dreambooth/dataclasses/db_config.py index a4ccc543..d0392dee 100644 --- a/dreambooth/dataclasses/db_config.py +++ b/dreambooth/dataclasses/db_config.py @@ -115,6 +115,8 @@ class DreamboothConfig(BaseModel): scheduler: str = "ddim" shared_diffusers_path: str = "" shuffle_tags: bool = True + drop_tags: float = 0.0 + skip_first_tag: bool = True snapshot: str = "" split_loss: bool = True src: str = "" diff --git a/dreambooth/dataset/db_dataset.py b/dreambooth/dataset/db_dataset.py index 116cf6ca..2e567e73 100644 --- a/dreambooth/dataset/db_dataset.py +++ b/dreambooth/dataset/db_dataset.py @@ -14,7 +14,7 @@ from dreambooth.dataclasses.prompt_data import PromptData from dreambooth.shared import status from dreambooth.utils.image_utils import make_bucket_resolutions, \ - closest_resolution, shuffle_tags, open_and_trim + closest_resolution, process_tags, open_and_trim from dreambooth.utils.text_utils import build_strict_tokens from helpers.mytqdm import mytqdm @@ -38,6 +38,8 @@ def __init__( resolution: int, hflip: bool, do_shuffle_tags: bool, + drop_tags: float, + skip_first_tag: bool, strict_tokens: bool, dynamic_img_norm: bool, not_pad_tokens: bool, @@ -93,6 +95,8 @@ def __init__( self.resolution = resolution self.debug_dataset = debug_dataset self.shuffle_tags = do_shuffle_tags + self.drop_tags = drop_tags + self.skip_first_tag = skip_first_tag self.not_pad_tokens = not_pad_tokens self.strict_tokens = strict_tokens self.dynamic_img_norm = dynamic_img_norm @@ -210,8 +214,8 @@ def encode_prompt(self, prompt): bs_embed = None # default declaration auto_add_special_tokens = False if self.strict_tokens else True - if self.shuffle_tags: - prompt = shuffle_tags(prompt) + if self.shuffle_tags or self.drop_tags > 0: + prompt = process_tags(prompt, self.shuffle_tags, self.drop_tags, self.skip_first_tag) for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): if self.strict_tokens: prompt = build_strict_tokens(prompt, tokenizer.bos_token, tokenizer.eos_token) @@ -294,7 +298,7 @@ def load_image(self, image_path, caption, res): else: img = open_and_trim(image_path, res, False) image = self.image_transform(img) - if self.shuffle_tags: + if self.shuffle_tags or self.drop_tags > 0: caption, input_ids = self.cache_caption(image_path, caption) else: input_ids = self.data_cache["captions"][image_path] @@ -312,8 +316,8 @@ def cache_caption(self, image_path, caption): input_ids = None auto_add_special_tokens = False if self.strict_tokens else True if len(self.tokenizers) > 0 and (image_path not in self.data_cache["captions"] or self.debug_dataset): - if self.shuffle_tags: - caption = shuffle_tags(caption) + if self.shuffle_tags or self.drop_tags > 0: + caption = process_tags(caption, self.shuffle_tags, self.drop_tags, self.skip_first_tag) if self.strict_tokens: caption = build_strict_tokens(caption, self.tokenizers[0].bos_token, self.tokenizers[0].eos_token) if self.not_pad_tokens: @@ -324,7 +328,7 @@ def cache_caption(self, image_path, caption): input_ids = self.tokenizers[0](caption, padding='max_length', truncation=True, add_special_tokens=auto_add_special_tokens, return_tensors='pt').input_ids - if not self.shuffle_tags: + if not self.shuffle_tags and self.drop_tags <= 0: self.data_cache["captions"][image_path] = input_ids return caption, input_ids @@ -398,7 +402,7 @@ def cache_images(images, reso, p_bar: mytqdm): else: self.data_cache["latents"][img_path] = data_cache["latents"][img_path] - if not self.shuffle_tags: + if not self.shuffle_tags and self.drop_tags <= 0: if img_path not in data_cache["captions"] and not self.debug_dataset: self.cache_caption(img_path, cap) else: diff --git a/dreambooth/utils/gen_utils.py b/dreambooth/utils/gen_utils.py index 50dfea3b..13d21bee 100644 --- a/dreambooth/utils/gen_utils.py +++ b/dreambooth/utils/gen_utils.py @@ -82,6 +82,8 @@ def generate_dataset( resolution=args.resolution, hflip=args.hflip, do_shuffle_tags=args.shuffle_tags, + drop_tags=args.drop_tags, + skip_first_tag=args.skip_first_tag, strict_tokens=args.strict_tokens, dynamic_img_norm=args.dynamic_img_norm, not_pad_tokens=not args.pad_tokens, diff --git a/dreambooth/utils/image_utils.py b/dreambooth/utils/image_utils.py index 0a508488..bd71281f 100644 --- a/dreambooth/utils/image_utils.py +++ b/dreambooth/utils/image_utils.py @@ -244,22 +244,26 @@ def create_text(self, prompt, file_text, concept, is_class=True): output = re.sub(r"\\", "", output) if self.shuffle_tags: - output = shuffle_tags(output) + output = process_tags(output, True, 0, True) else: output = output.strip() return output -def shuffle_tags(caption: str): - tags = caption.split(',') - first_tag = tags.pop(0) - random.shuffle(tags) - tags.insert(0, first_tag) - output = ','.join(tags).strip() +def process_tags(caption: str, shuffle_tags: bool, drop_p: float, skip_first: bool): + tags = [t.strip() for t in caption.split(',')] + if skip_first: + first_tag = tags.pop(0) + if shuffle_tags: + random.shuffle(tags) + if drop_p > 0: + tags = [t for t in tags if random.random() >= drop_p] + if skip_first: + tags.insert(0, first_tag) + output = ', '.join(tags) return output - def get_scheduler_names(): return [scheduler.name.replace('Scheduler', '') for scheduler in KarrasDiffusionSchedulers] diff --git a/index.html b/index.html index 3d1b0602..f04d5e17 100644 --- a/index.html +++ b/index.html @@ -317,11 +317,11 @@ data-step="0.01" id="dream_randomness" data-value="0.0" data-label="DREAM randomness"> -
-
-
+
+
+
+
+
+
+
+
+ + +
+