@@ -796,12 +796,13 @@ def validation_step(self, batch, batch_idx):
796
796
raise RuntimeError ("Trouble!" )
797
797
798
798
model = TroubledModel ()
799
- epoch_length = 64
799
+ epoch_length = 2
800
800
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
801
801
trainer = Trainer (
802
802
default_root_dir = tmp_path ,
803
803
callbacks = [checkpoint_callback ],
804
804
max_epochs = 5 ,
805
+ limit_train_batches = epoch_length ,
805
806
logger = False ,
806
807
enable_progress_bar = False ,
807
808
)
@@ -864,12 +865,13 @@ def on_train_epoch_start(self, trainer, pl_module):
864
865
raise RuntimeError ("Trouble!" )
865
866
866
867
model = BoringModel ()
867
- epoch_length = 64
868
+ epoch_length = 2
868
869
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
869
870
trainer = Trainer (
870
871
default_root_dir = tmp_path ,
871
872
callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()],
872
873
max_epochs = 5 ,
874
+ limit_train_batches = epoch_length ,
873
875
logger = False ,
874
876
enable_progress_bar = False ,
875
877
)
@@ -887,12 +889,13 @@ def on_train_epoch_end(self, trainer, pl_module):
887
889
raise RuntimeError ("Trouble!" )
888
890
889
891
model = BoringModel ()
890
- epoch_length = 64
892
+ epoch_length = 2
891
893
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
892
894
trainer = Trainer (
893
895
default_root_dir = tmp_path ,
894
896
callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()],
895
897
max_epochs = 5 ,
898
+ limit_train_batches = epoch_length ,
896
899
logger = False ,
897
900
enable_progress_bar = False ,
898
901
)
@@ -956,12 +959,13 @@ def on_validation_epoch_start(self, trainer, pl_module):
956
959
raise RuntimeError ("Trouble!" )
957
960
958
961
model = BoringModel ()
959
- epoch_length = 64
962
+ epoch_length = 2
960
963
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
961
964
trainer = Trainer (
962
965
default_root_dir = tmp_path ,
963
966
callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()],
964
967
max_epochs = 5 ,
968
+ limit_train_batches = epoch_length ,
965
969
logger = False ,
966
970
enable_progress_bar = False ,
967
971
)
@@ -979,12 +983,13 @@ def on_validation_epoch_end(self, trainer, pl_module):
979
983
raise RuntimeError ("Trouble!" )
980
984
981
985
model = BoringModel ()
982
- epoch_length = 64
986
+ epoch_length = 2
983
987
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
984
988
trainer = Trainer (
985
989
default_root_dir = tmp_path ,
986
990
callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()],
987
991
max_epochs = 5 ,
992
+ limit_train_batches = epoch_length ,
988
993
logger = False ,
989
994
enable_progress_bar = False ,
990
995
)
@@ -1002,12 +1007,13 @@ def on_validation_start(self, trainer, pl_module):
1002
1007
raise RuntimeError ("Trouble!" )
1003
1008
1004
1009
model = BoringModel ()
1005
- epoch_length = 64
1010
+ epoch_length = 2
1006
1011
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
1007
1012
trainer = Trainer (
1008
1013
default_root_dir = tmp_path ,
1009
1014
callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()],
1010
1015
max_epochs = 5 ,
1016
+ limit_train_batches = epoch_length ,
1011
1017
logger = False ,
1012
1018
enable_progress_bar = False ,
1013
1019
)
@@ -1025,12 +1031,13 @@ def on_validation_end(self, trainer, pl_module):
1025
1031
raise RuntimeError ("Trouble!" )
1026
1032
1027
1033
model = BoringModel ()
1028
- epoch_length = 64
1034
+ epoch_length = 2
1029
1035
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
1030
1036
trainer = Trainer (
1031
1037
default_root_dir = tmp_path ,
1032
1038
callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()],
1033
1039
max_epochs = 5 ,
1040
+ limit_train_batches = epoch_length ,
1034
1041
logger = False ,
1035
1042
enable_progress_bar = False ,
1036
1043
)
0 commit comments