1
1
from __future__ import annotations
2
2
3
+ import copy
4
+
3
5
import numpy as np
4
6
import json
5
7
from pathlib import Path
@@ -45,6 +47,7 @@ def correct_inter_session_displacement(
45
47
from spikeinterface .sortingcomponents .motion_estimation import estimate_motion
46
48
from spikeinterface .sortingcomponents .motion_interpolation import InterpolateMotionRecording
47
49
from spikeinterface .core .node_pipeline import ExtractDenseWaveforms , run_node_pipeline
50
+ from spikeinterface .sortingcomponents .motion_utils import Motion
48
51
49
52
# TODO: do not accept multi-segment recordings.
50
53
# TODO: check all recordings have the same probe dimensions!
@@ -128,6 +131,7 @@ def correct_inter_session_displacement(
128
131
spatial_bin_edges = None ,
129
132
)
130
133
else :
134
+ assert NotImplementedError
131
135
motion_histogram = make_3d_motion_histograms (
132
136
recording ,
133
137
peaks ,
@@ -141,6 +145,9 @@ def correct_inter_session_displacement(
141
145
spatial_bin_edges = None ,
142
146
)
143
147
motion_histogram_list .append (motion_histogram [0 ].squeeze ())
148
+ # store bin edges
149
+ temporal_bin_edges = motion_histogram [1 ]
150
+ spatial_bin_edges = motion_histogram [2 ]
144
151
145
152
# Do some checks on temporal and spatial bin edges that they are all the same?
146
153
# TODO: do some smoothing? Try some other methds (e.g. NMI, KL divergence)
@@ -180,8 +187,42 @@ def correct_inter_session_displacement(
180
187
# TODO: do multi-session optimisation
181
188
182
189
# Handle drift
190
+ interpolate_motion_kwargs = {}
183
191
184
192
# TODO: add motion to motion if exists otherwise create InterpolateMotionRecording object!
185
193
# Will need the y-axis bins for this
186
- motion = Motion ([motion_array ], [temporal_bins ], non_rigid_window_centers , direction = direction )
187
- recording_corrected = InterpolateMotionRecording (recording , motion , ** interpolate_motion_kwargs )
194
+ all_recording_corrected = []
195
+ all_motion_info = []
196
+ for i , recording in enumerate (recordings_list ):
197
+
198
+ # TODO: direct copy, use 'get_window' from motion machinery
199
+ bin_centers = spatial_bin_edges [:- 1 ] + bin_um / 2.0
200
+ n = bin_centers .size
201
+ non_rigid_windows = [np .ones (n , dtype = "float64" )]
202
+ middle = (spatial_bin_edges [0 ] + spatial_bin_edges [- 1 ]) / 2.0
203
+ non_rigid_window_centers = np .array ([middle ])
204
+
205
+ motion_array = shifts [i ] # TODO: this is the rigid case!
206
+ temporal_bins = 0.5 * (temporal_bin_edges [1 :] + temporal_bin_edges [:- 1 ])
207
+ motion = Motion (
208
+ [np .atleast_2d (motion_array )], [temporal_bins ], non_rigid_window_centers , direction = "y"
209
+ ) # will be same for all except for shifts
210
+ all_motion_info .append (motion ) # not certain on this
211
+
212
+ if isinstance (recording , InterpolateMotionRecording ):
213
+ raise NotImplementedError
214
+ recording_corrected = copy .deepcopy (recording )
215
+ # TODO: add interpolation to the existing one.
216
+ # Not if inter-session motion correction already exists, but further
217
+ # up the preprocessing chain, it will NOT be added and interpolation
218
+ # will occur twice. Throw a warning here!
219
+ else :
220
+ recording_corrected = InterpolateMotionRecording (recording , motion , ** interpolate_motion_kwargs )
221
+ all_recording_corrected .append (recording_corrected )
222
+
223
+ displacement_info = {
224
+ "all_motion_info" : all_motion_info ,
225
+ "all_motion_histograms" : motion_histogram_list , # TODO: naming
226
+ "all_shifts" : shifts ,
227
+ }
228
+ return all_recording_corrected , displacement_info # TODO: output more stuff later e.g. the Motion object
0 commit comments