Skip to content

gpu acceleration removed #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 0 additions & 159 deletions eks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,165 +529,6 @@ def single_timestep_nll(innovation, innovation_cov):
return nll_increment


# ----- Parallel Functions for GPU -----

def first_filtering_element(C, A, Q, R, m0, P0, y):
# model.F = A, model.H = C,
S = C @ Q @ C.T + R
CF, low = jsc.linalg.cho_factor(S) # note the jsc

m1 = A @ m0
P1 = A @ P0 @ A.T + Q
S1 = C @ P1 @ C.T + R
K1 = jsc.linalg.solve(S1, C @ P1, assume_a='pos').T # note the jsc

A_updated = jnp.zeros_like(A)
b = m1 + K1 @ (y - C @ m1)
C_updated = P1 - K1 @ S1 @ K1.T

# note the jsc
eta = A.T @ C.T @ jsc.linalg.cho_solve((CF, low), y)
J = A.T @ C.T @ jsc.linalg.cho_solve((CF, low), C @ A)
return A_updated, b, C_updated, J, eta


def generic_filtering_element(C, A, Q, R, y):
S = C @ Q @ C.T + R
CF, low = jsc.linalg.cho_factor(S) # note the jsc
K = jsc.linalg.cho_solve((CF, low), C @ Q).T # note the jsc
A_updated = A - K @ C @ A
b = K @ y
C_updated = Q - K @ C @ Q

# note the jsc
eta = A.T @ C.T @ jsc.linalg.cho_solve((CF, low), y)
J = A.T @ C.T @ jsc.linalg.cho_solve((CF, low), C @ A)
return A_updated, b, C_updated, J, eta


def make_associative_filtering_elements(C, A, Q, R, m0, P0, observations):
first_elems = first_filtering_element(C, A, Q, R, m0, P0, observations[0])
generic_elems = vmap(lambda o: generic_filtering_element(C, A, Q, R, o))(observations[1:])
return tuple(jnp.concatenate([jnp.expand_dims(first_e, 0), gen_es])
for first_e, gen_es in zip(first_elems, generic_elems))


@partial(vmap)
def filtering_operator(elem1, elem2):
# # note the jsc everywhere
A1, b1, C1, J1, eta1 = elem1
A2, b2, C2, J2, eta2 = elem2
dim = A1.shape[0]
I_var = jnp.eye(dim) # note the jnp

I_C1J2 = I_var + C1 @ J2
temp = jsc.linalg.solve(I_C1J2.T, A2.T).T
A = temp @ A1
b = temp @ (b1 + C1 @ eta2) + b2
C = temp @ C1 @ A2.T + C2

I_J2C1 = I_var + J2 @ C1
temp = jsc.linalg.solve(I_J2C1.T, A1).T

eta = temp @ (eta2 - J2 @ b1) + eta1
J = temp @ J2 @ A1 + J1

return A, b, C, J, eta


def pkf(y, m0, cov0, A, Q, C, R):
initial_elements = make_associative_filtering_elements(C, A, Q, R, m0, cov0, y)
final_elements = associative_scan(filtering_operator, initial_elements)
return final_elements


pkf_func = jit(pkf)


def get_kalman_means(A_scan, b_scan, m0):
"""
Computes the Kalman mean at a single timepoint, the result is:
A_scan @ m0 + b_scan

Returned shape: (state_dimension, 1)
"""
return A_scan @ jnp.expand_dims(m0, axis=1) + jnp.expand_dims(b_scan, axis=1)


def get_kalman_variances(C):
return C


def get_next_cov(A, C, Q, R, filter_cov, filter_mean):
"""
Given the moments of p(x_t | y_1, ..., y_t) (normal filter distribution),
compute the moments of the distribution for:
p(y_{t+1} | y_1, ..., y_t)

Params:
A (np.ndarray): Shape (state_dimension, state_dimension) Process coeff matrix
C (np.ndarray): Shape (obs_dimension, state_dimension) Observation coeff matrix
Q (np.ndarray): Shape (state_dimension, state_dimension). Process noise covariance matrix.
R (np.ndarray): Shape (obs_dimension, obs_dimension). Observation noise covariance matrix.
filter_cov (np.ndarray). Shape (state_dimension, state_dimension). Filtered covariance
filter_mean (np.ndarray). Shape (state_dimension, 1). Filter mean

Returns:
mean (np.ndarray). Shape (obs_dimension, 1)
cov (np.ndarray). Shape (obs_dimension, obs_dimension).
"""
mean = C @ A @ filter_mean
cov = C @ (A @ filter_cov @ A.T + Q) @ C.T + R
return mean, cov


