Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/estimate initial state #25

Merged
merged 12 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ experiments
.coverage
htmlcov
build
docs/source/_build
_build
docs/generated
*.pytest_cache
.pytype
Expand Down
21 changes: 17 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import typing


project = "Dynax"
copyright = "2023, Franz M. Heuchel"
Expand All @@ -14,13 +16,15 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = [
"sphinx.ext.napoleon",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.todo",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinxcontrib.bibtex",
"sphinx_autodoc_typehints",
"sphinxcontrib.aafig",
# "sphinx_autodoc_typehints",
"nbsphinx",
]

Expand All @@ -40,11 +44,20 @@
}

autoclass_content = "both"
# autodoc_typehints = "signature"
# typehints_use_signature = "True"
autodoc_typehints = "signature"
typehints_use_signature = True


napoleon_include_init_with_doc = True
napoleon_preprocess_types = True
napoleon_attr_annotations = True

# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output

html_theme = "furo"
html_theme = "sphinx_rtd_theme"
html_static_path = ["_static"]

# Short type docs for jaxtyping's types
# https://github.com/patrick-kidger/pytkdocs_tweaks/blob/2a7ce453e315f526d792f689e61d56ecaa4ab000/pytkdocs_tweaks/__init__.py#L283
typing.GENERATING_DOCUMENTATION = True # pyright: ignore
42 changes: 21 additions & 21 deletions dynax/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import scipy.signal as sig
from jax import Array
from jax.flatten_util import ravel_pytree
from jax.typing import ArrayLike
from jaxtyping import ArrayLike
from numpy.typing import NDArray
from scipy.linalg import pinvh
from scipy.optimize import least_squares, OptimizeResult
Expand Down Expand Up @@ -92,7 +92,7 @@ def _compute_covariance(

def _least_squares(
f: Callable[[Array], Array],
x0: NDArray,
init_params: Array,
bounds: tuple[list, list],
reg_term: Optional[Callable[[Array], Array]] = None,
x_scale: bool = True,
Expand All @@ -105,7 +105,7 @@ def _least_squares(
# Add regularization term
_f = f
_reg_term = reg_term # https://github.com/python/mypy/issues/7268
f = lambda x: jnp.concatenate((_f(x), _reg_term(x)))
f = lambda params: jnp.concatenate((_f(params), _reg_term(params)))

if verbose_mse:
# Scale cost to mean-squared error
Expand All @@ -117,15 +117,17 @@ def f(params):

if x_scale:
# Scale parameters and bounds by initial values
norm = np.where(np.asarray(x0) != 0, np.abs(x0), 1)
x0 = x0 / norm
norm = np.where(np.asarray(init_params) != 0, np.abs(init_params), 1)
init_params = init_params / norm
___f = f
f = lambda x: ___f(x * norm)
f = lambda params: ___f(params * norm)
bounds = (np.array(bounds[0]) / norm, np.array(bounds[1]) / norm)

fun = MemoizeJac(jax.jit(lambda x: value_and_jacfwd(f, x)))
jac = fun.derivative
res = least_squares(fun, x0, bounds=bounds, jac=jac, x_scale="jac", **kwargs)
res = least_squares(
fun, init_params, bounds=bounds, jac=jac, x_scale="jac", **kwargs
)

if x_scale:
# Unscale parameters
Expand All @@ -139,8 +141,8 @@ def f(params):

if reg_term is not None:
# Remove regularization from residuals and Jacobian and cost
res.fun = res.fun[: -len(x0)]
res.jac = res.jac[: -len(x0)]
res.fun = res.fun[: -len(init_params)]
res.jac = res.jac[: -len(init_params)]
res.cost = np.sum(res.fun**2) / 2

return res
Expand All @@ -150,7 +152,6 @@ def fit_least_squares(
model: AbstractEvolution,
t: ArrayLike,
y: ArrayLike,
x0: ArrayLike,
u: Optional[ArrayLike] = None,
batched: bool = False,
sigma: Optional[ArrayLike] = None,
Expand All @@ -168,7 +169,6 @@ def fit_least_squares(
"""
t = jnp.asarray(t)
y = jnp.asarray(y)
x0 = jnp.asarray(x0)

if batched:
# First axis holds experiments, second axis holds time.
Expand Down Expand Up @@ -212,7 +212,7 @@ def residual_term(params):
# this can use pmap, if batch size is smaller than CPU cores
model = jax.vmap(model)
# FIXME: ucoeffs not supported for Map
_, pred_y = model(x0, t=t, ucoeffs=ucoeffs)
_, pred_y = model(t=t, ucoeffs=ucoeffs)
res = (y - pred_y) * weight
return res.reshape(-1)

Expand Down Expand Up @@ -246,7 +246,6 @@ def fit_multiple_shooting(
model: AbstractEvolution,
t: ArrayLike,
y: ArrayLike,
x0: ArrayLike,
u: Optional[Union[Callable[[float], Array], ArrayLike]] = None,
num_shots: int = 1,
continuity_penalty: float = 0.1,
Expand All @@ -273,7 +272,6 @@ def fit_multiple_shooting(
"""
t = jnp.asarray(t)
y = jnp.asarray(y)
x0 = jnp.asarray(x0)

if u is None:
msg = (
Expand All @@ -300,11 +298,13 @@ def fit_multiple_shooting(
t = t[:num_samples]
y = y[:num_samples]

n_states = len(model.system.initial_state)

# TODO: use numpy for everything that is not jitted
# Divide signals into segments.
ts = _moving_window(t, num_samples_per_segment, num_samples_per_segment - 1)
ys = _moving_window(y, num_samples_per_segment, num_samples_per_segment - 1)
x0s = np.zeros((num_shots - 1, len(x0)))
x0s = np.zeros((num_shots - 1, n_states))

ucoeffs = None
if u is not None:
Expand All @@ -321,8 +321,8 @@ def fit_multiple_shooting(
std_y = np.std(y, axis=0)
parameter_bounds = _get_bounds(model)
state_bounds = (
(num_shots - 1) * len(x0) * [-np.inf],
(num_shots - 1) * len(x0) * [np.inf],
(num_shots - 1) * n_states * [-np.inf],
(num_shots - 1) * n_states * [np.inf],
)
bounds = (
state_bounds[0] + parameter_bounds[0],
Expand All @@ -331,8 +331,8 @@ def fit_multiple_shooting(

def residuals(params):
x0s, model = unravel(params)
x0s = jnp.concatenate((x0[None], x0s), axis=0)
xs_pred, ys_pred = jax.vmap(model)(x0s, t=ts0, ucoeffs=ucoeffs)
x0s = jnp.concatenate((model.system.initial_state[None], x0s), axis=0)
xs_pred, ys_pred = jax.vmap(model)(t=ts0, ucoeffs=ucoeffs, initial_state=x0s)
# output residual
res_y = ((ys - ys_pred) / std_y).reshape(-1)
res_y = res_y / np.sqrt(len(res_y))
Expand All @@ -345,7 +345,7 @@ def residuals(params):
res = _least_squares(residuals, init_params, bounds, x_scale=False, **kwargs)

x0s, res.result = unravel(res.x)
res.x0s = np.asarray(jnp.concatenate((x0[None], x0s), axis=0))
res.x0s = jnp.concatenate((res.result.system.initial_state[None], x0s), axis=0)
res.ts = np.asarray(ts)
res.ts0 = np.asarray(ts0)

Expand All @@ -362,7 +362,7 @@ def transfer_function(sys: DynamicalSystem, to_states: bool = False, **kwargs):

def H(s: complex):
"""Transfer-function at s."""
identity = np.eye(linsys.n_states)
identity = np.eye(linsys.initial_state.size)
phi_B = jnp.linalg.solve(s * identity - A, B)
if to_states:
return phi_B
Expand Down
Loading
Loading