Skip to content

Commit 985c1e1

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 8bc93e2 commit 985c1e1

File tree

2 files changed

+108
-34
lines changed

2 files changed

+108
-34
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", e
350350
self._save_last_checkpoint(trainer, monitor_candidates)
351351
rank_zero_info(f"An exception was raised saved checkpoint to {filepath}")
352352

353-
354353
@override
355354
def state_dict(self) -> dict[str, Any]:
356355
return {
@@ -441,10 +440,10 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
441440

442441
def _should_save_on_exception(self, trainer: "pl.Trainer") -> bool:
443442
return (
444-
self.save_on_exception
445-
and not bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
446-
and not trainer.sanity_checking # don't save anything during sanity check
447-
and not self._last_global_step_saved == trainer.global_step # already saved at the last step
443+
self.save_on_exception
444+
and not bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
445+
and not trainer.sanity_checking # don't save anything during sanity check
446+
and self._last_global_step_saved != trainer.global_step # already saved at the last step
448447
)
449448

450449
def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool:

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 104 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -766,22 +766,30 @@ def test_ckpt_every_n_train_steps(tmp_path):
766766

767767
def test_model_checkpoint_save_on_exception_in_training_step(tmp_path):
768768
"""Test that the checkpoint is saved when an exception is raised in training_step."""
769+
769770
class TroubledModel(BoringModel):
770771
def training_step(self, batch, batch_idx):
771772
if batch_idx == 1:
772773
raise RuntimeError("Trouble!")
773774

774775
model = TroubledModel()
775776
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
776-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback],
777-
max_epochs=5, logger=False, enable_progress_bar=False)
777+
trainer = Trainer(
778+
default_root_dir=tmp_path,
779+
callbacks=[checkpoint_callback],
780+
max_epochs=5,
781+
logger=False,
782+
enable_progress_bar=False,
783+
)
778784
with pytest.raises(RuntimeError, match="Trouble!"):
779785
trainer.fit(model)
780786
print(os.listdir(tmp_path))
781787
assert os.path.isfile(tmp_path / "step=1.ckpt")
782788

789+
783790
def test_model_checkpoint_save_on_exception_in_validation_step(tmp_path):
784791
"""Test that the checkpoint is saved when an exception is raised in validation_step."""
792+
785793
class TroubledModel(BoringModel):
786794
def validation_step(self, batch, batch_idx):
787795
if not trainer.sanity_checking and batch_idx == 0:
@@ -790,40 +798,57 @@ def validation_step(self, batch, batch_idx):
790798
model = TroubledModel()
791799
epoch_length = 64
792800
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
793-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback],
794-
max_epochs=5, logger=False, enable_progress_bar=False)
801+
trainer = Trainer(
802+
default_root_dir=tmp_path,
803+
callbacks=[checkpoint_callback],
804+
max_epochs=5,
805+
logger=False,
806+
enable_progress_bar=False,
807+
)
795808
with pytest.raises(RuntimeError, match="Trouble!"):
796809
trainer.fit(model)
797810
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
798811

799812

800813
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start(tmp_path):
801814
"""Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start."""
815+
802816
class TroublemakerOnTrainBatchStart(Callback):
803817
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
804818
if batch_idx == 1:
805819
raise RuntimeError("Trouble!")
806820

807821
model = BoringModel()
808822
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
809-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()],
810-
max_epochs=5, logger=False, enable_progress_bar=False)
823+
trainer = Trainer(
824+
default_root_dir=tmp_path,
825+
callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()],
826+
max_epochs=5,
827+
logger=False,
828+
enable_progress_bar=False,
829+
)
811830
with pytest.raises(RuntimeError, match="Trouble!"):
812831
trainer.fit(model)
813832
assert os.path.isfile(tmp_path / "step=1.ckpt")
814833

815834

