Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Sep 21, 2023
1 parent 32ca291 commit c60c015
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 47 deletions.
3 changes: 3 additions & 0 deletions dynax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from .evolution import AbstractEvolution as AbstractEvolution, Flow as Flow, Map as Map
from .interpolation import spline_it as spline_it
from .linearize import (
discrete_input_output_linearize as discrete_input_output_linearize,
discrete_relative_degree as discrete_relative_degree,
DiscreteLinearizingSystem as DiscreteLinearizingSystem,
input_output_linearize as input_output_linearize,
LinearizingSystem as LinearizingSystem,
relative_degree as relative_degree,
Expand Down
4 changes: 2 additions & 2 deletions dynax/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ class Map(AbstractEvolution):
def __call__(
self,
x0: ArrayLike,
num_steps: Optional[int] = None,
t: Optional[ArrayLike] = None,
u: Optional[ArrayLike] = None,
num_steps: Optional[int] = None,
squeeze: bool = True,
):
"""Solve discrete map."""
Expand Down Expand Up @@ -138,7 +138,7 @@ def scan_fun(state, input):
_, x = jax.lax.scan(scan_fun, x0, inputs, length=num_steps)

# Compute output
y = jax.vmap(self.system.output)(x)
y = jax.vmap(self.system.output)(x, u)
if squeeze:
# Remove singleton dimensions
x = x.squeeze()
Expand Down
147 changes: 120 additions & 27 deletions dynax/linearize.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Functions related to feedback linearization of nonlinear systems."""

from collections.abc import Callable
from functools import partial
from typing import Optional, Sequence

import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array
from functools import partial
from diffrax import NewtonNonlinearSolver
from jaxtyping import Array, ArrayLike

from .derivative import lie_derivative
from .system import ControlAffine, DynamicalSystem, LinearSystem

Expand Down Expand Up @@ -37,31 +39,6 @@ def relative_degree(sys, xs, max_reldeg=10, output: Optional[int] = None) -> int
raise RuntimeError("Could not estimate relative degree. Increase max_reldeg.")


def discrete_relative_degree(sys, xs, us, max_reldeg=10, output: Optional[int] = None):
"""Estimate relative degree of discrete-time system on region xs."""
f = sys.vector_field
h = sys.output

def F(n, x, u):
if n == 0:
return x
elif n == 1:
return f(x, u)
return F(n-1, f(x, u), 0)

deriv_u_hF = jax.grad(lambda k, x, u: h(F(k, x, u)), 2)

for n in range(1, max_reldeg + 1):
res = jax.vmap(partial(deriv_u_hF, n))(xs, us)
if np.all(res == 0.0):
continue
elif np.all(res != 0.0):
return n
else:
raise RuntimeError("sys has ill defined relative degree.")
raise RuntimeError("Could not estmate relative degree. Increase max_reldeg.")


def is_controllable(A, B) -> bool:
"""Test controllability of linear system."""
n = A.shape[0]
Expand Down Expand Up @@ -139,6 +116,122 @@ def feedbacklaw(x: Array, z: Array, v: float) -> float:
return feedbacklaw


def prop(f: Callable[[Array, float], Array], n: int, x: Array, u: float) -> Array:
"""Propagates system n steps."""
# TODO: replace by lax.scan
if n == 0:
return x
elif n == 1:
return f(x, u)
return prop(f, n-1, f(x, u), 0)


def discrete_relative_degree(
sys: DynamicalSystem,
xs: ArrayLike,
us: ArrayLike,
max_reldeg=10,
output: Optional[int] = None
):
"""Estimate relative degree of discrete-time system on region xs.
Source: Lee, Linearization of Nonlinear Control Systems (2022), Def. 7.7
"""
f = sys.vector_field
h = sys.output

y_depends_u = jax.grad(lambda n, x, u: h(prop(f, n, x, u)).squeeze(), 2)

for n in range(1, max_reldeg + 1):
res = jax.vmap(partial(y_depends_u, n))(xs, us)
if np.all(res == 0):
continue
elif np.all(res != 0):
return n
else:
raise RuntimeError("sys has ill defined relative degree.")
raise RuntimeError("Could not estmate relative degree. Increase max_reldeg.")


def discrete_input_output_linearize(
sys: DynamicalSystem,
reldeg: int,
ref: LinearSystem,
output: Optional[int] = None,
solver=None
) -> Callable[[Array, Array, float, float], float]:
"""Construct input-output linearizing feedback law for a discrete-time system."""
if not (sys.n_inputs == ref.n_inputs == 1):
raise ValueError("systems must be single input")
if output is None:
if not (sys.n_outputs == ref.n_outputs == 1):
raise ValueError("Systems must be single output and `output` is None.")
h = sys.h
A, b, c = ref.A, ref.B, ref.C
else:
h = lambda x, t=None: sys.h(x, t=t)[output]
A, b, c = ref.A, ref.B, ref.C[output]
if solver is None:
solver = NewtonNonlinearSolver(rtol=1e-3, atol=1e-6)

cAn = c.dot(np.linalg.matrix_power(A, reldeg))
cAnm1b = c.dot(np.linalg.matrix_power(A, reldeg-1)).dot(b)

def feedbacklaw(x: Array, z: Array, v: float, u_prev: float):
y_reldeg_ref = cAn.dot(z) + cAnm1b * v
fn = lambda u, x: h(prop(sys.f, reldeg, x, u)) - y_reldeg_ref
jax.debug.breakpoint()
# Catch https://github.com/patrick-kidger/diffrax/issues/296
u = jnp.where(
fn(u_prev, x) == 0,
u_prev,
solver(fn, u_prev, x).root.squeeze()
)
# JDB: u is inf after second iteration
jax.debug.breakpoint()
jax.debug.print("u: {u}", u=u)
return u

return feedbacklaw


class DiscreteLinearizingSystem(DynamicalSystem):
r"""Coupled difference-equatio of nonlinear dynamics, linear reference and io linearizing law.
"""

sys: ControlAffine
refsys: LinearSystem
feedbacklaw: Optional[Callable] = None

def __init__(self, sys, refsys, reldeg, feedbacklaw=None, linearizing_output=None):
if sys.n_inputs > 1:
raise ValueError("Only single input systems supported.")
self.sys = sys
self.refsys = refsys
self.n_outputs = self.n_inputs = 1
self.n_states = self.sys.n_states + self.refsys.n_states + 1
self.feedbacklaw = feedbacklaw
if feedbacklaw is None:
self.feedbacklaw = discrete_input_output_linearize(
sys, reldeg, refsys, linearizing_output
)

def vector_field(self, x, u=None, t=None):
x, z, y_last = x[: self.sys.n_states], x[self.sys.n_states :-1], x[-1]
if u is None:
u = 0.0
v = self.feedbacklaw(x, z, u, y_last)
xn = self.sys.vector_field(x, v)
zn = self.refsys.vector_field(z, u)
return jnp.concatenate((xn, zn, jnp.array([y_last])))

def output(self, x, u=None, t=None):
x, z, y_last = x[: self.sys.n_states], x[self.sys.n_states :-1], x[-1]
y = self.feedbacklaw(x, z, u, y_last)
return y


class LinearizingSystem(DynamicalSystem):
r"""Coupled ODE of nonlinear dynamics, linear reference and io linearizing law.
Expand Down
42 changes: 24 additions & 18 deletions tests/test_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@

from dynax import (
ControlAffine,
discrete_input_output_linearize,
discrete_relative_degree,
DiscreteLinearizingSystem,
DynamicalSystem,
DynamicStateFeedbackSystem,
Flow,
input_output_linearize,
LinearSystem,
Map,
relative_degree,
)
from dynax.example_models import NonlinearDrag, Sastry9_9
from dynax.linearize import (
input_output_linearize,
is_controllable,
relative_degree,
discrete_relative_degree,
)


Expand Down Expand Up @@ -57,13 +60,13 @@ def test_relative_degree():
def test_discrete_relative_degree():
xs = np.random.normal(size=(100, 2))
us = np.random.normal(size=(100))

sys = SpringMassDamperWithOutput(out=0)
assert discrete_relative_degree(sys, xs, us) == 2

with npt.assert_raises(RuntimeError):
discrete_relative_degree(sys, xs, us, max_reldeg=1)

sys = SpringMassDamperWithOutput(out=1)
assert discrete_relative_degree(sys, xs, us) == 1

Expand Down Expand Up @@ -153,15 +156,18 @@ def test_input_output_linearize_multiple_outputs():


def test_discrete_input_output_linearize():
sys = NonlinearDrag(0.1, 0.1, 0.1, 0.1)
ref = sys.linearize()
xs = np.random.normal(size=(100, sys.n_states))
us = np.random.normal(size=100)
reldeg = discrete_relative_degree(sys, xs, us)
feedbacklaw = discrete_input_output_linearize(sys, reldeg, ref)
feedback_sys = DynamicStateFeedbackSystem(sys, ref, feedbacklaw)
t = np.linspace(0, 1, 1000)
u = np.sin(t) * 0.1
y_ref = Map(ref)(np.zeros(sys.n_states), t, u)[1]
y = Map(feedback_sys)(np.zeros(feedback_sys.n_states), t, u)[1]
npt.assert_allclose(y_ref, y, **tols)
import jax

# with jax.disable_jit():
with jax.debug_nans():
sys = NonlinearDrag(0.1, 0.1, 0.1, 0.1)
refsys = sys.linearize()
xs = np.random.normal(size=(100, sys.n_states))
us = np.random.normal(size=100)
reldeg = discrete_relative_degree(sys, xs, us)
feedback_sys = DiscreteLinearizingSystem(sys, refsys, reldeg)
t = np.linspace(0, 1, 1000)
u = np.sin(t) * 0.1
y = Map(feedback_sys)(np.zeros(feedback_sys.n_states), t, u)[1]
y_ref = Map(refsys)(np.zeros(sys.n_states), t, u)[1]
npt.assert_allclose(y_ref, y, **tols)

0 comments on commit c60c015

Please sign in to comment.