Skip to content

Commit 1d3d8db

Browse files
committed
Continue playing with estimation.
1 parent e0ebddf commit 1d3d8db

File tree

1 file changed

+74
-54
lines changed

1 file changed

+74
-54
lines changed

debugging/estimate_session_histogram.py

Lines changed: 74 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,59 +17,75 @@
1717
# if specifying 5 units, 5 unit peaks are clearly visible, none are lost
1818
# because their position is too far from probe.
1919

20-
default_unit_params_range = dict(
21-
alpha=(100.0, 500.0),
22-
depolarization_ms=(0.09, 0.14),
23-
repolarization_ms=(0.5, 0.8),
24-
recovery_ms=(1.0, 1.5),
25-
positive_amplitude=(0.1, 0.25),
26-
smooth_ms=(0.03, 0.07),
27-
spatial_decay=(20, 40),
28-
propagation_speed=(250.0, 350.0),
29-
b=(0.1, 1),
30-
c=(0.1, 1),
31-
x_angle=(0, np.pi),
32-
y_angle=(0, np.pi),
33-
z_angle=(0, np.pi),
20+
if False:
21+
default_unit_params_range = dict(
22+
alpha=(100.0, 500.0),
23+
depolarization_ms=(0.09, 0.14),
24+
repolarization_ms=(0.5, 0.8),
25+
recovery_ms=(1.0, 1.5),
26+
positive_amplitude=(0.1, 0.25),
27+
smooth_ms=(0.03, 0.07),
28+
spatial_decay=(20, 40),
29+
propagation_speed=(250.0, 350.0),
30+
b=(0.1, 1),
31+
c=(0.1, 1),
32+
x_angle=(0, np.pi),
33+
y_angle=(0, np.pi),
34+
z_angle=(0, np.pi),
35+
)
36+
37+
default_unit_params_range["alpha"] = (100, 600) # do this or change the margin on generate_unit_locations_kwargs
38+
default_unit_params_range["b"] = (0.5, 1) # and make the units fatter, easier to receive signal!
39+
default_unit_params_range["c"] = (0.5, 1)
40+
41+
rec_list, _ = generate_session_displacement_recordings(
42+
non_rigid_gradient=None, # 0.05, TODO: note this will set nonlinearity to both x and y (the same)
43+
num_units=100,
44+
recording_durations=(50,), # TODO: checks on inputs
45+
recording_shifts=(
46+
(0, 0),
47+
),
48+
recording_amplitude_scalings=None,
49+
# generate_sorting_kwargs=dict(firing_rates=(50, 100), refractory_period_ms=4.0),
50+
# generate_templates_kwargs=dict(unit_params=default_unit_params_range, ms_before=1.5, ms_after=3),
51+
seed=None,
52+
# generate_unit_locations_kwargs=dict(
53+
# margin_um=0.0, # if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up.
54+
# minimum_z=5.0,
55+
# maximum_z=45.0,
56+
# minimum_distance=18.0,
57+
# max_iteration=100,
58+
# distance_strict=False,
59+
# ),
60+
)
61+
62+
_, drift_rec, _ = si.generate_drifting_recording(duration=250)
63+
64+
corrected_recording, motion_info = si.correct_motion(drift_rec, preset="kilosort_like", output_motion_info=True)
65+
66+
from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks
67+
68+
peaks = motion_info["peaks"]
69+
peak_locations = correct_motion_on_peaks(
70+
peaks,
71+
motion_info["peak_locations"],
72+
motion_info["motion"],
73+
corrected_recording,
3474
)
3575

36-
default_unit_params_range["alpha"] = (100, 600) # do this or change the margin on generate_unit_locations_kwargs
37-
default_unit_params_range["b"] = (0.5, 1) # and make the units fatter, easier to receive signal!
38-
default_unit_params_range["c"] = (0.5, 1)
39-
40-
rec_list, _ = generate_session_displacement_recordings(
41-
non_rigid_gradient=None, # 0.05, TODO: note this will set nonlinearity to both x and y (the same)
42-
num_units=100,
43-
recording_durations=(50,), # TODO: checks on inputs
44-
recording_shifts=(
45-
(0, 0),
46-
),
47-
recording_amplitude_scalings=None,
48-
# generate_sorting_kwargs=dict(firing_rates=(50, 100), refractory_period_ms=4.0),
49-
# generate_templates_kwargs=dict(unit_params=default_unit_params_range, ms_before=1.5, ms_after=3),
50-
seed=None,
51-
# generate_unit_locations_kwargs=dict(
52-
# margin_um=0.0, # if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up.
53-
# minimum_z=5.0,
54-
# maximum_z=45.0,
55-
# minimum_distance=18.0,
56-
# max_iteration=100,
57-
# distance_strict=False,
58-
# ),
59-
)
76+
if False:
77+
recording = rec_list[0]
6078

61-
recording = rec_list[0]
79+
peaks = detect_peaks(recording, method="locally_exclusive")
80+
peak_locations = localize_peaks(recording, peaks, method="grid_convolution")
6281

63-
peaks = detect_peaks(recording, method="locally_exclusive")
64-
peak_locations = localize_peaks(recording, peaks, method="grid_convolution")
65-
66-
si.plot_drift_raster_map(
67-
peaks=peaks,
68-
peak_locations=peak_locations,
69-
recording=recording,
70-
clim=(-300, 0) # fix clim for comparability across plots
71-
)
72-
plt.show()
82+
si.plot_drift_raster_map(
83+
peaks=peaks,
84+
peak_locations=peak_locations,
85+
recording=recording,
86+
clim=(-300, 0) # fix clim for comparability across plots
87+
)
88+
plt.show()
7389

7490

7591
# TODO: to test, get a real recording, interpolate each recording
@@ -78,16 +94,19 @@
7894
# -----------------------------------------------------------------------------
7995
# Over Entire Session
8096
# -----------------------------------------------------------------------------
81-
bin_um = 1
97+
98+
# TODO: figure out dynamic bin sizing based on 1-um histogram.
99+
#
100+
bin_um = 25 # TODO: maybe do some testing on benchmarks backed by some theory.
82101

83102
print("starting make hist")
84103
entire_session_hist, temporal_bin_edges, spatial_bin_edges = make_2d_motion_histogram(
85-
recording,
104+
corrected_recording,
86105
peaks,
87106
peak_locations,
88107
weight_with_amplitude=False,
89108
direction="y",
90-
bin_s=recording.get_duration(segment_index=0), # 1.0,
109+
bin_s=corrected_recording.get_duration(segment_index=0), # 1.0,
91110
bin_um=bin_um,
92111
hist_margin_um=50,
93112
spatial_bin_edges=None,
@@ -103,12 +122,12 @@
103122
# -----------------------------------------------------------------------------
104123

105124
chunk_session_hist, temporal_bin_edges, spatial_bin_edges = make_2d_motion_histogram(
106-
recording,
125+
corrected_recording,
107126
peaks,
108127
peak_locations,
109128
weight_with_amplitude=False,
110129
direction="y",
111-
bin_s=1, # Now make 25 histograms
130+
bin_s=5, # Now make 25 histograms
112131
bin_um=bin_um,
113132
hist_margin_um=50,
114133
spatial_bin_edges=None,
@@ -162,6 +181,7 @@
162181
# -----------------------------------------------------------------------------
163182
# Poisson Modelling
164183
# -----------------------------------------------------------------------------
184+
# Under assumption of independent bins and time points
165185

166186
def obj_fun(lambda_, m, sum_k ):
167187
return -(sum_k * np.log(lambda_) - m * lambda_)

0 commit comments

Comments
 (0)