Skip to content

Add correct_inter_session_displacementfunction #3126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 114 commits into
base: main
Choose a base branch
from

Conversation

JoeZiminski
Copy link
Collaborator

@JoeZiminski JoeZiminski commented Jul 2, 2024

This PR requires #3231 , on which it is rebased. Currently this PR contains commits from that PR so it makes sense for that to be merged before reviewing this in depth as the diff will be confusing.

The problem to be solved is inter-session movement ('displacement') of the probe. This is very similar problem to the inter-session motion correction, and can use many of the same tools. The idea is accept as input a list of recording objects (and, optionally, a list of intra-session motion correction outputs). For each session, an activity histogram will be generated, and then the inter-session drift and necessary correction estimated. The correction is then applied to the recording object, but a displacement_info object (similar to motion_info object) will be output. This could be used to correct peaks directly and optionally later, correct templates (see #2626).

Under the hood it essentially using the kilosort drift correction approach (cross correlation for linear alignment, and rigid alignment within multiple segments along the probe y-axis for non-rigid alignment). To generate the histogram for each session, the session is first split into segments and a histogram calculated for each segment, and then either the mean or median taken. The idea is to downweight periods of a session that may have noise. I tried a few different approaches (e.g. maximum value, first eigenvalue) but decided it was overkill and just kept the mean / median for now).

This PR makes a few changes to existing code to refactor anything I needed so it can be shared across modules, for the most part though the changes are contained within new modules.

Tests are added and a new tutorial is added in the documentation.

Question

  • One issue I had is that if I add the module to preprocessing __init__ I get a circular import issue because I am calling some sortingcomponents.motion modules. Please see the comment mode in the preprocessing/__init__.py and uncomment to reproduce. I think this can be solved by moving some of the motion code into core? For now, the functions are not in the API docs for this reason.
Example Script
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings

from spikeinterface.preprocessing.inter_session_alignment import (
    session_alignment,
)
from spikeinterface.widgets import plot_session_alignment, plot_activity_histogram_2d
import matplotlib.pyplot as plt

import spikeinterface.full as si

si.set_global_job_kwargs(n_jobs=10)


if __name__ == '__main__':

    # --------------------------------------------------------------------------------------
    # Load / generate some recordings
    # --------------------------------------------------------------------------------------

    recordings_list, _ = generate_session_displacement_recordings(
        num_units=5,
        recording_durations=[25, 25, 25],
        recording_shifts=((0, 0), (0, -200), (0, 150)),
        non_rigid_gradient=0.1,
        seed=55,
    )

    peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment(
        recordings_list,
        detect_kwargs={"method": "locally_exclusive"},
        localize_peaks_kwargs={"method": "grid_convolution"},
    )

    # --------------------------------------------------------------------------------------
    # Do the estimation
    # --------------------------------------------------------------------------------------
    # For each session, an 'activity histogram' is generated. This can be `entire_session`
    # or the session can be chunked into segments and some summary generated taken over then.
    # This might be useful if periods of the recording have weird kinetics or noise.
    estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs()
    estimate_histogram_kwargs["method"] = "chunked_median"
    estimate_histogram_kwargs["log_scale"] = True

    corrected_recordings_list, extra_info = session_alignment.align_sessions(
        recordings_list,
        peaks_list,
        peak_locations_list,
        alignment_order="to_session_1",  # "to_session_X" or "to_middle"
        estimate_histogram_kwargs=estimate_histogram_kwargs,
    )

    plot_session_alignment(
        recordings_list,
        peaks_list,
        peak_locations_list,
        extra_info["session_histogram_list"],
        **extra_info["corrected"],
        spatial_bin_centers=extra_info["spatial_bin_centers"],
        drift_raster_map_kwargs={"clim":(-250, 0), "scatter_decimate": 10}
    )
    plt.show()

    if estimate_histogram_kwargs["histogram_type"]  == "2d":
        plot_activity_histogram_2d(
            extra_info["session_histogram_list"],
            extra_info["spatial_bin_centers"],
            extra_info["corrected"]["corrected_session_histogram_list"]
        )
        plt.show()

image

@JoeZiminski JoeZiminski force-pushed the inter_session_displacement branch from bf9a84c to 58a6962 Compare July 15, 2024 22:24
@JoeZiminski JoeZiminski force-pushed the inter_session_displacement branch 2 times, most recently from 99431b5 to 8e6e77b Compare July 29, 2024 15:55
@JoeZiminski JoeZiminski changed the title Add 'correct_inter_session_displacement' function Add correct_inter_session_displacementfunction Jul 30, 2024
@JoeZiminski JoeZiminski force-pushed the inter_session_displacement branch 2 times, most recently from 9c8f81a to 1d3d8db Compare July 30, 2024 16:27
@JoeZiminski JoeZiminski force-pushed the inter_session_displacement branch 2 times, most recently from e7c22dc to 8ec328c Compare August 28, 2024 18:32
@alejoe91 alejoe91 added the preprocessing Related to preprocessing module label Nov 4, 2024
@JoeZiminski JoeZiminski force-pushed the inter_session_displacement branch from 2c1ebca to 74d6c45 Compare December 17, 2024 11:41
@JoeZiminski JoeZiminski force-pushed the inter_session_displacement branch 5 times, most recently from 98682ea to 01adece Compare January 17, 2025 19:42
@JoeZiminski JoeZiminski force-pushed the inter_session_displacement branch from 1faafaa to abe27c1 Compare January 21, 2025 00:14
@JoeZiminski JoeZiminski force-pushed the inter_session_displacement branch from 73edb81 to 140a8c7 Compare March 14, 2025 22:05
@@ -168,6 +168,18 @@ def get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um):
return spatial_bins


def get_spatial_bins(recording, direction, hist_margin_um, bin_um):
# TODO: could this be merged with the above function?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess where get_spatial_bin_edges is used this could be used instead

@@ -203,6 +203,7 @@ def __init__(
peak_amplitudes = peak_amplitudes[peak_mask]

from matplotlib.pyplot import colormaps
from matplotlib.colors import Normalize
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small fixes in this module

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
preprocessing Related to preprocessing module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants