@@ -766,22 +766,30 @@ def test_ckpt_every_n_train_steps(tmp_path):
766
766
767
767
def test_model_checkpoint_save_on_exception_in_training_step (tmp_path ):
768
768
"""Test that the checkpoint is saved when an exception is raised in training_step."""
769
+
769
770
class TroubledModel (BoringModel ):
770
771
def training_step (self , batch , batch_idx ):
771
772
if batch_idx == 1 :
772
773
raise RuntimeError ("Trouble!" )
773
774
774
775
model = TroubledModel ()
775
776
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
776
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ],
777
- max_epochs = 5 , logger = False , enable_progress_bar = False )
777
+ trainer = Trainer (
778
+ default_root_dir = tmp_path ,
779
+ callbacks = [checkpoint_callback ],
780
+ max_epochs = 5 ,
781
+ logger = False ,
782
+ enable_progress_bar = False ,
783
+ )
778
784
with pytest .raises (RuntimeError , match = "Trouble!" ):
779
785
trainer .fit (model )
780
786
print (os .listdir (tmp_path ))
781
787
assert os .path .isfile (tmp_path / "step=1.ckpt" )
782
788
789
+
783
790
def test_model_checkpoint_save_on_exception_in_validation_step (tmp_path ):
784
791
"""Test that the checkpoint is saved when an exception is raised in validation_step."""
792
+
785
793
class TroubledModel (BoringModel ):
786
794
def validation_step (self , batch , batch_idx ):
787
795
if not trainer .sanity_checking and batch_idx == 0 :
@@ -790,40 +798,57 @@ def validation_step(self, batch, batch_idx):
790
798
model = TroubledModel ()
791
799
epoch_length = 64
792
800
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
793
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback ],
794
- max_epochs = 5 , logger = False , enable_progress_bar = False )
801
+ trainer = Trainer (
802
+ default_root_dir = tmp_path ,
803
+ callbacks = [checkpoint_callback ],
804
+ max_epochs = 5 ,
805
+ logger = False ,
806
+ enable_progress_bar = False ,
807
+ )
795
808
with pytest .raises (RuntimeError , match = "Trouble!" ):
796
809
trainer .fit (model )
797
810
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
798
811
799
812
800
813
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start (tmp_path ):
801
814
"""Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start."""
815
+
802
816
class TroublemakerOnTrainBatchStart (Callback ):
803
817
def on_train_batch_start (self , trainer , pl_module , batch , batch_idx ):
804
818
if batch_idx == 1 :
805
819
raise RuntimeError ("Trouble!" )
806
820
807
821
model = BoringModel ()
808
822
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
809
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchStart ()],
810
- max_epochs = 5 , logger = False , enable_progress_bar = False )
823
+ trainer = Trainer (
824
+ default_root_dir = tmp_path ,
825
+ callbacks = [checkpoint_callback , TroublemakerOnTrainBatchStart ()],
826
+ max_epochs = 5 ,
827
+ logger = False ,
828
+ enable_progress_bar = False ,
829
+ )
811
830
with pytest .raises (RuntimeError , match = "Trouble!" ):
812
831
trainer .fit (model )
813
832
assert os .path .isfile (tmp_path / "step=1.ckpt" )
814
833
815
834
816
835
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_end (tmp_path ):
817
836
"""Test that the checkpoint is saved when an exception is raised in a callback on train_batch_end."""
837
+
818
838
class TroublemakerOnTrainBatchEnd (Callback ):
819
839
def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
820
840
if batch_idx == 1 :
821
841
raise RuntimeError ("Trouble!" )
822
842
823
843
model = BoringModel ()
824
844
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
825
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainBatchEnd ()],
826
- max_epochs = 5 , logger = False , enable_progress_bar = False )
845
+ trainer = Trainer (
846
+ default_root_dir = tmp_path ,
847
+ callbacks = [checkpoint_callback , TroublemakerOnTrainBatchEnd ()],
848
+ max_epochs = 5 ,
849
+ logger = False ,
850
+ enable_progress_bar = False ,
851
+ )
827
852
with pytest .raises (RuntimeError , match = "Trouble!" ):
828
853
trainer .fit (model )
829
854
@@ -832,6 +857,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
832
857
833
858
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_start (tmp_path ):
834
859
"""Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_start."""
860
+
835
861
class TroublemakerOnTrainEpochStart (Callback ):
836
862
def on_train_epoch_start (self , trainer , pl_module ):
837
863
if trainer .current_epoch == 1 :
@@ -840,15 +866,21 @@ def on_train_epoch_start(self, trainer, pl_module):
840
866
model = BoringModel ()
841
867
epoch_length = 64
842
868
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
843
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()],
844
- max_epochs = 5 , logger = False , enable_progress_bar = False )
869
+ trainer = Trainer (
870
+ default_root_dir = tmp_path ,
871
+ callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()],
872
+ max_epochs = 5 ,
873
+ logger = False ,
874
+ enable_progress_bar = False ,
875
+ )
845
876
with pytest .raises (RuntimeError , match = "Trouble!" ):
846
877
trainer .fit (model )
847
878
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
848
879
849
880
850
881
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_end (tmp_path ):
851
882
"""Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_end."""
883
+
852
884
class TroublemakerOnTrainEpochEnd (Callback ):
853
885
def on_train_epoch_end (self , trainer , pl_module ):
854
886
if trainer .current_epoch == 1 :
@@ -857,49 +889,67 @@ def on_train_epoch_end(self, trainer, pl_module):
857
889
model = BoringModel ()
858
890
epoch_length = 64
859
891
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
860
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()],
861
- max_epochs = 5 , logger = False , enable_progress_bar = False )
892
+ trainer = Trainer (
893
+ default_root_dir = tmp_path ,
894
+ callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()],
895
+ max_epochs = 5 ,
896
+ logger = False ,
897
+ enable_progress_bar = False ,
898
+ )
862
899
with pytest .raises (RuntimeError , match = "Trouble!" ):
863
900
trainer .fit (model )
864
- assert os .path .isfile (tmp_path / f"step={ 2 * epoch_length } .ckpt" )
901
+ assert os .path .isfile (tmp_path / f"step={ 2 * epoch_length } .ckpt" )
865
902
866
903
867
904
def test_model_checkpoint_save_on_exception_in_val_callback (tmp_path ):
868
905
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_start."""
906
+
869
907
class TroublemakerOnValidationBatchStart (Callback ):
870
908
def on_validation_batch_start (self , trainer , pl_module , batch , batch_idx ):
871
- if not trainer .sanity_checking and batch_idx == 1 :
872
- raise RuntimeError ("Trouble!" )
909
+ if not trainer .sanity_checking and batch_idx == 1 :
910
+ raise RuntimeError ("Trouble!" )
873
911
874
912
model = BoringModel ()
875
913
epoch_length = 64
876
914
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
877
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchStart ()],
878
- max_epochs = 5 , logger = False , enable_progress_bar = False )
915
+ trainer = Trainer (
916
+ default_root_dir = tmp_path ,
917
+ callbacks = [checkpoint_callback , TroublemakerOnValidationBatchStart ()],
918
+ max_epochs = 5 ,
919
+ logger = False ,
920
+ enable_progress_bar = False ,
921
+ )
879
922
with pytest .raises (RuntimeError , match = "Trouble!" ):
880
923
trainer .fit (model )
881
924
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
882
925
883
926
884
927
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_batch_end (tmp_path ):
885
928
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_end."""
929
+
886
930
class TroublemakerOnValidationBatchEnd (Callback ):
887
931
def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
888
- if not trainer .sanity_checking and batch_idx == 1 :
889
- raise RuntimeError ("Trouble!" )
932
+ if not trainer .sanity_checking and batch_idx == 1 :
933
+ raise RuntimeError ("Trouble!" )
890
934
891
935
model = BoringModel ()
892
936
epoch_length = 64
893
937
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
894
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationBatchEnd ()],
895
- max_epochs = 5 , logger = False , enable_progress_bar = False )
938
+ trainer = Trainer (
939
+ default_root_dir = tmp_path ,
940
+ callbacks = [checkpoint_callback , TroublemakerOnValidationBatchEnd ()],
941
+ max_epochs = 5 ,
942
+ logger = False ,
943
+ enable_progress_bar = False ,
944
+ )
896
945
with pytest .raises (RuntimeError , match = "Trouble!" ):
897
946
trainer .fit (model )
898
947
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
899
948
900
949
901
950
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_start (tmp_path ):
902
951
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_start."""
952
+
903
953
class TroublemakerOnValidationEpochStart (Callback ):
904
954
def on_validation_epoch_start (self , trainer , pl_module ):
905
955
if not trainer .sanity_checking and trainer .current_epoch == 0 :
@@ -908,15 +958,21 @@ def on_validation_epoch_start(self, trainer, pl_module):
908
958
model = BoringModel ()
909
959
epoch_length = 64
910
960
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
911
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()],
912
- max_epochs = 5 , logger = False , enable_progress_bar = False )
961
+ trainer = Trainer (
962
+ default_root_dir = tmp_path ,
963
+ callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()],
964
+ max_epochs = 5 ,
965
+ logger = False ,
966
+ enable_progress_bar = False ,
967
+ )
913
968
with pytest .raises (RuntimeError , match = "Trouble!" ):
914
969
trainer .fit (model )
915
970
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
916
971
917
972
918
973
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end (tmp_path ):
919
974
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end."""
975
+
920
976
class TroublemakerOnValidationEpochEnd (Callback ):
921
977
def on_validation_epoch_end (self , trainer , pl_module ):
922
978
if not trainer .sanity_checking and trainer .current_epoch == 0 :
@@ -925,14 +981,21 @@ def on_validation_epoch_end(self, trainer, pl_module):
925
981
model = BoringModel ()
926
982
epoch_length = 64
927
983
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
928
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()],
929
- max_epochs = 5 , logger = False , enable_progress_bar = False )
984
+ trainer = Trainer (
985
+ default_root_dir = tmp_path ,
986
+ callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()],
987
+ max_epochs = 5 ,
988
+ logger = False ,
989
+ enable_progress_bar = False ,
990
+ )
930
991
with pytest .raises (RuntimeError , match = "Trouble!" ):
931
992
trainer .fit (model )
932
993
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
933
994
995
+
934
996
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_start (tmp_path ):
935
997
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_start."""
998
+
936
999
class TroublemakerOnValidationStart (Callback ):
937
1000
def on_validation_start (self , trainer , pl_module ):
938
1001
if not trainer .sanity_checking :
@@ -941,14 +1004,21 @@ def on_validation_start(self, trainer, pl_module):
941
1004
model = BoringModel ()
942
1005
epoch_length = 64
943
1006
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
944
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()],
945
- max_epochs = 5 , logger = False , enable_progress_bar = False )
1007
+ trainer = Trainer (
1008
+ default_root_dir = tmp_path ,
1009
+ callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()],
1010
+ max_epochs = 5 ,
1011
+ logger = False ,
1012
+ enable_progress_bar = False ,
1013
+ )
946
1014
with pytest .raises (RuntimeError , match = "Trouble!" ):
947
1015
trainer .fit (model )
948
1016
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
949
1017
1018
+
950
1019
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_end (tmp_path ):
951
1020
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_end."""
1021
+
952
1022
class TroublemakerOnValidationEnd (Callback ):
953
1023
def on_validation_end (self , trainer , pl_module ):
954
1024
if not trainer .sanity_checking :
@@ -957,8 +1027,13 @@ def on_validation_end(self, trainer, pl_module):
957
1027
model = BoringModel ()
958
1028
epoch_length = 64
959
1029
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
960
- trainer = Trainer (default_root_dir = tmp_path , callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()],
961
- max_epochs = 5 , logger = False , enable_progress_bar = False )
1030
+ trainer = Trainer (
1031
+ default_root_dir = tmp_path ,
1032
+ callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()],
1033
+ max_epochs = 5 ,
1034
+ logger = False ,
1035
+ enable_progress_bar = False ,
1036
+ )
962
1037
with pytest .raises (RuntimeError , match = "Trouble!" ):
963
1038
trainer .fit (model )
964
1039
assert os .path .isfile (tmp_path / f"step={ epoch_length } .ckpt" )
0 commit comments