Skip to content

Commit

Permalink
Feat/estimate initial state (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl authored Feb 11, 2024
1 parent 0e0c3ad commit dc94bd8
Show file tree
Hide file tree
Showing 18 changed files with 730 additions and 481 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ experiments
.coverage
htmlcov
build
docs/source/_build
_build
docs/generated
*.pytest_cache
.pytype
Expand Down
21 changes: 17 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import typing


project = "Dynax"
copyright = "2023, Franz M. Heuchel"
Expand All @@ -14,13 +16,15 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = [
"sphinx.ext.napoleon",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.todo",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinxcontrib.bibtex",
"sphinx_autodoc_typehints",
"sphinxcontrib.aafig",
# "sphinx_autodoc_typehints",
"nbsphinx",
]

Expand All @@ -40,11 +44,20 @@
}

autoclass_content = "both"
# autodoc_typehints = "signature"
# typehints_use_signature = "True"
autodoc_typehints = "signature"
typehints_use_signature = True


napoleon_include_init_with_doc = True
napoleon_preprocess_types = True
napoleon_attr_annotations = True

# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output

html_theme = "furo"
html_theme = "sphinx_rtd_theme"
html_static_path = ["_static"]

# Short type docs for jaxtyping's types
# https://github.com/patrick-kidger/pytkdocs_tweaks/blob/2a7ce453e315f526d792f689e61d56ecaa4ab000/pytkdocs_tweaks/__init__.py#L283
typing.GENERATING_DOCUMENTATION = True # pyright: ignore
42 changes: 21 additions & 21 deletions dynax/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import scipy.signal as sig
from jax import Array
from jax.flatten_util import ravel_pytree
from jax.typing import ArrayLike
from jaxtyping import ArrayLike
from numpy.typing import NDArray
from scipy.linalg import pinvh
from scipy.optimize import least_squares, OptimizeResult
Expand Down Expand Up @@ -92,7 +92,7 @@ def _compute_covariance(

def _least_squares(
f: Callable[[Array], Array],
x0: NDArray,
init_params: Array,
bounds: tuple[list, list],
reg_term: Optional[Callable[[Array], Array]] = None,
x_scale: bool = True,
Expand All @@ -105,7 +105,7 @@ def _least_squares(
# Add regularization term
_f = f
_reg_term = reg_term # https://github.com/python/mypy/issues/7268
f = lambda x: jnp.concatenate((_f(x), _reg_term(x)))
f = lambda params: jnp.concatenate((_f(params), _reg_term(params)))

if verbose_mse:
# Scale cost to mean-squared error
Expand All @@ -117,15 +117,17 @@ def f(params):

if x_scale:
# Scale parameters and bounds by initial values
norm = np.where(np.asarray(x0) != 0, np.abs(x0), 1)
x0 = x0 / norm
norm = np.where(np.asarray(init_params) != 0, np.abs(init_params), 1)
init_params = init_params / norm
___f = f
f = lambda x: ___f(x * norm)
f = lambda params: ___f(params * norm)
bounds = (np.array(bounds[0]) / norm, np.array(bounds[1]) / norm)

fun = MemoizeJac(jax.jit(lambda x: value_and_jacfwd(f, x)))
jac = fun.derivative
res = least_squares(fun, x0, bounds=bounds, jac=jac, x_scale="jac", **kwargs)
res = least_squares(
fun, init_params, bounds=bounds, jac=jac, x_scale="jac", **kwargs
)

if x_scale:
# Unscale parameters
Expand All @@ -139,8 +141,8 @@ def f(params):

if reg_term is not None:
# Remove regularization from residuals and Jacobian and cost
res.fun = res.fun[: -len(x0)]
res.jac = res.jac[: -len(x0)]
res.fun = res.fun[: -len(init_params)]
res.jac = res.jac[: -len(init_params)]
res.cost = np.sum(res.fun**2) / 2

return res
Expand All @@ -150,7 +152,6 @@ def fit_least_squares(
model: AbstractEvolution,
t: ArrayLike,
y: ArrayLike,
x0: ArrayLike,
u: Optional[ArrayLike] = None,
batched: bool = False,
sigma: Optional[ArrayLike] = None,
Expand All @@ -168,7 +169,6 @@ def fit_least_squares(
"""
t = jnp.asarray(t)
y = jnp.asarray(y)
x0 = jnp.asarray(x0)

if batched:
# First axis holds experiments, second axis holds time.
Expand Down Expand Up @@ -212,7 +212,7 @@ def residual_term(params):
# this can use pmap, if batch size is smaller than CPU cores
model = jax.vmap(model)
# FIXME: ucoeffs not supported for Map
_, pred_y = model(x0, t=t, ucoeffs=ucoeffs)
_, pred_y = model(t=t, ucoeffs=ucoeffs)
res = (y - pred_y) * weight
return res.reshape(-1)

Expand Down Expand Up @@ -246,7 +246,6 @@ def fit_multiple_shooting(
model: AbstractEvolution,
t: ArrayLike,
y: ArrayLike,
x0: ArrayLike,
u: Optional[Union[Callable[[float], Array], ArrayLike]] = None,
num_shots: int = 1,
continuity_penalty: float = 0.1,
Expand All @@ -273,7 +272,6 @@ def fit_multiple_shooting(
"""
t = jnp.asarray(t)
y = jnp.asarray(y)
x0 = jnp.asarray(x0)

if u is None:
msg = (
Expand All @@ -300,11 +298,13 @@ def fit_multiple_shooting(
t = t[:num_samples]
y = y[:num_samples]

n_states = len(model.system.initial_state)

# TODO: use numpy for everything that is not jitted
# Divide signals into segments.
ts = _moving_window(t, num_samples_per_segment, num_samples_per_segment - 1)
ys = _moving_window(y, num_samples_per_segment, num_samples_per_segment - 1)
x0s = np.zeros((num_shots - 1, len(x0)))
x0s = np.zeros((num_shots - 1, n_states))

ucoeffs = None
if u is not None:
Expand All @@ -321,8 +321,8 @@ def fit_multiple_shooting(
std_y = np.std(y, axis=0)
parameter_bounds = _get_bounds(model)
state_bounds = (
(num_shots - 1) * len(x0) * [-np.inf],
(num_shots - 1) * len(x0) * [np.inf],
(num_shots - 1) * n_states * [-np.inf],
(num_shots - 1) * n_states * [np.inf],
)
bounds = (
state_bounds[0] + parameter_bounds[0],
Expand All @@ -331,8 +331,8 @@ def fit_multiple_shooting(

def residuals(params):
x0s, model = unravel(params)
x0s = jnp.concatenate((x0[None], x0s), axis=0)
xs_pred, ys_pred = jax.vmap(model)(x0s, t=ts0, ucoeffs=ucoeffs)
x0s = jnp.concatenate((model.system.initial_state[None], x0s), axis=0)
xs_pred, ys_pred = jax.vmap(model)(t=ts0, ucoeffs=ucoeffs, initial_state=x0s)
# output residual
res_y = ((ys - ys_pred) / std_y).reshape(-1)
res_y = res_y / np.sqrt(len(res_y))
Expand All @@ -345,7 +345,7 @@ def residuals(params):
res = _least_squares(residuals, init_params, bounds, x_scale=False, **kwargs)

x0s, res.result = unravel(res.x)
res.x0s = np.asarray(jnp.concatenate((x0[None], x0s), axis=0))
res.x0s = jnp.concatenate((res.result.system.initial_state[None], x0s), axis=0)
res.ts = np.asarray(ts)
res.ts0 = np.asarray(ts0)

Expand All @@ -362,7 +362,7 @@ def transfer_function(sys: DynamicalSystem, to_states: bool = False, **kwargs):

def H(s: complex):
"""Transfer-function at s."""
identity = np.eye(linsys.n_states)
identity = np.eye(linsys.initial_state.size)
phi_B = jnp.linalg.solve(s * identity - A, B)
if to_states:
return phi_B
Expand Down
Loading

0 comments on commit dc94bd8

Please sign in to comment.