From 78f4d4d5df65da1e9095887f7f5c163d90e2ae8a Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Thu, 12 Dec 2024 17:37:17 +0100 Subject: [PATCH] refactoring: reorganize analysis code etc. almost done --- retinal_rl/analysis/channel_analysis.py | 398 ++++++++++++++++ retinal_rl/analysis/plot.py | 497 ++++---------------- retinal_rl/analysis/receptive_fields.py | 124 +++++ retinal_rl/analysis/reconstructions.py | 248 ++++++++++ retinal_rl/analysis/statistics.py | 418 ---------------- retinal_rl/analysis/transforms_analysis.py | 145 ++++++ retinal_rl/util.py | 13 +- runner/frameworks/classification/analyze.py | 454 +++++------------- 8 files changed, 1151 insertions(+), 1146 deletions(-) create mode 100644 retinal_rl/analysis/channel_analysis.py create mode 100644 retinal_rl/analysis/receptive_fields.py create mode 100644 retinal_rl/analysis/reconstructions.py delete mode 100644 retinal_rl/analysis/statistics.py create mode 100644 retinal_rl/analysis/transforms_analysis.py diff --git a/retinal_rl/analysis/channel_analysis.py b/retinal_rl/analysis/channel_analysis.py new file mode 100644 index 00000000..d95caf57 --- /dev/null +++ b/retinal_rl/analysis/channel_analysis.py @@ -0,0 +1,398 @@ + + + +import logging +from dataclasses import dataclass +from typing import cast + +import numpy as np +import torch +import torch.utils +import torch.utils.data +from matplotlib import gridspec +from matplotlib import pyplot as plt +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from torch import Tensor, fft, nn +from torch.utils.data import DataLoader + +from retinal_rl.analysis.plot import FigureLogger, set_integer_ticks +from retinal_rl.classification.imageset import Imageset, ImageSubset +from retinal_rl.models.brain import Brain, get_cnn_circuit +from retinal_rl.util import FloatArray, is_nonlinearity + +logger = logging.getLogger(__name__) + +@dataclass +class SpectralAnalysis: + """Results of spectral analysis for a layer.""" + + mean_power_spectrum: FloatArray + var_power_spectrum: FloatArray + mean_autocorr: FloatArray + var_autocorr: FloatArray + + +@dataclass +class HistogramAnalysis: + """Results of histogram analysis for a layer.""" + + channel_histograms: FloatArray + bin_edges: FloatArray + +def spectral_analysis( + device: torch.device, + imageset: Imageset, + brain: Brain, + max_sample_size: int = 0, +) -> dict[str, SpectralAnalysis]: + brain.eval() + brain.to(device) + _, cnn_layers = get_cnn_circuit(brain) + + # Prepare dataset + dataloader = _prepare_dataset(imageset, max_sample_size) + + # Initialize results + results = {"input": _layer_spectral_analysis(device, dataloader, nn.Identity())} + + # Analyze each layer + head_layers: list[nn.Module] = [] + + for layer_name, layer in cnn_layers.items(): + head_layers.append(layer) + + if is_nonlinearity(layer): + continue + # TODO: Possible for non Conv2D layers? + + results[layer_name] = _layer_spectral_analysis( + device, dataloader, nn.Sequential(*head_layers) + ) + + return results + + +def histogram_analysis( + device: torch.device, + imageset: Imageset, + brain: Brain, + max_sample_size: int = 0, +) -> dict[str, HistogramAnalysis]: + brain.eval() + brain.to(device) + _, cnn_layers = get_cnn_circuit(brain) + + # Prepare dataset + dataloader = _prepare_dataset(imageset, max_sample_size) + + # Initialize results + results = {"input": _layer_pixel_histograms(device, dataloader, nn.Identity())} + + # Analyze each layer + head_layers: list[nn.Module] = [] + + for layer_name, layer in cnn_layers.items(): + head_layers.append(layer) + if is_nonlinearity(layer): + continue + # TODO: Possible for non Conv2D layers? + results[layer_name] = _layer_pixel_histograms( + device, dataloader, nn.Sequential(*head_layers) + ) + + return results + + +def _prepare_dataset( + imageset: Imageset, max_sample_size: int = 0 +) -> DataLoader[tuple[Tensor, Tensor, int]]: + """Prepare dataset and dataloader for analysis.""" + epoch_len = imageset.epoch_len() + logger.info(f"Original dataset size: {epoch_len}") + + if max_sample_size > 0 and epoch_len > max_sample_size: + indices = torch.randperm(epoch_len)[:max_sample_size].tolist() + subset = ImageSubset(imageset, indices=indices) + logger.info(f"Reducing dataset size for cnn_statistics to {max_sample_size}") + else: + indices = list(range(epoch_len)) + subset = ImageSubset(imageset, indices=indices) + logger.info("Using full dataset for cnn_statistics") + + return DataLoader(subset, batch_size=64, shuffle=False) + + +def _layer_pixel_histograms( + device: torch.device, + dataloader: DataLoader[tuple[Tensor, Tensor, int]], + model: nn.Module, + num_bins: int = 20, +) -> HistogramAnalysis: + """Compute histograms of pixel/activation values for each channel across all data in an imageset.""" + _, first_batch, _ = next(iter(dataloader)) + with torch.no_grad(): + first_batch = model(first_batch.to(device)) + num_channels: int = first_batch.shape[1] + + # Initialize variables for dynamic range computation + global_min = torch.full((num_channels,), float("inf"), device=device) + global_max = torch.full((num_channels,), float("-inf"), device=device) + + # First pass: compute global min and max + total_elements = 0 + + for _, batch, _ in dataloader: + with torch.no_grad(): + batch = model(batch.to(device)) + batch_min, _ = batch.view(-1, num_channels).min(dim=0) + batch_max, _ = batch.view(-1, num_channels).max(dim=0) + global_min = torch.min(global_min, batch_min) + global_max = torch.max(global_max, batch_max) + total_elements += batch.numel() // num_channels + + # Compute histogram parameters + hist_range: tuple[float, float] = (global_min.min().item(), global_max.max().item()) + + histograms: Tensor = torch.zeros( + (num_channels, num_bins), dtype=torch.float64, device=device + ) + + for _, batch, _ in dataloader: + with torch.no_grad(): + batch = model(batch.to(device)) + for c in range(num_channels): + channel_data = batch[:, c, :, :].reshape(-1) + hist = torch.histc( + channel_data, bins=num_bins, min=hist_range[0], max=hist_range[1] + ) + histograms[c] += hist + + bin_width = (hist_range[1] - hist_range[0]) / num_bins + normalized_histograms = histograms / (total_elements * bin_width / num_channels) + + return HistogramAnalysis( + normalized_histograms.cpu().numpy(), + np.linspace(hist_range[0], hist_range[1], num_bins + 1, dtype=np.float64), + ) + + +def _layer_spectral_analysis( + device: torch.device, + dataloader: DataLoader[tuple[Tensor, Tensor, int]], + model: nn.Module, +) -> SpectralAnalysis: + """Compute spectral analysis statistics for each channel across all data in an imageset.""" + _, first_batch, _ = next(iter(dataloader)) + with torch.no_grad(): + first_batch = model(first_batch.to(device)) + image_size = first_batch.shape[1:] + + # Initialize variables for dynamic range computation + mean_power_spectrum = torch.zeros(image_size, dtype=torch.float64, device=device) + m2_power_spectrum = torch.zeros(image_size, dtype=torch.float64, device=device) + mean_autocorr = torch.zeros(image_size, dtype=torch.float64, device=device) + m2_autocorr = torch.zeros(image_size, dtype=torch.float64, device=device) + autocorr: Tensor = torch.zeros(image_size, dtype=torch.float64, device=device) + + count = 0 + + for _, batch, _ in dataloader: + with torch.no_grad(): + batch = model(batch.to(device)) + for image in batch: + count += 1 + + # Compute power spectrum + power_spectrum = torch.abs(fft.fft2(image)) ** 2 + + # Compute power spectrum statistics + mean_power_spectrum += power_spectrum + m2_power_spectrum += power_spectrum**2 + + # Compute normalized autocorrelation + autocorr = cast(Tensor, fft.ifft2(power_spectrum)).real + max_abs_autocorr = torch.amax( + torch.abs(autocorr), dim=(-2, -1), keepdim=True + ) + autocorr = autocorr / (max_abs_autocorr + 1e-8) + + # Compute autocorrelation statistics + mean_autocorr += autocorr + m2_autocorr += autocorr**2 + + mean_power_spectrum /= count + mean_autocorr /= count + var_power_spectrum = m2_power_spectrum / count - (mean_power_spectrum / count) ** 2 + var_autocorr = m2_autocorr / count - (mean_autocorr / count) ** 2 + + return SpectralAnalysis( + mean_power_spectrum.cpu().numpy(), + var_power_spectrum.cpu().numpy(), + mean_autocorr.cpu().numpy(), + var_autocorr.cpu().numpy(), + ) + + +def plot( + log: FigureLogger, + rf_result: dict[str, FloatArray], + spectral_result: dict[str, SpectralAnalysis], + histogram_result: dict[str, HistogramAnalysis], + epoch: int, + copy_checkpoint: bool, +): + for layer_name, layer_rfs in rf_result.items(): + if layer_name != "input": + continue + layer_spectral = spectral_result[layer_name] + layer_histogram = histogram_result[layer_name] + for channel in range(layer_rfs.shape[0]): + channel_fig = layer_channel_plots( + layer_rfs, + layer_spectral, + layer_histogram, + layer_name=layer_name, + channel=channel, + ) + log.log_figure( + channel_fig, + f"{layer_name}_layer_channel_analysis", + f"channel_{channel}", + epoch, + copy_checkpoint, + ) + + +def layer_channel_plots( + receptive_fields: FloatArray, + spectral: SpectralAnalysis, + histogram: HistogramAnalysis, + layer_name: str, + channel: int, +) -> Figure: + """Plot receptive fields, pixel histograms, and autocorrelation plots for a single channel in a layer.""" + axs: np.ndarray[Axes] + fig, axs = plt.subplots(2, 3, figsize=(20, 10)) + fig.suptitle(f"Layer: {layer_name}, Channel: {channel}", fontsize=16) + + # Receptive Fields + rf = receptive_fields[channel] + _plot_receptive_fields(axs[0, 0], rf) + axs[0, 0].set_title("Receptive Field") + axs[0, 0].set_xlabel("X") + axs[0, 0].set_ylabel("Y") + + # Pixel Histograms + hist = histogram.channel_histograms[channel] + bin_edges = histogram.bin_edges + axs[1, 0].bar( + bin_edges[:-1], + hist, + width=np.diff(bin_edges), + align="edge", + color="gray", + edgecolor="black", + ) + axs[1, 0].set_title("Channel Histogram") + axs[1, 0].set_xlabel("Value") + axs[1, 0].set_ylabel("Empirical Probability") + + # Autocorrelation plots + # Plot average 2D autocorrelation and variance + autocorr = fft.fftshift(spectral.mean_autocorr[channel]) + h, w = autocorr.shape + extent = [-w // 2, w // 2, -h // 2, h // 2] + im = axs[0, 1].imshow( + autocorr, cmap="twilight", vmin=-1, vmax=1, origin="lower", extent=extent + ) + axs[0, 1].set_title("Average 2D Autocorrelation") + axs[0, 1].set_xlabel("Lag X") + axs[0, 1].set_ylabel("Lag Y") + fig.colorbar(im, ax=axs[0, 1]) + set_integer_ticks(axs[0, 1]) + + autocorr_sd = fft.fftshift(np.sqrt(spectral.var_autocorr[channel])) + im = axs[0, 2].imshow( + autocorr_sd, cmap="inferno", origin="lower", extent=extent, vmin=0 + ) + axs[0, 2].set_title("2D Autocorrelation SD") + axs[0, 2].set_xlabel("Lag X") + axs[0, 2].set_ylabel("Lag Y") + fig.colorbar(im, ax=axs[0, 2]) + set_integer_ticks(axs[0, 2]) + + # Plot average 2D power spectrum + log_power_spectrum = fft.fftshift(np.log1p(spectral.mean_power_spectrum[channel])) + h, w = log_power_spectrum.shape + + im = axs[1, 1].imshow( + log_power_spectrum, cmap="viridis", origin="lower", extent=extent, vmin=0 + ) + axs[1, 1].set_title("Average 2D Power Spectrum (log)") + axs[1, 1].set_xlabel("Frequency X") + axs[1, 1].set_ylabel("Frequency Y") + fig.colorbar(im, ax=axs[1, 1]) + set_integer_ticks(axs[1, 1]) + + log_power_spectrum_sd = fft.fftshift( + np.log1p(np.sqrt(spectral.var_power_spectrum[channel])) + ) + im = axs[1, 2].imshow( + log_power_spectrum_sd, + cmap="viridis", + origin="lower", + extent=extent, + vmin=0, + ) + axs[1, 2].set_title("2D Power Spectrum SD") + axs[1, 2].set_xlabel("Frequency X") + axs[1, 2].set_ylabel("Frequency Y") + fig.colorbar(im, ax=axs[1, 2]) + set_integer_ticks(axs[1, 2]) + + plt.tight_layout() + return fig + + +def _plot_receptive_fields(ax: Axes, rf: FloatArray): + """Plot full-color receptive field and individual color channels for CIFAR-10 range (-1 to 1).""" + # Clear the main axes + ax.clear() + ax.axis("off") + + # Create a GridSpec within the given axes + gs = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec()) + + rf_full = np.moveaxis(rf, 0, -1) # Move channel axis to the last dimension + rf_min = rf_full.min() + rf_max = rf_full.max() + rf_full = (rf_full - rf_min) / (rf_max - rf_min) + # Full-color receptive field + + ax_full = ax.figure.add_subplot(gs[0, 0]) + ax_full.imshow(rf_full) + ax_full.set_title("Full Color") + ax_full.axis("off") + + # Individual color channels + channels = ["Red", "Green", "Blue"] + cmaps = ["RdGy_r", "RdYlGn", "PuOr"] # Diverging colormaps centered at 0 + positions = [(0, 1), (1, 0), (1, 1)] # Correct positions for a 2x2 grid + for i in range(3): + row, col = positions[i] + ax_channel = ax.figure.add_subplot(gs[row, col]) + im = ax_channel.imshow(rf[i], cmap=cmaps[i], vmin=rf_min, vmax=rf_max) + ax_channel.set_title(channels[i]) + ax_channel.axis("off") + plt.colorbar(im, ax=ax_channel, fraction=0.046, pad=0.04) + + # Add min and max values as text + ax.text( + 0.5, + -0.05, + f"Min: {rf.min():.2f}, Max: {rf.max():.2f}", + horizontalalignment="center", + verticalalignment="center", + transform=ax.transAxes, + ) \ No newline at end of file diff --git a/retinal_rl/analysis/plot.py b/retinal_rl/analysis/plot.py index 6341a337..0247493a 100644 --- a/retinal_rl/analysis/plot.py +++ b/retinal_rl/analysis/plot.py @@ -1,25 +1,26 @@ """Utility functions for plotting the results of statistical analyses.""" -from typing import Dict, List, Tuple +import shutil +from pathlib import Path import matplotlib.pyplot as plt import networkx as nx import numpy as np import seaborn as sns -from matplotlib import gridspec, patches +import wandb +from matplotlib import patches from matplotlib.axes import Axes from matplotlib.figure import Figure from matplotlib.lines import Line2D from matplotlib.patches import Circle, Wedge from matplotlib.ticker import MaxNLocator -from numpy import fft from retinal_rl.models.brain import Brain from retinal_rl.models.objective import ContextT, Objective from retinal_rl.util import FloatArray -def make_image_grid(arrays: List[FloatArray], nrow: int) -> FloatArray: +def make_image_grid(arrays: list[FloatArray], nrow: int) -> FloatArray: """Create a grid of images from a list of numpy arrays.""" # Assuming arrays are [C, H, W] n = len(arrays) @@ -40,105 +41,19 @@ def make_image_grid(arrays: List[FloatArray], nrow: int) -> FloatArray: return grid -def plot_transforms( - source_transforms: Dict[str, Dict[float, List[FloatArray]]], - noise_transforms: Dict[str, Dict[float, List[FloatArray]]], -) -> Figure: - """Plot effects of source and noise transforms on images. - - Args: - source_transforms: Dictionary of source transforms (numpy arrays) - noise_transforms: Dictionary of noise transforms (numpy arrays) - - Returns: - Figure containing the plotted transforms - """ - num_source_transforms = len(source_transforms) - num_noise_transforms = len(noise_transforms) - num_transforms = num_source_transforms + num_noise_transforms - num_images = len( - next(iter(source_transforms.values()))[ - next(iter(next(iter(source_transforms.values())).keys())) - ] - ) - - fig, axs = plt.subplots(num_transforms, 1, figsize=(20, 5 * num_transforms)) - if num_transforms == 1: - axs = [axs] - - transform_index = 0 - - # Plot source transforms - for transform_name, transform_data in source_transforms.items(): - ax = axs[transform_index] - steps = sorted(transform_data.keys()) - - # Create a grid of images for each step - images = [ - make_image_grid( - [(img * 0.5 + 0.5) for img in transform_data[step]], - nrow=num_images, - ) - for step in steps - ] - grid = make_image_grid(images, nrow=len(steps)) - - # Move channels last for imshow - grid_display = np.transpose(grid, (1, 2, 0)) - ax.imshow(grid_display) - ax.set_title(f"Source Transform: {transform_name}") - ax.set_xticks( - [(i + 0.5) * grid_display.shape[1] / len(steps) for i in range(len(steps))] - ) - ax.set_xticklabels([f"{step:.2f}" for step in steps]) - ax.set_yticks([]) - - transform_index += 1 - - # Plot noise transforms - for transform_name, transform_data in noise_transforms.items(): - ax = axs[transform_index] - steps = sorted(transform_data.keys()) - - # Create a grid of images for each step - images = [ - make_image_grid( - [(img * 0.5 + 0.5) for img in transform_data[step]], - nrow=num_images, - ) - for step in steps - ] - grid = make_image_grid(images, nrow=len(steps)) - - # Move channels last for imshow - grid_display = np.transpose(grid, (1, 2, 0)) - ax.imshow(grid_display) - ax.set_title(f"Noise Transform: {transform_name}") - ax.set_xticks( - [(i + 0.5) * grid_display.shape[1] / len(steps) for i in range(len(steps))] - ) - ax.set_xticklabels([f"{step:.2f}" for step in steps]) - ax.set_yticks([]) - - transform_index += 1 - - plt.tight_layout() - return fig - - def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> Figure: graph = brain.connectome # Compute the depth of each node - depths: Dict[str, int] = {} + depths: dict[str, int] = {} for node in nx.topological_sort(graph): depths[node] = ( max([depths[pred] for pred in graph.predecessors(node)] + [-1]) + 1 ) # Create a position dictionary based on depth - pos: Dict[str, Tuple[float, float]] = {} - nodes_at_depth: Dict[int, List[str]] = {} + pos: dict[str, tuple[float, float]] = {} + nodes_at_depth: dict[int, list[str]] = {} for node, depth in depths.items(): if depth not in nodes_at_depth: nodes_at_depth[depth] = [] @@ -253,19 +168,18 @@ def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> F def plot_receptive_field_sizes( - input_shape: Tuple[int, ...], layers: Dict[str, Dict[str, FloatArray]] + input_shape: tuple[int, ...], rf_layers: dict[str, FloatArray] ) -> Figure: """Plot the receptive field sizes for each layer of the convolutional part of the network.""" # Get visual field size from the input shape [_, height, width] = list(input_shape) # Calculate receptive field sizes for each layer - rf_sizes: List[Tuple[int, int]] = [] - layer_names: List[str] = [] - for name, layer_data in layers.items(): - if name == "input": + rf_sizes: list[tuple[int, int]] = [] + layer_names: list[str] = [] + for name, rf in rf_layers.items(): + if name == "input": # TODO: Should not be possible?! continue - rf = layer_data["receptive_fields"] rf_height, rf_width = rf.shape[2:] rf_sizes.append((rf_height, rf_width)) layer_names.append(name) @@ -318,7 +232,7 @@ def plot_receptive_field_sizes( return fig -def plot_histories(histories: Dict[str, List[float]]) -> Figure: +def plot_histories(histories: dict[str, list[float]]) -> Figure: """Plot training and test losses over epochs.""" train_metrics = [ key.split("_", 1)[1] for key in histories if key.startswith("train_") @@ -334,7 +248,7 @@ def plot_histories(histories: Dict[str, List[float]]) -> Figure: num_rows = (len(metrics) + 1) // 2 fig: Figure - axs: List[Axes] + axs: list[Axes] fig, axs = plt.subplots( num_rows, 2, figsize=(15, 5 * num_rows), constrained_layout=True ) @@ -372,308 +286,87 @@ def plot_histories(histories: Dict[str, List[float]]) -> Figure: return fig -def plot_channel_statistics( - receptive_fields: FloatArray, - spectral: Dict[str, FloatArray], - histogram: Dict[str, FloatArray], - layer_name: str, - channel: int, -) -> Figure: - """Plot receptive fields, pixel histograms, and autocorrelation plots for a single channel in a layer.""" - fig, axs = plt.subplots(2, 3, figsize=(20, 10)) - fig.suptitle(f"Layer: {layer_name}, Channel: {channel}", fontsize=16) - - # Receptive Fields - rf = receptive_fields[channel] - _plot_receptive_fields(axs[0, 0], rf) - axs[0, 0].set_title("Receptive Field") - axs[0, 0].set_xlabel("X") - axs[0, 0].set_ylabel("Y") - - # Pixel Histograms - hist = histogram["channel_histograms"][channel] - bin_edges = histogram["bin_edges"] - axs[1, 0].bar( - bin_edges[:-1], - hist, - width=np.diff(bin_edges), - align="edge", - color="gray", - edgecolor="black", - ) - axs[1, 0].set_title("Channel Histogram") - axs[1, 0].set_xlabel("Value") - axs[1, 0].set_ylabel("Empirical Probability") - - # Autocorrelation plots - # Plot average 2D autocorrelation and variance - autocorr = fft.fftshift(spectral["mean_autocorr"][channel]) - h, w = autocorr.shape - extent = [-w // 2, w // 2, -h // 2, h // 2] - im = axs[0, 1].imshow( - autocorr, cmap="twilight", vmin=-1, vmax=1, origin="lower", extent=extent - ) - axs[0, 1].set_title("Average 2D Autocorrelation") - axs[0, 1].set_xlabel("Lag X") - axs[0, 1].set_ylabel("Lag Y") - fig.colorbar(im, ax=axs[0, 1]) - _set_integer_ticks(axs[0, 1]) - - autocorr_sd = fft.fftshift(np.sqrt(spectral["var_autocorr"][channel])) - im = axs[0, 2].imshow( - autocorr_sd, cmap="inferno", origin="lower", extent=extent, vmin=0 - ) - axs[0, 2].set_title("2D Autocorrelation SD") - axs[0, 2].set_xlabel("Lag X") - axs[0, 2].set_ylabel("Lag Y") - fig.colorbar(im, ax=axs[0, 2]) - _set_integer_ticks(axs[0, 2]) - - # Plot average 2D power spectrum - log_power_spectrum = fft.fftshift( - np.log1p(spectral["mean_power_spectrum"][channel]) - ) - h, w = log_power_spectrum.shape - - im = axs[1, 1].imshow( - log_power_spectrum, cmap="viridis", origin="lower", extent=extent, vmin=0 - ) - axs[1, 1].set_title("Average 2D Power Spectrum (log)") - axs[1, 1].set_xlabel("Frequency X") - axs[1, 1].set_ylabel("Frequency Y") - fig.colorbar(im, ax=axs[1, 1]) - _set_integer_ticks(axs[1, 1]) - - log_power_spectrum_sd = fft.fftshift( - np.log1p(np.sqrt(spectral["var_power_spectrum"][channel])) - ) - im = axs[1, 2].imshow( - log_power_spectrum_sd, - cmap="viridis", - origin="lower", - extent=extent, - vmin=0, - ) - axs[1, 2].set_title("2D Power Spectrum SD") - axs[1, 2].set_xlabel("Frequency X") - axs[1, 2].set_ylabel("Frequency Y") - fig.colorbar(im, ax=axs[1, 2]) - _set_integer_ticks(axs[1, 2]) - - plt.tight_layout() - return fig - - -def _set_integer_ticks(ax: Axes): +def set_integer_ticks(ax: Axes): """Set integer ticks for both x and y axes.""" ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) - -def plot_reconstructions( - normalization_mean: List[float], - normalization_std: List[float], - train_sources: List[Tuple[FloatArray, int]], - train_inputs: List[Tuple[FloatArray, int]], - train_estimates: List[Tuple[FloatArray, int]], - test_sources: List[Tuple[FloatArray, int]], - test_inputs: List[Tuple[FloatArray, int]], - test_estimates: List[Tuple[FloatArray, int]], - num_samples: int, -) -> Figure: - """Plot original and reconstructed images for both training and test sets, including the classes.""" - fig, axes = plt.subplots(6, num_samples, figsize=(15, 10)) - - for i in range(num_samples): - train_source, _ = train_sources[i] - train_input, train_class = train_inputs[i] - train_recon, train_pred = train_estimates[i] - test_source, _ = test_sources[i] - test_input, test_class = test_inputs[i] - test_recon, test_pred = test_estimates[i] - - # Unnormalize the original images using the normalization lists - # Arrays are already [C, H, W], need to move channels to last dimension - train_source = ( - np.transpose(train_source, (1, 2, 0)) * normalization_std - + normalization_mean - ) - train_input = ( - np.transpose(train_input, (1, 2, 0)) * normalization_std - + normalization_mean - ) - train_recon = ( - np.transpose(train_recon, (1, 2, 0)) * normalization_std - + normalization_mean - ) - test_source = ( - np.transpose(test_source, (1, 2, 0)) * normalization_std - + normalization_mean - ) - test_input = ( - np.transpose(test_input, (1, 2, 0)) * normalization_std + normalization_mean - ) - test_recon = ( - np.transpose(test_recon, (1, 2, 0)) * normalization_std + normalization_mean - ) - - axes[0, i].imshow(np.clip(train_source, 0, 1)) - axes[0, i].axis("off") - axes[0, i].set_title(f"Class: {train_class}") - - axes[1, i].imshow(np.clip(train_input, 0, 1)) - axes[1, i].axis("off") - axes[1, i].set_title(f"Class: {train_class}") - - axes[2, i].imshow(np.clip(train_recon, 0, 1)) - axes[2, i].axis("off") - axes[2, i].set_title(f"Pred: {train_pred}") - - axes[3, i].imshow(np.clip(test_source, 0, 1)) - axes[3, i].axis("off") - axes[3, i].set_title(f"Class: {test_class}") - - axes[4, i].imshow(np.clip(test_input, 0, 1)) - axes[4, i].axis("off") - axes[4, i].set_title(f"Class: {test_class}") - - axes[5, i].imshow(np.clip(test_recon, 0, 1)) - axes[5, i].axis("off") - axes[5, i].set_title(f"Pred: {test_pred}") - - # Set y-axis labels for each row - fig.text( - 0.02, - 0.90, - "Train Source", - va="center", - rotation="vertical", - fontsize=12, - weight="bold", - ) - fig.text( - 0.02, - 0.74, - "Train Input", - va="center", - rotation="vertical", - fontsize=12, - weight="bold", - ) - fig.text( - 0.02, - 0.56, - "Train Recon.", - va="center", - rotation="vertical", - fontsize=12, - weight="bold", - ) - fig.text( - 0.02, - 0.40, - "Test Source", - va="center", - rotation="vertical", - fontsize=12, - weight="bold", - ) - fig.text( - 0.02, - 0.24, - "Test Input", - va="center", - rotation="vertical", - fontsize=12, - weight="bold", - ) - fig.text( - 0.02, - 0.08, - "Test Recon.", - va="center", - rotation="vertical", - fontsize=12, - weight="bold", - ) - - plt.tight_layout() - return fig - - -def _plot_receptive_fields(ax: Axes, rf: FloatArray): - """Plot full-color receptive field and individual color channels for CIFAR-10 range (-1 to 1).""" - # Clear the main axes - ax.clear() - ax.axis("off") - - # Create a GridSpec within the given axes - gs = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec()) - - rf_full = np.moveaxis(rf, 0, -1) # Move channel axis to the last dimension - rf_min = rf_full.min() - rf_max = rf_full.max() - rf_full = (rf_full - rf_min) / (rf_max - rf_min) - # Full-color receptive field - - ax_full = ax.figure.add_subplot(gs[0, 0]) - ax_full.imshow(rf_full) - ax_full.set_title("Full Color") - ax_full.axis("off") - - # Individual color channels - channels = ["Red", "Green", "Blue"] - cmaps = ["RdGy_r", "RdYlGn", "PuOr"] # Diverging colormaps centered at 0 - positions = [(0, 1), (1, 0), (1, 1)] # Correct positions for a 2x2 grid - for i in range(3): - row, col = positions[i] - ax_channel = ax.figure.add_subplot(gs[row, col]) - im = ax_channel.imshow(rf[i], cmap=cmaps[i], vmin=rf_min, vmax=rf_max) - ax_channel.set_title(channels[i]) - ax_channel.axis("off") - plt.colorbar(im, ax=ax_channel, fraction=0.046, pad=0.04) - - # Add min and max values as text - ax.text( - 0.5, - -0.05, - f"Min: {rf.min():.2f}, Max: {rf.max():.2f}", - horizontalalignment="center", - verticalalignment="center", - transform=ax.transAxes, - ) - - -def layer_receptive_field_plots(lyr_rfs: FloatArray, max_cols: int = 8) -> Figure: - """Plot the receptive fields of a convolutional layer.""" - ochns, _, _, _ = lyr_rfs.shape - - # Calculate the number of rows needed based on max_cols - cols = min(ochns, max_cols) - rows = ochns // cols + (1 if ochns % cols > 0 else 0) - - fig, axs0 = plt.subplots( - rows, - cols, - figsize=(cols * 2, 1.6 * rows), - squeeze=False, - ) - - axs = axs0.flat - - for i in range(ochns): - ax = axs[i] - data = np.moveaxis(lyr_rfs[i], 0, -1) # Move channel axis to the last dimension - data_min = data.min() - data_max = data.max() - data = (data - data_min) / (data_max - data_min) - ax.imshow(data) - - ax.set_xticks([]) - ax.set_yticks([]) - ax.spines["top"].set_visible(True) - ax.spines["right"].set_visible(True) - ax.set_title(f"Channel {i+1}") - - fig.tight_layout() # Adjust layout to fit color bars - return fig +class FigureLogger: + def __init__( + self, + use_wandb: bool, + plot_dir: Path, + checkpoint_plot_dir: Path, + run_dir: Path + ): + self.use_wandb = use_wandb + self.plot_dir = plot_dir + self.checkpoint_plot_dir = checkpoint_plot_dir + self.run_dir = run_dir + + def log_figure( + self, + fig: Figure, + sub_dir: str, + file_name: str, + epoch: int, + copy_checkpoint: bool, + ) -> None: + if self.use_wandb: + title = f"{self._wandb_title(sub_dir)}/{self._wandb_title(file_name)}" + img = wandb.Image(fig) + wandb.log({title: img}, commit=False) + else: + self.save_figure(sub_dir, file_name, fig) + if copy_checkpoint: + self._checkpoint_copy(sub_dir, file_name, epoch) + plt.close(fig) + + @staticmethod + def _wandb_title(title: str) -> str: + # Split the title by slashes + parts = title.split("/") + + def capitalize_part(part: str) -> str: + # Split the part by dashes + words = part.split("_") + # Capitalize each word + capitalized_words = [word.capitalize() for word in words] + # Join the words with spaces + return " ".join(capitalized_words) + + # Capitalize each part, then join with slashes + capitalized_parts = [capitalize_part(part) for part in parts] + return "/".join(capitalized_parts) + + def _checkpoint_copy(self, sub_dir: str, file_name: str, epoch: int) -> None: + # TODO: Does this need to be in here? + src_path = self.plot_dir / sub_dir / f"{file_name}.png" + + dest_dir = self.checkpoint_plot_dir / f"epoch_{epoch}" / sub_dir + dest_dir.mkdir(parents=True, exist_ok=True) + dest_path = dest_dir / f"{file_name}.png" + + shutil.copy2(src_path, dest_path) + + def save_figure(self, sub_dir: str, file_name: str, fig: Figure) -> None: + dir = self.plot_dir / sub_dir + dir.mkdir(exist_ok=True) + file_path = dir / f"{file_name}.png" + fig.savefig(file_path) + + def plot_and_save_histories( + self, histories: dict[str, list[float]], save_always: bool = False + ): + if not self.use_wandb or save_always: + hist_fig = plot_histories(histories) + self.save_figure("", "histories", hist_fig) + plt.close(hist_fig) + + def save_summary(self, brain: Brain): + summary = brain.scan() + filepath = self.run_dir / "brain_summary.txt" + filepath.write_text(summary) + + if self.use_wandb: + wandb.save(str(filepath), base_path=self.run_dir, policy="now") diff --git a/retinal_rl/analysis/receptive_fields.py b/retinal_rl/analysis/receptive_fields.py new file mode 100644 index 00000000..da188411 --- /dev/null +++ b/retinal_rl/analysis/receptive_fields.py @@ -0,0 +1,124 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from torch import Tensor, nn + +from retinal_rl.analysis.plot import FigureLogger +from retinal_rl.models.brain import Brain, get_cnn_circuit +from retinal_rl.util import FloatArray, is_nonlinearity, rf_size_and_start + + +def analyze(brain: Brain, device: torch.device): + brain.eval() + brain.to(device) + input_shape, cnn_layers = get_cnn_circuit(brain) + + # Analyze each layer + head_layers: list[ + nn.Module + ] = [] # possible TODO: have get cnn circuit return a nn.Sequential, then looping here is nicer / easier + + results: dict[str, FloatArray] = {} + for layer_name, layer in cnn_layers.items(): + head_layers.append(layer) + + if is_nonlinearity(layer): + continue + + if isinstance(layer, nn.Conv2d): + out_channels = layer.out_channels + else: + raise NotImplementedError( + "Can only compute receptive fields for 2d convolutional layers" + ) + + results[layer_name] = _compute_receptive_fields( + device, head_layers, input_shape, out_channels + ) + + return input_shape, results + +def plot( + log: FigureLogger, + rf_result: dict[str, FloatArray], + epoch: int, + copy_checkpoint: bool, +): + for layer_name, layer_rfs in rf_result.items(): + if layer_name == "input": # TODO: Remove, see where this goes + continue + layer_rf_plots = layer_receptive_field_plots(layer_rfs) + log.log_figure( + layer_rf_plots, + "receptive_fields", + f"{layer_name}", + epoch, + copy_checkpoint, + ) + +def layer_receptive_field_plots( + lyr_rfs: FloatArray, max_cols: int = 8 +) -> Figure: + """Plot the receptive fields of a convolutional layer.""" + ochns, _, _, _ = lyr_rfs.shape + + # Calculate the number of rows needed based on max_cols + cols = min(ochns, max_cols) + rows = ochns // cols + (1 if ochns % cols > 0 else 0) + + fig, axs0 = plt.subplots( + rows, + cols, + figsize=(cols * 2, 1.6 * rows), + squeeze=False, + ) + + axs: list[Axes] = axs0.flat + + for i in range(ochns): + ax = axs[i] + data = np.moveaxis( + lyr_rfs[i], 0, -1 + ) # Move channel axis to the last dimension + data_min = data.min() + data_max = data.max() + data = (data - data_min) / (data_max - data_min) + ax.imshow(data) + + ax.set_xticks([]) + ax.set_yticks([]) + ax.spines["top"].set_visible(True) + ax.spines["right"].set_visible(True) + ax.set_title(f"Channel {i+1}") + + fig.tight_layout() # Adjust layout to fit color bars + return fig + +def _compute_receptive_fields( + device: torch.device, + head_layers: list[nn.Module], + input_shape: tuple[int, ...], + out_channels: int, +) -> FloatArray: + """Compute receptive fields for a sequence of layers.""" + nclrs, hght, wdth = input_shape + imgsz = [1, nclrs, hght, wdth] + obs = torch.zeros(size=imgsz, device=device, requires_grad=True) + + head_model = nn.Sequential(*head_layers) + x = head_model(obs) + + hsz, wsz = x.shape[2:] + hidx = (hsz - 1) // 2 + widx = (wsz - 1) // 2 + + hrf_size, wrf_size, hmn, wmn = rf_size_and_start(head_layers, hidx, widx) + grads: list[Tensor] = [] + + for j in range(out_channels): + grad = torch.autograd.grad(x[0, j, hidx, widx], obs, retain_graph=True)[0] + grads.append(grad[0, :, hmn : hmn + hrf_size, wmn : wmn + wrf_size]) + + return torch.stack(grads).cpu().numpy() diff --git a/retinal_rl/analysis/reconstructions.py b/retinal_rl/analysis/reconstructions.py new file mode 100644 index 00000000..e467a365 --- /dev/null +++ b/retinal_rl/analysis/reconstructions.py @@ -0,0 +1,248 @@ + +import json +from dataclasses import asdict, dataclass +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.figure import Figure + +from retinal_rl.analysis.plot import FigureLogger +from retinal_rl.classification.imageset import Imageset +from retinal_rl.models.brain import Brain +from retinal_rl.models.loss import ContextT, ReconstructionLoss +from retinal_rl.models.objective import Objective +from retinal_rl.util import FloatArray, NumpyEncoder + + +@dataclass +class Reconstructions: + """Set of source images, inputs, and their reconstructions.""" + + sources: list[tuple[FloatArray, int]] + inputs: list[tuple[FloatArray, int]] + estimates: list[tuple[FloatArray, int]] + +@dataclass +class ReconstructionStatistics: + """Results of image reconstruction for both training and test sets.""" + + train: Reconstructions + test: Reconstructions + +# TODO: Make structure match the analyze / plot structure as receptive_fields + +def perform_reconstruction_analysis( + log: FigureLogger, + analyses_dir: Path, + device: torch.device, + brain: Brain, + objective: Objective[ContextT], + train_set: Imageset, + test_set: Imageset, + epoch: int, + copy_checkpoint: bool, +): + reconstruction_decoders = [ + loss.target_decoder + for loss in objective.losses + if isinstance(loss, ReconstructionLoss) + ] + + for decoder in reconstruction_decoders: + norm_means, norm_stds = train_set.normalization_stats + rec_dict = asdict( + reconstruct_images(device, brain, decoder, train_set, test_set, 5) + ) + # Save the reconstructions + rec_path = analyses_dir / f"{decoder}_reconstructions_epoch_{epoch}.json" + with open(rec_path, "w") as f: + json.dump(rec_dict, f, cls=NumpyEncoder) + + recon_fig = plot_reconstructions( + norm_means, + norm_stds, + *rec_dict["train"].values(), + *rec_dict["test"].values(), + num_samples=5, + ) + log.log_figure( + recon_fig, + "reconstruction", + f"{decoder}_reconstructions", + epoch, + copy_checkpoint, + ) + + +def reconstruct_images( + device: torch.device, + brain: Brain, + decoder: str, + test_set: Imageset, + train_set: Imageset, + sample_size: int, +) -> ReconstructionStatistics: + """Compute reconstructions of a set of training and test images using a Brain model.""" + brain.eval() # Set the model to evaluation mode + + def collect_reconstructions( + imageset: Imageset, sample_size: int + ) -> Reconstructions: + """Collect reconstructions for a subset of a dataset.""" + source_subset: list[tuple[FloatArray, int]] = [] + input_subset: list[tuple[FloatArray, int]] = [] + estimates: list[tuple[FloatArray, int]] = [] + indices = torch.randperm(imageset.epoch_len())[:sample_size] + + with torch.no_grad(): # Disable gradient computation + for index in indices: + src, img, k = imageset[int(index)] + src = src.to(device) + img = img.to(device) + stimulus = {"vision": img.unsqueeze(0)} + response = brain(stimulus) + rec_img = response[decoder].squeeze(0) + pred_k = response["classifier"].argmax().item() + source_subset.append((src.cpu().numpy(), k)) + input_subset.append((img.cpu().numpy(), k)) + estimates.append((rec_img.cpu().numpy(), pred_k)) + + return Reconstructions(source_subset, input_subset, estimates) + + return ReconstructionStatistics( + collect_reconstructions(train_set, sample_size), + collect_reconstructions(test_set, sample_size), + ) + + +def plot_reconstructions( + normalization_mean: list[float], + normalization_std: list[float], + train_sources: list[tuple[FloatArray, int]], + train_inputs: list[tuple[FloatArray, int]], + train_estimates: list[tuple[FloatArray, int]], + test_sources: list[tuple[FloatArray, int]], + test_inputs: list[tuple[FloatArray, int]], + test_estimates: list[tuple[FloatArray, int]], + num_samples: int, +) -> Figure: + """Plot original and reconstructed images for both training and test sets, including the classes.""" + fig, axes = plt.subplots(6, num_samples, figsize=(15, 10)) + + for i in range(num_samples): + train_source, _ = train_sources[i] + train_input, train_class = train_inputs[i] + train_recon, train_pred = train_estimates[i] + test_source, _ = test_sources[i] + test_input, test_class = test_inputs[i] + test_recon, test_pred = test_estimates[i] + + # Unnormalize the original images using the normalization lists + # Arrays are already [C, H, W], need to move channels to last dimension + train_source = ( + np.transpose(train_source, (1, 2, 0)) * normalization_std + + normalization_mean + ) + train_input = ( + np.transpose(train_input, (1, 2, 0)) * normalization_std + + normalization_mean + ) + train_recon = ( + np.transpose(train_recon, (1, 2, 0)) * normalization_std + + normalization_mean + ) + test_source = ( + np.transpose(test_source, (1, 2, 0)) * normalization_std + + normalization_mean + ) + test_input = ( + np.transpose(test_input, (1, 2, 0)) * normalization_std + normalization_mean + ) + test_recon = ( + np.transpose(test_recon, (1, 2, 0)) * normalization_std + normalization_mean + ) + + axes[0, i].imshow(np.clip(train_source, 0, 1)) + axes[0, i].axis("off") + axes[0, i].set_title(f"Class: {train_class}") + + axes[1, i].imshow(np.clip(train_input, 0, 1)) + axes[1, i].axis("off") + axes[1, i].set_title(f"Class: {train_class}") + + axes[2, i].imshow(np.clip(train_recon, 0, 1)) + axes[2, i].axis("off") + axes[2, i].set_title(f"Pred: {train_pred}") + + axes[3, i].imshow(np.clip(test_source, 0, 1)) + axes[3, i].axis("off") + axes[3, i].set_title(f"Class: {test_class}") + + axes[4, i].imshow(np.clip(test_input, 0, 1)) + axes[4, i].axis("off") + axes[4, i].set_title(f"Class: {test_class}") + + axes[5, i].imshow(np.clip(test_recon, 0, 1)) + axes[5, i].axis("off") + axes[5, i].set_title(f"Pred: {test_pred}") + + # Set y-axis labels for each row + fig.text( + 0.02, + 0.90, + "Train Source", + va="center", + rotation="vertical", + fontsize=12, + weight="bold", + ) + fig.text( + 0.02, + 0.74, + "Train Input", + va="center", + rotation="vertical", + fontsize=12, + weight="bold", + ) + fig.text( + 0.02, + 0.56, + "Train Recon.", + va="center", + rotation="vertical", + fontsize=12, + weight="bold", + ) + fig.text( + 0.02, + 0.40, + "Test Source", + va="center", + rotation="vertical", + fontsize=12, + weight="bold", + ) + fig.text( + 0.02, + 0.24, + "Test Input", + va="center", + rotation="vertical", + fontsize=12, + weight="bold", + ) + fig.text( + 0.02, + 0.08, + "Test Recon.", + va="center", + rotation="vertical", + fontsize=12, + weight="bold", + ) + + plt.tight_layout() + return fig diff --git a/retinal_rl/analysis/statistics.py b/retinal_rl/analysis/statistics.py deleted file mode 100644 index ca71acc0..00000000 --- a/retinal_rl/analysis/statistics.py +++ /dev/null @@ -1,418 +0,0 @@ -"""Functions for analysis and statistics on a Brain model.""" - -import logging -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, cast - -import numpy as np -import torch -from PIL import Image -from torch import Tensor, fft, nn -from torch.utils.data import DataLoader - -from retinal_rl.classification.imageset import Imageset, ImageSubset -from retinal_rl.classification.transforms import ContinuousTransform -from retinal_rl.models.brain import Brain, get_cnn_circuit -from retinal_rl.util import ( - FloatArray, - is_nonlinearity, - rf_size_and_start, -) - -logger = logging.getLogger(__name__) - - -### Dataclasses ### - - -@dataclass -class TransformStatistics: - """Results of applying transformations to images.""" - - source_transforms: Dict[str, Dict[float, List[FloatArray]]] - noise_transforms: Dict[str, Dict[float, List[FloatArray]]] - - -@dataclass -class Reconstructions: - """Set of source images, inputs, and their reconstructions.""" - - sources: List[Tuple[FloatArray, int]] - inputs: List[Tuple[FloatArray, int]] - estimates: List[Tuple[FloatArray, int]] - - -@dataclass -class ReconstructionStatistics: - """Results of image reconstruction for both training and test sets.""" - - train: Reconstructions - test: Reconstructions - - -@dataclass -class SpectralAnalysis: - """Results of spectral analysis for a layer.""" - - mean_power_spectrum: FloatArray - var_power_spectrum: FloatArray - mean_autocorr: FloatArray - var_autocorr: FloatArray - - -@dataclass -class HistogramAnalysis: - """Results of histogram analysis for a layer.""" - - channel_histograms: FloatArray - bin_edges: FloatArray - - -@dataclass -class LayerStatistics: - """Statistics for a single layer.""" - - receptive_fields: FloatArray - num_channels: int - spectral: Optional[SpectralAnalysis] = None - histogram: Optional[HistogramAnalysis] = None - - -@dataclass -class CNNStatistics: - """Complete statistics for a CNN model.""" - - input_shape: Tuple[int, ...] # nclrs, hght, wdth - layers: Dict[str, LayerStatistics] - - -### Functions ### - - -def transform_base_images( - imageset: Imageset, num_steps: int, num_images: int -) -> TransformStatistics: - """Apply transformations to a set of images from an Imageset.""" - images: List[Image.Image] = [] - - base_dataset = imageset.base_dataset - base_len = imageset.base_len - - for _ in range(num_images): - src, _ = base_dataset[np.random.randint(base_len)] - images.append(src) - - resultss = TransformStatistics( - source_transforms={}, - noise_transforms={}, - ) - - for transforms, results in [ - (imageset.source_transforms, resultss.source_transforms), - (imageset.noise_transforms, resultss.noise_transforms), - ]: - for transform in transforms: - if isinstance(transform, ContinuousTransform): - results[transform.name] = {} - trans_range: Tuple[float, float] = transform.trans_range - transform_steps = np.linspace(*trans_range, num_steps) - for step in transform_steps: - results[transform.name][step] = [] - for img in images: - results[transform.name][step].append( - imageset.to_tensor(transform.transform(img, step)) - .cpu() - .numpy() - ) - - return resultss - - -def reconstruct_images( - device: torch.device, - brain: Brain, - decoder: str, - test_set: Imageset, - train_set: Imageset, - sample_size: int, -) -> ReconstructionStatistics: - """Compute reconstructions of a set of training and test images using a Brain model.""" - brain.eval() # Set the model to evaluation mode - - def collect_reconstructions( - imageset: Imageset, sample_size: int - ) -> Reconstructions: - """Collect reconstructions for a subset of a dataset.""" - source_subset: List[Tuple[FloatArray, int]] = [] - input_subset: List[Tuple[FloatArray, int]] = [] - estimates: List[Tuple[FloatArray, int]] = [] - indices = torch.randperm(imageset.epoch_len())[:sample_size] - - with torch.no_grad(): # Disable gradient computation - for index in indices: - src, img, k = imageset[int(index)] - src = src.to(device) - img = img.to(device) - stimulus = {"vision": img.unsqueeze(0)} - response = brain(stimulus) - rec_img = response[decoder].squeeze(0) - pred_k = response["classifier"].argmax().item() - source_subset.append((src.cpu().numpy(), k)) - input_subset.append((img.cpu().numpy(), k)) - estimates.append((rec_img.cpu().numpy(), pred_k)) - - return Reconstructions(source_subset, input_subset, estimates) - - return ReconstructionStatistics( - collect_reconstructions(train_set, sample_size), - collect_reconstructions(test_set, sample_size), - ) - - -def cnn_statistics( - device: torch.device, - imageset: Imageset, - brain: Brain, - channel_analysis: bool, - max_sample_size: int = 0, -) -> CNNStatistics: - """Compute statistics for a convolutional encoder model.""" - brain.eval() - brain.to(device) - input_shape, cnn_layers = get_cnn_circuit(brain) - - # Prepare dataset - dataloader = _prepare_dataset(imageset, max_sample_size) - - # Initialize results - results = { - "input": _analyze_input(device, dataloader, input_shape, channel_analysis) - } - - # Analyze each layer - head_layers: List[nn.Module] = [] - - for layer_name, layer in cnn_layers.items(): - head_layers.append(layer) - - if is_nonlinearity(layer): - continue - - if isinstance(layer, nn.Conv2d): - out_channels = layer.out_channels - else: - raise NotImplementedError( - "Can only compute receptive fields for 2d convolutional layers" - ) - - results[layer_name] = _analyze_layer( - device, dataloader, head_layers, input_shape, out_channels, channel_analysis - ) - - return CNNStatistics(input_shape, results) - - -def _prepare_dataset( - imageset: Imageset, max_sample_size: int = 0 -) -> DataLoader[Tuple[Tensor, Tensor, int]]: - """Prepare dataset and dataloader for analysis.""" - epoch_len = imageset.epoch_len() - logger.info(f"Original dataset size: {epoch_len}") - - if max_sample_size > 0 and epoch_len > max_sample_size: - indices = torch.randperm(epoch_len)[:max_sample_size].tolist() - subset = ImageSubset(imageset, indices=indices) - logger.info(f"Reducing dataset size for cnn_statistics to {max_sample_size}") - else: - indices = list(range(epoch_len)) - subset = ImageSubset(imageset, indices=indices) - logger.info("Using full dataset for cnn_statistics") - - return DataLoader(subset, batch_size=64, shuffle=False) - - -def _compute_receptive_fields( - device: torch.device, - head_layers: List[nn.Module], - input_shape: Tuple[int, ...], - out_channels: int, -) -> FloatArray: - """Compute receptive fields for a sequence of layers.""" - nclrs, hght, wdth = input_shape - imgsz = [1, nclrs, hght, wdth] - obs = torch.zeros(size=imgsz, device=device, requires_grad=True) - - head_model = nn.Sequential(*head_layers) - x = head_model(obs) - - hsz, wsz = x.shape[2:] - hidx = (hsz - 1) // 2 - widx = (wsz - 1) // 2 - - hrf_size, wrf_size, hmn, wmn = rf_size_and_start(head_layers, hidx, widx) - grads: List[Tensor] = [] - - for j in range(out_channels): - grad = torch.autograd.grad(x[0, j, hidx, widx], obs, retain_graph=True)[0] - grads.append(grad[0, :, hmn : hmn + hrf_size, wmn : wmn + wrf_size]) - - return torch.stack(grads).cpu().numpy() - - -def _analyze_layer( - device: torch.device, - dataloader: DataLoader[Tuple[Tensor, Tensor, int]], - head_layers: List[nn.Module], - input_shape: Tuple[int, ...], - out_channels: int, - channel_analysis: bool = True, -) -> LayerStatistics: - """Analyze statistics for a single layer.""" - head_model = nn.Sequential(*head_layers) - - # Always compute receptive fields - rfs = _compute_receptive_fields(device, head_layers, input_shape, out_channels) - - layer_spectral = None - layer_histograms = None - - # Compute channel-wise statistics only if requested - if channel_analysis: - layer_spectral = _layer_spectral_analysis(device, dataloader, head_model) - layer_histograms = _layer_pixel_histograms(device, dataloader, head_model) - - return LayerStatistics(rfs, out_channels, layer_spectral, layer_histograms) - - -def _analyze_input( - device: torch.device, - dataloader: DataLoader[Tuple[Tensor, Tensor, int]], - input_shape: Tuple[int, ...], - channel_analysis: bool, -) -> LayerStatistics: - """Analyze statistics for the input layer.""" - - input_spectral = None - input_histograms = None - - if channel_analysis: - input_spectral = _layer_spectral_analysis(device, dataloader, nn.Identity()) - input_histograms = _layer_pixel_histograms(device, dataloader, nn.Identity()) - - return LayerStatistics( - np.eye(input_shape[0])[:, :, np.newaxis, np.newaxis], - input_shape[0], - input_spectral, - input_histograms, - ) - - -def _layer_pixel_histograms( - device: torch.device, - dataloader: DataLoader[Tuple[Tensor, Tensor, int]], - model: nn.Module, - num_bins: int = 20, -) -> HistogramAnalysis: - """Compute histograms of pixel/activation values for each channel across all data in an imageset.""" - _, first_batch, _ = next(iter(dataloader)) - with torch.no_grad(): - first_batch = model(first_batch.to(device)) - num_channels: int = first_batch.shape[1] - - # Initialize variables for dynamic range computation - global_min = torch.full((num_channels,), float("inf"), device=device) - global_max = torch.full((num_channels,), float("-inf"), device=device) - - # First pass: compute global min and max - total_elements = 0 - - for _, batch, _ in dataloader: - with torch.no_grad(): - batch = model(batch.to(device)) - batch_min, _ = batch.view(-1, num_channels).min(dim=0) - batch_max, _ = batch.view(-1, num_channels).max(dim=0) - global_min = torch.min(global_min, batch_min) - global_max = torch.max(global_max, batch_max) - total_elements += batch.numel() // num_channels - - # Compute histogram parameters - hist_range: tuple[float, float] = (global_min.min().item(), global_max.max().item()) - - histograms: Tensor = torch.zeros( - (num_channels, num_bins), dtype=torch.float64, device=device - ) - - for _, batch, _ in dataloader: - with torch.no_grad(): - batch = model(batch.to(device)) - for c in range(num_channels): - channel_data = batch[:, c, :, :].reshape(-1) - hist = torch.histc( - channel_data, bins=num_bins, min=hist_range[0], max=hist_range[1] - ) - histograms[c] += hist - - bin_width = (hist_range[1] - hist_range[0]) / num_bins - normalized_histograms = histograms / (total_elements * bin_width / num_channels) - - return HistogramAnalysis( - normalized_histograms.cpu().numpy(), - np.linspace(hist_range[0], hist_range[1], num_bins + 1, dtype=np.float64), - ) - - -def _layer_spectral_analysis( - device: torch.device, - dataloader: DataLoader[Tuple[Tensor, Tensor, int]], - model: nn.Module, -) -> SpectralAnalysis: - """Compute spectral analysis statistics for each channel across all data in an imageset.""" - _, first_batch, _ = next(iter(dataloader)) - with torch.no_grad(): - first_batch = model(first_batch.to(device)) - image_size = first_batch.shape[1:] - - # Initialize variables for dynamic range computation - mean_power_spectrum = torch.zeros(image_size, dtype=torch.float64, device=device) - m2_power_spectrum = torch.zeros(image_size, dtype=torch.float64, device=device) - mean_autocorr = torch.zeros(image_size, dtype=torch.float64, device=device) - m2_autocorr = torch.zeros(image_size, dtype=torch.float64, device=device) - autocorr: Tensor = torch.zeros(image_size, dtype=torch.float64, device=device) - - count = 0 - - for _, batch, _ in dataloader: - with torch.no_grad(): - batch = model(batch.to(device)) - for image in batch: - count += 1 - - # Compute power spectrum - power_spectrum = torch.abs(fft.fft2(image)) ** 2 - - # Compute power spectrum statistics - mean_power_spectrum += power_spectrum - m2_power_spectrum += power_spectrum**2 - - # Compute normalized autocorrelation - autocorr = cast(Tensor, fft.ifft2(power_spectrum)).real - max_abs_autocorr = torch.amax( - torch.abs(autocorr), dim=(-2, -1), keepdim=True - ) - autocorr = autocorr / (max_abs_autocorr + 1e-8) - - # Compute autocorrelation statistics - mean_autocorr += autocorr - m2_autocorr += autocorr**2 - - mean_power_spectrum /= count - mean_autocorr /= count - var_power_spectrum = m2_power_spectrum / count - (mean_power_spectrum / count) ** 2 - var_autocorr = m2_autocorr / count - (mean_autocorr / count) ** 2 - - return SpectralAnalysis( - mean_power_spectrum.cpu().numpy(), - var_power_spectrum.cpu().numpy(), - mean_autocorr.cpu().numpy(), - var_autocorr.cpu().numpy(), - ) diff --git a/retinal_rl/analysis/transforms_analysis.py b/retinal_rl/analysis/transforms_analysis.py new file mode 100644 index 00000000..192225a0 --- /dev/null +++ b/retinal_rl/analysis/transforms_analysis.py @@ -0,0 +1,145 @@ +from dataclasses import dataclass + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.figure import Figure +from PIL import Image + +from retinal_rl.analysis.plot import make_image_grid +from retinal_rl.classification.imageset import Imageset +from retinal_rl.classification.transforms import ContinuousTransform +from retinal_rl.util import FloatArray + + +@dataclass +class TransformStatistics: + """Results of applying transformations to images.""" + + source_transforms: dict[str, dict[float, list[FloatArray]]] + noise_transforms: dict[str, dict[float, list[FloatArray]]] + +# TODO: Make structure match the analyze / plot structure as receptive_fields + +def transform_base_images( + imageset: Imageset, num_steps: int, num_images: int +) -> TransformStatistics: + """Apply transformations to a set of images from an Imageset.""" + images: list[Image.Image] = [] + + base_dataset = imageset.base_dataset + base_len = imageset.base_len + + for _ in range(num_images): + src, _ = base_dataset[np.random.randint(base_len)] + images.append(src) + + resultss = TransformStatistics( + source_transforms={}, + noise_transforms={}, + ) + + for transforms, results in [ + (imageset.source_transforms, resultss.source_transforms), + (imageset.noise_transforms, resultss.noise_transforms), + ]: + for transform in transforms: + if isinstance(transform, ContinuousTransform): + results[transform.name] = {} + trans_range: tuple[float, float] = transform.trans_range + transform_steps = np.linspace(*trans_range, num_steps) + for step in transform_steps: + results[transform.name][step] = [] + for img in images: + results[transform.name][step].append( + imageset.to_tensor(transform.transform(img, step)) + .cpu() + .numpy() + ) + + return resultss + + +def plot_transforms( + source_transforms: dict[str, dict[float, list[FloatArray]]], + noise_transforms: dict[str, dict[float, list[FloatArray]]], +) -> Figure: + """Plot effects of source and noise transforms on images. + + Args: + source_transforms: dictionary of source transforms (numpy arrays) + noise_transforms: dictionary of noise transforms (numpy arrays) + + Returns: + Figure containing the plotted transforms + """ + num_source_transforms = len(source_transforms) + num_noise_transforms = len(noise_transforms) + num_transforms = num_source_transforms + num_noise_transforms + num_images = len( + next(iter(source_transforms.values()))[ + next(iter(next(iter(source_transforms.values())).keys())) + ] + ) + + fig, axs = plt.subplots(num_transforms, 1, figsize=(20, 5 * num_transforms)) + if num_transforms == 1: + axs = [axs] + + transform_index = 0 + + # Plot source transforms + for transform_name, transform_data in source_transforms.items(): + ax = axs[transform_index] + steps = sorted(transform_data.keys()) + + # Create a grid of images for each step + images = [ + make_image_grid( + [(img * 0.5 + 0.5) for img in transform_data[step]], + nrow=num_images, + ) + for step in steps + ] + grid = make_image_grid(images, nrow=len(steps)) + + # Move channels last for imshow + grid_display = np.transpose(grid, (1, 2, 0)) + ax.imshow(grid_display) + ax.set_title(f"Source Transform: {transform_name}") + ax.set_xticks( + [(i + 0.5) * grid_display.shape[1] / len(steps) for i in range(len(steps))] + ) + ax.set_xticklabels([f"{step:.2f}" for step in steps]) + ax.set_yticks([]) + + transform_index += 1 + + # Plot noise transforms + for transform_name, transform_data in noise_transforms.items(): + ax = axs[transform_index] + steps = sorted(transform_data.keys()) + + # Create a grid of images for each step + images = [ + make_image_grid( + [(img * 0.5 + 0.5) for img in transform_data[step]], + nrow=num_images, + ) + for step in steps + ] + grid = make_image_grid(images, nrow=len(steps)) + + # Move channels last for imshow + grid_display = np.transpose(grid, (1, 2, 0)) + ax.imshow(grid_display) + ax.set_title(f"Noise Transform: {transform_name}") + ax.set_xticks( + [(i + 0.5) * grid_display.shape[1] / len(steps) for i in range(len(steps))] + ) + ax.set_xticklabels([f"{step:.2f}" for step in steps]) + ax.set_yticks([]) + + transform_index += 1 + + plt.tight_layout() + return fig \ No newline at end of file diff --git a/retinal_rl/util.py b/retinal_rl/util.py index 7fd84c8d..11c9dd0d 100644 --- a/retinal_rl/util.py +++ b/retinal_rl/util.py @@ -1,8 +1,9 @@ +import json import logging import re from enum import Enum from math import ceil, floor -from typing import List, Tuple, TypeVar, Union, cast +from typing import Any, List, Tuple, TypeVar, Union, cast import numpy as np from numpy.typing import NDArray @@ -16,6 +17,16 @@ T = TypeVar("T") +### IO Handling stuff + +class NumpyEncoder(json.JSONEncoder): + """JSON encoder that handles numpy arrays.""" + + def default(self, obj: Any) -> Any: + if isinstance(obj, np.ndarray): + return obj.tolist() + return super().default(obj) + ### Functions diff --git a/runner/frameworks/classification/analyze.py b/runner/frameworks/classification/analyze.py index 013ba78a..07584d0e 100644 --- a/runner/frameworks/classification/analyze.py +++ b/runner/frameworks/classification/analyze.py @@ -1,53 +1,51 @@ import json -import logging -import shutil -from dataclasses import asdict +from dataclasses import asdict, dataclass from pathlib import Path -from typing import Any, Dict, List +from typing import Optional -import matplotlib.pyplot as plt -import numpy as np import torch -import wandb -from matplotlib.figure import Figure from omegaconf import DictConfig +from retinal_rl.analysis import channel_analysis as channel_ana +from retinal_rl.analysis import receptive_fields from retinal_rl.analysis.plot import ( - layer_receptive_field_plots, + FigureLogger, plot_brain_and_optimizers, - plot_channel_statistics, - plot_histories, plot_receptive_field_sizes, - plot_reconstructions, - plot_transforms, ) -from retinal_rl.analysis.statistics import ( - CNNStatistics, - LayerStatistics, - cnn_statistics, - reconstruct_images, +from retinal_rl.analysis.reconstructions import perform_reconstruction_analysis +from retinal_rl.analysis.transforms_analysis import ( + plot_transforms, transform_base_images, ) from retinal_rl.classification.imageset import Imageset from retinal_rl.models.brain import Brain -from retinal_rl.models.loss import ReconstructionLoss from retinal_rl.models.objective import ContextT, Objective +from retinal_rl.util import FloatArray, NumpyEncoder ### Infrastructure ### - -logger = logging.getLogger(__name__) - init_dir = "initialization_analysis" -class NumpyEncoder(json.JSONEncoder): - """JSON encoder that handles numpy arrays.""" +@dataclass +class AnalysesCfg: + run_dir: Path + plot_dir: Path + checkpoint_plot_dir: Path + data_dir: Path + use_wandb: bool + channel_analysis: bool + plot_sample_size: int - def default(self, obj: Any) -> Any: - if isinstance(obj, np.ndarray): - return obj.tolist() - return super().default(obj) + def __post_init__(self): + self.analyses_dir = Path(self.data_dir) / "analyses" + + # Ensure all dirs exist + self.run_dir.mkdir(exist_ok=True) + self.plot_dir.mkdir(exist_ok=True) + self.checkpoint_plot_dir.mkdir(exist_ok=True) + self.analyses_dir.mkdir(exist_ok=True) ### Analysis ### @@ -58,7 +56,7 @@ def analyze( device: torch.device, brain: Brain, objective: Objective[ContextT], - histories: Dict[str, List[float]], + histories: dict[str, list[float]], train_set: Imageset, test_set: Imageset, epoch: int, @@ -66,71 +64,70 @@ def analyze( ): ## DictConfig - # Path creation - run_dir = Path(cfg.path.run_dir) - run_dir.mkdir(exist_ok=True) - - plot_dir = Path(cfg.path.plot_dir) - plot_dir.mkdir(exist_ok=True) - - checkpoint_plot_dir = Path(cfg.path.checkpoint_plot_dir) - checkpoint_plot_dir.mkdir(exist_ok=True) - - analyses_dir = Path(cfg.path.data_dir) / "analyses" - analyses_dir.mkdir(exist_ok=True) - - # Variables - use_wandb = cfg.logging.use_wandb - channel_analysis = cfg.logging.channel_analysis - plot_sample_size = cfg.logging.plot_sample_size + _cfg = AnalysesCfg( + Path(cfg.path.run_dir), + Path(cfg.path.plot_dir), + Path(cfg.path.checkpoint_plot_dir), + Path(cfg.path.data_dir), + cfg.logging.use_wandb, + cfg.logging.channel_analysis, + cfg.logging.plot_sample_size, + ) + log = FigureLogger( + _cfg.use_wandb, _cfg.plot_dir, _cfg.checkpoint_plot_dir, _cfg.run_dir + ) ## Analysis + log.plot_and_save_histories(histories) - if not use_wandb: - _plot_and_save_histories(plot_dir, histories) + # # Save CNN statistics # TODO: how to do this now... + # with open(_cfg.analyses_dir / f"cnn_stats_epoch_{epoch}.json", "w") as f: + # json.dump(asdict(cnn_stats), f, cls=NumpyEncoder) - # Get CNN statistics and save them - cnn_stats = cnn_statistics( - device, - test_set, - brain, - channel_analysis, - plot_sample_size, + # perform different analyses + input_shape, rf_result = receptive_fields.analyze(brain, device) + receptive_fields.plot( + log, + rf_result, + epoch, + copy_checkpoint, ) - # Save CNN statistics - with open(analyses_dir / f"cnn_stats_epoch_{epoch}.json", "w") as f: - json.dump(asdict(cnn_stats), f, cls=NumpyEncoder) + if _cfg.channel_analysis: + spectral_result = channel_ana.spectral_analysis( + device, test_set, brain, _cfg.plot_sample_size + ) + histogram_result = channel_ana.histogram_analysis( + device, test_set, brain, _cfg.plot_sample_size + ) + # TODO: Do we really want to replot rfs here? + channel_ana.plot( + log, + rf_result, + spectral_result, + histogram_result, + epoch, + copy_checkpoint, + ) + else: + spectral_result, histogram_result = None, None + # plot results if epoch == 0: - _perform_initialization_analysis( - channel_analysis, - use_wandb, - analyses_dir, - plot_dir, - checkpoint_plot_dir, - run_dir, - brain, - objective, + _default_initialization_plots(log, brain, objective, input_shape, rf_result) + _extended_initialization_plots( + log, + _cfg.channel_analysis, + _cfg.analyses_dir, train_set, - cnn_stats, + rf_result, + spectral_result, + histogram_result, ) - _analyze_layers( - channel_analysis, - use_wandb, - plot_dir, - checkpoint_plot_dir, - cnn_stats, - epoch, - copy_checkpoint, - ) - - _perform_reconstruction_analysis( - use_wandb, - analyses_dir, - plot_dir, - checkpoint_plot_dir, + perform_reconstruction_analysis( + log, + _cfg.analyses_dir, device, brain, objective, @@ -140,61 +137,48 @@ def analyze( copy_checkpoint, ) - hist_fig = plot_histories(histories) - _save_figure(plot_dir, "", "histories", hist_fig) - plt.close(hist_fig) + log.plot_and_save_histories(histories, save_always=True) -def _plot_and_save_histories(plot_dir: Path, histories: Dict[str, List[float]]): - hist_fig = plot_histories(histories) - _save_figure(plot_dir, "", "histories", hist_fig) - plt.close(hist_fig) - - -def _perform_initialization_analysis( - channel_analysis: bool, - use_wandb: bool, - analyses_dir: Path, - plot_dir: Path, - checkpoint_plot_dir: Path, - run_dir: Path, +def _default_initialization_plots( + log: FigureLogger, brain: Brain, objective: Objective[ContextT], - train_set: Imageset, - cnn_stats: CNNStatistics, + input_shape: tuple[int, ...], + rf_result: dict[str, FloatArray], ): - summary = brain.scan() - filepath = run_dir / "brain_summary.txt" - filepath.write_text(summary) - - if use_wandb: - wandb.save(str(filepath), base_path=run_dir, policy="now") + log.save_summary(brain) + # TODO: Move this somewhere accessible for RL # TODO: This is a bit of a hack, we should refactor this to get the relevant information out of cnn_stats - rf_sizes_fig = plot_receptive_field_sizes(**asdict(cnn_stats)) - _process_figure( - use_wandb, - plot_dir, - checkpoint_plot_dir, - False, + rf_sizes_fig = plot_receptive_field_sizes(input_shape, rf_result) + log.log_figure( rf_sizes_fig, init_dir, "receptive_field_sizes", 0, + False, ) graph_fig = plot_brain_and_optimizers(brain, objective) - _process_figure( - use_wandb, - plot_dir, - checkpoint_plot_dir, - False, + log.log_figure( graph_fig, init_dir, "brain_graph", 0, + False, ) + +def _extended_initialization_plots( + log: FigureLogger, + channel_analysis: bool, + analyses_dir: Path, + train_set: Imageset, + rf_result: dict[str, FloatArray], + spectral_result: Optional[dict[str, channel_ana.SpectralAnalysis]] = None, + histogram_result: Optional[dict[str, channel_ana.HistogramAnalysis]] = None, +): transforms = transform_base_images(train_set, num_steps=5, num_images=2) # Save transform statistics transform_path = analyses_dir / "transforms.json" @@ -202,232 +186,52 @@ def _perform_initialization_analysis( json.dump(asdict(transforms), f, cls=NumpyEncoder) transforms_fig = plot_transforms(**asdict(transforms)) - _process_figure( - use_wandb, - plot_dir, - checkpoint_plot_dir, - False, + log.log_figure( transforms_fig, init_dir, "transforms", 0, + False, ) - _analyze_input_layer( - use_wandb, - plot_dir, - checkpoint_plot_dir, - cnn_stats.layers["input"], - channel_analysis, - ) - - -def _analyze_layers( - channel_analysis: bool, - use_wandb: bool, - plot_dir: Path, - checkpoint_plot_dir: Path, - cnn_stats: CNNStatistics, - epoch: int, - copy_checkpoint: bool, -): - for layer_name, layer_data in cnn_stats.layers.items(): - if layer_name != "input": - _analyze_regular_layer( - use_wandb, - plot_dir, - checkpoint_plot_dir, - layer_name, - layer_data, - epoch, - copy_checkpoint, - channel_analysis, - ) + if spectral_result and histogram_result: + _analyze_input_layer( + log, + rf_result["input"], + spectral_result["input"], + histogram_result["input"], + channel_analysis, + ) def _analyze_input_layer( - use_wandb: bool, - plot_dir: Path, - checkpoint_plot_dir: Path, - layer_statistics: LayerStatistics, + log: FigureLogger, + rf_result: FloatArray, + spectral_result: channel_ana.SpectralAnalysis, + histogram_result: channel_ana.HistogramAnalysis, channel_analysis: bool, ): - layer_rfs = layer_receptive_field_plots(layer_statistics.receptive_fields) - _process_figure( - use_wandb, - plot_dir, - checkpoint_plot_dir, - False, + layer_rfs = receptive_fields.layer_receptive_field_plots(rf_result) + log.log_figure( layer_rfs, init_dir, "input_rfs", 0, - ) - + False, + ) # TODO: What's the purpose of it - it's just RGB I guess? if channel_analysis: - layer_dict = asdict(layer_statistics) - num_channels = int(layer_dict.pop("num_channels")) - for channel in range(num_channels): - channel_fig = plot_channel_statistics( - **layer_dict, layer_name="input", channel=channel + for channel in range(rf_result.shape[0]): + channel_fig = channel_ana.layer_channel_plots( + rf_result, + spectral_result, + histogram_result, + layer_name="input", + channel=channel, ) - _process_figure( - use_wandb, - plot_dir, - checkpoint_plot_dir, - False, + log.log_figure( channel_fig, init_dir, f"input_channel_{channel}", 0, + False, ) - - -def _analyze_regular_layer( - use_wandb: bool, - plot_dir: Path, - checkpoint_plot_dir: Path, - layer_name: str, - layer_statistics: LayerStatistics, - epoch: int, - copy_checkpoint: bool, - channel_analysis: bool, -): - layer_rfs = layer_receptive_field_plots(layer_statistics.receptive_fields) - _process_figure( - use_wandb, - plot_dir, - checkpoint_plot_dir, - copy_checkpoint, - layer_rfs, - "receptive_fields", - f"{layer_name}", - epoch, - ) - - if channel_analysis: - layer_dict = asdict(layer_statistics) - num_channels = int(layer_dict.pop("num_channels")) - for channel in range(num_channels): - channel_fig = plot_channel_statistics( - **layer_dict, layer_name=layer_name, channel=channel - ) - - _process_figure( - use_wandb, - plot_dir, - checkpoint_plot_dir, - copy_checkpoint, - channel_fig, - f"{layer_name}_layer_channel_analysis", - f"channel_{channel}", - epoch, - ) - - -def _perform_reconstruction_analysis( - use_wandb: bool, - analyses_dir: Path, - plot_dir: Path, - checkpoint_plot_dir: Path, - device: torch.device, - brain: Brain, - objective: Objective[ContextT], - train_set: Imageset, - test_set: Imageset, - epoch: int, - copy_checkpoint: bool, -): - reconstruction_decoders = [ - loss.target_decoder - for loss in objective.losses - if isinstance(loss, ReconstructionLoss) - ] - - for decoder in reconstruction_decoders: - norm_means, norm_stds = train_set.normalization_stats - rec_dict = asdict( - reconstruct_images(device, brain, decoder, train_set, test_set, 5) - ) - # Save the reconstructions - rec_path = analyses_dir / f"{decoder}_reconstructions_epoch_{epoch}.json" - with open(rec_path, "w") as f: - json.dump(rec_dict, f, cls=NumpyEncoder) - - recon_fig = plot_reconstructions( - norm_means, - norm_stds, - *rec_dict["train"].values(), - *rec_dict["test"].values(), - num_samples=5, - ) - _process_figure( - use_wandb, - plot_dir, - checkpoint_plot_dir, - copy_checkpoint, - recon_fig, - "reconstruction", - f"{decoder}_reconstructions", - epoch, - ) - - -### Helper Functions ### - - -def _save_figure(plot_dir: Path, sub_dir: str, file_name: str, fig: Figure) -> None: - dir = plot_dir / sub_dir - dir.mkdir(exist_ok=True) - file_path = dir / f"{file_name}.png" - fig.savefig(file_path) - - -def _checkpoint_copy( - plot_dir: Path, checkpoint_plot_dir: Path, sub_dir: str, file_name: str, epoch: int -) -> None: - src_path = plot_dir / sub_dir / f"{file_name}.png" - - dest_dir = checkpoint_plot_dir / f"epoch_{epoch}" / sub_dir - dest_dir.mkdir(parents=True, exist_ok=True) - dest_path = dest_dir / f"{file_name}.png" - - shutil.copy2(src_path, dest_path) - - -def _wandb_title(title: str) -> str: - # Split the title by slashes - parts = title.split("/") - - def capitalize_part(part: str) -> str: - # Split the part by dashes - words = part.split("_") - # Capitalize each word - capitalized_words = [word.capitalize() for word in words] - # Join the words with spaces - return " ".join(capitalized_words) - - # Capitalize each part, then join with slashes - capitalized_parts = [capitalize_part(part) for part in parts] - return "/".join(capitalized_parts) - - -def _process_figure( - use_wandb: bool, - plot_dir: Path, - checkpoint_plot_dir: Path, - copy_checkpoint: bool, - fig: Figure, - sub_dir: str, - file_name: str, - epoch: int, -) -> None: - if use_wandb: - title = f"{_wandb_title(sub_dir)}/{_wandb_title(file_name)}" - img = wandb.Image(fig) - wandb.log({title: img}, commit=False) - else: - _save_figure(plot_dir, sub_dir, file_name, fig) - if copy_checkpoint: - _checkpoint_copy(plot_dir, checkpoint_plot_dir, sub_dir, file_name, epoch) - plt.close(fig)