Skip to content

Commit

Permalink
Okay, everything seems to be, and there's now an alpha version of
Browse files Browse the repository at this point in the history
showing the effect of the selected image transforms.
  • Loading branch information
alex404 committed Oct 14, 2024
1 parent af51e43 commit 87b7085
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 21 deletions.
4 changes: 2 additions & 2 deletions resources/config_templates/user/dataset/cifar10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ imageset:
noise_transforms:
- _target_: retinal_rl.datasets.transforms.ShotNoiseTransform
lambda_range:
- ${eval:"0.5 if ${shot_noise_transform} else 1"}
- ${eval:"1.5 if ${shot_noise_transform} else 1"}
- ${eval:"0.5 if ${shot_noise_transform} else 0"}
- ${eval:"1.5 if ${shot_noise_transform} else 0"}
- _target_: retinal_rl.datasets.transforms.ContrastTransform
contrast_range:
- ${eval:"0.01 if ${contrast_noise_transform} else 1"}
Expand Down
95 changes: 91 additions & 4 deletions retinal_rl/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,104 @@
import numpy as np
import numpy.fft as fft
import seaborn as sns
import torch
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator
from torch import Tensor
from torchvision.utils import make_grid

from retinal_rl.models.brain import Brain
from retinal_rl.models.goal import ContextT, Goal
from retinal_rl.util import FloatArray


def plot_transforms(
source_transforms: Dict[str, Dict[float, List[torch.Tensor]]],
noise_transforms: Dict[str, Dict[float, List[torch.Tensor]]],
) -> Figure:
"""Use the result of the transform_base_images function to plot the effects of source and noise transforms on images.
Args:
----
source_transforms: A dictionary of source transforms and their effects on images.
noise_transforms: A dictionary of noise transforms and their effects on images.
Returns:
-------
Figure: A matplotlib Figure containing the plotted transforms.
"""
# Determine the number of transforms and images
num_source_transforms = len(source_transforms)
num_noise_transforms = len(noise_transforms)
num_transforms = num_source_transforms + num_noise_transforms
num_images = len(
next(iter(source_transforms.values()))[
next(iter(next(iter(source_transforms.values())).keys()))
]
)

# Create a figure with subplots for each transform
fig, axs = plt.subplots(num_transforms, 1, figsize=(20, 5 * num_transforms))
if num_transforms == 1:
axs = [axs]

transform_index = 0

# Plot source transforms
for transform_name, transform_data in source_transforms.items():
ax = axs[transform_index]
steps = sorted(transform_data.keys())

# Create a grid of images for each step
images = [
make_grid(
torch.stack([img * 0.5 + 0.5 for img in transform_data[step]]),
nrow=num_images,
)
for step in steps
]
grid = make_grid(images, nrow=len(steps))

# Display the grid
ax.imshow(grid.permute(1, 2, 0))
ax.set_title(f"Source Transform: {transform_name}")
ax.set_xticks([(i + 0.5) * grid.shape[2] / len(steps) for i in range(len(steps))])
ax.set_xticklabels([f"{step:.2f}" for step in steps])
ax.set_yticks([])

transform_index += 1

# Plot noise transforms
for transform_name, transform_data in noise_transforms.items():
ax = axs[transform_index]
steps = sorted(transform_data.keys())

# Create a grid of images for each step
images = [
make_grid(
torch.stack([img * 0.5 + 0.5 for img in transform_data[step]]),
nrow=num_images,
)
for step in steps
]
grid = make_grid(images, nrow=len(steps))

# Display the grid
ax.imshow(grid.permute(1, 2, 0))
ax.set_title(f"Noise Transform: {transform_name}")
ax.set_xticks([(i + 0.5) * grid.shape[2] / len(steps) for i in range(len(steps))])
ax.set_xticklabels([f"{step:.2f}" for step in steps])
ax.set_yticks([])

transform_index += 1

plt.tight_layout()
return fig


