diff --git a/dynax/linearize.py b/dynax/linearize.py index 4fe5e83..be3641b 100644 --- a/dynax/linearize.py +++ b/dynax/linearize.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import numpy as np from jaxtyping import Array - +from functools import partial from .derivative import lie_derivative from .system import ControlAffine, DynamicalSystem, LinearSystem @@ -37,6 +37,31 @@ def relative_degree(sys, xs, max_reldeg=10, output: Optional[int] = None) -> int raise RuntimeError("Could not estimate relative degree. Increase max_reldeg.") +def discrete_relative_degree(sys, xs, us, max_reldeg=10, output: Optional[int] = None): + """Estimate relative degree of discrete-time system on region xs.""" + f = sys.vector_field + h = sys.output + + def F(n, x, u): + if n == 0: + return x + elif n == 1: + return f(x, u) + return F(n-1, f(x, u), 0) + + deriv_u_hF = jax.grad(lambda k, x, u: h(F(k, x, u)), 2) + + for n in range(1, max_reldeg + 1): + res = jax.vmap(partial(deriv_u_hF, n))(xs, us) + if np.all(res == 0.0): + continue + elif np.all(res != 0.0): + return n + else: + raise RuntimeError("sys has ill defined relative degree.") + raise RuntimeError("Could not estmate relative degree. Increase max_reldeg.") + + def is_controllable(A, B) -> bool: """Test controllability of linear system.""" n = A.shape[0] diff --git a/tests/test_linearize.py b/tests/test_linearize.py index 8f19039..dc35fd5 100644 --- a/tests/test_linearize.py +++ b/tests/test_linearize.py @@ -10,7 +10,12 @@ LinearSystem, ) from dynax.example_models import NonlinearDrag, Sastry9_9 -from dynax.linearize import input_output_linearize, is_controllable, relative_degree +from dynax.linearize import ( + input_output_linearize, + is_controllable, + relative_degree, + discrete_relative_degree, +) tols = dict(rtol=1e-04, atol=1e-06) @@ -49,6 +54,21 @@ def test_relative_degree(): assert relative_degree(sys, xs) == 1 +def test_discrete_relative_degree(): + xs = np.random.normal(size=(100, 2)) + us = np.random.normal(size=(100)) + + sys = SpringMassDamperWithOutput(out=0) + assert discrete_relative_degree(sys, xs, us) == 2 + + with npt.assert_raises(RuntimeError): + discrete_relative_degree(sys, xs, us, max_reldeg=1) + + sys = SpringMassDamperWithOutput(out=1) + assert discrete_relative_degree(sys, xs, us) == 1 + + + def test_is_controllable(): n = 3 A = np.diag(np.arange(n)) @@ -130,3 +150,18 @@ def test_input_output_linearize_multiple_outputs(): y_ref = Flow(ref)(np.zeros(sys.n_states), t, u)[1] y = Flow(feedback_sys)(np.zeros(feedback_sys.n_states), t, u)[1] npt.assert_allclose(y_ref[:, out_idx], y[:, out_idx], **tols) + + +def test_discrete_input_output_linearize(): + sys = NonlinearDrag(0.1, 0.1, 0.1, 0.1) + ref = sys.linearize() + xs = np.random.normal(size=(100, sys.n_states)) + us = np.random.normal(size=100) + reldeg = discrete_relative_degree(sys, xs, us) + feedbacklaw = discrete_input_output_linearize(sys, reldeg, ref) + feedback_sys = DynamicStateFeedbackSystem(sys, ref, feedbacklaw) + t = np.linspace(0, 1, 1000) + u = np.sin(t) * 0.1 + y_ref = Map(ref)(np.zeros(sys.n_states), t, u)[1] + y = Map(feedback_sys)(np.zeros(feedback_sys.n_states), t, u)[1] + npt.assert_allclose(y_ref, y, **tols)