diff --git a/dynax/custom_types.py b/dynax/custom_types.py index 9ec38ae..a6c1448 100644 --- a/dynax/custom_types.py +++ b/dynax/custom_types.py @@ -1,7 +1,8 @@ import typing -from typing import Callable, TypeAlias +from typing import Callable, TypeAlias, Union import jaxtyping +import numpy as np generating_docs = getattr(typing, "GENERATING_DOCUMENTATION", False) @@ -13,6 +14,7 @@ Scalar: TypeAlias = Array ScalarLike: TypeAlias = ArrayLike + FloatScalarLike = Union[float, Array, np.ndarray] elif generating_docs: # In the docs. class Scalar: @@ -27,6 +29,8 @@ class Array: class ArrayLike: pass + FloatScalarLike = float + for cls in (Scalar, ScalarLike, Array, ArrayLike): cls.__module__ = "builtins" cls.__qualname__ = cls.__name__ @@ -37,6 +41,7 @@ class ArrayLike: Scalar = jaxtyping.Shaped[Array, ""] ScalarLike = jaxtyping.Shaped[ArrayLike, ""] + FloatScalarLike = jaxtyping.Float[ArrayLike, ""] VectorFunc: TypeAlias = Callable[[Array], Array] diff --git a/dynax/system.py b/dynax/system.py index 3969769..c9c6b76 100644 --- a/dynax/system.py +++ b/dynax/system.py @@ -2,8 +2,8 @@ from abc import abstractmethod from collections.abc import Callable -from dataclasses import field, Field -from typing import Literal, Any +from dataclasses import Field, field +from typing import Any, Literal import equinox import jax @@ -11,7 +11,7 @@ import numpy as np from jax import Array -from .custom_types import ArrayLike +from .custom_types import FloatScalarLike from .util import dim2shape, pretty @@ -165,7 +165,7 @@ def __check_init__(self): @abstractmethod def vector_field( - self, x: Array, u: Array | None = None, t: float | None = None + self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: """Compute state derivative. @@ -180,7 +180,9 @@ def vector_field( """ raise NotImplementedError - def output(self, x: Array, u: Array | None = None, t: float | None = None) -> Array: + def output( + self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None + ) -> Array: """Compute output. Args: @@ -203,7 +205,10 @@ def n_outputs(self) -> int | Literal["scalar"]: return "scalar" if y.ndim == 0 else y.shape[0] def linearize( - self, x0: Array | None = None, u0: Array | None = None, t: float | None = None + self, + x0: Array | None = None, + u0: Array | None = None, + t: FloatScalarLike | None = None, ) -> "LinearSystem": """Compute the Jacobian linearizationaround a point. @@ -389,7 +394,7 @@ def __init__(self, sys1: AbstractSystem, sys2: AbstractSystem): self.n_inputs = sys1.n_inputs def vector_field( - self, x: Array, u: Array | None = None, t: float | None = None + self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: x1, x2 = self._unpack_states(x) y1 = self._sys1.output(x1, u, t) @@ -397,7 +402,9 @@ def vector_field( dx2 = self._sys2.vector_field(x2, y1, t) return self._pack_states(dx1, dx2) - def output(self, x: Array, u: Array | None = None, t: float | None = None) -> Array: + def output( + self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None + ) -> Array: x1, x2 = self._unpack_states(x) y1 = self._sys1.output(x1, u, t) y2 = self._sys2.output(x2, y1, t) @@ -437,7 +444,7 @@ def __init__(self, sys1: AbstractSystem, sys2: AbstractSystem): self.n_inputs = sys1.n_inputs def vector_field( - self, x: Array, u: Array | None = None, t: float | None = None + self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: if u is None: u = jnp.zeros(dim2shape(self._sys1.n_inputs)) @@ -449,7 +456,9 @@ def vector_field( dx = self._pack_states(dx1, dx2) return dx - def output(self, x: Array, u: Array | None = None, t: float | None = None) -> Array: + def output( + self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None + ) -> Array: x1, _ = self._unpack_states(x) y = self._sys1.output(x1, None, t) return y