Skip to content
This repository has been archived by the owner on Mar 23, 2024. It is now read-only.

20 lookup table class #25

Merged
merged 13 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions controller/common/lut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import List

import numpy as np
from numpy.typing import NDArray
from scipy import interpolate

from controller.common.types import Scalar


class LUT:
"""
Class for performing look-up table interpolation.

Methods:
__init__: Initializes the LUT object with lookup table data and interpolation method.
__call__: Calls the interpolation method with the given input.
"""

def __init__(
self,
lookup_table: List[List[Scalar]] | NDArray,
interpolation_method: str = "linear",
):
"""
Initializes the LUT object.

Args:
lookup_table (List[List[Scalar]] or NDArray): A list of lists or NDArray containing x-y
data points for interpolation. Shape should be (n, 2)
interpolation_method (str): Interpolation method to use. Default is "linear".

Raises:
ValueError: If the specified interpolation method is unknown
or if the table shape is incorrect.
"""
if isinstance(lookup_table, np.ndarray):
table = lookup_table
else:
table = np.array(lookup_table)

self.__verifyTable(table)
self.x = table[:, 0]
self.y = table[:, 1]

self.__interpolation_method = interpolation_method

match self.__interpolation_method:
case "linear":
self.__interpolation_function = self.__linearInterpolation
case "spline":
self.__interpolation_function = self.__splineInterpolation
case _:
raise ValueError(
self.__interpolation_method + " is an unknown interpolation method!"
)

def __call__(self, x: float) -> float:
"""
Calls the interpolation method with the given input.

Args:
x (float): The input value to interpolate.

Returns:
float: The interpolated value using the interpolation method defined when LUT instance
creation.
"""
return self.__interpolation_function(x)

def __linearInterpolation(self, x: Scalar) -> float:
output = np.interp(x, self.x, self.y)
if isinstance(output, np.ndarray):
raise ValueError(
"linear interpolation returned a NDArray when it should have returned a float"
)
return output

def __splineInterpolation(self, x: Scalar) -> float:
cs = interpolate.CubicSpline(self.x, self.y)
return cs(x)

def __verifyTable(self, table: NDArray) -> None:
if (len(table.shape) != 2) or table.shape[1] != 2:
raise ValueError(f"Input table has shape {table.shape}, but expected shape of (n, 2)")
Empty file added test.py
Empty file.
88 changes: 88 additions & 0 deletions tests/unit/wingsail/common/test_lut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import math

import numpy as np
import pytest
from scipy import interpolate

from controller.common.lut import LUT


class TestLUT:
# Intialize lookup table
look_up_table = [[50000, 5.75], [100000, 6.75], [200000, 7], [500000, 9.75], [1000000, 10]]
look_up_np_array = np.array(look_up_table)

def test_LUT_constructor(self):
alberto-escobar marked this conversation as resolved.
Show resolved Hide resolved
# set up
testLUT = LUT(self.look_up_table)
testLUT2 = LUT(self.look_up_np_array)
# test that LUT return a known value
assert math.isclose(testLUT(50000), 5.75)
assert math.isclose(testLUT2(50000), 5.75)

def test_unknown_interpolation_exception(self):
with pytest.raises(ValueError):
testLUT = LUT(self.look_up_table, "gabagool")
assert math.isclose(testLUT(50000), 5.75)

@pytest.mark.parametrize(
"invalid_table",
[
[[10000, 10000, 10000], [1, 1, 1]],
[10000, 10000, 10000],
[[0, 1], 10000, 10000],
np.array([[10000, 10000, 10000], [1, 1, 1]]),
np.array([10000, 10000, 10000]),
np.array([[[0, 1]], [[0, 1]], [[0, 1]]]),
],
)
def test_invalid_table_exception(self, invalid_table):
with pytest.raises(ValueError):
testLUT = LUT(invalid_table)
assert math.isclose(testLUT(50000), 5.75)

@pytest.mark.parametrize(
"invalid_table",
[
np.array([[10000, 10000, 10000], [1, 1, 1]]),
np.array([10000, 10000, 10000]),
np.array([[[0, 1]], [[0, 1]], [[0, 1]]]),
],
)
def test_invalid_numpy_array_exception(self, invalid_table):
with pytest.raises(ValueError):
testLUT = LUT(invalid_table)
assert math.isclose(testLUT(50000), 5.75)

@pytest.mark.parametrize("linear_test_values", list(range(50000, 1100000, 10000)))
alberto-escobar marked this conversation as resolved.
Show resolved Hide resolved
def test_linear_interpolation(self, linear_test_values):
# set up
testLUT = LUT(self.look_up_table)
table = np.array(self.look_up_table)

# Test that LUT returns same values as np linear interpolate function
assert math.isclose(
testLUT(linear_test_values), np.interp(linear_test_values, table[:, 0], table[:, 1])
)

@pytest.mark.parametrize("test_value, expected_value", [(1000, 5.75), (2000000, 10)])
def test_linear_extrapolation(self, test_value, expected_value):
testLUT = LUT(self.look_up_table)
# Test that linear interpolation does not extrapolate
assert math.isclose(testLUT(test_value), expected_value)

@pytest.mark.parametrize("test_value", [[100, 200, 300]])
def test_linear_interpolation_exception(self, test_value):
testLUT = LUT(self.look_up_table)
# Test that linear interpolation does not extrapolate
with pytest.raises(ValueError):
testLUT(test_value)

@pytest.mark.parametrize("spline_test_values", list(range(10000, 2100000, 10000)))
alberto-escobar marked this conversation as resolved.
Show resolved Hide resolved
def test_spline_interpolation_extrapolation(self, spline_test_values):
testLUT = LUT(self.look_up_table, "spline")
table = np.array(self.look_up_table)
cs = interpolate.CubicSpline(table[:, 0], table[:, 1])

# Test that LUT returns same values as cubic interpolate function
assert math.isclose(testLUT(spline_test_values), cs(spline_test_values))