diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index 6736ac1..375dada 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -10,7 +10,7 @@ jobs: build-sdist: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Verify tag matches version run: | set -ex @@ -19,7 +19,7 @@ jobs: if [[ "$version" != "$tag" ]]; then exit 1 fi - - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 with: python-version: 3.9 - run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9d724d6..bc4ae55 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,9 +21,9 @@ jobs: python-version: [3.8, 3.9, "3.10", "3.11"] experimental: [false] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: coqui-ai/setup-python@pip-cache-key-py-ver + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} architecture: x64 diff --git a/trainer/callbacks.py b/trainer/callbacks.py index 5da4100..3a407d7 100644 --- a/trainer/callbacks.py +++ b/trainer/callbacks.py @@ -7,6 +7,8 @@ def __init__(self) -> None: self.callbacks_on_init_end = [] self.callbacks_on_epoch_start = [] self.callbacks_on_epoch_end = [] + self.callbacks_on_train_epoch_start = [] + self.callbacks_on_train_epoch_end = [] self.callbacks_on_train_step_start = [] self.callbacks_on_train_step_end = [] self.callbacks_on_keyboard_interrupt = [] @@ -21,6 +23,10 @@ def parse_callbacks_dict(self, callbacks_dict: Dict[str, Callable]) -> None: self.callbacks_on_epoch_start.append(value) elif key == "on_epoch_end": self.callbacks_on_epoch_end.append(value) + elif key == "on_train_epoch_start": + self.callbacks_on_train_epoch_start.append(value) + elif key == "on_train_epoch_end": + self.callbacks_on_train_epoch_end.append(value) elif key == "on_train_step_start": self.callbacks_on_train_step_start.append(value) elif key == "on_train_step_end": @@ -102,6 +108,42 @@ def on_epoch_end(self, trainer) -> None: for callback in self.callbacks_on_epoch_end: callback(trainer) + def on_train_epoch_start(self, trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_train_epoch_start"): + trainer.model.module.on_train_epoch_start(trainer) + else: + if hasattr(trainer.model, "on_train_epoch_start"): + trainer.model.on_train_epoch_start(trainer) + + if hasattr(trainer.criterion, "on_train_epoch_start"): + trainer.criterion.on_train_epoch_start(trainer) + + if hasattr(trainer.optimizer, "on_train_epoch_start"): + trainer.optimizer.on_train_epoch_start(trainer) + + if self.callbacks_on_train_epoch_start: + for callback in self.callbacks_on_train_epoch_start: + callback(trainer) + + def on_train_epoch_end(self, trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_train_epoch_end"): + trainer.model.module.on_train_epoch_end(trainer) + else: + if hasattr(trainer.model, "on_train_epoch_end"): + trainer.model.on_train_epoch_end(trainer) + + if hasattr(trainer.criterion, "on_train_epoch_end"): + trainer.criterion.on_train_epoch_end(trainer) + + if hasattr(trainer.optimizer, "on_train_epoch_end"): + trainer.optimizer.on_train_epoch_end(trainer) + + if self.callbacks_on_train_epoch_end: + for callback in self.callbacks_on_train_epoch_end: + callback(trainer) + @staticmethod def before_backward_pass(trainer, loss_dict) -> None: if hasattr(trainer.model, "module"): diff --git a/trainer/trainer.py b/trainer/trainer.py index 9c2c9ab..cc74024 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -443,6 +443,10 @@ def __init__( # pylint: disable=dangerous-default-value if not self.config.log_model_step: self.config.log_model_step = self.config.save_step + # make sure that start_with_eval is disabled if eval is disabled + if not self.config.run_eval and self.start_with_eval: + self.start_with_eval = False + self.total_steps_done = 0 self.epochs_done = 0 self.restore_step = 0 @@ -525,6 +529,16 @@ def __init__( # pylint: disable=dangerous-default-value # setup optimizer self.optimizer = self.get_optimizer(self.model, self.config) + # If multiple-optimizer setup with grad accumulation and without custom optimize method raise an error + if ( + self.grad_accum_steps != 1 + and isinstance(self.optimizer, list) + and not isimplemented(self.model, "optimize") + ): + raise ValueError( + " [!] Coqui Trainer does not support grad_accum_steps for multiple-optimizer setup, please set grad_accum_steps to 1 or implement in your model a custom method called ´optimize` that need to deal with dangling gradients in multiple-optimizer setup!" + ) + # CALLBACK self.callbacks = TrainerCallback() self.callbacks.parse_callbacks_dict(callbacks) @@ -1480,6 +1494,8 @@ def train_epoch(self) -> None: self.model.train() epoch_start_time = time.time() + self.callbacks.on_train_epoch_start(self) + self.c_logger.print_train_start() loader_start_time = time.time() # TRAINING EPOCH -> iterate over the training samples @@ -1502,6 +1518,8 @@ def train_epoch(self) -> None: torch.set_grad_enabled(True) epoch_time = time.time() - epoch_start_time + self.callbacks.on_train_epoch_end(self) + # scheduler step if self.scheduler is not None and self.config.scheduler_after_epoch: if isinstance(self.scheduler, list): @@ -1884,14 +1902,12 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None): def save_best_model(self) -> None: """Save the best model. It only saves if the current target loss is smaller then the previous.""" - eval_loss = None - if self.keep_avg_eval and len(self.keep_avg_eval.avg_values.keys()) > 0: - eval_loss = self._pick_target_avg_loss(self.keep_avg_eval) + eval_loss = self._pick_target_avg_loss(self.keep_avg_eval) train_loss = self._pick_target_avg_loss(self.keep_avg_train) # save the model and update the best_loss self.best_loss = save_best_model( - train_loss if eval_loss is None else eval_loss, + eval_loss if eval_loss else train_loss, self.best_loss, self.config, self.model, @@ -1908,9 +1924,7 @@ def save_best_model(self) -> None: @rank_zero_only def save_checkpoint(self) -> None: """Save the current model checkpoint.""" - eval_loss = None - if self.keep_avg_eval and len(self.keep_avg_eval.avg_values.keys()) > 0: - eval_loss = self._pick_target_avg_loss(self.keep_avg_eval) + eval_loss = self._pick_target_avg_loss(self.keep_avg_eval) train_loss = self._pick_target_avg_loss(self.keep_avg_train) save_checkpoint( @@ -2101,18 +2115,21 @@ def _detach_loss_dict(loss_dict: Dict) -> Dict: def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict: """Pick the target loss to compare models""" + + # if the keep_avg_target is None or empty return None + if keep_avg_target is None or len(list(keep_avg_target.avg_values.keys())) == 0: + return None + target_avg_loss = None # return if target loss defined in the model config # if not available in Dict use loss_1 as by default loss if "target_loss" in self.config and self.config.target_loss: if f"avg_{self.config.target_loss}" in keep_avg_target.avg_values.keys(): return keep_avg_target[f"avg_{self.config.target_loss}"] - target_loss = keep_avg_target["avg_loss_1"] - if target_loss is None: - raise ValueError( - " [!] Target loss not found in the keep_avg_target. You might be exiting the training loop before it is computed or set the target_loss in the model config incorrectly." - ) - return target_loss + + raise ValueError( + " [!] Target loss not found in the keep_avg_target. You might be exiting the training loop before it is computed or set the target_loss in the model config incorrectly." + ) # take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers if isinstance(self.optimizer, list):