Skip to content

Commit

Permalink
add additional scalar type
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed May 2, 2024
1 parent 84c8489 commit 1a6f69c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
7 changes: 6 additions & 1 deletion dynax/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import typing
from typing import Callable, TypeAlias
from typing import Callable, TypeAlias, Union

import jaxtyping
import numpy as np


generating_docs = getattr(typing, "GENERATING_DOCUMENTATION", False)
Expand All @@ -13,6 +14,7 @@

Scalar: TypeAlias = Array
ScalarLike: TypeAlias = ArrayLike
FloatScalarLike = Union[float, Array, np.ndarray]
elif generating_docs:
# In the docs.
class Scalar:
Expand All @@ -27,6 +29,8 @@ class Array:
class ArrayLike:
pass

FloatScalarLike = float

for cls in (Scalar, ScalarLike, Array, ArrayLike):
cls.__module__ = "builtins"
cls.__qualname__ = cls.__name__
Expand All @@ -37,6 +41,7 @@ class ArrayLike:

Scalar = jaxtyping.Shaped[Array, ""]
ScalarLike = jaxtyping.Shaped[ArrayLike, ""]
FloatScalarLike = jaxtyping.Float[ArrayLike, ""]


VectorFunc: TypeAlias = Callable[[Array], Array]
Expand Down
29 changes: 19 additions & 10 deletions dynax/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

from abc import abstractmethod
from collections.abc import Callable
from dataclasses import field, Field
from typing import Literal, Any
from dataclasses import Field, field
from typing import Any, Literal

import equinox
import jax
import jax.numpy as jnp
import numpy as np
from jax import Array

from .custom_types import ArrayLike
from .custom_types import FloatScalarLike
from .util import dim2shape, pretty


Expand Down Expand Up @@ -165,7 +165,7 @@ def __check_init__(self):

@abstractmethod
def vector_field(
self, x: Array, u: Array | None = None, t: float | None = None
self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None
) -> Array:
"""Compute state derivative.
Expand All @@ -180,7 +180,9 @@ def vector_field(
"""
raise NotImplementedError

def output(self, x: Array, u: Array | None = None, t: float | None = None) -> Array:
def output(
self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None
) -> Array:
"""Compute output.
Args:
Expand All @@ -203,7 +205,10 @@ def n_outputs(self) -> int | Literal["scalar"]:
return "scalar" if y.ndim == 0 else y.shape[0]

def linearize(
self, x0: Array | None = None, u0: Array | None = None, t: float | None = None
self,
x0: Array | None = None,
u0: Array | None = None,
t: FloatScalarLike | None = None,
) -> "LinearSystem":
"""Compute the Jacobian linearizationaround a point.
Expand Down Expand Up @@ -389,15 +394,17 @@ def __init__(self, sys1: AbstractSystem, sys2: AbstractSystem):
self.n_inputs = sys1.n_inputs

def vector_field(
self, x: Array, u: Array | None = None, t: float | None = None
self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None
) -> Array:
x1, x2 = self._unpack_states(x)
y1 = self._sys1.output(x1, u, t)
dx1 = self._sys1.vector_field(x1, u, t)
dx2 = self._sys2.vector_field(x2, y1, t)
return self._pack_states(dx1, dx2)

def output(self, x: Array, u: Array | None = None, t: float | None = None) -> Array:
def output(
self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None
) -> Array:
x1, x2 = self._unpack_states(x)
y1 = self._sys1.output(x1, u, t)
y2 = self._sys2.output(x2, y1, t)
Expand Down Expand Up @@ -437,7 +444,7 @@ def __init__(self, sys1: AbstractSystem, sys2: AbstractSystem):
self.n_inputs = sys1.n_inputs

def vector_field(
self, x: Array, u: Array | None = None, t: float | None = None
self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None
) -> Array:
if u is None:
u = jnp.zeros(dim2shape(self._sys1.n_inputs))
Expand All @@ -449,7 +456,9 @@ def vector_field(
dx = self._pack_states(dx1, dx2)
return dx

def output(self, x: Array, u: Array | None = None, t: float | None = None) -> Array:
def output(
self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None
) -> Array:
x1, _ = self._unpack_states(x)
y = self._sys1.output(x1, None, t)
return y
Expand Down

0 comments on commit 1a6f69c

Please sign in to comment.