From cd6c6b8aa3c8c5312d7547322148d0671aee2436 Mon Sep 17 00:00:00 2001 From: fhchl Date: Tue, 30 Jan 2024 15:02:23 +0100 Subject: [PATCH] ruff and black --- dynax/linearize.py | 14 ++++++++------ dynax/util.py | 2 +- examples/linearize_discrete_time.py | 15 ++++++--------- tests/test_linearize.py | 6 +++--- tests/test_util.py | 1 - 5 files changed, 18 insertions(+), 20 deletions(-) diff --git a/dynax/linearize.py b/dynax/linearize.py index c789208..fa39a8b 100644 --- a/dynax/linearize.py +++ b/dynax/linearize.py @@ -170,11 +170,10 @@ def discrete_input_output_linearize( reldeg: int, ref: DynamicalSystem, output: Optional[int] = None, - solver: Optional[optx.AbstractRootFinder] = None, + solver: Optional[optx.AbstractRootFinder] = None, ) -> Callable[[Array, Array, float, float], float]: - """Construct the input-output linearizing feedback law for a discrete-time system. - """ - + """Construct the input-output linearizing feedback for a discrete-time system.""" + # Lee 2022, Chap. 7.4 f = lambda x, u: sys.vector_field(x, u) h = sys.output @@ -186,7 +185,7 @@ def discrete_input_output_linearize( _output = lambda x: x else: _output = lambda x: x[output] - + if solver is None: solver = optx.Newton(rtol=1e-6, atol=1e-6) @@ -202,7 +201,10 @@ def y_reldeg_ref(z, v): def feedbacklaw(x: Array, z: Array, v: float, u_prev: float): def fn(u, args): - return (_output(h(propagate(f, reldeg, x, u))) - y_reldeg_ref(z, v)).squeeze() + return ( + _output(h(propagate(f, reldeg, x, u))) - y_reldeg_ref(z, v) + ).squeeze() + u = optx.root_find(fn, solver, u_prev).value return u diff --git a/dynax/util.py b/dynax/util.py index 79d54a1..3aaca7b 100644 --- a/dynax/util.py +++ b/dynax/util.py @@ -1,10 +1,10 @@ import functools +from typing import Literal import equinox import jax import jax.numpy as jnp from jaxtyping import Array, ArrayLike -from typing import Literal def ssmatrix(data: ArrayLike, axis: int = 0) -> Array: diff --git a/examples/linearize_discrete_time.py b/examples/linearize_discrete_time.py index 844d2c6..c8cef09 100644 --- a/examples/linearize_discrete_time.py +++ b/examples/linearize_discrete_time.py @@ -1,17 +1,16 @@ +import jax.numpy as jnp import matplotlib.pyplot as plt +import numpy as np +from equinox.nn import GRUCell +from jax.random import PRNGKey from dynax import ( - DynamicalSystem, - Map, discrete_relative_degree, DiscreteLinearizingSystem, + DynamicalSystem, LinearSystem, + Map, ) -from equinox.nn import GRUCell -import jax -import jax.numpy as jnp -import numpy as np -from jax.random import PRNGKey # A nonlinear discrete-time system. @@ -89,5 +88,3 @@ def output(self, x, u=None, t=None): plt.plot(linearizing_inputs, label="linearizing input") plt.legend() plt.show() - - diff --git a/tests/test_linearize.py b/tests/test_linearize.py index ea4b9f6..feb3dec 100644 --- a/tests/test_linearize.py +++ b/tests/test_linearize.py @@ -94,7 +94,7 @@ def test_linearize_dyn2lin(): class ScalarScalar(DynamicalSystem): n_states = "scalar" n_inputs = "scalar" - + def vector_field(self, x, u, t): return -1 * x + 2 * u @@ -164,7 +164,7 @@ def vector_field(self, x, u, t=None): def output(self, x, u=None, t=None): return x[0] - + def test_discrete_input_output_linearize(): sys = Lee7_4_5() refsys = sys.linearize() @@ -176,7 +176,7 @@ def test_discrete_input_output_linearize(): feedback_sys = DiscreteLinearizingSystem(sys, refsys, reldeg) t = np.linspace(0, 0.001, 10) u = np.cos(t) * 0.1 - _, v = Map(feedback_sys)(np.zeros(2+2+1), t, u) + _, v = Map(feedback_sys)(np.zeros(2 + 2 + 1), t, u) _, y = Map(sys)(np.zeros(2), t, u) _, y_ref = Map(refsys)(np.zeros(2), t, u) diff --git a/tests/test_util.py b/tests/test_util.py index 9f30e45..01ccbba 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,4 +1,3 @@ -import jax import jax.numpy as jnp import numpy.testing as npt