Skip to content

Commit

Permalink
Merge pull request #336 from fmartiescofet/compute_statistics
Browse files Browse the repository at this point in the history
Feat: Add `compute_statistics` subcommand
  • Loading branch information
Joao-L-S-Almeida authored Jan 28, 2025
2 parents 9140dc1 + d1bacbd commit 3cfee48
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
50 changes: 50 additions & 0 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@
from lightning.pytorch.callbacks import BasePredictionWriter, ModelCheckpoint, RichProgressBar
from lightning.pytorch.cli import ArgsType, LightningArgumentParser, LightningCLI, SaveConfigCallback
from torchgeo.trainers import BaseTask
from torch.utils.data import DataLoader
from tqdm import tqdm

import terratorch.datamodules
from terratorch.utils import compute_mask_statistics, compute_statistics
import terratorch.tasks # noqa: F401
from terratorch.datamodules import ( # noqa: F401
GenericNonGeoClassificationDataModule,
Expand Down Expand Up @@ -407,6 +410,13 @@ def instantiate_classes(self) -> None:

import_custom_modules(custom_modules_path)

@staticmethod
def subcommands() -> dict[str, set[str]]:
existing_subcommands = LightningCLI.subcommands()
existing_subcommands["compute_statistics"] = {"datamodule"}
return existing_subcommands


def build_lightning_cli(
args: ArgsType = None,
run=True, # noqa: FBT002
Expand Down Expand Up @@ -445,6 +455,7 @@ def build_lightning_cli(
# save only state_dict as well as full state. Only state_dict will be used for exporting the model
trainer_defaults={"callbacks": [CustomWriter(write_interval="batch")]},
run=run,
trainer_class=MyTrainer,
)


Expand Down Expand Up @@ -567,3 +578,42 @@ def inference(self, file_path: Path) -> torch.Tensor:
tmpdir,
)
return prediction.squeeze(0)


class MyTrainer(Trainer):
def compute_statistics(self, datamodule: LightningDataModule, **kwargs) -> None:
"""
Compute the dataset statistics for the training dataset.
This method will compute the mean and standard deviation of the image data and the count and percentage of each
unique value in the masks in case these are int and the mean and standard deviation of the mask values in case
these are floats. The statistics are computed using the entire training dataset and are printed to the logger.
Please note that this method assumes that there is only one train dataloader in the datamodule. The train
transforms are removed before computing the statistics to ensure that the statistics are computed on the raw
data without any augmentation and randomization.
"""
# remove train transforms, this may not work for all datamodules
if hasattr(datamodule, "train_transform"):
datamodule.train_transform = None
datamodule.setup("fit")
original_dataloader = datamodule.train_dataloader()
if not isinstance(original_dataloader, DataLoader):
msg = "DataLoader not found in datamodule.train_dataloader()"
raise ValueError(msg)
new_dataloader = DataLoader(
dataset=original_dataloader.dataset,
batch_size=original_dataloader.batch_size,
shuffle=False,
num_workers=original_dataloader.num_workers,
collate_fn=original_dataloader.collate_fn,
pin_memory=original_dataloader.pin_memory,
drop_last=False,
)
image_stats = compute_statistics(new_dataloader)
logger.info("Image statistics:")
logger.info(yaml.dump(image_stats))
if "mask" in datamodule.train_dataloader().dataset[0]:
mask_stats = compute_mask_statistics(new_dataloader)
logger.info("Mask statistics:")
logger.info(yaml.dump(mask_stats))
70 changes: 70 additions & 0 deletions terratorch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import math
from collections import Counter

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm


def compute_statistics(dataloader: DataLoader) -> dict[str, list[float]]:
n_bands = dataloader.dataset[0]["image"].shape[0]
n_data = torch.zeros([n_bands], dtype=torch.int64)
sum_data = torch.zeros([n_bands], dtype=torch.float64)

# First pass for mean
for batch in tqdm(dataloader, desc="Compute mean"):
imgs: torch.Tensor = batch["image"]
# switch batch and band dimensions and flatten
samples = imgs.transpose(0, 1).reshape(n_bands, -1).double()
sum_data += samples.sum(dim=1)
n_data += samples.shape[1]
mean = sum_data / n_data

sum_squared = torch.zeros(n_bands, dtype=torch.float64)
for batch in tqdm(dataloader, desc="Compute variance"):
imgs = batch["image"]
samples = imgs.transpose(0, 1).reshape(n_bands, -1).double()
sum_squared += ((samples - mean.unsqueeze(1)) ** 2).sum(dim=1)

variance = sum_squared / n_data
std = torch.sqrt(variance)
return {"means": mean.numpy().tolist(), "stds": std.numpy().tolist()}


def compute_mask_statistics(dataloader: DataLoader) -> dict[int, dict[str, int | float]] | dict[str, float]:
if torch.is_floating_point(dataloader.dataset[0]["mask"]):
return compute_float_mask_statistics(dataloader)
else:
return compute_int_mask_statistics(dataloader)


def compute_int_mask_statistics(dataloader: DataLoader) -> dict[int, dict[str, int | float]]:
counter = Counter()
for batch in tqdm(dataloader, desc="Compute counts"):
masks: torch.Tensor = batch["mask"]
counter.update(masks.flatten().tolist())

stats = {}
for key, count in counter.items():
stats[key] = {"count": count, "percentage": count / counter.total()}
return stats


def compute_float_mask_statistics(dataloader: DataLoader) -> dict[str, float]:
n_data = 0
total = 0.0

for batch in tqdm(dataloader, desc="Compute mask mean"):
masks: torch.Tensor = batch["mask"]
total += masks.sum().item()
n_data += masks.numel()
mean = total / n_data

sum_squared = 0.0
for batch in tqdm(dataloader, desc="Compute mask variance"):
masks = batch["mask"]
sum_squared += ((masks - mean) ** 2).sum().item()

variance = sum_squared / n_data
std = math.sqrt(variance)
return {"mean": mean, "std": std}

0 comments on commit 3cfee48

Please sign in to comment.