From 7370e22b7b29cb9f038d620237df06e3c6aa6e34 Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Fri, 20 Dec 2024 17:52:27 +0100 Subject: [PATCH 1/4] Add compute_statistic subcommand Signed-off-by: Francesc Marti Escofet --- terratorch/cli_tools.py | 53 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index a00f2ad3..547475c7 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -31,6 +31,8 @@ from lightning.pytorch.callbacks import BasePredictionWriter, ModelCheckpoint, RichProgressBar from lightning.pytorch.cli import ArgsType, LightningArgumentParser, LightningCLI, SaveConfigCallback from torchgeo.trainers import BaseTask +from torch.utils.data import DataLoader +from tqdm import tqdm import terratorch.datamodules import terratorch.tasks # noqa: F401 @@ -399,6 +401,13 @@ def instantiate_classes(self) -> None: import_custom_modules(custom_modules_path) + @staticmethod + def subcommands() -> dict[str, set[str]]: + existing_subcommands = LightningCLI.subcommands() + existing_subcommands["compute_statistics"] = {"datamodule"} + return existing_subcommands + + def build_lightning_cli( args: ArgsType = None, run=True, # noqa: FBT002 @@ -437,6 +446,7 @@ def build_lightning_cli( # save only state_dict as well as full state. Only state_dict will be used for exporting the model trainer_defaults={"callbacks": [CustomWriter(write_interval="batch")]}, run=run, + trainer_class=MyTrainer, ) @@ -559,3 +569,46 @@ def inference(self, file_path: Path) -> torch.Tensor: tmpdir, ) return prediction.squeeze(0) + + +class MyTrainer(Trainer): + def compute_statistics(self, datamodule: LightningDataModule, **kwargs) -> None: + datamodule.setup("fit") + original_dataloader = datamodule.train_dataloader() + if not isinstance(original_dataloader, DataLoader): + msg = "DataLoader not found in datamodule.train_dataloader()" + raise ValueError(msg) + new_dataloader = DataLoader( + dataset=original_dataloader.dataset, + batch_size=original_dataloader.batch_size, + shuffle=False, + num_workers=original_dataloader.num_workers, + collate_fn=original_dataloader.collate_fn, + pin_memory=original_dataloader.pin_memory, + drop_last=False, + ) + n_bands = original_dataloader.dataset[0]["image"].shape[0] + n_data = torch.zeros([n_bands], dtype=torch.int64) + sum_data = torch.zeros([n_bands], dtype=torch.float64) + + # First pass for mean + for batch in tqdm(new_dataloader, desc="Compute mean"): + imgs: torch.Tensor = batch["image"] + # switch batch and band dimensions and flatten + samples = imgs.transpose(0, 1).reshape(n_bands, -1).double() + sum_data += samples.sum(dim=1) + n_data += samples.shape[1] + mean = sum_data / n_data + + sum_squared = torch.zeros(n_bands, dtype=torch.float64) + for batch in tqdm(new_dataloader, desc="Compute variance"): + imgs = batch["image"] + samples = imgs.transpose(0, 1).reshape(n_bands, -1).double() + sum_squared += ((samples - mean.unsqueeze(1)) ** 2).sum(dim=1) + + variance = sum_squared / n_data + std = torch.sqrt(variance) + + torch.set_printoptions(precision=10) + logger.info(f"Dataset mean: {mean}") + logger.info(f"Dataset std: {std}") From 2d69f065cfd749428b55835bd5146e5ddeb31e4e Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Mon, 23 Dec 2024 13:12:58 +0100 Subject: [PATCH 2/4] Move compute_statistic to utils and output in yaml Signed-off-by: Francesc Marti Escofet --- terratorch/cli_tools.py | 29 ++++------------------------- terratorch/utils.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 25 deletions(-) create mode 100644 terratorch/utils.py diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 547475c7..598f354f 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -35,6 +35,7 @@ from tqdm import tqdm import terratorch.datamodules +from terratorch.utils import compute_statistics import terratorch.tasks # noqa: F401 from terratorch.datamodules import ( # noqa: F401 GenericNonGeoClassificationDataModule, @@ -587,28 +588,6 @@ def compute_statistics(self, datamodule: LightningDataModule, **kwargs) -> None: pin_memory=original_dataloader.pin_memory, drop_last=False, ) - n_bands = original_dataloader.dataset[0]["image"].shape[0] - n_data = torch.zeros([n_bands], dtype=torch.int64) - sum_data = torch.zeros([n_bands], dtype=torch.float64) - - # First pass for mean - for batch in tqdm(new_dataloader, desc="Compute mean"): - imgs: torch.Tensor = batch["image"] - # switch batch and band dimensions and flatten - samples = imgs.transpose(0, 1).reshape(n_bands, -1).double() - sum_data += samples.sum(dim=1) - n_data += samples.shape[1] - mean = sum_data / n_data - - sum_squared = torch.zeros(n_bands, dtype=torch.float64) - for batch in tqdm(new_dataloader, desc="Compute variance"): - imgs = batch["image"] - samples = imgs.transpose(0, 1).reshape(n_bands, -1).double() - sum_squared += ((samples - mean.unsqueeze(1)) ** 2).sum(dim=1) - - variance = sum_squared / n_data - std = torch.sqrt(variance) - - torch.set_printoptions(precision=10) - logger.info(f"Dataset mean: {mean}") - logger.info(f"Dataset std: {std}") + mean, std = compute_statistics(new_dataloader) + + logger.info(yaml.dump({"means": mean, "stds": std})) diff --git a/terratorch/utils.py b/terratorch/utils.py new file mode 100644 index 00000000..631556c9 --- /dev/null +++ b/terratorch/utils.py @@ -0,0 +1,29 @@ +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + + +def compute_statistics(dataloader: DataLoader) -> tuple[list[float], list[float]]: + n_bands = dataloader.dataset[0]["image"].shape[0] + n_data = torch.zeros([n_bands], dtype=torch.int64) + sum_data = torch.zeros([n_bands], dtype=torch.float64) + + # First pass for mean + for batch in tqdm(dataloader, desc="Compute mean"): + imgs: torch.Tensor = batch["image"] + # switch batch and band dimensions and flatten + samples = imgs.transpose(0, 1).reshape(n_bands, -1).double() + sum_data += samples.sum(dim=1) + n_data += samples.shape[1] + mean = sum_data / n_data + + sum_squared = torch.zeros(n_bands, dtype=torch.float64) + for batch in tqdm(dataloader, desc="Compute variance"): + imgs = batch["image"] + samples = imgs.transpose(0, 1).reshape(n_bands, -1).double() + sum_squared += ((samples - mean.unsqueeze(1)) ** 2).sum(dim=1) + + variance = sum_squared / n_data + std = torch.sqrt(variance) + + return mean.numpy().tolist(), std.numpy().tolist() From 6b5b3d49f893f237752f8f90cc6fd26e748a80d6 Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Tue, 7 Jan 2025 11:05:05 +0100 Subject: [PATCH 3/4] Add docs and compute mask stats Signed-off-by: Francesc Marti Escofet --- terratorch/cli_tools.py | 22 ++++++++++++++++---- terratorch/utils.py | 45 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 598f354f..38e69c3a 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -35,7 +35,7 @@ from tqdm import tqdm import terratorch.datamodules -from terratorch.utils import compute_statistics +from terratorch.utils import compute_mask_statistics, compute_statistics import terratorch.tasks # noqa: F401 from terratorch.datamodules import ( # noqa: F401 GenericNonGeoClassificationDataModule, @@ -574,6 +574,16 @@ def inference(self, file_path: Path) -> torch.Tensor: class MyTrainer(Trainer): def compute_statistics(self, datamodule: LightningDataModule, **kwargs) -> None: + """ + Compute the dataset statistics for the training dataset. + + This method will compute the mean and standard deviation of the image data and the count and percentage of each + unique value in the masks in case these are int and the mean and standard deviation of the mask values in case + these are floats. The statistics are computed using the entire training dataset and are printed to the logger. + + Please note that this method assumes that there is only one train dataloader in the datamodule and that the + dataset does not have any transforms that may introduce randomness. + """ datamodule.setup("fit") original_dataloader = datamodule.train_dataloader() if not isinstance(original_dataloader, DataLoader): @@ -588,6 +598,10 @@ def compute_statistics(self, datamodule: LightningDataModule, **kwargs) -> None: pin_memory=original_dataloader.pin_memory, drop_last=False, ) - mean, std = compute_statistics(new_dataloader) - - logger.info(yaml.dump({"means": mean, "stds": std})) + image_stats = compute_statistics(new_dataloader) + logger.info("Image statistics:") + logger.info(yaml.dump(image_stats)) + if "mask" in datamodule.train_dataloader().dataset[0]: + mask_stats = compute_mask_statistics(new_dataloader) + logger.info("Mask statistics:") + logger.info(yaml.dump(mask_stats)) diff --git a/terratorch/utils.py b/terratorch/utils.py index 631556c9..fea856c8 100644 --- a/terratorch/utils.py +++ b/terratorch/utils.py @@ -1,9 +1,12 @@ +import math +from collections import Counter + import torch from torch.utils.data import DataLoader from tqdm import tqdm -def compute_statistics(dataloader: DataLoader) -> tuple[list[float], list[float]]: +def compute_statistics(dataloader: DataLoader) -> dict[str, list[float]]: n_bands = dataloader.dataset[0]["image"].shape[0] n_data = torch.zeros([n_bands], dtype=torch.int64) sum_data = torch.zeros([n_bands], dtype=torch.float64) @@ -25,5 +28,43 @@ def compute_statistics(dataloader: DataLoader) -> tuple[list[float], list[float] variance = sum_squared / n_data std = torch.sqrt(variance) + return {"means": mean.numpy().tolist(), "stds": std.numpy().tolist()} + + +def compute_mask_statistics(dataloader: DataLoader) -> dict[int, dict[str, int | float]] | dict[str, float]: + if torch.is_floating_point(dataloader.dataset[0]["mask"]): + return compute_float_mask_statistics(dataloader) + else: + return compute_int_mask_statistics(dataloader) + + +def compute_int_mask_statistics(dataloader: DataLoader) -> dict[int, dict[str, int | float]]: + counter = Counter() + for batch in tqdm(dataloader, desc="Compute counts"): + masks: torch.Tensor = batch["mask"] + counter.update(masks.flatten().tolist()) + + stats = {} + for key, count in counter.items(): + stats[key] = {"count": count, "percentage": count / counter.total()} + return stats - return mean.numpy().tolist(), std.numpy().tolist() + +def compute_float_mask_statistics(dataloader: DataLoader) -> dict[str, float]: + n_data = 0 + total = 0.0 + + for batch in tqdm(dataloader, desc="Compute mask mean"): + masks: torch.Tensor = batch["mask"] + total += masks.sum().item() + n_data += masks.numel() + mean = total / n_data + + sum_squared = 0.0 + for batch in tqdm(dataloader, desc="Compute mask variance"): + masks = batch["mask"] + sum_squared += ((masks - mean) ** 2).sum().item() + + variance = sum_squared / n_data + std = math.sqrt(variance) + return {"mean": mean, "std": std} From 463c41eb6aaa668f63ad79936c48e29bb4a96b48 Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Tue, 7 Jan 2025 12:25:56 +0100 Subject: [PATCH 4/4] Remove train transforms Signed-off-by: Francesc Marti Escofet --- terratorch/cli_tools.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 38e69c3a..80cb488a 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -581,9 +581,13 @@ def compute_statistics(self, datamodule: LightningDataModule, **kwargs) -> None: unique value in the masks in case these are int and the mean and standard deviation of the mask values in case these are floats. The statistics are computed using the entire training dataset and are printed to the logger. - Please note that this method assumes that there is only one train dataloader in the datamodule and that the - dataset does not have any transforms that may introduce randomness. + Please note that this method assumes that there is only one train dataloader in the datamodule. The train + transforms are removed before computing the statistics to ensure that the statistics are computed on the raw + data without any augmentation and randomization. """ + # remove train transforms, this may not work for all datamodules + if hasattr(datamodule, "train_transform"): + datamodule.train_transform = None datamodule.setup("fit") original_dataloader = datamodule.train_dataloader() if not isinstance(original_dataloader, DataLoader):