Skip to content

Commit

Permalink
Additional tag processing options
Browse files Browse the repository at this point in the history
Includes randomly dropping tags, and whether to treat the first tag specially.
  • Loading branch information
RossM committed May 3, 2024
1 parent d7dad82 commit 9d88b7c
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 21 deletions.
2 changes: 2 additions & 0 deletions dreambooth/dataclasses/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class DreamboothConfig(BaseModel):
scheduler: str = "ddim"
shared_diffusers_path: str = ""
shuffle_tags: bool = True
drop_tags: float = 0.0
skip_first_tag: bool = True
snapshot: str = ""
split_loss: bool = True
src: str = ""
Expand Down
20 changes: 12 additions & 8 deletions dreambooth/dataset/db_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dreambooth.dataclasses.prompt_data import PromptData
from dreambooth.shared import status
from dreambooth.utils.image_utils import make_bucket_resolutions, \
closest_resolution, shuffle_tags, open_and_trim
closest_resolution, process_tags, open_and_trim
from dreambooth.utils.text_utils import build_strict_tokens
from helpers.mytqdm import mytqdm

Expand All @@ -38,6 +38,8 @@ def __init__(
resolution: int,
hflip: bool,
do_shuffle_tags: bool,
drop_tags: float,
skip_first_tag: bool,
strict_tokens: bool,
dynamic_img_norm: bool,
not_pad_tokens: bool,
Expand Down Expand Up @@ -93,6 +95,8 @@ def __init__(
self.resolution = resolution
self.debug_dataset = debug_dataset
self.shuffle_tags = do_shuffle_tags
self.drop_tags = drop_tags
self.skip_first_tag = skip_first_tag
self.not_pad_tokens = not_pad_tokens
self.strict_tokens = strict_tokens
self.dynamic_img_norm = dynamic_img_norm
Expand Down Expand Up @@ -210,8 +214,8 @@ def encode_prompt(self, prompt):
bs_embed = None # default declaration

auto_add_special_tokens = False if self.strict_tokens else True
if self.shuffle_tags:
prompt = shuffle_tags(prompt)
if self.shuffle_tags or self.drop_tags > 0:
prompt = process_tags(prompt, self.shuffle_tags, self.drop_tags, self.skip_first_tag)
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
if self.strict_tokens:
prompt = build_strict_tokens(prompt, tokenizer.bos_token, tokenizer.eos_token)
Expand Down Expand Up @@ -294,7 +298,7 @@ def load_image(self, image_path, caption, res):
else:
img = open_and_trim(image_path, res, False)
image = self.image_transform(img)
if self.shuffle_tags:
if self.shuffle_tags or self.drop_tags > 0:
caption, input_ids = self.cache_caption(image_path, caption)
else:
input_ids = self.data_cache["captions"][image_path]
Expand All @@ -312,8 +316,8 @@ def cache_caption(self, image_path, caption):
input_ids = None
auto_add_special_tokens = False if self.strict_tokens else True
if len(self.tokenizers) > 0 and (image_path not in self.data_cache["captions"] or self.debug_dataset):
if self.shuffle_tags:
caption = shuffle_tags(caption)
if self.shuffle_tags or self.drop_tags > 0:
caption = process_tags(caption, self.shuffle_tags, self.drop_tags, self.skip_first_tag)
if self.strict_tokens:
caption = build_strict_tokens(caption, self.tokenizers[0].bos_token, self.tokenizers[0].eos_token)
if self.not_pad_tokens:
Expand All @@ -324,7 +328,7 @@ def cache_caption(self, image_path, caption):
input_ids = self.tokenizers[0](caption, padding='max_length', truncation=True,
add_special_tokens=auto_add_special_tokens,
return_tensors='pt').input_ids
if not self.shuffle_tags:
if not self.shuffle_tags and self.drop_tags <= 0:
self.data_cache["captions"][image_path] = input_ids

return caption, input_ids
Expand Down Expand Up @@ -398,7 +402,7 @@ def cache_images(images, reso, p_bar: mytqdm):
else:
self.data_cache["latents"][img_path] = data_cache["latents"][img_path]

if not self.shuffle_tags:
if not self.shuffle_tags and self.drop_tags <= 0:
if img_path not in data_cache["captions"] and not self.debug_dataset:
self.cache_caption(img_path, cap)
else:
Expand Down
2 changes: 2 additions & 0 deletions dreambooth/utils/gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def generate_dataset(
resolution=args.resolution,
hflip=args.hflip,
do_shuffle_tags=args.shuffle_tags,
drop_tags=args.drop_tags,
skip_first_tag=args.skip_first_tag,
strict_tokens=args.strict_tokens,
dynamic_img_norm=args.dynamic_img_norm,
not_pad_tokens=not args.pad_tokens,
Expand Down
20 changes: 12 additions & 8 deletions dreambooth/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,22 +244,26 @@ def create_text(self, prompt, file_text, concept, is_class=True):
output = re.sub(r"\\", "", output)

if self.shuffle_tags:
output = shuffle_tags(output)
output = process_tags(output, True, 0, True)
else:
output = output.strip()

return output


def shuffle_tags(caption: str):
tags = caption.split(',')
first_tag = tags.pop(0)
random.shuffle(tags)
tags.insert(0, first_tag)
output = ','.join(tags).strip()
def process_tags(caption: str, shuffle_tags: bool, drop_p: float, skip_first: bool):
tags = [t.strip() for t in caption.split(',')]
if skip_first:
first_tag = tags.pop(0)
if shuffle_tags:
random.shuffle(tags)
if drop_p > 0:
tags = [t for t in tags if random.random() >= drop_p]
if skip_first:
tags.insert(0, first_tag)
output = ', '.join(tags)
return output


def get_scheduler_names():
return [scheduler.name.replace('Scheduler', '') for scheduler in KarrasDiffusionSchedulers]

Expand Down
25 changes: 20 additions & 5 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,11 @@
data-step="0.01" id="dream_randomness" data-value="0.0"
data-label="DREAM randomness"></div>
</div>
<div class="form-group">
<div class="dbInput db-slider" data-min="0" data-max="1.0"
data-step="0.01" id="dream_randomness2" data-value="0.0"
data-label="DREAM randomness experimental"></div>
</div>
<div class="form-group">
<div class="dbInput db-slider" data-min="0" data-max="1.0"
data-step="0.01" id="dream_randomness2" data-value="0.0"
data-label="DREAM randomness experimental"></div>
</div>
<div class="form-group">
<div class="form-check form-switch">
<input class="dbInput form-check-input" type="checkbox"
Expand Down Expand Up @@ -350,6 +350,21 @@
</label>
</div>
</div>
<div class="form-group">
<div class="dbInput db-slider" data-min="0" data-max="1.0"
data-step="0.01" id="drop_tags"
data-value="0.0"
data-label="Drop Tags"></div>
</div>
<div class="form-group">
<div class="form-check form-switch">
<input class="dbInput form-check-input" type="checkbox"
id="skip_first_tag" name="skip_first_tag" checked>
<label class="form-check-label" for="skip_first_tag">
Skip First Tag
</label>
</div>
</div>
<div class="form-group">
<div class="form-check form-switch">
<input class="dbInput form-check-input" type="checkbox"
Expand Down
2 changes: 2 additions & 0 deletions javascript/dreambooth.js
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ let db_titles = {
"Set Gradients to None When Zeroing": "When performing the backwards pass, gradients will be set to none, instead of creating a new empty tensor. This will slightly improve VRAM.",
"Shuffle After Epoch": "When enabled, will shuffle the dataset after the first epoch. Will enable text encoder training and latent caching (More VRAM).",
"Shuffle Tags": "When enabled, tags after the first ',' in a prompt will be randomly ordered, which can potentially improve training.",
"Drop Tags": "A chance from 0 to 1 for each tag after the first separated by ',' in a prompt to be randomly dropped, which can potentially improve training.",
"Skip First Tag": "Whether to exclude the first tag when randomly shuffling and/or dropping tags.",
"Source Checkpoint": "The source checkpoint to extract for training.",
"Step Ratio of Text Encoder Training": "The number of steps per image (Epoch) to train the text encoder for. Set 0.5 for 50% of the epochs",
"Dynamic Image Normalization": "Normalizes each image separately by mean and standard deviation in your dataset. Useful to preserve likeness to your images.",
Expand Down
11 changes: 11 additions & 0 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,15 @@ def on_ui_tabs():
db_shuffle_tags = gr.Checkbox(
label="Shuffle Tags", value=True
)
db_drop_tags = gr.Slider(
label="Drop Tags",
minimum=0,
maximum=1,
step=0.01,
)
db_skip_first_tag = gr.Checkbox(
label="Skip First Tag", value=True
)
db_max_token_length = gr.Slider(
label="Max Token Length",
minimum=75,
Expand Down Expand Up @@ -1571,6 +1580,8 @@ def toggle_advanced():
db_scheduler,
db_shared_diffusers_path,
db_shuffle_tags,
db_drop_tags,
db_skip_first_tag,
db_snapshot,
db_split_loss,
db_src,
Expand Down
2 changes: 2 additions & 0 deletions templates/defaults/dreambooth_model_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
"save_state_during": false,
"scheduler": "UniPCMultistep",
"shuffle_tags": true,
"drop_tags": 0.0,
"skip_first_tag": true,
"split_loss": true,
"stop_text_encoder": 0.75,
"strict_tokens": true,
Expand Down
10 changes: 10 additions & 0 deletions templates/locales/titles_en.json
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@
"title": "Shuffle the tags before training.",
"description": "Whether or not to shuffle the order of the tags before training. Shuffling the tags can help prevent overfitting by making it harder for the model to memorize the order of the tags.."
},
"drop_tags": {
"label": "Drop Tags",
"title": "Randomly drop some tags during training.",
"description": "Whether or not to randomly drop some tags during training. Dropping tags can make the model more robust to inputs that don't match the usual tag format'."
},
"skip_first_tag": {
"label": "Skip First Tag",
"title": "Skip first tag when shuffling and dropping tags.",
"description": "Whether to skip the first tag when shuffling and/or dropping tags."
},
"strict_tokens": {
"label": "Strict Tokens",
"title": "Enable strict token checking.",
Expand Down

0 comments on commit 9d88b7c

Please sign in to comment.