From 05b5ba2193054835a9825f0276f5675d4ddd438d Mon Sep 17 00:00:00 2001 From: Ross Morgan-Linial Date: Sat, 20 Jan 2024 03:38:00 -0800 Subject: [PATCH] Save weights/Generate samples improvements * "Save weights" and "Generate samples" happen immediately * Generating samples won't try to load every image in the dataset when "[filewords]" is used * Generating samples handles exceptions reading images for "[filewords]" --- dreambooth/dataset/sample_dataset.py | 23 +++++++++++------ dreambooth/train_dreambooth.py | 37 ++++++++++++++-------------- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/dreambooth/dataset/sample_dataset.py b/dreambooth/dataset/sample_dataset.py index 026c7c58..4253f748 100644 --- a/dreambooth/dataset/sample_dataset.py +++ b/dreambooth/dataset/sample_dataset.py @@ -38,18 +38,27 @@ def __init__(self, config: DreamboothConfig): elif "[filewords]" in sample_prompt: prompts = [] images = get_images(concept.instance_data_dir) + random.shuffle(images) getter = FilenameTextGetter(shuffle_tags) + selected = 0 for image in images: - file_text = getter.read_text(image) - prompt = getter.create_text(sample_prompt, file_text, concept, False) - img = Image.open(image) - res = img.size - closest = closest_resolution(res[0], res[1], bucket_resos) - prompts.append((prompt, closest)) + try: + file_text = getter.read_text(image) + prompt = getter.create_text(sample_prompt, file_text, concept, False) + img = Image.open(image) + res = img.size + closest = closest_resolution(res[0], res[1], bucket_resos) + prompts.append((prompt, closest)) + selected += 1 + if selected >= required: + break + except: + pass else: prompts = [(sample_prompt, (config.resolution, config.resolution))] + random.shuffle(prompts) for i in range(required): - pi = random.choice(prompts) + pi = prompts[i % len(prompts)] pd = PromptData( prompt=pi[0], negative_prompt=neg, diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index e640c4a8..ed3bf523 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -1012,7 +1012,14 @@ def check_save(is_epoch_check=False): save_model = False save_lora = False - if not save_canceled and not save_completed: + if save_canceled or save_completed: + logger.debug("\nSave completed/canceled.") + if global_step > 0: + save_image = True + save_model = True + if args.use_lora: + save_lora = True + elif is_epoch_check: # Check to see if the number of epochs since last save is gt the interval if 0 < save_model_interval <= session_epoch - last_model_save: save_model = True @@ -1025,26 +1032,17 @@ def check_save(is_epoch_check=False): save_image = True last_image_save = session_epoch - else: - logger.debug("\nSave completed/canceled.") - if global_step > 0: - save_image = True - save_model = True - if args.use_lora: - save_lora = True - save_snapshot = False - if is_epoch_check: - if shared.status.do_save_samples: - save_image = True - shared.status.do_save_samples = False + if shared.status.do_save_samples: + save_image = True + shared.status.do_save_samples = False - if shared.status.do_save_model: - if args.use_lora: - save_lora = True - save_model = True - shared.status.do_save_model = False + if shared.status.do_save_model: + if args.use_lora: + save_lora = True + save_model = True + shared.status.do_save_model = False save_checkpoint = False if save_model: @@ -1911,6 +1909,9 @@ def lora_save_function(weights, filename): status_handler.end(status.textinfo) break + if status.do_save_model or status.do_save_samples: + check_save(False) + accelerator.wait_for_everyone() args.epoch += 1