Skip to content
This repository has been archived by the owner on Oct 22, 2023. It is now read-only.

Change how noise is added to latents to better learn contrasts/strong colors #101

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions scripts/configuration_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,8 @@ def create_default_variables(self):
self.save_sample_controlled_seed = []
self.delete_checkpoints_when_full_drive = True
self.use_image_names_as_captions = True
self.use_offset_noise = False
self.offset_noise_weight = 0.1
self.num_samples_to_generate = 1
self.auto_balance_concept_datasets = False
self.sample_width = 512
Expand Down Expand Up @@ -1571,6 +1573,24 @@ def create_trainer_settings_widgets(self):
self.prior_loss_preservation_weight_entry = ctk.CTkEntry(self.training_frame_finetune_subframe)
self.prior_loss_preservation_weight_entry.grid(row=19, column=3, sticky="w")
self.prior_loss_preservation_weight_entry.insert(0, self.prior_loss_weight)

#create contrasting light and color checkbox
self.use_offset_noise_var = tk.IntVar()
self.use_offset_noise_var.set(self.use_offset_noise)
#create label
self.offset_noise_label = ctk.CTkLabel(self.training_frame_finetune_subframe, text="With Offset Noise")
offset_noise_label_ttp = CreateToolTip(self.offset_noise_label, "Apply offset noise to latents to learn image contrast.")
self.offset_noise_label.grid(row=20, column=0, sticky="nsew")
#create checkbox
self.offset_noise_checkbox = ctk.CTkSwitch(self.training_frame_finetune_subframe, variable=self.use_offset_noise_var)
self.offset_noise_checkbox.grid(row=20, column=1, sticky="nsew")
#create prior loss preservation weight entry
self.offset_noise_weight_label = ctk.CTkLabel(self.training_frame_finetune_subframe, text="Offset Noise Weight")
offset_noise_weight_label_ttp = CreateToolTip(self.offset_noise_weight_label, "The weight of the offset noise.")
self.offset_noise_weight_label.grid(row=20, column=1, sticky="e")
self.offset_noise_weight_entry = ctk.CTkEntry(self.training_frame_finetune_subframe)
self.offset_noise_weight_entry.grid(row=20, column=3, sticky="w")
self.offset_noise_weight_entry.insert(0, self.offset_noise_weight)


def create_dataset_settings_widgets(self):
Expand Down Expand Up @@ -3070,6 +3090,8 @@ def save_config(self, config_file=None):
configure['attention'] = self.attention_var.get()
configure['batch_prompt_sampling'] = int(self.batch_prompt_sampling_optionmenu_var.get())
configure['shuffle_dataset_per_epoch'] = self.shuffle_dataset_per_epoch_var.get()
configure['use_offset_noise'] = self.use_offset_noise_var.get()
configure['offset_noise_weight'] = self.offset_noise_weight_entry.get()
#save the configure file
#if the file exists, delete it
if os.path.exists(file_name):
Expand Down Expand Up @@ -3222,6 +3244,9 @@ def load_config(self,file_name=None):
self.attention_var.set(configure["attention"])
self.batch_prompt_sampling_optionmenu_var.set(str(configure['batch_prompt_sampling']))
self.shuffle_dataset_per_epoch_var.set(configure["shuffle_dataset_per_epoch"])
self.use_offset_noise_var.set(configure["use_offset_noise"])
self.offset_noise_weight_entry.delete(0, tk.END)
self.offset_noise_weight_entry.insert(0, configure["offset_noise_weight"])
self.update()

def process_inputs(self,export=None):
Expand Down Expand Up @@ -3291,6 +3316,9 @@ def process_inputs(self,export=None):
self.attention = self.attention_var.get()
self.batch_prompt_sampling = int(self.batch_prompt_sampling_optionmenu_var.get())
self.shuffle_dataset_per_epoch = self.shuffle_dataset_per_epoch_var.get()
self.use_offset_noise = self.use_offset_noise_var.get()
self.offset_noise_weight = self.offset_noise_weight_entry.get()

mode = 'normal'
if self.cloud_mode == False and export == None:
#check if output path exists
Expand Down Expand Up @@ -3579,6 +3607,13 @@ def process_inputs(self,export=None):
batBase += ' --use_image_names_as_captions'
else:
batBase += f' "--use_image_names_as_captions" '
if self.use_offset_noise == True:
if export == 'Linux':
batBase += f' --with_offset_noise'
batBase += f' --offset_noise_weight={self.offset_noise_weight}'
else:
batBase += f' "--with_offset_noise" '
batBase += f' "--offset_noise_weight={self.offset_noise_weight}" '
if self.auto_balance_concept_datasets == True:
if export == 'Linux':
batBase += ' --auto_balance_concept_datasets'
Expand Down
16 changes: 15 additions & 1 deletion scripts/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,14 @@ def parse_args():
help="Flag to add prior preservation loss.",
)
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
parser.add_argument(
"--with_offset_noise",
default=False,
action="store_true",
help="Flag to offset noise applied to latents.",
)

parser.add_argument("--offset_noise_weight", type=float, default=0.1, help="The weight of offset noise applied during training.")
parser.add_argument(
"--num_class_images",
type=int,
Expand Down Expand Up @@ -1369,7 +1377,13 @@ def help(event=None):
if args.sample_from_batch > 0:
args.batch_tokens = batch[0][5]
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
# and some extra bits to make it so that the model learns to change the zero-frequency of the component freely
# https://www.crosslabs.org/blog/diffusion-with-offset-noise
if (args.with_offset_noise == True):
noise = torch.randn_like(latents) + (args.offset_noise_weight * torch.randn(latents.shape[0], latents.shape[1], 1, 1).to(accelerator.device))
else:
noise = torch.randn_like(latents)

bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, int(noise_scheduler.config.num_train_timesteps * args.max_denoising_strength), (bsz,), device=latents.device)
Expand Down