diff --git a/scripts/configuration_gui.py b/scripts/configuration_gui.py index c3a32c1..a44374c 100644 --- a/scripts/configuration_gui.py +++ b/scripts/configuration_gui.py @@ -888,6 +888,8 @@ def create_default_variables(self): self.save_sample_controlled_seed = [] self.delete_checkpoints_when_full_drive = True self.use_image_names_as_captions = True + self.use_offset_noise = False + self.offset_noise_weight = 0.1 self.num_samples_to_generate = 1 self.auto_balance_concept_datasets = False self.sample_width = 512 @@ -1571,6 +1573,24 @@ def create_trainer_settings_widgets(self): self.prior_loss_preservation_weight_entry = ctk.CTkEntry(self.training_frame_finetune_subframe) self.prior_loss_preservation_weight_entry.grid(row=19, column=3, sticky="w") self.prior_loss_preservation_weight_entry.insert(0, self.prior_loss_weight) + + #create contrasting light and color checkbox + self.use_offset_noise_var = tk.IntVar() + self.use_offset_noise_var.set(self.use_offset_noise) + #create label + self.offset_noise_label = ctk.CTkLabel(self.training_frame_finetune_subframe, text="With Offset Noise") + offset_noise_label_ttp = CreateToolTip(self.offset_noise_label, "Apply offset noise to latents to learn image contrast.") + self.offset_noise_label.grid(row=20, column=0, sticky="nsew") + #create checkbox + self.offset_noise_checkbox = ctk.CTkSwitch(self.training_frame_finetune_subframe, variable=self.use_offset_noise_var) + self.offset_noise_checkbox.grid(row=20, column=1, sticky="nsew") + #create prior loss preservation weight entry + self.offset_noise_weight_label = ctk.CTkLabel(self.training_frame_finetune_subframe, text="Offset Noise Weight") + offset_noise_weight_label_ttp = CreateToolTip(self.offset_noise_weight_label, "The weight of the offset noise.") + self.offset_noise_weight_label.grid(row=20, column=1, sticky="e") + self.offset_noise_weight_entry = ctk.CTkEntry(self.training_frame_finetune_subframe) + self.offset_noise_weight_entry.grid(row=20, column=3, sticky="w") + self.offset_noise_weight_entry.insert(0, self.offset_noise_weight) def create_dataset_settings_widgets(self): @@ -3070,6 +3090,8 @@ def save_config(self, config_file=None): configure['attention'] = self.attention_var.get() configure['batch_prompt_sampling'] = int(self.batch_prompt_sampling_optionmenu_var.get()) configure['shuffle_dataset_per_epoch'] = self.shuffle_dataset_per_epoch_var.get() + configure['use_offset_noise'] = self.use_offset_noise_var.get() + configure['offset_noise_weight'] = self.offset_noise_weight_entry.get() #save the configure file #if the file exists, delete it if os.path.exists(file_name): @@ -3222,6 +3244,9 @@ def load_config(self,file_name=None): self.attention_var.set(configure["attention"]) self.batch_prompt_sampling_optionmenu_var.set(str(configure['batch_prompt_sampling'])) self.shuffle_dataset_per_epoch_var.set(configure["shuffle_dataset_per_epoch"]) + self.use_offset_noise_var.set(configure["use_offset_noise"]) + self.offset_noise_weight_entry.delete(0, tk.END) + self.offset_noise_weight_entry.insert(0, configure["offset_noise_weight"]) self.update() def process_inputs(self,export=None): @@ -3291,6 +3316,9 @@ def process_inputs(self,export=None): self.attention = self.attention_var.get() self.batch_prompt_sampling = int(self.batch_prompt_sampling_optionmenu_var.get()) self.shuffle_dataset_per_epoch = self.shuffle_dataset_per_epoch_var.get() + self.use_offset_noise = self.use_offset_noise_var.get() + self.offset_noise_weight = self.offset_noise_weight_entry.get() + mode = 'normal' if self.cloud_mode == False and export == None: #check if output path exists @@ -3579,6 +3607,13 @@ def process_inputs(self,export=None): batBase += ' --use_image_names_as_captions' else: batBase += f' "--use_image_names_as_captions" ' + if self.use_offset_noise == True: + if export == 'Linux': + batBase += f' --with_offset_noise' + batBase += f' --offset_noise_weight={self.offset_noise_weight}' + else: + batBase += f' "--with_offset_noise" ' + batBase += f' "--offset_noise_weight={self.offset_noise_weight}" ' if self.auto_balance_concept_datasets == True: if export == 'Linux': batBase += ' --auto_balance_concept_datasets' diff --git a/scripts/trainer.py b/scripts/trainer.py index f5e36a1..4552610 100644 --- a/scripts/trainer.py +++ b/scripts/trainer.py @@ -245,6 +245,14 @@ def parse_args(): help="Flag to add prior preservation loss.", ) parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--with_offset_noise", + default=False, + action="store_true", + help="Flag to offset noise applied to latents.", + ) + + parser.add_argument("--offset_noise_weight", type=float, default=0.1, help="The weight of offset noise applied during training.") parser.add_argument( "--num_class_images", type=int, @@ -1369,7 +1377,13 @@ def help(event=None): if args.sample_from_batch > 0: args.batch_tokens = batch[0][5] # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) + # and some extra bits to make it so that the model learns to change the zero-frequency of the component freely + # https://www.crosslabs.org/blog/diffusion-with-offset-noise + if (args.with_offset_noise == True): + noise = torch.randn_like(latents) + (args.offset_noise_weight * torch.randn(latents.shape[0], latents.shape[1], 1, 1).to(accelerator.device)) + else: + noise = torch.randn_like(latents) + bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, int(noise_scheduler.config.num_train_timesteps * args.max_denoising_strength), (bsz,), device=latents.device)