From adbc5ff8596fba3011910327f76bcfd2f86171f0 Mon Sep 17 00:00:00 2001 From: Harshil <37377066+harshil21@users.noreply.github.com> Date: Mon, 16 Sep 2024 16:55:52 -0400 Subject: [PATCH] More tests and fix few bugs --- airbrakes/airbrakes.py | 4 ++ airbrakes/imu/data_processor.py | 21 +++++--- airbrakes/logger.py | 2 +- airbrakes/servo.py | 8 ++- airbrakes/state.py | 7 +-- tests/conftest.py | 27 ++++++++++ tests/test_airbrakes.py | 22 ++++++++ tests/test_data_processor.py | 96 +++++++++++++++++++++++++++++++++ tests/test_imu.py | 9 ++-- tests/test_logger.py | 27 +++++----- 10 files changed, 193 insertions(+), 30 deletions(-) create mode 100644 tests/test_airbrakes.py create mode 100644 tests/test_data_processor.py diff --git a/airbrakes/airbrakes.py b/airbrakes/airbrakes.py index dbe3ec48..8fb133a9 100644 --- a/airbrakes/airbrakes.py +++ b/airbrakes/airbrakes.py @@ -66,6 +66,10 @@ def update(self) -> None: # behind on processing data_packets: collections.deque[IMUDataPacket] = self.imu.get_imu_data_packets() + # This should never happen, but if it does, we want to not error out and wait for packets + if not data_packets: + return + # Update the processed data with the new data packets. We only care about EstimatedDataPackets self.data_processor.update_data( [data_packet for data_packet in data_packets if isinstance(data_packet, EstimatedDataPacket)] diff --git a/airbrakes/imu/data_processor.py b/airbrakes/imu/data_processor.py index e74c9ec2..9904b85e 100644 --- a/airbrakes/imu/data_processor.py +++ b/airbrakes/imu/data_processor.py @@ -15,13 +15,18 @@ class IMUDataProcessor: :param data_points: A sequence of EstimatedDataPacket objects to process. """ - __slots__ = ("_avg_accel", "_avg_accel_mag", "_max_altitude", "data_points") + __slots__ = ("_avg_accel", "_avg_accel_mag", "_max_altitude", "_data_points") def __init__(self, data_points: Sequence[EstimatedDataPacket]): - self.data_points: Sequence[EstimatedDataPacket] = data_points self._avg_accel: tuple[float, float, float] = (0.0, 0.0, 0.0) self._avg_accel_mag: float = 0.0 self._max_altitude: float = 0.0 + self._data_points: Sequence[EstimatedDataPacket] + + if data_points: # actually update the data on init + self.update_data(data_points) + else: + self._data_points: Sequence[EstimatedDataPacket] = data_points def __str__(self) -> str: return ( @@ -37,11 +42,13 @@ def update_data(self, data_points: Sequence[EstimatedDataPacket]) -> None: altitude. :param data_points: A sequence of EstimatedDataPacket objects to process. """ - self.data_points = data_points + # if not data_points: + # raise ValueError("Data packets must be non-empty!") + self._data_points = data_points a_x, a_y, a_z = self._compute_averages() self._avg_accel = (a_x, a_y, a_z) self._avg_accel_mag = (self._avg_accel[0] ** 2 + self._avg_accel[1] ** 2 + self._avg_accel[2] ** 2) ** 0.5 - self._max_altitude = max(*(data_point.estPressureAlt for data_point in self.data_points), self._max_altitude) + self._max_altitude = max(*(data_point.estPressureAlt for data_point in self._data_points), self._max_altitude) def _compute_averages(self) -> tuple[float, float, float]: """ @@ -50,9 +57,9 @@ def _compute_averages(self) -> tuple[float, float, float]: """ # calculate the average acceleration in the x, y, and z directions # TODO: Test what these accel values actually look like - x_accel = stats.fmean(data_point.estCompensatedAccelX for data_point in self.data_points) - y_accel = stats.fmean(data_point.estCompensatedAccelY for data_point in self.data_points) - z_accel = stats.fmean(data_point.estCompensatedAccelZ for data_point in self.data_points) + x_accel = stats.fmean(data_point.estCompensatedAccelX for data_point in self._data_points) + y_accel = stats.fmean(data_point.estCompensatedAccelY for data_point in self._data_points) + z_accel = stats.fmean(data_point.estCompensatedAccelZ for data_point in self._data_points) # TODO: Calculate avg velocity if that's also available return x_accel, y_accel, z_accel diff --git a/airbrakes/logger.py b/airbrakes/logger.py index 000a445a..c7a98e1e 100644 --- a/airbrakes/logger.py +++ b/airbrakes/logger.py @@ -43,7 +43,7 @@ def __init__(self, log_dir: Path): self._log_queue: multiprocessing.Queue[dict[str, str] | str] = multiprocessing.Queue() # Start the logging process - self._log_process = multiprocessing.Process(target=self._logging_loop) + self._log_process = multiprocessing.Process(target=self._logging_loop, name="Logger") @property def is_running(self) -> bool: diff --git a/airbrakes/servo.py b/airbrakes/servo.py index 737bafa0..99d10fd7 100644 --- a/airbrakes/servo.py +++ b/airbrakes/servo.py @@ -10,14 +10,18 @@ class Servo: __slots__ = ("current_extension", "max_extension", "min_extension", "servo") - def __init__(self, gpio_pin_number: int, min_extension: float, max_extension: float): + def __init__(self, gpio_pin_number: int, min_extension: float, max_extension: float, + pin_factory=None): self.min_extension = min_extension self.max_extension = max_extension self.current_extension = 0.0 # Sets up the servo with the specified GPIO pin number # For this to work, you have to run the pigpio daemon on the Raspberry Pi (sudo pigpiod) - gpiozero.Device.pin_factory = gpiozero.pins.pigpio.PiGPIOFactory() + if pin_factory is None: + gpiozero.Device.pin_factory = gpiozero.pins.pigpio.PiGPIOFactory() + else: + gpiozero.Device.pin_factory = pin_factory self.servo = gpiozero.Servo(gpio_pin_number) def set_extension(self, extension: float): diff --git a/airbrakes/state.py b/airbrakes/state.py index 602134de..19f9f8a3 100644 --- a/airbrakes/state.py +++ b/airbrakes/state.py @@ -1,9 +1,10 @@ """Module for the finite state machine that represents which state of flight we are in.""" from abc import ABC, abstractmethod -from typing import override +from typing import override, TYPE_CHECKING -from airbrakes.airbrakes import AirbrakesContext +if TYPE_CHECKING: + from airbrakes.airbrakes import AirbrakesContext class State(ABC): @@ -22,7 +23,7 @@ class State(ABC): __slots__ = ("context",) - def __init__(self, context: AirbrakesContext): + def __init__(self, context: "AirbrakesContext"): """ :param context: The state context object that will be used to interact with the electronics """ diff --git a/tests/conftest.py b/tests/conftest.py index e69de29b..22c0265b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -0,0 +1,27 @@ +"""Module where fixtures are shared between all test files.""" + +import pytest +from pathlib import Path + +from airbrakes.constants import FREQUENCY, PORT, UPSIDE_DOWN, SERVO_PIN, MIN_EXTENSION, MAX_EXTENSION +from airbrakes.imu.imu import IMU +from airbrakes.logger import Logger +from airbrakes.servo import Servo + + +from gpiozero.pins.mock import MockFactory, MockPWMPin + + +LOG_PATH = Path("tests/logs") + +@pytest.fixture +def logger(): + return Logger(LOG_PATH) + +@pytest.fixture +def imu(): + return IMU(port=PORT, frequency=FREQUENCY, upside_down=UPSIDE_DOWN) + +@pytest.fixture +def servo(): + return Servo(SERVO_PIN, MIN_EXTENSION, MAX_EXTENSION, pin_factory=MockFactory(pin_class=MockPWMPin)) \ No newline at end of file diff --git a/tests/test_airbrakes.py b/tests/test_airbrakes.py new file mode 100644 index 00000000..2953009c --- /dev/null +++ b/tests/test_airbrakes.py @@ -0,0 +1,22 @@ +import pytest + +from airbrakes.airbrakes import AirbrakesContext + + +@pytest.fixture +def airbrakes_context(imu, logger, servo): + return AirbrakesContext(logger, servo, imu) + + +class TestAirbrakesContext: + """Tests the AirbrakesContext class""" + + def test_slots(self, airbrakes_context): + inst = airbrakes_context + for attr in inst.__slots__: + assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'" + + def test_init(self, airbrakes_context, logger, imu, servo): + assert airbrakes_context.logger == logger + assert airbrakes_context.servo == servo + assert airbrakes_context.imu == imu \ No newline at end of file diff --git a/tests/test_data_processor.py b/tests/test_data_processor.py new file mode 100644 index 00000000..f33724c1 --- /dev/null +++ b/tests/test_data_processor.py @@ -0,0 +1,96 @@ +import math +import random + +import pytest + +from airbrakes.imu.data_processor import IMUDataProcessor +from airbrakes.imu.imu_data_packet import EstimatedDataPacket + + +def simulate_altitude_sine_wave(n_points=1000, frequency=0.01, amplitude=100, noise_level=3, base_altitude=20): + """Generates a random distribution of altitudes that follow a sine wave pattern, with some + noise added to simulate variations in the readings. + + :param n_points: The number of altitude points to generate. + :param frequency: The frequency of the sine wave. + :param amplitude: The amplitude of the sine wave. + :param noise_level: The standard deviation of the Gaussian noise to add. + :param base_altitude: The base altitude, i.e. starting altitude from sea level. + """ + altitudes = [] + for i in range(n_points): + # Calculate the sine wave value + # sine wave roughly models the altitude of the rocket + sine_value = amplitude * math.sin(math.pi * i / (n_points - 1)) + # Add Gaussian noise + noise = random.gauss(0, noise_level) + # Calculate the altitude at this point + altitude_value = base_altitude + sine_value + noise + altitudes.append(altitude_value) + return altitudes + + +@pytest.fixture +def data_processor(): + # list of randomly increasing altitudes up to 1000 items + sample_data = [ + EstimatedDataPacket( + 1, estCompensatedAccelX=1, estCompensatedAccelY=2, estCompensatedAccelZ=3, estPressureAlt=20 + ) + ] + return IMUDataProcessor(sample_data) + + +class TestIMUDataProcessor: + """Tests the IMUDataProcessor class""" + + def test_slots(self, data_processor): + inst = data_processor + for attr in inst.__slots__: + assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'" + + def test_init(self, data_processor): + d = IMUDataProcessor([]) + assert d._avg_accel == (0.0, 0.0, 0.0) + assert d._avg_accel_mag == 0.0 + assert d._max_altitude == 0.0 + + d = data_processor + assert d._avg_accel == (1, 2, 3) + assert d._avg_accel_mag == math.sqrt(1**2 + 2**2 + 3**2) + assert d._max_altitude == 20 + + def test_str(self, data_processor): + assert ( + str(data_processor) == "IMUDataProcessor(avg_acceleration=(1.0, 2.0, 3.0), " + "avg_acceleration_mag=3.7416573867739413, max_altitude=20)" + ) + + def test_update_data(self, data_processor): + d = data_processor + d.update_data([ + EstimatedDataPacket( + 1, estCompensatedAccelX=1, estCompensatedAccelY=2, estCompensatedAccelZ=3, estPressureAlt=20 + ), + EstimatedDataPacket( + 2, estCompensatedAccelX=2, estCompensatedAccelY=3, estCompensatedAccelZ=4, estPressureAlt=30 + ), + ]) + assert d._avg_accel == (1.5, 2.5, 3.5) == d.avg_acceleration + assert d._avg_accel_mag == math.sqrt(1.5**2 + 2.5**2 + 3.5**2) == d.avg_acceleration_mag + assert d.avg_acceleration_z == 3.5 + assert d._max_altitude == 30 == d.max_altitude + + def test_max_altitude(self, data_processor): + """Tests whether the max altitude is correctly calculated even when alititude decreases""" + d = data_processor + altitudes = simulate_altitude_sine_wave(n_points=1000) + # run update_data every 10 packets, to simulate actual data processing in real time: + for i in range(0, len(altitudes), 10): + d.update_data([ + EstimatedDataPacket( + i, estCompensatedAccelX=1, estCompensatedAccelY=2, estCompensatedAccelZ=3, estPressureAlt=alt + ) + for alt in altitudes[i : i + 10] + ]) + assert d.max_altitude == max(altitudes) diff --git a/tests/test_imu.py b/tests/test_imu.py index 8b454d9c..84691a4a 100644 --- a/tests/test_imu.py +++ b/tests/test_imu.py @@ -10,9 +10,7 @@ from airbrakes.imu.imu_data_packet import EstimatedDataPacket, IMUDataPacket, RawDataPacket -@pytest.fixture -def imu(): - return IMU(port=PORT, frequency=FREQUENCY, upside_down=UPSIDE_DOWN) + RAW_DATA_PACKET_SAMPLING_RATE = 1 / 1000 # 1kHz @@ -22,6 +20,11 @@ def imu(): class TestIMU: """Class to test the IMU class in imu.py""" + def test_slots(self, imu): + inst = imu + for attr in inst.__slots__: + assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'" + def test_init(self, imu): assert isinstance(imu._data_queue, multiprocessing.queues.Queue) # Test that _running is a shared boolean multiprocessing.Value: diff --git a/tests/test_logger.py b/tests/test_logger.py index e0b6d250..f966abe5 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -2,33 +2,32 @@ import multiprocessing import multiprocessing.sharedctypes import time -from pathlib import Path import pytest from airbrakes.constants import CSV_HEADERS from airbrakes.imu.imu_data_packet import EstimatedDataPacket, RawDataPacket from airbrakes.logger import Logger - - -@pytest.fixture -def logger(): - return Logger(TestLogger.LOG_PATH) +from tests.conftest import LOG_PATH class TestLogger: """Tests the Logger() class in logger.py""" - LOG_PATH = Path("tests/logs") @pytest.fixture(autouse=True) # autouse=True means run this function before/after every test def clear_directory(self): """Clear the tests/logs directory after running each test.""" yield # This is where the test runs # Test run is over, now clean up - for log in self.LOG_PATH.glob("log_*.csv"): + for log in LOG_PATH.glob("log_*.csv"): log.unlink() + def test_slots(self, logger): + inst = logger + for attr in inst.__slots__: + assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'" + def test_init(self, logger): # Test if "logs" directory was created assert logger.log_path.parent.exists() @@ -44,18 +43,18 @@ def test_init(self, logger): def test_init_sets_log_path_correctly(self): # assert no files exist: - assert not list(self.LOG_PATH.glob("log_*.csv")) - logger = Logger(self.LOG_PATH) - expected_log_path = self.LOG_PATH / "log_1.csv" + assert not list(LOG_PATH.glob("log_*.csv")) + logger = Logger(LOG_PATH) + expected_log_path = LOG_PATH / "log_1.csv" assert expected_log_path.exists() assert logger.log_path == expected_log_path - logger_2 = Logger(self.LOG_PATH) - expected_log_path_2 = self.LOG_PATH / "log_2.csv" + logger_2 = Logger(LOG_PATH) + expected_log_path_2 = LOG_PATH / "log_2.csv" assert expected_log_path_2.exists() assert logger_2.log_path == expected_log_path_2 # Test only 2 csv files exist: - assert set(self.LOG_PATH.glob("log_*.csv")) == {expected_log_path, expected_log_path_2} + assert set(LOG_PATH.glob("log_*.csv")) == {expected_log_path, expected_log_path_2} def test_init_log_file_has_correct_headers(self, logger): with logger.log_path.open() as f: