Skip to content

Commit

Permalink
Merge pull request #330 from fmartiescofet/implement_multiple_test_dl
Browse files Browse the repository at this point in the history
Feat: Implement multiple test dataloaders in all tasks
  • Loading branch information
Joao-L-S-Almeida authored Jan 27, 2025
2 parents 375cc4a + e641e7c commit 9140dc1
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 18 deletions.
5 changes: 3 additions & 2 deletions terratorch/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ def on_validation_epoch_end(self) -> None:
self.val_metrics.reset()

def on_test_epoch_end(self) -> None:
self.log_dict(self.test_metrics.compute(), sync_dist=True)
self.test_metrics.reset()
for metrics in self.test_metrics:
self.log_dict(metrics.compute(), sync_dist=True)
metrics.reset()

def _do_plot_samples(self, batch_index):
if not self.plot_on_val: # dont plot if self.plot_on_val is 0
Expand Down
30 changes: 25 additions & 5 deletions terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Any
import logging
import lightning
Expand Down Expand Up @@ -34,6 +35,7 @@ class ClassificationTask(TerraTorchTask):
- Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
- Allows the setting of optimizers in the constructor
- It provides mIoU with both Micro and Macro averaging
- Allows to evaluate on multiple test dataloaders
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
Expand Down Expand Up @@ -63,6 +65,7 @@ def __init__(
freeze_backbone: bool = False, # noqa: FBT001, FBT002
freeze_decoder: bool = False, # noqa: FBT002, FBT001
class_names: list[str] | None = None,
test_dataloaders_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
) -> None:
"""Constructor
Expand Down Expand Up @@ -99,6 +102,9 @@ def __init__(
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
Defaults to numeric ordering.
test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
which assumes only one test dataloader is used.
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
Expand All @@ -121,7 +127,9 @@ def __init__(
self.model = model

self.train_loss_handler = LossHandler(self.train_metrics.prefix)
self.test_loss_handler = LossHandler(self.test_metrics.prefix)
self.test_loss_handler: list[LossHandler] = []
for metrics in self.test_metrics:
self.test_loss_handler.append(LossHandler(metrics.prefix))
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
self.monitor = f"{self.val_metrics.prefix}loss"

Expand Down Expand Up @@ -191,7 +199,12 @@ def configure_metrics(self) -> None:
)
self.train_metrics = metrics.clone(prefix="train/")
self.val_metrics = metrics.clone(prefix="val/")
self.test_metrics = metrics.clone(prefix="test/")
if self.hparams["test_dataloaders_names"] is not None:
self.test_metrics = nn.ModuleList(
[metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]]
)
else:
self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")])

def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
"""Compute the train loss and additional metrics.
Expand Down Expand Up @@ -245,10 +258,17 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
if dataloader_idx >= len(self.test_loss_handler):
msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
raise ValueError(msg)
loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
self.test_loss_handler[dataloader_idx].log_loss(
partial(self.log, add_dataloader_idx=False), # We don't need the dataloader idx as prefixes are different
loss_dict=loss,
batch_size=x.shape[0],
)
y_hat_hard = to_class_prediction(model_output)
self.test_metrics.update(y_hat_hard, y)
self.test_metrics[dataloader_idx].update(y_hat_hard, y)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
"""Compute the predicted class probabilities.
Expand Down
37 changes: 31 additions & 6 deletions terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module contains the regression task and its auxiliary classes."""

from collections.abc import Sequence
from functools import partial
from typing import Any

import logging
Expand Down Expand Up @@ -130,7 +131,8 @@ class PixelwiseRegressionTask(TerraTorchTask):
- Accepts the specification of a model factory
- Logs metrics per class
- Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
- Allows the setting of optimizers in the constructor"""
- Allows the setting of optimizers in the constructor
- Allows to evaluate on multiple test dataloaders"""

def __init__(
self,
Expand All @@ -153,6 +155,7 @@ def __init__(
freeze_decoder: bool = False, # noqa: FBT001, FBT002
plot_on_val: bool | int = 10,
tiled_inference_parameters: TiledInferenceParameters | None = None,
test_dataloaders_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
) -> None:
"""Constructor
Expand Down Expand Up @@ -188,6 +191,9 @@ def __init__(
If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
used to determine if inference is done on the whole image or through tiling.
test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
which assumes only one test dataloader is used.
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
Expand All @@ -211,7 +217,9 @@ def __init__(
self.model = model

self.train_loss_handler = LossHandler(self.train_metrics.prefix)
self.test_loss_handler = LossHandler(self.test_metrics.prefix)
self.test_loss_handler: list[LossHandler] = []
for metrics in self.test_metrics:
self.test_loss_handler.append(LossHandler(metrics.prefix))
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
self.monitor = f"{self.val_metrics.prefix}loss"
self.plot_on_val = int(plot_on_val)
Expand Down Expand Up @@ -258,7 +266,17 @@ def wrap_metrics_with_ignore_index(metrics):

self.train_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="train/")
self.val_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="val/")
self.test_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")
if self.hparams["test_dataloaders_names"] is not None:
self.test_metrics = nn.ModuleList(
[
MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix=f"test/{dl_name}/")
for dl_name in self.hparams["test_dataloaders_names"]
]
)
else:
self.test_metrics = nn.ModuleList(
[MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")]
)

def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
"""Compute the train loss and additional metrics.
Expand Down Expand Up @@ -336,10 +354,17 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
if dataloader_idx >= len(self.test_loss_handler):
msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
raise ValueError(msg)
loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
self.test_loss_handler[dataloader_idx].log_loss(
partial(self.log, add_dataloader_idx=False), # We don't need the dataloader idx as prefixes are different
loss_dict=loss,
batch_size=x.shape[0],
)
y_hat = model_output.output
self.test_metrics.update(y_hat, y)
self.test_metrics[dataloader_idx].update(y_hat, y)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
"""Compute the predicted class probabilities.
Expand Down
5 changes: 0 additions & 5 deletions terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
y_hat_hard = to_segmentation_prediction(model_output)
self.test_metrics[dataloader_idx].update(y_hat_hard, y)

def on_test_epoch_end(self) -> None:
for metrics in self.test_metrics:
self.log_dict(metrics.compute(), sync_dist=True)
metrics.reset()

def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Compute the validation loss and additional metrics.
Args:
Expand Down

0 comments on commit 9140dc1

Please sign in to comment.