def compute_marginal_nll(value, mean, covariance):
return -1 * jax.scipy.stats.multivariate_normal.logpdf(value, mean, covariance)


def parallel_loss_single(A_scan, b_scan, C_scan, A, C, Q, R, next_observation, m0):
curr_mean = get_kalman_means(A_scan, b_scan, m0)
curr_cov = get_kalman_variances(C_scan) # Placeholder; just returns identity

next_mean, next_cov = get_next_cov(A, C, Q, R, curr_cov, curr_mean)
return jnp.squeeze(curr_mean), curr_cov, compute_marginal_nll(jnp.squeeze(next_observation),
jnp.squeeze(next_mean), next_cov)


parallel_loss_func_vmap = jit(
vmap(parallel_loss_single, in_axes=(0, 0, 0, None, None, None, None, 0, None),
out_axes=(0, 0, 0)))


@partial(jit)
def y1_given_x0_nll(C, A, Q, R, m0, cov0, obs):
y1_predictive_mean = C @ A @ jnp.expand_dims(m0, axis=1)
y1_predictive_cov = C @ (A @ cov0 @ A.T + Q) @ C.T + R
addend = -1 * jax.scipy.stats.multivariate_normal.logpdf(obs, jnp.squeeze(y1_predictive_mean),
y1_predictive_cov)
return addend


def pkf_and_loss(y, m0, cov0, A, Q, C, R):
A_scan, b_scan, C_scan, _, _ = pkf_func(y, m0, cov0, A, Q, C, R)

# Gives us the NLL for p(y_i | y_1, ..., y_{i-1}) for i > 1.
# Need to use the parallel scan outputs for this. i = 1 handled below
filtered_states, filtered_covariances, losses = parallel_loss_func_vmap(A_scan[:-1],
b_scan[:-1],
C_scan[:-1], A, C, Q,
R, y[1:], m0)

# Gives us the NLL for p_y(y_1 | x_0)
addend = y1_given_x0_nll(C, A, Q, R, m0, cov0, y[0])

final_mean = get_kalman_means(A_scan[-1], b_scan[-1], m0).T
final_covariance = jnp.expand_dims(get_kalman_variances(C_scan[-1]), axis=0)
filtered_states = jnp.concatenate([filtered_states, final_mean], axis=0)
filtered_variances = jnp.concatenate([filtered_covariances, final_covariance], axis=0)
return filtered_states, filtered_variances, jnp.sum(losses) + addend


