Skip to content

Commit

Permalink
cleanup singleview and pupil eks scripts/functions (#16)
Browse files Browse the repository at this point in the history
* initial cleanup of singlecam smoother

* bugfix for already-arrayed smooth params

* initial cleanup of pupil smoother
  • Loading branch information
themattinthehatt authored Dec 10, 2024
1 parent c31ba0e commit af5d74d
Show file tree
Hide file tree
Showing 14 changed files with 313 additions and 240 deletions.
6 changes: 0 additions & 6 deletions eks/command_line_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ def handle_parse_args(script_type):
default=None,
type=str,
)
parser.add_argument(
'--data-type',
help='format of input data (Lightning Pose = lp, SLEAP = slp), dlc by default.',
default='lp',
type=str,
)
parser.add_argument(
'--s-frames',
help='frames to be considered for smoothing '
Expand Down
34 changes: 22 additions & 12 deletions eks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,21 @@ def compute_nll(innovations, innovation_covs, epsilon=1e-6):

# ----- Sequential Functions for CPU -----

def jax_ensemble(markers_3d_array, mode='median'):
def jax_ensemble(
markers_3d_array: np.ndarray,
avg_mode: str = 'median',
var_mode: str = 'confidence_weighted_var',
):
"""
Computes ensemble median (or mean) and variance of a 3D array of DLC marker data using JAX.
Args:
markers_3d_array
avg_mode
'median' | 'mean'
var_mode
'confidence_weighted_var' | 'var'
Returns:
ensemble_preds: np.ndarray
shape (n_timepoints, n_keypoints, n_coordinates).
Expand All @@ -281,32 +292,31 @@ def jax_ensemble(markers_3d_array, mode='median'):
ensemble_vars = np.zeros((n_frames, n_keypoints, 2))

# Choose the appropriate JAX function based on the mode
if mode == 'median':
if avg_mode == 'median':
avg_func = lambda x: jnp.nanmedian(x, axis=0)
elif mode == 'mean':
elif avg_mode == 'mean':
avg_func = lambda x: jnp.nanmean(x, axis=0)
elif mode == 'confidence_weighted_mean':
avg_func = None
else:
raise ValueError(f"{mode} averaging not supported")
raise ValueError(f"{avg_mode} averaging not supported")

def compute_stats(i):
data_x = markers_3d_array[:, :, 3 * i]
data_y = markers_3d_array[:, :, 3 * i + 1]
data_likelihood = markers_3d_array[:, :, 3 * i + 2]

if mode == 'confidence_weighted_mean':
avg_x = avg_func(data_x)
avg_y = avg_func(data_y)

if var_mode in ['conf_weighted_var', 'confidence_weighted_var']:
conf_per_keypoint = jnp.sum(data_likelihood, axis=0)
mean_conf_per_keypoint = conf_per_keypoint / data_likelihood.shape[0]
avg_x = jnp.sum(data_x * data_likelihood, axis=0) / conf_per_keypoint
avg_y = jnp.sum(data_y * data_likelihood, axis=0) / conf_per_keypoint
var_x = jnp.nanvar(data_x, axis=0) / mean_conf_per_keypoint
var_y = jnp.nanvar(data_y, axis=0) / mean_conf_per_keypoint
else:
avg_x = avg_func(data_x)
avg_y = avg_func(data_y)
elif var_mode in ['var', 'variance']:
var_x = jnp.nanvar(data_x, axis=0)
var_y = jnp.nanvar(data_y, axis=0)
else:
raise ValueError(f"{var_mode} for variance computation not supported")

return avg_x, avg_y, var_x, var_y

Expand Down
137 changes: 74 additions & 63 deletions eks/ibl_pupil_smoother.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import warnings
from typing import Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -8,9 +10,6 @@
from eks.utils import crop_frames, make_dlc_pandas_index, format_data


# -----------------------
# funcs for kalman pupil
# -----------------------
def get_pupil_location(dlc):
"""get mean of both pupil diameters
d1 = top - bottom, d2 = left - right
Expand Down Expand Up @@ -79,72 +78,75 @@ def add_mean_to_array(pred_arr, keys, mean_x, mean_y):
return processed_arr_dict


def fit_eks_pupil(input_source, data_type, save_dir, smooth_params, s_frames):
"""
Wrapper function to fit the Ensemble Kalman Smoother for the ibl-pupil dataset.
def fit_eks_pupil(
input_source: Union[str, list],
save_file: str,
smooth_params: list,
s_frames: Optional[list] = None,
) -> tuple:
"""Function to fit the Ensemble Kalman Smoother for the ibl-pupil dataset.
Args:
input_source (str or list): Directory path or list of input CSV files.
data_type (str): Type of data (e.g., 'csv', 'slp').
save_dir (str): Directory to save outputs.
smooth_params (list): List containing diameter_s and com_s.
s_frames (list or None): Frames for automatic optimization if needed.
input_source: Directory path or list of input CSV files.
save_file: File to save outputs.
smooth_params: List containing diameter_s and com_s.
s_frames: Frames for automatic optimization if needed.
Returns:
df_dicts (dict): Dictionary containing smoothed DataFrames.
smooth_params (list): Final smoothing parameters used.
input_dfs_list (list): List of input DataFrames.
keypoint_names (list): List of keypoint names.
nll_values (list): List of NLL values.
tuple:
df_dicts (dict): Dictionary containing smoothed DataFrames.
smooth_params (list): Final smoothing parameters used.
input_dfs_list (list): List of input DataFrames.
keypoint_names (list): List of keypoint names.
nll_values (list): List of NLL values.
"""

# Load and format input files
input_dfs_list, output_df, keypoint_names = format_data(input_source, data_type)
input_dfs_list, output_df, keypoint_names = format_data(input_source)

print(f"Input data loaded for keypoints: {keypoint_names}")

# Run the ensemble Kalman smoother
df_dicts, smooth_params, nll_values = ensemble_kalman_smoother_ibl_pupil(
markers_list=input_dfs_list,
keypoint_names=keypoint_names,
tracker_name='ensemble-kalman_tracker',
smooth_params=smooth_params,
s_frames=s_frames
)

# Save the output DataFrame to CSV
os.makedirs(os.path.dirname(save_file), exist_ok=True)
df_dicts['markers_df'].to_csv(save_file)
print("DataFrames successfully converted to CSV")

return df_dicts, smooth_params, input_dfs_list, keypoint_names, nll_values


def ensemble_kalman_smoother_ibl_pupil(
markers_list,
keypoint_names,
tracker_name,
smooth_params,
s_frames,
likelihood_default=np.nan,
zscore_threshold=2,
):
"""
markers_list: list,
keypoint_names: list,
smooth_params: list,
s_frames: Optional[list] = None,
zscore_threshold: float = 2,
) -> tuple:
"""Perform Ensemble Kalman Smoothing on pupil data.
Parameters
----------
markers_list : list of pd.DataFrames
Args:
markers_list: pd.DataFrames
each list element is a dataframe of predictions from one ensemble member
keypoint_names: list
tracker_name : str
tracker name for constructing final dataframe
smooth_params : [float, float]
contains smoothing parameters for diameter and center of mass
likelihood_default
value to store in likelihood column; should be np.nan or int in [0, 1]
zscore_threshold:
Minimum std threshold to reduce the effect of low ensemble std on a zscore metric
(default 2).
Returns
-------
dict
markers_df: dataframe containing smoothed markers; same format as input dataframes
latents_df: dataframe containing 3d latents: pupil diameter and pupil center of mass
keypoint_names
smooth_params: contains smoothing parameters for diameter and center of mass
s_frames: frames for automatic optimization if s is not provided
zscore_threshold: Minimum std threshold to reduce the effect of low ensemble std on a
zscore metric (default 2).
Returns:
tuple
dict: markers_df contains smoothed markers; latents_df contains 3d latents: pupil
diameter and pupil center of mass
final smooth params values
final nll
"""

