diff --git a/src/wristpy/core/orchestrator.py b/src/wristpy/core/orchestrator.py index 072c25c..0f99cd3 100644 --- a/src/wristpy/core/orchestrator.py +++ b/src/wristpy/core/orchestrator.py @@ -230,19 +230,7 @@ def run( enmo = computations.moving_mean(enmo, epoch_length=epoch_length) anglez = computations.moving_mean(anglez, epoch_length=epoch_length) - # Watches require different criteria due to differences in the sensor values on the - # lower end of the distribution. - - if input.suffix == ".bin": - range_criterion = 0.5 - elif input.suffix == ".gt3x": - range_criterion = 0.05 - else: - raise exceptions.InvalidFileTypeError("Unknown input file type.") - - non_wear_array = metrics.detect_nonwear( - calibrated_acceleration, range_criteria=range_criterion - ) + non_wear_array = metrics.detect_nonwear(calibrated_acceleration) sleep_detector = analytics.GGIRSleepDetection(anglez) sleep_windows = sleep_detector.run_sleep_detection() diff --git a/src/wristpy/processing/metrics.py b/src/wristpy/processing/metrics.py index be44459..be2d6b1 100644 --- a/src/wristpy/processing/metrics.py +++ b/src/wristpy/processing/metrics.py @@ -63,16 +63,15 @@ def detect_nonwear( short_epoch_length: int = 900, n_short_epoch_in_long_epoch: int = 4, std_criteria: float = 0.013, - range_criteria: float = 0.05, ) -> models.Measurement: """Set non_wear_flag based on accelerometer data. - This implements GGIR "2023" non-wear detection algorithm. + This implements a modified version of the GGIR "2023" non-wear detection algorithm. Briefly, the algorithm, creates a sliding window of long epoch length that steps forward by the short epoch length. The long epoch length is an integer multiple of the short epoch length, that can be specified by the user. - It checks if the acceleration data in that long window, for each axis, meets certain - criteria thresholds for the standard deviation and range of acceleration values to + It checks if the acceleration data in that long window, for each axis, meets the + criteria threshold for the standard deviation of acceleration values to compute a non-wear value. The total non-wear value (0, 1, 2, 3) for the long window is the sum of each axis. The non-wear value is applied to all the short windows that make up the long @@ -90,7 +89,7 @@ def detect_nonwear( short_epoch_length: The short window size, in seconds. n_short_epoch_in_long_epoch: Number of short epochs that makeup one long epoch. std_criteria: Threshold criteria for standard deviation. - range_criteria: Threshold criteria for range of acceleration. + Returns: A new Measurment instance with the non-wear flag and corresponding timestamps. @@ -104,7 +103,6 @@ def detect_nonwear( acceleration_grouped_by_short_window, n_short_epoch_in_long_epoch, std_criteria, - range_criteria, ) nonwear_value_array_cleaned = _cleanup_isolated_ones_nonwear_value( @@ -152,7 +150,6 @@ def _compute_nonwear_value_array( grouped_acceleration: pl.DataFrame, n_short_epoch_in_long_epoch: int, std_criteria: float, - range_criteria: float, ) -> np.ndarray: """Helper function to calculate the nonwear value array. @@ -167,7 +164,6 @@ def _compute_nonwear_value_array( grouped_acceleration: The acceleration data grouped into short windows. n_short_epoch_in_long_epoch: Number of short epochs that makeup one long epoch. std_criteria: Threshold criteria for standard deviation. - range_criteria: Threshold criteria for range of acceleration. Returns: Non-wear value array. @@ -183,7 +179,8 @@ def _compute_nonwear_value_array( calculated_nonwear_value = acceleration_selected_long_window.select( pl.col("X", "Y", "Z").map_batches( lambda df: _compute_nonwear_value_per_axis( - df, std_criteria, range_criteria + df, + std_criteria, ) ) ).sum_horizontal() @@ -200,7 +197,8 @@ def _compute_nonwear_value_array( def _compute_nonwear_value_per_axis( - axis_acceleration_data: pl.Series, std_criteria: float, range_criteria: float + axis_acceleration_data: pl.Series, + std_criteria: float, ) -> bool: """Helper function to calculate the nonwear criteria per axis. @@ -210,16 +208,15 @@ def _compute_nonwear_value_per_axis( acceleration data of one axis (length of each list is the number of samples that make up short_epoch_length in seconds). std_criteria: Threshold criteria for standard deviation - range_criteria: Threshold criteria for range of acceleration + Returns: Non-wear value for the axis. """ axis_long_window_data = pl.concat(axis_acceleration_data, how="vertical") axis_std = axis_long_window_data.std() - axis_range = axis_long_window_data.max() - axis_long_window_data.min() + criteria_boolean = axis_std < std_criteria - criteria_boolean = (axis_std < std_criteria) & (axis_range < range_criteria) return criteria_boolean diff --git a/tests/smoke/test_orchestrator_smoke.py b/tests/smoke/test_orchestrator_smoke.py index fb26694..900017f 100644 --- a/tests/smoke/test_orchestrator_smoke.py +++ b/tests/smoke/test_orchestrator_smoke.py @@ -4,7 +4,7 @@ import pytest -from wristpy.core import orchestrator +from wristpy.core import models, orchestrator @pytest.mark.parametrize( @@ -18,11 +18,11 @@ def test_orchestrator_happy_path( assert (tmp_path / file_name).exists() assert isinstance(results, orchestrator.Results) - assert results.enmo is not None - assert results.anglez is not None - assert results.nonwear_epoch is not None - assert results.sleep_windows_epoch is not None - assert results.physical_activity_levels is not None + assert isinstance(results.enmo, models.Measurement) + assert isinstance(results.anglez, models.Measurement) + assert isinstance(results.nonwear_epoch, models.Measurement) + assert isinstance(results.sleep_windows_epoch, models.Measurement) + assert isinstance(results.physical_activity_levels, models.Measurement) def test_orchestrator_different_epoch( @@ -35,8 +35,8 @@ def test_orchestrator_different_epoch( assert (tmp_path / "good_file.csv").exists() assert isinstance(results, orchestrator.Results) - assert results.enmo is not None - assert results.anglez is not None - assert results.nonwear_epoch is not None - assert results.sleep_windows_epoch is not None - assert results.physical_activity_levels is not None + assert isinstance(results.enmo, models.Measurement) + assert isinstance(results.anglez, models.Measurement) + assert isinstance(results.nonwear_epoch, models.Measurement) + assert isinstance(results.sleep_windows_epoch, models.Measurement) + assert isinstance(results.physical_activity_levels, models.Measurement) diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index 745c152..7430b9f 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -155,15 +155,13 @@ def test_compute_nonwear_value_per_axis( create_acceleration: pl.DataFrame, modifier: int, expected_result: int ) -> None: """Test the nonwear value per axis function.""" - std_criteria = modifier - range_criteria = modifier acceleration = create_acceleration.with_columns(pl.col("time").set_sorted()) acceleration_grouped = acceleration.group_by_dynamic( index_column="time", every="5s" ).agg([pl.all().exclude(["time"])]) test_resultx = metrics._compute_nonwear_value_per_axis( - acceleration_grouped["X"], std_criteria, range_criteria + acceleration_grouped["X"], std_criteria=modifier ) assert ( @@ -185,7 +183,6 @@ def test_compute_nonwear_value_array(create_acceleration: pl.DataFrame) -> None: acceleration_grouped, n_short_epoch_in_long_epoch, std_criteria=1, - range_criteria=1, ) assert np.all( @@ -209,8 +206,6 @@ def test_detect_nonwear( """Test the detect nonwear function.""" short_epoch_length = 5 n_short_epoch_in_long_epoch = int(4) - std_criteria = modifier - range_criteria = modifier acceleration_df = create_acceleration acceleration = models.Measurement( measurements=acceleration_df.select(["X", "Y", "Z"]).to_numpy(), @@ -222,8 +217,7 @@ def test_detect_nonwear( acceleration, short_epoch_length, n_short_epoch_in_long_epoch, - std_criteria, - range_criteria, + std_criteria=modifier, ) assert np.all(