Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Estimators return OpimizationResult object #8

Merged
merged 7 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ repos:
hooks:
- id: nbqa-black
- id: nbqa-ruff
args: [--fix, --exit-non-zero-on-fix]
278 changes: 201 additions & 77 deletions dynax/estimation.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
"""Functions for estimating parameters of dynamical systems."""

import warnings
from dataclasses import fields
from typing import Callable, Optional, TypeVar, Union
from typing import Any, Callable, Literal, Optional, Union

import diffrax as dfx
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import scipy.signal as sig
from jax import Array
from jax.flatten_util import ravel_pytree
from jaxtyping import Array
from numpy.typing import ArrayLike, NDArray
from scipy.optimize import least_squares
from jax.typing import ArrayLike
from numpy.typing import NDArray
from scipy.linalg import svd
from scipy.optimize import least_squares, OptimizeResult as _OptimizeResult
from scipy.optimize._optimize import MemoizeJac

from .evolution import AbstractEvolution
from .system import DynamicalSystem
from .util import value_and_jacfwd
from .util import mse, nmse, nrmse, value_and_jacfwd


def _get_bounds(module: eqx.Module) -> tuple[list, list]:
Expand Down Expand Up @@ -47,88 +51,222 @@ def _get_bounds(module: eqx.Module) -> tuple[list, list]:
return lower_bounds, upper_bounds


Evolution = TypeVar("Evolution", bound=AbstractEvolution)
def _key_paths(tree: Any, root: str = "tree") -> list[str]:
"""List key_paths to free fields of pytree including elements of JAX arrays."""
f = lambda l: l.tolist() if isinstance(l, jax.Array) else l
flattened, _ = jtu.tree_flatten_with_path(jtu.tree_map(f, tree))
return [f"{root}{jtu.keystr(kp)}" for kp, _ in flattened]


class OptimizeResult(_OptimizeResult):
"""Represents the optimization result.

Attributes
----------
x : Evolution
The solution of the optimization.
success : bool
Whether or not the optimizer exited successfully.
status : int
Termination status of the optimizer. Its value depends on the
underlying solver. Refer to `message` for details.
message : str
Description of the cause of the termination.
fun, jac, hess: ndarray
Values of objective function, its Jacobian and its Hessian (if
available). The Hessians may be approximations, see the documentation
of the function in question.
pcov: ndarray
Estimate of the covariance matrix.
hess_inv : object
Inverse of the objective function's Hessian; may be an approximation.
Not available for all solvers. The type of this attribute may be
either np.ndarray or scipy.sparse.linalg.LinearOperator.
key_paths: List of key_paths for x that index the corresponding entries in `pcov`,
`jac`, `hess` and `hess_inv`.
nfev, njev, nhev : int
Number of evaluations of the objective functions and of its
Jacobian and Hessian.
nit : int
Number of iterations performed by the optimizer.
maxcv : float
The maximum constraint violation.
"""


def fit_least_squares(
model: Evolution,
t: Array,
y: Array,
x0: Array,
u: Optional[Array] = None,
model: AbstractEvolution,
t: ArrayLike,
y: ArrayLike,
x0: ArrayLike,
u: Optional[ArrayLike] = None,
batched: bool = False,
sigma: Optional[ArrayLike] = None,
absolute_sigma: bool = False,
reg_val: float = 0,
reg_bias: Optional[Literal["initial"]] = None,
reg_weight: Optional[Literal["inv_initial"]] = None,
**kwargs,
) -> Evolution:
"""Fit forward model with nonlinear least-squares.
) -> OptimizeResult:
"""Fit forward model with (regularized) nonlinear least-squares.

Parameters can be constrained via the `*_field` functions.

Args:
model: the forward model to fit
t: times at which `y` is given
y: target outputs of system
x0: initial state
u: optional system input
model: Forward model holding initial parameter estimates
t: Times at which `y` is given
y: Target outputs of system
x0: Initial state
u: Pptional system input
batched: If True, interpret `t`, `y`, `x0`, `u` as holding multiple
experiments stacked along the first axis.
sigma: A 1-D sequence with values of the standard deviation of the measurement
error for each output of `model.system`. If None, `sigma` will be set to
the rms values of each measurement in `y`, which makes the cost
scale-invariant to magnitude differences between measurements.
absolute_sigma: If True, `sigma` is used in an absolute sense and the estimated
parameter covariance `pcov` reflects these absolute values. If False
(default), only the relative magnitudes of the `sigma` values matter and
`sigma` is scaled to match the sample variance of the residuals after the
fit.
reg_val: Weight of the l2 penalty term
reg_bias: If "initial", bias the parameter estimates towards the values in
`model`.
reg_weight: If "inv_initial", weight each penalty term with the inverse
of the initial values. Applies no weightings to parameters with zero initial
value.
kwargs: optional parameters for `scipy.optimize.least_squares`

