Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fhchl/issue14 #16

Merged
merged 3 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ jobs:

- name: Test with pytest
run: |
python -m pip install jaxlib==0.4.11
python -m pip install jaxlib==0.4.16
python -m pip install .[dev]
python -m pytest --runslow --durations=0
6 changes: 5 additions & 1 deletion dynax/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def lie_derivative(f, h, n=1):
return h
else:
lie_der = lie_derivative(f, h, n=n - 1)
return lambda x, *args: jax.jvp(lie_der, (x, *args), (f(x, *args),))[1]
return lambda x: jax.jvp(
lie_der,
(x,),
(f(x),),
)[1]


def lie_derivatives_jet(f, h, n=1):
Expand Down
7 changes: 4 additions & 3 deletions dynax/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _least_squares(
verbose_mse: bool = True,
**kwargs: Any,
) -> OptimizeResult:
"""Least-squares with jit, autodiff and parameter scaling, regularization"""
"""Least-squares with jit, autodiff, parameter scaling and regularization."""

if reg_term is not None:
# Add regularization term
Expand All @@ -152,11 +152,12 @@ def f(params):
return res * np.sqrt(2 / res.size)

if x_scale:
# Scale parameters by initial value
norm = np.where(np.asarray(x0) != 0, x0, 1)
# Scale parameters and bounds by initial values
norm = np.where(np.asarray(x0) != 0, np.abs(x0), 1)
x0 = x0 / norm
___f = f
f = lambda x: ___f(x * norm)
bounds = (np.array(bounds[0]) / norm, np.array(bounds[1]) / norm)

fun = MemoizeJac(jax.jit(lambda x: value_and_jacfwd(f, x)))
jac = fun.derivative
Expand Down
12 changes: 6 additions & 6 deletions dynax/example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ class NonlinearDrag(ControlAffine):
n_states = 2
n_inputs = 1

def f(self, x, u=None, t=None):
def f(self, x):
x1, x2 = x
return jnp.array(
[x2, (-self.r * x2 - self.r2 * jnp.abs(x2) * x2 - self.k * x1) / self.m]
)

def g(self, x, u=None, t=None):
def g(self, x):
return jnp.array([0.0, 1.0 / self.m])

def h(self, x, u=None, t=None):
def h(self, x):
return x[jnp.array(self.outputs)]


Expand All @@ -84,13 +84,13 @@ class Sastry9_9(ControlAffine):
n_states = 3
n_inputs = 1

def f(self, x, t=None):
def f(self, x):
return jnp.array([0.0, x[0] + x[1] ** 2, x[0] - x[1]])

def g(self, x, t=None):
def g(self, x):
return jnp.array([jnp.exp(x[1]), jnp.exp(x[1]), 0.0])

def h(self, x, t=None):
def h(self, x):
return x[2]


Expand Down
2 changes: 1 addition & 1 deletion dynax/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def input_output_linearize(
h = sys.h
A, b, c = ref.A, ref.B, ref.C
else:
h = lambda x, t=None: sys.h(x, t=t)[output]
h = lambda x, t=None: sys.h(x)[output]
A, b, c = ref.A, ref.B, ref.C[output]

Lfnh = lie_derivative(sys.f, h, reldeg)
Expand Down
10 changes: 5 additions & 5 deletions dynax/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,23 +264,23 @@ class ControlAffine(DynamicalSystem):

"""

def f(self, x, t=None):
def f(self, x):
raise NotImplementedError

def g(self, x, t=None):
def g(self, x):
raise NotImplementedError

def h(self, x, t=None):
def h(self, x):
return x

# FIXME: remove time dependence
def vector_field(self, x, u=None, t=None):
if u is None:
u = 0
return self.f(x, t) + self.g(x, t) * u
return self.f(x) + self.g(x) * u

def output(self, x, u=None, t=None):
return self.h(x, t)
return self.h(x)


class SeriesSystem(DynamicalSystem):
Expand Down
6 changes: 3 additions & 3 deletions examples/fit_multiple_shooting_second_order_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ class NonlinearDrag(ControlAffine):
n_inputs = 1

# Define the dynamical system via the methods f, g, and h
def f(self, x, u=None, t=None):
def f(self, x):
x1, x2 = x
return jnp.array(
[x2, (-self.r * x2 - self.r2 * jnp.abs(x2) * x2 - self.k * x1) / self.m]
)

def g(self, x, u=None, t=None):
def g(self, x):
return jnp.array([0.0, 1.0 / self.m])

def h(self, x, u=None, t=None):
def h(self, x):
return x[0]


Expand Down
13 changes: 7 additions & 6 deletions examples/fit_second_order_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,18 @@ class NonlinearDrag(ControlAffine):
# Set the number of states (order of system), the number of in- and outputs.
n_states = 2
n_inputs = 1
n_outputs = 1

# Define the dynamical system via the methods f, g, and h
def f(self, x, u=None, t=None):
def f(self, x):
x1, x2 = x
return jnp.array(
[x2, (-self.r * x2 - self.r2 * jnp.abs(x2) * x2 - self.k * x1) / self.m]
)

def g(self, x, u=None, t=None):
def g(self, x):
return jnp.array([0.0, 1.0 / self.m])

def h(self, x, u=None, t=None):
def h(self, x):
return x[0]


Expand All @@ -68,15 +67,17 @@ def h(self, x, u=None, t=None):
# If we have long-duration, wide-band input data we can fit the linear
# parameters by matching the transfer-functions. In this example the result is
# not very good.
initial_sys = fit_csd_matching(initial_sys, u_train, y_train, samplerate, nperseg=100).x
initial_sys = fit_csd_matching(
initial_sys, u_train, y_train, samplerate, nperseg=100
).sys
print("linear params fitted:", initial_sys)

# Combine the ODE with an ODE solver
init_model = Flow(initial_sys)
# Fit all parameters with previously estimated parameters as a starting guess.
pred_model = fit_least_squares(
model=init_model, t=t_train, y=y_train, x0=initial_x, u=u_train, verbose=0
).x
).model
print("fitted system:", pred_model.system)

# check the results
Expand Down
6 changes: 3 additions & 3 deletions tests/test_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ class SpringMassDamperWithOutput(ControlAffine):
n_states = 2
n_inputs = 1

def f(self, x, t=None):
def f(self, x):
x1, x2 = x
return jnp.array([x2, (-self.r * x2 - self.k * x1) / self.m])

def g(self, x, t=None):
def g(self, x):
return jnp.array([0, 1 / self.m])

def h(self, x, t=None):
def h(self, x):
return x[np.array(self.out)]


Expand Down
Loading