diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 6b7b2831a2e04..5c2d77af57904 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -97,6 +97,7 @@ class ModelCheckpoint(Checkpoint): collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid collisions. + save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``False``. mode: one of {min, max}. If ``save_top_k != 0``, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. @@ -224,6 +225,7 @@ def __init__( verbose: bool = False, save_last: Optional[Union[bool, Literal["link"]]] = None, save_top_k: int = 1, + save_on_exception: bool = False, save_weights_only: bool = False, mode: str = "min", auto_insert_metric_name: bool = True, @@ -238,6 +240,7 @@ def __init__( self.verbose = verbose self.save_last = save_last self.save_top_k = save_top_k + self.save_on_exception = save_on_exception self.save_weights_only = save_weights_only self.auto_insert_metric_name = auto_insert_metric_name self._save_on_train_epoch_end = save_on_train_epoch_end @@ -338,6 +341,20 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) + @override + def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: + """Save a checkpoint when an exception is raised.""" + if not self._should_save_on_exception(trainer): + return + monitor_candidates = self._monitor_candidates(trainer) + filepath = self.format_checkpoint_name(metrics=monitor_candidates) + self._save_checkpoint(trainer, filepath) + self._save_last_checkpoint(trainer, monitor_candidates) + rank_zero_info( + f"An {type(exception).__name__} was raised with message: \ + {str(exception)}, saved checkpoint to {filepath}" + ) + @override def state_dict(self) -> dict[str, Any]: return { @@ -426,6 +443,14 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool: or self._last_global_step_saved == trainer.global_step # already saved at the last step ) + def _should_save_on_exception(self, trainer: "pl.Trainer") -> bool: + return ( + self.save_on_exception + and not bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run + and not trainer.sanity_checking # don't save anything during sanity check + and self._last_global_step_saved != trainer.global_step # already saved at the last step + ) + def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool: if self._save_on_train_epoch_end is not None: return self._save_on_train_epoch_end @@ -538,7 +563,7 @@ def _format_checkpoint_name( self, filename: Optional[str], metrics: dict[str, Tensor], - prefix: str = "", + prefix: Optional[str] = None, auto_insert_metric_name: bool = True, ) -> str: if not filename: @@ -565,13 +590,17 @@ def _format_checkpoint_name( metrics[name] = torch.tensor(0) filename = filename.format(metrics) - if prefix: + if prefix is not None: filename = self.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) return filename def format_checkpoint_name( - self, metrics: dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None + self, + metrics: dict[str, Tensor], + filename: Optional[str] = None, + prefix: Optional[str] = None, + ver: Optional[int] = None, ) -> str: """Generate a filename according to the defined template. @@ -603,7 +632,9 @@ def format_checkpoint_name( """ filename = filename or self.filename - filename = self._format_checkpoint_name(filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name) + filename = self._format_checkpoint_name( + filename, metrics, prefix, auto_insert_metric_name=self.auto_insert_metric_name + ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 7b17498865889..9eeb3381d75e6 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -35,7 +35,7 @@ import lightning.pytorch as pl from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.pytorch import Trainer, seed_everything -from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -453,6 +453,12 @@ def test_model_checkpoint_format_checkpoint_name(tmp_path, monkeypatch): ckpt_name = ckpt.format_checkpoint_name({}, ver=3) assert ckpt_name == str(tmp_path / "name-v3.ckpt") + # with prefix + ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath=tmp_path, filename="name").format_checkpoint_name( + {}, prefix="test" + ) + assert ckpt_name == str(tmp_path / "test-name.ckpt") + # using slashes ckpt = ModelCheckpoint(monitor="early_stop_on", dirpath=None, filename="{epoch}_{val/loss:.5f}") ckpt_name = ckpt.format_checkpoint_name({"epoch": 4, "val/loss": 0.03}) @@ -764,6 +770,420 @@ def test_ckpt_every_n_train_steps(tmp_path): assert set(os.listdir(tmp_path)) == set(expected) +def test_model_checkpoint_on_exception_run_condition(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a lightning module.""" + + # Don't save checkpoint if sanity check fails + class TroubledModelSanityCheck(BoringModel): + def on_validation_start(self) -> None: + if self.trainer.sanity_checking: + print("Trouble!") + raise RuntimeError("Trouble!") + + model = TroubledModelSanityCheck() + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="sanity_check", save_on_exception=True) + trainer = Trainer( + default_root_dir=tmp_path, + num_sanity_val_steps=4, + limit_train_batches=2, + callbacks=[checkpoint_callback], + max_epochs=2, + logger=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert not os.path.isfile(tmp_path / "exception-sanity_check.ckpt") + + # Don't save checkpoint if fast dev run fails + class TroubledModelFastDevRun(BoringModel): + def on_train_batch_start(self, batch, batch_idx) -> None: + if self.trainer.fast_dev_run and batch_idx == 1: + raise RuntimeError("Trouble!") + + model = TroubledModelFastDevRun() + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="fast_dev_run", save_on_exception=True) + trainer = Trainer( + default_root_dir=tmp_path, + fast_dev_run=2, + limit_train_batches=2, + callbacks=[checkpoint_callback], + max_epochs=2, + logger=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert not os.path.isfile(tmp_path / "exception-fast_dev_run.ckpt") + + # Don't save checkpoint if already saved a checkpoint + class TroubledModelAlreadySavedCheckpoint(BoringModel): + def on_train_batch_start(self, batch, batch_idx) -> None: + if self.trainer.global_step == 1: + raise RuntimeError("Trouble!") + + model = TroubledModelAlreadySavedCheckpoint() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, filename="already_saved", save_on_exception=True, every_n_train_steps=1 + ) + trainer = Trainer( + default_root_dir=tmp_path, limit_train_batches=2, callbacks=[checkpoint_callback], max_epochs=2, logger=False + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + + assert not os.path.isfile(tmp_path / "exception-already_saved.ckpt") + assert os.path.isfile(tmp_path / "already_saved.ckpt") + + +class TroubledModelInTrainingStep(BoringModel): + def training_step(self, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelInValidationStep(BoringModel): + def validation_step(self, batch, batch_idx): + if not self.trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelBackward(BoringModel): + def backward(self, loss): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnBeforeBackward(BoringModel): + def on_before_backward(self, loss): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnAfterBackward(BoringModel): + def on_after_backward(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnBeforeZeroGrad(BoringModel): + def on_before_zero_grad(self, optimizer): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnFitEnd(BoringModel): + def on_fit_end(self): + raise RuntimeError("Trouble!") + + +class TroubledModelOnTrainEnd(BoringModel): + def on_train_end(self): + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationStart(BoringModel): + def on_validation_start(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationEnd(BoringModel): + def on_validation_end(self): + if not self.trainer.sanity_checking: + raise RuntimeError("Trouble!") + + +class TroubledModelOnTrainBatchStart(BoringModel): + def on_train_batch_start(self, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnTrainBatchEnd(BoringModel): + def on_train_batch_end(self, outputs, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnTrainEpochStart(BoringModel): + def on_train_epoch_start(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnTrainEpochEnd(BoringModel): + def on_train_epoch_end(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationBatchStart(BoringModel): + def on_validation_batch_start(self, batch, batch_idx): + if not self.trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationBatchEnd(BoringModel): + def on_validation_batch_end(self, outputs, batch, batch_idx): + if not self.trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationEpochStart(BoringModel): + def on_validation_epoch_start(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationEpochEnd(BoringModel): + def on_validation_epoch_end(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationModelEval(BoringModel): + def on_validation_model_eval(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationModelTrain(BoringModel): + def on_validation_model_train(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnBeforeOptimizerStep(BoringModel): + def on_before_optimizer_step(self, optimizer): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelConfigureGradienClipping(BoringModel): + def configure_gradient_clipping(self, optimizer, gradient_clip_val=None, gradient_clip_algorithm=None): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOptimizerStep(BoringModel): + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure=None): + optimizer.step(closure=optimizer_closure) + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOptimizerZeroGrad(BoringModel): + def optimizer_zero_grad(self, epoch, batch_idx, optimizer): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +@pytest.mark.parametrize( + "TroubledModel", + [ + TroubledModelInTrainingStep, + TroubledModelInValidationStep, + TroubledModelBackward, + TroubledModelOnBeforeBackward, + TroubledModelOnAfterBackward, + TroubledModelOnBeforeZeroGrad, + TroubledModelOnFitEnd, + TroubledModelOnTrainEnd, + TroubledModelOnValidationStart, + TroubledModelOnValidationEnd, + TroubledModelOnTrainBatchStart, + TroubledModelOnTrainBatchEnd, + TroubledModelOnTrainEpochStart, + TroubledModelOnTrainEpochEnd, + TroubledModelOnValidationBatchStart, + TroubledModelOnValidationBatchEnd, + TroubledModelOnValidationEpochStart, + TroubledModelOnValidationEpochEnd, + TroubledModelOnValidationModelEval, + TroubledModelOnValidationModelTrain, + TroubledModelOnBeforeOptimizerStep, + TroubledModelConfigureGradienClipping, + TroubledModelOptimizerStep, + TroubledModelOptimizerZeroGrad, + ], +) +def test_model_checkpoint_on_exception_parametrized(tmp_path, TroubledModel): + """Test that the checkpoint is saved when an exception is raised in a lightning module.""" + model = TroubledModel() + + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, filename="exception", save_on_exception=True, every_n_epochs=7 + ) + + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback], + limit_train_batches=2, + max_epochs=4, + logger=False, + enable_progress_bar=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + + checkpoint_path = tmp_path / "exception.ckpt" + + assert os.path.isfile(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + assert checkpoint["state_dict"] is not None + assert checkpoint["state_dict"] != {} + + +class TroubledCallbackOnFitEnd(Callback): + def on_fit_end(self, trainer, pl_module): + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainBatchStart(Callback): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainBatchEnd(Callback): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainEpochStart(Callback): + def on_train_epoch_start(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainEpochEnd(Callback): + def on_train_epoch_end(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationEpochStart(Callback): + def on_validation_epoch_start(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationEpochEnd(Callback): + def on_validation_epoch_end(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationBatchStart(Callback): + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationBatchEnd(Callback): + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainEnd(Callback): + def on_train_end(self, trainer, pl_module): + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationStart(Callback): + def on_validation_start(self, trainer, pl_module): + if not trainer.sanity_checking: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationEnd(Callback): + def on_validation_end(self, trainer, pl_module): + if not trainer.sanity_checking: + raise RuntimeError("Trouble!") + + +class TroubleCallbackOnBeforeBackward(Callback): + def on_before_backward(self, trainer, pl_module, loss): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubleCallbackOnAfterBackward(Callback): + def on_after_backward(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubleCallbackOnBeforeOptimizerStep(Callback): + def on_before_optimizer_step(self, trainer, pl_module, optimizer): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubleCallbackOnBeforeZeroGrad(Callback): + def on_before_zero_grad(self, trainer, pl_module, optimizer): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +#### + + +@pytest.mark.parametrize( + "TroubledCallback", + [ + TroubledCallbackOnFitEnd, + TroubledCallbackOnTrainBatchStart, + TroubledCallbackOnTrainBatchEnd, + TroubledCallbackOnTrainEpochStart, + TroubledCallbackOnTrainEpochEnd, + TroubledCallbackOnValidationEpochStart, + TroubledCallbackOnValidationEpochEnd, + TroubledCallbackOnValidationBatchStart, + TroubledCallbackOnValidationBatchEnd, + TroubledCallbackOnTrainEnd, + TroubledCallbackOnValidationStart, + TroubledCallbackOnValidationEnd, + TroubleCallbackOnBeforeBackward, + TroubleCallbackOnAfterBackward, + TroubleCallbackOnBeforeOptimizerStep, + TroubleCallbackOnBeforeZeroGrad, + ], +) +def test_model_checkpoint_on_exception_in_other_callbacks(tmp_path, TroubledCallback): + """Test that an checkpoint is saved when an exception is raised in an other callback.""" + + model = BoringModel() + troubled_callback = TroubledCallback() + + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, filename="exception", save_on_exception=True, every_n_epochs=7 + ) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, troubled_callback], + max_epochs=4, + limit_train_batches=2, + logger=False, + enable_progress_bar=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + + checkpoint_path = tmp_path / "exception.ckpt" + + assert os.path.isfile(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + assert checkpoint["state_dict"] is not None + assert checkpoint["state_dict"] != {} + + @mock.patch("lightning.pytorch.callbacks.model_checkpoint.time") def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None: """Tests that the checkpoints are saved at the specified time interval."""