From bf0284928d6ab3daef5a78abdf11b9ae03ab1fc8 Mon Sep 17 00:00:00 2001 From: fhchl Date: Tue, 30 Jan 2024 15:26:14 +0100 Subject: [PATCH] feat: discrete time linearization (#24) --- .github/workflows/run_tests.yml | 6 +- README.md | 5 +- docs/api.rst | 3 +- docs/examples.rst | 4 +- dynax/__init__.py | 3 + dynax/estimation.py | 36 ----- dynax/evolution.py | 4 +- dynax/linearize.py | 153 ++++++++++++++++-- dynax/system.py | 4 +- dynax/util.py | 2 +- .../fit_multiple_shooting_second_order_sys.py | 2 +- examples/fit_ode.ipynb | 13 +- examples/fit_second_order_sys.py | 2 +- examples/linearize_discrete_time.py | 90 +++++++++++ tests/test_linearize.py | 61 ++++++- tests/test_systems.py | 2 - tests/test_util.py | 1 - 17 files changed, 309 insertions(+), 82 deletions(-) create mode 100644 examples/linearize_discrete_time.py diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index fa36c1b..9668105 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -1,12 +1,12 @@ name: Tests -on: [pull_request, push] +on: [pull_request] jobs: run-tests: strategy: matrix: - python-version: [ "3.9", "3.10", "3.11" ] + python-version: [ "3.10", "3.11", "3.12" ] os: [ ubuntu-latest ] fail-fast: false runs-on: ${{ matrix.os }} @@ -24,6 +24,6 @@ jobs: - name: Test with pytest run: | - python -m pip install jaxlib==0.4.16 + python -m pip install jaxlib==0.4.23 python -m pip install .[dev] python -m pytest --runslow --durations=0 diff --git a/README.md b/README.md index 15dd2c8..30b2435 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,8 @@ include: - estimation of linear ODE parameters via matching of frequency-response functions - estimation from multiple experiments - estimation with a poor man's multiple shooting -- input-output linearization of continuous- and discrete-time input-affine systems with well-defined relative degree +- input-output linearization of continuous-time input affine systems +- input-output linearization of discrete-time systems [example](examples/linearize_discrete_time) - estimation of a system's relative-degree Documentation is on its way. Until then, have a look at the [example](examples) and [test](tests) folders. @@ -22,7 +23,7 @@ Documentation is on its way. Until then, have a look at the [example](examples) ## Installing -Requires Python 3.9+, JAX 0.4.13+, Equinox 0.10.10+ and Diffrax 0.4.0+. With a +Requires Python 3.9+, JAX 0.4.23+, Equinox 0.11+ and Diffrax 0.5+. With a suitable version of jaxlib installed: pip install . diff --git a/docs/api.rst b/docs/api.rst index 11fd959..fbcbfb3 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -2,6 +2,7 @@ API documentation ================= .. autosummary:: + :toctree: generated :recursive: @@ -10,4 +11,4 @@ API documentation dynax.linearize dynax.derivative dynax.interpolation - dynax.util \ No newline at end of file + dynax.util diff --git a/docs/examples.rst b/docs/examples.rst index 734503c..1470641 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -2,9 +2,7 @@ Examples ======== .. toctree:: - - index - ../examples/fit_ode.ipynb + :maxdepth: 2 Have a look at the notebooks on the left or the following scripts. diff --git a/dynax/__init__.py b/dynax/__init__.py index 2266dbc..641d3cb 100644 --- a/dynax/__init__.py +++ b/dynax/__init__.py @@ -12,6 +12,9 @@ from .evolution import AbstractEvolution as AbstractEvolution, Flow as Flow, Map as Map from .interpolation import spline_it as spline_it from .linearize import ( + discrete_input_output_linearize as discrete_input_output_linearize, + discrete_relative_degree as discrete_relative_degree, + DiscreteLinearizingSystem as DiscreteLinearizingSystem, input_output_linearize as input_output_linearize, LinearizingSystem as LinearizingSystem, relative_degree as relative_degree, diff --git a/dynax/estimation.py b/dynax/estimation.py index 2f89766..b98818f 100644 --- a/dynax/estimation.py +++ b/dynax/estimation.py @@ -164,41 +164,6 @@ def fit_least_squares( Parameters can be constrained via the `*_field` functions. - Args: - model: Flow instance holding initial parameter estimates - t: Times at which `y` is given - y: Target outputs of system - x0: Initial state - u: Pptional system input - batched: If True, interpret `t`, `y`, `x0`, `u` as holding multiple - experiments stacked along the first axis. - sigma: A 1-D sequence with values of the standard deviation of the measurement - error for each output of `model.system`. If None, `sigma` will be set to - the rms values of each measurement in `y`, which makes the cost - scale-invariant to magnitude differences between measurements. - absolute_sigma: If True, `sigma` is used in an absolute sense and the estimated - parameter covariance `pcov` reflects these absolute values. If False - (default), only the relative magnitudes of the `sigma` values matter and - `sigma` is scaled to match the sample variance of the residuals after the - fit. - reg_val: Weight of the l2 penalty term. - reg_bias: If "initial", bias the parameter estimates towards the values in - `model`. - verbose_mse: Scale cost to mean-squared-error for easier interpretation. - kwargs: Optional parameters for `scipy.optimize.least_squares`. - - Returns: - `OptimizeResult` as returned by `scipy.optimize.least_squares` with the - following additional attributes defined: - - result: `model` with estimated parameters. - cov: Covariance matrix of the parameter estimate. - y_pred: Model prediction at optimum. - key_paths: List of key_paths that index the corresponding entries in `cov`, - `jac`, and `x`. - mse: Mean-squared-error. - nmse: Normalized mean-squared-error. - nrmse: Normalized root-mean-squared-error. """ t = jnp.asarray(t) @@ -463,7 +428,6 @@ def residuals(params): sys = unravel(params) H = transfer_function(sys) Gyu_pred = jax.vmap(H)(s) - # FIXME: there are some bugs here, run pytest... Syu_pred = Gyu_pred * broadcast_right(Suu, Gyu_pred) r = (Syu - Syu_pred) * weight r = jnp.concatenate((jnp.real(r), jnp.imag(r))) diff --git a/dynax/evolution.py b/dynax/evolution.py index f823b19..31d61ee 100644 --- a/dynax/evolution.py +++ b/dynax/evolution.py @@ -4,8 +4,6 @@ import equinox as eqx import jax import jax.numpy as jnp -import numpy as np -from jax._src.config import _validate_default_device from jaxtyping import Array, ArrayLike, PyTree from .interpolation import spline_it @@ -149,6 +147,6 @@ def scan_fun(state, input): _, x = jax.lax.scan(scan_fun, x0, inputs, length=num_steps) # Compute output - y = jax.vmap(self.system.output)(x) + y = jax.vmap(self.system.output)(x, u, t) return x, y diff --git a/dynax/linearize.py b/dynax/linearize.py index 1397780..fa39a8b 100644 --- a/dynax/linearize.py +++ b/dynax/linearize.py @@ -1,11 +1,13 @@ """Functions related to feedback linearization of nonlinear systems.""" from collections.abc import Callable +from functools import partial from typing import Optional, Sequence import jax import jax.numpy as jnp import numpy as np +import optimistix as optx from jaxtyping import Array from .derivative import lie_derivative @@ -53,7 +55,7 @@ def input_output_linearize( ref: LinearSystem, output: Optional[int] = None, asymptotic: Optional[Sequence] = None, - reg: Optional[float] = None + reg: Optional[float] = None, ) -> Callable[[Array, Array, float], float]: """Construct input-output linearizing feedback law. @@ -64,7 +66,7 @@ def input_output_linearize( output: specify linearizing output if systems have multiple outputs asymptotic: If `None`, compute the exactly linearizing law. Otherwise, a sequence of length `reldeg` defining the tracking behaviour. - reg: parameter that control the linearization effort. Only effective if + reg: parameter that control the linearization effort. Only effective if asymptotic is not None. Note: @@ -117,7 +119,7 @@ def feedbacklaw(x: Array, z: Array, v: float) -> float: for ai, Lfih, cAi in zip(alphas, Lfihs, cAis) ] ) - error = (y_reldeg_ref - y_reldeg + jnp.sum(ae0s)) + error = y_reldeg_ref - y_reldeg + jnp.sum(ae0s) if reg is None: return error / LgLfnm1h(x) else: @@ -127,6 +129,120 @@ def feedbacklaw(x: Array, z: Array, v: float) -> float: return feedbacklaw +def propagate(f: Callable[[Array, float], Array], n: int, x: Array, u: float) -> Array: + """Propagates system n steps.""" + # TODO: replace by lax.scan + if n == 0: + return x + return propagate(f, n - 1, f(x, u), u) + + +def discrete_relative_degree( + sys: DynamicalSystem, + xs: Array, + us: Array, + max_reldeg=10, + output: Optional[int] = None, +): + """Estimate relative degree of discrete-time system on region xs. + + Source: Lee, Linearization of Nonlinear Control Systems (2022), Def. 7.7 + + """ + f = sys.vector_field + h = sys.output + + y_depends_u = jax.grad(lambda n, x, u: h(propagate(f, n, x, u)), 2) + + for n in range(1, max_reldeg + 1): + res = jax.vmap(partial(y_depends_u, n))(xs, us) + if np.all(res == 0): + continue + elif np.all(res != 0): + return n + else: + raise RuntimeError("sys has ill defined relative degree.") + raise RuntimeError("Could not estmate relative degree. Increase max_reldeg.") + + +def discrete_input_output_linearize( + sys: DynamicalSystem, + reldeg: int, + ref: DynamicalSystem, + output: Optional[int] = None, + solver: Optional[optx.AbstractRootFinder] = None, +) -> Callable[[Array, Array, float, float], float]: + """Construct the input-output linearizing feedback for a discrete-time system.""" + + # Lee 2022, Chap. 7.4 + f = lambda x, u: sys.vector_field(x, u) + h = sys.output + if sys.n_inputs != ref.n_inputs != 1: + raise ValueError("Systems must have single input.") + if output is None: + if not (sys.n_outputs == ref.n_outputs and sys.n_outputs in ["scalar", 1]): + raise ValueError("Systems must be single output and `output` is None.") + _output = lambda x: x + else: + _output = lambda x: x[output] + + if solver is None: + solver = optx.Newton(rtol=1e-6, atol=1e-6) + + def y_reldeg_ref(z, v): + if isinstance(ref, LinearSystem): + # A little faster for the linear case (if this is not optimized by jit) + A, b, c = ref.A, ref.B, ref.C + A_reldeg = c.dot(np.linalg.matrix_power(A, reldeg)) + B_reldeg = c.dot(np.linalg.matrix_power(A, reldeg - 1)).dot(b) + return _output(A_reldeg.dot(z) + B_reldeg.dot(v)) + else: + _output(ref.output(propagate(ref.vector_field, reldeg, z, v))) + + def feedbacklaw(x: Array, z: Array, v: float, u_prev: float): + def fn(u, args): + return ( + _output(h(propagate(f, reldeg, x, u))) - y_reldeg_ref(z, v) + ).squeeze() + + u = optx.root_find(fn, solver, u_prev).value + return u + + return feedbacklaw + + +class DiscreteLinearizingSystem(DynamicalSystem): + r"""Dynamics computing linearizing feedback as output.""" + + sys: ControlAffine + refsys: LinearSystem + feedbacklaw: Callable + + n_inputs = "scalar" + + def __init__(self, sys, refsys, reldeg, linearizing_output=None): + if sys.n_inputs != "scalar": + raise ValueError("Only single input systems supported.") + self.sys = sys + self.refsys = refsys + self.n_states = self.sys.n_states + self.refsys.n_states + 1 + self.feedbacklaw = discrete_input_output_linearize( + sys, reldeg, refsys, linearizing_output + ) + + def vector_field(self, x, u, t=None): + x, z, v_last = x[: self.sys.n_states], x[self.sys.n_states : -1], x[-1] + v = self.feedbacklaw(x, z, u, v_last) + xn = self.sys.vector_field(x, v) + zn = self.refsys.vector_field(z, u) + return jnp.concatenate((xn, zn, jnp.array([v]))) + + def output(self, x, u, t=None): + x, z, v_last = x[: self.sys.n_states], x[self.sys.n_states : -1], x[-1] + v = self.feedbacklaw(x, z, u, v_last) # FIXME: feedback law called twice + return v + + class LinearizingSystem(DynamicalSystem): r"""Coupled ODE of nonlinear dynamics, linear reference and io linearizing law. @@ -145,33 +261,38 @@ class LinearizingSystem(DynamicalSystem): sys: ControlAffine refsys: LinearSystem - feedbacklaw: Optional[Callable] = None - - def __init__(self, sys, refsys, reldeg, feedbacklaw=None, linearizing_output=None): - if sys.n_inputs > 1: - raise ValueError("Only single input systems supported.") + feedbacklaw: Callable[[Array, Array, float], float] + + n_inputs = "scalar" + + def __init__( + self, + sys: ControlAffine, + refsys: LinearSystem, + reldeg: int, + feedbacklaw: Optional[Callable] = None, + linearizing_output: Optional[int] = None, + ): self.sys = sys self.refsys = refsys - self.n_inputs = "scalar" self.n_states = ( self.sys.n_states + self.refsys.n_states ) # FIXME: support "scalar" - self.feedbacklaw = feedbacklaw - if feedbacklaw is None: + if callable(feedbacklaw): + self.feedbacklaw = feedbacklaw + else: self.feedbacklaw = input_output_linearize( sys, reldeg, refsys, linearizing_output ) def vector_field(self, x, u=None, t=None): x, z = x[: self.sys.n_states], x[self.sys.n_states :] - if u is None: - u = 0.0 y = self.feedbacklaw(x, z, u) dx = self.sys.vector_field(x, y) dz = self.refsys.vector_field(z, u) return jnp.concatenate((dx, dz)) - def output(self, x, u=None, t=None): + def output(self, x, u, t=None): x, z = x[: self.sys.n_states], x[self.sys.n_states :] - y = self.feedbacklaw(x, z, u) - return y + ur = self.feedbacklaw(x, z, u) + return ur diff --git a/dynax/system.py b/dynax/system.py index 52347f6..416d2fd 100644 --- a/dynax/system.py +++ b/dynax/system.py @@ -11,7 +11,7 @@ from jax import Array from jax.typing import ArrayLike -from .util import dim2shape, ssmatrix +from .util import dim2shape def _linearize(f, h, x0, u0, t0): @@ -398,7 +398,7 @@ class DynamicStateFeedbackSystem(DynamicalSystem): ẋ &= f_1(x, v(x, z, u), t) \\ ż &= f_2(z, r, t) \\ - y &= h(x, u, t) + y &= h_1(x, u, t) """ diff --git a/dynax/util.py b/dynax/util.py index 79d54a1..3aaca7b 100644 --- a/dynax/util.py +++ b/dynax/util.py @@ -1,10 +1,10 @@ import functools +from typing import Literal import equinox import jax import jax.numpy as jnp from jaxtyping import Array, ArrayLike -from typing import Literal def ssmatrix(data: ArrayLike, axis: int = 0) -> Array: diff --git a/examples/fit_multiple_shooting_second_order_sys.py b/examples/fit_multiple_shooting_second_order_sys.py index a7d8974..429745a 100644 --- a/examples/fit_multiple_shooting_second_order_sys.py +++ b/examples/fit_multiple_shooting_second_order_sys.py @@ -36,7 +36,7 @@ class NonlinearDrag(ControlAffine): # Set the number of states (order of system), the number of inputs n_states = 2 - n_inputs = 1 + n_inputs = "scalar" # Define the dynamical system via the methods f, g, and h def f(self, x): diff --git a/examples/fit_ode.ipynb b/examples/fit_ode.ipynb index bfb1aff..86945e4 100644 --- a/examples/fit_ode.ipynb +++ b/examples/fit_ode.ipynb @@ -26,9 +26,9 @@ " Flow,\n", " non_negative_field,\n", " pretty,\n", - " static_field,\n", ")\n", "\n", + "\n", "np.random.seed(42)" ] }, @@ -79,7 +79,7 @@ " # Set the number of states (order of system) and inputs.\n", " n_states = 2\n", " n_inputs = 1\n", - " \n", + "\n", " # Declare parameters as dataclass fields.\n", " m: float\n", " r: float = non_negative_field()\n", @@ -92,13 +92,12 @@ " return jnp.array(\n", " [\n", " x2,\n", - " (-self.r * x2 - self.r2 * jnp.abs(x2) * x2 - self.k * x1 + u)\n", - " / self.m,\n", + " (-self.r * x2 - self.r2 * jnp.abs(x2) * x2 - self.k * x1 + u) / self.m,\n", " ]\n", " )\n", "\n", " def output(self, x, u, t):\n", - " return 1000*x[0]" + " return 1000 * x[0]" ] }, { @@ -185,7 +184,7 @@ "# time\n", "t = np.linspace(0, 10, 1000)\n", "# forcing signal\n", - "u = np.sin(t*2*np.pi)\n", + "u = np.sin(t * 2 * np.pi)\n", "# zero initial state\n", "x0 = [0.0, 0.0]\n", "# x are the states and y is the output\n", @@ -248,7 +247,7 @@ ], "source": [ "init_model = Flow(\n", - " system=NonlinearDrag(m=2.0, r=1.0, r2=1., k=2.0),\n", + " system=NonlinearDrag(m=2.0, r=1.0, r2=1.0, k=2.0),\n", " solver=diffrax.Tsit5(),\n", " step=diffrax.PIDController(rtol=1e-3, atol=1e-6),\n", ")\n", diff --git a/examples/fit_second_order_sys.py b/examples/fit_second_order_sys.py index 0996e2a..f85cf04 100644 --- a/examples/fit_second_order_sys.py +++ b/examples/fit_second_order_sys.py @@ -30,7 +30,7 @@ class NonlinearDrag(ControlAffine): # Set the number of states (order of system), the number of in- and outputs. n_states = 2 - n_inputs = 1 + n_inputs = "scalar" # Define the dynamical system via the methods f, g, and h def f(self, x): diff --git a/examples/linearize_discrete_time.py b/examples/linearize_discrete_time.py new file mode 100644 index 0000000..c8cef09 --- /dev/null +++ b/examples/linearize_discrete_time.py @@ -0,0 +1,90 @@ +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +from equinox.nn import GRUCell +from jax.random import PRNGKey + +from dynax import ( + discrete_relative_degree, + DiscreteLinearizingSystem, + DynamicalSystem, + LinearSystem, + Map, +) + + +# A nonlinear discrete-time system. +class Recurrent(DynamicalSystem): + cell: GRUCell + + n_inputs = "scalar" + + def __init__(self, hidden_size, *, key): + input_size = 1 + self.cell = GRUCell(input_size, hidden_size, use_bias=False, key=key) + self.n_states = hidden_size + + def vector_field(self, x, u, t=None): + return self.cell(jnp.array([u]), x) + + def output(self, x, u=None, t=None): + return x[0] + + +# A linear reference system. +reference_system = LinearSystem( + A=jnp.array([[-0.3, 0.1], [0, -0.3]]), + B=jnp.array([0.0, 1.0]), + C=jnp.array([1, 0]), + D=jnp.array(0), +) + +# System to contol. +hidden_size = 3 +system = Recurrent(hidden_size=hidden_size, key=PRNGKey(0)) + +# We want the nonlinear systems output to be equal to the reference system's output +# when driven with this input. +inputs = 0.1 * jnp.concatenate((jnp.array([0.1, 0.2, 0.3]), jnp.zeros(10))) + +# The relative degree of the reference system can be larger or equal to the relative +# degree of the nonlinear system. Here we test for the relative degree with a set of +# points and inputs. +reldeg = discrete_relative_degree( + system, np.random.normal(size=(inputs.size, system.n_states)), inputs +) +print("Relative degree of nonlinear system:", reldeg) +print( + "Relative degree of reference system:", + discrete_relative_degree( + reference_system, + np.random.normal(size=(inputs.size, reference_system.n_states)), + inputs, + ), +) + +# We compute the input signal that forces the outputs of the nonlinear and reference +# systems to be equal by solving a coupled system. +linearizing_system = DiscreteLinearizingSystem(system, reference_system, reldeg) + +# The output of this system when driven with the reference input is the linearizing +# input. The coupled system as an extra state used internally. +_, linearizing_inputs = Map(linearizing_system)( + jnp.zeros(system.n_states + reference_system.n_states + 1), u=inputs +) + +# Lets simulate the original system, the linear reference and the linearized system. +states_orig, output_orig = Map(system)(x0=jnp.zeros(hidden_size), u=inputs) +_, output_ref = Map(reference_system)(x0=jnp.zeros(reference_system.n_states), u=inputs) +_, output_linearized = Map(system)(jnp.zeros(hidden_size), u=linearizing_inputs) + +assert np.allclose(output_ref, output_linearized) + +plt.plot(output_orig, label="GRUCell") +plt.plot(output_ref, label="linear reference") +plt.plot(output_linearized, "--", label="input-output linearized GRU") +plt.legend() +plt.figure() +plt.plot(linearizing_inputs, label="linearizing input") +plt.legend() +plt.show() diff --git a/tests/test_linearize.py b/tests/test_linearize.py index 342081e..feb3dec 100644 --- a/tests/test_linearize.py +++ b/tests/test_linearize.py @@ -4,13 +4,20 @@ from dynax import ( ControlAffine, + discrete_relative_degree, + DiscreteLinearizingSystem, DynamicalSystem, DynamicStateFeedbackSystem, Flow, + input_output_linearize, LinearSystem, + Map, + relative_degree, ) from dynax.example_models import NonlinearDrag, Sastry9_9 -from dynax.linearize import input_output_linearize, is_controllable, relative_degree +from dynax.linearize import ( + is_controllable, +) tols = dict(rtol=1e-04, atol=1e-06) @@ -45,6 +52,20 @@ def test_relative_degree(): assert relative_degree(sys, xs) == 1 +def test_discrete_relative_degree(): + xs = np.random.normal(size=(100, 2)) + us = np.random.normal(size=(100, 1)) + + sys = SpringMassDamperWithOutput(out=0) + assert discrete_relative_degree(sys, xs, us) == 2 + + with npt.assert_raises(RuntimeError): + discrete_relative_degree(sys, xs, us, max_reldeg=1) + + sys = SpringMassDamperWithOutput(out=1) + assert discrete_relative_degree(sys, xs, us) == 1 + + def test_is_controllable(): n = 3 A = np.diag(np.arange(n)) @@ -73,8 +94,12 @@ def test_linearize_dyn2lin(): class ScalarScalar(DynamicalSystem): n_states = "scalar" n_inputs = "scalar" - vector_field = lambda self, x, u, t: -1 * x + 2 * u # FIXME remove Nones - output = lambda self, x, u, t: 3 * x + 4 * u + + def vector_field(self, x, u, t): + return -1 * x + 2 * u + + def output(self, x, u, t): + return 3 * x + 4 * u sys = ScalarScalar() linsys = sys.linearize() @@ -126,3 +151,33 @@ def test_input_output_linearize_multiple_outputs(): y_ref = Flow(ref)(np.zeros(sys.n_states), t, u)[1] y = Flow(feedback_sys)(np.zeros(feedback_sys.n_states), t, u)[1] npt.assert_allclose(y_ref[:, out_idx], y[:, out_idx], **tols) + + +class Lee7_4_5(DynamicalSystem): + n_states = 2 + n_inputs = "scalar" + + def vector_field(self, x, u, t=None): + x1, x2 = x + return 0.1 * jnp.array([x1 + x1**3 + x2, x2 + x2**3 + u]) + + def output(self, x, u=None, t=None): + return x[0] + + +def test_discrete_input_output_linearize(): + sys = Lee7_4_5() + refsys = sys.linearize() + xs = np.random.normal(size=(100, 2)) + us = np.random.normal(size=100) + reldeg = discrete_relative_degree(sys, xs, us) + assert reldeg == 2 + + feedback_sys = DiscreteLinearizingSystem(sys, refsys, reldeg) + t = np.linspace(0, 0.001, 10) + u = np.cos(t) * 0.1 + _, v = Map(feedback_sys)(np.zeros(2 + 2 + 1), t, u) + _, y = Map(sys)(np.zeros(2), t, u) + _, y_ref = Map(refsys)(np.zeros(2), t, u) + + npt.assert_allclose(y_ref, y, **tols) diff --git a/tests/test_systems.py b/tests/test_systems.py index 2d91538..cf4d407 100644 --- a/tests/test_systems.py +++ b/tests/test_systems.py @@ -131,7 +131,6 @@ def test_discrete_forward_model(): B = jnp.array([[0], [1]]) C = jnp.array([[1, 0]]) D = jnp.zeros((1, 1)) - # test just input sys = LinearSystem(A, B, C, D) model = Map(sys) @@ -140,7 +139,6 @@ def test_discrete_forward_model(): _, scipy_y, scipy_x = dlsim(scipy_sys, u, x0=x0) npt.assert_allclose(scipy_y, y, **tols) npt.assert_allclose(scipy_x, x, **tols) - # test input and time (results should be same) x, y = model(x0, u=u, t=t) scipy_t, scipy_y, scipy_x = dlsim(scipy_sys, u, x0=x0, t=t) diff --git a/tests/test_util.py b/tests/test_util.py index 9f30e45..01ccbba 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,4 +1,3 @@ -import jax import jax.numpy as jnp import numpy.testing as npt