Skip to content

Commit

Permalink
Simplify unit tests to make it easier to read
Browse files Browse the repository at this point in the history
  • Loading branch information
robertapplin committed Apr 7, 2024
1 parent c77b6aa commit c5d4d44
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 42 deletions.
20 changes: 7 additions & 13 deletions quasielasticbayes/test/qldata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
import numpy as np
import tempfile

from quasielasticbayes.testing import load_json, add_path
from quasielasticbayes.testing import add_path, load_json, DATA_DIR, RELATIVE_TOLERANCE_FIT, RELATIVE_TOLERANCE_PROB
from quasielasticbayes.QLdata import qldata


DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')


class QLdataTest(unittest.TestCase):
"""
Characterization tests using inputs that have been accepted as correct.
Expand All @@ -19,13 +16,11 @@ class QLdataTest(unittest.TestCase):
"""

def test_qlres_minimal_input(self):
# reference inputs
fin = 'qldata_input.json'
with open(os.path.join(DATA_DIR, 'qldata', fin), 'r') as fh:
inputs = load_json(fh)
with tempfile.TemporaryDirectory() as tmp_dir:
inputs = load_json("qldata", "qldata_input.json")

with tempfile.TemporaryDirectory() as tmp_dir:
inputs['wrks'] = add_path(tmp_dir, inputs['wrks'])

nd, xout, yout, eout, yfit, yprob = qldata(inputs['numb'],
inputs['Xv'],
inputs['Yv'],
Expand All @@ -42,16 +37,15 @@ def test_qlres_minimal_input(self):
inputs['wrkr'],
inputs['lwrk'])
# verify
cf = 'qldata_output.json'
with open(os.path.join(DATA_DIR, 'qldata', cf), 'r') as fh:
with open(os.path.join(DATA_DIR, 'qldata', 'qldata_output.json'), 'r') as fh:
reference = load_json(fh)

self.assertEqual(reference['nd'], nd)
np.testing.assert_allclose(reference['xout'], xout)
np.testing.assert_allclose(reference['yout'], yout)
np.testing.assert_allclose(reference['eout'], eout)
np.testing.assert_allclose(reference['yfit'], yfit, rtol=1e-3)
np.testing.assert_allclose(reference['yprob'], yprob, rtol=1e-3)
np.testing.assert_allclose(reference['yfit'], yfit, rtol=RELATIVE_TOLERANCE_FIT)
np.testing.assert_allclose(reference['yprob'], yprob, rtol=RELATIVE_TOLERANCE_PROB)


if __name__ == '__main__':
Expand Down
12 changes: 4 additions & 8 deletions quasielasticbayes/test/qlres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
import numpy as np
import sys

from quasielasticbayes.testing import add_path, load_json
from quasielasticbayes.testing import add_path, load_json, DATA_DIR, RELATIVE_TOLERANCE_FIT, RELATIVE_TOLERANCE_PROB
from quasielasticbayes.QLres import qlres

DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')


