Skip to content

Commit

Permalink
ruff and black
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Jan 30, 2024
1 parent b7e8160 commit cd6c6b8
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 20 deletions.
14 changes: 8 additions & 6 deletions dynax/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,10 @@ def discrete_input_output_linearize(
reldeg: int,
ref: DynamicalSystem,
output: Optional[int] = None,
solver: Optional[optx.AbstractRootFinder] = None,
solver: Optional[optx.AbstractRootFinder] = None,
) -> Callable[[Array, Array, float, float], float]:
"""Construct the input-output linearizing feedback law for a discrete-time system.
"""

"""Construct the input-output linearizing feedback for a discrete-time system."""

# Lee 2022, Chap. 7.4
f = lambda x, u: sys.vector_field(x, u)
h = sys.output
Expand All @@ -186,7 +185,7 @@ def discrete_input_output_linearize(
_output = lambda x: x
else:
_output = lambda x: x[output]

if solver is None:
solver = optx.Newton(rtol=1e-6, atol=1e-6)

Expand All @@ -202,7 +201,10 @@ def y_reldeg_ref(z, v):

def feedbacklaw(x: Array, z: Array, v: float, u_prev: float):
def fn(u, args):
return (_output(h(propagate(f, reldeg, x, u))) - y_reldeg_ref(z, v)).squeeze()
return (
_output(h(propagate(f, reldeg, x, u))) - y_reldeg_ref(z, v)
).squeeze()

u = optx.root_find(fn, solver, u_prev).value
return u

Expand Down
2 changes: 1 addition & 1 deletion dynax/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import functools
from typing import Literal

import equinox
import jax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike
from typing import Literal


def ssmatrix(data: ArrayLike, axis: int = 0) -> Array:
Expand Down
15 changes: 6 additions & 9 deletions examples/linearize_discrete_time.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from equinox.nn import GRUCell
from jax.random import PRNGKey

from dynax import (
DynamicalSystem,
Map,
discrete_relative_degree,
DiscreteLinearizingSystem,
DynamicalSystem,
LinearSystem,
Map,
)
from equinox.nn import GRUCell
import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey


# A nonlinear discrete-time system.
Expand Down Expand Up @@ -89,5 +88,3 @@ def output(self, x, u=None, t=None):
plt.plot(linearizing_inputs, label="linearizing input")
plt.legend()
plt.show()


6 changes: 3 additions & 3 deletions tests/test_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_linearize_dyn2lin():
class ScalarScalar(DynamicalSystem):
n_states = "scalar"
n_inputs = "scalar"

def vector_field(self, x, u, t):
return -1 * x + 2 * u

Expand Down Expand Up @@ -164,7 +164,7 @@ def vector_field(self, x, u, t=None):
def output(self, x, u=None, t=None):
return x[0]


def test_discrete_input_output_linearize():
sys = Lee7_4_5()
refsys = sys.linearize()
Expand All @@ -176,7 +176,7 @@ def test_discrete_input_output_linearize():
feedback_sys = DiscreteLinearizingSystem(sys, refsys, reldeg)
t = np.linspace(0, 0.001, 10)
u = np.cos(t) * 0.1
_, v = Map(feedback_sys)(np.zeros(2+2+1), t, u)
_, v = Map(feedback_sys)(np.zeros(2 + 2 + 1), t, u)
_, y = Map(sys)(np.zeros(2), t, u)
_, y_ref = Map(refsys)(np.zeros(2), t, u)

Expand Down
1 change: 0 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import jax
import jax.numpy as jnp
import numpy.testing as npt

Expand Down

0 comments on commit cd6c6b8

Please sign in to comment.