def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:
"""Visualize the Brain's connectome organized by depth and highlight optimizer targets using border colors.
Expand Down Expand Up @@ -384,11 +471,11 @@ def plot_reconstructions(
Args:
----
train_source (List[Tuple[Tensor, int]]): List of original source images and their classes.
train_input (List[Tuple[Tensor, int]]): List of original training images and their classes.
train_sources (List[Tuple[Tensor, int]]): List of original source images and their classes.
train_inputs (List[Tuple[Tensor, int]]): List of original training images and their classes.
train_estimates (List[Tuple[Tensor, int]]): List of reconstructed training images and their predicted classes.
test_source (List[Tuple[Tensor, int]]): List of original source images and their classes.
test_input (List[Tuple[Tensor, int]]): List of original test images and their classes.
test_sources (List[Tuple[Tensor, int]]): List of original source images and their classes.
test_inputs (List[Tuple[Tensor, int]]): List of original test images and their classes.
test_estimates (List[Tuple[Tensor, int]]): List of reconstructed test images and their predicted classes.
num_samples (int): The number of samples to plot.
Expand Down
55 changes: 55 additions & 0 deletions retinal_rl/analysis/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import torch
import torch.fft as fft
import torch.nn as nn
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader

from retinal_rl.dataset import Imageset, ImageSubset
from retinal_rl.datasets.transforms import ContinuousTransform
from retinal_rl.models.brain import Brain, get_cnn_circuit
from retinal_rl.util import (
FloatArray,
Expand All @@ -21,6 +23,59 @@
logger = logging.getLogger(__name__)


def transform_base_images(
imageset: Imageset, num_steps: int, num_images: int
) -> Dict[str, Dict[str, Dict[float, List[Tensor]]]]:
"""Apply transformations to a set of images from an Imageset.
Args:
----
imageset (Imageset): The dataset to transform.
num_images (int): The number of images to sample.
num_steps (int): The number of steps across the transformation range, and to apply each transformation.
Returns:
-------
Dict[str, Dict[str, Dict[float, List[Tensor]]]]: A dictionary containing the results.
"""
images: List[Image.Image] = []

base_dataset = imageset.base_dataset
base_len = imageset.base_len

for _ in range(num_images):
src, _ = base_dataset[np.random.randint(base_len)]
images.append(src)

results: Dict[str, Dict[str, Dict[float, List[Tensor]]]] = {
"source_transforms": {},
"noise_transforms": {},
}

transforms: List[Tuple[str, nn.Module]] = []
transforms += [
("source_transforms", transform) for transform in imageset.source_transforms
]
transforms += [
("noise_transforms", transform) for transform in imageset.noise_transforms
]

for category, transform in transforms:
if isinstance(transform, ContinuousTransform):
results[category][transform.name] = {}
trans_range: Tuple[float, float] = transform.trans_range
transform_steps = np.linspace(*trans_range, num_steps)
for step in transform_steps:
results[category][transform.name][step] = []
for img in images:
results[category][transform.name][step].append(
imageset.to_tensor(transform.transform(img, step))
)

return results


def reconstruct_images(
device: torch.device,
brain: Brain,
Expand Down
19 changes: 10 additions & 9 deletions retinal_rl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def __init__(
self.normalization_stats = (normalization_mean, normalization_std)
self.fixed_transformation = fixed_transformation
self.multiplier = multiplier if fixed_transformation else 1
self._base_len = 0
self.base_len = 0
for _ in self.base_dataset:
self._base_len += 1
self.base_len += 1

if fixed_transformation:
self.transformed_dataset = self._create_fixed_dataset()
Expand All @@ -68,23 +68,24 @@ def _create_fixed_dataset(self) -> List[Tuple[Tensor, Tensor, int]]:
source_img = self.source_transforms(img)
noisy_img = self.noise_transforms(source_img)
transformed_data.append(
(self._to_tensor(source_img), self._to_tensor(noisy_img), label)
(self.to_tensor(source_img), self.to_tensor(noisy_img), label)
)
idx += 1
return transformed_data

def _to_tensor(self, img: Image.Image) -> Tensor:
tensor = tf.to_tensor(img)
def to_tensor(self, img: Image.Image) -> Tensor:
"""Convert a PIL image to a PyTorch tensor and apply normalization if needed."""
tensor: Tensor = tf.to_tensor(img)
if self.apply_normalization:
mean, std = self.normalization_stats
tensor = tf.normalize(tensor, mean, std)
return tensor

def __len__(self) -> int:
if self.fixed_transformation:
return self._base_len * self.multiplier
return self.base_len * self.multiplier
logger.warning("Length of on-the-fly transformed dataset is not really fixed.")
return self._base_len
return self.base_len

def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
if self.fixed_transformation:
Expand All @@ -98,8 +99,8 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
noisy_img = self.noise_transforms(source_img)

# Convert to tensor and normalize
source_tensor = self._to_tensor(source_img)
noisy_tensor = self._to_tensor(noisy_img)
source_tensor = self.to_tensor(source_img)
noisy_tensor = self.to_tensor(noisy_img)

return source_tensor, noisy_tensor, label

Expand Down
20 changes: 16 additions & 4 deletions retinal_rl/datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,19 @@
class ContinuousTransform(nn.Module, ABC):
"""Base class for continuous image transformations."""

def __init__(self, range: Tuple[float, float]) -> None:
def __init__(self, trans_range: Tuple[float, float]) -> None:
"""Initialize the ContinuousTransform."""
super().__init__()
self.range = range
self.trans_range: Tuple[float, float] = trans_range

@property
def name(self) -> str:
"""Return a pretty name of the transformation."""
name = self.__class__.__name__
# Remove the "Transform" suffix
name = name.replace("Transform", "")
# decamelcase
return name.replace("([a-z])([A-Z])", r"\1 \2").lower()

@abstractmethod
def transform(self, img: Image.Image, trans_factor: float) -> Image.Image:
Expand Down Expand Up @@ -52,7 +61,7 @@ def forward(self, img: Image.Image) -> Image.Image:
Image.Image: The transformed PIL Image.
"""
trans_factor = np.random.uniform(self.range[0], self.range[1])
trans_factor = np.random.uniform(self.trans_range[0], self.trans_range[1])
return self.transform(img, trans_factor)


