Skip to content

Commit

Permalink
Save weights/Generate samples improvements
Browse files Browse the repository at this point in the history
* "Save weights" and "Generate samples" happen immediately
* Generating samples won't try to load every image in the dataset when "[filewords]" is used
* Generating samples handles exceptions reading images for "[filewords]"
  • Loading branch information
RossM committed Jan 20, 2024
1 parent 39acf55 commit 05b5ba2
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
23 changes: 16 additions & 7 deletions dreambooth/dataset/sample_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,27 @@ def __init__(self, config: DreamboothConfig):
elif "[filewords]" in sample_prompt:
prompts = []
images = get_images(concept.instance_data_dir)
random.shuffle(images)
getter = FilenameTextGetter(shuffle_tags)
selected = 0
for image in images:
file_text = getter.read_text(image)
prompt = getter.create_text(sample_prompt, file_text, concept, False)
img = Image.open(image)
res = img.size
closest = closest_resolution(res[0], res[1], bucket_resos)
prompts.append((prompt, closest))
try:
file_text = getter.read_text(image)
prompt = getter.create_text(sample_prompt, file_text, concept, False)
img = Image.open(image)
res = img.size
closest = closest_resolution(res[0], res[1], bucket_resos)
prompts.append((prompt, closest))
selected += 1
if selected >= required:
break
except:
pass
else:
prompts = [(sample_prompt, (config.resolution, config.resolution))]
random.shuffle(prompts)
for i in range(required):
pi = random.choice(prompts)
pi = prompts[i % len(prompts)]
pd = PromptData(
prompt=pi[0],
negative_prompt=neg,
Expand Down
37 changes: 19 additions & 18 deletions dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,14 @@ def check_save(is_epoch_check=False):
save_model = False
save_lora = False

if not save_canceled and not save_completed:
if save_canceled or save_completed:
logger.debug("\nSave completed/canceled.")
if global_step > 0:
save_image = True
save_model = True
if args.use_lora:
save_lora = True
elif is_epoch_check:
# Check to see if the number of epochs since last save is gt the interval
if 0 < save_model_interval <= session_epoch - last_model_save:
save_model = True
Expand All @@ -1025,26 +1032,17 @@ def check_save(is_epoch_check=False):
save_image = True
last_image_save = session_epoch

else:
logger.debug("\nSave completed/canceled.")
if global_step > 0:
save_image = True
save_model = True
if args.use_lora:
save_lora = True

save_snapshot = False

if is_epoch_check:
if shared.status.do_save_samples:
save_image = True
shared.status.do_save_samples = False
if shared.status.do_save_samples:
save_image = True
shared.status.do_save_samples = False

if shared.status.do_save_model:
if args.use_lora:
save_lora = True
save_model = True
shared.status.do_save_model = False
if shared.status.do_save_model:
if args.use_lora:
save_lora = True
save_model = True
shared.status.do_save_model = False

save_checkpoint = False
if save_model:
Expand Down Expand Up @@ -1911,6 +1909,9 @@ def lora_save_function(weights, filename):
status_handler.end(status.textinfo)
break

if status.do_save_model or status.do_save_samples:
check_save(False)

accelerator.wait_for_everyone()

args.epoch += 1
Expand Down

0 comments on commit 05b5ba2

Please sign in to comment.