diff --git a/.gitignore b/.gitignore index 3727e09..a00214e 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,6 @@ data/ibl-pupil/outputs/ data/mirror-mouse/outputs/ *.pdf +/eks/timing.py +/scripts/plotting.py +/scripts/plotting2.py diff --git a/README.md b/README.md index a6e18a2..b26644d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,16 @@ # EKS -This repo contains code to run an Ensemble Kalman Smoother (EKS) for improving pose estimation outputs. +This repo contains code to run an Ensemble Kalman Smoother (EKS) for improving pose estimation outputs. + +The EKS uses a Kalman filter to ensemble and smooth pose estimation outputs as a post-processing +step after multiple model predictions have been generated, resulting in a more robust output: + +![](assets/crim13_singlecam.gif) + +In this gif, there are five models (in red) on which EKS is run to generate the EKS +output (in green). As can be seen, the EKS output improves tracking by ensembling and smoothing the +individual model outputs, which can often diverge considerably for difficult frames. + +--- ## Installation @@ -46,6 +57,8 @@ If you wish to install the developer version of the package, run installation li pip install -e ".[dev]" ``` +For more information on individual modules and their usage, see [Requirements](docs/requirements.md) + ### Method 2: pip You can also install the `eks` package using the Python Package Index (PyPI): @@ -55,10 +68,20 @@ python3 -m pip install ensemble-kalman-smoother Note that you will not have access to the example data or example scripts with the pip install option. +### Note: Using GPU for fast parallel-scan +As of now, EKS singlecam features a jitted parallel scan implementation for quickly optimizing the +smoothing parameter (notably for larger datasets of 10,000+ frames). In order to utilize parallel scan, +you will need to have a cuda environment with jax enabled. Further instructions can be found in the [jax +docs](https://jax.readthedocs.io/en/latest/installation.html). +------------------------------- + ## Example scripts We provide several example datasets and fitting scripts to illustrate use of the package. See -[Command-Line Arguments](docs/command-line_arguments.md) for more information on arguments. +[Command-Line Arguments](docs/command-line_arguments.md) for more information on arguments, +including optional flags and defaults. We recommend starting with the first of four scripts outlined +below, `singlecam_example.py`, following along with the [Singlecam Overview](docs/singlecam_overview.md) +if a deeper understanding of EKS is desired. ### Single-camera datasets The `singlecam_example.py` script demonstrates how to run the EKS code for standard single-camera @@ -68,8 +91,12 @@ our example. To run the EKS on the example data, execute the following command from inside this repo: ```console -python scripts/singlecam_example.py --input-dir ./data/ibl-pupil --data-type lp --bodypart-list pupil_top_r pupil_bottom_r pupil_left_r pupil_right_r +python scripts/singlecam_example.py --input-dir ./data/ibl-pupil ``` + +The singlecam script is currently the most up-to-date script with the greatest number of feature +implementations, including fast smoothing parameter auto-tuning using GPU-driven parallelization. +[Here](docs/singlecam_overview.md) is a detailed overview of the workflow. ### Multi-camera datasets The `multicam_example.py` script demonstrates how to run the EKS code for multi-camera @@ -82,7 +109,7 @@ for a two-view video of a mouse with cameras named `top` and `bot`. To run the EKS on the example data provided, execute the following command from inside this repo: ```console -python scripts/multicam_example.py --input-dir ./data/mirror-mouse --data-type lp --bodypart-list paw1LH paw2LF paw3RF paw4RH --camera-names top bot +python scripts/multicam_example.py --input-dir ./data/mirror-mouse --bodypart-list paw1LH paw2LF paw3RF paw4RH --camera-names top bot ``` ### IBL pupil dataset @@ -91,7 +118,7 @@ model predictions. To run this script on the example data provided, execute the following command from inside this repo: ```console -python scripts/pupil_example.py --input-dir ./data/ibl-pupil +python scripts/ibl_pupil_example.py --input-dir ./data/ibl-pupil ``` ### IBL paw dataset (multiple asynchronous views) @@ -101,12 +128,16 @@ the two cameras. To run this script on the example data provided, execute the following command from inside this repo: ```console -python scripts/multiview_paw_example.py --input-dir ./data/ibl-paw +python scripts/ibl_paw_multiview_example.py --input-dir ./data/ibl-paw ``` -### Authors +### Authors + Cole Hurwitz +Keemin Lee + +Amol Pasarkar + Matt Whiteway -Keemin Lee \ No newline at end of file diff --git a/assets/crim13_singlecam.gif b/assets/crim13_singlecam.gif new file mode 100644 index 0000000..3acafd8 Binary files /dev/null and b/assets/crim13_singlecam.gif differ diff --git a/docs/command-line_arguments.md b/docs/command-line_arguments.md index 0f78af5..a05f791 100644 --- a/docs/command-line_arguments.md +++ b/docs/command-line_arguments.md @@ -12,11 +12,15 @@ The script is run as a python executable. Arguments are parsed using the `argpar [Python module](https://docs.python.org/3/library/argparse.html), as seen in each of the example scripts. -Arguments fall into two categories: -## File I/O Arguments -These arguments dictate the file directories for reading and writing data, and are present in all example scripts. -- `--csv-dir ` String specifying read-in directory containing CSV files used as input data. +Arguments are either general or script-specific. +## General Arguments +These arguments are present in all example scripts. +- `--input-dir ` String specifying read-in directory containing data. - `--save-dir ` String specifying write-to directory (csv-dir by default). +- `--save-filename ` String specifying output file name ({'smoother_type}_{last_keypoint_smooth_param}' by default). +- `--data-type ` String specifying input data file type. Accepts DeepLabCut (dlc), Lightning Pose (lp), and SLEAP (slp) file types. LP by default. +- `--s-frames or ...` List of camera views. - `--s ` Float specifying the extent of smoothing to be done. Smoothing increases as param **decreases** (range 0.01-20, 0.01 by default) - `--quantile_keep_pca ` Float specifying the percentage of points kept for multi-view PCA. Selectivity increases as param increases (range 0-100, 25 by default) -### [IBL Pupil](../scripts/pupil_example.py) +### [IBL Pupil](../scripts/ibl_pupil_example.py) - `--diameter-s ` Float specifying the extent of smoothing to be done for diameter. Smoothing increases as param **increases** (range 0-1 exclusive, 0.9999 by default) - `--com-s ` Float specifying the extent of smoothing to be done for center of mass. Smoothing increases as param **increases** (range 0-1 exclusive, 0.999 by default) -### [IBL Paw (multiple asynchronous view)](../scripts/multiview_paw_example.py) +### [IBL Paw (multiple asynchronous view)](../scripts/ibl_paw_multiview_example.py) - `--s ` Float specifying the extent of smoothing to be done. Smoothing increases as param **decreases** (range 0.01-20, 0.01 by default) - `--quantile_keep_pca ` Float specifying the percentage of points kept for multi-view PCA. Selectivity increases as param increases (range 0-100, 25 by default) The following table summarizes the script-specific arguments featured in each of the example scripts. -| Argument\Script | [Single-Camera](../scripts/singlecam_example.py) | [Multi-Camera](../scripts/multicam_example.py) | [IBL Pupil](../scripts/pupil_example.py) | [IBL Paw](../scripts/multiview_paw_example.py) | +| Argument\Script | [Single-Camera](../scripts/singlecam_example.py) | [Multi-Camera](../scripts/multicam_example.py) | [IBL Pupil](../scripts/ibl_pupil_example.py) | [IBL Paw](../scripts/ibl_paw_multiview_example.py) | |-----------------------|---------------|--------------|-----------|---------------------| | `--bodypart-list` | ✓ | ✓ | | | | `--camera-names` | | ✓ | | | diff --git a/docs/eks_smoothers.md b/docs/eks_smoothers.md deleted file mode 100644 index 3ad744e..0000000 --- a/docs/eks_smoothers.md +++ /dev/null @@ -1,47 +0,0 @@ -# EKS Smoothers - -This document covers the three EKS smoothers for Single-view, Multi-view, and Pupil use-cases, -working through their main functions while noting their similarities and differences. - -## Overview - -Smoothers take in the DataFrame-formatted marker data from [scripts](scripts.md), and utilizes -functions in [ensemble-kalman.py](../eks/ensemble_kalman.py) to run the EKS on the input data, using the -smoothing parameter in the state-covariance matrix to improve the accuracy of the smoothed predictions. -It returns a new DataFrame containing the smoothed markers while retaining the input format, using -`make_dlc_pandas_index()` from [utils.py](../eks/utils.py). - -## Function Details - -All smoothers have a main function beginning with `ensemble_kalman_...` (the Multi-view smoother -has two of these, one for multi-camera and one for IBL-paw). For instance, -[singleview_smoother.py](../eks/singleview_smoother.py) has the function: -```python -def ensemble_kalman_smoother_single_view( - markers_list, keypoint_ensemble, smooth_param, ensembling_mode='median', zscore_threshold=2, verbose=False): -``` - -These functions apply the Ensemble Kalman Smoother (EKS) to smooth marker data from multiple -ensemble members for each view. - -#### Parameters -- `markers_list` : List of List of DataFrame - - Contains the formatted DataFrames containing predictions from each ensemble member. The -multi-view smoother has an additional parameter for each camera view. -- `keypoint_ensemble` : str - - The name of the keypoint to be ensembled and smoothed. Parameter `keypoint_names` -taken as a list of keypoints instead in the multi-view and pupil smoother. -- `smooth_param` : float - - See [command-line_arguments.md](command-line_arguments.md) for information on the smoothing parameter. -The pupil smoother takes in parameter `state_transition_matrix` instead, which is built from the two smoothing -parameters `diameter-s` and `com-s` in the [pupil example script](../scripts/pupil_example.py). -- `ensembling_mode` : str, optional - - The function used for ensembling. Options are 'mean', 'median', or 'confidence_weighted_mean', -median by default. The parameter does not exist for [pupil_smoother](../eks/pupil_smoother) and instead defaults to median. -- `zscore_threshold` : float, optional - - Minimum standard deviation threshold to reduce the effect of low ensemble standard deviation on a z-score metric. Default is 2. -- `verbose` : bool, optional - - If True, progress will be printed for the user. - -#### Returns -- 'keypoint_df': DataFrame containing smoothed markers for one keypoint. Same format as input DataFrames. \ No newline at end of file diff --git a/docs/ensemble_kalman.md b/docs/ensemble_kalman.md deleted file mode 100644 index 1e15d2a..0000000 --- a/docs/ensemble_kalman.md +++ /dev/null @@ -1,134 +0,0 @@ -# Ensemble Kalman Smoother - -## ensemble -```python -def ensemble(markers_list, keys, mode='median'): -``` - -Computes ensemble median (or mean) and variance of a list of DLC marker dataframes. - -### Args: -- `markers_list`: list - - List of DLC marker dataframes. -- `keys`: list - - List of keys in each marker dataframe. -- `mode`: string, optional (default: 'median') - - Averaging mode which includes 'median', 'mean', or 'confidence_weighted_mean'. - -### Returns: -- `ensemble_preds`: np.ndarray - - Shape: (samples, n_keypoints) -- `ensemble_vars`: np.ndarray - - Shape: (samples, n_keypoints) -- `ensemble_stacks`: np.ndarray - - Shape: (n_models, samples, n_keypoints) -- `keypoints_avg_dict`: dict - - Keys: marker keypoints - - Values: Shape (samples) -- `keypoints_var_dict`: dict - - Keys: marker keypoints - - Values: Shape (samples) -- `keypoints_stack_dict`: dict(dict) - - Keys: model_ids - - Values: Shape (samples) - -## filtering_pass -```python -def filtering_pass(y, m0, S0, C, R, A, Q, ensemble_vars): -``` - -Implements the Kalman filter. - -### Args: -- `y`: np.ndarray - - Shape: (samples, n_keypoints) -- `m0`: np.ndarray - - Shape: (n_latents) -- `S0`: np.ndarray - - Shape: (n_latents, n_latents) -- `C`: np.ndarray - - Shape: (n_keypoints, n_latents) -- `R`: np.ndarray - - Shape: (n_keypoints, n_keypoints) -- `A`: np.ndarray - - Shape: (n_latents, n_latents) -- `Q`: np.ndarray - - Shape: (n_latents, n_latents) -- `ensemble_vars`: np.ndarray - - Shape: (samples, n_keypoints) - -### Returns: -- `mf`: np.ndarray - - Shape: (samples, n_keypoints) -- `Vf`: np.ndarray - - Shape: (samples, n_latents, n_latents) -- `S`: np.ndarray - - Shape: (samples, n_latents, n_latents) - -## kalman_dot -```python -def kalman_dot(array, V, C, R): -``` - -Helper function for matrix multiplication used in the Kalman filter. - -### Args: -- `array`: np.ndarray -- `V`: np.ndarray -- `C`: np.ndarray -- `R`: np.ndarray - -### Returns: -- `K_array`: np.ndarray - -## smooth_backward -```python -def smooth_backward(y, mf, Vf, S, A, Q, C): -``` - -Implements Kalman smoothing backwards. - -### Args: -- `y`: np.ndarray - - Shape: (samples, n_keypoints) -- `mf`: np.ndarray - - Shape: (samples, n_keypoints) -- `Vf`: np.ndarray - - Shape: (samples, n_latents, n_latents) -- `S`: np.ndarray - - Shape: (samples, n_latents, n_latents) -- `A`: np.ndarray - - Shape: (n_latents, n_latents) -- `Q`: np.ndarray - - Shape: (n_latents, n_latents) -- `C`: np.ndarray - - Shape: (n_keypoints, n_latents) - -### Returns: -- `ms`: np.ndarray - - Shape: (samples, n_keypoints) -- `Vs`: np.ndarray - - Shape: (samples, n_latents, n_latents) -- `CV`: np.ndarray - - Shape: (samples, n_latents, n_latents) - -## eks_zscore -```python -def eks_zscore(eks_predictions, ensemble_means, ensemble_vars, min_ensemble_std=2): -``` - -Computes the z-score between EKS prediction and the ensemble for a single keypoint. - -### Args: -- `eks_predictions`: list - - EKS prediction for each coordinate (x and y) for a single keypoint - Shape: (samples, 2) -- `ensemble_means`: list - - Ensemble mean for each coordinate (x and y) for a single keypoint - Shape: (samples, 2) -- `ensemble_vars`: string - - Ensemble variance for each coordinate (x and y) for a single keypoint - Shape: (samples, 2) -- `min_ensemble_std`: float, optional (default: 2) - - Minimum standard deviation threshold to reduce the effect of low ensemble standard deviation. - -### Returns: -- `z_score`: np.ndarray - - Z-score for each time point - Shape: (samples, 1) \ No newline at end of file diff --git a/docs/requirements.md b/docs/requirements.md index d53698e..a5ed809 100644 --- a/docs/requirements.md +++ b/docs/requirements.md @@ -1,50 +1,39 @@ -# Project Requirements - -This document outlines the Python packages required to run the project. - -## Overview - -The `requirements.txt` file lists all the necessary dependencies for the project. Below is a breakdown of each package and its purpose: - -1. **ipykernel** - - Description: Provides the IPython kernel for Jupyter notebooks and interactive computing. - - Usage: This package is used to enable the execution of Python code within Jupyter notebooks. - -2. **matplotlib** - - Description: Provides a MATLAB-like plotting interface for creating static, interactive, and animated visualizations in Python. - - Usage: This package is used for data visualization and plotting graphs within the project. - -3. **numpy** - - Description: Provides support for numerical computations and multidimensional array operations in Python. - - Usage: Numpy is extensively used for numerical computing tasks such as array manipulation, linear algebra, and mathematical operations. - -4. **opencv-python** - - Description: Provides the OpenCV library for computer vision and image processing tasks in Python. - - Usage: This package is used for various computer vision tasks, including image manipulation, object detection, and feature extraction. - -5. **pandas** - - Description: Provides data structures and data analysis tools for handling structured data in Python. - - Usage: Pandas is used for data manipulation, exploration, and analysis, especially with tabular data structures like DataFrames. - -6. **scikit-learn** - - Description: Provides a collection of machine learning algorithms and tools for data mining and data analysis tasks. - - Usage: Scikit-learn is used for implementing machine learning models, including classification, regression, clustering, and dimensionality reduction. - -7. **scipy>=1.2.0** - - Description: Provides scientific computing tools and algorithms for numerical integration, optimization, interpolation, and more. - - Usage: Scipy complements Numpy and provides additional mathematical functions and routines for scientific computing tasks. - -8. **tqdm** - - Description: Provides a fast, extensible progress bar for loops and tasks in Python. - - Usage: This package is used to display progress bars and monitor the progress of iterative tasks, such as data processing or model training. - -9. **typing** - - Description: Provides support for type hints and type checking in Python. - - Usage: Typing is used to annotate function signatures and variables with type information, improving code readability and enabling static type checking. - -## Installation - -To install the required packages, run the following command (copied from `README.md`): - -```bash -pip install -r requirements.txt +# Requirements + +This document explains the purpose of each package listed in the `install_requires` and `extras_require` sections of the `setup.py` file for the `ensemble-kalman-smoother` project. + +## Basic Requirements + +These are the core dependencies required for the project to function properly, and can be installed via: +``` +pip install -e . +``` + +- **ipykernel**: Provides the IPython kernel for Jupyter, allowing the execution of Python code in Jupyter notebooks. +- **matplotlib**: A comprehensive library for creating static, animated, and interactive visualizations in Python. +- **numpy**: The fundamental package for scientific computing with Python, providing support for arrays, matrices, and many mathematical functions. +- **opencv-python**: A library of programming functions mainly aimed at real-time computer vision. It allows for image and video capture and processing. +- **pandas**: A powerful data analysis and manipulation library for Python, providing data structures like DataFrames. +- **scikit-learn**: A machine learning library for Python, offering simple and efficient tools for data mining and data analysis. +- **scipy (>=1.2.0)**: A library used for scientific and technical computing, building on the capabilities of numpy and providing additional functionality. +- **tqdm**: A fast, extensible progress bar for Python and CLI, useful for tracking the progress of loops and processes. +- **typing**: Provides support for type hints, making it easier to write and maintain Python code by specifying expected types of variables. +- **sleap_io**: A library for reading and writing SLEAP (Single Leap Application Protocol) files, which are used for pose estimation in biological research. +- **jax**: A library for high-performance numerical computing, offering support for automatic differentiation and optimized operations on CPUs and GPUs. +- **jaxlib**: A companion library to jax, providing implementations of numerical operations on various hardware platforms. + +## Additional Requirements (for devs) + +These are optional dependencies used for development and documentation purposes, and can be installed via: +``` +pip install -e ".[dev]" +``` + +- **flake8**: A linting tool for Python that checks the code for style and quality issues, ensuring adherence to coding standards. +- **isort**: A tool to sort imports in Python files, organizing them according to the PEP8 style guide. +- **Sphinx**: A documentation generator for Python projects, converting reStructuredText files into various output formats such as HTML and PDF. +- **sphinx_rtd_theme**: The theme for Sphinx documentation, used to create documentation that looks similar to Read the Docs. +- **sphinx-rtd-dark-mode**: An extension for Sphinx to add dark mode support to the Read the Docs theme. +- **sphinx-automodapi**: A Sphinx extension that helps generate documentation from docstrings in the code. +- **sphinx-copybutton**: A Sphinx extension that adds a copy button to code blocks in the documentation. +- **sphinx-design**: A Sphinx extension that adds design elements such as cards, grids, and buttons to the documentation. \ No newline at end of file diff --git a/docs/scripts.md b/docs/scripts.md deleted file mode 100644 index 876559b..0000000 --- a/docs/scripts.md +++ /dev/null @@ -1,49 +0,0 @@ -# Scripts - -This document is a general overview of the inputs, workflow, and outputs for the example scripts in `scripts/`. -It covers the overlapping code across the four examples and notes where they differ. See [Command-Line Arguments](command-line_arguments.md) for usage. - -## Input - -The input is a directory of csv files containing data in either DLC or LP output form. -LP data is converted to DLC via `convert_lp_dlc()` in [utils.py](../eks/utils.py). Input data must -have three headers, scorer, bodyparts, and coords. The name of the scorer will be replaced in the -output as specified by `tracker_name`, a parameter in the main EKS smoother function called by the script. -The body part names must be identical to the bodypart command-line arguments. -The coords must take the form x, y, and likelihood. The following is an example taken from -[IBL-paw](../data/ibl-paw/3f859b5c-e73a-4044-b49e-34bb81e96715.left.rng=0.csv), showing the necessary column -headers: - -| scorer | heatmap_mhcrnn_tracker | heatmap_mhcrnn_tracker | heatmap_mhcrnn_tracker | heatmap_mhcrnn_tracker | heatmap_mhcrnn_tracker | heatmap_mhcrnn_tracker | -|----------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------------------| -| bodyparts | paw_l | paw_l | paw_l | paw_r | paw_r | paw_r | -| coords | x | y | likelihood | x | y | likelihood | - - -## Workflow - -1. **Argument Parsing**: - - Parse [command-line arguments](command-line_arguments.md) to specify input directories and parameters. - -2. **EKS Execution**: - - Check if the provided CSV directory exists. - - Load CSV files containing marker predictions and convert them into the correct format. - - Apply Ensemble Kalman Smoothing (EKS) to each keypoint, iterating through each camera view in the case of multiple cameras - - Each example script calls one of the specialized [EKS smoothers](eks_smoothers.md) specific to that use-case. - - Save the EKS results to a CSV file. - -3. **Plotting Results**: - - Select an example keypoint from the provided list. - - Plot individual model predictions and EKS-smoothed predictions for `x`, `y`, `likelihood`, and `zscore`. - - Save the plot as a PDF file. - -## Output - -The script generates two main outputs: -- A CSV file containing the EKS-smoothed results. -- A PDF file containing visualizations of the EKS results for an example keypoint. - -## TODO -Standardize where the tracker_name is stored (currently in pupil_smoother as a parameter, -hard-coded in multiview_pca_smoother, and hard-coded in utils.py function make_dlc_pandas_index()) -Improve visualizations (video?) \ No newline at end of file diff --git a/docs/singlecam_overview.md b/docs/singlecam_overview.md new file mode 100644 index 0000000..3e31907 --- /dev/null +++ b/docs/singlecam_overview.md @@ -0,0 +1,173 @@ +# Workflow Overview + +This documentation provides an overview of the high-level workflow involved in processing +single-camera datasets using `singlecam_example.py`. `singlecam_example.py` is currently the most +up-to-date script, incorporating efficient optimization routines for finding a suitable smoothing +parameter given data, and is the most useful script to walk through in order to provide a high-level +understanding of the EKS workflow. + +Here, we will progress through the high-level workflow of `singlecam_example.py`. Further details on +key functions `ensemble_kalman_smoother_singlecam` and `singlecam_optimize_smooth` are provided as +well. + +--- + +## singlecam_example.py + +### Overview + +The `singlecam_example.py` script demonstrates how to process and smooth single-camera datasets. +It includes steps to handle input/output operations, format data, and apply the Ensemble Kalman +Smoother (EKS). + +### Workflow Steps + +1. **Collect User-Provided Arguments**: + - Define the `smoother_type` as 'singlecam'. + - Parse command-line arguments using `handle_parse_args(smoother_type)`. + - Extract and set various input parameters such as `input_dir`, `data_type`, `save_dir`, + - `save_filename`, `bodypart_list`, `s`, `s_frames`, and `blocks`. + +2. **Load and Format Input Data**: + - Use `format_data` to read and format input files, and prepare an empty DataFrame for output. + - If `bodypart_list` is not provided, use keypoint names from the input data. + - Print the keypoints being processed. + +3. **Convert Input Data to 3D Array**: + - Convert the list of DataFrames to a 3D NumPy array using `np.stack`. + - Map keypoint names to their respective indices in the DataFrames. + - Crop the 3D array to include only the columns corresponding to the specified body parts + - (`_x`, `_y`, `_likelihood`). + +4. **Apply Ensemble Kalman Smoother**: + - Call `ensemble_kalman_smoother_singlecam` from `singlecam_smoother.py` + with the prepared 3D array and other arguments to + obtain smoothed results (`df_dicts`, `s_finals`). + +5. **Save Smoothed Results**: + - For each body part, convert the resulting DataFrames to CSV files. + - Use `populate_output_dataframe` to integrate the results into the output DataFrame. + - Save the output DataFrame as a CSV file in the specified directory. + +6. **Plot Results**: + - Use `plot_results` to visualize the smoothed data. + - Plot the results for a specific keypoint (`keypoint_i`). + +--- + + +### Key Function: `ensemble_kalman_smoother_singlecam` + +(from `eks/singlecam_smoother.py`) + +This function performs Ensemble Kalman Smoothing on 3D marker data from a single camera. It takes as input a 3D array of marker data, a list of body parts, smoothing parameters, and frames, and returns dataframes with smoothed predictions, final smoothing parameters, and Negative Log-Likelihood (NLL) values. + +#### Parameters: + +- **`markers_3d_array` (np.ndarray)**: A 3D array of marker data with dimensions corresponding to time frames, body parts, and coordinates (x, y, z). +- **`bodypart_list` (list)**: A list of body parts for which the data is being processed. +- **`smooth_param` (float)**: A parameter controlling the smoothing process. +- **`s_frames` (list)**: A list of frames used in the smoothing process. +- **`blocks` (list)**: Optional. A list of blocks for segmenting the data (default is an empty list). +- **`ensembling_mode` (str)**: The mode used for ensembling the data (default is 'median'). +- **`zscore_threshold` (float)**: The Z-score threshold for outlier detection (default is 2). + +#### Returns: + +- **`tuple`**: A tuple containing: + - Dataframes with smoothed predictions. + - Final smoothing parameters (per keypoint). + - NLL values (used for finding the ideal smoothing parameter value) + +### Detailed Steps: + +1. **Initialization**: + - Extract the total number of frames (`T`) and the number of keypoints (`n_keypoints`) from the `markers_3d_array`. + - Define the number of coordinates (`n_coords`) as 2 (x and y). + +2. **Ensemble Statistics**: + - Compute ensemble statistics by calling `jax_ensemble` with the marker data and ensembling mode. + - Extract ensemble predictions (`ensemble_preds`), variances (`ensemble_vars`), and average keypoints (`keypoints_avg_dict`). + +3. **Adjust Observations**: + - Calculate mean and adjusted observations by calling `adjust_observations` with the average keypoints, number of keypoints, and ensemble predictions. + - Obtain `mean_obs_dict`, `adjusted_obs_dict`, and `scaled_ensemble_preds`. + +4. **Initialize Kalman Filter**: + - Initialize Kalman filter values by calling `initialize_kalman_filter` with the scaled ensemble predictions, adjusted observations, and number of keypoints. + - Obtain initial means (`m0s`), covariances (`S0s`), state transition matrices (`As`), covariance matrices (`cov_mats`), observation matrices (`Cs`), observation covariances (`Rs`), and observations (`ys`). + +5. **Smoothing**: + - Perform the main smoothing function by calling `singlecam_optimize_smooth` with the initialized values, ensemble variances, frames, smoothing parameter, and blocks. + - Obtain final smoothing parameters (`s_finals`), means (`ms`), and covariances (`Vs`). + +6. **Process Each Keypoint**: + - Initialize arrays for smoothed means (`y_m_smooths`), variances (`y_v_smooths`), and predicted arrays (`eks_preds_array`). + - Loop through each keypoint to compute smoothed predictions and variances, adjust predictions based on mean observations, and compute Z-scores using `eks_zscore`. + +7. **Final Cleanup**: + - Create a pandas DataFrame for each keypoint with smoothed predictions, variances, and Z-scores. + - Append each DataFrame to a list (`dfs`) and a dictionary (`df_dicts`). + +8. **Return Results**: + - Return a tuple containing the dictionary of DataFrames and the final smoothing parameters. + +--- + + +### Key Function: `singlecam_optimize_smooth` + +This function optimizes the smoothing parameter and uses the result to run the Kalman filter-smoother. It takes in various parameters related to covariance matrices, observations, and initial states, and returns the final smoothing parameters, smoothed means, and smoothed covariances. + +#### Parameters: + +- **`cov_mats` (np.ndarray)**: Covariance matrices. +- **`ys` (np.ndarray)**: Observations with shape (keypoints, frames, coordinates), where coordinate is usually 2. +- **`m0s` (np.ndarray)**: Initial mean state. +- **`S0s` (np.ndarray)**: Initial state covariance. +- **`Cs` (np.ndarray)**: Measurement function. +- **`As` (np.ndarray)**: State-transition matrix. +- **`Rs` (np.ndarray)**: Measurement noise covariance. +- **`ensemble_vars` (np.ndarray)**: Ensemble variances. +- **`s_frames` (list)**: List of frames. +- **`smooth_param` (float)**: Smoothing parameter. +- **`blocks` (list)**: List of blocks for segmenting the data. +- **`maxiter` (int)**: Maximum number of iterations for optimization (default is 1000). + +#### Returns: + +- **`tuple`**: A tuple containing: + - Final smoothing parameters. + - Smoothed means. + - Smoothed covariances. + - Negative log-likelihoods. + - Negative log-likelihood values. + +### Detailed Steps: + +1. **Initialization**: + - Extract the number of keypoints (`n_keypoints`) from the `ys` array. + - Initialize an empty list for final smoothing parameters (`s_finals`). + - If no blocks are provided, create a block for each keypoint. + +2. **Device Check**: + - Check if a GPU is available for parallel processing. If available, use the GPU for parallel smoothing parameter optimization. Otherwise, use the CPU for sequential optimization. + +3. **Define Loss Functions**: + - Define `nll_loss_parallel_scan` for GPU usage and `nll_loss_sequential_scan` for CPU usage. Both functions ensure positivity by taking the exponential of the smoothing parameter and call the appropriate smoothing function (`singlecam_smooth_min_parallel` or `singlecam_smooth_min`). + +4. **Smooth Parameter Optimization**: + - If a `smooth_param` is provided, use it directly. Otherwise, initialize guesses for each keypoint using `compute_initial_guesses` and crop the frames using `crop_frames`. + - Optimize the negative log-likelihood for each block of keypoints: + - Initialize the smoothing parameter (`s_init`) with a positive guess. + - Set up the optimizer using `optax.adam`. + - Select the relevant subsets of the input arrays for the current block. + - Define a `step` function to perform optimization steps. + - Iterate the optimization process until convergence or the maximum number of iterations is reached. + +5. **Final Smooth**: + - After optimization, perform a final forward-backward pass with the optimized smoothing parameters by calling `final_forwards_backwards_pass`. + +6. **Return Results**: + - Return a tuple containing the final smoothing parameters, smoothed means, and smoothed covariances. + diff --git a/scripts/general_scripting.py b/eks/command_line_args.py similarity index 52% rename from scripts/general_scripting.py rename to eks/command_line_args.py index b2de75e..9a46a04 100644 --- a/scripts/general_scripting.py +++ b/eks/command_line_args.py @@ -1,5 +1,6 @@ -import os import argparse +import os +import re # ------------------------------------------------------------- """ Collection of General Functions for EKS Scripting @@ -21,6 +22,7 @@ def handle_io(input_dir, save_dir): return save_dir +# Handles extraction of arguments from command-line flags def handle_parse_args(script_type): parser = argparse.ArgumentParser() parser.add_argument( @@ -44,9 +46,27 @@ def handle_parse_args(script_type): parser.add_argument( '--data-type', help='format of input data (Lightning Pose = lp, SLEAP = slp), dlc by default.', - default='dlc', + default='lp', type=str, ) + parser.add_argument( + '--s-frames', + help='frames to be considered for smoothing ' + 'parameter optimization, first 2k frames by default. Moot if --s is specified. ' + 'Format: "[(start_int, end_int), (start_int, end_int), ... ]" or int. ' + 'Inputting a single int uses all frames from 1 to the int. ' + '(None, end_int) starts from first frame; (start_int, None) proceeds to last frame.', + default=[(None, 10000)], + type=parse_s_frames, + ) + parser.add_argument( + '--blocks', + help='keypoints to be blocked for correlated noise. Generates on smoothing param per ' + 'block, as opposed to per keypoint. Specified by the form "x1, x2, x3; y1, y2"' + ' referring to keypoint indices (starting at 0)', + default=[], + type=parse_blocks, + ) if script_type == 'singlecam': add_bodyparts(parser) add_s(parser) @@ -67,13 +87,59 @@ def handle_parse_args(script_type): return args +# Helper function for parsing s-frames +def parse_s_frames(input_string): + try: + # First, check if the input is a single integer + if input_string.isdigit(): + end = int(input_string) + return [(1, end)] # Handle as from first to 'end' + + # Remove spaces, replace with nothing + cleaned = re.sub(r'\s+', '', input_string) + # Match tuples in the form of (x,ys), (x,), (,ys) + tuple_pattern = re.compile(r'\((\d*),(\d*)\)') + matches = tuple_pattern.findall(cleaned) + + if not matches: + raise ValueError("No valid tuples found.") + + tuples = [] + for start, end in matches: + # Convert numbers to integers or None if empty + start = int(start) if start else None + end = int(end) if end else None + if start is not None and end is not None and start > end: + raise ValueError("Start index cannot be greater than end index.") + tuples.append((start, end)) + + return tuples + except Exception as e: + raise argparse.ArgumentTypeError(f"Invalid format for --s-frames: {e}") + + +# Helper function for parsing blocks +def parse_blocks(blocks_str): + try: + # Split the input string by ';' to separate each block + blocks = blocks_str.split(';') + # Split each block by ',' to get individual integers and convert to lists of integers + parsed_blocks = [list(map(int, block.split(','))) for block in blocks] + return parsed_blocks + except ValueError as e: + raise argparse.ArgumentTypeError(f"Invalid format for --blocks: {blocks_str}. Error: {e}") + + +# -------------------------------------- # Helper Functions for handle_parse_args +# -------------------------------------- + + def add_bodyparts(parser): parser.add_argument( '--bodypart-list', - required=True, nargs='+', - help='the list of body parts to be ensembled and smoothed', + help='the list of body parts to be ensembled and smoothed. If not specified, uses all.', ) return parser @@ -81,7 +147,9 @@ def add_bodyparts(parser): def add_s(parser): parser.add_argument( '--s', - help='smoothing parameter ranges from .01-20 (smaller values = more smoothing)', + help='Specifying a smoothing parameter overrides the auto-tuning function. ' + 'Providing multiple args will set each additional bodypart to the next s param', + nargs='+', type=float, ) return parser @@ -111,7 +179,6 @@ def add_diameter_s(parser): parser.add_argument( '--diameter-s', help='smoothing parameter for diameter (closer to 1 = more smoothing)', - default=.9999, type=float ) return parser @@ -121,7 +188,6 @@ def add_com_s(parser): parser.add_argument( '--com-s', help='smoothing parameter for center of mass (closer to 1 = more smoothing)', - default=.999, type=float ) return parser diff --git a/eks/core.py b/eks/core.py index 1861af2..72f9622 100644 --- a/eks/core.py +++ b/eks/core.py @@ -1,6 +1,19 @@ +from functools import partial from collections import defaultdict + +import jax +import jax.scipy as jsc import numpy as np -from scipy.optimize import minimize +from jax import jit +from jax import numpy as jnp +from jax import vmap +from jax.lax import associative_scan + +# ------------------------------------------------------------------------------------------ +# Original Core Functions: These functions are still in use for the multicam and IBL scripts +# as of this update, but will eventually be replaced the with faster versions used in +# the singlecam script +# ------------------------------------------------------------------------------------------ def ensemble(markers_list, keys, mode='median'): @@ -83,7 +96,7 @@ def ensemble(markers_list, keys, mode='median'): ensemble_vars = np.asarray(ensemble_vars).T ensemble_stacks = np.asarray(ensemble_stacks).T return ensemble_preds, ensemble_vars, ensemble_stacks, \ - keypoints_avg_dict, keypoints_var_dict, keypoints_stack_dict + keypoints_avg_dict, keypoints_var_dict, keypoints_stack_dict def forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars): @@ -107,11 +120,11 @@ def forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars): shape (samples, n_keypoints) Returns: - mf: np.ndarray + mfs: np.ndarray shape (samples, n_keypoints) - Vf: np.ndarray + Vfs: np.ndarray shape (samples, n_latents, n_latents) - S: np.ndarray + Ss: np.ndarray shape (samples, n_latents, n_latents) innovations: np.ndarray shape (samples, n_keypoints) @@ -122,9 +135,8 @@ def forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars): mf = np.zeros(shape=(T, m0.shape[0])) Vf = np.zeros(shape=(T, m0.shape[0], m0.shape[0])) S = np.zeros(shape=(T, m0.shape[0], m0.shape[0])) - innovations = np.zeros((T, y.shape[1])) # Assuming y is m x T + innovations = np.zeros((T, y.shape[1])) innovation_cov = np.zeros((T, C.shape[0], C.shape[0])) - # time-varying observation variance for i in range(ensemble_vars.shape[1]): R[i, i] = ensemble_vars[0][i] @@ -141,7 +153,7 @@ def forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars): mf[t, :] = np.dot(A, mf[t - 1, :]) S[t - 1] = np.dot(A, np.dot(Vf[t - 1, :], A.T)) + Q - if np.sum(~np.isnan(y[t, :])) >= 2: # Check if any value in y[t] is not NaN + if np.sum(~np.isnan(y[t, :])) >= 2: # Check if any value in ys[t] is not NaN # Update R for time-varying observation variance for i in range(ensemble_vars.shape[1]): R[i, i] = ensemble_vars[t][i] @@ -154,18 +166,10 @@ def forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars): Vf[t, :] = S[t - 1] - K_array else: Vf[t, :] = S[t - 1] - return mf, Vf, S, innovations, innovation_cov -def kalman_dot(innovation, V, C, R): - innovation_cov = R + np.dot(C, np.dot(V, C.T)) - innovation_cov_inv = np.linalg.solve(innovation_cov, innovation) - K_array = np.dot(V, np.dot(C.T, innovation_cov_inv)) - return K_array, innovation_cov - - -def backward_pass(y, mf, Vf, S, A, Q, C): +def backward_pass(y, mf, Vf, S, A): """Implements Kalman-smoothing backwards Args: y: np.ndarray @@ -199,10 +203,9 @@ def backward_pass(y, mf, Vf, S, A, Q, C): # Last-time smoothed posterior is equal to last-time filtered posterior ms[-1, :] = mf[-1, :] Vs[-1, :, :] = Vf[-1, :, :] - # Smoothing steps for i in range(T - 2, -1, -1): - if not np.all(np.isnan(y[i])): # Check if all values in y[i] are not NaN + if not np.all(np.isnan(y[i])): # Check if all values in ys[i] are not NaN try: J = np.linalg.solve(S[i], np.dot(A, Vf[i])).T except np.linalg.LinAlgError: @@ -212,19 +215,412 @@ def backward_pass(y, mf, Vf, S, A, Q, C): Vs[i] = Vf[i] + np.dot(J, np.dot(Vs[i + 1] - S[i], J.T)) ms[i] = mf[i] + np.dot(J, ms[i + 1] - np.dot(A, mf[i])) CV[i] = np.dot(Vs[i + 1], J.T) - return ms, Vs, CV +def kalman_dot(innovation, V, C, R): + """ Kalman dot product computation """ + innovation_cov = R + np.dot(C, np.dot(V, C.T)) + innovation_cov_inv = np.linalg.solve(innovation_cov, innovation) + Ks = np.dot(V, np.dot(C.T, innovation_cov_inv)) + return Ks, innovation_cov + + +def compute_nll(innovations, innovation_covs, epsilon=1e-6): + """ + Computes the negative log likelihood, which is a likelihood measurement for the + EKS prediction. This metric is used (minimized) to optimize s. + """ + T = innovations.shape[0] + n_coords = innovations.shape[1] + nll = 0 + nll_values = [] + c = np.log(2 * np.pi) * n_coords # The Gaussian normalization constant part + for t in range(T): + if not np.any(np.isnan(innovations[t])): # Check if any value in innovations[t] is not NaN + # Regularize the innovation covariance matrix by adding epsilon to the diagonal + reg_innovation_cov = innovation_covs[t] + epsilon * np.eye(n_coords) + + # Compute the log determinant of the regularized covariance matrix + log_det_S = np.log(np.abs(np.linalg.det(reg_innovation_cov)) + epsilon) + solved_term = np.linalg.solve(reg_innovation_cov, innovations[t]) + quadratic_term = np.dot(innovations[t], solved_term) + + # Compute the NLL increment for time step t + nll_increment = 0.5 * np.abs((log_det_S + quadratic_term + c)) + nll_values.append(nll_increment) + nll += nll_increment + return nll, nll_values + + +# ------------------------------------------------------------------------------------- +# Fast Core Functions: These functions are fast versions used by the singlecam script +# and will eventually replace the Original Core Functions +# ------------------------------------------------------------------------------------- + +# ----- Sequential Functions for CPU ----- + +def jax_ensemble(markers_3d_array, mode='median'): + """ + Computes ensemble median (or mean) and variance of a 3D array of DLC marker data using JAX. + + Returns: + ensemble_preds: np.ndarray + shape (n_timepoints, n_keypoints, n_coordinates). + ensembled predictions for each keypoint for each target + ensemble_vars: np.ndarray + shape (n_timepoints, n_keypoints, n_coordinates). + ensembled variances for each keypoint for each target + """ + markers_3d_array = jnp.array(markers_3d_array) # Convert to JAX array + n_frames = markers_3d_array.shape[1] + n_keypoints = markers_3d_array.shape[2] // 3 + + # Initialize output structures + ensemble_preds = np.zeros((n_frames, n_keypoints, 2)) + ensemble_vars = np.zeros((n_frames, n_keypoints, 2)) + + # Choose the appropriate JAX function based on the mode + if mode == 'median': + avg_func = lambda x: jnp.nanmedian(x, axis=0) + elif mode == 'mean': + avg_func = lambda x: jnp.nanmean(x, axis=0) + elif mode == 'confidence_weighted_mean': + avg_func = None + else: + raise ValueError(f"{mode} averaging not supported") + + def compute_stats(i): + data_x = markers_3d_array[:, :, 3 * i] + data_y = markers_3d_array[:, :, 3 * i + 1] + data_likelihood = markers_3d_array[:, :, 3 * i + 2] + + if mode == 'confidence_weighted_mean': + conf_per_keypoint = jnp.sum(data_likelihood, axis=0) + mean_conf_per_keypoint = conf_per_keypoint / data_likelihood.shape[0] + avg_x = jnp.sum(data_x * data_likelihood, axis=0) / conf_per_keypoint + avg_y = jnp.sum(data_y * data_likelihood, axis=0) / conf_per_keypoint + var_x = jnp.nanvar(data_x, axis=0) / mean_conf_per_keypoint + var_y = jnp.nanvar(data_y, axis=0) / mean_conf_per_keypoint + else: + avg_x = avg_func(data_x) + avg_y = avg_func(data_y) + var_x = jnp.nanvar(data_x, axis=0) + var_y = jnp.nanvar(data_y, axis=0) + + return avg_x, avg_y, var_x, var_y + + compute_stats_jit = jax.jit(compute_stats) + stats = jax.vmap(compute_stats_jit)(jnp.arange(n_keypoints)) + + avg_x, avg_y, var_x, var_y = stats + + keypoints_avg_dict = {} + for i in range(n_keypoints): + ensemble_preds[:, i, 0] = avg_x[i] + ensemble_preds[:, i, 1] = avg_y[i] + ensemble_vars[:, i, 0] = var_x[i] + ensemble_vars[:, i, 1] = var_y[i] + keypoints_avg_dict[2 * i] = avg_x[i] + keypoints_avg_dict[2 * i + 1] = avg_y[i] + + # Convert outputs to JAX arrays + ensemble_preds = jnp.array(ensemble_preds) + ensemble_vars = jnp.array(ensemble_vars) + keypoints_avg_dict = {k: jnp.array(v) for k, v in keypoints_avg_dict.items()} + + return ensemble_preds, ensemble_vars, keypoints_avg_dict + + +def kalman_filter_step(carry, curr_y): + m_prev, V_prev, A, Q, C, R, nll_net = carry + + # Predict + m_pred = jnp.dot(A, m_prev) + V_pred = jnp.dot(A, jnp.dot(V_prev, A.T)) + Q + + # Update + innovation = curr_y - jnp.dot(C, m_pred) + innovation_cov = jnp.dot(C, jnp.dot(V_pred, C.T)) + R + K = jnp.dot(V_pred, jnp.dot(C.T, jnp.linalg.inv(innovation_cov))) + m_t = m_pred + jnp.dot(K, innovation) + V_t = V_pred - jnp.dot(K, jnp.dot(C, V_pred)) + + nll_current = single_timestep_nll(innovation, innovation_cov) + nll_net = nll_net + nll_current + + return (m_t, V_t, A, Q, C, R, nll_net), (m_t, V_t, nll_current) + + +# Always run the sequential filter on CPU. +# GPU will deploy individual kernels for each scan iteration, very slow. +@partial(jit, backend='cpu') +def jax_forward_pass(y, m0, cov0, A, Q, C, R): + """ + Kalman Filter for a single keypoint + (can be vectorized using vmap for handling multiple keypoints in parallel) + Parameters: + y: Shape (num_timepoints, observation_dimension). + m0: Shape (state_dim,). Initial state of system. + cov0: Shape (state_dim, state_dim). Initial covariance of state variable. + A: Shape (state_dim, state_dim). Process transition matrix. + Q: Shape (state_dim, state_dim). Process noise covariance matrix. + C: Shape (observation_dim, state_dim). Observation coefficient matrix. + R: Shape (observation_dim, observation_dim). Observation noise covar matrix. + + Returns: + mfs: Shape (timepoints, state_dim). Mean filter state at each timepoint. + Vfs: Shape (timepoints, state_dim, state_dim). Covar for each filtered estimate. + nll_net: Shape (1,). Negative log likelihood observations -log (p(y_1, ..., y_T)) + """ + # Initialize carry + carry = (m0, cov0, A, Q, C, R, 0) + carry, outputs = jax.lax.scan(kalman_filter_step, carry, y) + mfs, Vfs, _ = outputs + nll_net = carry[-1] + return mfs, Vfs, nll_net + + +def kalman_smoother_step(carry, X): + m_ahead_smooth, v_ahead_smooth, A, Q = carry + m_curr_filter, v_curr_filter = X[0], X[1] + + # Compute the smoother gain + ahead_cov = jnp.dot(A, jnp.dot(v_curr_filter, A.T)) + Q + + smoothing_gain = jsc.linalg.solve(ahead_cov, jnp.dot(A, v_curr_filter.T)).T + smoothed_state = m_curr_filter + jnp.dot(smoothing_gain, m_ahead_smooth - m_curr_filter) + smoothed_cov = v_curr_filter + jnp.dot(jnp.dot(smoothing_gain, m_ahead_smooth - ahead_cov), + smoothing_gain.T) + + return (smoothed_state, smoothed_cov, A, Q), (smoothed_state, smoothed_cov) + + +@partial(jit, backend='cpu') +def jax_backward_pass(mfs, Vfs, A, Q): + """ + Runs the kalman smoother given the filtered values + Parameters: + mfs: Shape (timepoints, state_dim). The kalman-filtered means of the data. + Vfs: Shape (timepoints, state_dim, state_dimension). + The kalman-filtered covariance matrix of the state vector at each time point. + A: Shape (state_dim, state_dim). The process transition matrix + Q: Shape (state_dim, state_dim). The covariance of the process noise. + Returns: + smoothed_states: Shape (timepoints, state_dim). + The smoothed estimates for the state vector starting at the first timepoint + where observations are possible. + smoothed_state_covariances: Shape (timepoints, state_dim, state_dim). + """ + carry = (mfs[-1], Vfs[-1], A, Q) + + # Reverse scan over the time steps + carry, outputs = jax.lax.scan( + kalman_smoother_step, + carry, + [mfs[:-1], Vfs[:-1]], + reverse=True + ) + + smoothed_states, smoothed_state_covariances = outputs + smoothed_states = jnp.append(smoothed_states, jnp.expand_dims(mfs[-1], 0), 0) + smoothed_state_covariances = jnp.append(smoothed_state_covariances, + jnp.expand_dims(Vfs[-1], 0), 0) + return smoothed_states, smoothed_state_covariances + + +def single_timestep_nll(innovation, innovation_cov): + epsilon = 1e-6 + n_coords = innovation.shape[0] + + # Regularize the innovation covariance matrix by adding epsilon to the diagonal + reg_innovation_cov = innovation_cov + epsilon * jnp.eye(n_coords) + + # Compute the log determinant of the regularized covariance matrix + log_det_S = jnp.log(jnp.abs(jnp.linalg.det(reg_innovation_cov)) + epsilon) + solved_term = jnp.linalg.solve(reg_innovation_cov, innovation) + quadratic_term = jnp.dot(innovation, solved_term) + + # Compute the NLL increment for the current time step + c = jnp.log(2 * jnp.pi) * n_coords # The Gaussian normalization constant part + nll_increment = 0.5 * jnp.abs(log_det_S + quadratic_term + c) + return nll_increment + + +# ----- Parallel Functions for GPU ----- + +def first_filtering_element(C, A, Q, R, m0, P0, y): + # model.F = A, model.H = C, + S = C @ Q @ C.T + R + CF, low = jsc.linalg.cho_factor(S) # note the jsc + + m1 = A @ m0 + P1 = A @ P0 @ A.T + Q + S1 = C @ P1 @ C.T + R + K1 = jsc.linalg.solve(S1, C @ P1, assume_a='pos').T # note the jsc + + A_updated = jnp.zeros_like(A) + b = m1 + K1 @ (y - C @ m1) + C_updated = P1 - K1 @ S1 @ K1.T + + # note the jsc + eta = A.T @ C.T @ jsc.linalg.cho_solve((CF, low), y) + J = A.T @ C.T @ jsc.linalg.cho_solve((CF, low), C @ A) + return A_updated, b, C_updated, J, eta + + +def generic_filtering_element(C, A, Q, R, y): + S = C @ Q @ C.T + R + CF, low = jsc.linalg.cho_factor(S) # note the jsc + K = jsc.linalg.cho_solve((CF, low), C @ Q).T # note the jsc + A_updated = A - K @ C @ A + b = K @ y + C_updated = Q - K @ C @ Q + + # note the jsc + eta = A.T @ C.T @ jsc.linalg.cho_solve((CF, low), y) + J = A.T @ C.T @ jsc.linalg.cho_solve((CF, low), C @ A) + return A_updated, b, C_updated, J, eta + + +def make_associative_filtering_elements(C, A, Q, R, m0, P0, observations): + first_elems = first_filtering_element(C, A, Q, R, m0, P0, observations[0]) + generic_elems = vmap(lambda o: generic_filtering_element(C, A, Q, R, o))(observations[1:]) + return tuple(jnp.concatenate([jnp.expand_dims(first_e, 0), gen_es]) + for first_e, gen_es in zip(first_elems, generic_elems)) + + +@partial(vmap) +def filtering_operator(elem1, elem2): + # # note the jsc everywhere + A1, b1, C1, J1, eta1 = elem1 + A2, b2, C2, J2, eta2 = elem2 + dim = A1.shape[0] + I_var = jnp.eye(dim) # note the jnp + + I_C1J2 = I_var + C1 @ J2 + temp = jsc.linalg.solve(I_C1J2.T, A2.T).T + A = temp @ A1 + b = temp @ (b1 + C1 @ eta2) + b2 + C = temp @ C1 @ A2.T + C2 + + I_J2C1 = I_var + J2 @ C1 + temp = jsc.linalg.solve(I_J2C1.T, A1).T + + eta = temp @ (eta2 - J2 @ b1) + eta1 + J = temp @ J2 @ A1 + J1 + + return A, b, C, J, eta + + +def pkf(y, m0, cov0, A, Q, C, R): + initial_elements = make_associative_filtering_elements(C, A, Q, R, m0, cov0, y) + final_elements = associative_scan(filtering_operator, initial_elements) + return final_elements + + +pkf_func = jit(pkf) + + +def get_kalman_means(A_scan, b_scan, m0): + """ + Computes the Kalman mean at a single timepoint, the result is: + A_scan @ m0 + b_scan + + Returned shape: (state_dimension, 1) + """ + return A_scan @ jnp.expand_dims(m0, axis=1) + jnp.expand_dims(b_scan, axis=1) + + +def get_kalman_variances(C): + return C + + +def get_next_cov(A, C, Q, R, filter_cov, filter_mean): + """ + Given the moments of p(x_t | y_1, ..., y_t) (normal filter distribution), + compute the moments of the distribution for: + p(y_{t+1} | y_1, ..., y_t) + + Params: + A (np.ndarray): Shape (state_dimension, state_dimension) Process coeff matrix + C (np.ndarray): Shape (obs_dimension, state_dimension) Observation coeff matrix + Q (np.ndarray): Shape (state_dimension, state_dimension). Process noise covariance matrix. + R (np.ndarray): Shape (obs_dimension, obs_dimension). Observation noise covariance matrix. + filter_cov (np.ndarray). Shape (state_dimension, state_dimension). Filtered covariance + filter_mean (np.ndarray). Shape (state_dimension, 1). Filter mean + + Returns: + mean (np.ndarray). Shape (obs_dimension, 1) + cov (np.ndarray). Shape (obs_dimension, obs_dimension). + """ + mean = C @ A @ filter_mean + cov = C @ (A @ filter_cov @ A.T + Q) @ C.T + R + return mean, cov + + +def compute_marginal_nll(value, mean, covariance): + return -1 * jax.scipy.stats.multivariate_normal.logpdf(value, mean, covariance) + + +def parallel_loss_single(A_scan, b_scan, C_scan, A, C, Q, R, next_observation, m0): + curr_mean = get_kalman_means(A_scan, b_scan, m0) + curr_cov = get_kalman_variances(C_scan) # Placeholder; just returns identity + + next_mean, next_cov = get_next_cov(A, C, Q, R, curr_cov, curr_mean) + return jnp.squeeze(curr_mean), curr_cov, compute_marginal_nll(jnp.squeeze(next_observation), + jnp.squeeze(next_mean), next_cov) + + +parallel_loss_func_vmap = jit( + vmap(parallel_loss_single, in_axes=(0, 0, 0, None, None, None, None, 0, None), + out_axes=(0, 0, 0))) + + +@partial(jit) +def y1_given_x0_nll(C, A, Q, R, m0, cov0, obs): + y1_predictive_mean = C @ A @ jnp.expand_dims(m0, axis=1) + y1_predictive_cov = C @ (A @ cov0 @ A.T + Q) @ C.T + R + addend = -1 * jax.scipy.stats.multivariate_normal.logpdf(obs, jnp.squeeze(y1_predictive_mean), + y1_predictive_cov) + return addend + + +def pkf_and_loss(y, m0, cov0, A, Q, C, R): + A_scan, b_scan, C_scan, _, _ = pkf_func(y, m0, cov0, A, Q, C, R) + + # Gives us the NLL for p(y_i | y_1, ..., y_{i-1}) for i > 1. + # Need to use the parallel scan outputs for this. i = 1 handled below + filtered_states, filtered_covariances, losses = parallel_loss_func_vmap(A_scan[:-1], + b_scan[:-1], + C_scan[:-1], A, C, Q, + R, y[1:], m0) + + # Gives us the NLL for p_y(y_1 | x_0) + addend = y1_given_x0_nll(C, A, Q, R, m0, cov0, y[0]) + + final_mean = get_kalman_means(A_scan[-1], b_scan[-1], m0).T + final_covariance = jnp.expand_dims(get_kalman_variances(C_scan[-1]), axis=0) + filtered_states = jnp.concatenate([filtered_states, final_mean], axis=0) + filtered_variances = jnp.concatenate([filtered_covariances, final_covariance], axis=0) + return filtered_states, filtered_variances, jnp.sum(losses) + addend + + +# ------------------------------------------------------------------------------------- +# Misc: These miscellaneous functions generally have specific computations used by the +# core functions or the smoothers +# ------------------------------------------------------------------------------------- + + def eks_zscore(eks_predictions, ensemble_means, ensemble_vars, min_ensemble_std=2): """Computes zscore between eks prediction and the ensemble for a single keypoint. Args: eks_predictions: list - EKS prediction for each coordinate (x and y) for as single keypoint - (samples, 2) + EKS prediction for each coordinate (x and ys) for as single keypoint - (samples, 2) ensemble_means: list - Ensemble mean for each coordinate (x and y) for as single keypoint - (samples, 2) + Ensemble mean for each coordinate (x and ys) for as single keypoint - (samples, 2) ensemble_vars: string - Ensemble var for each coordinate (x and y) for as single keypoint - (samples, 2) + Ensemble var for each coordinate (x and ys) for as single keypoint - (samples, 2) min_ensemble_std: Minimum std threshold to reduce the effect of low ensemble std (default 2). Returns: @@ -244,50 +640,51 @@ def eks_zscore(eks_predictions, ensemble_means, ensemble_vars, min_ensemble_std= return z_score -''' -# Kept for reference -- Returns high-precision but low speed result +def compute_covariance_matrix(ensemble_preds): + """ + Compute the covariance matrix E for correlated noise dynamics. -def optimize_smoothing_param(cov_matrix, y, m0, s0, C, A, R, ensemble_vars): - guess = compute_initial_guess(y, ensemble_vars) - result = minimize( - return_nll_only, - x0=guess, # initial smooth param guess - args=(cov_matrix, y, m0, s0, C, A, R, ensemble_vars), - method='Nelder-Mead' - ) - print(f'Optimal at s={result.x[0]}') - return result.x[0] -''' - - -def optimize_smoothing_params(cov_matrix, y, m0, s0, C, A, R, ensemble_vars, max_frames=2000): - guess = compute_initial_guesses(ensemble_vars) - # Update xatol during optimization - def callback(xk): - # Update xatol based on the current solution xk - xatol = np.log(np.abs(xk)) * 0.01 - - # Update the options dictionary with the new xatol value - options['xatol'] = xatol - - # Initialize options with initial xatol - options = {'xatol': np.log(guess)} - - result = minimize( - return_nll_only, - x0=guess, # initial smooth param guess - args=(cov_matrix, y[:max_frames], m0, s0, C, A, R, ensemble_vars), - method='Nelder-Mead', - options=options, - callback=callback # Pass the callback function - ) - print(f'Optimal at s={result.x[0]}') - return result.x[0] + Parameters: + ensemble_preds: A 3D array of shape (T, n_keypoints, n_coords) + containing the ensemble predictions. + Returns: + E: A 2K x 2K covariance matrix where K is the number of keypoints. + """ + # Get the number of time steps, keypoints, and coordinates + T, n_keypoints, n_coords = ensemble_preds.shape + + # Flatten the ensemble predictions to shape (T, 2K) where K is the number of keypoints + flattened_preds = ensemble_preds.reshape(T, -1) + + # Compute the temporal differences + temporal_diffs = np.diff(flattened_preds, axis=0) + + # Compute the covariance matrix of the temporal differences + E = np.cov(temporal_diffs, rowvar=False) + + # Index covariance matrix into blocks for each keypoint + cov_mats = [] + for i in range(n_keypoints): + E_block = extract_submatrix(E, i) + cov_mats.append(E_block) + cov_mats = jnp.array(cov_mats) + return cov_mats -# Function to compute ensemble mean, temporal differences, and standard deviation -def compute_initial_guesses(ensemble_vars): +def extract_submatrix(Qs, i, submatrix_size=2): + # Compute the start indices for the submatrix + i_q = 2 * i + start_indices = (i_q, i_q) + + # Use jax.lax.dynamic_slice to extract the submatrix + submatrix = jax.lax.dynamic_slice(Qs, start_indices, (submatrix_size, submatrix_size)) + + return submatrix + + +def compute_initial_guesses(ensemble_vars): + """Computes an initial guess for optimized s, which is the stdev of temporal differences.""" # Consider only the first 2000 entries in ensemble_vars ensemble_vars = ensemble_vars[:2000] @@ -306,81 +703,5 @@ def compute_initial_guesses(ensemble_vars): temporal_diffs_list.append(temporal_diff) # Compute standard deviation of temporal differences - std_dev_guess = np.std(temporal_diffs_list) - print(f'Initial guess: {std_dev_guess}') + std_dev_guess = round(np.std(temporal_diffs_list), 5) return std_dev_guess - - -# Combines filtering_pass, smoothing, and computing nll -def filter_smooth_nll(cov_matrix, smooth_param, y, m0, S0, C, A, R, ensemble_vars): - - # Adjust Q based on smooth_param and cov_matrix - Q = smooth_param * cov_matrix - # Run filtering and smoothing with the current smooth_param - mf, Vf, S, innovs, innov_cov = forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars) - ms, Vs, CV = backward_pass(y, mf, Vf, S, A, Q, C) - # Compute the negative log-likelihood based on innovations and their covariance - nll, nll_values = compute_nll(innovs, innov_cov) - return ms, Vs, nll, nll_values - - -# filter_smooth_nll version for iterative calls from optimize_smoothing_param -def return_nll_only(cov_matrix, smooth_param, y, m0, S0, C, A, R, ensemble_vars): - # Adjust Q based on smooth_param and cov_matrix - Q = smooth_param * cov_matrix - smooth_param = smooth_param[0] - # Run filtering and smoothing with the current smooth_param - mf, Vf, S, innovs, innov_cov = forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars) - # Compute the negative log-likelihood based on innovations and their covariance - nll, nll_values = compute_nll(innovs, innov_cov) - return nll - - -def compute_nll(innovations, innovation_covs, epsilon=1e-6): - T = innovations.shape[0] - n_keypoints = innovations.shape[1] - nll = 0 - nll_values = [] - k = np.log(2 * np.pi) * n_keypoints # The Gaussian normalization constant part - for t in range(T): - if not np.any(np.isnan(innovations[t])): # Check if any value in innovations[t] is not NaN - # Regularize the innovation covariance matrix by adding epsilon to the diagonal - reg_innovation_cov = innovation_covs[t] + epsilon * np.eye(n_keypoints) - - # Compute the log determinant of the regularized covariance matrix - log_det_S = np.log(np.abs(np.linalg.det(reg_innovation_cov)) + epsilon) - solved_term = np.linalg.solve(reg_innovation_cov, innovations[t]) - quadratic_term = np.dot(innovations[t], solved_term) - - # Compute the NLL increment for time step t - nll_increment = 0.5 * np.abs((log_det_S + quadratic_term + k)) - nll_values.append(nll_increment) - nll += nll_increment - return nll, nll_values - - -# Alternative implementation of NLL - -def compute_nll_2(y, mf, S, C, epsilon=1e-6, lower_bound=0, upper_bound=0): - T, n_keypoints = y.shape - nll = 0 - nll_values = [] - k = np.log(2 * np.pi) * n_keypoints - - for t in range(T): - # Compute the innovation for time t - innovation = y[t, :] - np.dot(C, mf[t, :]) - - # Compute the log determinant and the quadratic term - A = np.dot(C, S[t]) - - # Add epsilon to the diagonal elements of A and S[t] - A += np.eye(A.shape[0]) + epsilon - S[t] += np.eye(S[t].shape[0]) + epsilon - - log_det_S = np.log(np.linalg.det(A)) - quadratic_term = np.dot(innovation.T, np.linalg.solve(S[t], innovation)) - nll_increment = 0.5 * (log_det_S + quadratic_term + k) - nll_values.append(nll_increment) - nll += nll_increment - return nll, nll_values diff --git a/eks/multiview_pca_smoother.py b/eks/ibl_paw_multiview_smoother.py similarity index 65% rename from eks/multiview_pca_smoother.py rename to eks/ibl_paw_multiview_smoother.py index aab3c18..8831d9c 100644 --- a/eks/multiview_pca_smoother.py +++ b/eks/ibl_paw_multiview_smoother.py @@ -2,14 +2,11 @@ import pandas as pd from scipy.interpolate import interp1d from sklearn.decomposition import PCA + +from eks.core import backward_pass, eks_zscore, ensemble, forward_pass from eks.utils import make_dlc_pandas_index -from eks.core import ensemble, forward_pass, \ - backward_pass, eks_zscore, optimize_smoothing_params, filter_smooth_nll -# ----------------------- -# funcs for kalman paw -# ----------------------- def remove_camera_means(ensemble_stacks, camera_means): scaled_ensemble_stacks = ensemble_stacks.copy() for k in range(len(ensemble_stacks)): @@ -33,7 +30,7 @@ def pca(S, n_comps): return pca_.fit(S), pca_.explained_variance_ratio_ -def ensemble_kalman_smoother_paw_asynchronous( +def ensemble_kalman_smoother_ibl_paw( markers_list_left_cam, markers_list_right_cam, timestamps_left_cam, timestamps_right_cam, keypoint_names, smooth_param, quantile_keep_pca, ensembling_mode='median', @@ -219,7 +216,7 @@ def ensemble_kalman_smoother_paw_asynchronous( # kalman filtering + smoothing # -------------------------------------------------------------- # $z_t = (d_t, x_t, y_t)$ - # $z_t = A z_{t-1} + e_t, e_t ~ N(0,E)$ + # $z_t = As z_{t-1} + e_t, e_t ~ N(0,E)$ # $O_t = B z_t + n_t, n_t ~ N(0,D_t)$ dfs = {} @@ -282,9 +279,9 @@ def ensemble_kalman_smoother_paw_asynchronous( # -------------------------------------- # Do the smoothing step print(f"smoothing {paw} paw...") - ms, Vs, _ = backward_pass(y, mf, Vf, S, A, Q, C) + ms, Vs, _ = backward_pass(y, mf, Vf, S, A) print("done smoothing") - # Smoothed posterior over y + # Smoothed posterior over ys y_m_smooth = np.dot(C, ms.T).T y_v_smooth = np.swapaxes(np.dot(C, np.dot(Vs, C.T)), 0, 1) @@ -388,197 +385,3 @@ def ensemble_kalman_smoother_paw_asynchronous( return {'left_df': df_left, 'right_df': df_right}, \ markers_list_left_cam, markers_list_right_cam - - -# ----------------------- -# funcs for mirror-mouse -# ----------------------- -def ensemble_kalman_smoother_multi_cam( - markers_list_cameras, keypoint_ensemble, smooth_param, quantile_keep_pca, camera_names, - ensembling_mode='median', zscore_threshold=2): - """Use multi-view constraints to fit a 3d latent subspace for each body part. - - Parameters - ---------- - markers_list_cameras : list of list of pd.DataFrames - each list element is a list of dataframe predictions from one ensemble member for each - camera. - keypoint_ensemble : str - the name of the keypoint to be ensembled and smoothed - smooth_param : float - ranges from .01-2 (smaller values = more smoothing) - quantile_keep_pca - percentage of the points are kept for multi-view PCA (lowest ensemble variance) - camera_names: list - the camera names (should be the same length as markers_list_cameras). - ensembling_mode: - the function used for ensembling ('mean', 'median', or 'confidence_weighted_mean') - zscore_threshold: - Minimum std threshold to reduce the effect of low ensemble std on a zscore metric - (default 2). - - Returns - ------- - - Returns - ------- - dict - camera_dfs: dataframe containing smoothed markers for each camera; same format as input - dataframes - """ - - # -------------------------------------------------------------- - # interpolate right cam markers to left cam timestamps - # -------------------------------------------------------------- - num_cameras = len(camera_names) - markers_list_stacked_interp = [] - markers_list_interp = [[] for i in range(num_cameras)] - camera_likelihoods_stacked = [] - for model_id in range(len(markers_list_cameras[0])): - bl_markers_curr = [] - camera_markers_curr = [[] for i in range(num_cameras)] - camera_likelihoods = [[] for i in range(num_cameras)] - for i in range(markers_list_cameras[0][0].shape[0]): - curr_markers = [] - for camera in range(num_cameras): - markers = np.array(markers_list_cameras[camera][model_id].to_numpy()[i, [0, 1]]) - likelihood = np.array(markers_list_cameras[camera][model_id].to_numpy()[i, [2]])[0] - camera_markers_curr[camera].append(markers) - curr_markers.append(markers) - camera_likelihoods[camera].append(likelihood) - # combine predictions for all cameras - bl_markers_curr.append(np.concatenate(curr_markers)) - markers_list_stacked_interp.append(bl_markers_curr) - camera_likelihoods_stacked.append(camera_likelihoods) - camera_likelihoods = np.asarray(camera_likelihoods) - for camera in range(num_cameras): - markers_list_interp[camera].append(camera_markers_curr[camera]) - camera_likelihoods[camera] = np.asarray(camera_likelihoods[camera]) - markers_list_stacked_interp = np.asarray(markers_list_stacked_interp) - markers_list_interp = np.asarray(markers_list_interp) - camera_likelihoods_stacked = np.asarray(camera_likelihoods_stacked) - - keys = [keypoint_ensemble + '_x', keypoint_ensemble + '_y'] - markers_list_cams = [[] for i in range(num_cameras)] - for k in range(len(markers_list_interp[0])): - for camera in range(num_cameras): - markers_cam = pd.DataFrame(markers_list_interp[camera][k], columns=keys) - markers_cam[f'{keypoint_ensemble}_likelihood'] = camera_likelihoods_stacked[k][camera] - markers_list_cams[camera].append(markers_cam) - # compute ensemble median for each camera - cam_ensemble_preds = [] - cam_ensemble_vars = [] - cam_ensemble_stacks = [] - cam_keypoints_mean_dict = [] - cam_keypoints_var_dict = [] - cam_keypoints_stack_dict = [] - for camera in range(num_cameras): - cam_ensemble_preds_curr, cam_ensemble_vars_curr, cam_ensemble_stacks_curr, \ - cam_keypoints_mean_dict_curr, cam_keypoints_var_dict_curr, \ - cam_keypoints_stack_dict_curr = \ - ensemble(markers_list_cams[camera], keys, mode=ensembling_mode) - cam_ensemble_preds.append(cam_ensemble_preds_curr) - cam_ensemble_vars.append(cam_ensemble_vars_curr) - cam_ensemble_stacks.append(cam_ensemble_stacks_curr) - cam_keypoints_mean_dict.append(cam_keypoints_mean_dict_curr) - cam_keypoints_var_dict.append(cam_keypoints_var_dict_curr) - cam_keypoints_stack_dict.append(cam_keypoints_stack_dict_curr) - - # filter by low ensemble variances - hstacked_vars = np.hstack(cam_ensemble_vars) - max_vars = np.max(hstacked_vars, 1) - quantile_keep = quantile_keep_pca - good_frames = np.where(max_vars <= np.percentile(max_vars, quantile_keep))[0] - - good_cam_ensemble_preds = [] - good_cam_ensemble_vars = [] - for camera in range(num_cameras): - good_cam_ensemble_preds.append(cam_ensemble_preds[camera][good_frames]) - good_cam_ensemble_vars.append(cam_ensemble_vars[camera][good_frames]) - - good_ensemble_preds = np.hstack(good_cam_ensemble_preds) - # good_ensemble_vars = np.hstack(good_cam_ensemble_vars) - means_camera = [] - for i in range(good_ensemble_preds.shape[1]): - means_camera.append(good_ensemble_preds[:, i].mean()) - - ensemble_preds = np.hstack(cam_ensemble_preds) - ensemble_vars = np.hstack(cam_ensemble_vars) - ensemble_stacks = np.concatenate(cam_ensemble_stacks, 2) - remove_camera_means(ensemble_stacks, means_camera) - - good_scaled_ensemble_preds = remove_camera_means( - good_ensemble_preds[None, :, :], means_camera)[0] - ensemble_pca, ensemble_ex_var = pca( - good_scaled_ensemble_preds, 3) - - scaled_ensemble_preds = remove_camera_means(ensemble_preds[None, :, :], means_camera)[0] - ensemble_pcs = ensemble_pca.transform(scaled_ensemble_preds) - good_ensemble_pcs = ensemble_pcs[good_frames] - - y_obs = scaled_ensemble_preds - - # compute center of mass - # latent variables (observed) - good_z_t_obs = good_ensemble_pcs # latent variables - true 3D pca - - # ------ Set values for kalman filter ------ - m0 = np.asarray([0.0, 0.0, 0.0]) # initial state: mean - S0 = np.asarray([[np.nanvar(good_z_t_obs[:, 0]), 0.0, 0.0], - [0.0, np.nanvar(good_z_t_obs[:, 1]), 0.0], - [0.0, 0.0, np.nanvar(good_z_t_obs[:, 2])]]) # diagonal: var - - A = np.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) # state-transition matrix, - - # Q = np.asarray([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) <-- state-cov matrix? - - d_t = good_z_t_obs[1:] - good_z_t_obs[:-1] - - C = ensemble_pca.components_.T # Measurement function is inverse transform of PCA - R = np.eye(ensemble_pca.components_.shape[1]) # placeholder diagonal matrix for ensemble var - - cov_matrix = np.cov(d_t.T) - - # Call functions from ensemble_kalman to optimize smooth_param before filtering and smoothing - if smooth_param is None: - smooth_param = optimize_smoothing_params(cov_matrix, y_obs, m0, S0, C, A, R, ensemble_vars) - ms, Vs, nll, nll_values = filter_smooth_nll( - cov_matrix, smooth_param, y_obs, m0, S0, C, A, R, ensemble_vars) - print(f"NLL is {nll} for {keypoint_ensemble}, smooth_param={smooth_param}") - smooth_param_final = smooth_param - - # Smoothed posterior over y - y_m_smooth = np.dot(C, ms.T).T - y_v_smooth = np.swapaxes(np.dot(C, np.dot(Vs, C.T)), 0, 1) - - # -------------------------------------- - # final cleanup - # -------------------------------------- - pdindex = make_dlc_pandas_index([keypoint_ensemble], - labels=["x", "y", "likelihood", "x_var", "y_var", "zscore"]) - camera_indices = [] - for camera in range(num_cameras): - camera_indices.append([camera * 2, camera * 2 + 1]) - camera_dfs = {} - for camera, camera_name in enumerate(camera_names): - var = np.empty(y_m_smooth.T[camera_indices[camera][0]].shape) - var[:] = np.nan - eks_pred_x = \ - y_m_smooth.T[camera_indices[camera][0]] + means_camera[camera_indices[camera][0]] - eks_pred_y = \ - y_m_smooth.T[camera_indices[camera][1]] + means_camera[camera_indices[camera][1]] - # compute zscore for EKS to see how it deviates from the ensemble - eks_predictions = np.asarray([eks_pred_x, eks_pred_y]).T - zscore = eks_zscore(eks_predictions, cam_ensemble_preds[camera], cam_ensemble_vars[camera], - min_ensemble_std=zscore_threshold) - pred_arr = np.vstack([ - eks_pred_x, - eks_pred_y, - var, - y_v_smooth[:, camera_indices[camera][0], camera_indices[camera][0]], - y_v_smooth[:, camera_indices[camera][1], camera_indices[camera][1]], - zscore, - ]).T - camera_dfs[camera_name + '_df'] = pd.DataFrame(pred_arr, columns=pdindex) - return camera_dfs, smooth_param_final, nll_values - # return camera_dfs, cam_keypoints_mean_dict, cam_keypoints_var_dict diff --git a/eks/pupil_smoother.py b/eks/ibl_pupil_smoother.py similarity index 69% rename from eks/pupil_smoother.py rename to eks/ibl_pupil_smoother.py index 16521db..feed93d 100644 --- a/eks/pupil_smoother.py +++ b/eks/ibl_pupil_smoother.py @@ -1,8 +1,11 @@ +import warnings + import numpy as np import pandas as pd -from eks.utils import make_dlc_pandas_index -from eks.core import ensemble, forward_pass, backward_pass, eks_zscore -import warnings +from scipy.optimize import minimize + +from eks.core import backward_pass, compute_nll, eks_zscore, ensemble, forward_pass +from eks.utils import crop_frames, make_dlc_pandas_index # ----------------------- @@ -28,9 +31,9 @@ def get_pupil_location(dlc): tmp_x2 = np.median(np.hstack([r[:, 0, None], le[:, 0, None]]), axis=1) center[:, 0] = np.nanmedian(np.hstack([tmp_x1[:, None], tmp_x2[:, None]]), axis=1) - # both top and bottom must be present in y-dir + # both top and bottom must be present in ys-dir tmp_y1 = np.median(np.hstack([t[:, 1, None], b[:, 1, None]]), axis=1) - # ok if either left or right is nan in y-dir + # ok if either left or right is nan in ys-dir tmp_y2 = np.nanmedian(np.hstack([r[:, 1, None], le[:, 1, None]]), axis=1) center[:, 1] = np.nanmedian(np.hstack([tmp_y1[:, None], tmp_y2[:, None]]), axis=1) return center @@ -48,7 +51,7 @@ def get_pupil_diameter(dlc): :return: np.array, pupil diameter estimate for each time point, shape (n_frames,) """ diameters = [] - # Get the x,y coordinates of the four pupil points + # Get the x,ys coordinates of the four pupil points top, bottom, left, right = [np.vstack((dlc[f'pupil_{point}_r_x'], dlc[f'pupil_{point}_r_y'])) for point in ['top', 'bottom', 'left', 'right']] # First compute direct diameters @@ -76,11 +79,12 @@ def add_mean_to_array(pred_arr, keys, mean_x, mean_y): return processed_arr_dict -def ensemble_kalman_smoother_pupil( +def ensemble_kalman_smoother_ibl_pupil( markers_list, keypoint_names, tracker_name, - state_transition_matrix, + smooth_params, + s_frames, likelihood_default=np.nan, zscore_threshold=2, ): @@ -93,7 +97,8 @@ def ensemble_kalman_smoother_pupil( keypoint_names: list tracker_name : str tracker name for constructing final dataframe - state_transition_matrix : np.ndarray + smooth_params : [float, float] + contains smoothing parameters for diameter and center of mass likelihood_default value to store in likelihood column; should be np.nan or int in [0, 1] zscore_threshold: @@ -145,16 +150,6 @@ def ensemble_kalman_smoother_pupil( [0.0, 0.0, np.nanvar(y_t_obs)] ]) - # state-transition matrix - A = state_transition_matrix - - # state covariance matrix - Q = np.asarray([ - [np.nanvar(pupil_diameters) * (1 - (A[0, 0] ** 2)), 0, 0], - [0, np.nanvar(x_t_obs) * (1 - A[1, 1] ** 2), 0], - [0, 0, np.nanvar(y_t_obs) * (1 - (A[2, 2] ** 2))] - ]) - # Measurement function C = np.asarray( [[0, 1, 0], [-.5, 0, 1], [0, 1, 0], [.5, 0, 1], [.5, 1, 0], [0, 0, 1], [-.5, 1, 0], @@ -177,24 +172,17 @@ def ensemble_kalman_smoother_pupil( scaled_ensemble_stacks[:, :, i] -= mean_x_obs else: scaled_ensemble_stacks[:, :, i] -= mean_y_obs - y = scaled_ensemble_preds + y_obs = scaled_ensemble_preds # -------------------------------------- # perform filtering # -------------------------------------- - # do filtering pass with time-varying ensemble variances - print("filtering...") - mf, Vf, S, _, _ = forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars) - print("done filtering") - - # -------------------------------------- - # perform smoothing - # -------------------------------------- - # Do the smoothing step - print("smoothing...") - ms, Vs, _ = backward_pass(y, mf, Vf, S, A, Q, C) - print("done smoothing") - # Smoothed posterior over y + smooth_params, ms, Vs, nll, nll_values = pupil_optimize_smooth( + y_obs, m0, S0, C, R, ensemble_vars, + np.var(pupil_diameters), np.var(x_t_obs), np.var(y_t_obs), s_frames, smooth_params) + diameter_s, com_s = smooth_params[0], smooth_params[1] + print(f"NLL is {nll} for diameter_s={diameter_s}, com_s={com_s}") + # Smoothed posterior over ys y_m_smooth = np.dot(C, ms.T).T y_v_smooth = np.swapaxes(np.dot(C, np.dot(Vs, C.T)), 0, 1) @@ -236,10 +224,88 @@ def ensemble_kalman_smoother_pupil( pred_arr2 = [] pred_arr2.append(ms[:, 0]) pred_arr2.append(ms[:, 1] + mean_x_obs) # add back x mean of pupil location - pred_arr2.append(ms[:, 2] + mean_y_obs) # add back y mean of pupil location + pred_arr2.append(ms[:, 2] + mean_y_obs) # add back ys mean of pupil location pred_arr2 = np.asarray(pred_arr2) arrays = [[tracker_name, tracker_name, tracker_name], ['diameter', 'com_x', 'com_y']] pd_index2 = pd.MultiIndex.from_arrays(arrays, names=('scorer', 'latent')) latents_df = pd.DataFrame(pred_arr2.T, columns=pd_index2) - return {'markers_df': markers_df, 'latents_df': latents_df} + return {'markers_df': markers_df, 'latents_df': latents_df}, smooth_params, nll_values + + +def pupil_optimize_smooth( + y, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var, + s_frames=[(1, 2000)], + smooth_params=[None, None]): + """Optimize-and-smooth function for the pupil example script.""" + # Optimize smooth_param + if smooth_params[0] is None or smooth_params[1] is None: + + # Unpack s_frames + y_shortened = crop_frames(y, s_frames) + + # Minimize negative log likelihood + smooth_params = minimize( + pupil_smooth_min, # function to minimize + x0=[1, 1], + args=(y_shortened, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var), + method='Nelder-Mead', + tol=0.002, + bounds=[(0, 1), (0, 1)] # bounds for each parameter in smooth_params + ) + smooth_params = [round(smooth_params.x[0], 5), round(smooth_params.x[1], 5)] + print(f'Optimal at diameter_s={smooth_params[0]}, com_s={smooth_params[1]}') + + # Final smooth with optimized s + ms, Vs, nll, nll_values = pupil_smooth_final( + y, smooth_params, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var) + + return smooth_params, ms, Vs, nll, nll_values + + +def pupil_smooth_final(y, smooth_params, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var): + # Construct state transition matrix + diameter_s = smooth_params[0] + com_s = smooth_params[1] + A = np.asarray([ + [diameter_s, 0, 0], + [0, com_s, 0], + [0, 0, com_s] + ]) + # cov_matrix + Q = np.asarray([ + [diameters_var * (1 - (A[0, 0] ** 2)), 0, 0], + [0, x_var * (1 - A[1, 1] ** 2), 0], + [0, 0, y_var * (1 - (A[2, 2] ** 2))] + ]) + # Run filtering and smoothing with the current smooth_param + mf, Vf, S, innovs, innov_cov = forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars) + ms, Vs, CV = backward_pass(y, mf, Vf, S, A) + # Compute the negative log-likelihood based on innovations and their covariance + nll, nll_values = compute_nll(innovs, innov_cov) + return ms, Vs, nll, nll_values + + +def pupil_smooth_min(smooth_params, y, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var): + # Construct As + diameter_s, com_s = smooth_params[0], smooth_params[1] + A = np.array([ + [diameter_s, 0, 0], + [0, com_s, 0], + [0, 0, com_s] + ]) + + # Construct cov_matrix Q + Q = np.array([ + [diameters_var * (1 - (A[0, 0] ** 2)), 0, 0], + [0, x_var * (1 - A[1, 1] ** 2), 0], + [0, 0, y_var * (1 - (A[2, 2] ** 2))] + ]) + + # Run filtering with the current smooth_param + mf, Vf, S, innovs, innov_cov = forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars) + + # Compute the negative log-likelihood + nll, nll_values = compute_nll(innovs, innov_cov) + + return nll diff --git a/eks/multicam_smoother.py b/eks/multicam_smoother.py new file mode 100644 index 0000000..5f2cdc0 --- /dev/null +++ b/eks/multicam_smoother.py @@ -0,0 +1,269 @@ +import numpy as np +import pandas as pd +from scipy.optimize import minimize + +from eks.core import ensemble, eks_zscore, compute_initial_guesses, forward_pass, backward_pass, \ + compute_nll +from eks.ibl_paw_multiview_smoother import remove_camera_means, pca +from eks.utils import make_dlc_pandas_index, crop_frames + + +def ensemble_kalman_smoother_multicam( + markers_list_cameras, keypoint_ensemble, smooth_param, quantile_keep_pca, camera_names, + s_frames, ensembling_mode='median', zscore_threshold=2): + """Use multi-view constraints to fit a 3d latent subspace for each body part. + + Parameters + ---------- + markers_list_cameras : list of list of pd.DataFrames + each list element is a list of dataframe predictions from one ensemble member for each + camera. + keypoint_ensemble : str + the name of the keypoint to be ensembled and smoothed + smooth_param : float + ranges from .01-2 (smaller values = more smoothing) + quantile_keep_pca + percentage of the points are kept for multi-view PCA (lowest ensemble variance) + camera_names: list + the camera names (should be the same length as markers_list_cameras). + s_frames : list of tuples or int + specifies frames to be used for smoothing parameter auto-tuning + the function used for ensembling ('mean', 'median', or 'confidence_weighted_mean') + zscore_threshold: + Minimum std threshold to reduce the effect of low ensemble std on a zscore metric + (default 2). + + Returns + ------- + dict + camera_dfs: dataframe containing smoothed markers for each camera; same format as input + dataframes + """ + + # -------------------------------------------------------------- + # interpolate right cam markers to left cam timestamps + # -------------------------------------------------------------- + num_cameras = len(camera_names) + markers_list_stacked_interp = [] + markers_list_interp = [[] for i in range(num_cameras)] + camera_likelihoods_stacked = [] + for model_id in range(len(markers_list_cameras[0])): + bl_markers_curr = [] + camera_markers_curr = [[] for i in range(num_cameras)] + camera_likelihoods = [[] for i in range(num_cameras)] + for i in range(markers_list_cameras[0][0].shape[0]): + curr_markers = [] + for camera in range(num_cameras): + markers = np.array(markers_list_cameras[camera][model_id].to_numpy()[i, [0, 1]]) + likelihood = np.array(markers_list_cameras[camera][model_id].to_numpy()[i, [2]])[0] + camera_markers_curr[camera].append(markers) + curr_markers.append(markers) + camera_likelihoods[camera].append(likelihood) + # combine predictions for all cameras + bl_markers_curr.append(np.concatenate(curr_markers)) + markers_list_stacked_interp.append(bl_markers_curr) + camera_likelihoods_stacked.append(camera_likelihoods) + camera_likelihoods = np.asarray(camera_likelihoods) + for camera in range(num_cameras): + markers_list_interp[camera].append(camera_markers_curr[camera]) + camera_likelihoods[camera] = np.asarray(camera_likelihoods[camera]) + markers_list_stacked_interp = np.asarray(markers_list_stacked_interp) + markers_list_interp = np.asarray(markers_list_interp) + camera_likelihoods_stacked = np.asarray(camera_likelihoods_stacked) + + keys = [keypoint_ensemble + '_x', keypoint_ensemble + '_y'] + markers_list_cams = [[] for i in range(num_cameras)] + for k in range(len(markers_list_interp[0])): + for camera in range(num_cameras): + markers_cam = pd.DataFrame(markers_list_interp[camera][k], columns=keys) + markers_cam[f'{keypoint_ensemble}_likelihood'] = camera_likelihoods_stacked[k][camera] + markers_list_cams[camera].append(markers_cam) + # compute ensemble median for each camera + cam_ensemble_preds = [] + cam_ensemble_vars = [] + cam_ensemble_stacks = [] + cam_keypoints_mean_dict = [] + cam_keypoints_var_dict = [] + cam_keypoints_stack_dict = [] + for camera in range(num_cameras): + cam_ensemble_preds_curr, cam_ensemble_vars_curr, cam_ensemble_stacks_curr, \ + cam_keypoints_mean_dict_curr, cam_keypoints_var_dict_curr, \ + cam_keypoints_stack_dict_curr = \ + ensemble(markers_list_cams[camera], keys, mode=ensembling_mode) + cam_ensemble_preds.append(cam_ensemble_preds_curr) + cam_ensemble_vars.append(cam_ensemble_vars_curr) + cam_ensemble_stacks.append(cam_ensemble_stacks_curr) + cam_keypoints_mean_dict.append(cam_keypoints_mean_dict_curr) + cam_keypoints_var_dict.append(cam_keypoints_var_dict_curr) + cam_keypoints_stack_dict.append(cam_keypoints_stack_dict_curr) + + # filter by low ensemble variances + hstacked_vars = np.hstack(cam_ensemble_vars) + max_vars = np.max(hstacked_vars, 1) + quantile_keep = quantile_keep_pca + good_frames = np.where(max_vars <= np.percentile(max_vars, quantile_keep))[0] + + good_cam_ensemble_preds = [] + good_cam_ensemble_vars = [] + for camera in range(num_cameras): + good_cam_ensemble_preds.append(cam_ensemble_preds[camera][good_frames]) + good_cam_ensemble_vars.append(cam_ensemble_vars[camera][good_frames]) + + good_ensemble_preds = np.hstack(good_cam_ensemble_preds) + # good_ensemble_vars = np.hstack(good_cam_ensemble_vars) + means_camera = [] + for i in range(good_ensemble_preds.shape[1]): + means_camera.append(good_ensemble_preds[:, i].mean()) + + ensemble_preds = np.hstack(cam_ensemble_preds) + ensemble_vars = np.hstack(cam_ensemble_vars) + ensemble_stacks = np.concatenate(cam_ensemble_stacks, 2) + remove_camera_means(ensemble_stacks, means_camera) + + good_scaled_ensemble_preds = remove_camera_means( + good_ensemble_preds[None, :, :], means_camera)[0] + ensemble_pca, ensemble_ex_var = pca( + good_scaled_ensemble_preds, 3) + + scaled_ensemble_preds = remove_camera_means(ensemble_preds[None, :, :], means_camera)[0] + ensemble_pcs = ensemble_pca.transform(scaled_ensemble_preds) + good_ensemble_pcs = ensemble_pcs[good_frames] + + y_obs = scaled_ensemble_preds + + # compute center of mass + # latent variables (observed) + good_z_t_obs = good_ensemble_pcs # latent variables - true 3D pca + + # ------ Set values for kalman filter ------ + m0 = np.asarray([0.0, 0.0, 0.0]) # initial state: mean + S0 = np.asarray([[np.var(good_z_t_obs[:, 0]), 0.0, 0.0], + [0.0, np.var(good_z_t_obs[:, 1]), 0.0], + [0.0, 0.0, np.var(good_z_t_obs[:, 2])]]) # diagonal: var + + A = np.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) # state-transition matrix, + + # Q = np.asarray([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) <-- state-cov matrix? + + d_t = good_z_t_obs[1:] - good_z_t_obs[:-1] + + C = ensemble_pca.components_.T # Measurement function is inverse transform of PCA + R = np.eye(ensemble_pca.components_.shape[1]) # placeholder diagonal matrix for ensemble var + + cov_matrix = np.cov(d_t.T) + + # Call functions from ensemble_kalman to optimize smooth_param before filtering and smoothing + smooth_param, ms, Vs, nll, nll_values = multicam_optimize_smooth( + cov_matrix, y_obs, m0, S0, C, A, R, ensemble_vars, s_frames, smooth_param) + print(f"NLL is {nll} for {keypoint_ensemble}, smooth_param={smooth_param}") + smooth_param_final = smooth_param + + # Smoothed posterior over ys + y_m_smooth = np.dot(C, ms.T).T + y_v_smooth = np.swapaxes(np.dot(C, np.dot(Vs, C.T)), 0, 1) + + # -------------------------------------- + # final cleanup + # -------------------------------------- + pdindex = make_dlc_pandas_index([keypoint_ensemble], + labels=["x", "y", "likelihood", "x_var", "y_var", "zscore"]) + camera_indices = [] + for camera in range(num_cameras): + camera_indices.append([camera * 2, camera * 2 + 1]) + camera_dfs = {} + for camera, camera_name in enumerate(camera_names): + var = np.empty(y_m_smooth.T[camera_indices[camera][0]].shape) + var[:] = np.nan + eks_pred_x = \ + y_m_smooth.T[camera_indices[camera][0]] + means_camera[camera_indices[camera][0]] + eks_pred_y = \ + y_m_smooth.T[camera_indices[camera][1]] + means_camera[camera_indices[camera][1]] + # compute zscore for EKS to see how it deviates from the ensemble + eks_predictions = np.asarray([eks_pred_x, eks_pred_y]).T + zscore = eks_zscore(eks_predictions, cam_ensemble_preds[camera], cam_ensemble_vars[camera], + min_ensemble_std=zscore_threshold) + pred_arr = np.vstack([ + eks_pred_x, + eks_pred_y, + var, + y_v_smooth[:, camera_indices[camera][0], camera_indices[camera][0]], + y_v_smooth[:, camera_indices[camera][1], camera_indices[camera][1]], + zscore, + ]).T + camera_dfs[camera_name + '_df'] = pd.DataFrame(pred_arr, columns=pdindex) + return camera_dfs, smooth_param_final, nll_values + + +def multicam_optimize_smooth( + cov_matrix, y, m0, s0, C, A, R, ensemble_vars, + s_frames=[(None, None)], + smooth_param=None): + """ + Optimizes s using Nelder-Mead minimization, then smooths using s. + Compatible with the singlecam and multicam examples. + """ + # Optimize smooth_param + if smooth_param is None: + guess = compute_initial_guesses(ensemble_vars) + + # Update xatol during optimization + def callback(xk): + # Update xatol based on the current solution xk + xatol = np.log(np.abs(xk)) * 0.01 + + # Update the options dictionary with the new xatol value + options['xatol'] = xatol + + # Initialize options with initial xatol + options = {'xatol': np.log(guess)} + + # Unpack s_frames + cropped_y = crop_frames(y, s_frames) + + # Minimize negative log likelihood + sol = minimize( + multicam_smooth_min, + x0=guess, # initial smooth param guess + args=(cov_matrix, cropped_y, m0, s0, C, A, R, ensemble_vars), + method='Nelder-Mead', + options=options, + callback=callback, # Pass the callback function + bounds=[(0, None)] + ) + smooth_param = sol.x[0] + print(f'Optimal at s={smooth_param}') + + # Final smooth with optimized s + ms, Vs, nll, nll_values = multicam_smooth_final( + cov_matrix, smooth_param, y, m0, s0, C, A, R, ensemble_vars) + + return smooth_param, ms, Vs, nll, nll_values + + +def multicam_smooth_final(smooth_param, cov_matrix, y, m0, S0, C, A, R, ensemble_vars): + """ + Smooths once using the given smooth_param, used after optimizing smooth_param. + Compatible with the singlecam and multicam example scripts. + """ + # Adjust Q based on smooth_param and cov_matrix + Q = smooth_param * cov_matrix + # Run filtering and smoothing with the current smooth_param + mf, Vf, S, innovs, innov_cov = forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars) + ms, Vs, CV = backward_pass(y, mf, Vf, S, A) + # Compute the negative log-likelihood based on innovations and their covariance + nll, nll_values = compute_nll(innovs, innov_cov) + return ms, Vs, nll, nll_values + + +def multicam_smooth_min(smooth_param, cov_matrix, y, m0, S0, C, A, R, ensemble_vars): + """ + Smooths once using the given smooth_param. Returns only the nll, which is the parameter to + be minimized using the scipy.minimize() function + """ + # Adjust Q based on smooth_param and cov_matrix + Q = smooth_param * cov_matrix + # Run filtering with the current smooth_param + mf, Vf, S, innovs, innov_cov = forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars) + # Compute the negative log-likelihood based on innovations and their covariance + nll, nll_values = compute_nll(innovs, innov_cov) + return nll diff --git a/eks/singlecam_smoother.py b/eks/singlecam_smoother.py new file mode 100644 index 0000000..a66aa8b --- /dev/null +++ b/eks/singlecam_smoother.py @@ -0,0 +1,451 @@ +import time +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +import optax +import pandas as pd +from jax import jit, vmap + +from eks.core import ( + compute_covariance_matrix, + compute_initial_guesses, + eks_zscore, + jax_backward_pass, + jax_ensemble, + jax_forward_pass, + pkf_and_loss, +) +from eks.utils import crop_frames, make_dlc_pandas_index + + +def ensemble_kalman_smoother_singlecam( + markers_3d_array, bodypart_list, smooth_param, s_frames, blocks=[], + ensembling_mode='median', + zscore_threshold=2): + """ + Perform Ensemble Kalman Smoothing on 3D marker data from a single camera. + + Parameters: + markers_3d_array (np.ndarray): 3D array of marker data. + bodypart_list (list): List of body parts. + smooth_param (float): Smoothing parameter. + s_frames (list): List of frames. + ensembling_mode (str): Mode for ensembling ('median' by default). + zscore_threshold (float): Z-score threshold. + + Returns: + tuple: Dataframes with smoothed predictions, final smoothing parameters, NLL values. + """ + + T = markers_3d_array.shape[1] + n_keypoints = markers_3d_array.shape[2] // 3 + n_coords = 2 + + # Compute ensemble statistics + print("Ensembling models") + ensemble_preds, ensemble_vars, keypoints_avg_dict = jax_ensemble( + markers_3d_array, mode=ensembling_mode) + + # Calculate mean and adjusted observations + mean_obs_dict, adjusted_obs_dict, scaled_ensemble_preds = adjust_observations( + keypoints_avg_dict, n_keypoints, ensemble_preds.copy()) + + # Initialize Kalman filter values + m0s, S0s, As, cov_mats, Cs, Rs, ys = initialize_kalman_filter( + scaled_ensemble_preds, adjusted_obs_dict, n_keypoints) + + # Main smoothing function + s_finals, ms, Vs = singlecam_optimize_smooth( + cov_mats, ys, m0s, S0s, Cs, As, Rs, ensemble_vars, + s_frames, smooth_param, blocks) + + y_m_smooths = np.zeros((n_keypoints, T, n_coords)) + y_v_smooths = np.zeros((n_keypoints, T, n_coords, n_coords)) + eks_preds_array = np.zeros(y_m_smooths.shape) + dfs = [] + df_dicts = [] + + # Process each keypoint + for k in range(n_keypoints): + y_m_smooths[k] = np.dot(Cs[k], ms[k].T).T + y_v_smooths[k] = np.swapaxes(np.dot(Cs[k], np.dot(Vs[k], Cs[k].T)), 0, 1) + mean_x_obs = mean_obs_dict[3 * k] + mean_y_obs = mean_obs_dict[3 * k + 1] + + # Computing z-score + eks_preds_array[k] = y_m_smooths[k].copy() + eks_preds_array[k] = np.asarray([eks_preds_array[k].T[0] + mean_x_obs, + eks_preds_array[k].T[1] + mean_y_obs]).T + zscore = eks_zscore(eks_preds_array[k], + ensemble_preds[:, k, :], + ensemble_vars[:, k, :], + min_ensemble_std=zscore_threshold) + + # Final Cleanup + pdindex = make_dlc_pandas_index([bodypart_list[k]], + labels=["x", "y", "likelihood", "x_var", "y_var", + "zscore"]) + var = np.empty(y_m_smooths[k].T[0].shape) + var[:] = np.nan + pred_arr = np.vstack([ + y_m_smooths[k].T[0] + mean_x_obs, + y_m_smooths[k].T[1] + mean_y_obs, + var, + y_v_smooths[k][:, 0, 0], + y_v_smooths[k][:, 1, 1], + zscore, + ]).T + df = pd.DataFrame(pred_arr, columns=pdindex) + dfs.append(df) + df_dicts.append({bodypart_list[k] + '_df': df}) + + return df_dicts, s_finals + + +def adjust_observations(keypoints_avg_dict, n_keypoints, scaled_ensemble_preds): + """ + Adjust observations by computing mean and adjusted observations for each keypoint. + + Parameters: + keypoints_avg_dict (dict): Dictionary of keypoints averages. + n_keypoints (int): Number of keypoints. + scaled_ensemble_preds (np.ndarray): Scaled ensemble predictions. + + Returns: + tuple: Mean observations dictionary, adjusted observations dictionary, scaled ensemble preds. + """ + + # Convert dictionaries to JAX arrays + keypoints_avg_array = jnp.array([keypoints_avg_dict[k] for k in keypoints_avg_dict.keys()]) + x_keys = jnp.array([3 * i for i in range(n_keypoints)]) + y_keys = jnp.array([3 * i + 1 for i in range(n_keypoints)]) + + def compute_adjusted_means(i): + mean_x_obs = jnp.nanmean(keypoints_avg_array[2 * i]) + mean_y_obs = jnp.nanmean(keypoints_avg_array[2 * i + 1]) + adjusted_x_obs = keypoints_avg_array[2 * i] - mean_x_obs + adjusted_y_obs = keypoints_avg_array[2 * i + 1] - mean_y_obs + return mean_x_obs, mean_y_obs, adjusted_x_obs, adjusted_y_obs + + means_and_adjustments = jax.vmap(compute_adjusted_means)(jnp.arange(n_keypoints)) + + mean_x_obs, mean_y_obs, adjusted_x_obs, adjusted_y_obs = means_and_adjustments + + # Convert JAX arrays to NumPy arrays for dictionary keys + x_keys_np = np.array(x_keys) + y_keys_np = np.array(y_keys) + + mean_obs_dict = {x_keys_np[i]: mean_x_obs[i] for i in range(n_keypoints)} + mean_obs_dict.update({y_keys_np[i]: mean_y_obs[i] for i in range(n_keypoints)}) + + adjusted_obs_dict = {x_keys_np[i]: adjusted_x_obs[i] for i in range(n_keypoints)} + adjusted_obs_dict.update({y_keys_np[i]: adjusted_y_obs[i] for i in range(n_keypoints)}) + + # Ensure scaled_ensemble_preds is a JAX array + scaled_ensemble_preds = jnp.array(scaled_ensemble_preds) + + def scale_ensemble_preds(mean_x_obs, mean_y_obs, scaled_ensemble_preds, i): + scaled_ensemble_preds = scaled_ensemble_preds.at[:, i, 0].add(-mean_x_obs) + scaled_ensemble_preds = scaled_ensemble_preds.at[:, i, 1].add(-mean_y_obs) + return scaled_ensemble_preds + + for i in range(n_keypoints): + mean_x = mean_obs_dict[x_keys_np[i]] + mean_y = mean_obs_dict[y_keys_np[i]] + scaled_ensemble_preds = scale_ensemble_preds(mean_x, mean_y, scaled_ensemble_preds, i) + + return mean_obs_dict, adjusted_obs_dict, scaled_ensemble_preds + + +def initialize_kalman_filter(scaled_ensemble_preds, adjusted_obs_dict, n_keypoints): + """ + Initialize the Kalman filter values. + + Parameters: + scaled_ensemble_preds (np.ndarray): Scaled ensemble predictions. + adjusted_obs_dict (dict): Adjusted observations dictionary. + n_keypoints (int): Number of keypoints. + + Returns: + tuple: Initial Kalman filter values and covariance matrices. + """ + + # Convert inputs to JAX arrays + scaled_ensemble_preds = jnp.array(scaled_ensemble_preds) + + # Extract the necessary values from adjusted_obs_dict + adjusted_x_obs_list = [adjusted_obs_dict[3 * i] for i in range(n_keypoints)] + adjusted_y_obs_list = [adjusted_obs_dict[3 * i + 1] for i in range(n_keypoints)] + + # Convert these lists to JAX arrays + adjusted_x_obs_array = jnp.array(adjusted_x_obs_list) + adjusted_y_obs_array = jnp.array(adjusted_y_obs_list) + + def init_kalman(i, adjusted_x_obs, adjusted_y_obs): + m0 = jnp.array([0.0, 0.0]) # initial state: mean + S0 = jnp.array([[jnp.nanvar(adjusted_x_obs), 0.0], + [0.0, jnp.nanvar(adjusted_y_obs)]]) # diagonal: var + A = jnp.array([[1.0, 0], [0, 1.0]]) # state-transition matrix + C = jnp.array([[1, 0], [0, 1]]) # Measurement function + R = jnp.eye(2) # placeholder diagonal matrix for ensemble variance + y_obs = scaled_ensemble_preds[:, i, :] + + return m0, S0, A, C, R, y_obs + + # Use vmap to vectorize the initialization over all keypoints + init_kalman_vmap = jax.vmap(init_kalman, in_axes=(0, 0, 0)) + m0s, S0s, As, Cs, Rs, y_obs_array = init_kalman_vmap(jnp.arange(n_keypoints), + adjusted_x_obs_array, + adjusted_y_obs_array) + cov_mats = compute_covariance_matrix(scaled_ensemble_preds) + return m0s, S0s, As, cov_mats, Cs, Rs, y_obs_array + + +def singlecam_optimize_smooth( + cov_mats, ys, m0s, S0s, Cs, As, Rs, ensemble_vars, + s_frames, smooth_param, blocks=[], maxiter=1000): + """ + Optimize smoothing parameter, and use the result to run the kalman filter-smoother + + Parameters: + cov_mats (np.ndarray): Covariance matrices. + ys (np.ndarray): Observations. Shape (keypoints, frames, coordinates). coordinate is usually 2 + m0s (np.ndarray): Initial mean state. + S0s (np.ndarray): Initial state covariance. + Cs (np.ndarray): Measurement function. + As (np.ndarray): State-transition matrix. + Rs (np.ndarray): Measurement noise covariance. + ensemble_vars (np.ndarray): Ensemble variances. + s_frames (list): List of frames. + smooth_param (float): Smoothing parameter. + blocks (list): List of blocks. + + Returns: + tuple: Final smoothing parameters, smoothed means, smoothed covariances, + negative log-likelihoods, negative log-likelihood values. + """ + + n_keypoints = ys.shape[0] + s_finals = [] + if blocks == []: + for n in range(n_keypoints): + blocks.append([n]) + print(f'Correlated keypoint blocks: {blocks}') + + # Depending on whether we use GPU, choose parallel or sequential smoothing param optimization + try: + _ = jax.device_put(jax.numpy.ones(1), device=jax.devices('gpu')[0]) + print("Using GPU") + + @partial(jit) + def nll_loss_parallel_scan(s, cov_mats, cropped_ys, m0s, S0s, Cs, As, Rs): + s = jnp.exp(s) # To ensure positivity + output = singlecam_smooth_min_parallel(s, cov_mats, cropped_ys, m0s, S0s, Cs, As, Rs) + return output + + loss_function = nll_loss_parallel_scan + except: + print("Using CPU") + + @partial(jit) + def nll_loss_sequential_scan(s, cov_mats, cropped_ys, m0s, S0s, Cs, As, Rs): + s = jnp.exp(s) # To ensure positivity + return singlecam_smooth_min(s, cov_mats, cropped_ys, m0s, S0s, Cs, As, Rs) + + loss_function = nll_loss_sequential_scan + + # Optimize smooth_param + if smooth_param is not None: + s_finals = [smooth_param] + else: + guesses = [] + cropped_ys = [] + for k in range(n_keypoints): + current_guess = compute_initial_guesses(ensemble_vars[:, k, :]) + guesses.append(current_guess) + cropped_ys.append(crop_frames(ys[k], s_frames)) + + cropped_ys = np.array(cropped_ys) # Concatenation of this list along dimension 0 + + # Optimize negative log likelihood + for block in blocks: + s_init = guesses[block[0]] + if s_init <= 0: + s_init = 2 + s_init = jnp.log(s_init) + optimizer = optax.adam(learning_rate=0.25) + opt_state = optimizer.init(s_init) + + selector = np.array(block).astype(int) + cov_mats_sub = cov_mats[selector] + m0s_crop = m0s[selector] + S0s_crop = S0s[selector] + Cs_crop = Cs[selector] + As_crop = As[selector] + Rs_crop = Rs[selector] + y_subset = cropped_ys[selector] + + def step(s, opt_state): + loss, grads = jax.value_and_grad(loss_function)(s, cov_mats_sub, y_subset, + m0s_crop, + S0s_crop, Cs_crop, As_crop, + Rs_crop) + updates, opt_state = optimizer.update(grads, opt_state) + s = optax.apply_updates(s, updates) + return s, opt_state, loss + + prev_loss = jnp.inf + for iteration in range(maxiter): + start_time = time.time() + s_init, opt_state, loss = step(s_init, opt_state) + + # if iteration % 10 == 0 or iteration == maxiter - 1: + # print(f'Iteration {iteration}, Current loss: {loss}, Current s: {s_init}') + + tol = 0.001 * jnp.abs(jnp.log(prev_loss)) + if jnp.linalg.norm(loss - prev_loss) < tol + 1e-6: + # print( + # f'Converged at iteration {iteration} with ' + # f'smoothing parameter {jnp.exp(s_init)}. NLL={loss}') + break + + prev_loss = loss + + s_final = jnp.exp(s_init) # Convert back from log-space + for b in block: + print(f's={s_final} for keypoint {b}') + s_finals.append(s_final) + + s_finals = np.array(s_finals) + # Final smooth with optimized s + ms, Vs = final_forwards_backwards_pass( + cov_mats, s_finals, + ys, m0s, S0s, Cs, As, Rs) + + return s_finals, ms, Vs + + +###### +## Routines that use the sequential kalman filter implementation to arrive at the NLL function +## Note: this code is set up to always run on CPU. +###### + +def inner_smooth_min_routine(y, m0, S0, A, Q, C, R): + # Run filtering with the current smooth_param + _, _, nll = jax_forward_pass(y, m0, S0, A, Q, C, R) + return nll + + +inner_smooth_min_routine_vmap = vmap(inner_smooth_min_routine, in_axes=(0, 0, 0, 0, 0, 0, 0)) + + +def singlecam_smooth_min( + smooth_param, cov_mats, ys, m0s, S0s, Cs, As, Rs): + """ + Smooths once using the given smooth_param. Returns only the nll, which is the parameter to + be minimized using the scipy.minimize() function. + + Parameters: + smooth_param (float): Smoothing parameter. + block (list): List of blocks. + cov_mats (np.ndarray): Covariance matrices. + ys (np.ndarray): Observations. + m0s (np.ndarray): Initial mean state. + S0s (np.ndarray): Initial state covariance. + Cs (np.ndarray): Measurement function. + As (np.ndarray): State-transition matrix. + Rs (np.ndarray): Measurement noise covariance. + + Returns: + float: Negative log-likelihood. + """ + # Adjust Q based on smooth_param and cov_matrix + Qs = smooth_param * cov_mats + nlls = jnp.sum(inner_smooth_min_routine_vmap(ys, m0s, S0s, As, Qs, Cs, Rs)) + return nlls + + +def inner_smooth_min_routine_parallel(y, m0, S0, A, Q, C, R): + # Run filtering with the current smooth_param + means, covariances, NLL = pkf_and_loss(y, m0, S0, A, Q, C, R) + return jnp.sum(NLL) + + +inner_smooth_min_routine_parallel_vmap = jit( + vmap(inner_smooth_min_routine_parallel, in_axes=(0, 0, 0, 0, 0, 0, 0))) + + +# ------------------------------------------------------------------------------------------------ +# Routines that use the parallel scan kalman filter implementation to arrive at the NLL function. +# Note: This should only be run on GPUs +# ------------------------------------------------------------------------------------------------ + + +def singlecam_smooth_min_parallel( + smooth_param, cov_mats, observations, initial_means, initial_covariances, Cs, As, Rs): + """ + Computes the maximum likelihood estimator for the process noise variance (smoothness param). + This function is parallelized to process all keypoints in a given block. + KEY: This function uses the parallel scan algorithm, which has effectively O(log(n)) + runtime on GPUs. On CPUs, it is slower than the jax.lax.scan implementation above. + + Parameters: + smooth_param (float): Smoothing parameter. + block (list): List of blocks. + cov_mats (np.ndarray): Covariance matrices. + ys (np.ndarray): Observations. + m0s (np.ndarray): Initial mean state. + S0s (np.ndarray): Initial state covariance. + Cs (np.ndarray): Measurement function. + As (np.ndarray): State-transition matrix. + Rs (np.ndarray): Measurement noise covariance. + + Returns: + float: Negative log-likelihood. + """ + # Adjust Q based on smooth_param and cov_matrix + Qs = smooth_param * cov_mats + values = inner_smooth_min_routine_parallel_vmap(observations, initial_means, + initial_covariances, As, Qs, Cs, Rs) + return jnp.sum(values) + + +def final_forwards_backwards_pass(process_cov, s, ys, m0s, S0s, Cs, As, Rs): + """ + Perform final smoothing with the optimized smoothing parameters. + + Parameters: + process_cov: Shape (keypoints, state_coords, state_coords). Process noise covariance matrix + s: Shape (keypoints,). We scale the process noise covariance by this value at each keypoint + ys: Shape (keypoints, frames, observation_coordinates). Observations for all keypoints. + m0s: Shape (keypoints, state_coords). Initial ensembled mean state for each keypoint. + S0s: Shape (keypoints, state_coords, state_coords). Initial ensembled state covars fek. + Cs: Shape (keypoints, obs_coords, state_coords). Observation measurement coeff matrix. + As: Shape (keypoints, state_coords, state_coords). Process matrix for each keypoint. + Rs: Shape (keypoints, obs_coords, obs_coords). Measurement noise covariance. + + Returns: + smoothed means: Shape (keypoints, timepoints, coords). + Kalman smoother state estimates outputs for all frames/all keypoints. + smoothed covariances: Shape (num_keypoints, num_state_coordinates, num_state_coordinates) + """ + + # Initialize + n_keypoints = ys.shape[0] + ms_array = [] + Vs_array = [] + Qs = s[:, None, None] * process_cov + + # Run forward and backward pass for each keypoint + for k in range(n_keypoints): + mf, Vf, nll = jax_forward_pass(ys[k], m0s[k], S0s[k], As[k], Qs[k], Cs[k], Rs[k]) + ms, Vs = jax_backward_pass(mf, Vf, As[k], Qs[k]) + ms_array.append(np.array(ms)) + Vs_array.append(np.array(Vs)) + + smoothed_means = np.stack(ms_array, axis=0) + smoothed_covariances = np.stack(Vs_array, axis=0) + + return smoothed_means, smoothed_covariances diff --git a/eks/singleview_smoother.py b/eks/singleview_smoother.py deleted file mode 100644 index 44868d8..0000000 --- a/eks/singleview_smoother.py +++ /dev/null @@ -1,148 +0,0 @@ -import numpy as np -import pandas as pd -from eks.utils import make_dlc_pandas_index -from eks.core import ensemble, eks_zscore, optimize_smoothing_params, \ - filter_smooth_nll - - -# ----------------------- -# funcs for single-view -# ----------------------- -def ensemble_kalman_smoother_single_view( - markers_list, keypoint_ensemble, smooth_param, ensembling_mode='median', - zscore_threshold=2, verbose=False): - """ Use an identity observation matrix and smoothes by adjusting the smoothing parameter in the - state-covariance matrix. - - Parameters - ---------- - markers_list : list of list of pd.DataFrames - each list element is a list of dataframe predictions from one ensemble member. - keypoint_ensemble : str - the name of the keypoint to be ensembled and smoothed - smooth_param : float - ranges from .01-20 (smaller values = more smoothing) - ensembling_mode: - the function used for ensembling ('mean', 'median', or 'confidence_weighted_mean') - zscore_threshold: - Minimum std threshold to reduce the effect of low ensemble std on a zscore metric - (default 2). - verbose: bool - If True, progress will be printed for the user. - Returns - ------- - - Returns - ------- - dict - keypoint_df: dataframe containing smoothed markers for one keypoint; same format as input - dataframes - """ - - # -------------------------------------------------------------- - # interpolate right cam markers to left cam timestamps - # -------------------------------------------------------------- - keys = [keypoint_ensemble + '_x', keypoint_ensemble + '_y'] - x_key = keys[0] - y_key = keys[1] - - # compute ensemble median - ensemble_preds, ensemble_vars, ensemble_stacks, keypoints_mean_dict, keypoints_var_dict, \ - keypoints_stack_dict = ensemble(markers_list, keys, mode=ensembling_mode) - mean_x_obs = np.nanmean(keypoints_mean_dict[x_key]) - mean_y_obs = np.nanmean(keypoints_mean_dict[y_key]) - x_t_obs, y_t_obs = \ - keypoints_mean_dict[x_key] - mean_x_obs, keypoints_mean_dict[y_key] - mean_y_obs - # z_t_obs = np.vstack((x_t_obs, y_t_obs)) # latent variables - true x and y - - # ------ Set values for kalman filter ------ - m0 = np.asarray([0.0, 0.0]) # initial state: mean - S0 = np.asarray([[np.nanvar(x_t_obs), 0.0], [0.0 , np.nanvar(y_t_obs)]]) # diagonal: var - - A = np.asarray([[1.0, 0], [0, 1.0]]) # state-transition matrix, - cov_matrix = np.asarray([[1, 0], [0, 1]]) # state covariance matrix; smaller = more smoothing - C = np.asarray([[1, 0], [0, 1]]) # Measurement function - R = np.eye(2) # placeholder diagonal matrix for ensemble variance - - scaled_ensemble_preds = ensemble_preds.copy() - scaled_ensemble_preds[:, 0] -= mean_x_obs - scaled_ensemble_preds[:, 1] -= mean_y_obs - - y_obs = scaled_ensemble_preds - - ''' - if verbose: - print(f"filtering {keypoint_ensemble}...") - mf, Vf, S = filtering_pass(y_obs, m0, S0, C, R, A, Q, ensemble_vars) - if verbose: - print("done filtering") - y_m_filt = np.dot(C, mf.T).T - y_v_filt = np.swapaxes(np.dot(C, np.dot(Vf, C.T)), 0, 1) - - # Do the smoothing step - if verbose: - print(f"smoothing {keypoint_ensemble}...") - ms, Vs, _ = smooth_backward(y_obs, mf, Vf, S, A, Q, C) - if verbose: - print("done smoothing") - # compute NLL - nll = compute_nll_2(y_obs, mf, S, C) - nll_values = compute_nll_2_steps(y_obs, mf, S, C) - ''' - - - # Call functions from ensemble_kalman to optimize smooth_param before filtering and smoothing - if smooth_param is None: - smooth_param_final = \ - optimize_smoothing_params(cov_matrix, y_obs, m0, S0, C, A, R, ensemble_vars) - else: - smooth_param_final = smooth_param - ms, Vs, nll, nll_values = \ - filter_smooth_nll(cov_matrix, smooth_param_final, y_obs, m0, S0, C, A, R, ensemble_vars) - print(f"NLL is {nll} for {keypoint_ensemble}, smooth_param={smooth_param_final}") - - # Smoothed posterior over y - y_m_smooth = np.dot(C, ms.T).T - y_v_smooth = np.swapaxes(np.dot(C, np.dot(Vs, C.T)), 0, 1) - - # compute zscore for EKS to see how it deviates from the ensemble - eks_predictions = y_m_smooth.copy() - eks_predictions = \ - np.asarray([eks_predictions.T[0] + mean_x_obs, eks_predictions.T[1] + mean_y_obs]).T - zscore = \ - eks_zscore(eks_predictions, ensemble_preds, ensemble_vars, - min_ensemble_std=zscore_threshold) - - # -------------------------------------- - # final cleanup - # -------------------------------------- - pdindex = make_dlc_pandas_index([keypoint_ensemble], - labels=["x", "y", "likelihood", "x_var", "y_var", "zscore"]) - var = np.empty(y_m_smooth.T[0].shape) - var[:] = np.nan - pred_arr = np.vstack([ - y_m_smooth.T[0] + mean_x_obs, - y_m_smooth.T[1] + mean_y_obs, - var, - y_v_smooth[:, 0, 0], - y_v_smooth[:, 1, 1], - zscore, - ]).T - df = pd.DataFrame(pred_arr, columns=pdindex) - return {keypoint_ensemble + '_df': df}, smooth_param_final, nll_values - - -''' -Plotting NLL traces (paste in before final cleanup) - # Plot nll values against time - plt.plot(range(len(nll_values)), nll_values) - plt.xlabel('Time Step') - plt.ylabel('Negative Log Likelihood (nll)') - plt.title(f'Negative Log Likelihood vs Time for IBL Pupil s={smooth_param}') - plt.grid(True) - - # Save the plot as a PDF file - plt.savefig('nll_plot.pdf') - - plt.show() -''' diff --git a/eks/slp_test.py b/eks/slp_test.py deleted file mode 100644 index d888080..0000000 --- a/eks/slp_test.py +++ /dev/null @@ -1,39 +0,0 @@ -from sleap_io.io.slp import read_labels -import os - -# python scripts/singlecam_example.py --input-dir ./data/fish-slp --data-type slp --bodypart-list chin mouth head middle tail - -base_dir = "data/fish-slp/" -filenames = [ - "4fish.v009.slp.240422_114719.predictions.slp" - -] - -''' - "4fish.v009.slp.240422_114719.predictions.slp" - "4fish.v009.slp.240422_154713.predictions.slp" - "4fish.v009.slp.240422_154713.predictions.slp", - "4fish.v009.slp.240422_182825.predictions.slp", - "4fish.v009.slp.240423_113502.predictions.slp", - "4fish.v009.slp.240423_141211.predictions.slp", -''' - -for f, filename in enumerate(filenames): - filepath = os.path.join(base_dir, filename) - labels = read_labels(filepath) - print(labels[16][1][4].x) -# labels.labeled_frames[frame].instances[animal#][bodypart#].x) - - ''' - nodes = labels.skeletons[0].nodes - print(nodes) - keypoint_names = [] - for node in enumerate(nodes): - keypoint_name = node[1].name - keypoint_names.append(keypoint_name) - print(keypoint_names) - ''' - # labeled_frame = labels[0] - # instance = labeled_frame[0] - # print(f'Instance 1: {instance}') - # print(instance[0]) \ No newline at end of file diff --git a/eks/utils.py b/eks/utils.py index 513ea27..46b0012 100644 --- a/eks/utils.py +++ b/eks/utils.py @@ -1,4 +1,6 @@ import os + +import matplotlib.pyplot as plt import numpy as np import pandas as pd from sleap_io.io.slp import read_labels @@ -16,71 +18,71 @@ def convert_lp_dlc(df_lp, keypoint_names, model_name=None): df_dlc = {} for feat in keypoint_names: for feat2 in ['x', 'y', 'likelihood']: - if model_name is None: - df_dlc[f'{feat}_{feat2}'] = df_lp.loc[:, (feat, feat2)] - else: - df_dlc[f'{feat}_{feat2}'] = df_lp.loc[:, (model_name, feat, feat2)] - df_dlc = pd.DataFrame(df_dlc, index=df_lp.index) + try: + if model_name is None: + col_tuple = (feat, feat2) + else: + col_tuple = (model_name, feat, feat2) + + # Skip columns with any unnamed level + if any(level.startswith('Unnamed') for level in col_tuple if + isinstance(level, str)): + continue + + df_dlc[f'{feat}_{feat2}'] = df_lp.loc[:, col_tuple] + except KeyError: + # If the specified column does not exist, skip it + continue + df_dlc = pd.DataFrame(df_dlc, index=df_lp.index) return df_dlc def convert_slp_dlc(base_dir, slp_file): - print(f'Reading {base_dir}/{slp_file}') # Read data from .slp file filepath = os.path.join(base_dir, slp_file) labels = read_labels(filepath) - # Determine the maximum number of instances + # Determine the maximum number of instances and keypoints max_instances = len(labels[0].instances) + keypoint_names = [node.name for node in labels[0].instances[0].points.keys()] + print(keypoint_names) + num_keypoints = len(keypoint_names) + + # Initialize a NumPy array to store the data + num_frames = len(labels.labeled_frames) + data = np.zeros((num_frames, max_instances, num_keypoints, 3)) # 3 for x, y, likelihood - data = [] # List to store data for DataFrame + # Fill the NumPy array with data for i, labeled_frame in enumerate(labels.labeled_frames): - frame_data = {} # Dictionary to store data for current frame for j, instance in enumerate(labeled_frame.instances): - # Check if the instance number exceeds the maximum expected if j >= max_instances: break - - for keypoint_node in instance.points.keys(): - # Extract the name from keypoint_node - keypoint_name = keypoint_node.name - # Extract x, y, and likelihood from the PredictedPoint object + for k, keypoint_node in enumerate(instance.points.keys()): point = instance.points[keypoint_node] - - # Ensure x and y are floats, handle blank entries by converting to 0 - x = point.x # if not np.isnan(point.x) else 0 - y = point.y # if not np.isnan(point.y) else 0 - likelihood = point.score + 1e-6 - - # Construct the column name based on instance number and keypoint name - column_name_x = f"{j + 1}_{keypoint_name}_x" - column_name_y = f"{j + 1}_{keypoint_name}_y" - column_name_likelihood = f"{j + 1}_{keypoint_name}_likelihood" - - # Add data to frame_data dictionary - frame_data[column_name_x] = x - frame_data[column_name_y] = y - frame_data[column_name_likelihood] = likelihood - - # Append frame_data to the data list - data.append(frame_data) - - # Create DataFrame from the list of frame data - df = pd.DataFrame(data) + data[i, j, k, 0] = point.x if not np.isnan(point.x) else 0 + data[i, j, k, 1] = point.y if not np.isnan(point.y) else 0 + data[i, j, k, 2] = point.score + 1e-6 + + # Reshape data to 2D array for DataFrame creation + reshaped_data = data.reshape(num_frames, -1) + columns = [] + for j in range(max_instances): + for keypoint_name in keypoint_names: + columns.append(f"{j + 1}_{keypoint_name}_x") + columns.append(f"{j + 1}_{keypoint_name}_y") + columns.append(f"{j + 1}_{keypoint_name}_likelihood") + + # Create DataFrame from the reshaped data + df = pd.DataFrame(reshaped_data, columns=columns) df.to_csv(f'{slp_file}.csv', index=False) - print(f"DataFrame successfully converted to CSV: input.csv") + print(f"File read. See read-in data at {slp_file}.csv") return df -# --------------------------------------------- -# Loading + Formatting CSV<->DataFrame -# --------------------------------------------- - - def format_data(input_dir, data_type): input_files = os.listdir(input_dir) - markers_list = [] + input_dfs_list = [] # Extracting markers from data # Applies correct format conversion and stores each file's markers in a list for input_file in input_files: @@ -89,30 +91,34 @@ def format_data(input_dir, data_type): if not input_file.endswith('.slp'): continue markers_curr = convert_slp_dlc(input_dir, input_file) + keypoint_names = [c[1] for c in markers_curr.columns[::3]] markers_curr_fmt = markers_curr elif data_type == 'lp' or 'dlc': if not input_file.endswith('csv'): continue - markers_curr = pd.read_csv(os.path.join(input_dir, input_file), header=[0, 1, 2], index_col=0) + markers_curr = pd.read_csv( + os.path.join(input_dir, input_file), header=[0, 1, 2], index_col=0) keypoint_names = [c[1] for c in markers_curr.columns[::3]] model_name = markers_curr.columns[0][0] if data_type == 'lp': - markers_curr_fmt = convert_lp_dlc(markers_curr, keypoint_names, model_name=model_name) + markers_curr_fmt = convert_lp_dlc( + markers_curr, keypoint_names, model_name=model_name) else: markers_curr_fmt = markers_curr - markers_list.append(markers_curr_fmt) + # markers_curr_fmt.to_csv('fmt_input.csv', index=False) + input_dfs_list.append(markers_curr_fmt) - if len(markers_list) == 0: + if len(input_dfs_list) == 0: raise FileNotFoundError(f'No marker input files found in {input_dir}') - markers_eks = make_output_dataframe(markers_curr) + output_df = make_output_dataframe(markers_curr) # returns both the formatted marker data and the empty dataframe for EKS output - return markers_list, markers_eks + return input_dfs_list, output_df, keypoint_names -# Making empty DataFrame for EKS output def make_output_dataframe(markers_curr): + ''' Makes empty DataFrame for EKS output ''' markers_eks = markers_curr.copy() # Check if the columns Index is a MultiIndex @@ -148,11 +154,12 @@ def make_output_dataframe(markers_curr): markers_eks[col].values[:] = np.nan # Write DataFrame to CSV - output_csv = 'output_dataframe.csv' - dataframe_to_csv(markers_eks, output_csv) + # output_csv = 'output_dataframe.csv' + # dataframe_to_csv(markers_eks, output_csv) return markers_eks + def dataframe_to_csv(df, filename): """ Converts a DataFrame to a CSV file. @@ -166,14 +173,97 @@ def dataframe_to_csv(df, filename): """ try: df.to_csv(filename, index=False) - print(f"DataFrame successfully converted to CSV: {filename}") except Exception as e: print("Error:", e) -def populate_output_dataframe(keypoint_df, keypoint_ensemble, markers_eks): +def populate_output_dataframe(keypoint_df, keypoint_ensemble, output_df, + key_suffix=''): # key_suffix only required for multi-camera setups for coord in ['x', 'y', 'zscore']: src_cols = ('ensemble-kalman_tracker', f'{keypoint_ensemble}', coord) - dst_cols = ('ensemble-kalman_tracker', f'{keypoint_ensemble}', coord) - markers_eks.loc[:, dst_cols] = keypoint_df.loc[:, src_cols] - return markers_eks + dst_cols = ('ensemble-kalman_tracker', f'{keypoint_ensemble}' + key_suffix, coord) + output_df.loc[:, dst_cols] = keypoint_df.loc[:, src_cols] + + return output_df + + +def plot_results(output_df, input_dfs_list, + key, s_final, nll_values, idxs, save_dir, smoother_type): + if nll_values is None: + fig, axes = plt.subplots(4, 1, figsize=(9, 10)) + else: + fig, axes = plt.subplots(5, 1) + + for ax, coord in zip(axes, ['x', 'y', 'likelihood', 'zscore']): + # Rename axes label for likelihood and zscore coordinates + if coord == 'likelihood': + ylabel = 'model likelihoods' + elif coord == 'zscore': + ylabel = 'EKS disagreement' + else: + ylabel = coord + + # plot individual models + ax.set_ylabel(ylabel, fontsize=12) + if coord == 'zscore': + ax.plot(output_df.loc[slice(*idxs), ('ensemble-kalman_tracker', key, coord)], + color='k', linewidth=2) + ax.set_xlabel('Time (frames)', fontsize=12) + continue + for m, markers_curr in enumerate(input_dfs_list): + ax.plot( + markers_curr.loc[slice(*idxs), key + f'_{coord}'], color=[0.5, 0.5, 0.5], + label='Individual models' if m == 0 else None, + ) + # plot eks + if coord == 'likelihood': + continue + ax.plot( + output_df.loc[slice(*idxs), ('ensemble-kalman_tracker', key, coord)], + color='k', linewidth=2, label='EKS', + ) + if coord == 'x': + ax.legend() + + # Plot nll_values against the time axis + if nll_values is not None: + nll_values_subset = nll_values[idxs[0]:idxs[1]] + axes[-1].plot(range(*idxs), nll_values_subset, color='k', linewidth=2) + axes[-1].set_ylabel('EKS NLL', fontsize=12) + + plt.suptitle(f'EKS results for {key}, smoothing = {s_final}', fontsize=14) + plt.tight_layout() + save_file = os.path.join(save_dir, + f'{smoother_type}_{key}.pdf') + plt.savefig(save_file) + plt.close() + print(f'see example EKS output at {save_file}') + + +def crop_frames(y, s_frames): + """ Crops frames as specified by s_frames to be used for auto-tuning s.""" + # Create an empty list to store arrays + result = [] + + for frame in s_frames: + # Unpack the frame, setting defaults for empty start or end + start, end = frame + # Default start to 0 if not specified (and adjust for zero indexing) + start = start - 1 if start is not None else 0 + # Default end to the length of ys if not specified + end = end if end is not None else len(y) + + # Cap the indices within valid range + start = max(0, start) + end = min(len(y), end) + + # Validate the keys + if start >= end: + raise ValueError(f"Index range ({start + 1}, {end}) " + f"is out of bounds for the list of length {len(y)}.") + + # Use numpy slicing to preserve the data structure + result.append(y[start:end]) + + # Concatenate all slices into a single numpy array + return np.concatenate(result) diff --git a/scripts/EKSNLLPlot.py b/scripts/EKSNLLPlot.py deleted file mode 100644 index caab12d..0000000 --- a/scripts/EKSNLLPlot.py +++ /dev/null @@ -1,54 +0,0 @@ -import subprocess -import matplotlib.pyplot as plt - -''' -Generates a plot of EKS Negative Log Likelihood results at different Smoothing Parameters -''' - -# Smooth params to try: -smooth_params = [0.01, 0.1, 1, 5, 10, 100, 1000] # parameters to be tested - -# Collect output nll lists as a list of lists -nll_values_list = [] - -print('Starting runs') -for param in smooth_params: - # Run existing Python script with different parameters - result = subprocess.run([ - 'python', 'scripts/multicam_example.py', - '--csv-dir', './data/mirror-mouse', - '--bodypart-list', 'paw1LH', 'paw2LF', 'paw3RF', 'paw4RH', - '--camera-names', 'top', 'bot', '--s', str(param)], - capture_output=True, - text=True - ) - print(f'Run successful at smooth_param {param}') - - # Extract nll_values from result - output_lines = result.stdout.strip().split('\n') - nll_values = [] - for line in output_lines: - if line.startswith('NLL is'): - # Split the line to extract the NLL value - nll = float(line.split("is")[1].split()[0]) - nll_values.append(nll) - - # Store nll_values in the list - nll_values_list.append(nll_values) - -# Plot results for each list of nll_values -for i, nll_values in enumerate(nll_values_list): - # Create x-axis values evenly spaced - x_values = [i] * len(nll_values) - plt.plot(x_values, nll_values, marker='o', label=f'Smoothing Param: {smooth_params[i]}') - -plt.xlabel('Smoothing Parameter') -plt.ylabel('NLL') -plt.xticks(range(len(smooth_params)), smooth_params) # Set x-axis ticks to smooth_params values -plt.title('mirror-mouse multi-cam EKS NLL vs Smoothing Parameter') -plt.grid(True) - -# Save plot as PDF -plt.savefig('nll_vs_smoothing_param.pdf') -print('PDF MADE') -plt.show() diff --git a/scripts/multiview_paw_example.py b/scripts/ibl_paw_multiview_example.py similarity index 84% rename from scripts/multiview_paw_example.py rename to scripts/ibl_paw_multiview_example.py index 6f5d108..cec548a 100644 --- a/scripts/multiview_paw_example.py +++ b/scripts/ibl_paw_multiview_example.py @@ -1,39 +1,35 @@ """Example script for ibl-paw dataset.""" +import os import matplotlib.pyplot as plt import numpy as np -import os import pandas as pd +from eks.command_line_args import handle_io, handle_parse_args +from eks.ibl_paw_multiview_smoother import ensemble_kalman_smoother_ibl_paw from eks.utils import convert_lp_dlc -from eks.multiview_pca_smoother import ensemble_kalman_smoother_paw_asynchronous -from general_scripting import handle_io, handle_parse_args - -# collect user-provided args -args = handle_parse_args('paw') -csv_dir = os.path.abspath(args.input_dir) -save_dir = args.save_dir +# Collect User-Provided Args +smoother_type = 'paw' +args = handle_parse_args(smoother_type) +input_dir = os.path.abspath(args.input_dir) +data_type = args.data_type # Note: LP and DLC are .csv, SLP is .slp +save_dir = handle_io(input_dir, args.save_dir) # defaults to outputs\ +save_filename = args.save_filename s = args.s quantile_keep_pca = args.quantile_keep_pca - - -# --------------------------------------------- -# run EKS algorithm -# --------------------------------------------- - -# handle I/O -save_dir = handle_io(csv_dir, save_dir) +s_frames = args.s_frames # frames to be used for automatic optimization (only if no --s flag) # load files and put them in correct format markers_list_left = [] markers_list_right = [] timestamps_left = None timestamps_right = None -filenames = os.listdir(csv_dir) +filenames = os.listdir(input_dir) for filename in filenames: if 'timestamps' not in filename: - markers_curr = pd.read_csv(os.path.join(csv_dir, filename), header=[0, 1, 2], index_col=0) + markers_curr = pd.read_csv( + os.path.join(input_dir, filename), header=[0, 1, 2], index_col=0) keypoint_names = [c[1] for c in markers_curr.columns[::3]] model_name = markers_curr.columns[0][0] markers_curr_fmt = convert_lp_dlc(markers_curr, keypoint_names, model_name=model_name) @@ -53,9 +49,9 @@ markers_list_right.append(markers_curr_fmt) else: if 'left' in filename: - timestamps_left = np.load(os.path.join(csv_dir, filename)) + timestamps_left = np.load(os.path.join(input_dir, filename)) else: - timestamps_right = np.load(os.path.join(csv_dir, filename)) + timestamps_right = np.load(os.path.join(input_dir, filename)) # file checks if timestamps_left is None or timestamps_right is None: @@ -67,7 +63,7 @@ # run eks df_dicts, markers_list_left_cam, markers_list_right_cam = \ - ensemble_kalman_smoother_paw_asynchronous( + ensemble_kalman_smoother_ibl_paw( markers_list_left_cam=markers_list_left, markers_list_right_cam=markers_list_right, timestamps_left_cam=timestamps_left, diff --git a/scripts/ibl_pupil_example.py b/scripts/ibl_pupil_example.py new file mode 100644 index 0000000..b6848e7 --- /dev/null +++ b/scripts/ibl_pupil_example.py @@ -0,0 +1,53 @@ +"""Example script for ibl-pupil dataset.""" +import os + +from eks.command_line_args import handle_io, handle_parse_args +from eks.ibl_pupil_smoother import ensemble_kalman_smoother_ibl_pupil +from eks.utils import format_data, plot_results + +# Collect User-Provided Args +smoother_type = 'pupil' +args = handle_parse_args(smoother_type) +input_dir = os.path.abspath(args.input_dir) +data_type = args.data_type # Note: LP and DLC are .csv, SLP is .slp +save_dir = handle_io(input_dir, args.save_dir) # defaults to outputs\ +save_filename = args.save_filename +diameter_s = args.diameter_s # defaults to automatic optimization +com_s = args.com_s # defaults to automatic optimization +s_frames = args.s_frames # frames to be used for automatic optimization (only if no --s flag) + +# Load and format input files and prepare an empty DataFrame for output. +input_dfs_list, output_df, keypoint_names = format_data(input_dir, data_type) + +# run eks +df_dicts, smooth_params, nll_values = ensemble_kalman_smoother_ibl_pupil( + markers_list=input_dfs_list, + keypoint_names=keypoint_names, + tracker_name='ensemble-kalman_tracker', + smooth_params=[diameter_s, com_s], + s_frames=s_frames +) + +save_file = os.path.join(save_dir, 'kalman_smoothed_pupil_traces.csv') +print(f'saving smoothed predictions to {save_file}') +df_dicts['markers_df'].to_csv(save_file) + +save_file = os.path.join(save_dir, 'kalman_smoothed_latents.csv') +print(f'saving latents to {save_file}') +df_dicts['latents_df'].to_csv(save_file) + + +# --------------------------------------------- +# plot results +# --------------------------------------------- + +# plot results +plot_results(output_df=df_dicts['markers_df'], + input_dfs_list=input_dfs_list, + key=f'{keypoint_names[-1]}', + idxs=(0, 500), + s_final=(smooth_params[0], smooth_params[1]), + nll_values=nll_values, + save_dir=save_dir, + smoother_type=smoother_type + ) diff --git a/scripts/multicam_example.py b/scripts/multicam_example.py index fa423d6..290b6c2 100644 --- a/scripts/multicam_example.py +++ b/scripts/multicam_example.py @@ -1,105 +1,70 @@ """Example script for multi-camera datasets.""" - -import matplotlib.pyplot as plt import os -from eks.multiview_pca_smoother import ensemble_kalman_smoother_multi_cam -from general_scripting import handle_io, handle_parse_args -from eks.utils import format_data +from eks.command_line_args import handle_io, handle_parse_args +from eks.multicam_smoother import ensemble_kalman_smoother_multicam +from eks.utils import format_data, plot_results, populate_output_dataframe -# collect user-provided args +# Collect User-Provided Args smoother_type = 'multicam' args = handle_parse_args(smoother_type) - input_dir = os.path.abspath(args.input_dir) - -# Note: LP and DLC are .csv, SLP is .slp -data_type = args.data_type - -# Find save directory if specified, otherwise defaults to outputs\ -save_dir = handle_io(input_dir, args.save_dir) - +data_type = args.data_type # Note: LP and DLC are .csv, SLP is .slp +save_dir = handle_io(input_dir, args.save_dir) # defaults to outputs\ +save_filename = args.save_filename bodypart_list = args.bodypart_list +s = args.s # defaults to automatic optimization +s_frames = args.s_frames # frames to be used for automatic optimization (only if no --s flag) +blocks = args.blocks camera_names = args.camera_names -num_cameras = len(camera_names) quantile_keep_pca = args.quantile_keep_pca -s = args.s # optional, defaults to automatic optimization # Load and format input files and prepare an empty DataFrame for output. -# markers_list : list of input DataFrames -# markers_eks : empty DataFrame for EKS output -markers_list, markers_eks = format_data(input_dir, data_type) +input_dfs_list, output_df, keypoint_names = format_data(input_dir, data_type) +if bodypart_list is None: + bodypart_list = keypoint_names +print(f'Input data has been read in for the following keypoints:\n{bodypart_list}') # loop over keypoints; apply eks to each individually +# Note: all camera views must be stored in the same csv file for keypoint_ensemble in bodypart_list: - # this structure assumes all camera views are stored in the same csv file - # here we separate body part predictions by camera view - marker_list_by_cam = [[] for _ in range(num_cameras)] - for markers_curr in markers_list: + # Separate body part predictions by camera view + marker_list_by_cam = [[] for _ in range(len(camera_names))] + for markers_curr in input_dfs_list: for c, camera_name in enumerate(camera_names): non_likelihood_keys = [ key for key in markers_curr.keys() if camera_names[c] in key and keypoint_ensemble in key ] marker_list_by_cam[c].append(markers_curr[non_likelihood_keys]) + # run eks - cameras_df, s_final, nll_values = ensemble_kalman_smoother_multi_cam( + cameras_df_dict, s_final, nll_values = ensemble_kalman_smoother_multicam( markers_list_cameras=marker_list_by_cam, keypoint_ensemble=keypoint_ensemble, smooth_param=s, quantile_keep_pca=quantile_keep_pca, camera_names=camera_names, + s_frames=s_frames ) + # put results into new dataframe for camera in camera_names: - df_tmp = cameras_df[f'{camera}_df'] - for coord in ['x', 'y', 'zscore']: - src_cols = ('ensemble-kalman_tracker', f'{keypoint_ensemble}', coord) - dst_cols = ('ensemble-kalman_tracker', f'{keypoint_ensemble}_{camera}', coord) - markers_eks.loc[:, dst_cols] = df_tmp.loc[:, src_cols] + cameras_df = cameras_df_dict[f'{camera}_df'] + populate_output_dataframe(cameras_df, keypoint_ensemble, output_df, + key_suffix=f'_{camera}') # save eks results -markers_eks.to_csv(os.path.join(save_dir, 'eks.csv')) +save_filename = save_filename or f'{smoother_type}_{s_final}.csv' +output_df.to_csv(os.path.join(save_dir, save_filename)) -# --------------------------------------------- # plot results -# --------------------------------------------- - -# select example keypoint from example camera view -kp = bodypart_list[0] -cam = camera_names[0] -idxs = (0, 500) - -fig, axes = plt.subplots(4, 1, figsize=(9, 6)) - -for ax, coord in zip(axes, ['x', 'y', 'likelihood', 'zscore']): - # plot individual models - ax.set_ylabel(coord, fontsize=12) - if coord == 'zscore': - ax.plot( - markers_eks.loc[slice(*idxs), ('ensemble-kalman_tracker', f'{kp}_{cam}', coord)], - color=[0.5, 0.5, 0.5]) - ax.set_xlabel('Time (frames)', fontsize=12) - continue - for m, markers_curr in enumerate(markers_list): - ax.plot( - markers_curr.loc[slice(*idxs), f'{kp}_{cam}_{coord}'], color=[0.5, 0.5, 0.5], - label='Individual models' if m == 0 else None, - ) - # plot eks - if coord == 'likelihood': - continue - ax.plot( - markers_eks.loc[slice(*idxs), ('ensemble-kalman_tracker', f'{kp}_{cam}', coord)], - color='k', linewidth=2, label='EKS', - ) - if coord == 'x': - ax.legend() - -plt.suptitle(f'EKS results for {kp} ({cam} view)', fontsize=14) -plt.tight_layout() - -save_file = os.path.join(save_dir, 'example_multicam_eks_result.pdf') -plt.savefig(save_file) -plt.close() -print(f'see example EKS output at {save_file}') +plot_results(output_df=output_df, + input_dfs_list=input_dfs_list, + key=f'{bodypart_list[-1]}_{camera_names[0]}', + idxs=(0, 500), + s_final=s_final, + nll_values=nll_values, + save_dir=save_dir, + smoother_type=smoother_type + ) diff --git a/scripts/pupil_example.py b/scripts/pupil_example.py deleted file mode 100644 index b885d09..0000000 --- a/scripts/pupil_example.py +++ /dev/null @@ -1,105 +0,0 @@ -"""Example script for ibl-pupil dataset.""" - -import matplotlib.pyplot as plt -import numpy as np -import os -import pandas as pd - -from eks.utils import convert_lp_dlc -from eks.pupil_smoother import ensemble_kalman_smoother_pupil -from general_scripting import handle_io, handle_parse_args - - -# collect user-provided args -args = handle_parse_args('pupil') -csv_dir = os.path.abspath(args.input_dir) -save_dir = args.save_dir - - -# --------------------------------------------- -# run EKS algorithm -# --------------------------------------------- - -# handle I/O -save_dir = handle_io(csv_dir, save_dir) - -# load files and put them in correct format -csv_files = os.listdir(csv_dir) -markers_list = [] -for csv_file in csv_files: - if not csv_file.endswith('csv'): - continue - markers_curr = pd.read_csv(os.path.join(csv_dir, csv_file), header=[0, 1, 2], index_col=0) - keypoint_names = [c[1] for c in markers_curr.columns[::3]] - model_name = markers_curr.columns[0][0] - markers_curr_fmt = convert_lp_dlc(markers_curr, keypoint_names, model_name=model_name) - markers_list.append(markers_curr_fmt) -if len(markers_list) == 0: - raise FileNotFoundError(f'No marker csv files found in {csv_dir}') - -# parameters hand-picked for smoothing purposes (diameter_s, com_s, com_s) -state_transition_matrix = np.asarray([ - [args.diameter_s, 0, 0], - [0, args.com_s, 0], - [0, 0, args.com_s] -]) -print(f'Smoothing matrix: {state_transition_matrix}') - -# run eks -df_dicts = ensemble_kalman_smoother_pupil( - markers_list=markers_list, - keypoint_names=keypoint_names, - tracker_name='ensemble-kalman_tracker', - state_transition_matrix=state_transition_matrix, -) - -save_file = os.path.join(save_dir, 'kalman_smoothed_pupil_traces.csv') -print(f'saving smoothed predictions to {save_file }') -df_dicts['markers_df'].to_csv(save_file) - -save_file = os.path.join(save_dir, 'kalman_smoothed_latents.csv') -print(f'saving latents to {save_file}') -df_dicts['latents_df'].to_csv(save_file) - - -# --------------------------------------------- -# plot results -# --------------------------------------------- - -# select example keypoint -kp = keypoint_names[0] -idxs = (0, 500) - -fig, axes = plt.subplots(4, 1, figsize=(9, 8)) - -for ax, coord in zip(axes, ['x', 'y', 'likelihood', 'zscore']): - # plot individual models - ax.set_ylabel(coord, fontsize=12) - if coord == 'zscore': - ax.plot( - df_dicts['markers_df'].loc[slice(*idxs), ('ensemble-kalman_tracker', f'{kp}', coord)], - color='k', linewidth=2) - ax.set_xlabel('Time (frames)', fontsize=12) - continue - for m, markers_curr in enumerate(markers_list): - ax.plot( - markers_curr.loc[slice(*idxs), f'{kp}_{coord}'], color=[0.5, 0.5, 0.5], - label='Individual models' if m == 0 else None, - ) - # plot eks - if coord == 'likelihood': - continue - ax.plot( - df_dicts['markers_df'].loc[slice(*idxs), ('ensemble-kalman_tracker', kp, coord)], - color='k', linewidth=2, label='EKS', - ) - if coord == 'x': - ax.legend() - -plt.suptitle(f'EKS results for {kp}', fontsize=14) -plt.tight_layout() - -save_file = os.path.join(save_dir, 'example_pupil_eks_result.pdf') -plt.savefig(save_file) -plt.close() -print(f'see example EKS output at {save_file}') diff --git a/scripts/singlecam_example.py b/scripts/singlecam_example.py index 34ac68e..7c5ec4f 100644 --- a/scripts/singlecam_example.py +++ b/scripts/singlecam_example.py @@ -1,112 +1,71 @@ """Example script for single-camera datasets.""" - -import matplotlib.pyplot as plt import os -from general_scripting import handle_io, handle_parse_args -from eks.utils import format_data, populate_output_dataframe -from eks.singleview_smoother import ensemble_kalman_smoother_single_view +import numpy as np +from eks.command_line_args import handle_io, handle_parse_args +from eks.singlecam_smoother import ensemble_kalman_smoother_singlecam +from eks.utils import format_data, plot_results, populate_output_dataframe # Collect User-Provided Args smoother_type = 'singlecam' args = handle_parse_args(smoother_type) - input_dir = os.path.abspath(args.input_dir) - -# Note: LP and DLC are .csv, SLP is .slp -data_type = args.data_type - -# Find save directory if specified, otherwise defaults to outputs\ -save_dir = handle_io(input_dir, args.save_dir) +data_type = args.data_type # Note: LP and DLC are .csv, SLP is .slp +save_dir = handle_io(input_dir, args.save_dir) # defaults to outputs\ save_filename = args.save_filename - bodypart_list = args.bodypart_list -s = args.s # optional, defaults to automatic optimization - -# Load and format input files and prepare an empty DataFrame for output. -input_dfs_list, output_df = format_data(args.input_dir, data_type) - - -# --------------------------------------------- -# Run EKS Algorithm -# --------------------------------------------- - -# loop over keypoints; apply eks to each individually -for keypoint in bodypart_list: - # run eks - keypoint_df_dict, s_final, nll_values = ensemble_kalman_smoother_single_view( - input_dfs_list, - keypoint, - s, - ) - keypoint_df = keypoint_df_dict[keypoint + '_df'] - - # put results into new dataframe - output_df = populate_output_dataframe(keypoint_df, keypoint, output_df) - output_df.to_csv('populated_output.csv', index=False) - print(f"DataFrame successfully converted to CSV") -# save optimized smoothing param for plot title +s = args.s # defaults to automatic optimization +s_frames = args.s_frames # frames to be used for automatic optimization (only if no --s flag) +blocks = args.blocks -# save eks results -save_filename = save_filename or f'{smoother_type}.csv' # use type and s if no user input -output_df.to_csv(os.path.join(save_dir, save_filename)) - -# --------------------------------------------- -# plot results -# --------------------------------------------- - -# select example keypoint -kp = bodypart_list[-1] -idxs = (0, 1990) - -# crop NLL values -# nll_values_subset = nll_values[idxs[0]:idxs[1]] - -fig, axes = plt.subplots(5, 1, figsize=(9, 10)) - -for ax, coord in zip(axes, ['x', 'y', 'likelihood', 'zscore']): - # Rename axes label for likelihood and zscore coordinates - if coord == 'likelihood': - ylabel = 'model likelihoods' - elif coord == 'zscore': - ylabel = 'EKS disagreement' - else: - ylabel = coord - - # plot individual models - ax.set_ylabel(ylabel, fontsize=12) - if coord == 'zscore': - ax.plot( - output_df.loc[slice(*idxs), ('ensemble-kalman_tracker', f'{kp}', coord)], - color='k', linewidth=2) - ax.set_xlabel('Time (frames)', fontsize=12) - continue - for m, markers_curr in enumerate(input_dfs_list): - ax.plot( - markers_curr.loc[slice(*idxs), f'{kp}_{coord}'], color=[0.5, 0.5, 0.5], - label='Individual models' if m == 0 else None, - ) - # plot eks - if coord == 'likelihood': - continue - ax.plot( - output_df.loc[slice(*idxs), ('ensemble-kalman_tracker', f'{kp}', coord)], - color='k', linewidth=2, label='EKS', - ) - if coord == 'x': - ax.legend() - - # Plot nll_values_subset against the time axis - # axes[-1].plot(range(*idxs), nll_values_subset, color='k', linewidth=2) - # axes[-1].set_ylabel('EKS NLL', fontsize=12) - - -plt.suptitle(f'EKS results for {kp}, smoothing = {s}', fontsize=14) -plt.tight_layout() - -save_file = os.path.join(save_dir, f'singlecam_s={s}.pdf') -plt.savefig(save_file) -plt.close() -print(f'see example EKS output at {save_file}') +# Load and format input files and prepare an empty DataFrame for output. +input_dfs, output_df, keypoint_names = format_data(args.input_dir, data_type) +if bodypart_list is None: + bodypart_list = keypoint_names +print(f'Input data has been read in for the following keypoints:\n{bodypart_list}') + +# Convert list of DataFrames to a 3D NumPy array +data_arrays = [df.to_numpy() for df in input_dfs] +markers_3d_array = np.stack(data_arrays, axis=0) + +# Map keypoint names to keys in input_dfs and crop markers_3d_array +keypoint_is = {} +keys = [] +for i, col in enumerate(input_dfs[0].columns): + keypoint_is[col] = i +for part in bodypart_list: + keys.append(keypoint_is[part + '_x']) + keys.append(keypoint_is[part + '_y']) + keys.append(keypoint_is[part + '_likelihood']) +key_cols = np.array(keys) +markers_3d_array = markers_3d_array[:, :, key_cols] + +# Call the smoother function +df_dicts, s_finals = ensemble_kalman_smoother_singlecam( + markers_3d_array, + bodypart_list, + s, + s_frames, + blocks +) + +keypoint_i = -1 # keypoint to be plotted +# Save eks results in new DataFrames and .csv output files +for k in range(len(bodypart_list)): + df = df_dicts[k][bodypart_list[k] + '_df'] + output_df = populate_output_dataframe(df, bodypart_list[k], output_df) + save_filename = save_filename or f'{smoother_type}_{s_finals[keypoint_i]}.csv' + output_df.to_csv(os.path.join(save_dir, save_filename)) +print("DataFrames successfully converted to CSV") +# Plot results +plot_results(output_df=output_df, + input_dfs_list=input_dfs, + key=f'{bodypart_list[keypoint_i]}', + idxs=(0, 500), + s_final=s_finals[keypoint_i], + nll_values=None, + save_dir=save_dir, + smoother_type=smoother_type + ) diff --git a/scripts/timing.py b/scripts/timing.py deleted file mode 100644 index 550255d..0000000 --- a/scripts/timing.py +++ /dev/null @@ -1,19 +0,0 @@ -import subprocess -import time - -# Time script for timing runs of EKS. -start_time = time.time() -time_version_1 = subprocess.check_output([ - 'python', 'scripts/multicam_example.py', - '--csv-dir', './data/mirror-mouse', - '--bodypart-list', 'paw1LH', 'paw2LF', 'paw3RF', 'paw4RH', - '--camera-names', 'top', 'bot'], - text=True -) -end_time = time.time() - -# Calculate the execution time -execution_time = end_time - start_time - -# Print the results -print("Execution time:", execution_time) diff --git a/setup.py b/setup.py index 16fb509..4888322 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ -from setuptools import setup from pathlib import Path +from setuptools import setup # add the README.md file to the long_description with open('README.md', 'r') as fh: @@ -33,6 +33,9 @@ def get_version(rel_path): 'scipy>=1.2.0', 'tqdm', 'typing', + 'sleap_io', + 'jax', + 'jaxlib', ] # additional requirements