Skip to content

Commit

Permalink
Ensemble Varinace Bugfix (#10)
Browse files Browse the repository at this point in the history
* added verbose flag to reduce print statements

* merging

* more verbose print removal

* more, more verbose print removal

* outputs nll values

* zscore threshold set

* eks scalar covariance inflation, initial pytest setup

* removed SLEAP fish workaround

* merge

* added posterior var to eks output csvs

* ens var dynamic update fix

* merge

* removed debug prints

* fixed zscore indexing

* removed debug print for covariance scaling

* flake8

---------

Co-authored-by: Matt Whiteway <[email protected]>
  • Loading branch information
keeminlee and themattinthehatt authored Nov 4, 2024
1 parent 8b28a7e commit d601c1a
Show file tree
Hide file tree
Showing 11 changed files with 383 additions and 59 deletions.
6 changes: 6 additions & 0 deletions eks/command_line_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def handle_parse_args(script_type):
default=[],
type=parse_blocks,
)
parser.add_argument(
'--verbose',
help='if set to true, displays smoothing parameter optimization iterations',
default='',
type=str,
)
if script_type == 'singlecam':
add_bodyparts(parser)
add_s(parser)
Expand Down
94 changes: 85 additions & 9 deletions eks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def ensemble(markers_list, keys, mode='median'):
ensemble_vars = np.asarray(ensemble_vars).T
ensemble_stacks = np.asarray(ensemble_stacks).T
return ensemble_preds, ensemble_vars, ensemble_stacks, \
keypoints_avg_dict, keypoints_var_dict, keypoints_stack_dict
keypoints_avg_dict, keypoints_var_dict, keypoints_stack_dict


def forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars):
Expand Down Expand Up @@ -344,18 +344,53 @@ def kalman_filter_step(carry, curr_y):
innovation_cov = jnp.dot(C, jnp.dot(V_pred, C.T)) + R
K = jnp.dot(V_pred, jnp.dot(C.T, jnp.linalg.inv(innovation_cov)))
m_t = m_pred + jnp.dot(K, innovation)
V_t = V_pred - jnp.dot(K, jnp.dot(C, V_pred))
V_t = jnp.dot((jnp.eye(V_pred.shape[0]) - jnp.dot(K, C)), V_pred)

nll_current = single_timestep_nll(innovation, innovation_cov)
nll_net = nll_net + nll_current

return (m_t, V_t, A, Q, C, R, nll_net), (m_t, V_t, nll_current)


def kalman_filter_step_nlls(carry, inputs):
# Unpack carry and inputs
m_prev, V_prev, A, Q, C, R, nll_net, nll_array, t = carry
curr_y, curr_ensemble_var = inputs

# Update R with the current ensemble variance
R = jnp.diag(curr_ensemble_var)

# Predict
m_pred = jnp.dot(A, m_prev)
V_pred = jnp.dot(A, jnp.dot(V_prev, A.T)) + Q

# Update
innovation = curr_y - jnp.dot(C, m_pred)
innovation_cov = jnp.dot(C, jnp.dot(V_pred, C.T)) + R
K = jnp.dot(V_pred, jnp.dot(C.T, jnp.linalg.inv(innovation_cov)))
m_t = m_pred + jnp.dot(K, innovation)
V_t = V_pred - jnp.dot(K, jnp.dot(C, V_pred))

# Compute the negative log-likelihood for the current time step
nll_current = single_timestep_nll(innovation, innovation_cov)

# Accumulate the negative log-likelihood
nll_net = nll_net + nll_current

# Save the current NLL to the preallocated array
nll_array = nll_array.at[t].set(nll_current)

# Increment the time step
t = t + 1

# Return the updated state and outputs
return (m_t, V_t, A, Q, C, R, nll_net, nll_array, t), (m_t, V_t, nll_current)