class QLresTest(unittest.TestCase):
"""
Expand All @@ -20,9 +18,7 @@ class QLresTest(unittest.TestCase):

@unittest.skipIf(sys.platform == "darwin", "Reading the json reference file causes an unexplained crash.")
def test_qlres_minimal_input(self):
# reference inputs
with open(os.path.join(DATA_DIR, 'qlres', 'qlres-input-spec-0.json'), 'r') as fh:
inputs = load_json(fh)
inputs = load_json("qlres", "qlres-input-spec-0.json")

with tempfile.TemporaryDirectory() as tmp_dir:
inputs['wrks'] = add_path(tmp_dir, inputs['wrks'])
Expand All @@ -41,8 +37,8 @@ def test_qlres_minimal_input(self):
np.testing.assert_allclose(reference['xout'], xout)
np.testing.assert_allclose(reference['yout'], yout)
np.testing.assert_allclose(reference['eout'], eout)
np.testing.assert_allclose(reference['yfit'], yfit, rtol=1e-3)
np.testing.assert_allclose(reference['yprob'], yprob, rtol=1e-2)
np.testing.assert_allclose(reference['yfit'], yfit, rtol=RELATIVE_TOLERANCE_FIT)
np.testing.assert_allclose(reference['yprob'], yprob, rtol=RELATIVE_TOLERANCE_PROB)

if __name__ == '__main__':
unittest.main()
17 changes: 6 additions & 11 deletions quasielasticbayes/test/qlse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
import numpy as np
import tempfile

from quasielasticbayes.testing import load_json, add_path
from quasielasticbayes.testing import add_path, load_json, DATA_DIR, RELATIVE_TOLERANCE_FIT, RELATIVE_TOLERANCE_PROB
from quasielasticbayes.QLres import qlres


DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')


class QLresTest(unittest.TestCase):
"""
Characterization tests using inputs that have been accepted as correct.
Expand All @@ -19,13 +16,11 @@ class QLresTest(unittest.TestCase):
"""

def test_qlres_minimal_input(self):
# reference inputs
fin = 'qlse_input.json'
with open(os.path.join(DATA_DIR, 'qlse', fin), 'r') as fh:
inputs = load_json(fh)
with tempfile.TemporaryDirectory() as tmp_dir:
inputs = load_json("qlse", "qlse_input.json")

with tempfile.TemporaryDirectory() as tmp_dir:
inputs['wrks'] = add_path(tmp_dir, inputs['wrks'])

nd, xout, yout, eout, yfit, yprob = qlres(inputs['numb'],
inputs['Xv'],
inputs['Yv'],
Expand All @@ -51,8 +46,8 @@ def test_qlres_minimal_input(self):
np.testing.assert_allclose(reference['xout'], xout)
np.testing.assert_allclose(reference['yout'], yout)
np.testing.assert_allclose(reference['eout'], eout)
np.testing.assert_allclose(reference['yfit'], yfit, rtol=1e-3)
np.testing.assert_allclose(reference['yprob'], yprob, rtol=1e-2)
np.testing.assert_allclose(reference['yfit'], yfit, rtol=RELATIVE_TOLERANCE_FIT)
np.testing.assert_allclose(reference['yprob'], yprob, rtol=RELATIVE_TOLERANCE_PROB)


if __name__ == '__main__':
Expand Down
37 changes: 37 additions & 0 deletions quasielasticbayes/test/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Utility code to help with testing
"""
import base64
import json
import numpy as np
import os

RELATIVE_TOLERANCE_FIT=1e-3
RELATIVE_TOLERANCE_PROB=1e-2


def _json_numpy_obj_hook(dct):
"""
Decodes a previously encoded numpy ndarray
with proper shape and dtype
:param dct: (dict) json encoded ndarray
:return: (ndarray) if input was an encoded ndarray
"""
if isinstance(dct, dict) and '__ndarray__' in dct:
data = base64.b64decode(dct['__ndarray__'])
return np.frombuffer(data, dct['dtype']).reshape(dct['shape'])
return dct


def load_json(*args, **kwargs):
"""Loads a json-encoded file-like object to a dictionary.
Adds supports for decoding numpy arrays
See json.load.
"""
kwargs.setdefault('object_hook', _json_numpy_obj_hook)
return json.load(*args, **kwargs)


def add_path(file_path, file_name):
"""Sets the path for a file
"""
return os.path.join(file_path, file_name)
18 changes: 8 additions & 10 deletions quasielasticbayes/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import numpy as np
import os

DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')

RELATIVE_TOLERANCE_FIT=1e-3
RELATIVE_TOLERANCE_PROB=1e-2


def _json_numpy_obj_hook(dct):
"""
Expand All @@ -19,23 +24,16 @@ def _json_numpy_obj_hook(dct):
return dct


def load_json(*args, **kwargs):
def load_json(sub_directory: str, filename: str):
"""Loads a json-encoded file-like object to a dictionary.
Adds supports for decoding numpy arrays
See json.load.
"""
kwargs.setdefault('object_hook', _json_numpy_obj_hook)
return json.load(*args, **kwargs)
with open(os.path.join(DATA_DIR, sub_directory, filename), 'r') as file:
return json.load(file, object_hook=_json_numpy_obj_hook)


def add_path(file_path, file_name):
"""Sets the path for a file
"""
return os.path.join(file_path, file_name)


# def get_qlse_prob(ref):
# if sys.platform == 'win32':
# return ref
# else:
# return [-2.7656994e+04, -1.8887866e+2, 0.0, -5.8251953e-1]

0 comments on commit c5d4d44

Please sign in to comment.