Skip to content

Commit f0502ec

Browse files
committed
Update save checkpoint on exception tests to use a shorter more precisly defined epoch lenght
1 parent 985c1e1 commit f0502ec

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -796,12 +796,13 @@ def validation_step(self, batch, batch_idx):
796796
raise RuntimeError("Trouble!")
797797

798798
model = TroubledModel()
799-
epoch_length = 64
799+
epoch_length = 2
800800
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
801801
trainer = Trainer(
802802
default_root_dir=tmp_path,
803803
callbacks=[checkpoint_callback],
804804
max_epochs=5,
805+
limit_train_batches=epoch_length,
805806
logger=False,
806807
enable_progress_bar=False,
807808
)
@@ -864,12 +865,13 @@ def on_train_epoch_start(self, trainer, pl_module):
864865
raise RuntimeError("Trouble!")
865866

866867
model = BoringModel()
867-
epoch_length = 64
868+
epoch_length = 2
868869
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
869870
trainer = Trainer(
870871
default_root_dir=tmp_path,
871872
callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()],
872873
max_epochs=5,
874+
limit_train_batches=epoch_length,
873875
logger=False,
874876
enable_progress_bar=False,
875877
)
@@ -887,12 +889,13 @@ def on_train_epoch_end(self, trainer, pl_module):
887889
raise RuntimeError("Trouble!")
888890

889891
model = BoringModel()
890-
epoch_length = 64
892+
epoch_length = 2
891893
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
892894
trainer = Trainer(
893895
default_root_dir=tmp_path,
894896
callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()],
895897
max_epochs=5,
898+
limit_train_batches=epoch_length,
896899
logger=False,
897900
enable_progress_bar=False,
898901
)
@@ -956,12 +959,13 @@ def on_validation_epoch_start(self, trainer, pl_module):
956959
raise RuntimeError("Trouble!")
957960

958961
model = BoringModel()
959-
epoch_length = 64
962+
epoch_length = 2
960963
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
961964
trainer = Trainer(
962965
default_root_dir=tmp_path,
963966
callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()],
964967
max_epochs=5,
968+
limit_train_batches=epoch_length,
965969
logger=False,
966970
enable_progress_bar=False,
967971
)
@@ -979,12 +983,13 @@ def on_validation_epoch_end(self, trainer, pl_module):
979983
raise RuntimeError("Trouble!")
980984

981985
model = BoringModel()
982-
epoch_length = 64
986+
epoch_length = 2
983987
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
984988
trainer = Trainer(
985989
default_root_dir=tmp_path,
986990
callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()],
987991
max_epochs=5,
992+
limit_train_batches=epoch_length,
988993
logger=False,
989994
enable_progress_bar=False,
990995
)
@@ -1002,12 +1007,13 @@ def on_validation_start(self, trainer, pl_module):
10021007
raise RuntimeError("Trouble!")
10031008

10041009
model = BoringModel()
1005-
epoch_length = 64
1010+
epoch_length = 2
10061011
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
10071012
trainer = Trainer(
10081013
default_root_dir=tmp_path,
10091014
callbacks=[checkpoint_callback, TroublemakerOnValidationStart()],
10101015
max_epochs=5,
1016+
limit_train_batches=epoch_length,
10111017
logger=False,
10121018
enable_progress_bar=False,
10131019
)
@@ -1025,12 +1031,13 @@ def on_validation_end(self, trainer, pl_module):
10251031
raise RuntimeError("Trouble!")
10261032

10271033
model = BoringModel()
1028-
epoch_length = 64
1034+
epoch_length = 2
10291035
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
10301036
trainer = Trainer(
10311037
default_root_dir=tmp_path,
10321038
callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()],
10331039
max_epochs=5,
1040+
limit_train_batches=epoch_length,
10341041
logger=False,
10351042
enable_progress_bar=False,
10361043
)

0 commit comments

Comments
 (0)