Expand All @@ -153,7 +155,7 @@ def ensemble_kalman_smoother_ibl_pupil(
'pupil_right_r_x', 'pupil_right_r_y', 'pupil_left_r_x', 'pupil_left_r_y']
ensemble_preds, ensemble_vars, ensemble_stacks, keypoints_mean_dict, keypoints_var_dict, \
keypoints_stack_dict = ensemble(markers_list, keys)
# ## Set parameters

# compute center of mass
pupil_locations = get_pupil_location(keypoints_mean_dict)
pupil_diameters = get_pupil_diameter(keypoints_mean_dict)
Expand Down Expand Up @@ -186,9 +188,12 @@ def ensemble_kalman_smoother_ibl_pupil(
])

# Measurement function
C = np.asarray(
[[0, 1, 0], [-.5, 0, 1], [0, 1, 0], [.5, 0, 1], [.5, 1, 0], [0, 0, 1], [-.5, 1, 0],
[0, 0, 1]])
C = np.asarray([
[0, 1, 0], [-.5, 0, 1],
[0, 1, 0], [.5, 0, 1],
[.5, 1, 0], [0, 0, 1],
[-.5, 1, 0], [0, 0, 1]
])

# placeholder diagonal matrix for ensemble variance
R = np.eye(8)
Expand Down Expand Up @@ -224,9 +229,12 @@ def ensemble_kalman_smoother_ibl_pupil(
# --------------------------------------
# cleanup
# --------------------------------------

# save out marker info
pdindex = make_dlc_pandas_index(keypoint_names,
labels=["x", "y", "likelihood", "x_var", "y_var", "zscore"])
pdindex = make_dlc_pandas_index(
keypoint_names,
labels=["x", "y", "likelihood", "x_var", "y_var", "zscore"]
)
processed_arr_dict = add_mean_to_array(y_m_smooth, keys, mean_x_obs, mean_y_obs)
key_pair_list = [['pupil_top_r_x', 'pupil_top_r_y'],
['pupil_right_r_x', 'pupil_right_r_y'],
Expand All @@ -238,7 +246,7 @@ def ensemble_kalman_smoother_ibl_pupil(
pred_arr.append(processed_arr_dict[key_pair[0]])
pred_arr.append(processed_arr_dict[key_pair[1]])
var = np.empty(processed_arr_dict[key_pair[0]].shape)
var[:] = likelihood_default
var[:] = 1.0 # TODO: median of observed likelihoods
pred_arr.append(var)
x_var = y_v_smooth[:, i, i]
y_var = y_v_smooth[:, i + 1, i + 1]
Expand All @@ -253,17 +261,19 @@ def ensemble_kalman_smoother_ibl_pupil(
eks_predictions,
ensemble_preds_curr,
ensemble_vars_curr,
min_ensemble_std=zscore_threshold)
min_ensemble_std=zscore_threshold,
)
pred_arr.append(zscore)

pred_arr = np.asarray(pred_arr)
markers_df = pd.DataFrame(pred_arr.T, columns=pdindex)

# save out latents info: pupil diam, center of mass
pred_arr2 = []
pred_arr2.append(ms[:, 0])
pred_arr2.append(ms[:, 1] + mean_x_obs) # add back x mean of pupil location
pred_arr2.append(ms[:, 2] + mean_y_obs) # add back ys mean of pupil location
pred_arr2 = np.asarray(pred_arr2)
pred_arr2 = np.asarray([
ms[:, 0],
ms[:, 1] + mean_x_obs, # add back x mean of pupil location
ms[:, 2] + mean_y_obs, # add back ys mean of pupil location
])
tracker_name = 'ensemble-kalman_tracker'
arrays = [[tracker_name, tracker_name, tracker_name], ['diameter', 'com_x', 'com_y']]
pd_index2 = pd.MultiIndex.from_arrays(arrays, names=('scorer', 'latent'))
latents_df = pd.DataFrame(pred_arr2.T, columns=pd_index2)
Expand All @@ -272,9 +282,10 @@ def ensemble_kalman_smoother_ibl_pupil(


def pupil_optimize_smooth(
y, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var,
s_frames=[(1, 2000)],
smooth_params=[None, None]):
y, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var,
s_frames=[(1, 2000)],
smooth_params=[None, None],
):
"""Optimize-and-smooth function for the pupil example script."""
# Optimize smooth_param
if smooth_params[0] is None or smooth_params[1] is None:
Expand Down
Loading

0 comments on commit af5d74d

Please sign in to comment.