Returns:
A copy of `model` with the fitted parameters.
`OptimizeResult` as returned by `scipy.optimize.least_squares` with the
following fields defined:

x: `model` with estimated parameters
cov: Covariance matrix of the parameter estimate
key_paths: List of key_paths that index the corresponding entries in `cov`
and `jac`

"""
t = jnp.asarray(t)
y = jnp.asarray(y)
x0 = jnp.asarray(x0)

if batched:
# First axis holds experiments, second axis holds time.
std_y = np.std(y, axis=1, keepdims=True)
calc_coeffs = jax.vmap(dfx.backward_hermite_coefficients)
else:
# First axis holds time.
std_y = np.std(y, axis=0, keepdims=True)
calc_coeffs = dfx.backward_hermite_coefficients

ucoeffs = None
if u is not None:
ucoeffs = calc_coeffs(t, jnp.asarray(u))
if sigma is None:
weight = 1 / std_y
else:
sigma = np.asarray(sigma)
weight = 1 / sigma

# NOTE: least_squares wrapper also implemented at `jaxopt.ScipyLeastSquares`
if u is not None:
u = jnp.asarray(u)
ucoeffs = calc_coeffs(t, u)
else:
ucoeffs = None

# use ravel instead of flatten as we also want to flatten all ndarrays
init_params, unravel = ravel_pytree(model)
bounds = _get_bounds(model)

param_bias = 0
if reg_bias == "initial":
param_bias = init_params

if reg_val != 0 and reg_weight == "initial":
one_or_init_param = np.where(np.asarray(init_params) != 0, init_params, 1)
reg_val = reg_val / one_or_init_param

is_regularized = np.any(reg_val != 0)
res_size = y.size
if is_regularized:
res_size = y.size + init_params.size

def residuals(params):
model = unravel(params)
if batched:
model = jax.vmap(model)
_, pred_y = model(x0, t=t, ucoeffs=ucoeffs)
res = ((y - pred_y) / std_y).reshape(-1)
return res / np.sqrt(len(res))

# compute primal and sensitivties in one forward pass
res = (y - pred_y) * weight
res = res.reshape(-1)
if is_regularized:
res = jnp.concatenate((res, reg_val * (params - param_bias)))
# Scale cost to mean squared error (mse) for interpretable verbose output.
res = res * np.sqrt(2 / res_size)
return res

# Compute primal and sensitivties in one forward pass.
fun = MemoizeJac(jax.jit(lambda x: value_and_jacfwd(residuals, x)))
jac = fun.derivative
res = least_squares(
fun, init_params, bounds=bounds, jac=jac, x_scale="jac", **kwargs
)
params = res.x
return unravel(params)

# Unscale mse to Least-Squares cost.
res.fun = res.fun[: y.size].reshape(y.shape) / np.sqrt(2 / res_size) / weight
res.jac = res.jac * np.sqrt(res_size / 2)
res.cost = res.cost * res_size / 2

# Compute normalized root-mean-squared error
y_pred = y - res.fun
res.mse = np.atleast_1d(mse(y, y_pred))
res.nmse = np.atleast_1d(nmse(y, y_pred))
res.nrmse = np.atleast_1d(nrmse(y, y_pred))

# Compute covariance matrix.
# pcov = H^{-1} ~= inv(J^T J). Use regularized inverse.
_, s, VT = svd(res.jac, full_matrices=False)
threshold = np.finfo(float).eps * max(res.jac.shape) * s[0]
s = s[s > threshold]
VT = VT[: s.size]
pcov = np.dot(VT.T / s**2, VT)

warn_cov = False
if not absolute_sigma:
if res_size > res.x.size:
s_sq = res.cost / (res_size - res.x.size)
pcov = pcov * s_sq
else:
warn_cov = True

if np.isnan(pcov).any():
warn_cov = True

if warn_cov:
pcov.fill(np.inf)
warnings.warn(
"Covariance of the parameters could not be estimated", stacklevel=2
)

res.x = unravel(res.x)
res.pcov = pcov
res.key_paths = _key_paths(model, root=model.__class__.__name__)

return res

def _moving_window(a: jnp.ndarray, size: int, stride: int):

def _moving_window(a: Array, size: int, stride: int):
start_idx = jnp.arange(0, len(a) - size + 1, stride)[:, None]
inner_idx = jnp.arange(size)[None, :]
return a[start_idx + inner_idx]


def fit_multiple_shooting(
model: Evolution,
t: Array,
y: Array,
x0: Array,
u: Optional[Union[Callable[[float], Array], Array]] = None,
model: AbstractEvolution,
t: ArrayLike,
y: ArrayLike,
x0: ArrayLike,
u: Optional[Union[Callable[[float], Array], ArrayLike]] = None,
num_shots: int = 1,
continuity_penalty: float = 0.0,
**kwargs,
) -> Union[
tuple[Evolution, NDArray, NDArray, NDArray],
tuple[Evolution, NDArray, NDArray, NDArray, NDArray],
]:
) -> OptimizeResult:
"""Fit forward model with multiple shooting and nonlinear least-squares.

