Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ieivanov committed Jul 1, 2024
1 parent 13f0d4b commit 1983410
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 27 deletions.
6 changes: 4 additions & 2 deletions mantis/cli/estimate_stabilization.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def estimate_position_focus(


def get_mean_z_positions(dataframe_path: Path, verbose: bool = False) -> None:
import matplotlib.pyplot as plt

df = pd.read_csv(dataframe_path)

# Sort the DataFrame based on 'time_idx'
Expand All @@ -79,13 +77,17 @@ def get_mean_z_positions(dataframe_path: Path, verbose: bool = False) -> None:

# Get the mean of positions for each time point
average_focus_idx = df.groupby("time_idx")["focus_idx"].mean().reset_index()

if verbose:
import matplotlib.pyplot as plt

# Get the moving average of the focus_idx
plt.plot(average_focus_idx["focus_idx"], linestyle="--", label="mean of all positions")
plt.xlabel('Time index')
plt.ylabel('Focus index')
plt.legend()
plt.savefig(dataframe_path.parent / "z_drift.png")

return average_focus_idx["focus_idx"].values


Expand Down
40 changes: 15 additions & 25 deletions mantis/tests/test_analysis/test_stabilization.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,30 @@
import io

import numpy as np
import pandas as pd

from mantis.cli.estimate_stabilization import estimate_position_focus, get_mean_z_positions


def test_estimate_position_focus():
# Create input data
z_positions = [1, 2, 3, 4, 5]
focus_scores = [0.1, 0.2, 0.3, 0.4, 0.5]

# Call the function
result = estimate_position_focus(z_positions, focus_scores)

# Check the result
assert isinstance(result, tuple)
assert len(result) == 2
assert isinstance(result[0], float)
assert isinstance(result[1], float)
from mantis.cli.estimate_stabilization import get_mean_z_positions


def test_mean_z_positions():
# Create input data
df = pd.DataFrame(
{
"channel": ["GFP"],
"channel_idx": [0, 0, 0, 0, 0, 0, 0, 0],
"time_min": [1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0],
"focal_idx": [20, 25, 19, 16, 24, 29, 21, 18],
"position_idx": [1, 2, 3, 4, 1, 2, 3, 4],
"position": ['0/2/000000'] * 4 + ['0/2/001002'] * 4,
"time_idx": [0, 1, 2, 3] * 2,
"channel": ["GFP"] * 8,
"focus_idx": [20, 25, np.nan, 16, 24, 29, 21, 18],
}
)

# Create a pretend file
s_buf = io.StringIO()
df.to_csv(s_buf)
s_buf.seek(0)

# Call the function
result = get_mean_z_positions(df, 0)
pos_1 = np.array([20, 25, 19, 16]).mean()
pos_2 = np.array([24, 29, 21, 18]).mean()
result = get_mean_z_positions(s_buf, verbose=False)
correct_result = np.array([22, 27, 25, 17])

# Check the result
assert result == [pos_1, pos_2]
assert all(result == correct_result)

0 comments on commit 1983410

Please sign in to comment.