Skip to content

Commit

Permalink
lycoris student/teacher differential training should use set_multipli…
Browse files Browse the repository at this point in the history
…er instead of restore/apply_to to avoid resume bug
  • Loading branch information
bghira committed Jan 25, 2025
1 parent a368d23 commit cf0c689
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,8 @@ def _recalculate_training_steps(self):
"You must specify either --max_train_steps or --num_train_epochs with a value > 0"
)
self.config.num_train_epochs = math.ceil(
self.config.max_train_steps / max(self.config.num_update_steps_per_epoch, 1)
self.config.max_train_steps
/ max(self.config.num_update_steps_per_epoch, 1)
)
logger.info(
f"Calculated our maximum training steps at {self.config.max_train_steps} because we have"
Expand Down Expand Up @@ -1616,7 +1617,10 @@ def init_resume_checkpoint(self, lr_scheduler):
* self.accelerator.num_processes
)

if self.state["current_epoch"] > self.config.num_train_epochs + 1 and not self.config.ignore_final_epochs:
if (
self.state["current_epoch"] > self.config.num_train_epochs + 1
and not self.config.ignore_final_epochs
):
logger.info(
f"Reached the end ({self.state['current_epoch']} epochs) of our training run ({self.config.num_train_epochs} epochs). This run will do zero steps."
)
Expand Down Expand Up @@ -2307,7 +2311,10 @@ def train(self):
if self.config.ignore_final_epochs:
num_epochs_to_track += 1000000
for epoch in range(self.state["first_epoch"], num_epochs_to_track):
if self.state["current_epoch"] > self.config.num_train_epochs + 1 and not self.config.ignore_final_epochs:
if (
self.state["current_epoch"] > self.config.num_train_epochs + 1
and not self.config.ignore_final_epochs
):
# This might immediately end training, but that's useful for simply exporting the model.
logger.info(
f"Training run is complete ({self.config.num_train_epochs}/{self.config.num_train_epochs} epochs, {self.state['global_step']}/{self.config.max_train_steps} steps)."
Expand Down Expand Up @@ -2633,7 +2640,9 @@ def train(self):
training_logger.debug(
"Detaching LyCORIS adapter for parent prediction."
)
self.accelerator._lycoris_wrapped_network.restore()
self.accelerator._lycoris_wrapped_network.set_multiplier(
0.0
)
else:
raise ValueError(
f"Cannot train parent-student networks on {self.config.lora_type} model. Only LyCORIS is supported."
Expand All @@ -2651,7 +2660,9 @@ def train(self):
training_logger.debug(
"Attaching LyCORIS adapter for student prediction."
)
self.accelerator._lycoris_wrapped_network.apply_to()
self.accelerator._lycoris_wrapped_network.set_multiplier(
1.0
)

training_logger.debug("Predicting noise residual.")
model_pred = self.model_predict(
Expand Down Expand Up @@ -3077,18 +3088,18 @@ def train(self):
)
self.accelerator.wait_for_everyone()

if (
self.state["global_step"] >= self.config.max_train_steps
or (epoch > self.config.num_train_epochs and not self.config.ignore_final_epochs)
if self.state["global_step"] >= self.config.max_train_steps or (
epoch > self.config.num_train_epochs
and not self.config.ignore_final_epochs
):
logger.info(
f"Training has completed."
f"\n -> global_step = {self.state['global_step']}, max_train_steps = {self.config.max_train_steps}, epoch = {epoch}, num_train_epochs = {self.config.num_train_epochs}",
)
break
if (
self.state["global_step"] >= self.config.max_train_steps
or (epoch > self.config.num_train_epochs and not self.config.ignore_final_epochs)
if self.state["global_step"] >= self.config.max_train_steps or (
epoch > self.config.num_train_epochs
and not self.config.ignore_final_epochs
):
logger.info(
f"Exiting training loop. Beginning model unwind at epoch {epoch}, step {self.state['global_step']}"
Expand Down

0 comments on commit cf0c689

Please sign in to comment.