Skip to content

Commit 58a6962

Browse files
committed
Play around with slope drift for generate_drifting_recording.
1 parent 9509da0 commit 58a6962

File tree

1 file changed

+43
-2
lines changed

1 file changed

+43
-2
lines changed

src/spikeinterface/preprocessing/inter_session_displacement.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import copy
4+
35
import numpy as np
46
import json
57
from pathlib import Path
@@ -45,6 +47,7 @@ def correct_inter_session_displacement(
4547
from spikeinterface.sortingcomponents.motion_estimation import estimate_motion
4648
from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording
4749
from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline
50+
from spikeinterface.sortingcomponents.motion_utils import Motion
4851

4952
# TODO: do not accept multi-segment recordings.
5053
# TODO: check all recordings have the same probe dimensions!
@@ -128,6 +131,7 @@ def correct_inter_session_displacement(
128131
spatial_bin_edges=None,
129132
)
130133
else:
134+
assert NotImplementedError
131135
motion_histogram = make_3d_motion_histograms(
132136
recording,
133137
peaks,
@@ -141,6 +145,9 @@ def correct_inter_session_displacement(
141145
spatial_bin_edges=None,
142146
)
143147
motion_histogram_list.append(motion_histogram[0].squeeze())
148+
# store bin edges
149+
temporal_bin_edges = motion_histogram[1]
150+
spatial_bin_edges = motion_histogram[2]
144151

145152
# Do some checks on temporal and spatial bin edges that they are all the same?
146153
# TODO: do some smoothing? Try some other methds (e.g. NMI, KL divergence)
@@ -180,8 +187,42 @@ def correct_inter_session_displacement(
180187
# TODO: do multi-session optimisation
181188

182189
# Handle drift
190+
interpolate_motion_kwargs = {}
183191

184192
# TODO: add motion to motion if exists otherwise create InterpolateMotionRecording object!
185193
# Will need the y-axis bins for this
186-
motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction)
187-
recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs)
194+
all_recording_corrected = []
195+
all_motion_info = []
196+
for i, recording in enumerate(recordings_list):
197+
198+
# TODO: direct copy, use 'get_window' from motion machinery
199+
bin_centers = spatial_bin_edges[:-1] + bin_um / 2.0
200+
n = bin_centers.size
201+
non_rigid_windows = [np.ones(n, dtype="float64")]
202+
middle = (spatial_bin_edges[0] + spatial_bin_edges[-1]) / 2.0
203+
non_rigid_window_centers = np.array([middle])
204+
205+
motion_array = shifts[i] # TODO: this is the rigid case!
206+
temporal_bins = 0.5 * (temporal_bin_edges[1:] + temporal_bin_edges[:-1])
207+
motion = Motion(
208+
[np.atleast_2d(motion_array)], [temporal_bins], non_rigid_window_centers, direction="y"
209+
) # will be same for all except for shifts
210+
all_motion_info.append(motion) # not certain on this
211+
212+
if isinstance(recording, InterpolateMotionRecording):
213+
raise NotImplementedError
214+
recording_corrected = copy.deepcopy(recording)
215+
# TODO: add interpolation to the existing one.
216+
# Not if inter-session motion correction already exists, but further
217+
# up the preprocessing chain, it will NOT be added and interpolation
218+
# will occur twice. Throw a warning here!
219+
else:
220+
recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs)
221+
all_recording_corrected.append(recording_corrected)
222+
223+
displacement_info = {
224+
"all_motion_info": all_motion_info,
225+
"all_motion_histograms": motion_histogram_list, # TODO: naming
226+
"all_shifts": shifts,
227+
}
228+
return all_recording_corrected, displacement_info # TODO: output more stuff later e.g. the Motion object

0 commit comments

Comments
 (0)