# Always run the sequential filter on CPU.
# GPU will deploy individual kernels for each scan iteration, very slow.
@partial(jit, backend='cpu')
def jax_forward_pass(y, m0, cov0, A, Q, C, R):
def jax_forward_pass(y, m0, cov0, A, Q, C, R, ensemble_vars):
"""
Kalman Filter for a single keypoint
(can be vectorized using vmap for handling multiple keypoints in parallel)
Expand All @@ -367,6 +402,7 @@ def jax_forward_pass(y, m0, cov0, A, Q, C, R):
Q: Shape (state_dim, state_dim). Process noise covariance matrix.
C: Shape (observation_dim, state_dim). Observation coefficient matrix.
R: Shape (observation_dim, observation_dim). Observation noise covar matrix.
ensemble_vars: Shape (num_timepoints, observation_dimension). Time-varying obs noise var.
Returns:
mfs: Shape (timepoints, state_dim). Mean filter state at each timepoint.
Expand All @@ -375,12 +411,52 @@ def jax_forward_pass(y, m0, cov0, A, Q, C, R):
"""
# Initialize carry
carry = (m0, cov0, A, Q, C, R, 0)
carry, outputs = jax.lax.scan(kalman_filter_step, carry, y)

# Run the scan, passing y and ensemble_vars as inputs to kalman_filter_step
carry, outputs = jax.lax.scan(kalman_filter_step, carry, (y, ensemble_vars))
mfs, Vfs, _ = outputs
nll_net = carry[-1]
return mfs, Vfs, nll_net


def jax_forward_pass_nlls(y, m0, cov0, A, Q, C, R, ensemble_vars):
"""
Kalman Filter for a single keypoint
(can be vectorized using vmap for handling multiple keypoints in parallel)
Parameters:
y: Shape (num_timepoints, observation_dimension).
m0: Shape (state_dim,). Initial state of system.
cov0: Shape (state_dim, state_dim). Initial covariance of state variable.
A: Shape (state_dim, state_dim). Process transition matrix.
Q: Shape (state_dim, state_dim). Process noise covariance matrix.
C: Shape (observation_dim, state_dim). Observation coefficient matrix.
R: Shape (observation_dim, observation_dim). Observation noise covar matrix.
Returns:
mfs: Shape (timepoints, state_dim). Mean filter state at each timepoint.
Vfs: Shape (timepoints, state_dim, state_dim). Covar for each filtered estimate.
nll_net: Shape (1,). Negative log likelihood observations -log (p(y_1, ..., y_T))
nll_array: Shape (num_timepoints,). Incremental negative log-likelihood at each timepoint.
"""
# Ensure R is a (2, 2) matrix
if R.ndim == 1:
R = jnp.diag(R)

# Initialize carry
num_timepoints = y.shape[0]
nll_array_init = jnp.zeros(num_timepoints) # Preallocate an array with zeros
t_init = 0 # Initialize the time step counter
carry = (m0, cov0, A, Q, C, R, 0, nll_array_init, t_init)

# Run the scan, passing y and ensemble_vars
carry, outputs = jax.lax.scan(kalman_filter_step_nlls, carry, (y, ensemble_vars))
mfs, Vfs, _ = outputs
nll_net = carry[-3] # Total NLL
nll_array = carry[-2] # Array of incremental NLL values

return mfs, Vfs, nll_net, nll_array


def kalman_smoother_step(carry, X):
m_ahead_smooth, v_ahead_smooth, A, Q = carry
m_curr_filter, v_curr_filter = X[0], X[1]
Expand All @@ -390,7 +466,7 @@ def kalman_smoother_step(carry, X):

