Skip to content

Commit

Permalink
added some road signs
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelkomarov committed Jan 18, 2025
1 parent 6766e52 commit ae65554
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from derivative.differentiation import _gen_method


# Utilities for tests
# ===================
def default_args(kind):
""" The assumption is that the function will have dt = 1/100 over a range of 1 and not vary much. The goal is to
to set the parameters such that we obtain effective derivatives under these conditions.
Expand All @@ -26,8 +28,7 @@ def default_args(kind):
return {"sigma": 1, "lmbd": .01, "kernel": "gaussian"}
else:
raise ValueError('Unimplemented default args for kind {}.'.format(kind))



class NumericalExperiment:
def __init__(self, fn, fn_str, t, kind, args):
self.fn = fn
Expand All @@ -40,7 +41,6 @@ def __init__(self, fn, fn_str, t, kind, args):
def run(self):
return dxdt(self.fn(self.t), self.t, self.kind, self.axis, **self.kwargs)


def compare(experiment, truth, rel_tol, abs_tol, shape_only=False):
""" Compare a numerical experiment to theoretical expectations. Issue warnings for derivative methods that fail,
use asserts for implementation requirements.
Expand All @@ -60,8 +60,8 @@ def mean_sq(x):
assert np.linalg.norm(residual, ord=np.inf) < max(abs_tol, np.linalg.norm(truth, ord=np.inf) * rel_tol)


# Check that numbers are returned
# ===============================
# Check that only numbers are returned
# ====================================
@pytest.mark.parametrize("m", methods)
def test_notnan(m):
t = np.linspace(0, 1, 100)
Expand All @@ -71,8 +71,8 @@ def test_notnan(m):
assert not np.any(np.isnan(values)), message


# Test some basic functions
# =========================
# Test that basic functions are differentiated correctly
# ======================================================
funcs_and_derivs = (
(lambda t: np.ones_like(t), "f(t) = 1", lambda t: np.zeros_like(t), "const1"),
(lambda t: np.zeros_like(t), "f(t) = 0", lambda t: np.zeros_like(t), "const0"),
Expand Down Expand Up @@ -112,6 +112,8 @@ def test_fn(m, func_spec):
compare(nexp, deriv(t), 1e-1, 1e-1, bad_combo)


# Test smoothing for those that do it
# ===================================
@pytest.mark.parametrize("kind", ("kalman", "trend_filtered"))
def test_smoothing_x(kind):
t = np.linspace(0, 1, 100)
Expand All @@ -122,7 +124,6 @@ def test_smoothing_x(kind):
# MSE
assert np.linalg.norm(x_est - np.sin(t)) ** 2 / len(t) < 1e-1


@pytest.mark.parametrize("kind", ("kalman", "trend_filtered"))
def test_smoothing_functional(kind):
t = np.linspace(0, 1, 100)
Expand All @@ -133,13 +134,14 @@ def test_smoothing_functional(kind):
assert np.linalg.norm(x_est - np.sin(t)) ** 2 / len(t) < 1e-1


# Test caching of the expensive _gen_method using a dummy
# =======================================================
@pytest.fixture
def clean_gen_method_cache():
_gen_method.cache_clear()
yield
_gen_method.cache_clear()


def test_gen_method_caching(clean_gen_method_cache):
x = np.ones(3)
t = np.arange(3)
Expand All @@ -150,7 +152,6 @@ def test_gen_method_caching(clean_gen_method_cache):
assert _gen_method.cache_info().currsize == 1
assert id(expected) == id(result)


def test_gen_method_kwarg_caching(clean_gen_method_cache):
x = np.ones(3)
t = np.arange(3)
Expand All @@ -164,6 +165,8 @@ def test_gen_method_kwarg_caching(clean_gen_method_cache):
assert id(expected) != id(result)


# Test caching of the expensive private _global methods using a dummy
# ===================================================================
@pytest.fixture
def method_inst(request):
x = np.ones(3)
Expand All @@ -173,7 +176,6 @@ def method_inst(request):
yield x, t, method
method._global.cache_clear()


@pytest.mark.parametrize("method_inst", ["kalman", "trend_filtered", "spectral"], indirect=True)
def test_dglobal_caching(method_inst):
# make sure we're not recomputing expensive _global() method
Expand All @@ -184,8 +186,7 @@ def test_dglobal_caching(method_inst):
assert method._global.cache_info().misses == 1
assert method._global.cache_info().currsize == 1


@pytest.mark.parametrize("method_inst", ["kalman", "trend_filtered"], indirect=True)
@pytest.mark.parametrize("method_inst", ["kalman", "trend_filtered", "spectral"], indirect=True)
def test_cached_global_order(method_inst):
x, t, method = method_inst
x = np.vstack((x, -x))
Expand Down

0 comments on commit ae65554

Please sign in to comment.