816835
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_end(tmp_path):
817836
"""Test that the checkpoint is saved when an exception is raised in a callback on train_batch_end."""
837+
818838
class TroublemakerOnTrainBatchEnd(Callback):
819839
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
820840
if batch_idx == 1:
821841
raise RuntimeError("Trouble!")
822842

823843
model = BoringModel()
824844
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
825-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()],
826-
max_epochs=5, logger=False, enable_progress_bar=False)
845+
trainer = Trainer(
846+
default_root_dir=tmp_path,
847+
callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()],
848+
max_epochs=5,
849+
logger=False,
850+
enable_progress_bar=False,
851+
)
827852
with pytest.raises(RuntimeError, match="Trouble!"):
828853
trainer.fit(model)
829854

@@ -832,6 +857,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
832857

833858
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_start(tmp_path):
834859
"""Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_start."""
860+
835861
class TroublemakerOnTrainEpochStart(Callback):
836862
def on_train_epoch_start(self, trainer, pl_module):
837863
if trainer.current_epoch == 1:
@@ -840,15 +866,21 @@ def on_train_epoch_start(self, trainer, pl_module):
840866
model = BoringModel()
841867
epoch_length = 64
842868
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
843-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()],
844-
max_epochs=5, logger=False, enable_progress_bar=False)
869+
trainer = Trainer(
870+
default_root_dir=tmp_path,
871+
callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()],
872+
max_epochs=5,
873+
logger=False,
874+
enable_progress_bar=False,
875+
)
845876
with pytest.raises(RuntimeError, match="Trouble!"):
846877
trainer.fit(model)
847878
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
848879

849880

850881
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_end(tmp_path):
851882
"""Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_end."""
883+
852884
class TroublemakerOnTrainEpochEnd(Callback):
853885
def on_train_epoch_end(self, trainer, pl_module):
854886
if trainer.current_epoch == 1:
@@ -857,49 +889,67 @@ def on_train_epoch_end(self, trainer, pl_module):
857889
model = BoringModel()
858890
epoch_length = 64
859891
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
860-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()],
861-
max_epochs=5, logger=False, enable_progress_bar=False)
892+
trainer = Trainer(
893+
default_root_dir=tmp_path,
894+
callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()],
895+
max_epochs=5,
896+
logger=False,
897+
enable_progress_bar=False,
898+
)
862899
with pytest.raises(RuntimeError, match="Trouble!"):
863900
trainer.fit(model)
864-
assert os.path.isfile(tmp_path / f"step={2*epoch_length}.ckpt")
901+
assert os.path.isfile(tmp_path / f"step={2 * epoch_length}.ckpt")
865902

866903

867904
def test_model_checkpoint_save_on_exception_in_val_callback(tmp_path):
868905
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_start."""
906+
869907
class TroublemakerOnValidationBatchStart(Callback):
870908
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
871-
if not trainer.sanity_checking and batch_idx == 1:
872-
raise RuntimeError("Trouble!")
909+
if not trainer.sanity_checking and batch_idx == 1:
910+
raise RuntimeError("Trouble!")
873911

874912
model = BoringModel()
875913
epoch_length = 64
876914
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
877-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()],
878-
max_epochs=5, logger=False, enable_progress_bar=False)
915+
trainer = Trainer(
916+
default_root_dir=tmp_path,
917+
callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()],
918+
max_epochs=5,
919+
logger=False,
920+
enable_progress_bar=False,
921+
)
879922
with pytest.raises(RuntimeError, match="Trouble!"):
880923
trainer.fit(model)
881924
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
882925

883926

884927
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_batch_end(tmp_path):
885928
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_end."""
929+
886930
class TroublemakerOnValidationBatchEnd(Callback):
887931
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
888-
if not trainer.sanity_checking and batch_idx == 1:
889-
raise RuntimeError("Trouble!")
932+
if not trainer.sanity_checking and batch_idx == 1:
933+
raise RuntimeError("Trouble!")
890934

