Skip to content

Commit

Permalink
More tests and fix few bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
harshil21 committed Sep 16, 2024
1 parent d3d4ecd commit adbc5ff
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 30 deletions.
4 changes: 4 additions & 0 deletions airbrakes/airbrakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def update(self) -> None:
# behind on processing
data_packets: collections.deque[IMUDataPacket] = self.imu.get_imu_data_packets()

# This should never happen, but if it does, we want to not error out and wait for packets
if not data_packets:
return

# Update the processed data with the new data packets. We only care about EstimatedDataPackets
self.data_processor.update_data(
[data_packet for data_packet in data_packets if isinstance(data_packet, EstimatedDataPacket)]
Expand Down
21 changes: 14 additions & 7 deletions airbrakes/imu/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@ class IMUDataProcessor:
:param data_points: A sequence of EstimatedDataPacket objects to process.
"""

__slots__ = ("_avg_accel", "_avg_accel_mag", "_max_altitude", "data_points")
__slots__ = ("_avg_accel", "_avg_accel_mag", "_max_altitude", "_data_points")

Check failure on line 18 in airbrakes/imu/data_processor.py

View workflow job for this annotation

GitHub Actions / build

Ruff (RUF023)

airbrakes/imu/data_processor.py:18:17: RUF023 `IMUDataProcessor.__slots__` is not sorted

def __init__(self, data_points: Sequence[EstimatedDataPacket]):
self.data_points: Sequence[EstimatedDataPacket] = data_points
self._avg_accel: tuple[float, float, float] = (0.0, 0.0, 0.0)
self._avg_accel_mag: float = 0.0
self._max_altitude: float = 0.0
self._data_points: Sequence[EstimatedDataPacket]

if data_points: # actually update the data on init
self.update_data(data_points)
else:
self._data_points: Sequence[EstimatedDataPacket] = data_points

def __str__(self) -> str:
return (
Expand All @@ -37,11 +42,13 @@ def update_data(self, data_points: Sequence[EstimatedDataPacket]) -> None:
altitude.
:param data_points: A sequence of EstimatedDataPacket objects to process.
"""
self.data_points = data_points
# if not data_points:
# raise ValueError("Data packets must be non-empty!")
self._data_points = data_points
a_x, a_y, a_z = self._compute_averages()
self._avg_accel = (a_x, a_y, a_z)
self._avg_accel_mag = (self._avg_accel[0] ** 2 + self._avg_accel[1] ** 2 + self._avg_accel[2] ** 2) ** 0.5
self._max_altitude = max(*(data_point.estPressureAlt for data_point in self.data_points), self._max_altitude)
self._max_altitude = max(*(data_point.estPressureAlt for data_point in self._data_points), self._max_altitude)

def _compute_averages(self) -> tuple[float, float, float]:
"""
Expand All @@ -50,9 +57,9 @@ def _compute_averages(self) -> tuple[float, float, float]:
"""
# calculate the average acceleration in the x, y, and z directions
# TODO: Test what these accel values actually look like
x_accel = stats.fmean(data_point.estCompensatedAccelX for data_point in self.data_points)
y_accel = stats.fmean(data_point.estCompensatedAccelY for data_point in self.data_points)
z_accel = stats.fmean(data_point.estCompensatedAccelZ for data_point in self.data_points)
x_accel = stats.fmean(data_point.estCompensatedAccelX for data_point in self._data_points)
y_accel = stats.fmean(data_point.estCompensatedAccelY for data_point in self._data_points)
z_accel = stats.fmean(data_point.estCompensatedAccelZ for data_point in self._data_points)
# TODO: Calculate avg velocity if that's also available
return x_accel, y_accel, z_accel

Expand Down
2 changes: 1 addition & 1 deletion airbrakes/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, log_dir: Path):
self._log_queue: multiprocessing.Queue[dict[str, str] | str] = multiprocessing.Queue()

# Start the logging process
self._log_process = multiprocessing.Process(target=self._logging_loop)
self._log_process = multiprocessing.Process(target=self._logging_loop, name="Logger")

@property
def is_running(self) -> bool:
Expand Down
8 changes: 6 additions & 2 deletions airbrakes/servo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ class Servo:

__slots__ = ("current_extension", "max_extension", "min_extension", "servo")

def __init__(self, gpio_pin_number: int, min_extension: float, max_extension: float):
def __init__(self, gpio_pin_number: int, min_extension: float, max_extension: float,

Check failure on line 13 in airbrakes/servo.py

View workflow job for this annotation

GitHub Actions / build

Ruff (W291)

airbrakes/servo.py:13:89: W291 Trailing whitespace
pin_factory=None):
self.min_extension = min_extension
self.max_extension = max_extension
self.current_extension = 0.0

# Sets up the servo with the specified GPIO pin number
# For this to work, you have to run the pigpio daemon on the Raspberry Pi (sudo pigpiod)
gpiozero.Device.pin_factory = gpiozero.pins.pigpio.PiGPIOFactory()
if pin_factory is None:
gpiozero.Device.pin_factory = gpiozero.pins.pigpio.PiGPIOFactory()
else:
gpiozero.Device.pin_factory = pin_factory
self.servo = gpiozero.Servo(gpio_pin_number)

