This repository has been archived by the owner on Mar 23, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
completed testing suite for LUT class and implemented feedback from l…
…ast sync up
- Loading branch information
1 parent
b224731
commit 664b47f
Showing
2 changed files
with
60 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,26 @@ | ||
import numpy as np | ||
from scipy import interpolate | ||
|
||
|
||
class LUT: | ||
def __init__(self, lookup_table, interpolation_method="linear"): | ||
self.table = np.array(lookup_table) | ||
self.interpolation = interpolation_method | ||
# look at python "match" statement to replace below | ||
if self.interpolation != "linear" and self.interpolation != "spline": | ||
raise ValueError(self.interpolation + " is an unknown interpolation method!") | ||
def __init__(self, lookup_table: list, interpolation_method: str = "linear"): | ||
self.__table = np.array(lookup_table) | ||
self.__interpolation = interpolation_method | ||
|
||
# Devon suggestions | ||
# rather then if else, store a reference to the function of the interpolation you want and call | ||
# it here. You can define the reference in the constructor | ||
match self.__interpolation: | ||
case "linear": | ||
self.__method = self.__linearInterpolation | ||
case "spline": | ||
self.__method = self.__splineInterpolation | ||
case _: | ||
raise ValueError(self.__interpolation + " is an unknown interpolation method!") | ||
|
||
# make self.interpolation field private | ||
# also make the interpolation methods private | ||
def __call__(self, x: float) -> float: | ||
return self.__method(x) | ||
|
||
# add variable and return typing | ||
def __call__(self, x): | ||
if self.interpolation == "linear": | ||
return self.linearInterpolation(x) | ||
elif self.interpolation == "spline": | ||
return self.splineInterpolation(x) | ||
else: | ||
return 0 | ||
def __linearInterpolation(self, x: float) -> float: | ||
return np.interp(x, self.__table[:, 0], self.__table[:, 1]) | ||
|
||
def linearInterpolation(self, x): | ||
return np.interp(x, self.table[:, 0], self.table[:, 1]) | ||
|
||
def splineInterpolation(self, x): | ||
cs = np.interpolate.CubicSpline(self.table[:, 0], self.table[:, 1]) | ||
def __splineInterpolation(self, x: float) -> float: | ||
cs = interpolate.CubicSpline(self.__table[:, 0], self.__table[:, 1]) | ||
return cs(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,53 @@ | ||
# import pytest | ||
|
||
import numpy as np | ||
from scipy import interpolate | ||
|
||
from controller.common.lut import LUT | ||
|
||
|
||
class TestLUT: | ||
# Intialize lookup table | ||
look_up_table = [[50000, 5.75], [100000, 6.75], [200000, 7], [500000, 9.75], [1000000, 10]] | ||
|
||
def test_LUT_constructor(self): | ||
look_up_table = [[50000, 5.75], [100000, 6.75], [200000, 7], [500000, 9.75], [1000000, 10]] | ||
testLUT = LUT(look_up_table) | ||
assert testLUT(40000) == 5.75 | ||
# set up | ||
testLUT = LUT(self.look_up_table) | ||
|
||
# test that LUT return a known value | ||
assert testLUT(50000) == 5.75 | ||
|
||
def test_unknown_interpolation_exception(self): | ||
try: | ||
testLUT = LUT(self.look_up_table, "gabagool") | ||
assert False # failure: constructor accepted deli meat interpolation method | ||
|
||
except ValueError: | ||
assert True | ||
|
||
except: | ||
assert False # failure: constructor threw wrong exception | ||
|
||
def test_linear_interpolation(self): | ||
# set up | ||
testLUT = LUT(self.look_up_table) | ||
table = np.array(self.look_up_table) | ||
test_values = list(range(50000, 1100000, 10000)) | ||
|
||
# Test that linear interpolation does not extrapolate | ||
assert testLUT(10000) == 5.75 | ||
assert testLUT(2000000) == 10 | ||
|
||
# def test_linear_interpolation | ||
# Test that LUT returns same values as np linear interpolate function | ||
for value in test_values: | ||
assert testLUT(value) == np.interp(value, table[:, 0], table[:, 1]) | ||
|
||
# def test_spline_interpolation | ||
def test_spline_interpolation(self): | ||
testLUT = LUT(self.look_up_table, "spline") | ||
table = np.array(self.look_up_table) | ||
test_values = list(range(10000, 2100000, 10000)) | ||
cs = interpolate.CubicSpline(table[:, 0], table[:, 1]) | ||
|
||
# test_LUT_constructor | ||
# Test that LUT returns same values as cubic interpolate function | ||
for value in test_values: | ||
assert testLUT(value) == cs(value) |