Skip to content

Commit

Permalink
flake8 cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
themattinthehatt committed Dec 11, 2024
1 parent fbc1369 commit 9e3440b
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 66 deletions.
18 changes: 8 additions & 10 deletions eks/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import partial
from collections import defaultdict

import jax
import jax.scipy as jsc
Expand Down Expand Up @@ -724,27 +723,26 @@ def eks_zscore(eks_predictions, ensemble_means, ensemble_vars, min_ensemble_std=


def compute_covariance_matrix(ensemble_preds):
"""
Compute the covariance matrix E for correlated noise dynamics.
"""Compute the covariance matrix E for correlated noise dynamics.
Parameters:
ensemble_preds: A 3D array of shape (T, n_keypoints, n_coords)
containing the ensemble predictions.
Args:
ensemble_preds: shape (T, n_keypoints, n_coords) containing the ensemble predictions.
Returns:
E: A 2K x 2K covariance matrix where K is the number of keypoints.
E: A 2K x 2K covariance matrix where K is the number of keypoints.
"""
# Get the number of time steps, keypoints, and coordinates
T, n_keypoints, n_coords = ensemble_preds.shape

# Flatten the ensemble predictions to shape (T, 2K) where K is the number of keypoints
flattened_preds = ensemble_preds.reshape(T, -1)
# flattened_preds = ensemble_preds.reshape(T, -1)

# Compute the temporal differences
temporal_diffs = np.diff(flattened_preds, axis=0)
# temporal_diffs = np.diff(flattened_preds, axis=0)

# Compute the covariance matrix of the temporal differences
E = np.cov(temporal_diffs, rowvar=False)
# E = np.cov(temporal_diffs, rowvar=False)

# Index covariance matrix into blocks for each keypoint
cov_mats = []
Expand Down
2 changes: 1 addition & 1 deletion eks/ibl_pupil_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def ensemble_kalman_smoother_ibl_pupil(
[0, 1, 0], [.5, 0, 1],
[.5, 1, 0], [0, 0, 1],
[-.5, 1, 0], [0, 0, 1]
])
])

# placeholder diagonal matrix for ensemble variance
R = np.eye(8)
Expand Down
6 changes: 4 additions & 2 deletions eks/multicam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,10 @@ def ensemble_kalman_smoother_multicam(
# --------------------------------------
# final cleanup
# --------------------------------------
pdindex = make_dlc_pandas_index([keypoint_ensemble],
labels=["x", "y", "likelihood", "x_var", "y_var", "zscore", "nll", "ensemble_std"])
pdindex = make_dlc_pandas_index(
[keypoint_ensemble],
labels=["x", "y", "likelihood", "x_var", "y_var", "zscore", "nll", "ensemble_std"]
)
camera_indices = []
for camera in range(num_cameras):
camera_indices.append([camera * 2, camera * 2 + 1])
Expand Down
2 changes: 1 addition & 1 deletion eks/singlecam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
jax_forward_pass_nlls,
pkf_and_loss,
)
from eks.utils import crop_frames, format_data, make_dlc_pandas_index, populate_output_dataframe
from eks.utils import crop_frames, format_data, make_dlc_pandas_index


def fit_eks_singlecam(
Expand Down
2 changes: 1 addition & 1 deletion scripts/ibl_pupil_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from eks.command_line_args import handle_io, handle_parse_args
from eks.ibl_pupil_smoother import fit_eks_pupil
from eks.utils import format_data, plot_results
from eks.utils import plot_results


smoother_type = 'ibl_pupil'
Expand Down
104 changes: 60 additions & 44 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pytest
import numpy as np
import jax.numpy as jnp
import jax
import pandas as pd
from eks.core import ensemble, kalman_dot, forward_pass, backward_pass, compute_nll, jax_ensemble
from collections import defaultdict


def test_ensemble():
Expand Down Expand Up @@ -47,11 +45,11 @@ def test_ensemble():
expected_mean = np.nanmedian(stack, axis=1)
expected_variance = np.nanvar(stack, axis=1)
assert np.allclose(ensemble_preds[:, i], expected_mean), \
f"Medians not computed correctly in numpy ensemble function"
"Medians not computed correctly in numpy ensemble function"
assert np.allclose(ensemble_vars[:, i], expected_variance), \
f"Vars not computed correctly in numpy ensemble function"
"Vars not computed correctly in numpy ensemble function"
assert np.all(ensemble_likes[:, i] == 0.5), \
f"Likelihoods not computed correctly in numpy ensemble function"
"Likelihoods not computed correctly in numpy ensemble function"

# Run the ensemble function with avg_mode='mean' and var_mode='conf_weighted_var'
ensemble_preds, ensemble_vars, ensemble_likes, ensemble_stacks = ensemble(
Expand All @@ -63,11 +61,11 @@ def test_ensemble():
expected_mean = np.nanmean(stack, axis=1)
expected_variance = 2.0 * np.nanvar(stack, axis=1) # 2x since likelihoods all 0.5
assert np.allclose(ensemble_preds[:, i], expected_mean), \
f"Means not computed correctly in numpy ensemble function"
"Means not computed correctly in numpy ensemble function"
assert np.allclose(ensemble_vars[:, i], expected_variance), \
f"Conf weighted vars not computed correctly in numpy ensemble function"
"Conf weighted vars not computed correctly in numpy ensemble function"
assert np.all(ensemble_likes[:, i] == 0.5), \
f"Likelihoods not computed correctly in numpy ensemble function"
"Likelihoods not computed correctly in numpy ensemble function"


def test_kalman_dot_basic():
Expand Down Expand Up @@ -166,7 +164,8 @@ def test_kalman_dot_random_values():
# Run kalman_dot
Ks, innovation_cov = kalman_dot(innovation, V, C, R)

# Check if innovation_cov is positive semi-definite (eigenvalues should be non-negative or close to zero)
# Check if innovation_cov is positive semi-definite
# (eigenvalues should be non-negative or close to zero)
eigvals = np.linalg.eigvalsh(innovation_cov)
assert np.all(eigvals >= -1e-8), "Expected innovation_cov to be positive semi-definite"
assert Ks.shape == (n_latents,), f"Expected shape {(n_latents,)}, got {Ks.shape}"
Expand Down Expand Up @@ -197,15 +196,16 @@ def test_forward_pass_basic():
mf, Vf, S, innovations, innovation_cov = forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars)

# Check output shapes
assert mf.shape == (T, n_latents), f"Expected shape {(T, n_latents)}, got {mf.shape}"
assert Vf.shape == (
T, n_latents, n_latents), f"Expected shape {(T, n_latents, n_latents)}, got {Vf.shape}"
assert S.shape == (
T, n_latents, n_latents), f"Expected shape {(T, n_latents, n_latents)}, got {S.shape}"
assert innovations.shape == (
T, n_keypoints), f"Expected shape {(T, n_keypoints)}, got {innovations.shape}"
assert innovation_cov.shape == (T, n_keypoints,
n_keypoints), f"Expected shape {(T, n_keypoints, n_keypoints)}, got {innovation_cov.shape}"
assert mf.shape == (T, n_latents), \
f"Expected shape {(T, n_latents)}, got {mf.shape}"
assert Vf.shape == (T, n_latents, n_latents), \
f"Expected shape {(T, n_latents, n_latents)}, got {Vf.shape}"
assert S.shape == (T, n_latents, n_latents), \
f"Expected shape {(T, n_latents, n_latents)}, got {S.shape}"
assert innovations.shape == (T, n_keypoints), \
f"Expected shape {(T, n_keypoints)}, got {innovations.shape}"
assert innovation_cov.shape == (T, n_keypoints, n_keypoints), \
f"Expected shape {(T, n_keypoints, n_keypoints)}, got {innovation_cov.shape}"


def test_forward_pass_with_nan_values():
Expand Down Expand Up @@ -236,10 +236,12 @@ def test_forward_pass_with_nan_values():
else:
if found_nan_propagation:
# Once NaNs are expected, allow them to propagate
assert np.isnan(mf[t]).all(), f"Expected NaNs in mf at time {t} due to propagation, found finite values"
assert np.isnan(mf[t]).all(), \
f"Expected NaNs in mf at time {t} due to propagation, found finite values"
else:
# Check for finite values up until the first NaN propagation
assert np.isfinite(mf[t]).all(), f"Expected finite values in mf at time {t}, found NaNs"
assert np.isfinite(mf[t]).all(), \
f"Expected finite values in mf at time {t}, found NaNs"

# Ensure Vf and innovation_cov have finite values where possible
assert np.isfinite(Vf).all(), "Non-finite values found in Vf"
Expand All @@ -265,13 +267,14 @@ def test_forward_pass_single_sample():
mf, Vf, S, innovations, innovation_cov = forward_pass(y, m0, S0, C, R, A, Q, ensemble_vars)

# Check output shapes with a single sample
assert mf.shape == (T, n_latents), f"Expected shape {(T, n_latents)}, got {mf.shape}"
assert Vf.shape == (
T, n_latents, n_latents), f"Expected shape {(T, n_latents, n_latents)}, got {Vf.shape}"
assert S.shape == (
T, n_latents, n_latents), f"Expected shape {(T, n_latents, n_latents)}, got {S.shape}"
assert innovations.shape == (
T, n_keypoints), f"Expected shape {(T, n_keypoints)}, got {innovations.shape}"
assert mf.shape == (T, n_latents), \
f"Expected shape {(T, n_latents)}, got {mf.shape}"
assert Vf.shape == (T, n_latents, n_latents), \
f"Expected shape {(T, n_latents, n_latents)}, got {Vf.shape}"
assert S.shape == (T, n_latents, n_latents), \
f"Expected shape {(T, n_latents, n_latents)}, got {S.shape}"
assert innovations.shape == (T, n_keypoints), \
f"Expected shape {(T, n_keypoints)}, got {innovations.shape}"
assert innovation_cov.shape == (T, n_keypoints, n_keypoints), \
f"Expected shape {(T, n_keypoints, n_keypoints)}, got {innovation_cov.shape}"

Expand Down Expand Up @@ -314,9 +317,12 @@ def test_backward_pass_basic():
ms, Vs, CV = backward_pass(y, mf, Vf, S, A)

# Verify shapes of output arrays
assert ms.shape == (T, n_latents), f"Expected shape {(T, n_latents)}, got {ms.shape}"
assert Vs.shape == (T, n_latents, n_latents), f"Expected shape {(T, n_latents, n_latents)}, got {Vs.shape}"
assert CV.shape == (T - 1, n_latents, n_latents), f"Expected shape {(T - 1, n_latents, n_latents)}, got {CV.shape}"
assert ms.shape == (T, n_latents), \
f"Expected shape {(T, n_latents)}, got {ms.shape}"
assert Vs.shape == (T, n_latents, n_latents), \
f"Expected shape {(T, n_latents, n_latents)}, got {Vs.shape}"
assert CV.shape == (T - 1, n_latents, n_latents), \
f"Expected shape {(T - 1, n_latents, n_latents)}, got {CV.shape}"

# Check that ms, Vs, and CV contain finite values
assert np.isfinite(ms).all(), "Non-finite values found in ms"
Expand All @@ -342,9 +348,12 @@ def test_backward_pass_with_nan_values():
ms, Vs, CV = backward_pass(y, mf, Vf, S, A)

# Verify shapes of output arrays
assert ms.shape == (T, n_latents), f"Expected shape {(T, n_latents)}, got {ms.shape}"
assert Vs.shape == (T, n_latents, n_latents), f"Expected shape {(T, n_latents, n_latents)}, got {Vs.shape}"
assert CV.shape == (T - 1, n_latents, n_latents), f"Expected shape {(T - 1, n_latents, n_latents)}, got {CV.shape}"
assert ms.shape == (T, n_latents), \
f"Expected shape {(T, n_latents)}, got {ms.shape}"
assert Vs.shape == (T, n_latents, n_latents), \
f"Expected shape {(T, n_latents, n_latents)}, got {Vs.shape}"
assert CV.shape == (T - 1, n_latents, n_latents), \
f"Expected shape {(T - 1, n_latents, n_latents)}, got {CV.shape}"

# Check that ms, Vs, and CV contain finite values
assert np.isfinite(ms).all(), "Non-finite values found in ms"
Expand All @@ -368,9 +377,12 @@ def test_backward_pass_single_timestep():
ms, Vs, CV = backward_pass(y, mf, Vf, S, A)

# Verify shapes of output arrays
assert ms.shape == (T, n_latents), f"Expected shape {(T, n_latents)}, got {ms.shape}"
assert Vs.shape == (T, n_latents, n_latents), f"Expected shape {(T, n_latents, n_latents)}, got {Vs.shape}"
assert CV.shape == (T - 1, n_latents, n_latents), f"Expected shape {(T - 1, n_latents, n_latents)}, got {CV.shape}"
assert ms.shape == (T, n_latents), \
f"Expected shape {(T, n_latents)}, got {ms.shape}"
assert Vs.shape == (T, n_latents, n_latents), \
f"Expected shape {(T, n_latents, n_latents)}, got {Vs.shape}"
assert CV.shape == (T - 1, n_latents, n_latents), \
f"Expected shape {(T - 1, n_latents, n_latents)}, got {CV.shape}"

# Check that ms and Vs contain finite values
assert np.isfinite(ms).all(), "Non-finite values found in ms"
Expand All @@ -395,11 +407,12 @@ def test_backward_pass_singular_S_matrix():
ms, Vs, CV = backward_pass(y, mf, Vf, S, A)

# Verify shapes of output arrays
assert ms.shape == (T, n_latents), f"Expected shape {(T, n_latents)}, got {ms.shape}"
assert Vs.shape == (
T, n_latents, n_latents), f"Expected shape {(T, n_latents, n_latents)}, got {Vs.shape}"
assert CV.shape == (T - 1, n_latents,
n_latents), f"Expected shape {(T - 1, n_latents, n_latents)}, got {CV.shape}"
assert ms.shape == (T, n_latents), \
f"Expected shape {(T, n_latents)}, got {ms.shape}"
assert Vs.shape == (T, n_latents, n_latents), \
f"Expected shape {(T, n_latents, n_latents)}, got {Vs.shape}"
assert CV.shape == (T - 1, n_latents, n_latents), \
f"Expected shape {(T - 1, n_latents, n_latents)}, got {CV.shape}"

# Check for finite values in outputs, expecting NaNs or Infs due to singular S
assert np.all(np.isfinite(ms)), "Non-finite values found in ms"
Expand Down Expand Up @@ -427,9 +440,12 @@ def test_backward_pass_random_values():
ms, Vs, CV = backward_pass(y, mf, Vf, S, A)

# Verify shapes of output arrays
assert ms.shape == (T, n_latents), f"Expected shape {(T, n_latents)}, got {ms.shape}"
assert Vs.shape == (T, n_latents, n_latents), f"Expected shape {(T, n_latents, n_latents)}, got {Vs.shape}"
assert CV.shape == (T - 1, n_latents, n_latents), f"Expected shape {(T - 1, n_latents, n_latents)}, got {CV.shape}"
assert ms.shape == (T, n_latents), \
f"Expected shape {(T, n_latents)}, got {ms.shape}"
assert Vs.shape == (T, n_latents, n_latents), \
f"Expected shape {(T, n_latents, n_latents)}, got {Vs.shape}"
assert CV.shape == (T - 1, n_latents, n_latents), \
f"Expected shape {(T - 1, n_latents, n_latents)}, got {CV.shape}"

# Check that ms, Vs, and CV contain finite values
assert np.isfinite(ms).all(), "Non-finite values found in ms"
Expand Down
2 changes: 0 additions & 2 deletions tests/test_ibl_pupil_smoother.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from unittest.mock import MagicMock, patch

import numpy as np
import pandas as pd
import pytest
Expand Down
5 changes: 0 additions & 5 deletions tests/test_singlecam_smoother.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import os
from unittest.mock import MagicMock, patch

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import pytest

from eks.singlecam_smoother import (
adjust_observations,
Expand Down

0 comments on commit 9e3440b

Please sign in to comment.