Skip to content

Commit

Permalink
document evolution.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed May 2, 2024
1 parent 0c59bf5 commit 84c8489
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 22 deletions.
74 changes: 52 additions & 22 deletions dynax/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
class AbstractEvolution(Module):
"""Abstract base-class for evolutions.
from :py:class:`equinox.Module`.
Evolutions combine dynamical systems with a solver. They simulate the evolution of
the system state and output over time given an initial and, possibly, an input
sequence.
"""

system: AbstractSystem
Expand All @@ -39,12 +42,12 @@ def __call__(
Args:
t: Times at which to evaluate the evolution.
u: An optional input sequence of same length.
initial_state: An optional, fixed initial state used instead of
`system.initial_state`.
u: Optional input sequence of same length.
initial_state: Optional, fixed initial state used instead of
:py:attr:`AbstractSystem.initial_state`.
Returns:
A tuple `(x, y)` of state and output sequences.
Tuple `(x, y)` of state and output sequences.
"""
raise NotImplementedError
Expand All @@ -54,15 +57,16 @@ class Flow(AbstractEvolution):
"""Evolution for continous-time dynamical systems.
Args:
system: A dynamical system.
solver: A diffrax solver.
stepsize_controller: A diffrax stepsize controller.
system: Dynamical system.
solver: Differential equation solver. Defaults to :py:class:`diffrax.Dopri5`.
stepsize_controller: Stepsize controller. Defaults to
:py:class:`diffrax.ConstantStepSize`.
"""

solver: AbstractAdaptiveSolver = static_field(default_factory=Dopri5)
stepsize_controller: AbstractStepSizeController = static_field(
default_factory=lambda: ConstantStepSize()
default_factory=ConstantStepSize
)

def __call__(
Expand All @@ -75,16 +79,25 @@ def __call__(
ucoeffs: Optional[tuple[PyTree, PyTree, PyTree, PyTree]] = None,
**diffeqsolve_kwargs,
) -> tuple[Array, Array]:
(
super().__call__.__doc__
+ """
Additional args:
ufun: A function `t -> u` that returns the input at time `t`.
ucoeffs: A tuple of coefficients for a cubic spline interpolation.
**diffeqsolve_kwargs: Additional arguments to pass to `diffeqsolve`.
r"""Evolve an initial state along the vector field and compute output.
Args:
t: Times at which to evaluate the evolution.
u: Optional input sequence of same length.
initial_state: Optional, fixed initial state used instead of
:py:attr:`AbstractSystem.initial_state`.
ufun: A function :math:`t \mapsto u`. Can be used instead of `u` or
`ucoeffs`.
ucoeffs: Precomputed spline coefficients of the input passed to
:py:class:`diffrax.CubicInterpolation`. Can be used instead of `u` or
`ufun`.
**diffeqsolve_kwargs: Additional arguments passed to
:py:meth:`diffrax.diffeqsolve`.
Returns:
Tuple `(x, y)` of state and output sequences.
"""
)
# Parse inputs.
t = jnp.asarray(t)

Expand All @@ -100,7 +113,7 @@ def __call__(
path = CubicInterpolation(t, ucoeffs)
_ufun = path.evaluate
elif callable(ufun):
_ufun = u
_ufun = ufun
elif u is not None:
u = jnp.asarray(u)
if len(t) != u.shape[0]:
Expand Down Expand Up @@ -139,7 +152,7 @@ def __call__(
t1=t[-1],
y0=initial_state,
args=self, # https://github.com/patrick-kidger/diffrax/issues/135
**diffeqsolve_default_options,
**diffeqsolve_default_options, # type: ignore
).ys
# Could be in general a Pytree, but we only allow Array states.
x = cast(Array, x)
Expand All @@ -151,7 +164,12 @@ def __call__(


class Map(AbstractEvolution):
"""Evolution for discrete-time dynamical systems."""
"""Evolution for discrete-time dynamical systems.
Args:
system: Dynamical system.
"""

def __call__(
self,
Expand All @@ -161,7 +179,19 @@ def __call__(
*,
num_steps: Optional[int] = None,
) -> tuple[Array, Array]:
"""Solve discrete map."""
"""Evolve an initial state along the vector field and compute output.
Args:
t: Times at which to evaluate the evolution.
u: Optional input sequence of same length.
initial_state: Optional, fixed initial state used instead of
:py:attr:`AbstractSystem.initial_state`.
num_steps: Number of steps to compute if `t` and `u` are not specified.
Returns:
Tuple `(x, y)` of state and output sequences.
"""

# Parse inputs.
if initial_state is not None:
Expand Down
1 change: 1 addition & 0 deletions dynax/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, ts: Array, xs: Array):
self.path = dfx.CubicInterpolation(ts, coeffs)

def __call__(self, t: float) -> Array:
"""Evaluate the interpolating function at time `t`."""
return self.path.evaluate(t)


Expand Down

0 comments on commit 84c8489

Please sign in to comment.