Skip to content

Commit

Permalink
refactor: continue analysis/plot restructuring
Browse files Browse the repository at this point in the history
  • Loading branch information
fabioseel committed Dec 13, 2024
1 parent 78f4d4d commit 6f5861e
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 193 deletions.
50 changes: 39 additions & 11 deletions retinal_rl/analysis/channel_analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@



import logging
from dataclasses import dataclass
from typing import cast
Expand All @@ -23,6 +20,7 @@

logger = logging.getLogger(__name__)


@dataclass
class SpectralAnalysis:
"""Results of spectral analysis for a layer."""
Expand All @@ -40,6 +38,7 @@ class HistogramAnalysis:
channel_histograms: FloatArray
bin_edges: FloatArray


def spectral_analysis(
device: torch.device,
imageset: Imageset,
Expand All @@ -54,7 +53,7 @@ def spectral_analysis(
dataloader = _prepare_dataset(imageset, max_sample_size)

# Initialize results
results = {"input": _layer_spectral_analysis(device, dataloader, nn.Identity())}
results: dict[str, SpectralAnalysis] = {}

# Analyze each layer
head_layers: list[nn.Module] = []
Expand All @@ -64,7 +63,6 @@ def spectral_analysis(

if is_nonlinearity(layer):
continue
# TODO: Possible for non Conv2D layers?

results[layer_name] = _layer_spectral_analysis(
device, dataloader, nn.Sequential(*head_layers)
Expand All @@ -84,10 +82,10 @@ def histogram_analysis(
_, cnn_layers = get_cnn_circuit(brain)

# Prepare dataset
dataloader = _prepare_dataset(imageset, max_sample_size)
dataloader = _prepare_dataset(imageset, max_sample_size) # TODO: Move outside?

# Initialize results
results = {"input": _layer_pixel_histograms(device, dataloader, nn.Identity())}
results: dict[str, HistogramAnalysis] = {}

# Analyze each layer
head_layers: list[nn.Module] = []
Expand All @@ -96,7 +94,6 @@ def histogram_analysis(
head_layers.append(layer)
if is_nonlinearity(layer):
continue
# TODO: Possible for non Conv2D layers?
results[layer_name] = _layer_pixel_histograms(
device, dataloader, nn.Sequential(*head_layers)
)
Expand Down Expand Up @@ -243,8 +240,6 @@ def plot(
copy_checkpoint: bool,
):
for layer_name, layer_rfs in rf_result.items():
if layer_name != "input":
continue
layer_spectral = spectral_result[layer_name]
layer_histogram = histogram_result[layer_name]
for channel in range(layer_rfs.shape[0]):
Expand Down Expand Up @@ -395,4 +390,37 @@ def _plot_receptive_fields(ax: Axes, rf: FloatArray):
horizontalalignment="center",
verticalalignment="center",
transform=ax.transAxes,
)
)


def analyze_input(
device: torch.device, imageset: Imageset, max_sample_size: int
) -> tuple[SpectralAnalysis, HistogramAnalysis]:
dataloader = _prepare_dataset(imageset, max_sample_size)
spectral_result = _layer_spectral_analysis(device, dataloader, nn.Identity())
histogram_result = _layer_pixel_histograms(device, dataloader, nn.Identity())
return spectral_result, histogram_result


def input_plot(
log: FigureLogger,
rf_result: FloatArray,
spectral_result: SpectralAnalysis,
histogram_result: HistogramAnalysis,
init_dir: str,
):
for channel in range(histogram_result.channel_histograms.shape[0]):
channel_fig = layer_channel_plots(
rf_result,
spectral_result,
histogram_result,
layer_name="input",
channel=channel,
)
log.log_figure(
channel_fig,
init_dir,
f"input_channel_{channel}",
0,
False,
)
38 changes: 38 additions & 0 deletions retinal_rl/analysis/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from retinal_rl.analysis.plot import (
FigureLogger,
plot_brain_and_optimizers,
plot_receptive_field_sizes,
)
from retinal_rl.models.brain import Brain
from retinal_rl.models.objective import ContextT, Objective
from retinal_rl.util import FloatArray

INIT_DIR = "initialization_analysis"

def initialization_plots(
log: FigureLogger,
brain: Brain,
objective: Objective[ContextT],
input_shape: tuple[int, ...],
rf_result: dict[str, FloatArray],
):
log.save_summary(brain)

# TODO: This is a bit of a hack, we should refactor this to get the relevant information out of cnn_stats
rf_sizes_fig = plot_receptive_field_sizes(input_shape, rf_result)
log.log_figure(
rf_sizes_fig,
INIT_DIR,
"receptive_field_sizes",
0,
False,
)

graph_fig = plot_brain_and_optimizers(brain, objective)
log.log_figure(
graph_fig,
INIT_DIR,
"brain_graph",
0,
False,
)
18 changes: 9 additions & 9 deletions retinal_rl/analysis/plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Utility functions for plotting the results of statistical analyses."""

import json
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import networkx as nx
Expand All @@ -17,7 +19,7 @@

from retinal_rl.models.brain import Brain
from retinal_rl.models.objective import ContextT, Objective
from retinal_rl.util import FloatArray
from retinal_rl.util import FloatArray, NumpyEncoder


def make_image_grid(arrays: list[FloatArray], nrow: int) -> FloatArray:
Expand Down Expand Up @@ -178,8 +180,6 @@ def plot_receptive_field_sizes(
rf_sizes: list[tuple[int, int]] = []
layer_names: list[str] = []
for name, rf in rf_layers.items():
if name == "input": # TODO: Should not be possible?!
continue
rf_height, rf_width = rf.shape[2:]
rf_sizes.append((rf_height, rf_width))
layer_names.append(name)
Expand Down Expand Up @@ -291,13 +291,10 @@ def set_integer_ticks(ax: Axes):
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))


class FigureLogger:
def __init__(
self,
use_wandb: bool,
plot_dir: Path,
checkpoint_plot_dir: Path,
run_dir: Path
self, use_wandb: bool, plot_dir: Path, checkpoint_plot_dir: Path, run_dir: Path
):
self.use_wandb = use_wandb
self.plot_dir = plot_dir
Expand Down Expand Up @@ -340,7 +337,6 @@ def capitalize_part(part: str) -> str:
return "/".join(capitalized_parts)

def _checkpoint_copy(self, sub_dir: str, file_name: str, epoch: int) -> None:
# TODO: Does this need to be in here?
src_path = self.plot_dir / sub_dir / f"{file_name}.png"

dest_dir = self.checkpoint_plot_dir / f"epoch_{epoch}" / sub_dir
Expand Down Expand Up @@ -370,3 +366,7 @@ def save_summary(self, brain: Brain):

if self.use_wandb:
wandb.save(str(filepath), base_path=self.run_dir, policy="now")

def save_dict(self, path: Path, dict: dict[str, Any]):
with open(path, "w") as f:
json.dump(dict, f, cls=NumpyEncoder)
13 changes: 5 additions & 8 deletions retinal_rl/analysis/receptive_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,14 @@ def analyze(brain: Brain, device: torch.device):

return input_shape, results


def plot(
log: FigureLogger,
rf_result: dict[str, FloatArray],
epoch: int,
copy_checkpoint: bool,
):
for layer_name, layer_rfs in rf_result.items():
if layer_name == "input": # TODO: Remove, see where this goes
continue
layer_rf_plots = layer_receptive_field_plots(layer_rfs)
log.log_figure(
layer_rf_plots,
Expand All @@ -58,9 +57,8 @@ def plot(
copy_checkpoint,
)

def layer_receptive_field_plots(
lyr_rfs: FloatArray, max_cols: int = 8
) -> Figure:

def layer_receptive_field_plots(lyr_rfs: FloatArray, max_cols: int = 8) -> Figure:
"""Plot the receptive fields of a convolutional layer."""
ochns, _, _, _ = lyr_rfs.shape

Expand All @@ -79,9 +77,7 @@ def layer_receptive_field_plots(

for i in range(ochns):
ax = axs[i]
data = np.moveaxis(
lyr_rfs[i], 0, -1
) # Move channel axis to the last dimension
data = np.moveaxis(lyr_rfs[i], 0, -1) # Move channel axis to the last dimension
data_min = data.min()
data_max = data.max()
data = (data - data_min) / (data_max - data_min)
Expand All @@ -96,6 +92,7 @@ def layer_receptive_field_plots(
fig.tight_layout() # Adjust layout to fit color bars
return fig


def _compute_receptive_fields(
device: torch.device,
head_layers: list[nn.Module],
Expand Down
41 changes: 24 additions & 17 deletions retinal_rl/analysis/reconstructions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@

import json
from dataclasses import asdict, dataclass
from pathlib import Path

Expand All @@ -13,7 +11,7 @@
from retinal_rl.models.brain import Brain
from retinal_rl.models.loss import ContextT, ReconstructionLoss
from retinal_rl.models.objective import Objective
from retinal_rl.util import FloatArray, NumpyEncoder
from retinal_rl.util import FloatArray


@dataclass
Expand All @@ -24,42 +22,47 @@ class Reconstructions:
inputs: list[tuple[FloatArray, int]]
estimates: list[tuple[FloatArray, int]]


@dataclass
class ReconstructionStatistics:
"""Results of image reconstruction for both training and test sets."""

train: Reconstructions
test: Reconstructions

# TODO: Make structure match the analyze / plot structure as receptive_fields

def perform_reconstruction_analysis(
log: FigureLogger,
analyses_dir: Path,
def analyze(
device: torch.device,
brain: Brain,
objective: Objective[ContextT],
train_set: Imageset,
test_set: Imageset,
epoch: int,
copy_checkpoint: bool,
):
) -> tuple[dict[str, ReconstructionStatistics], list[float], list[float]]:
reconstruction_decoders = [
loss.target_decoder
for loss in objective.losses
if isinstance(loss, ReconstructionLoss)
]

results: dict[str, ReconstructionStatistics] = {}
for decoder in reconstruction_decoders:
norm_means, norm_stds = train_set.normalization_stats
rec_dict = asdict(
reconstruct_images(device, brain, decoder, train_set, test_set, 5)
results[decoder] = reconstruct_images(
device, brain, decoder, train_set, test_set, 5
)
# Save the reconstructions
rec_path = analyses_dir / f"{decoder}_reconstructions_epoch_{epoch}.json"
with open(rec_path, "w") as f:
json.dump(rec_dict, f, cls=NumpyEncoder)
return results, *train_set.normalization_stats


def plot(
log: FigureLogger,
analyses_dir: Path,
result: dict[str, ReconstructionStatistics],
norm_means: list[float],
norm_stds: list[float],
epoch: int,
copy_checkpoint: bool,
):
for decoder, reconstructions in result.items():
rec_dict = asdict(reconstructions)
recon_fig = plot_reconstructions(
norm_means,
norm_stds,
Expand All @@ -74,6 +77,10 @@ def perform_reconstruction_analysis(
epoch,
copy_checkpoint,
)
# Save the reconstructions #TODO: most plot functions don't do this, should stay?
log.save_dict(
analyses_dir / f"{decoder}_reconstructions_epoch_{epoch}.json", rec_dict
)


def reconstruct_images(
Expand Down
9 changes: 3 additions & 6 deletions retinal_rl/analysis/transforms_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@ class TransformStatistics:
source_transforms: dict[str, dict[float, list[FloatArray]]]
noise_transforms: dict[str, dict[float, list[FloatArray]]]

# TODO: Make structure match the analyze / plot structure as receptive_fields

def transform_base_images(
imageset: Imageset, num_steps: int, num_images: int
) -> TransformStatistics:
def analyze(imageset: Imageset, num_steps: int, num_images: int) -> TransformStatistics:
"""Apply transformations to a set of images from an Imageset."""
images: list[Image.Image] = []

Expand Down Expand Up @@ -59,7 +56,7 @@ def transform_base_images(
return resultss


def plot_transforms(
def plot(
source_transforms: dict[str, dict[float, list[FloatArray]]],
noise_transforms: dict[str, dict[float, list[FloatArray]]],
) -> Figure:
Expand Down Expand Up @@ -142,4 +139,4 @@ def plot_transforms(
transform_index += 1

plt.tight_layout()
return fig
return fig
2 changes: 2 additions & 0 deletions retinal_rl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

### IO Handling stuff


class NumpyEncoder(json.JSONEncoder):
"""JSON encoder that handles numpy arrays."""

Expand All @@ -27,6 +28,7 @@ def default(self, obj: Any) -> Any:
return obj.tolist()
return super().default(obj)


### Functions


Expand Down
Loading

0 comments on commit 6f5861e

Please sign in to comment.