From abaa560677fc87c90bbf8ceeb5e8d0ee586e6ac5 Mon Sep 17 00:00:00 2001 From: Hang Jung Ling <106228386+HangJung97@users.noreply.github.com> Date: Mon, 17 Jun 2024 08:42:08 +0200 Subject: [PATCH] Add `scheduler_config` to control the scheduler step interval (#165) --- ascent/configs/model/nnunet.yaml | 1 + .../model/scheduler_config/default.yaml | 18 ++++++++++++++++++ ascent/models/nnunet_module.py | 14 ++++++++++---- 3 files changed, 29 insertions(+), 4 deletions(-) create mode 100644 ascent/configs/model/scheduler_config/default.yaml diff --git a/ascent/configs/model/nnunet.yaml b/ascent/configs/model/nnunet.yaml index 66f558f..1078f41 100644 --- a/ascent/configs/model/nnunet.yaml +++ b/ascent/configs/model/nnunet.yaml @@ -2,6 +2,7 @@ defaults: - net: unet - optimizer: sgd - scheduler: polylr + - scheduler_config: default - loss: dice_ce _target_: ascent.models.nnunet_module.nnUNetLitModule diff --git a/ascent/configs/model/scheduler_config/default.yaml b/ascent/configs/model/scheduler_config/default.yaml new file mode 100644 index 0000000..89e725c --- /dev/null +++ b/ascent/configs/model/scheduler_config/default.yaml @@ -0,0 +1,18 @@ +# The unit of the scheduler's step size, could also be 'step'. +# 'epoch' updates the scheduler on epoch end whereas 'step' +# updates it after a optimizer update. +interval: "epoch" +# How many epochs/steps should pass between calls to +# `scheduler.step()`. 1 corresponds to updating the learning +# rate after every epoch/step. +frequency: 1 +# Metric to to monitor for schedulers like `ReduceLROnPlateau` +monitor: "val_loss" +# If set to `True`, will enforce that the value specified 'monitor' +# is available when the scheduler is updated, thus stopping +# training if not found. If set to `False`, it will only produce a warning +strict: True +# If using the `LearningRateMonitor` callback to monitor the +# learning rate progress, this keyword can be used to specify +# a custom logged name +name: null diff --git a/ascent/models/nnunet_module.py b/ascent/models/nnunet_module.py index 9844514..bd997e7 100644 --- a/ascent/models/nnunet_module.py +++ b/ascent/models/nnunet_module.py @@ -2,7 +2,7 @@ import time from collections import OrderedDict from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Dict, Literal, Optional, Union import numpy as np import SimpleITK as sitk @@ -34,6 +34,7 @@ def __init__( optimizer: torch.optim.Optimizer, loss: torch.nn.Module, scheduler: torch.optim.lr_scheduler._LRScheduler, + scheduler_config: dict[str, Any], tta: bool = True, sliding_window_overlap: float = 0.5, sliding_window_importance_map: bool = "gaussian", @@ -430,15 +431,20 @@ def predict_step(self, batch: dict[str, Tensor], batch_idx: int): # noqa: D102 self.save_mask(final_preds, fname, spacing, save_dir) - def configure_optimizers(self) -> dict[Literal["optimizer", "lr_scheduler"], Any]: + def configure_optimizers(self) -> dict[str, Any]: """Configures optimizers/LR schedulers. Returns: - A dict with an `optimizer` key, and an optional `lr_scheduler` if a scheduler is used. + A dict with an `optimizer` key, and an optional `scheduler_config` if a scheduler is used. """ configured_optimizer = {"optimizer": self.hparams.optimizer(params=self.parameters())} if self.hparams.scheduler: - configured_optimizer["lr_scheduler"] = self.hparams.scheduler( + configured_optimizer["lr_scheduler"] = dict(self.hparams.scheduler_config) + if configured_optimizer["lr_scheduler"].get("frequency"): + configured_optimizer["lr_scheduler"]["frequency"] = int( + configured_optimizer["lr_scheduler"]["frequency"] + ) + configured_optimizer["lr_scheduler"]["scheduler"] = self.hparams.scheduler( optimizer=configured_optimizer["optimizer"] ) return configured_optimizer