Skip to content

Commit 9da4cbf

Browse files
committed
Continue adding tests.
1 parent 070e33f commit 9da4cbf

File tree

5 files changed

+346
-136
lines changed

5 files changed

+346
-136
lines changed

playing.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,73 @@
11
from spikeinterface.generation import generate_drifting_recording
22
from spikeinterface.preprocessing.motion import correct_motion
33
from spikeinterface.sortingcomponents.motion.motion_interpolation import InterpolateMotionRecording
4+
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
5+
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
6+
from spikeinterface.generation import generate_ground_truth_recording
7+
from spikeinterface.core import get_noise_levels
8+
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
49

5-
rec = generate_drifting_recording(duration=100)[0]
610

7-
proc_rec = correct_motion(rec)
11+
recordings_list, _ = generate_session_displacement_recordings(
12+
num_units=5,
13+
recording_durations=[1, 1],
14+
recording_shifts=((0, 0), (0, 250)),
15+
# TODO: can see how well this is recaptured by comparing the displacements to the known displacement + gradient
16+
non_rigid_gradient=None, # 0.1, # 0.1,
17+
seed=55, # 52
18+
generate_sorting_kwargs=dict(firing_rates=(100, 250), refractory_period_ms=4.0),
19+
generate_unit_locations_kwargs=dict(
20+
margin_um=0.0,
21+
# if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up.
22+
minimum_z=0.0,
23+
maximum_z=2.0,
24+
minimum_distance=18.0,
25+
max_iteration=100,
26+
distance_strict=False,
27+
),
28+
generate_noise_kwargs=dict(noise_levels=(0.0, 1.0), spatial_decay=1.0),
29+
)
30+
rec = recordings_list[1]
831

9-
rec.set_probe(rec.get_probe())
32+
detect_kwargs = {
33+
"method": "locally_exclusive",
34+
"peak_sign": "neg",
35+
"detect_threshold": 25,
36+
"exclude_sweep_ms": 0.1,
37+
"radius_um": 75,
38+
"noise_levels": None,
39+
"random_chunk_kwargs": {},
40+
}
41+
localize_peaks_kwargs = {"method": "grid_convolution"}
1042

43+
# noise_levels = get_noise_levels(rec, return_scaled=False)
44+
rec_0 = recordings_list[0]
45+
rec_1 = recordings_list[1]
46+
47+
peaks_before_0 = detect_peaks(rec_0, **detect_kwargs) # noise_levels=noise_levels,
48+
peaks_before_1 = detect_peaks(rec_1, **detect_kwargs)
49+
50+
proc_rec_0, motion_info_0 = correct_motion(rec_0, preset="rigid_fast", detect_kwargs=detect_kwargs, localize_peaks_kwargs=localize_peaks_kwargs, output_motion_info=True)
51+
proc_rec_1, motion_info_1 = correct_motion(rec_1, preset="rigid_fast", detect_kwargs=detect_kwargs, localize_peaks_kwargs=localize_peaks_kwargs, output_motion_info=True)
52+
53+
peaks_after_0 = detect_peaks(proc_rec_0, **detect_kwargs) # noise_levels=noise_levels
54+
peaks_after_1 = detect_peaks(proc_rec_1, **detect_kwargs)
55+
56+
57+
import spikeinterface.full as si
58+
import matplotlib.pyplot as plt
59+
60+
# TODO: need to test multi-shank
61+
plot = si.plot_traces(rec_1, order_channel_by_depth=True) # , time_range=(0, 0.1))
62+
x = peaks_before_1["sample_index"] * (1/ rec_1.get_sampling_frequency())
63+
y = rec_1.get_channel_locations()[peaks_before_1["channel_index"], 1]
64+
plot.ax.scatter(x, y, color="r", s=2)
65+
plt.show()
66+
67+
plot = si.plot_traces(proc_rec_1, order_channel_by_depth=True)
68+
x = peaks_after_1["sample_index"] * (1/ proc_rec_1.get_sampling_frequency())
69+
y = rec_1.get_channel_locations()[peaks_after_1["channel_index"], 1]
70+
plot.ax.scatter(x, y, color="r", s=2)
71+
plt.show()
72+
73+
breakpoint()

src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# #############################################################################
1717

1818

