"Dynamical systems in JAX"
This is WIP. Expect things to break!
This package allows for straight-forward simulation, fitting and linearization of dynamical systems by combing JAX, Diffrax, Equinox, and scipy.optimize. Its main features include:
- estimation of ODE parameters and their covariance via the prediction-error method (example)
- estimation of the initial state (example)
- estimation of linear ODE parameters via matching of frequency-response functions (example)
- estimation from multiple experiments
- estimation with a poor man's multiple shooting (example)
- input-output linearization of continuous-time input affine systems
- input-output linearization of discrete-time systems (example)
- estimation of a system's relative-degree (example)
Documentation is on its way. Until then, have a look at the example and test folders.
Requires Python 3.9+, JAX 0.4.23+, Equinox 0.11+ and Diffrax 0.5+. With a suitable version of jaxlib installed:
pip install .
Install with
pip install .[dev]
and run
pytest
To also test the examples, do
pytest --runslow
- nlgreyfast: Matlab library for fitting ODE's with mutliple shooting
- dynamax: inference and learning for probablistic state-space models