Skip to content

Commit

Permalink
doc: mark rl analysis stuff as deprecated
Browse files Browse the repository at this point in the history
  • Loading branch information
fabioseel committed Oct 29, 2024
1 parent 3478315 commit 7168e02
Showing 1 changed file with 60 additions and 21 deletions.
81 changes: 60 additions & 21 deletions retinal_rl/rl/analysis/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from torch import Tensor, nn
from torch.utils.data import Dataset
from tqdm import tqdm
from typing_extensions import deprecated

from retinal_rl.models.brain import Brain
from retinal_rl.models.circuits.convolutional import ConvolutionalEncoder
from retinal_rl.util import encoder_out_size, rf_size_and_start


@deprecated("Use functions of retinal_rl.analysis.statistics")
def gradient_receptive_fields(
device: torch.device, enc: ConvolutionalEncoder
) -> Dict[str, NDArray[np.float64]]:
Expand Down Expand Up @@ -52,10 +54,10 @@ def gradient_receptive_fields(

# Assert min max is in bounds
# potential TODO: change input size if rf is larger than actual input
h_min = max(0,h_min)
w_min = max(0,w_min)
hrf_size = min(hght,hrf_size)
wrf_size = min(wdth,wrf_size)
h_min = max(0, h_min)
w_min = max(0, w_min)
hrf_size = min(hght, hrf_size)
wrf_size = min(wdth, wrf_size)

h_max = h_min + hrf_size
w_max = w_min + wrf_size
Expand All @@ -74,13 +76,18 @@ def gradient_receptive_fields(

return stas

def _activation_triggered_average(model: nn.Sequential, n_batch: int = 2048, rf_size=None, device=None):

def _activation_triggered_average(
model: nn.Sequential, n_batch: int = 2048, rf_size=None, device=None
):
model.eval()
if rf_size is None:
_out_channels, input_size = get_input_output_shape(model)
else:
input_size = rf_size
input_tensor = torch.randn((n_batch, *input_size), requires_grad=False, device=device)
input_tensor = torch.randn(
(n_batch, *input_size), requires_grad=False, device=device
)
output = model(input_tensor)
output = sum_collapse_output(output)
input_tensor = input_tensor[:, None, :, :, :].expand(
Expand All @@ -93,31 +100,50 @@ def _activation_triggered_average(model: nn.Sequential, n_batch: int = 2048, rf_
weighted = (weights * input_tensor).sum(0)
return weighted.cpu().detach(), weight_sums.cpu().detach()


def activation_triggered_average(
model: nn.Sequential, n_batch: int = 2048, n_iter: int = 1, rf_size=None, device=None
model: nn.Sequential,
n_batch: int = 2048,
n_iter: int = 1,
rf_size=None,
device=None,
) -> Dict[str, NDArray[np.float64]]:
# TODO: WIP
warnings.warn("Code is not tested and might contain bugs.")
stas: Dict[str, NDArray[np.float64]] = {}
with torch.no_grad():
for index, (layer_name, mdl) in tqdm(enumerate(model.named_children()), total=len(model)):
weighted, weight_sums = _activation_triggered_average(model[:index+1], n_batch, device=device)
for _ in tqdm(range(n_iter - 1), total=n_iter-1, leave=False):
it_weighted, it_weight_sums = _activation_triggered_average(model[:index+1], n_batch, rf_size, device=device)
for index, (layer_name, mdl) in tqdm(
enumerate(model.named_children()), total=len(model)
):
weighted, weight_sums = _activation_triggered_average(
model[: index + 1], n_batch, device=device
)
for _ in tqdm(range(n_iter - 1), total=n_iter - 1, leave=False):
it_weighted, it_weight_sums = _activation_triggered_average(
model[: index + 1], n_batch, rf_size, device=device
)
weighted += it_weighted
weight_sums += it_weight_sums
stas[layer_name] = (weighted.cpu().detach() / weight_sums[:, None, None, None] / len(weight_sums)).numpy()
stas[layer_name] = (
weighted.cpu().detach()
/ weight_sums[:, None, None, None]
/ len(weight_sums)
).numpy()
torch.cuda.empty_cache()
return stas


@deprecated("Use functions of retinal_rl.analysis.statistics")
def sum_collapse_output(out_tensor):
if len(out_tensor.shape) > 2:
sum_dims = [2+i for i in range(len(out_tensor.shape)-2)]
sum_dims = [2 + i for i in range(len(out_tensor.shape) - 2)]
out_tensor = torch.sum(out_tensor, dim=sum_dims)
return out_tensor


def _find_last_layer_shape(model: nn.Sequential) -> Tuple[int, Optional[int], Optional[int], Optional[int], bool]:
def _find_last_layer_shape(
model: nn.Sequential,
) -> Tuple[int, Optional[int], Optional[int], Optional[int], bool]:
_first = 0
down_stream_linear = False
num_outputs = None
Expand All @@ -133,22 +159,31 @@ def _find_last_layer_shape(model: nn.Sequential) -> Tuple[int, Optional[int], Op
if isinstance(layer, nn.Conv2d):
num_outputs = layer.out_channels
in_channels = layer.in_channels
in_size = layer.in_channels * ((layer.kernel_size[0]-1)*layer.dilation[0]+1) ** 2
in_size = (
layer.in_channels
* ((layer.kernel_size[0] - 1) * layer.dilation[0] + 1) ** 2
)
break
if isinstance(layer, (nn.MaxPool2d, nn.AvgPool2d)):
for prev_layer in reversed(model[:-i-1]):
for prev_layer in reversed(model[: -i - 1]):
if isinstance(prev_layer, nn.Conv2d):
in_channels = prev_layer.out_channels
break
if isinstance(prev_layer, nn.Linear):
in_channels=1
in_channels = 1
else:
raise TypeError("layer before pooling needs to be conv or linear")
_kernel_size = layer.kernel_size if isinstance(layer.kernel_size, int) else layer.kernel_size[0]
_kernel_size = (
layer.kernel_size
if isinstance(layer.kernel_size, int)
else layer.kernel_size[0]
)
in_size = _kernel_size**2 * in_channels
break
return _first, num_outputs, in_size, in_channels, down_stream_linear


@deprecated("Use functions of retinal_rl.analysis.statistics")
def get_input_output_shape(model: nn.Sequential):
"""
Calculates the 'minimal' input and output of a sequential model.
Expand All @@ -159,7 +194,9 @@ def get_input_output_shape(model: nn.Sequential):
TODO: Check if still needed, function near duplicate of some of Sachas code
"""

_first, num_outputs, in_size, in_channels, down_stream_linear = _find_last_layer_shape(model)
_first, num_outputs, in_size, in_channels, down_stream_linear = (
_find_last_layer_shape(model)
)

for i, layer in enumerate(reversed(model[:-_first])):
if isinstance(layer, nn.Linear):
Expand All @@ -176,11 +213,11 @@ def get_input_output_shape(model: nn.Sequential):
in_size = (
(in_size - 1) * layer.stride[0]
- 2 * layer.padding[0] * down_stream_linear
+ ((layer.kernel_size[0]-1)*layer.dilation[0]+1)
+ ((layer.kernel_size[0] - 1) * layer.dilation[0] + 1)
)
in_size = in_size**2 * in_channels
elif isinstance(layer, (nn.MaxPool2d, nn.AvgPool2d)):
for prev_layer in reversed(model[:-i-_first-1]):
for prev_layer in reversed(model[: -i - _first - 1]):
if isinstance(prev_layer, nn.Conv2d):
in_channels = prev_layer.out_channels
break
Expand All @@ -196,6 +233,8 @@ def get_input_output_shape(model: nn.Sequential):
input_size = (in_channels, in_size, in_size)
return num_outputs, input_size


@deprecated("Use functions of retinal_rl.analysis.statistics")
def get_reconstructions(
device: torch.device,
brain: Brain,
Expand Down

0 comments on commit 7168e02

Please sign in to comment.