Skip to content

Commit

Permalink
WIP: all pytrees, let's try...
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Oct 1, 2023
1 parent 1ab8327 commit 0349407
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 127 deletions.
79 changes: 49 additions & 30 deletions dynax/estimation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Functions for estimating parameters of dynamical systems."""

import warnings
from collections.abc import Callable, Sequence
from dataclasses import fields
from typing import Any, Callable, Literal, Optional, Union
from functools import partial
from typing import Any, Literal, Optional, Union

import diffrax as dfx
import equinox as eqx
Expand All @@ -13,15 +15,19 @@
import scipy.signal as sig
from jax import Array
from jax.flatten_util import ravel_pytree
from jax.typing import ArrayLike
from jax.tree_util import tree_map
from jaxtyping import ArrayLike, PyTree
from numpy.typing import NDArray
from scipy.linalg import pinvh
from scipy.optimize import least_squares, OptimizeResult
from scipy.optimize._optimize import MemoizeJac

from .evolution import AbstractEvolution
from .system import DynamicalSystem
from .util import mse, nmse, nrmse, value_and_jacfwd
from .util import mse, nmse, nrmse, value_and_jacfwd, tree_stack, tree_unstack


NDArrayLike = Union[Array, np.ndarray]


def _get_bounds(module: eqx.Module) -> tuple[list, list]:
Expand Down Expand Up @@ -51,7 +57,7 @@ def _get_bounds(module: eqx.Module) -> tuple[list, list]:
return lower_bounds, upper_bounds


def _key_paths(tree: Any, root: str = "tree") -> list[str]:
def _key_paths(tree: PyTree, root: str = "tree") -> list[str]:
"""List key_paths to free fields of pytree including elements of JAX arrays."""
f = lambda l: l.tolist() if isinstance(l, jax.Array) else l
flattened, _ = jtu.tree_flatten_with_path(jtu.tree_map(f, tree))
Expand Down Expand Up @@ -151,16 +157,17 @@ def f(params):
# x0 = (x, y)
# x = [np.array([x1, x2, x3, ...]), np.array([y1, y2, y3, ....])]


def fit_least_squares(
model: AbstractEvolution,
t: ArrayLike,
y: ArrayLike,
x0: ArrayLike,
u: Optional[ArrayLike] = None,
t: NDArrayLike,
y: NDArrayLike | list[NDArrayLike],
x0: PyTree | list[PyTree],
u: Optional[PyTree | list[PyTree]] = None,
batched: bool = False,
sigma: Optional[ArrayLike] = None,
sigma: Optional[NDArrayLike] = None,
absolute_sigma: bool = False,
reg_val: float = 0,
reg_val: float = 0.0,
reg_bias: Optional[Literal["initial"]] = None,
verbose_mse: bool = True,
**kwargs,
Expand All @@ -172,15 +179,20 @@ def fit_least_squares(
Args:
model: Flow instance holding initial parameter estimates
t: Times at which `y` is given
y: Target outputs of system
y: Target outputs of system of shape (times_size, output_size)
x0: Initial state
u: Pptional system input
batched: If True, interpret `t`, `y`, `x0`, `u` as holding multiple
experiments stacked along the first axis.
u: System input
batched: If True, interpret `y`, `x0`, `u` as holding multiple experiments of
equal length.
If all three arguments are arrays, the experiments should be stacked along
their first axis and the model's `vector_field` should expect and return
`jax.Array`s. If it expects `PyTree`s, `t`, `y`, `x0` and `u` should
instead be lists of equal length holding the data for each
experiment.
sigma: A 1-D sequence with values of the standard deviation of the measurement
error for each output of `model.system`. If None, `sigma` will be set to
the rms values of each measurement in `y`, which makes the cost
scale-invariant to magnitude differences between measurements.
the rms values of each measurement making the cost scale-invariant to
magnitude differences between measurements.
absolute_sigma: If True, `sigma` is used in an absolute sense and the estimated
parameter covariance `pcov` reflects these absolute values. If False
(default), only the relative magnitudes of the `sigma` values matter and
Expand All @@ -206,15 +218,23 @@ def fit_least_squares(
nrmse: Normalized root-mean-squared-error.
"""
t = jnp.asarray(t)
y = jnp.asarray(y)
t_ = jnp.asarray(t)
y_ = jnp.stack(y, axis=0)

if batched:
# First axis holds experiments, second axis holds time.
std_y = np.std(y, axis=1, keepdims=True)
calc_coeffs = jax.vmap(dfx.backward_hermite_coefficients)
# Multiple experiments are expected as lists or stacked arrays.
batched_inputs = [y, x0] if u is None else [y, x0, u]
is_lists = all(isinstance(o, (list, tuple)) for o in batched_inputs)
is_arrays = all(isinstance(o, (jax.Array, np.ndarray)) for o in batched_inputs)
if not (is_lists or is_arrays):
raise TypeError("For batched inputs, y, x0 (and u) must have same type")
std_y = np.std(y_, axis=1, keepdims=True)
if is_lists:
pass
calc_coeffs = jax.vmap(dfx.backward_hermite_coefficients, in_axes=(None, 0))
else:
# First axis holds time.
if not isinstance(y, NDArrayLike):
raise TypeError("If batched is False, `y` must be an array")
std_y = np.std(y, axis=0, keepdims=True)
calc_coeffs = dfx.backward_hermite_coefficients

