Skip to content

Commit

Permalink
test_estimation.py passes
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Oct 8, 2023
1 parent 3ac7240 commit 7fd1985
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 62 deletions.
136 changes: 78 additions & 58 deletions dynax/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
import jax.tree_util as jtu
import numpy as np
import scipy.signal as sig
from jax import Array
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map
from jaxtyping import ArrayLike, PyTree
from jaxtyping import ArrayLike, PyTree, Array, Float
from numpy.typing import NDArray
from scipy.linalg import pinvh
from scipy.optimize import least_squares, OptimizeResult
Expand Down Expand Up @@ -305,19 +304,23 @@ def residual_term(params):

return res

from typing import TypeVar

def _moving_window(a: Array, size: int, stride: int):
T = TypeVar('T')

def _moving_window(a: T, size: int, stride: int) -> T:
start_idx = jnp.arange(0, len(a) - size + 1, stride)[:, None]
inner_idx = jnp.arange(size)[None, :]
return a[start_idx + inner_idx]

TimeSignal = Float[ArrayLike, "times ..."]

def fit_multiple_shooting(
model: AbstractEvolution,
t: ArrayLike,
y: ArrayLike,
x0: ArrayLike,
u: Optional[Union[Callable[[float], Array], ArrayLike]] = None,
t: TimeSignal,
y: PyTree[TimeSignal], # noqa: F821
x0: PyTree,
u: Optional[PyTree[TimeSignal]] = None,
num_shots: int = 1,
continuity_penalty: float = 0.1,
**kwargs,
Expand All @@ -340,55 +343,61 @@ def fit_multiple_shooting(
`model` is the model with fitten parameters and `x0s`, `ts`, `us` are
the initial is an array of initial states, times, and inputs for each
shot. Else, return only `(model, x0s, ts, us)`.
"""
t = jnp.asarray(t)
y = jnp.asarray(y)
x0 = jnp.asarray(x0)
"""
# Check that all arguments have the same time size
if u is None:
msg = (
f"t, y must have same number of samples, but have shapes "
f"{t.shape}, {y.shape}"
)
assert t.shape[0] == y.shape[0], msg
ins = (t, y)
else:
u = jnp.asarray(u)
msg = (
f"t, y, u must have same number of samples, but have shapes "
f"{t.shape}, {y.shape} and {u.shape}"
ins = (t, y, u)
time_size = len(t)
is_right_size = lambda a: jnp.size(a, 0) == time_size
if not all(map(is_right_size, jtu.tree_flatten(ins)[0])):
raise ValueError("Inputs must be of same length.")

# Check that output is defined
x_shape, y_shape = jax.eval_shape(model, x0, t, u)
if y_shape is None:
raise ValueError(
"`model.system.output` seems to return `None`. "
"Did you forget to define the output method?"
)
assert t.shape[0] == y.shape[0] == u.shape[0], msg

# Compute number of samples per segment. Remove samples at end if total
# number is not divisible by num_shots.
num_samples = len(t)
num_samples_per_segment = int(np.floor((num_samples + (num_shots - 1)) / num_shots))
leftover_samples = num_samples - (num_samples_per_segment * num_shots)
if leftover_samples:
print("Warning: removing last ", leftover_samples, "samples.")
num_samples -= leftover_samples
t = t[:num_samples]
y = y[:num_samples]

# TODO: use numpy for everything that is not jitted
# Divide signals into segments.
ts = _moving_window(t, num_samples_per_segment, num_samples_per_segment - 1)
ys = _moving_window(y, num_samples_per_segment, num_samples_per_segment - 1)
x0s = np.zeros((num_shots - 1, len(x0)))

ucoeffs = None

# Make sure length is devisable by num_shots.
if time_size & num_shots != 0:
raise ValueError("Can't cleanly devide")
else:
# TODO: zeropad and mask or remove samples otherwise.
pass
samples_per_segment = time_size // num_shots

# Devide signals into segments
window_with_single_overlap = lambda x: _moving_window(
x, samples_per_segment, samples_per_segment - 1
)
ts = jtu.tree_map(window_with_single_overlap, t)
ys = jtu.tree_map(window_with_single_overlap, y)
if u is not None:
us = u[:num_samples]
us = _moving_window(us, num_samples_per_segment, num_samples_per_segment - 1)
compute_coeffs = lambda t, u: jnp.stack(dfx.backward_hermite_coefficients(t, u))
ucoeffs = jax.vmap(compute_coeffs)(ts, us)
us = jtu.tree_map(window_with_single_overlap, u)
ucoeffs = jax.vmap(dfx.backward_hermite_coefficients)(ts, us)
else:
ucoeffs = us = None

# x0s for all segments but the first
zeros_like_repeated_along_first_axis = lambda a: jnp.tile(
jnp.zeros_like(a), (num_shots - 1,) + (1,)*jnp.ndim(a)
)
x0s = jtu.tree_map(zeros_like_repeated_along_first_axis, x0)

# Each segment's time starts at 0.
ts0 = ts - ts[:, :1]

# Residuals are weighted by standard deviation
std_y = tree_map(partial(np.std, axis=0), y)
std_ys = tree_map(lambda std, y: std, std_y, ys)

# prepare optimization
init_params, unravel = ravel_pytree((x0s, model))
std_y = np.std(y, axis=0)
parameter_bounds = _get_bounds(model)
state_bounds = (
(num_shots - 1) * len(x0) * [-np.inf],
Expand All @@ -399,28 +408,39 @@ def fit_multiple_shooting(
state_bounds[1] + parameter_bounds[1],
)

def residuals(params):
prepend = lambda x, xs: jnp.concatenate((jnp.asarray([x]), xs))

def residuals(params: Array) -> Array:
x0s, model = unravel(params)
x0s = jnp.concatenate((x0[None], x0s), axis=0)
# Prepend known initial state
x0s = jtu.tree_map(prepend, x0, x0s)
# Make prediction
xs_pred, ys_pred = jax.vmap(model)(x0s, t=ts0, ucoeffs=ucoeffs)
# output residual
res_y = ((ys - ys_pred) / std_y).reshape(-1)
res_y = res_y / np.sqrt(len(res_y))
# continuity residual
std_x = jnp.std(xs_pred, axis=(0, 1))
res_x0 = ((x0s[1:] - xs_pred[:-1, -1]) / std_x).reshape(-1)
res_x0 = res_x0 / np.sqrt(len(res_x0))
return jnp.concatenate((res_y, continuity_penalty * res_x0))
# Output residual
res_y = ((ys**ω - ys_pred**ω) / std_ys**ω).ω
# Continuity residual
std_along_shots_and_time = partial(jnp.std, axis=(0, 1), keepdims=True)
std_x = jtu.tree_map(std_along_shots_and_time, xs_pred)
normalized_overlap_error = lambda x0, xs, norm: (x0[1:] - xs[:-1, -1]) * norm
res_x0 = jtu.tree_map(
normalized_overlap_error,
x0s,
xs_pred,
(continuity_penalty / std_x**ω).ω
)
return jnp.concatenate((ravel_pytree(res_y)[0], ravel_pytree(res_x0)[0]))

res = _least_squares(residuals, init_params, bounds, x_scale=False, **kwargs)

x0s, res.result = unravel(res.x)
res.x0s = np.asarray(jnp.concatenate((x0[None], x0s), axis=0))
res.x0s = jtu.tree_map(prepend, x0, x0s)

# TODO: cast everything to np.ndarray?
res.ts = np.asarray(ts)
res.ts0 = np.asarray(ts0)

if u is not None:
res.us = np.asarray(us)
res.us = jtu.tree_map(np.asarray, us)

return res

Expand All @@ -433,7 +453,7 @@ def transfer_function(sys: DynamicalSystem, to_states: bool = False, **kwargs):
# x0, u0 and t0 are supplied?
A, B, C, D = linsys.A, linsys.B, linsys.C, linsys.D

def H(s: complex):
def H(s: ArrayLike) -> Array:
"""Transfer-function at s."""
# TODO(pytree): Here, we are _required_ to materialize A, B, C, D, because this
# is otherwise very hard :/
Expand Down
3 changes: 3 additions & 0 deletions dynax/example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def vector_field(self, x, u=None, t=None):
dx2 = (u - self.r * x2 - self.k * x1) / self.m
return dx1, dx2

def output(self, x, u, t):
return x


class NonlinearDrag(ControlAffine):
"""Spring-mass-damper system with nonlin drag.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_fit_multiple_shooting_with_input(num_shots):
# data
t = np.linspace(0, 10, 10000)
u = np.sin(1 * 2 * np.pi * t)
x0 = [1.0, 0.0]
x0 = (1.0, 0.0)
true_model = Flow(SpringMassDamper(1.0, 2.0, 3.0))
x_true, _ = true_model(x0, t, u)
# fit
Expand All @@ -174,7 +174,7 @@ def test_fit_multiple_shooting_with_input(num_shots):
x_true,
x0,
u,
continuity_penalty=1,
continuity_penalty=1.,
num_shots=num_shots,
verbose=2,
).result
Expand All @@ -192,7 +192,7 @@ def test_fit_multiple_shooting_with_input(num_shots):
def test_fit_multiple_shooting_without_input(num_shots):
# data
t = np.linspace(0, 1, 1000)
x0 = [0.5, 0.5]
x0 = (0.5, 0.5)
solver_opt = dict(step=PIDController(rtol=1e-3, atol=1e-6))
true_model = Flow(
LotkaVolterra(alpha=2 / 3, beta=4 / 3, gamma=1.0, delta=1.0), **solver_opt
Expand All @@ -203,7 +203,7 @@ def test_fit_multiple_shooting_without_input(num_shots):
LotkaVolterra(alpha=1.0, beta=1.0, gamma=1.5, delta=2.0), **solver_opt
)
pred_model = fit_multiple_shooting(
init_model, t, x_true, x0, num_shots=num_shots, continuity_penalty=1
init_model, t, x_true, x0, num_shots=num_shots, continuity_penalty=1.
).result
# check result
x_pred, _ = pred_model(x0, t)
Expand Down

0 comments on commit 7fd1985

Please sign in to comment.