Args:
Expand Down Expand Up @@ -177,7 +315,7 @@ def fit_multiple_shooting(
t = t[:num_samples]
y = y[:num_samples]

# FIXME: use numpy for everything that is not jitted
# TODO: use numpy for everything that is not jitted
# Divide signals into segments.
ts = _moving_window(t, num_samples_per_segment, num_samples_per_segment - 1)
ys = _moving_window(y, num_samples_per_segment, num_samples_per_segment - 1)
Expand Down Expand Up @@ -223,22 +361,7 @@ def unpack(flat, unravel, x0s_shape):

def residuals(params):
x0s, model = unpack(params, treedef, x0s_shape)
# TODO: for dfx.NoAdjoint using diffrax<v0.3 computing jacobians through
# vmap is very slow.
# pmap needs exact number of devices:
# xs_pred, ys_pred = jax.pmap(model)(x0s, t=ts0, ucoeffs=ucoeffs)
# vmap is slow:
xs_pred, ys_pred = jax.vmap(model)(x0s, t=ts0, ucoeffs=ucoeffs)
# xmap needs axies names and seems complicated:
# in_axes = [['shots', ...], ['shots', ...], ['shots', ...]]
# out_axes = ['shots', ...]
# m = lambda x, t, u: model(x, t=t, ucoeffs=u)
# xs_pred, ys_pred = xmap(m, in_axes=in_axes, out_axes=[...])(x0s, ts0, ucoeffs)
# just use serial map:
# m = lambda x, t, u: model(x, t=t, ucoeffs=u)
# xs_pred, ys_pred = zip(*list(map(m, x0s, ts0, ucoeffs)))
# xs_pred = jnp.stack(xs_pred)
# ys_pred = jnp.stack(ys_pred)
# output residual
res_y = ((ys - ys_pred) / std_y).reshape(-1)
res_y = res_y / np.sqrt(len(res_y))
Expand All @@ -254,20 +377,19 @@ def residuals(params):
res = least_squares(
fun, init_params, bounds=bounds, jac=jac, x_scale="jac", **kwargs
)
x0s, model = unpack(res.x, treedef, x0s_shape)

x0s = np.asarray(x0s)
ts = np.asarray(ts)
ts0 = np.asarray(ts0)
x0s, res.x = unpack(res.x, treedef, x0s_shape)
res.x0s = np.asarray(x0s)
res.ts = np.asarray(ts)
res.ts0 = np.asarray(ts0)

if u is None:
return model, x0s, ts, ts0
else:
us = np.asarray(us)
return model, x0s, ts, ts0, us
if u is not None:
res.us = np.asarray(us)

return res

def transfer_function(sys: DynamicalSystem, to_states=False, **kwargs):

def transfer_function(sys: DynamicalSystem, to_states: bool = False, **kwargs):
"""Compute transfer-function of linearized system."""
linsys = sys.linearize(**kwargs)
A, B, C, D = linsys.A, linsys.B, linsys.C, linsys.D
Expand All @@ -286,7 +408,9 @@ def H(s):
def estimate_spectra(
u: ArrayLike, y: ArrayLike, sr: int, nperseg: int
) -> tuple[NDArray, NDArray, NDArray]:
"""Estimate cross and autopectral densities."""
"""Estimate cross and autospectral densities."""
u = np.asarray(u)
y = np.asarray(y)
if u.ndim == 1:
u = u[:, None]
if y.ndim == 1:
Expand All @@ -297,10 +421,14 @@ def estimate_spectra(


def fit_csd_matching(
sys: DynamicalSystem, u, y, sr, nperseg=1024, reg=0, ret_Syx=False, **kwargs
) -> Union[
DynamicalSystem, tuple[DynamicalSystem, tuple[NDArray, NDArray, NDArray, NDArray]]
]:
sys: DynamicalSystem,
u: ArrayLike,
y: ArrayLike,
sr: int,
nperseg: int = 1024,
reg: float = 0,
**kwargs,
) -> OptimizeResult:
"""Estimate parameters of linearized system by matching cross-spectral densities."""
f, S_yu, S_uu = estimate_spectra(u, y, sr, nperseg)
s = 2 * np.pi * f * 1j
Expand All @@ -326,9 +454,5 @@ def residuals(params):
fun = MemoizeJac(jax.jit(lambda x: value_and_jacfwd(residuals, x)))
jac = fun.derivative
res = least_squares(fun, x0, jac=jac, x_scale="jac", bounds=bounds, **kwargs)
fitted_sys = unravel(res.x)
if ret_Syx:
H = transfer_function(fitted_sys)
hatS_yu = jax.vmap(H)(s) * S_uu
return fitted_sys, (f, hatS_yu, S_yu, S_uu)
return fitted_sys
res.x = unravel(res.x)
return res
Loading
Loading