# -------------------------------------------------------------------------------------
# Misc: These miscellaneous functions generally have specific computations used by the
# core functions or the smoothers
Expand Down
16 changes: 7 additions & 9 deletions eks/multicam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def ensemble_kalman_smoother_multicam(
"""

# --------------------------------------------------------------
# interpolate right cam markers to left cam timestamps
# Setup: Interpolate right cam markers to left cam timestamps
# --------------------------------------------------------------
num_cameras = len(camera_names)
markers_list_stacked_interp = []
Expand All @@ -219,6 +219,7 @@ def ensemble_kalman_smoother_multicam(
for camera in range(num_cameras):
markers_list_interp[camera].append(camera_markers_curr[camera])
camera_likelihoods[camera] = np.asarray(camera_likelihoods[camera])

markers_list_stacked_interp = np.asarray(markers_list_stacked_interp)
markers_list_interp = np.asarray(markers_list_interp)
camera_likelihoods_stacked = np.asarray(camera_likelihoods_stacked)
Expand All @@ -230,6 +231,7 @@ def ensemble_kalman_smoother_multicam(
markers_cam = pd.DataFrame(markers_list_interp[camera][k], columns=keys)
markers_cam[f'{keypoint_ensemble}_likelihood'] = camera_likelihoods_stacked[k][camera]
markers_list_cams[camera].append(markers_cam)

# compute ensemble median for each camera
cam_ensemble_preds = []
cam_ensemble_vars = []
Expand Down Expand Up @@ -280,21 +282,17 @@ def ensemble_kalman_smoother_multicam(
# latent variables (observed)
good_z_t_obs = good_ensemble_pcs # latent variables - true 3D pca

# ------ Set values for kalman filter ------
# --------------------------------------------------------------
# Kalman Filter
# --------------------------------------------------------------
m0 = np.asarray([0.0, 0.0, 0.0]) # initial state: mean
S0 = np.asarray([[np.var(good_z_t_obs[:, 0]), 0.0, 0.0],
[0.0, np.var(good_z_t_obs[:, 1]), 0.0],
[0.0, 0.0, np.var(good_z_t_obs[:, 2])]]) # diagonal: var

A = np.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) # state-transition matrix,

# Q = np.asarray([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) <-- state-cov matrix?

A = np.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) # state-transition matrix
d_t = good_z_t_obs[1:] - good_z_t_obs[:-1]

C = ensemble_pca.components_.T # Measurement function is inverse transform of PCA
R = np.eye(ensemble_pca.components_.shape[1]) # placeholder diagonal matrix for ensemble var

cov_matrix = np.cov(d_t.T)

# Call functions from ensemble_kalman to optimize smooth_param before filtering and smoothing
Expand Down
75 changes: 6 additions & 69 deletions eks/singlecam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
jax_ensemble,
jax_forward_pass,
jax_forward_pass_nlls,
pkf_and_loss,
)
from eks.utils import crop_frames, format_data, make_dlc_pandas_index

Expand Down Expand Up @@ -377,30 +376,13 @@ def singlecam_optimize_smooth(
if verbose:
print(f'Correlated keypoint blocks: {blocks}')

# Depending on whether we use GPU, choose parallel or sequential smoothing param optimization
try:
_ = jax.device_put(jax.numpy.ones(1), device=jax.devices('gpu')[0])
if verbose:
print("Using GPU")
@partial(jit)
def nll_loss_sequential_scan(s, cov_mats, cropped_ys, m0s, S0s, Cs, As, Rs, ensemble_vars):
s = jnp.exp(s) # To ensure positivity
return singlecam_smooth_min(
s, cov_mats, cropped_ys, m0s, S0s, Cs, As, Rs, ensemble_vars)

@partial(jit)
def nll_loss_parallel_scan(s, cov_mats, cropped_ys, m0s, S0s, Cs, As, Rs):
s = jnp.exp(s) # To ensure positivity
output = singlecam_smooth_min_parallel(s, cov_mats, cropped_ys, m0s, S0s, Cs, As, Rs)
return output

loss_function = nll_loss_parallel_scan
except:
if verbose:
print("Using CPU")

@partial(jit)
def nll_loss_sequential_scan(s, cov_mats, cropped_ys, m0s, S0s, Cs, As, Rs, ensemble_vars):
s = jnp.exp(s) # To ensure positivity
return singlecam_smooth_min(
s, cov_mats, cropped_ys, m0s, S0s, Cs, As, Rs, ensemble_vars)

loss_function = nll_loss_sequential_scan
loss_function = nll_loss_sequential_scan

# Optimize smooth_param
if smooth_param is not None:
Expand Down Expand Up @@ -516,51 +498,6 @@ def singlecam_smooth_min(smooth_param, cov_mats, ys, m0s, S0s, Cs, As, Rs, ensem
return nlls


def inner_smooth_min_routine_parallel(y, m0, S0, A, Q, C, R):
# Run filtering with the current smooth_param
means, covariances, NLL = pkf_and_loss(y, m0, S0, A, Q, C, R)
return jnp.sum(NLL)


inner_smooth_min_routine_parallel_vmap = jit(
vmap(inner_smooth_min_routine_parallel, in_axes=(0, 0, 0, 0, 0, 0, 0)))


# ------------------------------------------------------------------------------------------------
# Routines that use the parallel scan kalman filter implementation to arrive at the NLL function.
# Note: This should only be run on GPUs
# ------------------------------------------------------------------------------------------------

def singlecam_smooth_min_parallel(
smooth_param, cov_mats, observations, initial_means, initial_covariances, Cs, As, Rs,
):
"""
Computes the maximum likelihood estimator for the process noise variance (smoothness param).
This function is parallelized to process all keypoints in a given block.
KEY: This function uses the parallel scan algorithm, which has effectively O(log(n))
runtime on GPUs. On CPUs, it is slower than the jax.lax.scan implementation above.

Parameters:
smooth_param (float): Smoothing parameter.
block (list): List of blocks.
cov_mats (np.ndarray): Covariance matrices.
ys (np.ndarray): Observations.
m0s (np.ndarray): Initial mean state.
S0s (np.ndarray): Initial state covariance.
Cs (np.ndarray): Measurement function.
As (np.ndarray): State-transition matrix.
Rs (np.ndarray): Measurement noise covariance.

Returns:
float: Negative log-likelihood.
"""
# Adjust Q based on smooth_param and cov_matrix
Qs = smooth_param * cov_mats
values = inner_smooth_min_routine_parallel_vmap(observations, initial_means,
initial_covariances, As, Qs, Cs, Rs)
return jnp.sum(values)


def final_forwards_backwards_pass(process_cov, s, ys, m0s, S0s, Cs, As, Rs, ensemble_vars):
"""
Perform final smoothing with the optimized smoothing parameters.
Expand Down
Loading