|
| 1 | +""" |
| 2 | +TODO: some notes on this debugging script. |
| 3 | +""" |
| 4 | +import spikeinterface.full as si |
| 5 | +from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings |
| 6 | +import matplotlib.pyplot as plt |
| 7 | +import numpy as np |
| 8 | +from spikeinterface.sortingcomponents.peak_detection import detect_peaks |
| 9 | +from spikeinterface.sortingcomponents.peak_localization import localize_peaks |
| 10 | +from spikeinterface.sortingcomponents.motion.motion_utils import \ |
| 11 | + make_2d_motion_histogram, make_3d_motion_histograms |
| 12 | +from scipy.optimize import minimize |
| 13 | + |
| 14 | +# Generate a ground truth recording where every unit is firing a lot, |
| 15 | +# with high amplitude, and is close to the spike, so all picked up. |
| 16 | +# This just makes it easy to play around with the units, e.g. |
| 17 | +# if specifying 5 units, 5 unit peaks are clearly visible, none are lost |
| 18 | +# because their position is too far from probe. |
| 19 | + |
| 20 | +default_unit_params_range = dict( |
| 21 | + alpha=(100.0, 500.0), |
| 22 | + depolarization_ms=(0.09, 0.14), |
| 23 | + repolarization_ms=(0.5, 0.8), |
| 24 | + recovery_ms=(1.0, 1.5), |
| 25 | + positive_amplitude=(0.1, 0.25), |
| 26 | + smooth_ms=(0.03, 0.07), |
| 27 | + spatial_decay=(20, 40), |
| 28 | + propagation_speed=(250.0, 350.0), |
| 29 | + b=(0.1, 1), |
| 30 | + c=(0.1, 1), |
| 31 | + x_angle=(0, np.pi), |
| 32 | + y_angle=(0, np.pi), |
| 33 | + z_angle=(0, np.pi), |
| 34 | +) |
| 35 | + |
| 36 | +default_unit_params_range["alpha"] = (100, 600) # do this or change the margin on generate_unit_locations_kwargs |
| 37 | +default_unit_params_range["b"] = (0.5, 1) # and make the units fatter, easier to receive signal! |
| 38 | +default_unit_params_range["c"] = (0.5, 1) |
| 39 | + |
| 40 | +rec_list, _ = generate_session_displacement_recordings( |
| 41 | + non_rigid_gradient=None, # 0.05, TODO: note this will set nonlinearity to both x and y (the same) |
| 42 | + num_units=100, |
| 43 | + recording_durations=(50,), # TODO: checks on inputs |
| 44 | + recording_shifts=( |
| 45 | + (0, 0), |
| 46 | + ), |
| 47 | + recording_amplitude_scalings=None, |
| 48 | +# generate_sorting_kwargs=dict(firing_rates=(50, 100), refractory_period_ms=4.0), |
| 49 | +# generate_templates_kwargs=dict(unit_params=default_unit_params_range, ms_before=1.5, ms_after=3), |
| 50 | + seed=None, |
| 51 | +# generate_unit_locations_kwargs=dict( |
| 52 | +# margin_um=0.0, # if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up. |
| 53 | +# minimum_z=5.0, |
| 54 | +# maximum_z=45.0, |
| 55 | +# minimum_distance=18.0, |
| 56 | +# max_iteration=100, |
| 57 | +# distance_strict=False, |
| 58 | +# ), |
| 59 | +) |
| 60 | + |
| 61 | +recording = rec_list[0] |
| 62 | + |
| 63 | +peaks = detect_peaks(recording, method="locally_exclusive") |
| 64 | +peak_locations = localize_peaks(recording, peaks, method="grid_convolution") |
| 65 | + |
| 66 | +si.plot_drift_raster_map( |
| 67 | + peaks=peaks, |
| 68 | + peak_locations=peak_locations, |
| 69 | + recording=recording, |
| 70 | + clim=(-300, 0) # fix clim for comparability across plots |
| 71 | +) |
| 72 | +plt.show() |
| 73 | + |
| 74 | + |
| 75 | +# TODO: to test, get a real recording, interpolate each recording |
| 76 | +# one up, one down a small amount. |
| 77 | + |
| 78 | +# ----------------------------------------------------------------------------- |
| 79 | +# Over Entire Session |
| 80 | +# ----------------------------------------------------------------------------- |
| 81 | +bin_um = 1 |
| 82 | + |
| 83 | +print("starting make hist") |
| 84 | +entire_session_hist, temporal_bin_edges, spatial_bin_edges = make_2d_motion_histogram( |
| 85 | + recording, |
| 86 | + peaks, |
| 87 | + peak_locations, |
| 88 | + weight_with_amplitude=False, |
| 89 | + direction="y", |
| 90 | + bin_s=recording.get_duration(segment_index=0), # 1.0, |
| 91 | + bin_um=bin_um, |
| 92 | + hist_margin_um=50, |
| 93 | + spatial_bin_edges=None, |
| 94 | + ) |
| 95 | +entire_session_hist = entire_session_hist[0] |
| 96 | +entire_session_hist /= np.max(entire_session_hist) |
| 97 | +centers = (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) / 2 |
| 98 | +plt.plot(centers, entire_session_hist) |
| 99 | +plt.show() |
| 100 | + |
| 101 | +# ----------------------------------------------------------------------------- |
| 102 | +# CHUNKING! |
| 103 | +# ----------------------------------------------------------------------------- |
| 104 | + |
| 105 | +chunk_session_hist, temporal_bin_edges, spatial_bin_edges = make_2d_motion_histogram( |
| 106 | + recording, |
| 107 | + peaks, |
| 108 | + peak_locations, |
| 109 | + weight_with_amplitude=False, |
| 110 | + direction="y", |
| 111 | + bin_s=1, # Now make 25 histograms |
| 112 | + bin_um=bin_um, |
| 113 | + hist_margin_um=50, |
| 114 | + spatial_bin_edges=None, |
| 115 | + ) |
| 116 | + |
| 117 | +m = chunk_session_hist.shape[0] # TODO: n_hist |
| 118 | +n = chunk_session_hist.shape[1] # TOOD: n_bin |
| 119 | + |
| 120 | +session_std = np.sum(np.std(chunk_session_hist, axis=0)) / m |
| 121 | +print("Histogram STD:: ", session_std) |
| 122 | + |
| 123 | +# TODO: exclude histograms at the level of entire histogram or bin |
| 124 | +# TODO: how to handle this multidimensional std. think more. |
| 125 | + |
| 126 | +# TODO: for now scale by max for interpretability |
| 127 | +# chunk_session_hist = chunk_session_hist / np.max(chunk_session_hist, axis=1)[:, np.newaxis] |
| 128 | +spatial_centers = (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) / 2 # TODO: own function! |
| 129 | +temporal_centers = (temporal_bin_edges[1:] + temporal_bin_edges[:-1]) / 2 |
| 130 | + |
| 131 | +for i in range(chunk_session_hist.shape[0]): |
| 132 | + plt.plot(spatial_centers, chunk_session_hist[i, :]) |
| 133 | +plt.show() |
| 134 | + |
| 135 | +# ----------------------------------------------------------------------------- |
| 136 | +# Mean of Chunks |
| 137 | +# ----------------------------------------------------------------------------- |
| 138 | + |
| 139 | +mean_hist = np.mean(chunk_session_hist, axis=0) |
| 140 | +mean_hist /= np.max(mean_hist) |
| 141 | + |
| 142 | +# ----------------------------------------------------------------------------- |
| 143 | +# Median of Chunks |
| 144 | +# ----------------------------------------------------------------------------- |
| 145 | + |
| 146 | +median_hist = np.median(chunk_session_hist, axis=0) # interesting, this is probably dumb |
| 147 | +median_hist /= np.max(median_hist) |
| 148 | + |
| 149 | +# ----------------------------------------------------------------------------- |
| 150 | +# Eigenvectors of Chunks |
| 151 | +# ----------------------------------------------------------------------------- |
| 152 | + |
| 153 | +A = chunk_session_hist |
| 154 | +S = A.T @ A # (num hist, num_bins) |
| 155 | + |
| 156 | +U,S, Vh = np.linalg.svd(S) |
| 157 | + |
| 158 | +# TODO: check why this is flipped |
| 159 | +first_eigenvalue = U[:, 0] * -1 # TODO: revise a little + consider another distance metric |
| 160 | +first_eigenvalue /= np.max(first_eigenvalue) |
| 161 | + |
| 162 | +# ----------------------------------------------------------------------------- |
| 163 | +# Poisson Modelling |
| 164 | +# ----------------------------------------------------------------------------- |
| 165 | + |
| 166 | +def obj_fun(lambda_, m, sum_k ): |
| 167 | + return -(sum_k * np.log(lambda_) - m * lambda_) |
| 168 | + |
| 169 | +poisson_estimate = np.zeros(chunk_session_hist.shape[1]) # TODO: var names |
| 170 | +for i in range(chunk_session_hist.shape[1]): |
| 171 | + |
| 172 | + ks = chunk_session_hist[:, i] |
| 173 | + |
| 174 | + m = ks.shape |
| 175 | + sum_k = np.sum(ks) |
| 176 | + |
| 177 | + # lol, this is painfully close to the mean... |
| 178 | + poisson_estimate[i] = minimize(obj_fun, 0.5, (m, sum_k), bounds=((1e-10, np.inf),)).x |
| 179 | +poisson_estimate /= np.max(poisson_estimate) |
| 180 | + |
| 181 | +# ----------------------------------------------------------------------------- |
| 182 | +# Plotting Results |
| 183 | +# ----------------------------------------------------------------------------- |
| 184 | + |
| 185 | +plt.plot(entire_session_hist) # obs this is equal to mean hist |
| 186 | +plt.plot(mean_hist) |
| 187 | +plt.plot(median_hist) |
| 188 | +plt.plot(first_eigenvalue) |
| 189 | +plt.plot(poisson_estimate) |
| 190 | +plt.legend(["entire", "chunk mean", "chunk median", "chunk eigenvalue", "Poisson estimate"]) |
| 191 | +plt.show() |
| 192 | + |
| 193 | +breakpoint() |
| 194 | + |
| 195 | +# After this try (x, y) alignment |
| 196 | +# estimate chunk size based on firing rate |
| 197 | +# figure out best histogram size based on x0x0x0 |
0 commit comments