Skip to content

Commit 01adece

Browse files
committed
Continue! working on tests.
1 parent 9da4cbf commit 01adece

File tree

3 files changed

+138
-64
lines changed

3 files changed

+138
-64
lines changed

src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -163,17 +163,17 @@ def get_chunked_hist_median(chunked_session_histograms):
163163

164164
# TODO: a good test here is to give zero shift for even and off numbered hist and check the output is zero!
165165
def compute_histogram_crosscorrelation(
166-
session_histogram_list: list[np.ndarray],
166+
session_histogram_list: np.ndarray,
167167
non_rigid_windows: np.ndarray,
168168
num_shifts: int,
169169
interpolate: bool,
170170
interp_factor: int,
171171
kriging_sigma: float,
172172
kriging_p: float,
173173
kriging_d: float,
174-
smoothing_sigma_bin: float,
175-
smoothing_sigma_window: float,
176-
):
174+
smoothing_sigma_bin: None | float,
175+
smoothing_sigma_window: None | float,
176+
) -> tuple[np.ndarray, np.ndarray]:
177177
"""
178178
Given a list of session activity histograms, cross-correlate
179179
all histograms returning the peak correlation shift (in indices)
@@ -185,7 +185,8 @@ def compute_histogram_crosscorrelation(
185185
Parameters
186186
----------
187187
188-
session_histogram_list : list[np.ndarray]
188+
session_histogram_list : list[np.ndarray] TODO: change name!!
189+
(num_sessions, num_bins) array of session activity histograms.
189190
non_rigid_windows : np.ndarray
190191
A (num windows x num_bins) binary of weights by which to window
191192
the activity histogram for non-rigid-registration. For example, if
@@ -258,23 +259,21 @@ def compute_histogram_crosscorrelation(
258259
"""
259260
import matplotlib.pyplot as plt
260261

261-
num_sessions = len(session_histogram_list)
262+
num_sessions = session_histogram_list.shape[0]
262263
num_bins = session_histogram_list.shape[1] # all hists are same length
263264
num_windows = non_rigid_windows.shape[0]
264265

265266
shift_matrix = np.zeros((num_sessions, num_sessions, num_windows))
266267

267268
center_bin = np.floor((num_bins * 2 - 1) / 2).astype(int)
268269

