Skip to content

Commit

Permalink
Own classes for opt results
Browse files Browse the repository at this point in the history
Fixes #15
  • Loading branch information
fhchl committed Sep 21, 2023
1 parent 08fe3d1 commit cb57d03
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 52 deletions.
50 changes: 7 additions & 43 deletions dynax/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from jax.typing import ArrayLike
from numpy.typing import NDArray
from scipy.linalg import pinvh
from scipy.optimize import least_squares, OptimizeResult as _OptimizeResult
from scipy.optimize import least_squares, OptimizeResult
from scipy.optimize._optimize import MemoizeJac

from .evolution import AbstractEvolution
Expand Down Expand Up @@ -58,42 +58,6 @@ def _key_paths(tree: Any, root: str = "tree") -> list[str]:
return [f"{root}{jtu.keystr(kp)}" for kp, _ in flattened]


class OptimizeResult(_OptimizeResult):
"""Represents the optimization result.
Attributes
----------
x : Evolution
The solution of the optimization.
success : bool
Whether or not the optimizer exited successfully.
status : int
Termination status of the optimizer. Its value depends on the
underlying solver. Refer to `message` for details.
message : str
Description of the cause of the termination.
fun, jac, hess: ndarray
Values of objective function, its Jacobian and its Hessian (if
available). The Hessians may be approximations, see the documentation
of the function in question.
pcov: ndarray
Estimate of the covariance matrix.
hess_inv : object
Inverse of the objective function's Hessian; may be an approximation.
Not available for all solvers. The type of this attribute may be
either np.ndarray or scipy.sparse.linalg.LinearOperator.
key_paths: List of key_paths for x that index the corresponding entries in `pcov`,
`jac`, `hess` and `hess_inv`.
nfev, njev, nhev : int
Number of evaluations of the objective functions and of its
Jacobian and Hessian.
nit : int
Number of iterations performed by the optimizer.
maxcv : float
The maximum constraint violation.
"""


def _compute_covariance(
jac, cost, absolute_sigma: bool, cov_prior: Optional[NDArray] = None
) -> NDArray:
Expand Down Expand Up @@ -201,7 +165,7 @@ def fit_least_squares(
Parameters can be constrained via the `*_field` functions.
Args:
model: Forward model holding initial parameter estimates
model: Flow instance holding initial parameter estimates
t: Times at which `y` is given
y: Target outputs of system
x0: Initial state
Expand All @@ -225,9 +189,9 @@ def fit_least_squares(
Returns:
`OptimizeResult` as returned by `scipy.optimize.least_squares` with the
following fields defined:
following additional attributes defined:
model: `model` with estimated parameters.
result: `model` with estimated parameters.
cov: Covariance matrix of the parameter estimate.
y_pred: Model prediction at optimum.
key_paths: List of key_paths that index the corresponding entries in `cov`,
Expand Down Expand Up @@ -295,7 +259,7 @@ def residual_term(params):
**kwargs,
)

res.model = unravel(res.x)
res.result = unravel(res.x)
res.pcov = _compute_covariance(res.jac, res.cost, absolute_sigma, cov_prior)
res.y_pred = y - res.fun.reshape(y.shape) / weight
res.key_paths = _key_paths(model, root=model.__class__.__name__)
Expand Down Expand Up @@ -414,7 +378,7 @@ def residuals(params):

res = _least_squares(residuals, init_params, bounds, x_scale=False, **kwargs)

x0s, res.model = unravel(res.x)
x0s, res.result = unravel(res.x)
res.x0s = np.asarray(jnp.concatenate((x0[None], x0s), axis=0))
res.ts = np.asarray(ts)
res.ts0 = np.asarray(ts0)
Expand Down Expand Up @@ -513,7 +477,7 @@ def residuals(params):
Syu_pred_real, Syu_pred_imag = res.fun[: Syu.size], res.fun[Syu.size :]
Syu_pred = Syu - (Syu_pred_real + 1j * Syu_pred_imag).reshape(Syu.shape) / weight

res.sys = unravel(res.x)
res.result = unravel(res.x)
res.pcov = _compute_covariance(
res.jac, res.cost, absolute_sigma, cov_prior=cov_prior
)
Expand Down
2 changes: 1 addition & 1 deletion examples/fit_multiple_shooting_second_order_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def h(self, x):
verbose=2,
num_shots=num_shots,
)
model = res.model
model = res.result
x0s = res.x0s
ts = res.ts
ts0 = res.ts0
Expand Down
2 changes: 1 addition & 1 deletion examples/fit_ode.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@
"source": [
"init_model = Flow(initial_sys)\n",
"res = fit_least_squares(model=init_model, t=t, y=yn, x0=x0, u=u, verbose=2)\n",
"pred_model = res.model\n",
"pred_model = res.result\n",
"print(\"fitted system:\", pretty(pred_model.system))\n",
"print(\"Normalized mean squared error:\", res.nrmse)"
]
Expand Down
2 changes: 1 addition & 1 deletion examples/fit_second_order_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def h(self, x):
# Fit all parameters with previously estimated parameters as a starting guess.
pred_model = fit_least_squares(
model=init_model, t=t_train, y=y_train, x0=initial_x, u=u_train, verbose=0
).model
).result
print("fitted system:", pred_model.system)