Expand All @@ -225,8 +245,7 @@ def fit_least_squares(
weight = 1 / sigma

if u is not None:
u = jnp.asarray(u)
ucoeffs = calc_coeffs(t, u)
ucoeffs = calc_coeffs(t_, u)
else:
ucoeffs = None

Expand All @@ -250,8 +269,8 @@ def residual_term(params):
if batched:
model = jax.vmap(model)
# FIXME: ucoeffs not supported for Map
_, pred_y = model(x0, t=t, ucoeffs=ucoeffs)
res = (y - pred_y) * weight
_, pred_y = model(x0, t=t_, ucoeffs=ucoeffs)
res = (y_ - pred_y) * weight
return res.reshape(-1)

res = _least_squares(
Expand All @@ -265,11 +284,11 @@ def residual_term(params):

res.result = unravel(res.x)
res.pcov = _compute_covariance(res.jac, res.cost, absolute_sigma, cov_prior)
res.y_pred = y - res.fun.reshape(y.shape) / weight
res.y_pred = y_ - res.fun.reshape(y_.shape) / weight
res.key_paths = _key_paths(model, root=model.__class__.__name__)
res.mse = np.atleast_1d(mse(y, res.y_pred))
res.nmse = np.atleast_1d(nmse(y, res.y_pred))
res.nrmse = np.atleast_1d(nrmse(y, res.y_pred))
res.mse = np.atleast_1d(mse(y_, res.y_pred))
res.nmse = np.atleast_1d(nmse(y_, res.y_pred))
res.nrmse = np.atleast_1d(nrmse(y_, res.y_pred))

return res

Expand Down
33 changes: 15 additions & 18 deletions dynax/evolution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Optional
from collections.abc import Callable
from typing import Optional

import diffrax as dfx
import equinox as eqx
Expand All @@ -11,13 +12,6 @@
from .system import DynamicalSystem


try:
# TODO: remove when upgrading to diffrax > v0.2
DefaultAdjoint = dfx.NoAdjoint
except AttributeError:
DefaultAdjoint = dfx.DirectAdjoint


class AbstractEvolution(eqx.Module):
"""Abstract base-class for evolutions."""

Expand All @@ -35,18 +29,21 @@ class Flow(AbstractEvolution):
)
dt0: Optional[float] = eqx.static_field(default=None)


# FIXME(pytree): I just changed the call order, so this has to be fixed EVERYWHERE!
def __call__(
self,
x0: PyTree,
t: ArrayLike,
u: Optional[ArrayLike] = None,
ufun: Optional[Callable[[float], float]] = None,
ucoeffs: Optional[ArrayLike] = None,
x0: Optional[PyTree] = None,
u: Optional[PyTree] = None,
ufun: Optional[Callable[[float], PyTree]] = None,
ucoeffs: Optional[tuple[PyTree, PyTree, PyTree, PyTree]] = None,
**diffeqsolve_kwargs,
) -> tuple[Array, Array]:
) -> tuple[PyTree, PyTree]:
"""Solve initial value problem for state and output trajectories."""
t_ = jnp.asarray(t)
# TODO(pytree): check that x0 has the right shape?
if x0 is None and self.system.x0 is None:
raise ValueError("One of x0 or system.x0 must be not None")
if u is None and ufun is None and ucoeffs is None:
_ufun = lambda t: None
elif ucoeffs is not None:
Expand All @@ -55,10 +52,10 @@ def __call__(
elif callable(u):
_ufun = u
elif u is not None:
u_ = jnp.asarray(u)
msg = "t and u must have matching first dimensions"
u = jnp.asarray(u)
assert len(t_) == u.shape[0], msg
_ufun = spline_it(t_, u)
assert len(t_) == u_.shape[0], msg
_ufun = spline_it(t_, u_)
else:
raise ValueError("Must specify one of u, ufun, ucoeffs.")

Expand All @@ -68,7 +65,7 @@ def __call__(
stepsize_controller=self.step,
saveat=dfx.SaveAt(ts=t_),
max_steps=50 * len(t_),
adjoint=DefaultAdjoint(),
adjoint=dfx.DirectAdjoint(),
dt0=self.dt0 if self.dt0 is not None else t_[1],
)
diffeqsolve_default_options |= diffeqsolve_kwargs
Expand Down
2 changes: 1 addition & 1 deletion dynax/example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def g(self, x):
return (0.0, 1.0 / self.m)

def h(self, x):
return tuple(x[i] for i in self.outputs)
return jnp.array([x[i] for i in self.outputs])


class Sastry9_9(ControlAffine):
Expand Down
11 changes: 5 additions & 6 deletions dynax/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import diffrax as dfx
import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, PyTree


class InterpolationFunction(eqx.Module):
"""Interpolating cubic-spline function."""

path: dfx.CubicInterpolation

def __init__(self, ts, us):
ts = jnp.asarray(ts)
us = jnp.asarray(us)
assert len(ts) == us.shape[0], "time and input must have same number of samples"
coeffs = dfx.backward_hermite_coefficients(ts, us)
self.path = dfx.CubicInterpolation(ts, coeffs)
def __init__(self, ts: Array, us: PyTree):
ts_ = jnp.asarray(ts)
coeffs = dfx.backward_hermite_coefficients(ts_, us)
self.path = dfx.CubicInterpolation(ts_, coeffs)

def __call__(self, t):
return self.path.evaluate(t)
Expand Down
4 changes: 2 additions & 2 deletions dynax/linearize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Functions related to feedback linearization of nonlinear systems."""

from collections.abc import Callable
from typing import Optional, Sequence
from collections.abc import Callable, Sequence
from typing import Optional

import jax
import jax.numpy as jnp
Expand Down
Loading

0 comments on commit 0349407

Please sign in to comment.