891935
model = BoringModel()
892936
epoch_length = 64
893937
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
894-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()],
895-
max_epochs=5, logger=False, enable_progress_bar=False)
938+
trainer = Trainer(
939+
default_root_dir=tmp_path,
940+
callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()],
941+
max_epochs=5,
942+
logger=False,
943+
enable_progress_bar=False,
944+
)
896945
with pytest.raises(RuntimeError, match="Trouble!"):
897946
trainer.fit(model)
898947
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
899948

900949

901950
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_start(tmp_path):
902951
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_start."""
952+
903953
class TroublemakerOnValidationEpochStart(Callback):
904954
def on_validation_epoch_start(self, trainer, pl_module):
905955
if not trainer.sanity_checking and trainer.current_epoch == 0:
@@ -908,15 +958,21 @@ def on_validation_epoch_start(self, trainer, pl_module):
908958
model = BoringModel()
909959
epoch_length = 64
910960
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
911-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()],
912-
max_epochs=5, logger=False, enable_progress_bar=False)
961+
trainer = Trainer(
962+
default_root_dir=tmp_path,
963+
callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()],
964+
max_epochs=5,
965+
logger=False,
966+
enable_progress_bar=False,
967+
)
913968
with pytest.raises(RuntimeError, match="Trouble!"):
914969
trainer.fit(model)
915970
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
916971

917972

918973
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end(tmp_path):
919974
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end."""
975+
920976
class TroublemakerOnValidationEpochEnd(Callback):
921977
def on_validation_epoch_end(self, trainer, pl_module):
922978
if not trainer.sanity_checking and trainer.current_epoch == 0:
@@ -925,14 +981,21 @@ def on_validation_epoch_end(self, trainer, pl_module):
925981
model = BoringModel()
926982
epoch_length = 64
927983
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
928-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()],
929-
max_epochs=5, logger=False, enable_progress_bar=False)
984+
trainer = Trainer(
985+
default_root_dir=tmp_path,
986+
callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()],
987+
max_epochs=5,
988+
logger=False,
989+
enable_progress_bar=False,
990+
)
930991
with pytest.raises(RuntimeError, match="Trouble!"):
931992
trainer.fit(model)
932993
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
933994

995+
934996
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_start(tmp_path):
935997
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_start."""
998+
936999
class TroublemakerOnValidationStart(Callback):
9371000
def on_validation_start(self, trainer, pl_module):
9381001
if not trainer.sanity_checking:
@@ -941,14 +1004,21 @@ def on_validation_start(self, trainer, pl_module):
9411004
model = BoringModel()
9421005
epoch_length = 64
9431006
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
944-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationStart()],
945-
max_epochs=5, logger=False, enable_progress_bar=False)
1007+
trainer = Trainer(
1008+
default_root_dir=tmp_path,
1009+
callbacks=[checkpoint_callback, TroublemakerOnValidationStart()],
1010+
max_epochs=5,
1011+
logger=False,
1012+
enable_progress_bar=False,
1013+
)
9461014
with pytest.raises(RuntimeError, match="Trouble!"):
9471015
trainer.fit(model)
9481016
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
9491017

1018+
9501019
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_end(tmp_path):
9511020
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_end."""
1021+
9521022
class TroublemakerOnValidationEnd(Callback):
9531023
def on_validation_end(self, trainer, pl_module):
9541024
if not trainer.sanity_checking:
@@ -957,8 +1027,13 @@ def on_validation_end(self, trainer, pl_module):
9571027
model = BoringModel()
9581028
epoch_length = 64
9591029
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
960-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()],
961-
max_epochs=5, logger=False, enable_progress_bar=False)
1030+
trainer = Trainer(
1031+
default_root_dir=tmp_path,
1032+
callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()],
1033+
max_epochs=5,
1034+
logger=False,
1035+
enable_progress_bar=False,
1036+
)
9621037
with pytest.raises(RuntimeError, match="Trouble!"):
9631038
trainer.fit(model)
9641039
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")

0 commit comments

Comments
 (0)