19-
def get_activity_histogram(
19+
def get_2d_activity_histogram(
2020
recording: BaseRecording,
2121
peaks: np.ndarray,
2222
peak_locations: np.ndarray,
@@ -74,9 +74,6 @@ def get_activity_histogram(
7474
)
7575
assert np.array_equal(generated_spatial_bin_edges, spatial_bin_edges), "TODO: remove soon after testing"
7676

77-
temporal_bin_centers = get_bin_centers(temporal_bin_edges)
78-
spatial_bin_centers = get_bin_centers(spatial_bin_edges)
79-
8077
if scale_to_hz:
8178
if bin_s is None:
8279
scaler = 1 / recording.get_duration()
@@ -88,8 +85,12 @@ def get_activity_histogram(
8885
if log_scale:
8986
activity_histogram = np.log10(1 + activity_histogram) # TODO: make_2d_motion_histogram uses log2
9087

88+
temporal_bin_centers = get_bin_centers(temporal_bin_edges)
89+
spatial_bin_centers = get_bin_centers(spatial_bin_edges)
90+
9191
return activity_histogram, temporal_bin_centers, spatial_bin_centers
9292

93+
9394
def get_bin_centers(bin_edges):
9495
return (bin_edges[1:] + bin_edges[:-1]) / 2
9596

@@ -152,9 +153,6 @@ def get_chunked_hist_median(chunked_session_histograms):
152153
""" """
153154
median_hist = np.median(chunked_session_histograms, axis=0)
154155

155-
quartile_1 = np.percentile(chunked_session_histograms, 25, axis=0)
156-
quartile_3 = np.percentile(chunked_session_histograms, 75, axis=0)
157-
158156
return median_hist
159157

160158

@@ -311,15 +309,6 @@ def compute_histogram_crosscorrelation(
311309
windowed_histogram_j - np.mean(windowed_histogram_i),
312310
mode="full",
313311
)
314-
import os
315-
if "hello_world" in os.environ:
316-
plt.plot(windowed_histogram_i)
317-
plt.plot(windowed_histogram_j)
318-
plt.show()
319-
320-
plt.plot(xcorr)
321-
plt.show()
322-
323312
if num_shifts:
324313
window_indices = np.arange(center_bin - num_shifts, center_bin + num_shifts)
325314
xcorr = xcorr[window_indices]
@@ -436,7 +425,9 @@ def akima_interpolate_nonrigid_shifts(
436425
interpolated from the non-rigid shifts.
437426
438427
"""
439-
if Version(scipy.__version__) >= Version("1.4.0"):
428+
import scipy
429+
430+
if Version(scipy.__version__) < Version("1.14.0"):
440431
raise ImportError("Scipy version 14 or higher is required fro Akima interpolation.")
441432

442433
from scipy.interpolate import Akima1DInterpolator

src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def get_compute_alignment_kwargs() -> dict:
7171
windows along the probe depth. See `get_spatial_windows`.
7272
"""
7373
return {
74+
"num_shifts_global": None,
75+
"num_shifts_block": 20,
7476
"interpolate": False,
7577
"interp_factor": 10,
7678
"kriging_sigma": 1,
@@ -93,8 +95,6 @@ def get_non_rigid_window_kwargs():
9395
"""
9496
return {
9597
"rigid": True,
96-
"num_shifts_global": None,
97-
"num_shifts_block": 20,
9898
"win_shape": "gaussian",
9999
"win_step_um": 50,
100100
"win_scale_um": 50,
@@ -109,12 +109,13 @@ def get_interpolate_motion_kwargs():
109109
see that class for parameter descriptions.
110110
"""
111111
return {
112-
"border_mode": "force_zeros", # fixed as this until can figure out probe
112+
"border_mode": "force_zeros", # fixed as this until can figure out probe
113113
"spatial_interpolation_method": "kriging",
114114
"sigma_um": 20.0,
115-
"p": 2
115+
"p": 2,
116116
}
117117

118+
118119
###############################################################################
119120
# Public Entry Level Functions
120121
###############################################################################
@@ -222,7 +223,12 @@ def align_sessions(
222223

223224
# Ensure list lengths match and all channel locations are the same across recordings.
224225
_check_align_sessions_inputs(
225-
recordings_list, peaks_list, peak_locations_list, alignment_order, estimate_histogram_kwargs, interpolate_motion_kwargs
226+
recordings_list,
227+
peaks_list,
228+
peak_locations_list,
229+
alignment_order,
230+
estimate_histogram_kwargs,
231+
interpolate_motion_kwargs,
226232
)
227233

228234
print("Computing a single activity histogram from each session...")
@@ -311,7 +317,10 @@ def align_sessions_after_motion_correction(
311317
)
312318

313319
motion_window_kwargs = copy.deepcopy(motion_kwargs_list[0])
314-
if motion_window_kwargs["direction"] != "y":
320+
321+
if (
322+
"direction" in motion_window_kwargs and motion_window_kwargs["direction"] != "y"
323+
): # TODO: why is this not in all?
315324
raise ValueError("motion correct must have been performed along the 'y' dimension.")
316325

317326
if align_sessions_kwargs is None:
@@ -322,24 +331,37 @@ def align_sessions_after_motion_correction(
322331
# shifts together.
323332
if (
324333
"non_rigid_window_kwargs" in align_sessions_kwargs
325-
and "nonrigid" in align_sessions_kwargs["non_rigid_window_kwargs"]["rigid_mode"]
334+
and not align_sessions_kwargs["non_rigid_window_kwargs"]["rigid"]
326335
):
327-
336+
# TODO: carefully walk through this function! and test all assumptions...
328337
if not motion_window_kwargs["rigid"]:
329-
print(
338+
print( # TODO: make a warning
330339
"Nonrigid inter-session alignment must use the motion correct "
331340
"nonrigid settings. Overwriting any passed `non_rigid_window_kwargs` "
332341
"with the motion object non_rigid_window_kwargs."
333342
)
334-
motion_window_kwargs.pop("method")
335-
motion_window_kwargs.pop("direction")
343+
non_rigid_window_kwargs = get_non_rigid_window_kwargs()
344+
345+
# TODO: generate function for replacing one dict into another?
346+
for (
347+
k,
348+
v,
349+
) in motion_window_kwargs.items(): # TODO: can get tighter alignment here with original implementation?
350+
if k in non_rigid_window_kwargs:
351+
non_rigid_window_kwargs[k] = v
352+
336353
align_sessions_kwargs = copy.deepcopy(align_sessions_kwargs)
337-
align_sessions_kwargs["non_rigid_window_kwargs"] = motion_window_kwargs
354+
align_sessions_kwargs["non_rigid_window_kwargs"] = non_rigid_window_kwargs
355+
356+
corrected_peak_locations = [
357+
correct_motion_on_peaks(info["peaks"], info["peak_locations"], info["motion"], recording)
358+
for info, recording in zip(motion_info_list, recordings_list)
359+
]
338360

339361
return align_sessions(
340362
recordings_list,
341363
[info["peaks"] for info in motion_info_list],
342-
[info["peak_locations"] for info in motion_info_list],
364+
corrected_peak_locations,
343365
**align_sessions_kwargs,
344366
)
345367

@@ -459,14 +481,14 @@ def _compute_session_histograms(
459481
recording,
460482
peaks,
461483
peak_locations,
462-
histogram_type,
463-
spatial_bin_edges,
464-
method,
465-
log_scale,
466-
chunked_bin_size_s,
467-
depth_smooth_um,
468-
weight_with_amplitude,
469-
avg_in_bin,
484+
histogram_type=histogram_type,
485+
spatial_bin_edges=spatial_bin_edges,
486+
method=method,
487+
log_scale=log_scale,
488+
chunked_bin_size_s=chunked_bin_size_s,
489+
depth_smooth_um=depth_smooth_um,
490+
weight_with_amplitude=weight_with_amplitude,
491+
avg_in_bin=avg_in_bin,
470492
)
471493
temporal_bin_centers_list.append(temporal_bin_centers)
472494
session_histogram_list.append(session_hist)
@@ -539,32 +561,31 @@ def _get_single_session_activity_histogram(
539561
# full estimation for chunked bin size
540562
if chunked_bin_size_s == "estimate":
541563

542-
one_bin_histogram, _, _ = alignment_utils.get_activity_histogram(
564+
scaled_hist, _, _ = alignment_utils.get_2d_activity_histogram(
543565
recording,
544566
peaks,
545567
peak_locations,
546568
spatial_bin_edges,
547569
log_scale=False,
548570
bin_s=None,
549571
depth_smooth_um=None,
550-
scale_to_hz=False,
572+
scale_to_hz=True,
551573
weight_with_amplitude=False,
552574
avg_in_bin=False,
553575
)
554576

555577
# It is important that the passed histogram is scaled to firing rate in Hz
556-
scaled_hist = one_bin_histogram / recording.get_duration() # TODO: why is this done here when have a scale_to_hz arg??!?
557578
chunked_bin_size_s = alignment_utils.estimate_chunk_size(scaled_hist)
558579
chunked_bin_size_s = np.min([chunked_bin_size_s, recording.get_duration()])
559580

560581
if histogram_type == "activity_1d":
561582

562-
chunked_histograms, chunked_temporal_bin_centers, _ = alignment_utils.get_activity_histogram(
583+
chunked_histograms, chunked_temporal_bin_centers, _ = alignment_utils.get_2d_activity_histogram(
563584
recording,
564585
peaks,
565586
peak_locations,
566587
spatial_bin_edges,
567-
log_scale,
588+
log_scale=log_scale,
568589
bin_s=chunked_bin_size_s,
569590
depth_smooth_um=depth_smooth_um,
570591
weight_with_amplitude=weight_with_amplitude,
@@ -656,9 +677,11 @@ def _create_motion_recordings(
656677
motion,
657678
interpolation_time_bin_centers_s=motion.temporal_bins_s,
658679
interpolation_time_bin_edges_s=[np.array(recording.get_times()[0], recording.get_times()[-1])],
659-
**interpolate_motion_kwargs
680+
**interpolate_motion_kwargs,
660681
)
661-
corrected_recording = corrected_recording.set_probe(recording.get_probe()) # TODO: if this works, might need to do above
682+
corrected_recording = corrected_recording.set_probe(
683+
recording.get_probe()
684+
) # TODO: if this works, might need to do above
662685

663686
corrected_recordings_list.append(corrected_recording)
664687

@@ -840,8 +863,8 @@ def _compute_session_alignment(
840863
session_histogram_array = np.array(session_histogram_list)
841864

842865
akima_interp_nonrigid = compute_alignment_kwargs.pop("akima_interp_nonrigid")
843-
num_shifts_global = non_rigid_window_kwargs.pop("num_shifts_global")
844-
num_shifts_block = non_rigid_window_kwargs.pop("num_shifts_block")
866+
num_shifts_global = compute_alignment_kwargs.pop("num_shifts_global")
867+
num_shifts_block = compute_alignment_kwargs.pop("num_shifts_block")
845868

846869
non_rigid_windows, non_rigid_window_centers = get_spatial_windows(
847870
contact_depths, spatial_bin_centers, **non_rigid_window_kwargs
@@ -870,7 +893,7 @@ def _compute_session_alignment(
870893

871894
# Then compute the nonrigid shifts
872895
nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
873-
shifted_histograms, non_rigid_windows, num_shifts_block, **compute_alignment_kwargs
896+
shifted_histograms, non_rigid_windows, num_shifts=num_shifts_block, **compute_alignment_kwargs
874897
)
875898
non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)
876899

@@ -920,7 +943,7 @@ def _estimate_rigid_alignment(
920943
rigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
921944
session_histogram_array,
922945
rigid_window,
923-
num_shifts,
946+
num_shifts=num_shifts,
924947
**compute_alignment_kwargs, # TODO: remove the copy above and pass directly. Consider removing this function...
925948
)
926949
optimal_shift_indices = alignment_utils.get_shifts_from_session_matrix(
@@ -968,7 +991,6 @@ def _check_align_sessions_inputs(
968991
"performed using the same probe."
969992
)
970993

971-
972994
accepted_hist_methods = [
973995
"entire_session",
974996
"chunked_mean",
@@ -998,4 +1020,6 @@ def _check_align_sessions_inputs(
9981020
if ses_num == 0:
9991021
raise ValueError("`alignment_order` required the session number, not session index.")
10001022

1001-
assert interpolate_motion_kwargs["border_mode"] == "force_zeros", "InterpolateMotionRecording must be `force_zeros` until probe is figured out." # TODO: ask sam
1023+
assert (
1024+
interpolate_motion_kwargs["border_mode"] == "force_zeros"
1025+
), "InterpolateMotionRecording must be `force_zeros` until probe is figured out." # TODO: ask sam

src/spikeinterface/preprocessing/motion.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,10 @@ def run_peak_detection_pipeline_node(recording, gather_mode, detect_kwargs, loca
446446
from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline
447447
from spikeinterface.sortingcomponents.peak_localization import localize_peak_methods
448448

449+
# Don't modify the kwargs in place in case the caller requires them
450+
detect_kwargs = copy.deepcopy(detect_kwargs)
451+
localize_peaks_kwargs = copy.deepcopy(localize_peaks_kwargs)
452+
449453
# node detect
450454
method = detect_kwargs.pop("method", "locally_exclusive")
451455
method_class = detect_peak_methods[method]

0 commit comments

Comments
 (0)