@@ -163,17 +163,17 @@ def get_chunked_hist_median(chunked_session_histograms):
163
163
164
164
# TODO: a good test here is to give zero shift for even and off numbered hist and check the output is zero!
165
165
def compute_histogram_crosscorrelation (
166
- session_histogram_list : list [ np .ndarray ] ,
166
+ session_histogram_list : np .ndarray ,
167
167
non_rigid_windows : np .ndarray ,
168
168
num_shifts : int ,
169
169
interpolate : bool ,
170
170
interp_factor : int ,
171
171
kriging_sigma : float ,
172
172
kriging_p : float ,
173
173
kriging_d : float ,
174
- smoothing_sigma_bin : float ,
175
- smoothing_sigma_window : float ,
176
- ):
174
+ smoothing_sigma_bin : None | float ,
175
+ smoothing_sigma_window : None | float ,
176
+ ) -> tuple [ np . ndarray , np . ndarray ] :
177
177
"""
178
178
Given a list of session activity histograms, cross-correlate
179
179
all histograms returning the peak correlation shift (in indices)
@@ -185,7 +185,8 @@ def compute_histogram_crosscorrelation(
185
185
Parameters
186
186
----------
187
187
188
- session_histogram_list : list[np.ndarray]
188
+ session_histogram_list : list[np.ndarray] TODO: change name!!
189
+ (num_sessions, num_bins) array of session activity histograms.
189
190
non_rigid_windows : np.ndarray
190
191
A (num windows x num_bins) binary of weights by which to window
191
192
the activity histogram for non-rigid-registration. For example, if
@@ -258,23 +259,21 @@ def compute_histogram_crosscorrelation(
258
259
"""
259
260
import matplotlib .pyplot as plt
260
261
261
- num_sessions = len ( session_histogram_list )
262
+ num_sessions = session_histogram_list . shape [ 0 ]
262
263
num_bins = session_histogram_list .shape [1 ] # all hists are same length
263
264
num_windows = non_rigid_windows .shape [0 ]
264
265
265
266
shift_matrix = np .zeros ((num_sessions , num_sessions , num_windows ))
266
267
267
268
center_bin = np .floor ((num_bins * 2 - 1 ) / 2 ).astype (int )
268
269
270
+ # Create the (num windows, num_bins) matrix for this pair of sessions
271
+ num_iter = num_bins * 2 - 1 if not num_shifts else num_shifts * 2
272
+ shifts_array = np .arange (- (num_iter // 2 ), num_iter // 2 + 1 )
273
+
269
274
for i in range (num_sessions ):
270
275
for j in range (i , num_sessions ):
271
276
272
- # Create the (num windows, num_bins) matrix for this pair of sessions
273
- num_iter = (
274
- num_bins * 2 - 1
275
- if not num_shifts
276
- else num_shifts * 2 # num_shift_block with iterative alignment is 2x, the same, make note!
277
- )
278
277
xcorr_matrix = np .zeros ((non_rigid_windows .shape [0 ], num_iter ))
279
278
280
279
# For each window, window the session histograms (`window` is binary)
@@ -292,12 +291,12 @@ def compute_histogram_crosscorrelation(
292
291
window_i = windowed_histogram_i - np .mean (windowed_histogram_i , axis = 1 )[:, np .newaxis ]
293
292
window_j = windowed_histogram_j - np .mean (windowed_histogram_j , axis = 1 )[:, np .newaxis ]
294
293
295
- xcorr = np .zeros (num_iter )
296
- for idx , shift in enumerate (range (- num_iter // 2 , num_iter // 2 )):
294
+ xcorr = np .zeros (num_iter + 1 )
295
+
296
+ for idx , shift in enumerate (shifts_array ):
297
297
shifted_i = shift_array_fill_zeros (window_i , shift )
298
298
299
299
xcorr [idx ] = np .correlate (shifted_i .flatten (), window_j .flatten ())
300
-
301
300
else :
302
301
# For a 1D histogram, compute the full cross-correlation and
303
302
# window the desired shifts ( this is faster than manual looping).
@@ -315,11 +314,6 @@ def compute_histogram_crosscorrelation(
315
314
316
315
xcorr_matrix [win_idx , :] = xcorr
317
316
318
- if num_shifts :
319
- shift_center_bin = num_shifts
320
- else :
321
- shift_center_bin = center_bin
322
-
323
317
# Smooth the cross-correlations across the bins
324
318
if smoothing_sigma_bin :
325
319
xcorr_matrix = gaussian_filter (xcorr_matrix , smoothing_sigma_bin , axes = 1 )
@@ -328,36 +322,67 @@ def compute_histogram_crosscorrelation(
328
322
if num_windows > 1 and smoothing_sigma_window :
329
323
xcorr_matrix = gaussian_filter (xcorr_matrix , smoothing_sigma_window , axes = 0 )
330
324
325
+ shifts_array = np .arange (- (num_iter // 2 ), num_iter // 2 + 1 ) # TODO: double check
331
326
# Upsample the cross-correlation
332
327
if interpolate :
333
- shifts = np .arange (xcorr_matrix .shape [1 ])
334
- shifts_upsampled = np .linspace (shifts [0 ], shifts [- 1 ], shifts .size * interp_factor )
328
+
329
+ # shifts = np.arange(xcorr_matrix.shape[1])
330
+ shifts_upsampled = np .linspace (shifts_array [0 ], shifts_array [- 1 ], shifts_array .size * interp_factor )
335
331
336
332
K = kriging_kernel (
337
- np .c_ [np .ones_like (shifts ), shifts ],
333
+ np .c_ [np .ones_like (shifts_array ), shifts_array ],
338
334
np .c_ [np .ones_like (shifts_upsampled ), shifts_upsampled ],
339
335
kriging_sigma ,
340
336
kriging_p ,
341
337
kriging_d ,
342
338
)
343
- xcorr_matrix = np .matmul (xcorr_matrix , K , axes = [(- 2 , - 1 ), (- 2 , - 1 ), (- 2 , - 1 )])
344
339
345
- xcorr_peak = np .argmax (xcorr_matrix , axis = 1 ) / interp_factor
340
+ # breakpoint()
341
+
342
+ xcorr_matrix_old = np .matmul (xcorr_matrix , K , axes = [(- 2 , - 1 ), (- 2 , - 1 ), (- 2 , - 1 )])
343
+ xcorr_matrix_ = np .zeros (
344
+ (xcorr_matrix .shape [0 ], shifts_upsampled .size )
345
+ ) # TODO: check in nonlinear case
346
+ for i_ in range (xcorr_matrix .shape [0 ]):
347
+ xcorr_matrix_ [i_ , :] = np .matmul (xcorr_matrix [i_ , :], K )
348
+
349
+ # breakpoint()
350
+
351
+ plt .plot (shifts_array , xcorr_matrix .T )
352
+ plt .show
353
+ plt .plot (shifts_upsampled , xcorr_matrix_ .T )
354
+ plt .show ()
355
+
356
+ xcorr_matrix = xcorr_matrix_
357
+
358
+ # plt.plot(xcorr_matrix.T)
359
+ # plt.plot(xcorr_matrix_old.T)
360
+ # plt.show()
361
+ #
362
+
363
+ xcorr_peak = np .argmax (xcorr_matrix , axis = 1 )
364
+ shift = shifts_upsampled [xcorr_peak ]
365
+
366
+ # breakpoint()
367
+
346
368
else :
347
369
xcorr_peak = np .argmax (xcorr_matrix , axis = 1 )
370
+ shift = shifts_array [xcorr_peak ]
348
371
349
- # Caclulate and save the shift for session i to j
350
- shift = xcorr_peak - shift_center_bin
372
+ # x=i;y= j
373
+ # breakpoint()
351
374
shift_matrix [i , j , :] = shift
352
375
376
+ breakpoint ()
377
+
353
378
# As xcorr shifts are symmetric, the shift matrix is skew symmetric, so fill
354
379
# the (empty) lower triangular with the negative (already computed) upper triangular to save computation
355
380
for k in range (shift_matrix .shape [2 ]):
356
381
lower_i , lower_j = np .tril_indices_from (shift_matrix [:, :, k ], k = - 1 )
357
382
upper_i , upper_j = np .triu_indices_from (shift_matrix [:, :, k ], k = 1 )
358
383
shift_matrix [lower_i , lower_j , k ] = shift_matrix [upper_i , upper_j , k ] * - 1
359
384
360
- return shift_matrix
385
+ return shift_matrix , xcorr_matrix
361
386
362
387
363
388
def shift_array_fill_zeros (array : np .ndarray , shift : int ) -> np .ndarray :
0 commit comments