Expand Down Expand Up @@ -194,7 +203,7 @@ def __init__(self, lambda_range: Tuple[float, float]) -> None:
Args:
----
lambda_range (Tuple[float, float]): Range of shot noise intensity factors. For an identity transform, set the range to (1, 1).
lambda_range (Tuple[float, float]): Range of shot noise intensity factors. For an identity transform, set the range to (0, 0) to disable the shot noise.
"""
super().__init__(lambda_range)
Expand All @@ -212,6 +221,9 @@ def transform(self, img: Image.Image, trans_factor: float) -> Image.Image:
Image.Image: The transformed PIL Image with added shot noise.
"""
if trans_factor <= 0:
return img

# Convert PIL Image to numpy array
img_array = np.array(img)

Expand Down
14 changes: 12 additions & 2 deletions runner/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,24 @@

import matplotlib.pyplot as plt
import torch
import wandb
from matplotlib.figure import Figure
from omegaconf import DictConfig

import wandb
from retinal_rl.analysis.plot import (
layer_receptive_field_plots,
plot_brain_and_optimizers,
plot_channel_statistics,
plot_histories,
plot_receptive_field_sizes,
plot_reconstructions,
plot_transforms,
)
from retinal_rl.analysis.statistics import (
cnn_statistics,
reconstruct_images,
transform_base_images,
)
from retinal_rl.analysis.statistics import cnn_statistics, reconstruct_images
from retinal_rl.dataset import Imageset
from retinal_rl.models.brain import Brain
from retinal_rl.models.goal import ContextT, Goal
Expand Down Expand Up @@ -101,11 +106,16 @@ def analyze(

# CNN analysis
cnn_analysis = cnn_statistics(device, test_set, brain, 1000)

# Initialization analysis
if epoch == 0:
rf_sizes_fig = plot_receptive_field_sizes(cnn_analysis)
_process_figure(cfg, False, rf_sizes_fig, init_dir, "receptive_field_sizes", 0)
graph_fig = plot_brain_and_optimizers(brain, goal)
_process_figure(cfg, False, graph_fig, init_dir, "brain_graph", 0)
transforms = transform_base_images(train_set, num_steps=5, num_images=2)
transforms_fig = plot_transforms(**transforms)
_process_figure(cfg, False, transforms_fig, init_dir, "transforms", 0)

for layer_name, layer_data in cnn_analysis.items():
if layer_name == "input":
Expand Down

0 comments on commit 87b7085

Please sign in to comment.