270+
# Create the (num windows, num_bins) matrix for this pair of sessions
271+
num_iter = num_bins * 2 - 1 if not num_shifts else num_shifts * 2
272+
shifts_array = np.arange(-(num_iter // 2), num_iter // 2 + 1)
273+
269274
for i in range(num_sessions):
270275
for j in range(i, num_sessions):
271276

272-
# Create the (num windows, num_bins) matrix for this pair of sessions
273-
num_iter = (
274-
num_bins * 2 - 1
275-
if not num_shifts
276-
else num_shifts * 2 # num_shift_block with iterative alignment is 2x, the same, make note!
277-
)
278277
xcorr_matrix = np.zeros((non_rigid_windows.shape[0], num_iter))
279278

280279
# For each window, window the session histograms (`window` is binary)
@@ -292,12 +291,12 @@ def compute_histogram_crosscorrelation(
292291
window_i = windowed_histogram_i - np.mean(windowed_histogram_i, axis=1)[:, np.newaxis]
293292
window_j = windowed_histogram_j - np.mean(windowed_histogram_j, axis=1)[:, np.newaxis]
294293

295-
xcorr = np.zeros(num_iter)
296-
for idx, shift in enumerate(range(-num_iter // 2, num_iter // 2)):
294+
xcorr = np.zeros(num_iter + 1)
295+
296+
for idx, shift in enumerate(shifts_array):
297297
shifted_i = shift_array_fill_zeros(window_i, shift)
298298

299299
xcorr[idx] = np.correlate(shifted_i.flatten(), window_j.flatten())
300-
301300
else:
302301
# For a 1D histogram, compute the full cross-correlation and
303302
# window the desired shifts ( this is faster than manual looping).
@@ -315,11 +314,6 @@ def compute_histogram_crosscorrelation(
315314

316315
xcorr_matrix[win_idx, :] = xcorr
317316

318-
if num_shifts:
319-
shift_center_bin = num_shifts
320-
else:
321-
shift_center_bin = center_bin
322-
323317
# Smooth the cross-correlations across the bins
324318
if smoothing_sigma_bin:
325319
xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1)
@@ -328,36 +322,67 @@ def compute_histogram_crosscorrelation(
328322
if num_windows > 1 and smoothing_sigma_window:
329323
xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_window, axes=0)
330324

325+
shifts_array = np.arange(-(num_iter // 2), num_iter // 2 + 1) # TODO: double check
331326
# Upsample the cross-correlation
332327
if interpolate:
333-
shifts = np.arange(xcorr_matrix.shape[1])
334-
shifts_upsampled = np.linspace(shifts[0], shifts[-1], shifts.size * interp_factor)
328+
329+
# shifts = np.arange(xcorr_matrix.shape[1])
330+
shifts_upsampled = np.linspace(shifts_array[0], shifts_array[-1], shifts_array.size * interp_factor)
335331

336332
K = kriging_kernel(
337-
np.c_[np.ones_like(shifts), shifts],
333+
np.c_[np.ones_like(shifts_array), shifts_array],
338334
np.c_[np.ones_like(shifts_upsampled), shifts_upsampled],
339335
kriging_sigma,
340336
kriging_p,
341337
kriging_d,
342338
)
343-
xcorr_matrix = np.matmul(xcorr_matrix, K, axes=[(-2, -1), (-2, -1), (-2, -1)])
344339

345-
xcorr_peak = np.argmax(xcorr_matrix, axis=1) / interp_factor
340+
# breakpoint()
341+
342+
xcorr_matrix_old = np.matmul(xcorr_matrix, K, axes=[(-2, -1), (-2, -1), (-2, -1)])
343+
xcorr_matrix_ = np.zeros(
344+
(xcorr_matrix.shape[0], shifts_upsampled.size)
345+
) # TODO: check in nonlinear case
346+
for i_ in range(xcorr_matrix.shape[0]):
347+
xcorr_matrix_[i_, :] = np.matmul(xcorr_matrix[i_, :], K)
348+
349+
# breakpoint()
350+
351+
plt.plot(shifts_array, xcorr_matrix.T)
352+
plt.show
353+
plt.plot(shifts_upsampled, xcorr_matrix_.T)
354+
plt.show()
355+
356+
xcorr_matrix = xcorr_matrix_
357+
358+
# plt.plot(xcorr_matrix.T)
359+
# plt.plot(xcorr_matrix_old.T)
360+
# plt.show()
361+
#
362+
363+
xcorr_peak = np.argmax(xcorr_matrix, axis=1)
364+
shift = shifts_upsampled[xcorr_peak]
365+
366+
# breakpoint()
367+
346368
else:
347369
xcorr_peak = np.argmax(xcorr_matrix, axis=1)
370+
shift = shifts_array[xcorr_peak]
348371

349-
# Caclulate and save the shift for session i to j
350-
shift = xcorr_peak - shift_center_bin
372+
# x=i;y=j
373+
# breakpoint()
351374
shift_matrix[i, j, :] = shift
352375

376+
breakpoint()
377+
353378
# As xcorr shifts are symmetric, the shift matrix is skew symmetric, so fill
354379
# the (empty) lower triangular with the negative (already computed) upper triangular to save computation
355380
for k in range(shift_matrix.shape[2]):
356381
lower_i, lower_j = np.tril_indices_from(shift_matrix[:, :, k], k=-1)
357382
upper_i, upper_j = np.triu_indices_from(shift_matrix[:, :, k], k=1)
358383
shift_matrix[lower_i, lower_j, k] = shift_matrix[upper_i, upper_j, k] * -1
359384

360-
return shift_matrix
385+
return shift_matrix, xcorr_matrix
361386

362387

363388
def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray:

src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_estimate_histogram_kwargs() -> dict:
4747
"bin_um": 2,
4848
"method": "chunked_mean",
4949
"chunked_bin_size_s": "estimate",
50-
"log_scale": False,
50+
"log_scale": True,
5151
"depth_smooth_um": None,
5252
"histogram_type": "activity_1d",
5353
"weight_with_amplitude": False,
@@ -881,8 +881,6 @@ def _compute_session_alignment(
881881
return rigid_shifts, non_rigid_windows, non_rigid_window_centers
882882

883883
# For non-rigid, first shift the histograms according to the rigid shift
884-
shifted_histograms = session_histogram_array.copy()
885-
886884
shifted_histograms = np.zeros_like(session_histogram_array)
887885
for ses_idx, orig_histogram in enumerate(session_histogram_array):
888886

@@ -892,7 +890,7 @@ def _compute_session_alignment(
892890
shifted_histograms[ses_idx, :] = shifted_histogram
893891

894892
# Then compute the nonrigid shifts
895-
nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
893+
nonrigid_session_offsets_matrix, _ = alignment_utils.compute_histogram_crosscorrelation(
896894
shifted_histograms, non_rigid_windows, num_shifts=num_shifts_block, **compute_alignment_kwargs
897895
)
898896
non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)
@@ -940,7 +938,7 @@ def _estimate_rigid_alignment(
940938

941939
rigid_window = np.ones(session_histogram_array.shape[1])[np.newaxis, :]
942940

943-
rigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
941+
rigid_session_offsets_matrix, _ = alignment_utils.compute_histogram_crosscorrelation(
944942
session_histogram_array,
945943
rigid_window,
946944
num_shifts=num_shifts,

src/spikeinterface/preprocessing/tests/test_inter_session_alignment.py

Lines changed: 82 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def test_recording_1(self):
4545
############################################################################
4646

4747
# TEST 1D AND 2D HERE
48-
@pytest.mark.parametrize("histogram_type", ["activity_2d"]) # "activity_1d"
48+
# TODO: test shift blocks...
49+
@pytest.mark.parametrize("histogram_type", ["activity_1d", "activity_2d"]) # "activity_1d" "activity_2d"
4950
def test_align_sessions_finds_correct_shifts(self, test_recording_1, histogram_type):
5051
""" """
5152
recordings_list, _, peaks_list, peak_locations_list = test_recording_1
@@ -57,8 +58,9 @@ def test_align_sessions_finds_correct_shifts(self, test_recording_1, histogram_t
5758
compute_alignment_kwargs["smoothing_sigma_window"] = None
5859

5960
estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs()
60-
estimate_histogram_kwargs["bin_um"] = 0.5
61+
estimate_histogram_kwargs["bin_um"] = 2
6162
estimate_histogram_kwargs["histogram_type"] = histogram_type
63+
estimate_histogram_kwargs["log_scale"] = True
6264

6365
for mode, expected in zip(
6466
["to_session_1", "to_session_2", "to_session_3", "to_middle"],
@@ -78,32 +80,7 @@ def test_align_sessions_finds_correct_shifts(self, test_recording_1, histogram_t
7880
estimate_histogram_kwargs=estimate_histogram_kwargs,
7981
)
8082

81-
# assert np.allclose(expected, extra_info["shifts_array"].squeeze(), rtol=0, atol=1.5)
82-
83-
from spikeinterface.widgets import plot_session_alignment, plot_activity_histogram_2d
84-
import matplotlib.pyplot as plt
85-
86-
corr_peaks_list, corr_peak_loc_list = session_alignment.compute_peaks_locations_for_session_alignment(
87-
corrected_recordings_list,
88-
detect_kwargs={"method": "locally_exclusive"},
89-
localize_peaks_kwargs={"method": "grid_convolution"},
90-
)
91-
92-
plot_session_alignment(
93-
corrected_recordings_list,
94-
corr_peaks_list,
95-
corr_peak_loc_list,
96-
extra_info["spatial_bin_centers"],
97-
**extra_info["corrected"],
98-
)
99-
plt.show()
100-
101-
plot_activity_histogram_2d(
102-
extra_info["session_histogram_list"],
103-
extra_info["spatial_bin_centers"],
104-
extra_info["corrected"]["session_histogram_list"],
105-
)
106-
plt.show()
83+
assert np.allclose(expected, extra_info["shifts_array"].squeeze(), rtol=0, atol=0.02)
10784

10885
# plot_session_alignment
10986
# recordings_list: list[BaseRecording],
@@ -124,8 +101,8 @@ def test_align_sessions_finds_correct_shifts(self, test_recording_1, histogram_t
124101

125102
rows, cols = np.triu_indices(len(new_histograms), k=1)
126103
assert np.all(
127-
np.abs(np.corrcoef(new_histograms)[rows, cols])
128-
- np.abs(np.corrcoef(extra_info["session_histogram_list"])[rows, cols])
104+
np.abs(np.corrcoef([hist.flatten() for hist in new_histograms])[rows, cols])
105+
- np.abs(np.corrcoef([hist.flatten() for hist in extra_info["session_histogram_list"]])[rows, cols])
129106
>= 0
130107
)
131108

@@ -766,7 +743,81 @@ def estimate_chunk_size(self):
766743
def test_akima_interpolate_nonrigid_shifts(self):
767744
pass
768745

769-
def test_compute_histogram_crosscorrelation(self):
746+
# TODO:
747+
@pytest.mark.parametrize("shifts", [3]) # -2 #test and off and even shift
748+
def test_compute_histogram_crosscorrelation(self, shifts):
749+
750+
even_hist = np.array([0, 0, 1, 1, 0, 1, 0, 1])
751+
odd_hist = np.array([1, 0, 1, 1, 1, 0])
752+
753+
even_hist_shift = alignment_utils.shift_array_fill_zeros(even_hist, shifts)
754+
odd_hist_shift = alignment_utils.shift_array_fill_zeros(odd_hist, shifts)
755+
756+
session_histogram_list = np.vstack([even_hist, even_hist_shift])
757+
758+
# Ut oh, is interpolate broken?
759+
interpolate = True # or False
760+
interp_factor = 50
761+
762+
shifts_matrix, xcorr_matrix_unsmoothed = alignment_utils.compute_histogram_crosscorrelation(
763+
session_histogram_list,
764+
non_rigid_windows=np.ones((1, even_hist.size)), # TODO: test non rigid!
765+
num_shifts=None, # TODO: test num shifts!
766+
interpolate=interpolate,
767+
interp_factor=interp_factor,
768+
kriging_sigma=0.5,
769+
kriging_p=2,
770+
kriging_d=2,
771+
smoothing_sigma_bin=None,
772+
smoothing_sigma_window=None,
773+
)
774+
breakpoint()
775+
assert alignment_utils.get_shifts_from_session_matrix("to_session_1", shifts_matrix)[-1] == -shifts
776+
777+
num_shifts = even_hist.size * 2 - 1
778+
if interpolate:
779+
assert xcorr_matrix_unsmoothed.shape[1] == num_shifts * interp_factor
780+
else:
781+
assert xcorr_matrix_unsmoothed.shape[1] == num_shifts
782+
783+
shifts_matrix_smoothed_bin, xcorr_matrix_smoothed_bin = alignment_utils.compute_histogram_crosscorrelation(
784+
session_histogram_list,
785+
non_rigid_windows=np.ones((1, even_hist.size)), # TODO: test non rigid!
786+
num_shifts=None, # TODO: test num shifts!
787+
interpolate=interpolate,
788+
interp_factor=interp_factor,
789+
kriging_sigma=1,
790+
kriging_p=1,
791+
kriging_d=1,
792+
smoothing_sigma_bin=0.5,
793+
smoothing_sigma_window=None,
794+
)
795+
796+
shifts_matrix_smoothed_window, xcorr_matrix_smoothed_window = (
797+
alignment_utils.compute_histogram_crosscorrelation(
798+
session_histogram_list,
799+
non_rigid_windows=np.ones((1, even_hist.size)), # TODO: test non rigid!
800+
num_shifts=None, # TODO: test num shifts!
801+
interpolate=interpolate,
802+
interp_factor=interp_factor,
803+
kriging_sigma=1,
804+
kriging_p=1,
805+
kriging_d=1,
806+
smoothing_sigma_bin=None,
807+
smoothing_sigma_window=0.5,
808+
)
809+
)
810+
811+
# make a histogram (odd and even length)
812+
# shift it (odd and even shift)
813+
# check smoothing across bins and time
814+
# check interpolate
815+
# thats it!
816+
817+
def test_compute_histogram_crosscorrelation_gaussian_filter_kwargs(self): ## TODO: incorporate these above
818+
pass
819+
820+
def test_compute_histogram_crosscorrelation_kriging_kwargs(self):
770821
pass
771822

772823
###########################################################################

0 commit comments

Comments
 (0)