Skip to content

Commit

Permalink
Add scheduler_config to control the scheduler step interval (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
HangJung97 authored Jun 17, 2024
1 parent 4ceef9b commit abaa560
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
1 change: 1 addition & 0 deletions ascent/configs/model/nnunet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ defaults:
- net: unet
- optimizer: sgd
- scheduler: polylr
- scheduler_config: default
- loss: dice_ce

_target_: ascent.models.nnunet_module.nnUNetLitModule
Expand Down
18 changes: 18 additions & 0 deletions ascent/configs/model/scheduler_config/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
interval: "epoch"
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
frequency: 1
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
monitor: "val_loss"
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
strict: True
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
name: null
14 changes: 10 additions & 4 deletions ascent/models/nnunet_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
from collections import OrderedDict
from pathlib import Path
from typing import Any, Literal, Optional, Union
from typing import Any, Dict, Literal, Optional, Union

import numpy as np
import SimpleITK as sitk
Expand Down Expand Up @@ -34,6 +34,7 @@ def __init__(
optimizer: torch.optim.Optimizer,
loss: torch.nn.Module,
scheduler: torch.optim.lr_scheduler._LRScheduler,
scheduler_config: dict[str, Any],
tta: bool = True,
sliding_window_overlap: float = 0.5,
sliding_window_importance_map: bool = "gaussian",
Expand Down Expand Up @@ -430,15 +431,20 @@ def predict_step(self, batch: dict[str, Tensor], batch_idx: int): # noqa: D102

self.save_mask(final_preds, fname, spacing, save_dir)

def configure_optimizers(self) -> dict[Literal["optimizer", "lr_scheduler"], Any]:
def configure_optimizers(self) -> dict[str, Any]:
"""Configures optimizers/LR schedulers.
Returns:
A dict with an `optimizer` key, and an optional `lr_scheduler` if a scheduler is used.
A dict with an `optimizer` key, and an optional `scheduler_config` if a scheduler is used.
"""
configured_optimizer = {"optimizer": self.hparams.optimizer(params=self.parameters())}
if self.hparams.scheduler:
configured_optimizer["lr_scheduler"] = self.hparams.scheduler(
configured_optimizer["lr_scheduler"] = dict(self.hparams.scheduler_config)
if configured_optimizer["lr_scheduler"].get("frequency"):
configured_optimizer["lr_scheduler"]["frequency"] = int(
configured_optimizer["lr_scheduler"]["frequency"]
)
configured_optimizer["lr_scheduler"]["scheduler"] = self.hparams.scheduler(
optimizer=configured_optimizer["optimizer"]
)
return configured_optimizer
Expand Down

0 comments on commit abaa560

Please sign in to comment.