Skip to content

Commit

Permalink
Separate epoch validation from step validation (#5208)
Browse files Browse the repository at this point in the history
* Seperate epoch validaton from step validation

* update system

* test

* baked logic in callbacks

* unbake logic in callbacks

* fix the call for scheduler

* use property

* pep

* correct rebase

* gitignore

* ref

* add tests

* fix

* add early stopping test

* trigger

* chlog

* rev

* 1.3

* log

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <[email protected]>

* Update pytorch_lightning/trainer/training_loop.py

* Update CHANGELOG.md

* Apply suggestions from code review

Co-authored-by: chaton <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
5 people authored Feb 8, 2021
1 parent 3b7afb9 commit e429f97
Show file tree
Hide file tree
Showing 13 changed files with 194 additions and 85 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,6 @@ pytorch\ lightning
test-reports/
wandb
.forked/

# ctags
tags
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [1.1.8] - 2021-02-06
## [1.1.8] - 2021-02-08

### Fixed

- Separate epoch validation from step validation ([#5208](https://github.com/PyTorchLightning/pytorch-lightning/pull/5208))
- Fixed `toggle_optimizers` not handling all optimizer parameters ([#5775](https://github.com/PyTorchLightning/pytorch-lightning/pull/5775))


## [1.1.7] - 2021-02-03

### Fixed
Expand Down
18 changes: 0 additions & 18 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ def __init__(
self.stopped_epoch = 0
self.mode = mode
self.warned_result_obj = False
# Indicates, if eval results are used as basis for early stopping
# It is set to False initially and overwritten, if eval results have been validated
self.based_on_eval_results = False

self.__init_monitor_mode()

Expand Down Expand Up @@ -164,21 +161,6 @@ def on_validation_end(self, trainer, pl_module):

self._run_early_stopping_check(trainer, pl_module)

def on_validation_epoch_end(self, trainer, pl_module):
if trainer.fast_dev_run or trainer.running_sanity_check:
return

if self._validate_condition_metric(trainer.callback_metrics):
# turn off early stopping in on_train_epoch_end
self.based_on_eval_results = True

def on_train_epoch_end(self, trainer, pl_module, outputs):
# disable early stopping in train loop when there's a val loop
if self.based_on_eval_results:
return

self._run_early_stopping_check(trainer, pl_module)

def _run_early_stopping_check(self, trainer, pl_module):
"""
Checks whether the early stopping condition is met
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.last_global_step_saved = -1
self._last_global_step_saved = -1
self.prefix = prefix
self.current_score = None
self.best_k_models = {}
Expand Down Expand Up @@ -231,15 +231,15 @@ def save_checkpoint(self, trainer, pl_module):
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
or self.last_global_step_saved == global_step # already saved at the last step
or self._last_global_step_saved == global_step # already saved at the last step
):
return

self._add_backward_monitor_support(trainer)
self._validate_monitor_key(trainer)

# track epoch when ckpt was last checked
self.last_global_step_saved = global_step
self._last_global_step_saved = global_step

# what can be monitored
monitor_candidates = self._monitor_candidates(trainer)
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def save_checkpoint(self, filepath, weights_only: bool = False):
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}'
'Warning, `hyper_parameters` dropped from checkpoint.'
f' An attribute is not picklable {err}'
)
atomic_save(checkpoint, filepath)
13 changes: 2 additions & 11 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,8 @@ def get_evaluation_dataloaders(self, max_batches):

return dataloaders, max_batches

def should_skip_evaluation(self, dataloaders, max_batches):
# skip when dataloaders aren't defined
if dataloaders is None:
return True

# enable disabling validation step with limit_val_batches = 0
should_skip = sum(max_batches) == 0
if should_skip:
return True

return False
def should_skip_evaluation(self, max_batches):
return sum(max_batches) == 0

def on_evaluation_start(self, *args, **kwargs):
if self.trainer.testing:
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,6 @@ def train(self):
if self.max_steps and self.max_steps <= self.global_step:
return

# update LR schedulers
self.optimizer_connector.update_learning_rates(interval='epoch')

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
Expand All @@ -591,7 +588,7 @@ def train(self):
# hook
self.train_loop.on_train_end()

def run_evaluation(self, max_batches=None):
def run_evaluation(self, max_batches=None, on_epoch=False):

# used to know if we are logging for val, test + reset cached results
self.logger_connector.set_stage(self.testing, reset=True)
Expand All @@ -603,7 +600,7 @@ def run_evaluation(self, max_batches=None):
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches)

# check if we want to skip this evaluation
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
if self.evaluation_loop.should_skip_evaluation(max_batches):
return [], []

# ref model
Expand Down Expand Up @@ -664,6 +661,10 @@ def run_evaluation(self, max_batches=None):
# hook
self.evaluation_loop.on_evaluation_epoch_end()

# update epoch-level lr_schedulers
if on_epoch:
self.optimizer_connector.update_learning_rates(interval='epoch')

# hook
self.evaluation_loop.on_evaluation_end()

Expand Down
78 changes: 53 additions & 25 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.distributed as torch_distrib

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -153,7 +153,7 @@ def on_train_end(self):
# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
# when a checkpoint was saved at the last step
self.trainer.global_step -= 1
self.check_checkpoint_callback(should_save=True, is_last=True)
self.check_checkpoint_callback(should_update=True, is_last=True)
self.trainer.global_step += 1

# hook
Expand All @@ -176,18 +176,27 @@ def on_train_end(self):
model.cpu()
torch.cuda.empty_cache()

def check_checkpoint_callback(self, should_save, is_last=False):
# TODO bake this logic into the checkpoint callback
if should_save and self.trainer.checkpoint_connector.has_trained:
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
def check_checkpoint_callback(self, should_update, is_last=False):
# TODO bake this logic into the ModelCheckpoint callback
if should_update and self.trainer.checkpoint_connector.has_trained:
callbacks = self.trainer.checkpoint_callbacks

if is_last and any(c.save_last for c in checkpoint_callbacks):
if is_last and any(cb.save_last for cb in callbacks):
rank_zero_info("Saving latest checkpoint...")

model = self.trainer.get_model()

for callback in checkpoint_callbacks:
callback.on_validation_end(self.trainer, model)
for cb in callbacks:
cb.on_validation_end(self.trainer, model)

def check_early_stopping_callback(self, should_update):
# TODO bake this logic into the EarlyStopping callback
if should_update and self.trainer.checkpoint_connector.has_trained:
callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)]
model = self.trainer.get_model()

for cb in callbacks:
cb.on_validation_end(self.trainer, model)

def on_train_epoch_start(self, epoch):

Expand Down Expand Up @@ -518,7 +527,6 @@ def tbptt_split_batch(self, batch):
return splits

def run_training_epoch(self):

# get model
model = self.trainer.get_model()

Expand All @@ -531,7 +539,6 @@ def run_training_epoch(self):
# enable profiling for the dataloader
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
should_check_val = False
for batch_idx, (batch, is_last_batch) in train_dataloader:

self.trainer.batch_idx = batch_idx
Expand Down Expand Up @@ -580,11 +587,12 @@ def run_training_epoch(self):
self.trainer.checkpoint_connector.has_trained = True

# max steps reached, end training
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
accumulation_done = self._accumulated_batches_reached()
# Ensure accumulation across batches has completed before breaking loop
if accumulation_done:
break
if (
self.trainer.max_steps is not None
and self.trainer.max_steps == self.trainer.global_step + 1
and self._accumulated_batches_reached()
):
break

# end epoch early
# stop when the flag is changed or we've gone past the amount
Expand All @@ -595,7 +603,7 @@ def run_training_epoch(self):
self.trainer.total_batch_idx += 1

# stop epoch if we limited the number of training batches
if (batch_idx + 1) >= self.trainer.num_training_batches:
if self._num_training_batches_reached(is_last_batch):
break

# progress global step according to grads progress
Expand All @@ -612,8 +620,20 @@ def run_training_epoch(self):
self.num_optimizers
)

# when no val loop is present or fast-dev-run still need to call checkpoints
self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model)))
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
if should_check_val:
self.trainer.run_evaluation(on_epoch=True)
# reset stage to train
self.trainer.logger_connector.set_stage("train")

should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval

if should_train_only:
# update epoch level lr_schedulers
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
self.check_checkpoint_callback(True)
self.check_early_stopping_callback(True)

# increment the global step once
# progress global step according to grads progress
Expand Down Expand Up @@ -853,25 +873,33 @@ def increment_accumulated_grad_global_step(self):
def _accumulated_batches_reached(self):
return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0

def _num_training_batches_reached(self):
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches
def _num_training_batches_reached(self, is_last_batch=False):
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch

def should_accumulate(self):
# checks if backward or backward + optimizer step (via closure)
accumulation_done = self._accumulated_batches_reached()
is_final_batch = self._num_training_batches_reached()
return not (accumulation_done or is_final_batch)

def should_check_val_fx(self, batch_idx, is_last_batch):
def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
can_check_val = self.trainer.enable_validation and is_val_check_epoch
should_check_val = is_val_check_batch or self.trainer.should_stop
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset)
epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches

should_check_val = (
(is_val_check_batch and epoch_end_val_check)
or self.trainer.should_stop
or is_last_batch_for_infinite_dataset
) if on_epoch else (
is_val_check_batch
and not epoch_end_val_check
)

return should_check_val
return should_check_val and can_check_val

def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
# enable not needing to add opt_idx to training_step
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ def test_trainer_callback_system(torch_save):
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_batch_end(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
call.on_epoch_end(trainer, model),
call.on_train_epoch_end(trainer, model, ANY),
call.on_validation_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model),
call.on_epoch_end(trainer, model),
call.on_train_epoch_end(trainer, model, ANY),
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, 'fit'),
Expand Down
39 changes: 36 additions & 3 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,9 @@ def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_ep

class ModelOverrideValidationReturn(EvalModelTemplate):
validation_return_values = torch.Tensor(loss_values)
count = 0

def validation_epoch_end(self, outputs):
loss = self.validation_return_values[self.count]
self.count += 1
loss = self.validation_return_values[self.current_epoch]
return {"test_val_loss": loss}

model = ModelOverrideValidationReturn()
Expand All @@ -133,6 +131,41 @@ def validation_epoch_end(self, outputs):
assert trainer.current_epoch == expected_stop_epoch


@pytest.mark.parametrize('validation_step', ['base', None])
@pytest.mark.parametrize(
"loss_values, patience, expected_stop_epoch",
[
([6, 5, 5, 5, 5, 5], 3, 4),
([6, 5, 4, 4, 3, 3], 1, 3),
([6, 5, 6, 5, 5, 5], 3, 4),
],
)
def test_early_stopping_patience_train(tmpdir, validation_step, loss_values, patience, expected_stop_epoch):
"""Test to ensure that early stopping is not triggered before patience is exhausted."""

class ModelOverrideTrainReturn(EvalModelTemplate):
train_return_values = torch.Tensor(loss_values)

def training_epoch_end(self, outputs):
loss = self.train_return_values[self.current_epoch]
self.log('train_loss', loss)

model = ModelOverrideTrainReturn()

if validation_step is None:
model.validation_step = None

early_stop_callback = EarlyStopping(monitor="train_loss", patience=patience, verbose=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
num_sanity_val_steps=0,
max_epochs=10,
)
trainer.fit(model)
assert trainer.current_epoch == expected_stop_epoch


def test_pickling(tmpdir):
early_stopping = EarlyStopping()

Expand Down
3 changes: 2 additions & 1 deletion tests/checkpointing/test_checkpoint_callback_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval,
default_root_dir=tmpdir,
max_epochs=epochs,
weights_summary=None,
val_check_interval=val_check_interval
val_check_interval=val_check_interval,
progress_bar_refresh_rate=0,
)
trainer.fit(model)

Expand Down
Loading

0 comments on commit e429f97

Please sign in to comment.