Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Oct 11, 2023
1 parent 23faf87 commit 343335e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 20 deletions.
66 changes: 48 additions & 18 deletions dynax/system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Classes for representing dynamical systems."""

from abc import abstractmethod
from collections.abc import Callable
from dataclasses import field
from typing import Optional
Expand All @@ -12,7 +13,7 @@
from jax import Array
from jax.tree_util import tree_map, tree_structure
from jax.typing import ArrayLike
from jaxtyping import PyTree, Scalar, ScalarLike
from jaxtyping import PyTree, ScalarLike


def static_field(**kwargs):
Expand Down Expand Up @@ -55,6 +56,8 @@ def non_negative_field(min_val: float = 0.0, **kwargs):
return boxed_field(lower=min_val, upper=np.inf, **kwargs)


from jax._src.numpy.util import promote_dtypes_numeric

def _linearize(f, h, x, u, t):
"""Linearize dx=f(x,u,t), y=h(x,u,t) around x, u and t."""
A = jax.jacfwd(f, argnums=0)(x, u, t)
Expand Down Expand Up @@ -95,22 +98,23 @@ def output(self, x, u, t):

x0: Optional[PyTree] = static_field(init=False, default=None)

@abstractmethod
def vector_field(
self, x: PyTree, u: Optional[PyTree] = None, t: Optional[Scalar] = None
self, x: PyTree, u: Optional[ScalarLike] = None, t: Optional[ScalarLike] = None
) -> PyTree:
"""Compute the state derivative from current state, input and time."""
raise NotImplementedError
pass

def output(
self, x: PyTree, u: Optional[PyTree] = None, t: Optional[Scalar] = None
self, x: PyTree, u: Optional[ScalarLike] = None, t: Optional[ScalarLike] = None
) -> PyTree:
"""Compute the output from current state, input and time."""
return None

def linearize(
self,
x: Optional[PyTree] = None,
u: Optional[PyTree] = None,
u: Optional[ScalarLike] = None,
t: Optional[float] = None,
) -> "LinearSystem":
"""Compute the linearized system around a state, and input and time."""
Expand Down Expand Up @@ -175,21 +179,20 @@ def linearize(
# pass


def _control_affine(bfun, Afun, x, u=None, t=None):
def _control_affine(
bfun: Callable[[PyTree], PyTree],
Afun: Callable[[PyTree], PyTree],
x: PyTree,
u: Optional[ScalarLike] = None
):
b = bfun(x)
A = Afun(x)
if u is None or A is None:
return b
if isinstance(u, (Scalar, ScalarLike)):
# elementwise multiply A with u
Au = tree_map(lambda Ai: Ai * u, A)
elif tree_structure(A) == tree_structure(u):
Au = tree_map(jnp.dot, A, u)
else:
raise ValueError(
f"If u isn't scalar, {bfun}(x) and u must have the same pytree structure"
)
return (b**ω + Au**ω).ω
# TODO(pytree): some casting to arrays is needed here, as otherwise the omegas behave
# weired. For np.ndarray a, a**ω is ω applied for each element resulting in a array
# of object
return (b**ω + A**ω * u).ω


class ControlAffine(DynamicalSystem):
Expand All @@ -215,10 +218,10 @@ def i(self, x: PyTree) -> Array | None:
return None

def vector_field(self, x, u=None, t=None):
return _control_affine(self.f, self.g, x, u, t)
return _control_affine(self.f, self.g, x, u)

def output(self, x, u=None, t=None):
return _control_affine(self.h, self.i, x, u, t)
return _control_affine(self.h, self.i, x, u)


class LinearSystem(ControlAffine):
Expand Down Expand Up @@ -248,6 +251,33 @@ def i(self, x: PyTree) -> PyTree:
return self.D


class LinearFuncSystem(ControlAffine):
r"""A linear, time-invariant dynamical system.
.. math::
ẋ &= Ax + Bu \\
y &= Cx + Du
"""
A: Callable
B: Optional[Callable] = None
C: Optional[Callable] = None
D: Optional[Callable] = None

def f(self, x: PyTree) -> PyTree:
return self.A(x)

def g(self, x: PyTree) -> PyTree:
return self.B(x)

def h(self, x: PyTree) -> PyTree:
return tree_map(jnp.dot, self.C, x)

def i(self, x: PyTree) -> PyTree:
return self.D


class SeriesSystem(DynamicalSystem):
"""Two systems in series."""

Expand Down
12 changes: 12 additions & 0 deletions pytreenotes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Don't understand yet how to do Ax matrix vector product when A and x are both pytrees. This is needed for LinearSystem

When linearizing, one way could be to store not A = jacfwd(f, x), but A = jax.linearize(f, x). Then Ax is just A(x).

Above is a problem for feedback linearization. There, we need to compute `c.dot(np.linalg.matrix_power(A, reldeg))` and the like.

Could one just ravel all the pytrees in linearize?

Could ine just ravel the outputs of vector_field, and unravel the inputs, thus allowing arbitrary pytrees as in and output, but keep all the other machieneary strictly with ndarrays?



4 changes: 2 additions & 2 deletions tests/test_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def output(self, x, u, t):


def test_series():
n1, m1, p1 = 4, 3, 2
n1, m1, p1 = 4, 1, 1
np.random.seed(42)
A1 = np.random.uniform(-5, 5, size=(n1, n1))
B1 = np.random.uniform(-5, 5, size=(n1, m1))
Expand All @@ -153,7 +153,7 @@ def test_series():
D2 = np.random.uniform(-5, 5, size=(p2, m2))
sys2 = LinearSystem(A2, B2, C2, D2)
sys = SeriesSystem(sys1, sys2)
linsys = sys.linearize(x=(np.zeros(n1), np.zeros(n2)), u=np.zeros(m1))
linsys = sys.linearize(x=(np.zeros(n1), np.zeros(n2)), u=0)
assert tree_equal(
linsys.A, ((A1, np.zeros((n1, n2))), (B2.dot(C1), A2))
)
Expand Down

0 comments on commit 343335e

Please sign in to comment.