@@ -71,6 +71,8 @@ def get_compute_alignment_kwargs() -> dict:
71
71
windows along the probe depth. See `get_spatial_windows`.
72
72
"""
73
73
return {
74
+ "num_shifts_global" : None ,
75
+ "num_shifts_block" : 20 ,
74
76
"interpolate" : False ,
75
77
"interp_factor" : 10 ,
76
78
"kriging_sigma" : 1 ,
@@ -93,8 +95,6 @@ def get_non_rigid_window_kwargs():
93
95
"""
94
96
return {
95
97
"rigid" : True ,
96
- "num_shifts_global" : None ,
97
- "num_shifts_block" : 20 ,
98
98
"win_shape" : "gaussian" ,
99
99
"win_step_um" : 50 ,
100
100
"win_scale_um" : 50 ,
@@ -109,12 +109,13 @@ def get_interpolate_motion_kwargs():
109
109
see that class for parameter descriptions.
110
110
"""
111
111
return {
112
- "border_mode" : "force_zeros" , # fixed as this until can figure out probe
112
+ "border_mode" : "force_zeros" , # fixed as this until can figure out probe
113
113
"spatial_interpolation_method" : "kriging" ,
114
114
"sigma_um" : 20.0 ,
115
- "p" : 2
115
+ "p" : 2 ,
116
116
}
117
117
118
+
118
119
###############################################################################
119
120
# Public Entry Level Functions
120
121
###############################################################################
@@ -222,7 +223,12 @@ def align_sessions(
222
223
223
224
# Ensure list lengths match and all channel locations are the same across recordings.
224
225
_check_align_sessions_inputs (
225
- recordings_list , peaks_list , peak_locations_list , alignment_order , estimate_histogram_kwargs , interpolate_motion_kwargs
226
+ recordings_list ,
227
+ peaks_list ,
228
+ peak_locations_list ,
229
+ alignment_order ,
230
+ estimate_histogram_kwargs ,
231
+ interpolate_motion_kwargs ,
226
232
)
227
233
228
234
print ("Computing a single activity histogram from each session..." )
@@ -311,7 +317,10 @@ def align_sessions_after_motion_correction(
311
317
)
312
318
313
319
motion_window_kwargs = copy .deepcopy (motion_kwargs_list [0 ])
314
- if motion_window_kwargs ["direction" ] != "y" :
320
+
321
+ if (
322
+ "direction" in motion_window_kwargs and motion_window_kwargs ["direction" ] != "y"
323
+ ): # TODO: why is this not in all?
315
324
raise ValueError ("motion correct must have been performed along the 'y' dimension." )
316
325
317
326
if align_sessions_kwargs is None :
@@ -322,24 +331,37 @@ def align_sessions_after_motion_correction(
322
331
# shifts together.
323
332
if (
324
333
"non_rigid_window_kwargs" in align_sessions_kwargs
325
- and "nonrigid" in align_sessions_kwargs ["non_rigid_window_kwargs" ]["rigid_mode " ]
334
+ and not align_sessions_kwargs ["non_rigid_window_kwargs" ]["rigid " ]
326
335
):
327
-
336
+ # TODO: carefully walk through this function! and test all assumptions...
328
337
if not motion_window_kwargs ["rigid" ]:
329
- print (
338
+ print ( # TODO: make a warning
330
339
"Nonrigid inter-session alignment must use the motion correct "
331
340
"nonrigid settings. Overwriting any passed `non_rigid_window_kwargs` "
332
341
"with the motion object non_rigid_window_kwargs."
333
342
)
334
- motion_window_kwargs .pop ("method" )
335
- motion_window_kwargs .pop ("direction" )
343
+ non_rigid_window_kwargs = get_non_rigid_window_kwargs ()
344
+
345
+ # TODO: generate function for replacing one dict into another?
346
+ for (
347
+ k ,
348
+ v ,
349
+ ) in motion_window_kwargs .items (): # TODO: can get tighter alignment here with original implementation?
350
+ if k in non_rigid_window_kwargs :
351
+ non_rigid_window_kwargs [k ] = v
352
+
336
353
align_sessions_kwargs = copy .deepcopy (align_sessions_kwargs )
337
- align_sessions_kwargs ["non_rigid_window_kwargs" ] = motion_window_kwargs
354
+ align_sessions_kwargs ["non_rigid_window_kwargs" ] = non_rigid_window_kwargs
355
+
356
+ corrected_peak_locations = [
357
+ correct_motion_on_peaks (info ["peaks" ], info ["peak_locations" ], info ["motion" ], recording )
358
+ for info , recording in zip (motion_info_list , recordings_list )
359
+ ]
338
360
339
361
return align_sessions (
340
362
recordings_list ,
341
363
[info ["peaks" ] for info in motion_info_list ],
342
- [ info [ "peak_locations" ] for info in motion_info_list ] ,
364
+ corrected_peak_locations ,
343
365
** align_sessions_kwargs ,
344
366
)
345
367
@@ -459,14 +481,14 @@ def _compute_session_histograms(
459
481
recording ,
460
482
peaks ,
461
483
peak_locations ,
462
- histogram_type ,
463
- spatial_bin_edges ,
464
- method ,
465
- log_scale ,
466
- chunked_bin_size_s ,
467
- depth_smooth_um ,
468
- weight_with_amplitude ,
469
- avg_in_bin ,
484
+ histogram_type = histogram_type ,
485
+ spatial_bin_edges = spatial_bin_edges ,
486
+ method = method ,
487
+ log_scale = log_scale ,
488
+ chunked_bin_size_s = chunked_bin_size_s ,
489
+ depth_smooth_um = depth_smooth_um ,
490
+ weight_with_amplitude = weight_with_amplitude ,
491
+ avg_in_bin = avg_in_bin ,
470
492
)
471
493
temporal_bin_centers_list .append (temporal_bin_centers )
472
494
session_histogram_list .append (session_hist )
@@ -539,32 +561,31 @@ def _get_single_session_activity_histogram(
539
561
# full estimation for chunked bin size
540
562
if chunked_bin_size_s == "estimate" :
541
563
542
- one_bin_histogram , _ , _ = alignment_utils .get_activity_histogram (
564
+ scaled_hist , _ , _ = alignment_utils .get_2d_activity_histogram (
543
565
recording ,
544
566
peaks ,
545
567
peak_locations ,
546
568
spatial_bin_edges ,
547
569
log_scale = False ,
548
570
bin_s = None ,
549
571
depth_smooth_um = None ,
550
- scale_to_hz = False ,
572
+ scale_to_hz = True ,
551
573
weight_with_amplitude = False ,
552
574
avg_in_bin = False ,
553
575
)
554
576
555
577
# It is important that the passed histogram is scaled to firing rate in Hz
556
- scaled_hist = one_bin_histogram / recording .get_duration () # TODO: why is this done here when have a scale_to_hz arg??!?
557
578
chunked_bin_size_s = alignment_utils .estimate_chunk_size (scaled_hist )
558
579
chunked_bin_size_s = np .min ([chunked_bin_size_s , recording .get_duration ()])
559
580
560
581
if histogram_type == "activity_1d" :
561
582
562
- chunked_histograms , chunked_temporal_bin_centers , _ = alignment_utils .get_activity_histogram (
583
+ chunked_histograms , chunked_temporal_bin_centers , _ = alignment_utils .get_2d_activity_histogram (
563
584
recording ,
564
585
peaks ,
565
586
peak_locations ,
566
587
spatial_bin_edges ,
567
- log_scale ,
588
+ log_scale = log_scale ,
568
589
bin_s = chunked_bin_size_s ,
569
590
depth_smooth_um = depth_smooth_um ,
570
591
weight_with_amplitude = weight_with_amplitude ,
@@ -656,9 +677,11 @@ def _create_motion_recordings(
656
677
motion ,
657
678
interpolation_time_bin_centers_s = motion .temporal_bins_s ,
658
679
interpolation_time_bin_edges_s = [np .array (recording .get_times ()[0 ], recording .get_times ()[- 1 ])],
659
- ** interpolate_motion_kwargs
680
+ ** interpolate_motion_kwargs ,
660
681
)
661
- corrected_recording = corrected_recording .set_probe (recording .get_probe ()) # TODO: if this works, might need to do above
682
+ corrected_recording = corrected_recording .set_probe (
683
+ recording .get_probe ()
684
+ ) # TODO: if this works, might need to do above
662
685
663
686
corrected_recordings_list .append (corrected_recording )
664
687
@@ -840,8 +863,8 @@ def _compute_session_alignment(
840
863
session_histogram_array = np .array (session_histogram_list )
841
864
842
865
akima_interp_nonrigid = compute_alignment_kwargs .pop ("akima_interp_nonrigid" )
843
- num_shifts_global = non_rigid_window_kwargs .pop ("num_shifts_global" )
844
- num_shifts_block = non_rigid_window_kwargs .pop ("num_shifts_block" )
866
+ num_shifts_global = compute_alignment_kwargs .pop ("num_shifts_global" )
867
+ num_shifts_block = compute_alignment_kwargs .pop ("num_shifts_block" )
845
868
846
869
non_rigid_windows , non_rigid_window_centers = get_spatial_windows (
847
870
contact_depths , spatial_bin_centers , ** non_rigid_window_kwargs
@@ -870,7 +893,7 @@ def _compute_session_alignment(
870
893
871
894
# Then compute the nonrigid shifts
872
895
nonrigid_session_offsets_matrix = alignment_utils .compute_histogram_crosscorrelation (
873
- shifted_histograms , non_rigid_windows , num_shifts_block , ** compute_alignment_kwargs
896
+ shifted_histograms , non_rigid_windows , num_shifts = num_shifts_block , ** compute_alignment_kwargs
874
897
)
875
898
non_rigid_shifts = alignment_utils .get_shifts_from_session_matrix (alignment_order , nonrigid_session_offsets_matrix )
876
899
@@ -920,7 +943,7 @@ def _estimate_rigid_alignment(
920
943
rigid_session_offsets_matrix = alignment_utils .compute_histogram_crosscorrelation (
921
944
session_histogram_array ,
922
945
rigid_window ,
923
- num_shifts ,
946
+ num_shifts = num_shifts ,
924
947
** compute_alignment_kwargs , # TODO: remove the copy above and pass directly. Consider removing this function...
925
948
)
926
949
optimal_shift_indices = alignment_utils .get_shifts_from_session_matrix (
@@ -968,7 +991,6 @@ def _check_align_sessions_inputs(
968
991
"performed using the same probe."
969
992
)
970
993
971
-
972
994
accepted_hist_methods = [
973
995
"entire_session" ,
974
996
"chunked_mean" ,
@@ -998,4 +1020,6 @@ def _check_align_sessions_inputs(
998
1020
if ses_num == 0 :
999
1021
raise ValueError ("`alignment_order` required the session number, not session index." )
1000
1022
1001
- assert interpolate_motion_kwargs ["border_mode" ] == "force_zeros" , "InterpolateMotionRecording must be `force_zeros` until probe is figured out." # TODO: ask sam
1023
+ assert (
1024
+ interpolate_motion_kwargs ["border_mode" ] == "force_zeros"
1025
+ ), "InterpolateMotionRecording must be `force_zeros` until probe is figured out." # TODO: ask sam
0 commit comments