Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

singleview EKS cleanup #22

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 30 additions & 28 deletions eks/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import partial
from collections import defaultdict

import jax
import jax.scipy as jsc
Expand Down Expand Up @@ -257,24 +256,29 @@ def jax_ensemble(
markers_3d_array: np.ndarray,
avg_mode: str = 'median',
var_mode: str = 'confidence_weighted_var',
):
) -> tuple:
"""
Computes ensemble median (or mean) and variance of a 3D array of DLC marker data using JAX.
Compute ensemble mean/median and variance of a 3D marker array using JAX.

Args:
markers_3d_array
markers_3d_array: shape (n_models, samples, 3 * n_keypoints); "3" is for x, y, likelihood
avg_mode
'median' | 'mean'
var_mode
'confidence_weighted_var' | 'var'

Returns:
ensemble_preds: np.ndarray
shape (n_timepoints, n_keypoints, n_coordinates).
ensembled predictions for each keypoint for each target
ensemble_vars: np.ndarray
shape (n_timepoints, n_keypoints, n_coordinates).
ensembled variances for each keypoint for each target
tuple:
ensemble_preds: np.ndarray
shape (n_timepoints, n_keypoints, n_coordinates).
ensembled predictions for each keypoint
ensemble_vars: np.ndarray
shape (n_timepoints, n_keypoints, n_coordinates).
ensembled variances for each keypoint
ensemble_likes: np.ndarray
shape (n_timepoints, n_keypoints, 1).
mean likelihood for each keypoint

"""
markers_3d_array = jnp.array(markers_3d_array) # Convert to JAX array
n_frames = markers_3d_array.shape[1]
Expand All @@ -283,6 +287,7 @@ def jax_ensemble(
# Initialize output structures
ensemble_preds = np.zeros((n_frames, n_keypoints, 2))
ensemble_vars = np.zeros((n_frames, n_keypoints, 2))
ensemble_likes = np.zeros((n_frames, n_keypoints, 1))

# Choose the appropriate JAX function based on the mode
if avg_mode == 'median':
Expand All @@ -300,9 +305,10 @@ def compute_stats(i):
avg_x = avg_func(data_x)
avg_y = avg_func(data_y)

conf_per_keypoint = jnp.sum(data_likelihood, axis=0)
mean_conf_per_keypoint = conf_per_keypoint / data_likelihood.shape[0]

if var_mode in ['conf_weighted_var', 'confidence_weighted_var']:
conf_per_keypoint = jnp.sum(data_likelihood, axis=0)
mean_conf_per_keypoint = conf_per_keypoint / data_likelihood.shape[0]
var_x = jnp.nanvar(data_x, axis=0) / mean_conf_per_keypoint
var_y = jnp.nanvar(data_y, axis=0) / mean_conf_per_keypoint
elif var_mode in ['var', 'variance']:
Expand All @@ -311,28 +317,25 @@ def compute_stats(i):
else:
raise ValueError(f"{var_mode} for variance computation not supported")

return avg_x, avg_y, var_x, var_y
return avg_x, avg_y, var_x, var_y, mean_conf_per_keypoint

compute_stats_jit = jax.jit(compute_stats)
stats = jax.vmap(compute_stats_jit)(jnp.arange(n_keypoints))

avg_x, avg_y, var_x, var_y = stats
avg_x, avg_y, var_x, var_y, likes = stats

keypoints_avg_dict = {}
for i in range(n_keypoints):
ensemble_preds[:, i, 0] = avg_x[i]
ensemble_preds[:, i, 1] = avg_y[i]
ensemble_vars[:, i, 0] = var_x[i]
ensemble_vars[:, i, 1] = var_y[i]
keypoints_avg_dict[2 * i] = avg_x[i]
keypoints_avg_dict[2 * i + 1] = avg_y[i]
ensemble_likes[:, i, 0] = likes[i]

# Convert outputs to JAX arrays
ensemble_preds = jnp.array(ensemble_preds)
ensemble_vars = jnp.array(ensemble_vars)
keypoints_avg_dict = {k: jnp.array(v) for k, v in keypoints_avg_dict.items()}

return ensemble_preds, ensemble_vars, keypoints_avg_dict
return ensemble_preds, ensemble_vars, ensemble_likes


def kalman_filter_step(carry, curr_y):
Expand Down Expand Up @@ -720,27 +723,26 @@ def eks_zscore(eks_predictions, ensemble_means, ensemble_vars, min_ensemble_std=


def compute_covariance_matrix(ensemble_preds):
"""
Compute the covariance matrix E for correlated noise dynamics.
"""Compute the covariance matrix E for correlated noise dynamics.

Parameters:
ensemble_preds: A 3D array of shape (T, n_keypoints, n_coords)
containing the ensemble predictions.
Args:
ensemble_preds: shape (T, n_keypoints, n_coords) containing the ensemble predictions.

Returns:
E: A 2K x 2K covariance matrix where K is the number of keypoints.
E: A 2K x 2K covariance matrix where K is the number of keypoints.

"""
# Get the number of time steps, keypoints, and coordinates
T, n_keypoints, n_coords = ensemble_preds.shape

# Flatten the ensemble predictions to shape (T, 2K) where K is the number of keypoints
flattened_preds = ensemble_preds.reshape(T, -1)
# flattened_preds = ensemble_preds.reshape(T, -1)

# Compute the temporal differences
temporal_diffs = np.diff(flattened_preds, axis=0)
# temporal_diffs = np.diff(flattened_preds, axis=0)

# Compute the covariance matrix of the temporal differences
E = np.cov(temporal_diffs, rowvar=False)
# E = np.cov(temporal_diffs, rowvar=False)

# Index covariance matrix into blocks for each keypoint
cov_mats = []
Expand Down
59 changes: 26 additions & 33 deletions eks/ibl_pupil_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,26 +83,28 @@ def add_mean_to_array(pred_arr, keys, mean_x, mean_y):
def fit_eks_pupil(
input_source: Union[str, list],
save_file: str,
smooth_params: list,
smooth_params: Optional[list] = None,
s_frames: Optional[list] = None,
avg_mode: str = 'median',
var_mode: str = 'confidence_weighted_var',
) -> tuple:
"""Function to fit the Ensemble Kalman Smoother for the ibl-pupil dataset.
"""Fit the Ensemble Kalman Smoother for the ibl-pupil dataset.

Args:
input_source: Directory path or list of input CSV files.
save_file: File to save outputs.
smooth_params: List containing diameter_s and com_s.
input_source: directory path or list of CSV file paths. If a directory path, all files
within this directory will be used.
save_file: File to save output dataframe.
smooth_params: [diameter param, center of mass param]
each value should be in (0, 1); closer to 1 means more smoothing
s_frames: Frames for automatic optimization if needed.
avg_mode
avg_mode: mode for averaging across ensemble
'median' | 'mean'
var_mode
'confidence_weighted_var' | 'var'
var_mode: mode for computing ensemble variance
'var' | 'confidence_weighted_var'

Returns:
tuple:
df_smotthed (pd.DataFrame):
df_smoothed (pd.DataFrame)
smooth_params (list): Final smoothing parameters used.
input_dfs_list (list): List of input DataFrames.
keypoint_names (list): List of keypoint names.
Expand All @@ -111,14 +113,13 @@ def fit_eks_pupil(
"""

# Load and format input files
input_dfs_list, output_df, keypoint_names = format_data(input_source)
input_dfs_list, _, keypoint_names = format_data(input_source)

print(f"Input data loaded for keypoints: {keypoint_names}")

# Run the ensemble Kalman smoother
df_smoothed, smooth_params, nll_values = ensemble_kalman_smoother_ibl_pupil(
df_smoothed, smooth_params_final, nll_values = ensemble_kalman_smoother_ibl_pupil(
markers_list=input_dfs_list,
keypoint_names=keypoint_names,
smooth_params=smooth_params,
s_frames=s_frames,
avg_mode=avg_mode,
Expand All @@ -130,13 +131,12 @@ def fit_eks_pupil(
df_smoothed.to_csv(save_file)
print("DataFrames successfully converted to CSV")

return df_smoothed, smooth_params, input_dfs_list, keypoint_names, nll_values
return df_smoothed, smooth_params_final, input_dfs_list, keypoint_names, nll_values


def ensemble_kalman_smoother_ibl_pupil(
markers_list: list,
keypoint_names: list,
smooth_params: list,
smooth_params: Optional[list] = None,
s_frames: Optional[list] = None,
avg_mode: str = 'median',
var_mode: str = 'confidence_weighted_var',
Expand All @@ -147,7 +147,6 @@ def ensemble_kalman_smoother_ibl_pupil(
Args:
markers_list: pd.DataFrames
each list element is a dataframe of predictions from one ensemble member
keypoint_names
smooth_params: contains smoothing parameters for diameter and center of mass
s_frames: frames for automatic optimization if s is not provided
avg_mode
Expand All @@ -165,12 +164,13 @@ def ensemble_kalman_smoother_ibl_pupil(

"""

# compute ensemble median
keys = [
'pupil_top_r_x', 'pupil_top_r_y', 'pupil_bottom_r_x', 'pupil_bottom_r_y',
'pupil_right_r_x', 'pupil_right_r_y', 'pupil_left_r_x', 'pupil_left_r_y',
]
ensemble_preds, ensemble_vars, ensemble_likes, ensemble_stacks = ensemble(
# pupil smoother only works for a pre-specified set of points
# NOTE: this order MUST be kept
keypoint_names = ['pupil_top_r', 'pupil_bottom_r', 'pupil_right_r', 'pupil_left_r']
keys = [f'{kp}_{coord}' for kp in keypoint_names for coord in ['x', 'y']]

# compute ensemble information
ensemble_preds, ensemble_vars, ensemble_likes, _ = ensemble(
markers_list, keys, avg_mode=avg_mode, var_mode=var_mode,
)

Expand Down Expand Up @@ -202,25 +202,18 @@ def ensemble_kalman_smoother_ibl_pupil(
[0, 1, 0], [.5, 0, 1],
[.5, 1, 0], [0, 0, 1],
[-.5, 1, 0], [0, 0, 1]
])
])

# placeholder diagonal matrix for ensemble variance
R = np.eye(8)

scaled_ensemble_preds = ensemble_preds.copy()
scaled_ensemble_stacks = ensemble_stacks.copy()
# subtract COM means from the ensemble predictions
for i in range(ensemble_preds.shape[1]):
if i % 2 == 0:
scaled_ensemble_preds[:, i] -= mean_x_obs
else:
scaled_ensemble_preds[:, i] -= mean_y_obs
# subtract COM means from all the predictions
for i in range(ensemble_preds.shape[1]):
if i % 2 == 0:
scaled_ensemble_stacks[:, :, i] -= mean_x_obs
else:
scaled_ensemble_stacks[:, :, i] -= mean_y_obs
y_obs = scaled_ensemble_preds

# --------------------------------------
Expand Down Expand Up @@ -304,12 +297,12 @@ def ensemble_kalman_smoother_ibl_pupil(

def pupil_optimize_smooth(
y, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var,
s_frames=[(1, 2000)],
smooth_params=[None, None],
s_frames: Optional[list] = [(1, 2000)],
smooth_params: Optional[list] = [None, None],
):
"""Optimize-and-smooth function for the pupil example script."""
# Optimize smooth_param
if smooth_params[0] is None or smooth_params[1] is None:
if smooth_params is None or smooth_params[0] is None or smooth_params[1] is None:

# Unpack s_frames
y_shortened = crop_frames(y, s_frames)
Expand Down
20 changes: 14 additions & 6 deletions eks/multicam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
import pandas as pd
from scipy.optimize import minimize

from eks.core import ensemble, eks_zscore, compute_initial_guesses, forward_pass, backward_pass, \
compute_nll
from eks.ibl_paw_multiview_smoother import remove_camera_means, pca
from eks.utils import make_dlc_pandas_index, crop_frames
from eks.core import (
backward_pass,
compute_initial_guesses,
compute_nll,
eks_zscore,
ensemble,
forward_pass,
)
from eks.ibl_paw_multiview_smoother import pca, remove_camera_means
from eks.utils import crop_frames, make_dlc_pandas_index


def ensemble_kalman_smoother_multicam(
Expand Down Expand Up @@ -158,8 +164,10 @@ def ensemble_kalman_smoother_multicam(
# --------------------------------------
# final cleanup
# --------------------------------------
pdindex = make_dlc_pandas_index([keypoint_ensemble],
labels=["x", "y", "likelihood", "x_var", "y_var", "zscore", "nll", "ensemble_std"])
pdindex = make_dlc_pandas_index(
[keypoint_ensemble],
labels=["x", "y", "likelihood", "x_var", "y_var", "zscore", "nll", "ensemble_std"]
)
camera_indices = []
for camera in range(num_cameras):
camera_indices.append([camera * 2, camera * 2 + 1])
Expand Down
Loading
Loading