Skip to content

Add save_on_exception option to ModelCheckpoint #20916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0f73167
add saving of checkpoint if an exception is raised
vsey Jun 19, 2025
136e59a
import callback to checkpoint test file
vsey Jun 19, 2025
e0dae53
add test for exception in training callbacks
vsey Jun 19, 2025
2113acc
split test for save checksave point on expection for expetions in tra…
vsey Jun 19, 2025
7d750e6
add extra condition for checking if we should save on exception
vsey Jun 19, 2025
34e598a
add for saving checkpoint on exeption if the exception occurs in a va…
vsey Jun 19, 2025
d4d933b
add test for save model chekpoint on exception for exception in train…
vsey Jun 19, 2025
9f6063b
disable trainer prog bar for test of model checkpoint on exception
vsey Jun 19, 2025
02477d5
model checkpoint on eception split trainer setup over two lines
vsey Jun 19, 2025
e5b0498
remove trainling braket from shoukd_save_on_eception condition
vsey Jun 19, 2025
8bc93e2
Merge branch 'master' into feat/ModelCheckpointException
vsey Jun 19, 2025
985c1e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2025
f0502ec
Update save checkpoint on exception tests to use a shorter more preci…
vsey Jun 19, 2025
99af7ed
switch default on save on checkpoint on exception to false to don't i…
vsey Jun 20, 2025
c092385
checkpoint on exception put callback tests into a pytest prametrization
vsey Jun 20, 2025
904bd74
change doc string to reflect new default value for save on exception …
vsey Jun 20, 2025
3a3204e
checkpoint on exception add test function for exception in callback
vsey Jun 20, 2025
0b1eb77
Merge branch 'master' into feat/ModelCheckpointException
vsey Jun 20, 2025
467c57b
add prefix option to generate checkpoint file name
vsey Jun 20, 2025
8ba6381
add exception prefix to checkpoints saved on exception
vsey Jun 20, 2025
3076ea1
add test to test prefix for checkpoint name
vsey Jun 21, 2025
d78ea3e
add test for exceptions at diffrent position in a model
vsey Jun 21, 2025
42bbac1
add description to on exception hook in model checkpoint
vsey Jun 21, 2025
c4b8063
add test to check saving on exception in all relevalnt callback posit…
vsey Jun 21, 2025
2ca6dab
also print exception when saving checkpoint
vsey Jun 21, 2025
9e9e580
test checkpointing on exception in varoius model steps
vsey Jun 21, 2025
d2f74e9
remove deviders in test_model_checkpoint
vsey Jun 21, 2025
ac33670
add test for run conditions for save checkpoint on exception
vsey Jun 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class ModelCheckpoint(Checkpoint):
collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
collisions.
save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``False``.
mode: one of {min, max}.
If ``save_top_k != 0``, the decision to overwrite the current save file is made
based on either the maximization or the minimization of the monitored quantity.
Expand Down Expand Up @@ -213,6 +214,7 @@ class ModelCheckpoint(Checkpoint):
CHECKPOINT_JOIN_CHAR = "-"
CHECKPOINT_EQUALS_CHAR = "="
CHECKPOINT_NAME_LAST = "last"
CHECKPOINT_EXCEPTION_PREFIX = "exception"
FILE_EXTENSION = ".ckpt"
STARTING_VERSION = 1

Expand All @@ -224,6 +226,7 @@ def __init__(
verbose: bool = False,
save_last: Optional[Union[bool, Literal["link"]]] = None,
save_top_k: int = 1,
save_on_exception: bool = False,
save_weights_only: bool = False,
mode: str = "min",
auto_insert_metric_name: bool = True,
Expand All @@ -238,6 +241,7 @@ def __init__(
self.verbose = verbose
self.save_last = save_last
self.save_top_k = save_top_k
self.save_on_exception = save_on_exception
self.save_weights_only = save_weights_only
self.auto_insert_metric_name = auto_insert_metric_name
self._save_on_train_epoch_end = save_on_train_epoch_end
Expand Down Expand Up @@ -338,6 +342,19 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)

@override
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: Exception) -> None:
"""Save a checkpoint when an exception is raised."""
if self._should_save_on_exception(trainer):
monitor_candidates = self._monitor_candidates(trainer)
filepath = self.format_checkpoint_name(metrics=monitor_candidates, prefix=self.CHECKPOINT_EXCEPTION_PREFIX)
self._save_checkpoint(trainer, filepath)
self._save_last_checkpoint(trainer, monitor_candidates)
rank_zero_info(
f"An {type(exception).__name__} was raised with message: \
{str(exception)}, saved checkpoint to {filepath}"
)

@override
def state_dict(self) -> dict[str, Any]:
return {
Expand Down Expand Up @@ -426,6 +443,14 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
or self._last_global_step_saved == trainer.global_step # already saved at the last step
)

def _should_save_on_exception(self, trainer: "pl.Trainer") -> bool:
return (
self.save_on_exception
and not bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
and not trainer.sanity_checking # don't save anything during sanity check
and self._last_global_step_saved != trainer.global_step # already saved at the last step
)

def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool:
if self._save_on_train_epoch_end is not None:
return self._save_on_train_epoch_end
Expand Down Expand Up @@ -571,7 +596,11 @@ def _format_checkpoint_name(
return filename

def format_checkpoint_name(
self, metrics: dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None
self,
metrics: dict[str, Tensor],
filename: Optional[str] = None,
prefix: Optional[str] = None,
ver: Optional[int] = None,
) -> str:
"""Generate a filename according to the defined template.

Expand Down Expand Up @@ -603,7 +632,9 @@ def format_checkpoint_name(

"""
filename = filename or self.filename
filename = self._format_checkpoint_name(filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name)
filename = self._format_checkpoint_name(
filename, metrics, prefix, auto_insert_metric_name=self.auto_insert_metric_name
)

if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
Expand Down
Loading