Skip to content

Commit e0ebddf

Browse files
committed
Start playing around with histogram estimation.
1 parent 502c152 commit e0ebddf

File tree

1 file changed

+197
-0
lines changed

1 file changed

+197
-0
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""
2+
TODO: some notes on this debugging script.
3+
"""
4+
import spikeinterface.full as si
5+
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
9+
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
10+
from spikeinterface.sortingcomponents.motion.motion_utils import \
11+
make_2d_motion_histogram, make_3d_motion_histograms
12+
from scipy.optimize import minimize
13+
14+
# Generate a ground truth recording where every unit is firing a lot,
15+
# with high amplitude, and is close to the spike, so all picked up.
16+
# This just makes it easy to play around with the units, e.g.
17+
# if specifying 5 units, 5 unit peaks are clearly visible, none are lost
18+
# because their position is too far from probe.
19+
20+
default_unit_params_range = dict(
21+
alpha=(100.0, 500.0),
22+
depolarization_ms=(0.09, 0.14),
23+
repolarization_ms=(0.5, 0.8),
24+
recovery_ms=(1.0, 1.5),
25+
positive_amplitude=(0.1, 0.25),
26+
smooth_ms=(0.03, 0.07),
27+
spatial_decay=(20, 40),
28+
propagation_speed=(250.0, 350.0),
29+
b=(0.1, 1),
30+
c=(0.1, 1),
31+
x_angle=(0, np.pi),
32+
y_angle=(0, np.pi),
33+
z_angle=(0, np.pi),
34+
)
35+
36+
default_unit_params_range["alpha"] = (100, 600) # do this or change the margin on generate_unit_locations_kwargs
37+
default_unit_params_range["b"] = (0.5, 1) # and make the units fatter, easier to receive signal!
38+
default_unit_params_range["c"] = (0.5, 1)
39+
40+
rec_list, _ = generate_session_displacement_recordings(
41+
non_rigid_gradient=None, # 0.05, TODO: note this will set nonlinearity to both x and y (the same)
42+
num_units=100,
43+
recording_durations=(50,), # TODO: checks on inputs
44+
recording_shifts=(
45+
(0, 0),
46+
),
47+
recording_amplitude_scalings=None,
48+
# generate_sorting_kwargs=dict(firing_rates=(50, 100), refractory_period_ms=4.0),
49+
# generate_templates_kwargs=dict(unit_params=default_unit_params_range, ms_before=1.5, ms_after=3),
50+
seed=None,
51+
# generate_unit_locations_kwargs=dict(
52+
# margin_um=0.0, # if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up.
53+
# minimum_z=5.0,
54+
# maximum_z=45.0,
55+
# minimum_distance=18.0,
56+
# max_iteration=100,
57+
# distance_strict=False,
58+
# ),
59+
)
60+
61+
recording = rec_list[0]
62+
63+
peaks = detect_peaks(recording, method="locally_exclusive")
64+
peak_locations = localize_peaks(recording, peaks, method="grid_convolution")
65+
66+
si.plot_drift_raster_map(
67+
peaks=peaks,
68+
peak_locations=peak_locations,
69+
recording=recording,
70+
clim=(-300, 0) # fix clim for comparability across plots
71+
)
72+
plt.show()
73+
74+
75+
# TODO: to test, get a real recording, interpolate each recording
76+
# one up, one down a small amount.
77+
78+
# -----------------------------------------------------------------------------
79+
# Over Entire Session
80+
# -----------------------------------------------------------------------------
81+
bin_um = 1
82+
83+
print("starting make hist")
84+
entire_session_hist, temporal_bin_edges, spatial_bin_edges = make_2d_motion_histogram(
85+
recording,
86+
peaks,
87+
peak_locations,
88+
weight_with_amplitude=False,
89+
direction="y",
90+
bin_s=recording.get_duration(segment_index=0), # 1.0,
91+
bin_um=bin_um,
92+
hist_margin_um=50,
93+
spatial_bin_edges=None,
94+
)
95+
entire_session_hist = entire_session_hist[0]
96+
entire_session_hist /= np.max(entire_session_hist)
97+
centers = (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) / 2
98+
plt.plot(centers, entire_session_hist)
99+
plt.show()
100+
101+
# -----------------------------------------------------------------------------
102+
# CHUNKING!
103+
# -----------------------------------------------------------------------------
104+
105+
chunk_session_hist, temporal_bin_edges, spatial_bin_edges = make_2d_motion_histogram(
106+
recording,
107+
peaks,
108+
peak_locations,
109+
weight_with_amplitude=False,
110+
direction="y",
111+
bin_s=1, # Now make 25 histograms
112+
bin_um=bin_um,
113+
hist_margin_um=50,
114+
spatial_bin_edges=None,
115+
)
116+
117+
m = chunk_session_hist.shape[0] # TODO: n_hist
118+
n = chunk_session_hist.shape[1] # TOOD: n_bin
119+
120+
session_std = np.sum(np.std(chunk_session_hist, axis=0)) / m
121+
print("Histogram STD:: ", session_std)
122+
123+
# TODO: exclude histograms at the level of entire histogram or bin
124+
# TODO: how to handle this multidimensional std. think more.
125+
126+
# TODO: for now scale by max for interpretability
127+
# chunk_session_hist = chunk_session_hist / np.max(chunk_session_hist, axis=1)[:, np.newaxis]
128+
spatial_centers = (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) / 2 # TODO: own function!
129+
temporal_centers = (temporal_bin_edges[1:] + temporal_bin_edges[:-1]) / 2
130+
131+
for i in range(chunk_session_hist.shape[0]):
132+
plt.plot(spatial_centers, chunk_session_hist[i, :])
133+
plt.show()
134+
135+
# -----------------------------------------------------------------------------
136+
# Mean of Chunks
137+
# -----------------------------------------------------------------------------
138+
139+
mean_hist = np.mean(chunk_session_hist, axis=0)
140+
mean_hist /= np.max(mean_hist)
141+
142+
# -----------------------------------------------------------------------------
143+
# Median of Chunks
144+
# -----------------------------------------------------------------------------
145+
146+
median_hist = np.median(chunk_session_hist, axis=0) # interesting, this is probably dumb
147+
median_hist /= np.max(median_hist)
148+
149+
# -----------------------------------------------------------------------------
150+
# Eigenvectors of Chunks
151+
# -----------------------------------------------------------------------------
152+
153+
A = chunk_session_hist
154+
S = A.T @ A # (num hist, num_bins)
155+
156+
U,S, Vh = np.linalg.svd(S)
157+
158+
# TODO: check why this is flipped
159+
first_eigenvalue = U[:, 0] * -1 # TODO: revise a little + consider another distance metric
160+
first_eigenvalue /= np.max(first_eigenvalue)
161+
162+
# -----------------------------------------------------------------------------
163+
# Poisson Modelling
164+
# -----------------------------------------------------------------------------
165+
166+
def obj_fun(lambda_, m, sum_k ):
167+
return -(sum_k * np.log(lambda_) - m * lambda_)
168+
169+
poisson_estimate = np.zeros(chunk_session_hist.shape[1]) # TODO: var names
170+
for i in range(chunk_session_hist.shape[1]):
171+
172+
ks = chunk_session_hist[:, i]
173+
174+
m = ks.shape
175+
sum_k = np.sum(ks)
176+
177+
# lol, this is painfully close to the mean...
178+
poisson_estimate[i] = minimize(obj_fun, 0.5, (m, sum_k), bounds=((1e-10, np.inf),)).x
179+
poisson_estimate /= np.max(poisson_estimate)
180+
181+
# -----------------------------------------------------------------------------
182+
# Plotting Results
183+
# -----------------------------------------------------------------------------
184+
185+
plt.plot(entire_session_hist) # obs this is equal to mean hist
186+
plt.plot(mean_hist)
187+
plt.plot(median_hist)
188+
plt.plot(first_eigenvalue)
189+
plt.plot(poisson_estimate)
190+
plt.legend(["entire", "chunk mean", "chunk median", "chunk eigenvalue", "Poisson estimate"])
191+
plt.show()
192+
193+
breakpoint()
194+
195+
# After this try (x, y) alignment
196+
# estimate chunk size based on firing rate
197+
# figure out best histogram size based on x0x0x0

0 commit comments

Comments
 (0)