Skip to content

Commit

Permalink
feat: discrete time linearization (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl authored Jan 30, 2024
1 parent 1c8d855 commit bf02849
Show file tree
Hide file tree
Showing 17 changed files with 309 additions and 82 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name: Tests

on: [pull_request, push]
on: [pull_request]

jobs:
run-tests:
strategy:
matrix:
python-version: [ "3.9", "3.10", "3.11" ]
python-version: [ "3.10", "3.11", "3.12" ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
Expand All @@ -24,6 +24,6 @@ jobs:

- name: Test with pytest
run: |
python -m pip install jaxlib==0.4.16
python -m pip install jaxlib==0.4.23
python -m pip install .[dev]
python -m pytest --runslow --durations=0
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ include:
- estimation of linear ODE parameters via matching of frequency-response functions
- estimation from multiple experiments
- estimation with a poor man's multiple shooting
- input-output linearization of continuous- and discrete-time input-affine systems with well-defined relative degree
- input-output linearization of continuous-time input affine systems
- input-output linearization of discrete-time systems [example](examples/linearize_discrete_time)
- estimation of a system's relative-degree

Documentation is on its way. Until then, have a look at the [example](examples) and [test](tests) folders.


## Installing

Requires Python 3.9+, JAX 0.4.13+, Equinox 0.10.10+ and Diffrax 0.4.0+. With a
Requires Python 3.9+, JAX 0.4.23+, Equinox 0.11+ and Diffrax 0.5+. With a
suitable version of jaxlib installed:

pip install .
Expand Down
3 changes: 2 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ API documentation
=================

.. autosummary::

:toctree: generated
:recursive:

Expand All @@ -10,4 +11,4 @@ API documentation
dynax.linearize
dynax.derivative
dynax.interpolation
dynax.util
dynax.util
4 changes: 1 addition & 3 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ Examples
========

.. toctree::

index
../examples/fit_ode.ipynb
:maxdepth: 2

Have a look at the notebooks on the left or the following scripts.

Expand Down
3 changes: 3 additions & 0 deletions dynax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from .evolution import AbstractEvolution as AbstractEvolution, Flow as Flow, Map as Map
from .interpolation import spline_it as spline_it
from .linearize import (
discrete_input_output_linearize as discrete_input_output_linearize,
discrete_relative_degree as discrete_relative_degree,
DiscreteLinearizingSystem as DiscreteLinearizingSystem,
input_output_linearize as input_output_linearize,
LinearizingSystem as LinearizingSystem,
relative_degree as relative_degree,
Expand Down
36 changes: 0 additions & 36 deletions dynax/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,41 +164,6 @@ def fit_least_squares(
Parameters can be constrained via the `*_field` functions.
Args:
model: Flow instance holding initial parameter estimates
t: Times at which `y` is given
y: Target outputs of system
x0: Initial state
u: Pptional system input
batched: If True, interpret `t`, `y`, `x0`, `u` as holding multiple
experiments stacked along the first axis.
sigma: A 1-D sequence with values of the standard deviation of the measurement
error for each output of `model.system`. If None, `sigma` will be set to
the rms values of each measurement in `y`, which makes the cost
scale-invariant to magnitude differences between measurements.
absolute_sigma: If True, `sigma` is used in an absolute sense and the estimated
parameter covariance `pcov` reflects these absolute values. If False
(default), only the relative magnitudes of the `sigma` values matter and
`sigma` is scaled to match the sample variance of the residuals after the
fit.
reg_val: Weight of the l2 penalty term.
reg_bias: If "initial", bias the parameter estimates towards the values in
`model`.
verbose_mse: Scale cost to mean-squared-error for easier interpretation.
kwargs: Optional parameters for `scipy.optimize.least_squares`.
Returns:
`OptimizeResult` as returned by `scipy.optimize.least_squares` with the
following additional attributes defined:
result: `model` with estimated parameters.
cov: Covariance matrix of the parameter estimate.
y_pred: Model prediction at optimum.
key_paths: List of key_paths that index the corresponding entries in `cov`,
`jac`, and `x`.
mse: Mean-squared-error.
nmse: Normalized mean-squared-error.
nrmse: Normalized root-mean-squared-error.
"""
t = jnp.asarray(t)
Expand Down Expand Up @@ -463,7 +428,6 @@ def residuals(params):
sys = unravel(params)
H = transfer_function(sys)
Gyu_pred = jax.vmap(H)(s)
# FIXME: there are some bugs here, run pytest...
Syu_pred = Gyu_pred * broadcast_right(Suu, Gyu_pred)
r = (Syu - Syu_pred) * weight
r = jnp.concatenate((jnp.real(r), jnp.imag(r)))
Expand Down
4 changes: 1 addition & 3 deletions dynax/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jax._src.config import _validate_default_device
from jaxtyping import Array, ArrayLike, PyTree

from .interpolation import spline_it
Expand Down Expand Up @@ -149,6 +147,6 @@ def scan_fun(state, input):
_, x = jax.lax.scan(scan_fun, x0, inputs, length=num_steps)

# Compute output
y = jax.vmap(self.system.output)(x)
y = jax.vmap(self.system.output)(x, u, t)

return x, y
153 changes: 137 additions & 16 deletions dynax/linearize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Functions related to feedback linearization of nonlinear systems."""

from collections.abc import Callable
from functools import partial
from typing import Optional, Sequence

import jax
import jax.numpy as jnp
import numpy as np
import optimistix as optx
from jaxtyping import Array

from .derivative import lie_derivative
Expand Down Expand Up @@ -53,7 +55,7 @@ def input_output_linearize(
ref: LinearSystem,
output: Optional[int] = None,
asymptotic: Optional[Sequence] = None,
reg: Optional[float] = None
reg: Optional[float] = None,
) -> Callable[[Array, Array, float], float]:
"""Construct input-output linearizing feedback law.
Expand All @@ -64,7 +66,7 @@ def input_output_linearize(
output: specify linearizing output if systems have multiple outputs
asymptotic: If `None`, compute the exactly linearizing law. Otherwise,
a sequence of length `reldeg` defining the tracking behaviour.
reg: parameter that control the linearization effort. Only effective if
reg: parameter that control the linearization effort. Only effective if
asymptotic is not None.
Note:
Expand Down Expand Up @@ -117,7 +119,7 @@ def feedbacklaw(x: Array, z: Array, v: float) -> float:
for ai, Lfih, cAi in zip(alphas, Lfihs, cAis)
]
)
error = (y_reldeg_ref - y_reldeg + jnp.sum(ae0s))
error = y_reldeg_ref - y_reldeg + jnp.sum(ae0s)
if reg is None:
return error / LgLfnm1h(x)
else:
Expand All @@ -127,6 +129,120 @@ def feedbacklaw(x: Array, z: Array, v: float) -> float:
return feedbacklaw


def propagate(f: Callable[[Array, float], Array], n: int, x: Array, u: float) -> Array:
"""Propagates system n steps."""
# TODO: replace by lax.scan
if n == 0:
return x
return propagate(f, n - 1, f(x, u), u)


def discrete_relative_degree(
sys: DynamicalSystem,
xs: Array,
us: Array,
max_reldeg=10,
output: Optional[int] = None,
):
"""Estimate relative degree of discrete-time system on region xs.
Source: Lee, Linearization of Nonlinear Control Systems (2022), Def. 7.7
"""
f = sys.vector_field
h = sys.output

y_depends_u = jax.grad(lambda n, x, u: h(propagate(f, n, x, u)), 2)

for n in range(1, max_reldeg + 1):
res = jax.vmap(partial(y_depends_u, n))(xs, us)
if np.all(res == 0):
continue
elif np.all(res != 0):
return n
else:
raise RuntimeError("sys has ill defined relative degree.")
raise RuntimeError("Could not estmate relative degree. Increase max_reldeg.")


def discrete_input_output_linearize(
sys: DynamicalSystem,
reldeg: int,
ref: DynamicalSystem,
output: Optional[int] = None,
solver: Optional[optx.AbstractRootFinder] = None,
) -> Callable[[Array, Array, float, float], float]:
"""Construct the input-output linearizing feedback for a discrete-time system."""

# Lee 2022, Chap. 7.4
f = lambda x, u: sys.vector_field(x, u)
h = sys.output
if sys.n_inputs != ref.n_inputs != 1:
raise ValueError("Systems must have single input.")
if output is None:
if not (sys.n_outputs == ref.n_outputs and sys.n_outputs in ["scalar", 1]):
raise ValueError("Systems must be single output and `output` is None.")
_output = lambda x: x
else:
_output = lambda x: x[output]

if solver is None:
solver = optx.Newton(rtol=1e-6, atol=1e-6)

def y_reldeg_ref(z, v):
if isinstance(ref, LinearSystem):
# A little faster for the linear case (if this is not optimized by jit)
A, b, c = ref.A, ref.B, ref.C
A_reldeg = c.dot(np.linalg.matrix_power(A, reldeg))
B_reldeg = c.dot(np.linalg.matrix_power(A, reldeg - 1)).dot(b)
return _output(A_reldeg.dot(z) + B_reldeg.dot(v))
else:
_output(ref.output(propagate(ref.vector_field, reldeg, z, v)))

def feedbacklaw(x: Array, z: Array, v: float, u_prev: float):
def fn(u, args):
return (
_output(h(propagate(f, reldeg, x, u))) - y_reldeg_ref(z, v)
).squeeze()

u = optx.root_find(fn, solver, u_prev).value
return u

return feedbacklaw


class DiscreteLinearizingSystem(DynamicalSystem):
r"""Dynamics computing linearizing feedback as output."""

sys: ControlAffine
refsys: LinearSystem
feedbacklaw: Callable

n_inputs = "scalar"

def __init__(self, sys, refsys, reldeg, linearizing_output=None):
if sys.n_inputs != "scalar":
raise ValueError("Only single input systems supported.")
self.sys = sys
self.refsys = refsys
self.n_states = self.sys.n_states + self.refsys.n_states + 1
self.feedbacklaw = discrete_input_output_linearize(
sys, reldeg, refsys, linearizing_output
)

def vector_field(self, x, u, t=None):
x, z, v_last = x[: self.sys.n_states], x[self.sys.n_states : -1], x[-1]
v = self.feedbacklaw(x, z, u, v_last)
xn = self.sys.vector_field(x, v)
zn = self.refsys.vector_field(z, u)
return jnp.concatenate((xn, zn, jnp.array([v])))

def output(self, x, u, t=None):
x, z, v_last = x[: self.sys.n_states], x[self.sys.n_states : -1], x[-1]
v = self.feedbacklaw(x, z, u, v_last) # FIXME: feedback law called twice
return v


class LinearizingSystem(DynamicalSystem):
r"""Coupled ODE of nonlinear dynamics, linear reference and io linearizing law.
Expand All @@ -145,33 +261,38 @@ class LinearizingSystem(DynamicalSystem):

sys: ControlAffine
refsys: LinearSystem
feedbacklaw: Optional[Callable] = None

def __init__(self, sys, refsys, reldeg, feedbacklaw=None, linearizing_output=None):
if sys.n_inputs > 1:
raise ValueError("Only single input systems supported.")
feedbacklaw: Callable[[Array, Array, float], float]

n_inputs = "scalar"

def __init__(
self,
sys: ControlAffine,
refsys: LinearSystem,
reldeg: int,
feedbacklaw: Optional[Callable] = None,
linearizing_output: Optional[int] = None,
):
self.sys = sys
self.refsys = refsys
self.n_inputs = "scalar"
self.n_states = (
self.sys.n_states + self.refsys.n_states
) # FIXME: support "scalar"
self.feedbacklaw = feedbacklaw
if feedbacklaw is None:
if callable(feedbacklaw):
self.feedbacklaw = feedbacklaw
else:
self.feedbacklaw = input_output_linearize(
sys, reldeg, refsys, linearizing_output
)

def vector_field(self, x, u=None, t=None):
x, z = x[: self.sys.n_states], x[self.sys.n_states :]
if u is None:
u = 0.0
y = self.feedbacklaw(x, z, u)
dx = self.sys.vector_field(x, y)
dz = self.refsys.vector_field(z, u)
return jnp.concatenate((dx, dz))

def output(self, x, u=None, t=None):
def output(self, x, u, t=None):
x, z = x[: self.sys.n_states], x[self.sys.n_states :]
y = self.feedbacklaw(x, z, u)
return y
ur = self.feedbacklaw(x, z, u)
return ur
4 changes: 2 additions & 2 deletions dynax/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jax import Array
from jax.typing import ArrayLike

from .util import dim2shape, ssmatrix
from .util import dim2shape


def _linearize(f, h, x0, u0, t0):
Expand Down Expand Up @@ -398,7 +398,7 @@ class DynamicStateFeedbackSystem(DynamicalSystem):
ẋ &= f_1(x, v(x, z, u), t) \\
ż &= f_2(z, r, t) \\
y &= h(x, u, t)
y &= h_1(x, u, t)
"""

Expand Down
2 changes: 1 addition & 1 deletion dynax/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import functools
from typing import Literal

import equinox
import jax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike
from typing import Literal


def ssmatrix(data: ArrayLike, axis: int = 0) -> Array:
Expand Down
2 changes: 1 addition & 1 deletion examples/fit_multiple_shooting_second_order_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class NonlinearDrag(ControlAffine):

# Set the number of states (order of system), the number of inputs
n_states = 2
n_inputs = 1
n_inputs = "scalar"

# Define the dynamical system via the methods f, g, and h
def f(self, x):
Expand Down
Loading

0 comments on commit bf02849

Please sign in to comment.