From cf0c6892bf389a5a8a9024cdebd1783f9e2a6696 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 25 Jan 2025 16:27:01 -0600 Subject: [PATCH] lycoris student/teacher differential training should use set_multiplier instead of restore/apply_to to avoid resume bug --- helpers/training/trainer.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index e864ee4c..d0e3472c 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1146,7 +1146,8 @@ def _recalculate_training_steps(self): "You must specify either --max_train_steps or --num_train_epochs with a value > 0" ) self.config.num_train_epochs = math.ceil( - self.config.max_train_steps / max(self.config.num_update_steps_per_epoch, 1) + self.config.max_train_steps + / max(self.config.num_update_steps_per_epoch, 1) ) logger.info( f"Calculated our maximum training steps at {self.config.max_train_steps} because we have" @@ -1616,7 +1617,10 @@ def init_resume_checkpoint(self, lr_scheduler): * self.accelerator.num_processes ) - if self.state["current_epoch"] > self.config.num_train_epochs + 1 and not self.config.ignore_final_epochs: + if ( + self.state["current_epoch"] > self.config.num_train_epochs + 1 + and not self.config.ignore_final_epochs + ): logger.info( f"Reached the end ({self.state['current_epoch']} epochs) of our training run ({self.config.num_train_epochs} epochs). This run will do zero steps." ) @@ -2307,7 +2311,10 @@ def train(self): if self.config.ignore_final_epochs: num_epochs_to_track += 1000000 for epoch in range(self.state["first_epoch"], num_epochs_to_track): - if self.state["current_epoch"] > self.config.num_train_epochs + 1 and not self.config.ignore_final_epochs: + if ( + self.state["current_epoch"] > self.config.num_train_epochs + 1 + and not self.config.ignore_final_epochs + ): # This might immediately end training, but that's useful for simply exporting the model. logger.info( f"Training run is complete ({self.config.num_train_epochs}/{self.config.num_train_epochs} epochs, {self.state['global_step']}/{self.config.max_train_steps} steps)." @@ -2633,7 +2640,9 @@ def train(self): training_logger.debug( "Detaching LyCORIS adapter for parent prediction." ) - self.accelerator._lycoris_wrapped_network.restore() + self.accelerator._lycoris_wrapped_network.set_multiplier( + 0.0 + ) else: raise ValueError( f"Cannot train parent-student networks on {self.config.lora_type} model. Only LyCORIS is supported." @@ -2651,7 +2660,9 @@ def train(self): training_logger.debug( "Attaching LyCORIS adapter for student prediction." ) - self.accelerator._lycoris_wrapped_network.apply_to() + self.accelerator._lycoris_wrapped_network.set_multiplier( + 1.0 + ) training_logger.debug("Predicting noise residual.") model_pred = self.model_predict( @@ -3077,18 +3088,18 @@ def train(self): ) self.accelerator.wait_for_everyone() - if ( - self.state["global_step"] >= self.config.max_train_steps - or (epoch > self.config.num_train_epochs and not self.config.ignore_final_epochs) + if self.state["global_step"] >= self.config.max_train_steps or ( + epoch > self.config.num_train_epochs + and not self.config.ignore_final_epochs ): logger.info( f"Training has completed." f"\n -> global_step = {self.state['global_step']}, max_train_steps = {self.config.max_train_steps}, epoch = {epoch}, num_train_epochs = {self.config.num_train_epochs}", ) break - if ( - self.state["global_step"] >= self.config.max_train_steps - or (epoch > self.config.num_train_epochs and not self.config.ignore_final_epochs) + if self.state["global_step"] >= self.config.max_train_steps or ( + epoch > self.config.num_train_epochs + and not self.config.ignore_final_epochs ): logger.info( f"Exiting training loop. Beginning model unwind at epoch {epoch}, step {self.state['global_step']}"