Skip to content

Commit

Permalink
simplify least squares
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Aug 11, 2023
1 parent d3b82ba commit 1c1aa60
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 128 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ repos:
rev: 'v0.0.269'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
args: [--fix]
- repo: https://github.com/ambv/black
rev: 23.3.0
hooks:
Expand All @@ -13,4 +13,4 @@ repos:
hooks:
- id: nbqa-black
- id: nbqa-ruff
args: [--fix, --exit-non-zero-on-fix]
args: [--fix]
8 changes: 4 additions & 4 deletions dynax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib.metadata

import jax as jax
import jax as _jax

from .derivative import lie_derivative as lie_derivative
from .estimation import (
Expand Down Expand Up @@ -29,13 +29,13 @@
static_field as static_field,
StaticStateFeedbackSystem as StaticStateFeedbackSystem,
)
from .util import monkeypatch_pretty_print, pretty as pretty
from .util import _monkeypatch_pretty_print, pretty as pretty


# TODO: leave out or make clear somewhere
print("Setting jax_enable_x64 to True.")
jax.config.update("jax_enable_x64", True)
_jax.config.update("jax_enable_x64", True)

monkeypatch_pretty_print()
_monkeypatch_pretty_print()

__version__ = importlib.metadata.version("dynax")
200 changes: 103 additions & 97 deletions dynax/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,73 @@ class OptimizeResult(_OptimizeResult):
"""


def _compute_covariance(res, absolute_sigma):
"""Compute covariance matrix."""
# pcov = H^{-1} ~= inv(J^T J). Use regularized inverse.
_, s, VT = svd(res.jac, full_matrices=False)
threshold = np.finfo(float).eps * max(res.jac.shape) * s[0]
s = s[s > threshold]
VT = VT[: s.size]
pcov = np.dot(VT.T / s**2, VT)

warn_cov = False
if not absolute_sigma:
if res.fun.size > res.x.size:
s_sq = res.cost / (res.fun.size - res.x.size)
pcov = pcov * s_sq
else:
warn_cov = True

if np.isnan(pcov).any():
warn_cov = True

if warn_cov:
pcov.fill(np.inf)
warnings.warn(
"Covariance of the parameters could not be estimated", stacklevel=2
)

return pcov


def _least_squares(
f: Callable[[ArrayLike], Array],
x0: NDArray,
bounds: tuple[NDArray, NDArray],
x_scale: bool = True,
verbose_mse: bool = True,
**kwargs,
):
"""Least-squares with jit, autodiff and parameter scaling."""
if verbose_mse:
_f = f

def f(x):
res = _f(x)
return res * np.sqrt(2 / res.size)

if x_scale:
norm = np.where(np.asarray(x0) != 0, x0, 1)
x0 = x0 / norm
__f = f
f = lambda x: __f(x * norm)
else:
norm = 1

fun = MemoizeJac(jax.jit(lambda x: value_and_jacfwd(f, x)))
jac = fun.derivative
res = least_squares(fun, x0, bounds=bounds, jac=jac, x_scale="jac", **kwargs)

res.x = res.x * norm
if verbose_mse:
mse_scaling = np.sqrt(2 / res.fun.size)
res.fun = res.fun / mse_scaling
res.jac = res.jac / mse_scaling
res.cost = res.cost / mse_scaling**2

return res


def fit_least_squares(
model: AbstractEvolution,
t: ArrayLike,
Expand All @@ -105,7 +172,7 @@ def fit_least_squares(
absolute_sigma: bool = False,
reg_val: float = 0,
reg_bias: Optional[Literal["initial"]] = None,
reg_weight: Optional[Literal["inv_initial"]] = None,
verbose_mse: bool = True,
**kwargs,
) -> OptimizeResult:
"""Fit forward model with (regularized) nonlinear least-squares.
Expand All @@ -129,22 +196,24 @@ def fit_least_squares(
(default), only the relative magnitudes of the `sigma` values matter and
`sigma` is scaled to match the sample variance of the residuals after the
fit.
reg_val: Weight of the l2 penalty term
reg_val: Weight of the l2 penalty term.
reg_bias: If "initial", bias the parameter estimates towards the values in
`model`.
reg_weight: If "inv_initial", weight each penalty term with the inverse
of the initial values. Applies no weightings to parameters with zero initial
value.
kwargs: optional parameters for `scipy.optimize.least_squares`
verbose_mse: Scale cost to mean-squared-error for easier interpretation.
kwargs: Optional parameters for `scipy.optimize.least_squares`.
Returns:
`OptimizeResult` as returned by `scipy.optimize.least_squares` with the
following fields defined:
x: `model` with estimated parameters
cov: Covariance matrix of the parameter estimate
key_paths: List of key_paths that index the corresponding entries in `cov`
and `jac`
model: `model` with estimated parameters.
cov: Covariance matrix of the parameter estimate.
y_pred: Model prediction at optimum.
key_paths: List of key_paths that index the corresponding entries in `cov`,
`jac`, and `x`.
mse: Mean-squared-error.
nmse: Normalized mean-squared-error.
nrmse: Normalized root-mean-squared-error.
"""
t = jnp.asarray(t)
Expand Down Expand Up @@ -179,74 +248,29 @@ def fit_least_squares(
if reg_bias == "initial":
param_bias = init_params

if reg_val != 0 and reg_weight == "initial":
one_or_init_param = np.where(np.asarray(init_params) != 0, init_params, 1)
reg_val = reg_val / one_or_init_param

is_regularized = np.any(reg_val != 0)
res_size = y.size
if is_regularized:
res_size = y.size + init_params.size

def residuals(params):
model = unravel(params)
if batched:
model = jax.vmap(model)
_, pred_y = model(x0, t=t, ucoeffs=ucoeffs)
res = (y - pred_y) * weight
res = res.reshape(-1)
res = ((y - pred_y) * weight).reshape(-1)
if is_regularized:
res = jnp.concatenate((res, reg_val * (params - param_bias)))
# Scale cost to mean squared error (mse) for interpretable verbose output.
res = res * np.sqrt(2 / res_size)
return res

# Compute primal and sensitivties in one forward pass.
fun = MemoizeJac(jax.jit(lambda x: value_and_jacfwd(residuals, x)))
jac = fun.derivative
res = least_squares(
fun, init_params, bounds=bounds, jac=jac, x_scale="jac", **kwargs
res = _least_squares(
residuals, init_params, bounds, verbose_mse=verbose_mse, **kwargs
)

# Unscale mse to Least-Squares cost.
res.fun = res.fun[: y.size].reshape(y.shape) / np.sqrt(2 / res_size) / weight
res.jac = res.jac * np.sqrt(res_size / 2)
res.cost = res.cost * res_size / 2

# Compute normalized root-mean-squared error
y_pred = y - res.fun
res.mse = np.atleast_1d(mse(y, y_pred))
res.nmse = np.atleast_1d(nmse(y, y_pred))
res.nrmse = np.atleast_1d(nrmse(y, y_pred))

# Compute covariance matrix.
# pcov = H^{-1} ~= inv(J^T J). Use regularized inverse.
_, s, VT = svd(res.jac, full_matrices=False)
threshold = np.finfo(float).eps * max(res.jac.shape) * s[0]
s = s[s > threshold]
VT = VT[: s.size]
pcov = np.dot(VT.T / s**2, VT)

warn_cov = False
if not absolute_sigma:
if res_size > res.x.size:
s_sq = res.cost / (res_size - res.x.size)
pcov = pcov * s_sq
else:
warn_cov = True

if np.isnan(pcov).any():
warn_cov = True

if warn_cov:
pcov.fill(np.inf)
warnings.warn(
"Covariance of the parameters could not be estimated", stacklevel=2
)

res.x = unravel(res.x)
res.pcov = pcov
res.model = unravel(res.x)
res.pcov = _compute_covariance(res, absolute_sigma)
res.y_pred = y - res.fun[: y.size].reshape(y.shape) / weight
res.key_paths = _key_paths(model, root=model.__class__.__name__)
res.mse = np.atleast_1d(mse(y, res.y_pred))
res.nmse = np.atleast_1d(nmse(y, res.y_pred))
res.nrmse = np.atleast_1d(nrmse(y, res.y_pred))

return res

Expand All @@ -264,7 +288,7 @@ def fit_multiple_shooting(
x0: ArrayLike,
u: Optional[Union[Callable[[float], Array], ArrayLike]] = None,
num_shots: int = 1,
continuity_penalty: float = 0.0,
continuity_penalty: float = 0.1,
**kwargs,
) -> OptimizeResult:
"""Fit forward model with multiple shooting and nonlinear least-squares.
Expand Down Expand Up @@ -319,8 +343,7 @@ def fit_multiple_shooting(
# Divide signals into segments.
ts = _moving_window(t, num_samples_per_segment, num_samples_per_segment - 1)
ys = _moving_window(y, num_samples_per_segment, num_samples_per_segment - 1)
x0s = np.broadcast_to(x0, (num_shots, len(x0))).copy()
x0s = np.concatenate((x0[None], np.zeros((num_shots - 1, len(x0)))))
x0s = np.zeros((num_shots - 1, len(x0)))

ucoeffs = None
if u is not None:
Expand All @@ -332,22 +355,8 @@ def fit_multiple_shooting(
# Each segment's time starts at 0.
ts0 = ts - ts[:, :1]

def pack(x0s, model):
# remove initial condition which is fixed not a parameter
x0s = x0s[1:]
x0s_shape = x0s.shape
flat, unravel = ravel_pytree((x0s.flatten().tolist(), model))
return flat, unravel, x0s_shape

def unpack(flat, unravel, x0s_shape):
x0s_list, model = unravel(flat)
x0s = jnp.array(x0s_list).reshape(x0s_shape)
# add initial condition
x0s = jnp.concatenate((x0[None], x0s), axis=0)
return x0s, model

# prepare optimization
init_params, treedef, x0s_shape = pack(x0s, model)
init_params, unravel = ravel_pytree((x0s, model))
std_y = np.std(y, axis=0)
parameter_bounds = _get_bounds(model)
state_bounds = (
Expand All @@ -360,7 +369,8 @@ def unpack(flat, unravel, x0s_shape):
)

def residuals(params):
x0s, model = unpack(params, treedef, x0s_shape)
x0s, model = unravel(params)
x0s = jnp.concatenate((x0[None], x0s), axis=0)
xs_pred, ys_pred = jax.vmap(model)(x0s, t=ts0, ucoeffs=ucoeffs)
# output residual
res_y = ((ys - ys_pred) / std_y).reshape(-1)
Expand All @@ -371,15 +381,10 @@ def residuals(params):
res_x0 = res_x0 / np.sqrt(len(res_x0))
return jnp.concatenate((res_y, continuity_penalty * res_x0))

# compute primal and sensitivties in one forward pass
fun = MemoizeJac(jax.jit(lambda x: value_and_jacfwd(residuals, x)))
jac = fun.derivative
res = least_squares(
fun, init_params, bounds=bounds, jac=jac, x_scale="jac", **kwargs
)
res = _least_squares(residuals, init_params, bounds, x_scale=False, **kwargs)

x0s, res.x = unpack(res.x, treedef, x0s_shape)
res.x0s = np.asarray(x0s)
x0s, res.model = unravel(res.x)
res.x0s = np.asarray(jnp.concatenate((x0[None], x0s), axis=0))
res.ts = np.asarray(ts)
res.ts0 = np.asarray(ts0)

Expand Down Expand Up @@ -427,21 +432,22 @@ def fit_csd_matching(
sr: int,
nperseg: int = 1024,
reg: float = 0,
verbose_mse: bool = True,
**kwargs,
) -> OptimizeResult:
"""Estimate parameters of linearized system by matching cross-spectral densities."""
f, S_yu, S_uu = estimate_spectra(u, y, sr, nperseg)
s = 2 * np.pi * f * 1j
weight = np.std(S_yu, axis=0) * np.sqrt(len(f))
x0, unravel = ravel_pytree(sys)
init_params, unravel = ravel_pytree(sys)

def residuals(params):
sys = unravel(params)
H = transfer_function(sys)
hatG_yx = jax.vmap(H)(s)
hatS_yu = hatG_yx * S_uu
res = (S_yu - hatS_yu) / weight
regterm = params / np.where(np.asarray(x0) != 0, x0, 1) * reg
regterm = params / np.where(np.asarray(init_params) != 0, init_params, 1) * reg
return jnp.concatenate(
(
jnp.real(res).reshape(-1),
Expand All @@ -451,8 +457,8 @@ def residuals(params):
)

bounds = _get_bounds(sys)
fun = MemoizeJac(jax.jit(lambda x: value_and_jacfwd(residuals, x)))
jac = fun.derivative
res = least_squares(fun, x0, jac=jac, x_scale="jac", bounds=bounds, **kwargs)
res.x = unravel(res.x)
res = _least_squares(
residuals, init_params, bounds, verbose_mse=verbose_mse, **kwargs
)
res.sys = unravel(res.x)
return res
2 changes: 1 addition & 1 deletion dynax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def nrmse(target, prediction, axis=0):
return jnp.sqrt(nmse(target, prediction, axis))


def monkeypatch_pretty_print():
def _monkeypatch_pretty_print():
from equinox._pretty_print import named_objs, bracketed, pp, dataclasses # noqa

def _pformat_dataclass(obj, **kwargs):
Expand Down
46 changes: 31 additions & 15 deletions examples/fit_ode.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 1c1aa60

Please sign in to comment.