def set_extension(self, extension: float):
Expand Down
7 changes: 4 additions & 3 deletions airbrakes/state.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Module for the finite state machine that represents which state of flight we are in."""

from abc import ABC, abstractmethod
from typing import override
from typing import override, TYPE_CHECKING

from airbrakes.airbrakes import AirbrakesContext
if TYPE_CHECKING:

Check failure on line 6 in airbrakes/state.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

airbrakes/state.py:3:1: I001 Import block is un-sorted or un-formatted
from airbrakes.airbrakes import AirbrakesContext


class State(ABC):
Expand All @@ -22,7 +23,7 @@ class State(ABC):

__slots__ = ("context",)

def __init__(self, context: AirbrakesContext):
def __init__(self, context: "AirbrakesContext"):
"""
:param context: The state context object that will be used to interact with the electronics
"""
Expand Down
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Module where fixtures are shared between all test files."""

import pytest
from pathlib import Path

from airbrakes.constants import FREQUENCY, PORT, UPSIDE_DOWN, SERVO_PIN, MIN_EXTENSION, MAX_EXTENSION
from airbrakes.imu.imu import IMU
from airbrakes.logger import Logger
from airbrakes.servo import Servo


from gpiozero.pins.mock import MockFactory, MockPWMPin


LOG_PATH = Path("tests/logs")

Check failure on line 15 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

tests/conftest.py:3:1: I001 Import block is un-sorted or un-formatted

@pytest.fixture
def logger():
return Logger(LOG_PATH)

@pytest.fixture
def imu():
return IMU(port=PORT, frequency=FREQUENCY, upside_down=UPSIDE_DOWN)

@pytest.fixture
def servo():
return Servo(SERVO_PIN, MIN_EXTENSION, MAX_EXTENSION, pin_factory=MockFactory(pin_class=MockPWMPin))

Check failure on line 27 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / build

Ruff (W292)

tests/conftest.py:27:105: W292 No newline at end of file
22 changes: 22 additions & 0 deletions tests/test_airbrakes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from airbrakes.airbrakes import AirbrakesContext


@pytest.fixture
def airbrakes_context(imu, logger, servo):
return AirbrakesContext(logger, servo, imu)


class TestAirbrakesContext:
"""Tests the AirbrakesContext class"""

def test_slots(self, airbrakes_context):
inst = airbrakes_context
for attr in inst.__slots__:
assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'"

def test_init(self, airbrakes_context, logger, imu, servo):
assert airbrakes_context.logger == logger
assert airbrakes_context.servo == servo
assert airbrakes_context.imu == imu

Check failure on line 22 in tests/test_airbrakes.py

View workflow job for this annotation

GitHub Actions / build

Ruff (W292)

tests/test_airbrakes.py:22:44: W292 No newline at end of file
96 changes: 96 additions & 0 deletions tests/test_data_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import math
import random

import pytest

from airbrakes.imu.data_processor import IMUDataProcessor
from airbrakes.imu.imu_data_packet import EstimatedDataPacket


def simulate_altitude_sine_wave(n_points=1000, frequency=0.01, amplitude=100, noise_level=3, base_altitude=20):
"""Generates a random distribution of altitudes that follow a sine wave pattern, with some
noise added to simulate variations in the readings.
:param n_points: The number of altitude points to generate.
:param frequency: The frequency of the sine wave.
:param amplitude: The amplitude of the sine wave.
:param noise_level: The standard deviation of the Gaussian noise to add.
:param base_altitude: The base altitude, i.e. starting altitude from sea level.
"""
altitudes = []
for i in range(n_points):
# Calculate the sine wave value
# sine wave roughly models the altitude of the rocket
sine_value = amplitude * math.sin(math.pi * i / (n_points - 1))
# Add Gaussian noise
noise = random.gauss(0, noise_level)
# Calculate the altitude at this point
altitude_value = base_altitude + sine_value + noise
altitudes.append(altitude_value)
return altitudes


@pytest.fixture
def data_processor():
# list of randomly increasing altitudes up to 1000 items
sample_data = [
EstimatedDataPacket(
1, estCompensatedAccelX=1, estCompensatedAccelY=2, estCompensatedAccelZ=3, estPressureAlt=20
)
]
return IMUDataProcessor(sample_data)


class TestIMUDataProcessor:
"""Tests the IMUDataProcessor class"""

def test_slots(self, data_processor):
inst = data_processor
for attr in inst.__slots__:
assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'"

def test_init(self, data_processor):
d = IMUDataProcessor([])
assert d._avg_accel == (0.0, 0.0, 0.0)
assert d._avg_accel_mag == 0.0
assert d._max_altitude == 0.0

d = data_processor
assert d._avg_accel == (1, 2, 3)
assert d._avg_accel_mag == math.sqrt(1**2 + 2**2 + 3**2)
assert d._max_altitude == 20

