Skip to content

IBL pupil cleanup #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions eks/command_line_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def handle_parse_args(script_type):
add_camera_names(parser)
add_quantile_keep_pca(parser)
add_s(parser)
elif script_type == 'pupil':
elif script_type == 'ibl_pupil':
add_diameter_s(parser)
add_com_s(parser)
elif script_type == 'paw':
elif script_type == 'ibl_paw':
add_s(parser)
add_quantile_keep_pca(parser)
else:
Expand Down
131 changes: 62 additions & 69 deletions eks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,87 +16,80 @@
# ------------------------------------------------------------------------------------------


def ensemble(markers_list, keys, mode='median'):
"""Computes ensemble median (or mean) and variance of list of DLC marker dataframes
def ensemble(
markers_list: list,
keys: list,
avg_mode: str = 'median',
var_mode: str = 'confidence_weighted_var',
) -> tuple:
"""Compute ensemble mean/median and variance of marker dataframes.

Args:
markers_list: list
List of DLC marker dataframes`
keys: list
List of keys in each marker dataframe
mode: string
Averaging mode which includes 'median', 'mean', or 'confidence_weighted_mean'.
markers_list: List of DLC marker dataframes
keys: List of keys in each marker dataframe
avg_mode
'median' | 'mean'
var_mode
'confidence_weighted_var' | 'var'

Returns:
ensemble_preds: np.ndarray
shape (samples, n_keypoints)
ensemble_vars: np.ndarray
shape (samples, n_keypoints)
ensemble_stacks: np.ndarray
shape (n_models, samples, n_keypoints)
keypoints_avg_dict: dict
keys: marker keypoints, values: shape (samples)
keypoints_var_dict: dict
keys: marker keypoints, values: shape (samples)
keypoints_stack_dict: dict(dict)
keys: model_ids, keys: marker keypoints, values: shape (samples)
tuple:
ensemble_preds: np.ndarray
shape (samples, n_keypoints)
ensemble_vars: np.ndarray
shape (samples, n_keypoints)
ensemble_likelihoods: np.ndarray
shape (samples, n_keypoints)
ensemble_stacks: np.ndarray
shape (n_models, samples, n_keypoints)

"""
ensemble_stacks = []
ensemble_vars = []

ensemble_preds = []
keypoints_avg_dict = {}
keypoints_var_dict = {}
keypoints_stack_dict = defaultdict(dict)
if mode != 'confidence_weighted_mean':
if mode == 'median':
average_func = np.nanmedian
elif mode == 'mean':
average_func = np.nanmean
else:
raise ValueError(f"{mode} averaging not supported")
ensemble_vars = []
ensemble_likes = []
ensemble_stacks = []

if avg_mode == 'median':
average_func = np.nanmedian
elif avg_mode == 'mean':
average_func = np.nanmean
else:
raise ValueError(f"avg_mode={avg_mode} not supported")

for key in keys:
if mode != 'confidence_weighted_mean':
stack = np.zeros((len(markers_list), markers_list[0].shape[0]))
for k in range(len(markers_list)):
stack[k] = markers_list[k][key]
stack = stack.T
avg = average_func(stack, 1)
var = np.nanvar(stack, 1)
ensemble_preds.append(avg)
ensemble_vars.append(var)
ensemble_stacks.append(stack)
keypoints_avg_dict[key] = avg
keypoints_var_dict[key] = var
for i, keypoints in enumerate(stack.T):
keypoints_stack_dict[i][key] = stack.T[i]
else:
likelihood_key = key[:-1] + 'likelihood'
if likelihood_key not in markers_list[0]:
raise ValueError(f"{likelihood_key} needs to be in your marker_df to use {mode}")
stack = np.zeros((len(markers_list), markers_list[0].shape[0]))
likelihood_stack = np.zeros((len(markers_list), markers_list[0].shape[0]))

# compute mean/median
stack = np.zeros((markers_list[0].shape[0], len(markers_list)))
for k in range(len(markers_list)):
stack[:, k] = markers_list[k][key]
ensemble_stacks.append(stack)
avg = average_func(stack, axis=1)
ensemble_preds.append(avg)

# collect likelihoods
likelihood_stack = np.ones((markers_list[0].shape[0], len(markers_list)))
likelihood_key = key[:-1] + 'likelihood'
if likelihood_key in markers_list[0]:
for k in range(len(markers_list)):
stack[k] = markers_list[k][key]
likelihood_stack[k] = markers_list[k][likelihood_key]
stack = stack.T
likelihood_stack = likelihood_stack.T
conf_per_keypoint = np.sum(likelihood_stack, 1)
mean_conf_per_keypoint = np.sum(likelihood_stack, 1) / likelihood_stack.shape[1]
avg = np.sum(stack * likelihood_stack, 1) / conf_per_keypoint
var = np.nanvar(stack, 1)
likelihood_stack[:, k] = markers_list[k][likelihood_key]
mean_conf_per_keypoint = np.mean(likelihood_stack, axis=1)
ensemble_likes.append(mean_conf_per_keypoint)

# compute variance
var = np.nanvar(stack, axis=1)
if var_mode in ['conf_weighted_var', 'confidence_weighted_var']:
var = var / mean_conf_per_keypoint # low-confidence --> inflated obs variances
ensemble_preds.append(avg)
ensemble_vars.append(var)
ensemble_stacks.append(stack)
keypoints_avg_dict[key] = avg
keypoints_var_dict[key] = var
for i, keypoints in enumerate(stack.T):
keypoints_stack_dict[i][key] = stack.T[i]
elif var_mode != 'var':
raise ValueError(f"var_mode={var_mode} not supported")
ensemble_vars.append(var)

ensemble_preds = np.asarray(ensemble_preds).T
ensemble_vars = np.asarray(ensemble_vars).T
ensemble_likes = np.asarray(ensemble_likes).T
ensemble_stacks = np.asarray(ensemble_stacks).T
return ensemble_preds, ensemble_vars, ensemble_stacks, \
keypoints_avg_dict, keypoints_var_dict, keypoints_stack_dict

return ensemble_preds, ensemble_vars, ensemble_likes, ensemble_stacks


def forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars):
Expand Down
33 changes: 19 additions & 14 deletions eks/ibl_paw_multiview_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from eks.utils import make_dlc_pandas_index


# TODO:
# - allow conf_weighted_mean for ensemble variance computation


def remove_camera_means(ensemble_stacks, camera_means):
scaled_ensemble_stacks = ensemble_stacks.copy()
for k in range(len(ensemble_stacks)):
Expand All @@ -31,10 +35,15 @@ def pca(S, n_comps):


def ensemble_kalman_smoother_ibl_paw(
markers_list_left_cam, markers_list_right_cam, timestamps_left_cam,
timestamps_right_cam, keypoint_names, smooth_param, quantile_keep_pca,
ensembling_mode='median',
zscore_threshold=2, img_width=128):
markers_list_left_cam, markers_list_right_cam,
timestamps_left_cam, timestamps_right_cam,
keypoint_names,
smooth_param,
quantile_keep_pca,
ensembling_mode='median',
zscore_threshold=2,
img_width=128,
):
"""
--(IBL-specific)-
-Use multi-view constraints to fit a 3d latent subspace for each body part with 2
Expand Down Expand Up @@ -63,8 +72,6 @@ def ensemble_kalman_smoother_ibl_paw(
(default 2).
img_width
The width of the image being smoothed (128 default, IBL-specific).
Returns
-------

Returns
-------
Expand Down Expand Up @@ -125,16 +132,14 @@ def ensemble_kalman_smoother_ibl_paw(
markers_list_right_cam.append(markers_right_cam)

# compute ensemble median left camera
left_cam_ensemble_preds, left_cam_ensemble_vars, left_cam_ensemble_stacks, \
left_cam_keypoints_mean_dict, left_cam_keypoints_var_dict, \
left_cam_keypoints_stack_dict = \
ensemble(markers_list_left_cam, keys, mode=ensembling_mode)
left_cam_ensemble_preds, left_cam_ensemble_vars, _, left_cam_ensemble_stacks = ensemble(
markers_list_left_cam, keys, avg_mode=ensembling_mode, var_mode='var',
)

# compute ensemble median right camera
right_cam_ensemble_preds, right_cam_ensemble_vars, right_cam_ensemble_stacks, \
right_cam_keypoints_mean_dict, right_cam_keypoints_var_dict, \
right_cam_keypoints_stack_dict = \
ensemble(markers_list_right_cam, keys, mode=ensembling_mode)
right_cam_ensemble_preds, right_cam_ensemble_vars, _, right_cam_ensemble_stacks = ensemble(
markers_list_right_cam, keys, avg_mode=ensembling_mode, var_mode='var',
)

# keep percentage of the points for multi-view PCA based lowest ensemble variance
hstacked_vars = np.hstack((left_cam_ensemble_vars, right_cam_ensemble_vars))
Expand Down
Loading
Loading