Skip to content

Commit 502c152

Browse files
committed
Edit motion interpolator for 1d case
1 parent be052f5 commit 502c152

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

src/spikeinterface/sortingcomponents/motion/motion_interpolation.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -120,20 +120,22 @@ def interpolate_motion_on_traces(
120120
time_bins = interpolation_time_bin_centers_s
121121
if time_bins is None:
122122
time_bins = motion.temporal_bins_s[segment_index]
123-
bin_s = (
124-
time_bins[1] - time_bins[0] if time_bins.size > 1 else time_bins * 2
125-
) # TODO: check this is * 2 but yes must be because its in the middle NO ITS NOT if first time is not 0
126-
# must use a different stragery
127-
bins_start = time_bins[0] - 0.5 * bin_s
128-
# nearest bin center for each frame?
129-
bin_inds = (times - bins_start) // bin_s
130-
bin_inds = bin_inds.astype(int)
131-
# the time bins may not cover the whole set of times in the recording,
132-
# so we need to clip these indices to the valid range
133-
np.clip(bin_inds, 0, time_bins.size, out=bin_inds)
134-
135-
# -- what are the possibilities here anyway?
136-
bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) # TODO: just replace this with 0
123+
124+
if time_bins.size == 1:
125+
bins_here = [0]
126+
else:
127+
bin_s = time_bins[1] - time_bins[0]
128+
# must use a different stragery
129+
bins_start = time_bins[0] - 0.5 * bin_s
130+
# nearest bin center for each frame?
131+
bin_inds = (times - bins_start) // bin_s
132+
bin_inds = bin_inds.astype(int)
133+
# the time bins may not cover the whole set of times in the recording,
134+
# so we need to clip these indices to the valid range
135+
np.clip(bin_inds, 0, time_bins.size, out=bin_inds)
136+
137+
# -- what are the possibilities here anyway?
138+
bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1)
137139

138140
# inperpolation kernel will be the same per temporal bin
139141
interp_times = np.empty(total_num_chans)
@@ -168,16 +170,19 @@ def interpolate_motion_on_traces(
168170
# plt.show()
169171

170172
# quickly find the end of this bin, which is also the start of the next
171-
next_start_index = current_start_index + np.searchsorted(
172-
bin_inds[current_start_index:], bin_ind + 1, side="left"
173-
)
174-
in_bin = slice(current_start_index, next_start_index)
173+
if time_bins.size == 1:
174+
in_bin = None
175+
else:
176+
next_start_index = current_start_index + np.searchsorted(
177+
bin_inds[current_start_index:], bin_ind + 1, side="left"
178+
)
179+
in_bin = slice(current_start_index, next_start_index)
180+
current_start_index = next_start_index
175181

176182
# here we use a simple np.matmul even if dirft_kernel can be super sparse.
177183
# because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing
178184
# in ChunkRecordingExecutor)
179185
np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin])
180-
current_start_index = next_start_index
181186

182187
return traces_corrected
183188

0 commit comments

Comments
 (0)