From 87b7085d7ceb20d729e90344986efc16f3522337 Mon Sep 17 00:00:00 2001 From: alex404 Date: Mon, 14 Oct 2024 17:56:21 +0200 Subject: [PATCH] Okay, everything seems to be, and there's now an alpha version of showing the effect of the selected image transforms. --- .../user/dataset/cifar10.yaml | 4 +- retinal_rl/analysis/plot.py | 95 ++++++++++++++++++- retinal_rl/analysis/statistics.py | 55 +++++++++++ retinal_rl/dataset.py | 19 ++-- retinal_rl/datasets/transforms.py | 20 +++- runner/analyze.py | 14 ++- 6 files changed, 186 insertions(+), 21 deletions(-) diff --git a/resources/config_templates/user/dataset/cifar10.yaml b/resources/config_templates/user/dataset/cifar10.yaml index 4e6b9db1..24845df4 100644 --- a/resources/config_templates/user/dataset/cifar10.yaml +++ b/resources/config_templates/user/dataset/cifar10.yaml @@ -12,8 +12,8 @@ imageset: noise_transforms: - _target_: retinal_rl.datasets.transforms.ShotNoiseTransform lambda_range: - - ${eval:"0.5 if ${shot_noise_transform} else 1"} - - ${eval:"1.5 if ${shot_noise_transform} else 1"} + - ${eval:"0.5 if ${shot_noise_transform} else 0"} + - ${eval:"1.5 if ${shot_noise_transform} else 0"} - _target_: retinal_rl.datasets.transforms.ContrastTransform contrast_range: - ${eval:"0.01 if ${contrast_noise_transform} else 1"} diff --git a/retinal_rl/analysis/plot.py b/retinal_rl/analysis/plot.py index dae4a7be..1897e12e 100644 --- a/retinal_rl/analysis/plot.py +++ b/retinal_rl/analysis/plot.py @@ -9,17 +9,104 @@ import numpy as np import numpy.fft as fft import seaborn as sns +import torch from matplotlib.axes import Axes from matplotlib.figure import Figure from matplotlib.lines import Line2D from matplotlib.ticker import MaxNLocator from torch import Tensor +from torchvision.utils import make_grid from retinal_rl.models.brain import Brain from retinal_rl.models.goal import ContextT, Goal from retinal_rl.util import FloatArray +def plot_transforms( + source_transforms: Dict[str, Dict[float, List[torch.Tensor]]], + noise_transforms: Dict[str, Dict[float, List[torch.Tensor]]], +) -> Figure: + """Use the result of the transform_base_images function to plot the effects of source and noise transforms on images. + + Args: + ---- + source_transforms: A dictionary of source transforms and their effects on images. + noise_transforms: A dictionary of noise transforms and their effects on images. + + Returns: + ------- + Figure: A matplotlib Figure containing the plotted transforms. + + """ + # Determine the number of transforms and images + num_source_transforms = len(source_transforms) + num_noise_transforms = len(noise_transforms) + num_transforms = num_source_transforms + num_noise_transforms + num_images = len( + next(iter(source_transforms.values()))[ + next(iter(next(iter(source_transforms.values())).keys())) + ] + ) + + # Create a figure with subplots for each transform + fig, axs = plt.subplots(num_transforms, 1, figsize=(20, 5 * num_transforms)) + if num_transforms == 1: + axs = [axs] + + transform_index = 0 + + # Plot source transforms + for transform_name, transform_data in source_transforms.items(): + ax = axs[transform_index] + steps = sorted(transform_data.keys()) + + # Create a grid of images for each step + images = [ + make_grid( + torch.stack([img * 0.5 + 0.5 for img in transform_data[step]]), + nrow=num_images, + ) + for step in steps + ] + grid = make_grid(images, nrow=len(steps)) + + # Display the grid + ax.imshow(grid.permute(1, 2, 0)) + ax.set_title(f"Source Transform: {transform_name}") + ax.set_xticks([(i + 0.5) * grid.shape[2] / len(steps) for i in range(len(steps))]) + ax.set_xticklabels([f"{step:.2f}" for step in steps]) + ax.set_yticks([]) + + transform_index += 1 + + # Plot noise transforms + for transform_name, transform_data in noise_transforms.items(): + ax = axs[transform_index] + steps = sorted(transform_data.keys()) + + # Create a grid of images for each step + images = [ + make_grid( + torch.stack([img * 0.5 + 0.5 for img in transform_data[step]]), + nrow=num_images, + ) + for step in steps + ] + grid = make_grid(images, nrow=len(steps)) + + # Display the grid + ax.imshow(grid.permute(1, 2, 0)) + ax.set_title(f"Noise Transform: {transform_name}") + ax.set_xticks([(i + 0.5) * grid.shape[2] / len(steps) for i in range(len(steps))]) + ax.set_xticklabels([f"{step:.2f}" for step in steps]) + ax.set_yticks([]) + + transform_index += 1 + + plt.tight_layout() + return fig + + def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: """Visualize the Brain's connectome organized by depth and highlight optimizer targets using border colors. @@ -384,11 +471,11 @@ def plot_reconstructions( Args: ---- - train_source (List[Tuple[Tensor, int]]): List of original source images and their classes. - train_input (List[Tuple[Tensor, int]]): List of original training images and their classes. + train_sources (List[Tuple[Tensor, int]]): List of original source images and their classes. + train_inputs (List[Tuple[Tensor, int]]): List of original training images and their classes. train_estimates (List[Tuple[Tensor, int]]): List of reconstructed training images and their predicted classes. - test_source (List[Tuple[Tensor, int]]): List of original source images and their classes. - test_input (List[Tuple[Tensor, int]]): List of original test images and their classes. + test_sources (List[Tuple[Tensor, int]]): List of original source images and their classes. + test_inputs (List[Tuple[Tensor, int]]): List of original test images and their classes. test_estimates (List[Tuple[Tensor, int]]): List of reconstructed test images and their predicted classes. num_samples (int): The number of samples to plot. diff --git a/retinal_rl/analysis/statistics.py b/retinal_rl/analysis/statistics.py index 38d49de8..87840b6f 100644 --- a/retinal_rl/analysis/statistics.py +++ b/retinal_rl/analysis/statistics.py @@ -7,10 +7,12 @@ import torch import torch.fft as fft import torch.nn as nn +from PIL import Image from torch import Tensor from torch.utils.data import DataLoader from retinal_rl.dataset import Imageset, ImageSubset +from retinal_rl.datasets.transforms import ContinuousTransform from retinal_rl.models.brain import Brain, get_cnn_circuit from retinal_rl.util import ( FloatArray, @@ -21,6 +23,59 @@ logger = logging.getLogger(__name__) +def transform_base_images( + imageset: Imageset, num_steps: int, num_images: int +) -> Dict[str, Dict[str, Dict[float, List[Tensor]]]]: + """Apply transformations to a set of images from an Imageset. + + Args: + ---- + imageset (Imageset): The dataset to transform. + num_images (int): The number of images to sample. + num_steps (int): The number of steps across the transformation range, and to apply each transformation. + + Returns: + ------- + Dict[str, Dict[str, Dict[float, List[Tensor]]]]: A dictionary containing the results. + + """ + images: List[Image.Image] = [] + + base_dataset = imageset.base_dataset + base_len = imageset.base_len + + for _ in range(num_images): + src, _ = base_dataset[np.random.randint(base_len)] + images.append(src) + + results: Dict[str, Dict[str, Dict[float, List[Tensor]]]] = { + "source_transforms": {}, + "noise_transforms": {}, + } + + transforms: List[Tuple[str, nn.Module]] = [] + transforms += [ + ("source_transforms", transform) for transform in imageset.source_transforms + ] + transforms += [ + ("noise_transforms", transform) for transform in imageset.noise_transforms + ] + + for category, transform in transforms: + if isinstance(transform, ContinuousTransform): + results[category][transform.name] = {} + trans_range: Tuple[float, float] = transform.trans_range + transform_steps = np.linspace(*trans_range, num_steps) + for step in transform_steps: + results[category][transform.name][step] = [] + for img in images: + results[category][transform.name][step].append( + imageset.to_tensor(transform.transform(img, step)) + ) + + return results + + def reconstruct_images( device: torch.device, brain: Brain, diff --git a/retinal_rl/dataset.py b/retinal_rl/dataset.py index 82c46c1f..9b259351 100644 --- a/retinal_rl/dataset.py +++ b/retinal_rl/dataset.py @@ -53,9 +53,9 @@ def __init__( self.normalization_stats = (normalization_mean, normalization_std) self.fixed_transformation = fixed_transformation self.multiplier = multiplier if fixed_transformation else 1 - self._base_len = 0 + self.base_len = 0 for _ in self.base_dataset: - self._base_len += 1 + self.base_len += 1 if fixed_transformation: self.transformed_dataset = self._create_fixed_dataset() @@ -68,13 +68,14 @@ def _create_fixed_dataset(self) -> List[Tuple[Tensor, Tensor, int]]: source_img = self.source_transforms(img) noisy_img = self.noise_transforms(source_img) transformed_data.append( - (self._to_tensor(source_img), self._to_tensor(noisy_img), label) + (self.to_tensor(source_img), self.to_tensor(noisy_img), label) ) idx += 1 return transformed_data - def _to_tensor(self, img: Image.Image) -> Tensor: - tensor = tf.to_tensor(img) + def to_tensor(self, img: Image.Image) -> Tensor: + """Convert a PIL image to a PyTorch tensor and apply normalization if needed.""" + tensor: Tensor = tf.to_tensor(img) if self.apply_normalization: mean, std = self.normalization_stats tensor = tf.normalize(tensor, mean, std) @@ -82,9 +83,9 @@ def _to_tensor(self, img: Image.Image) -> Tensor: def __len__(self) -> int: if self.fixed_transformation: - return self._base_len * self.multiplier + return self.base_len * self.multiplier logger.warning("Length of on-the-fly transformed dataset is not really fixed.") - return self._base_len + return self.base_len def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: if self.fixed_transformation: @@ -98,8 +99,8 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: noisy_img = self.noise_transforms(source_img) # Convert to tensor and normalize - source_tensor = self._to_tensor(source_img) - noisy_tensor = self._to_tensor(noisy_img) + source_tensor = self.to_tensor(source_img) + noisy_tensor = self.to_tensor(noisy_img) return source_tensor, noisy_tensor, label diff --git a/retinal_rl/datasets/transforms.py b/retinal_rl/datasets/transforms.py index 1a679113..9b9f87fd 100644 --- a/retinal_rl/datasets/transforms.py +++ b/retinal_rl/datasets/transforms.py @@ -19,10 +19,19 @@ class ContinuousTransform(nn.Module, ABC): """Base class for continuous image transformations.""" - def __init__(self, range: Tuple[float, float]) -> None: + def __init__(self, trans_range: Tuple[float, float]) -> None: """Initialize the ContinuousTransform.""" super().__init__() - self.range = range + self.trans_range: Tuple[float, float] = trans_range + + @property + def name(self) -> str: + """Return a pretty name of the transformation.""" + name = self.__class__.__name__ + # Remove the "Transform" suffix + name = name.replace("Transform", "") + # decamelcase + return name.replace("([a-z])([A-Z])", r"\1 \2").lower() @abstractmethod def transform(self, img: Image.Image, trans_factor: float) -> Image.Image: @@ -52,7 +61,7 @@ def forward(self, img: Image.Image) -> Image.Image: Image.Image: The transformed PIL Image. """ - trans_factor = np.random.uniform(self.range[0], self.range[1]) + trans_factor = np.random.uniform(self.trans_range[0], self.trans_range[1]) return self.transform(img, trans_factor) @@ -194,7 +203,7 @@ def __init__(self, lambda_range: Tuple[float, float]) -> None: Args: ---- - lambda_range (Tuple[float, float]): Range of shot noise intensity factors. For an identity transform, set the range to (1, 1). + lambda_range (Tuple[float, float]): Range of shot noise intensity factors. For an identity transform, set the range to (0, 0) to disable the shot noise. """ super().__init__(lambda_range) @@ -212,6 +221,9 @@ def transform(self, img: Image.Image, trans_factor: float) -> Image.Image: Image.Image: The transformed PIL Image with added shot noise. """ + if trans_factor <= 0: + return img + # Convert PIL Image to numpy array img_array = np.array(img) diff --git a/runner/analyze.py b/runner/analyze.py index affb2903..e1503df5 100644 --- a/runner/analyze.py +++ b/runner/analyze.py @@ -5,10 +5,10 @@ import matplotlib.pyplot as plt import torch -import wandb from matplotlib.figure import Figure from omegaconf import DictConfig +import wandb from retinal_rl.analysis.plot import ( layer_receptive_field_plots, plot_brain_and_optimizers, @@ -16,8 +16,13 @@ plot_histories, plot_receptive_field_sizes, plot_reconstructions, + plot_transforms, +) +from retinal_rl.analysis.statistics import ( + cnn_statistics, + reconstruct_images, + transform_base_images, ) -from retinal_rl.analysis.statistics import cnn_statistics, reconstruct_images from retinal_rl.dataset import Imageset from retinal_rl.models.brain import Brain from retinal_rl.models.goal import ContextT, Goal @@ -101,11 +106,16 @@ def analyze( # CNN analysis cnn_analysis = cnn_statistics(device, test_set, brain, 1000) + + # Initialization analysis if epoch == 0: rf_sizes_fig = plot_receptive_field_sizes(cnn_analysis) _process_figure(cfg, False, rf_sizes_fig, init_dir, "receptive_field_sizes", 0) graph_fig = plot_brain_and_optimizers(brain, goal) _process_figure(cfg, False, graph_fig, init_dir, "brain_graph", 0) + transforms = transform_base_images(train_set, num_steps=5, num_images=2) + transforms_fig = plot_transforms(**transforms) + _process_figure(cfg, False, transforms_fig, init_dir, "transforms", 0) for layer_name, layer_data in cnn_analysis.items(): if layer_name == "input":