# check the results
Expand Down
12 changes: 6 additions & 6 deletions tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_fit_least_squares(outputs):
_, y_true = true_model(x0, t, u)
# fit
init_model = Flow(NonlinearDrag(1.0, 1.0, 1.0, 1.0, outputs))
pred_model = fit_least_squares(init_model, t, y_true, x0, u).model
pred_model = fit_least_squares(init_model, t, y_true, x0, u).result
# check result
_, y_pred = pred_model(x0, t, u)
npt.assert_allclose(y_pred, y_true, **tols)
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_fit_least_squares_on_batch():
_, ys = jax.vmap(true_model)(x0s, ts, us)
# fit
init_model = Flow(NonlinearDrag(1.0, 1.0, 1.0, 1.0))
pred_model = fit_least_squares(init_model, ts, ys, x0s, us, batched=True).model
pred_model = fit_least_squares(init_model, ts, ys, x0s, us, batched=True).result
# check result
_, ys_pred = jax.vmap(pred_model)(x0s, ts, us)
npt.assert_allclose(ys_pred, ys, **tols)
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_fit_with_bounded_parameters():
init_model = Flow(
LotkaVolterra(alpha=1.0, beta=1.0, gamma=1.5, delta=2.0), **solver_opt
)
pred_model = fit_least_squares(init_model, t, x_true, x0).model
pred_model = fit_least_squares(init_model, t, x_true, x0).result
# check result
x_pred, _ = pred_model(x0, t)
npt.assert_allclose(x_pred, x_true, **tols)
Expand Down Expand Up @@ -145,7 +145,7 @@ def vector_field(self, x, u=None, t=None):
LotkaVolterra(alpha=1.0, beta=1.0, delta_gamma=jnp.array([1.5, 2])),
**solver_opt,
)
pred_model = fit_least_squares(init_model, t, x_true, x0).model
pred_model = fit_least_squares(init_model, t, x_true, x0).result
# check result
x_pred, _ = pred_model(x0, t)
npt.assert_allclose(x_pred, x_true, **tols)
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_fit_multiple_shooting_with_input(num_shots):
continuity_penalty=1,
num_shots=num_shots,
verbose=2,
).model
).result
# check result
x_pred, _ = pred_model(x0, t, u)
npt.assert_allclose(x_pred, x_true, **tols)
Expand All @@ -200,7 +200,7 @@ def test_fit_multiple_shooting_without_input(num_shots):
)
pred_model = fit_multiple_shooting(
init_model, t, x_true, x0, num_shots=num_shots, continuity_penalty=1
).model
).result
# check result
x_pred, _ = pred_model(x0, t)
npt.assert_allclose(x_pred, x_true, atol=1e-3, rtol=1e-3)
Expand Down

0 comments on commit cb57d03

Please sign in to comment.