def test_str(self, data_processor):
assert (
str(data_processor) == "IMUDataProcessor(avg_acceleration=(1.0, 2.0, 3.0), "
"avg_acceleration_mag=3.7416573867739413, max_altitude=20)"
)

def test_update_data(self, data_processor):
d = data_processor
d.update_data([
EstimatedDataPacket(
1, estCompensatedAccelX=1, estCompensatedAccelY=2, estCompensatedAccelZ=3, estPressureAlt=20
),
EstimatedDataPacket(
2, estCompensatedAccelX=2, estCompensatedAccelY=3, estCompensatedAccelZ=4, estPressureAlt=30
),
])
assert d._avg_accel == (1.5, 2.5, 3.5) == d.avg_acceleration
assert d._avg_accel_mag == math.sqrt(1.5**2 + 2.5**2 + 3.5**2) == d.avg_acceleration_mag
assert d.avg_acceleration_z == 3.5
assert d._max_altitude == 30 == d.max_altitude

def test_max_altitude(self, data_processor):
"""Tests whether the max altitude is correctly calculated even when alititude decreases"""
d = data_processor
altitudes = simulate_altitude_sine_wave(n_points=1000)
# run update_data every 10 packets, to simulate actual data processing in real time:
for i in range(0, len(altitudes), 10):
d.update_data([
EstimatedDataPacket(
i, estCompensatedAccelX=1, estCompensatedAccelY=2, estCompensatedAccelZ=3, estPressureAlt=alt
)
for alt in altitudes[i : i + 10]
])
assert d.max_altitude == max(altitudes)
9 changes: 6 additions & 3 deletions tests/test_imu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
from airbrakes.imu.imu_data_packet import EstimatedDataPacket, IMUDataPacket, RawDataPacket


@pytest.fixture
def imu():
return IMU(port=PORT, frequency=FREQUENCY, upside_down=UPSIDE_DOWN)



RAW_DATA_PACKET_SAMPLING_RATE = 1 / 1000 # 1kHz

Check failure on line 16 in tests/test_imu.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

tests/test_imu.py:1:1: I001 Import block is un-sorted or un-formatted
Expand All @@ -22,6 +20,11 @@ def imu():
class TestIMU:
"""Class to test the IMU class in imu.py"""

def test_slots(self, imu):
inst = imu
for attr in inst.__slots__:
assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'"

def test_init(self, imu):
assert isinstance(imu._data_queue, multiprocessing.queues.Queue)
# Test that _running is a shared boolean multiprocessing.Value:
Expand Down
27 changes: 13 additions & 14 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,32 @@
import multiprocessing
import multiprocessing.sharedctypes
import time
from pathlib import Path

import pytest

from airbrakes.constants import CSV_HEADERS
from airbrakes.imu.imu_data_packet import EstimatedDataPacket, RawDataPacket
from airbrakes.logger import Logger


@pytest.fixture
def logger():
return Logger(TestLogger.LOG_PATH)
from tests.conftest import LOG_PATH


class TestLogger:
"""Tests the Logger() class in logger.py"""

LOG_PATH = Path("tests/logs")

@pytest.fixture(autouse=True) # autouse=True means run this function before/after every test
def clear_directory(self):
"""Clear the tests/logs directory after running each test."""
yield # This is where the test runs
# Test run is over, now clean up
for log in self.LOG_PATH.glob("log_*.csv"):
for log in LOG_PATH.glob("log_*.csv"):
log.unlink()

def test_slots(self, logger):
inst = logger
for attr in inst.__slots__:
assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'"

def test_init(self, logger):
# Test if "logs" directory was created
assert logger.log_path.parent.exists()
Expand All @@ -44,18 +43,18 @@ def test_init(self, logger):

def test_init_sets_log_path_correctly(self):
# assert no files exist:
assert not list(self.LOG_PATH.glob("log_*.csv"))
logger = Logger(self.LOG_PATH)
expected_log_path = self.LOG_PATH / "log_1.csv"
assert not list(LOG_PATH.glob("log_*.csv"))
logger = Logger(LOG_PATH)
expected_log_path = LOG_PATH / "log_1.csv"
assert expected_log_path.exists()
assert logger.log_path == expected_log_path
logger_2 = Logger(self.LOG_PATH)
expected_log_path_2 = self.LOG_PATH / "log_2.csv"
logger_2 = Logger(LOG_PATH)
expected_log_path_2 = LOG_PATH / "log_2.csv"
assert expected_log_path_2.exists()
assert logger_2.log_path == expected_log_path_2

# Test only 2 csv files exist:
assert set(self.LOG_PATH.glob("log_*.csv")) == {expected_log_path, expected_log_path_2}
assert set(LOG_PATH.glob("log_*.csv")) == {expected_log_path, expected_log_path_2}

def test_init_log_file_has_correct_headers(self, logger):
with logger.log_path.open() as f:
Expand Down

0 comments on commit adbc5ff

Please sign in to comment.