smoothing_gain = jsc.linalg.solve(ahead_cov, jnp.dot(A, v_curr_filter.T)).T
smoothed_state = m_curr_filter + jnp.dot(smoothing_gain, m_ahead_smooth - m_curr_filter)
smoothed_cov = v_curr_filter + jnp.dot(jnp.dot(smoothing_gain, m_ahead_smooth - ahead_cov),
smoothed_cov = v_curr_filter + jnp.dot(jnp.dot(smoothing_gain, v_ahead_smooth - ahead_cov),
smoothing_gain.T)

return (smoothed_state, smoothed_cov, A, Q), (smoothed_state, smoothed_cov)
Expand Down Expand Up @@ -612,7 +688,7 @@ def pkf_and_loss(y, m0, cov0, A, Q, C, R):
# -------------------------------------------------------------------------------------


def eks_zscore(eks_predictions, ensemble_means, ensemble_vars, min_ensemble_std=2):
def eks_zscore(eks_predictions, ensemble_means, ensemble_vars, min_ensemble_std=1e-5):
"""Computes zscore between eks prediction and the ensemble for a single keypoint.
Args:
eks_predictions: list
Expand All @@ -622,7 +698,7 @@ def eks_zscore(eks_predictions, ensemble_means, ensemble_vars, min_ensemble_std=
ensemble_vars: string
Ensemble var for each coordinate (x and ys) for as single keypoint - (samples, 2)
min_ensemble_std:
Minimum std threshold to reduce the effect of low ensemble std (default 2).
Minimum std threshold to reduce the effect of low ensemble std (default 1e-5).
Returns:
z_score
z_score for each time point - (samples, 1)
Expand All @@ -637,7 +713,7 @@ def eks_zscore(eks_predictions, ensemble_means, ensemble_vars, min_ensemble_std=
thresh_ensemble_std = ensemble_std.copy()
thresh_ensemble_std[thresh_ensemble_std < min_ensemble_std] = min_ensemble_std
z_score = num / thresh_ensemble_std
return z_score
return z_score, ensemble_std


def compute_covariance_matrix(ensemble_preds):
Expand Down Expand Up @@ -667,7 +743,7 @@ def compute_covariance_matrix(ensemble_preds):
cov_mats = []
for i in range(n_keypoints):
E_block = extract_submatrix(E, i)
cov_mats.append(E_block)
cov_mats.append([[1, 0], [0, 1]])
cov_mats = jnp.array(cov_mats)
return cov_mats

Expand Down
4 changes: 2 additions & 2 deletions eks/ibl_paw_multiview_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ def ensemble_kalman_smoother_ibl_paw(
scaled_y_m_smooth.T[1 + 2 * i]]).T
ensemble_preds = scaled_y[:, 2 * i:2 * (i + 1)]
ensemble_vars_curr = ensemble_vars[:, 2 * i:2 * (i + 1)]
zscore = eks_zscore(eks_predictions, ensemble_preds, ensemble_vars_curr,
min_ensemble_std=4)
zscore, _ = eks_zscore(eks_predictions, ensemble_preds, ensemble_vars_curr,
min_ensemble_std=4)
pred_arr.append(zscore)
###
pred_arr = np.asarray(pred_arr)
Expand Down
7 changes: 5 additions & 2 deletions eks/ibl_pupil_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,11 @@ def ensemble_kalman_smoother_ibl_pupil(
np.asarray([processed_arr_dict[key_pair[0]], processed_arr_dict[key_pair[1]]]).T
ensemble_preds_curr = ensemble_preds[:, ensemble_indices[i][0]: ensemble_indices[i][1] + 1]
ensemble_vars_curr = ensemble_vars[:, ensemble_indices[i][0]: ensemble_indices[i][1] + 1]
zscore = eks_zscore(eks_predictions, ensemble_preds_curr, ensemble_vars_curr,
min_ensemble_std=zscore_threshold)
zscore, _ = eks_zscore(
eks_predictions,
ensemble_preds_curr,
ensemble_vars_curr,
min_ensemble_std=zscore_threshold)
pred_arr.append(zscore)

pred_arr = np.asarray(pred_arr)
Expand Down
5 changes: 3 additions & 2 deletions eks/multicam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ def ensemble_kalman_smoother_multicam(
y_m_smooth.T[camera_indices[camera][1]] + means_camera[camera_indices[camera][1]]
# compute zscore for EKS to see how it deviates from the ensemble
eks_predictions = np.asarray([eks_pred_x, eks_pred_y]).T
zscore = eks_zscore(eks_predictions, cam_ensemble_preds[camera], cam_ensemble_vars[camera],
min_ensemble_std=zscore_threshold)
zscore, _ = eks_zscore(
eks_predictions, cam_ensemble_preds[camera], cam_ensemble_vars[camera],
min_ensemble_std=zscore_threshold)
pred_arr = np.vstack([
eks_pred_x,
eks_pred_y,
Expand Down
Loading

0 comments on commit d601c1a

Please sign in to comment.