diff --git a/retinal_rl/rl/analysis/statistics.py b/retinal_rl/rl/analysis/statistics.py index 77689268..4eb2107b 100644 --- a/retinal_rl/rl/analysis/statistics.py +++ b/retinal_rl/rl/analysis/statistics.py @@ -9,12 +9,14 @@ from torch import Tensor, nn from torch.utils.data import Dataset from tqdm import tqdm +from typing_extensions import deprecated from retinal_rl.models.brain import Brain from retinal_rl.models.circuits.convolutional import ConvolutionalEncoder from retinal_rl.util import encoder_out_size, rf_size_and_start +@deprecated("Use functions of retinal_rl.analysis.statistics") def gradient_receptive_fields( device: torch.device, enc: ConvolutionalEncoder ) -> Dict[str, NDArray[np.float64]]: @@ -52,10 +54,10 @@ def gradient_receptive_fields( # Assert min max is in bounds # potential TODO: change input size if rf is larger than actual input - h_min = max(0,h_min) - w_min = max(0,w_min) - hrf_size = min(hght,hrf_size) - wrf_size = min(wdth,wrf_size) + h_min = max(0, h_min) + w_min = max(0, w_min) + hrf_size = min(hght, hrf_size) + wrf_size = min(wdth, wrf_size) h_max = h_min + hrf_size w_max = w_min + wrf_size @@ -74,13 +76,18 @@ def gradient_receptive_fields( return stas -def _activation_triggered_average(model: nn.Sequential, n_batch: int = 2048, rf_size=None, device=None): + +def _activation_triggered_average( + model: nn.Sequential, n_batch: int = 2048, rf_size=None, device=None +): model.eval() if rf_size is None: _out_channels, input_size = get_input_output_shape(model) else: input_size = rf_size - input_tensor = torch.randn((n_batch, *input_size), requires_grad=False, device=device) + input_tensor = torch.randn( + (n_batch, *input_size), requires_grad=False, device=device + ) output = model(input_tensor) output = sum_collapse_output(output) input_tensor = input_tensor[:, None, :, :, :].expand( @@ -93,31 +100,50 @@ def _activation_triggered_average(model: nn.Sequential, n_batch: int = 2048, rf_ weighted = (weights * input_tensor).sum(0) return weighted.cpu().detach(), weight_sums.cpu().detach() + def activation_triggered_average( - model: nn.Sequential, n_batch: int = 2048, n_iter: int = 1, rf_size=None, device=None + model: nn.Sequential, + n_batch: int = 2048, + n_iter: int = 1, + rf_size=None, + device=None, ) -> Dict[str, NDArray[np.float64]]: # TODO: WIP warnings.warn("Code is not tested and might contain bugs.") stas: Dict[str, NDArray[np.float64]] = {} with torch.no_grad(): - for index, (layer_name, mdl) in tqdm(enumerate(model.named_children()), total=len(model)): - weighted, weight_sums = _activation_triggered_average(model[:index+1], n_batch, device=device) - for _ in tqdm(range(n_iter - 1), total=n_iter-1, leave=False): - it_weighted, it_weight_sums = _activation_triggered_average(model[:index+1], n_batch, rf_size, device=device) + for index, (layer_name, mdl) in tqdm( + enumerate(model.named_children()), total=len(model) + ): + weighted, weight_sums = _activation_triggered_average( + model[: index + 1], n_batch, device=device + ) + for _ in tqdm(range(n_iter - 1), total=n_iter - 1, leave=False): + it_weighted, it_weight_sums = _activation_triggered_average( + model[: index + 1], n_batch, rf_size, device=device + ) weighted += it_weighted weight_sums += it_weight_sums - stas[layer_name] = (weighted.cpu().detach() / weight_sums[:, None, None, None] / len(weight_sums)).numpy() + stas[layer_name] = ( + weighted.cpu().detach() + / weight_sums[:, None, None, None] + / len(weight_sums) + ).numpy() torch.cuda.empty_cache() return stas + +@deprecated("Use functions of retinal_rl.analysis.statistics") def sum_collapse_output(out_tensor): if len(out_tensor.shape) > 2: - sum_dims = [2+i for i in range(len(out_tensor.shape)-2)] + sum_dims = [2 + i for i in range(len(out_tensor.shape) - 2)] out_tensor = torch.sum(out_tensor, dim=sum_dims) return out_tensor -def _find_last_layer_shape(model: nn.Sequential) -> Tuple[int, Optional[int], Optional[int], Optional[int], bool]: +def _find_last_layer_shape( + model: nn.Sequential, +) -> Tuple[int, Optional[int], Optional[int], Optional[int], bool]: _first = 0 down_stream_linear = False num_outputs = None @@ -133,22 +159,31 @@ def _find_last_layer_shape(model: nn.Sequential) -> Tuple[int, Optional[int], Op if isinstance(layer, nn.Conv2d): num_outputs = layer.out_channels in_channels = layer.in_channels - in_size = layer.in_channels * ((layer.kernel_size[0]-1)*layer.dilation[0]+1) ** 2 + in_size = ( + layer.in_channels + * ((layer.kernel_size[0] - 1) * layer.dilation[0] + 1) ** 2 + ) break if isinstance(layer, (nn.MaxPool2d, nn.AvgPool2d)): - for prev_layer in reversed(model[:-i-1]): + for prev_layer in reversed(model[: -i - 1]): if isinstance(prev_layer, nn.Conv2d): in_channels = prev_layer.out_channels break if isinstance(prev_layer, nn.Linear): - in_channels=1 + in_channels = 1 else: raise TypeError("layer before pooling needs to be conv or linear") - _kernel_size = layer.kernel_size if isinstance(layer.kernel_size, int) else layer.kernel_size[0] + _kernel_size = ( + layer.kernel_size + if isinstance(layer.kernel_size, int) + else layer.kernel_size[0] + ) in_size = _kernel_size**2 * in_channels break return _first, num_outputs, in_size, in_channels, down_stream_linear + +@deprecated("Use functions of retinal_rl.analysis.statistics") def get_input_output_shape(model: nn.Sequential): """ Calculates the 'minimal' input and output of a sequential model. @@ -159,7 +194,9 @@ def get_input_output_shape(model: nn.Sequential): TODO: Check if still needed, function near duplicate of some of Sachas code """ - _first, num_outputs, in_size, in_channels, down_stream_linear = _find_last_layer_shape(model) + _first, num_outputs, in_size, in_channels, down_stream_linear = ( + _find_last_layer_shape(model) + ) for i, layer in enumerate(reversed(model[:-_first])): if isinstance(layer, nn.Linear): @@ -176,11 +213,11 @@ def get_input_output_shape(model: nn.Sequential): in_size = ( (in_size - 1) * layer.stride[0] - 2 * layer.padding[0] * down_stream_linear - + ((layer.kernel_size[0]-1)*layer.dilation[0]+1) + + ((layer.kernel_size[0] - 1) * layer.dilation[0] + 1) ) in_size = in_size**2 * in_channels elif isinstance(layer, (nn.MaxPool2d, nn.AvgPool2d)): - for prev_layer in reversed(model[:-i-_first-1]): + for prev_layer in reversed(model[: -i - _first - 1]): if isinstance(prev_layer, nn.Conv2d): in_channels = prev_layer.out_channels break @@ -196,6 +233,8 @@ def get_input_output_shape(model: nn.Sequential): input_size = (in_channels, in_size, in_size) return num_outputs, input_size + +@deprecated("Use functions of retinal_rl.analysis.statistics") def get_reconstructions( device: torch.device, brain: Brain,