diff --git a/mantis/cli/estimate_stabilization.py b/mantis/cli/estimate_stabilization.py index ac43050a..c152e038 100644 --- a/mantis/cli/estimate_stabilization.py +++ b/mantis/cli/estimate_stabilization.py @@ -67,8 +67,6 @@ def estimate_position_focus( def get_mean_z_positions(dataframe_path: Path, verbose: bool = False) -> None: - import matplotlib.pyplot as plt - df = pd.read_csv(dataframe_path) # Sort the DataFrame based on 'time_idx' @@ -79,13 +77,17 @@ def get_mean_z_positions(dataframe_path: Path, verbose: bool = False) -> None: # Get the mean of positions for each time point average_focus_idx = df.groupby("time_idx")["focus_idx"].mean().reset_index() + if verbose: + import matplotlib.pyplot as plt + # Get the moving average of the focus_idx plt.plot(average_focus_idx["focus_idx"], linestyle="--", label="mean of all positions") plt.xlabel('Time index') plt.ylabel('Focus index') plt.legend() plt.savefig(dataframe_path.parent / "z_drift.png") + return average_focus_idx["focus_idx"].values diff --git a/mantis/tests/test_analysis/test_stabilization.py b/mantis/tests/test_analysis/test_stabilization.py index f487ea01..efbbc16b 100644 --- a/mantis/tests/test_analysis/test_stabilization.py +++ b/mantis/tests/test_analysis/test_stabilization.py @@ -1,40 +1,30 @@ +import io + import numpy as np import pandas as pd -from mantis.cli.estimate_stabilization import estimate_position_focus, get_mean_z_positions - - -def test_estimate_position_focus(): - # Create input data - z_positions = [1, 2, 3, 4, 5] - focus_scores = [0.1, 0.2, 0.3, 0.4, 0.5] - - # Call the function - result = estimate_position_focus(z_positions, focus_scores) - - # Check the result - assert isinstance(result, tuple) - assert len(result) == 2 - assert isinstance(result[0], float) - assert isinstance(result[1], float) +from mantis.cli.estimate_stabilization import get_mean_z_positions def test_mean_z_positions(): # Create input data df = pd.DataFrame( { - "channel": ["GFP"], - "channel_idx": [0, 0, 0, 0, 0, 0, 0, 0], - "time_min": [1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0], - "focal_idx": [20, 25, 19, 16, 24, 29, 21, 18], - "position_idx": [1, 2, 3, 4, 1, 2, 3, 4], + "position": ['0/2/000000'] * 4 + ['0/2/001002'] * 4, + "time_idx": [0, 1, 2, 3] * 2, + "channel": ["GFP"] * 8, + "focus_idx": [20, 25, np.nan, 16, 24, 29, 21, 18], } ) + # Create a pretend file + s_buf = io.StringIO() + df.to_csv(s_buf) + s_buf.seek(0) + # Call the function - result = get_mean_z_positions(df, 0) - pos_1 = np.array([20, 25, 19, 16]).mean() - pos_2 = np.array([24, 29, 21, 18]).mean() + result = get_mean_z_positions(s_buf, verbose=False) + correct_result = np.array([22, 27, 25, 17]) # Check the result - assert result == [pos_1, pos_2] + assert all(result == correct_result)