Skip to content

Commit

Permalink
ensure shapes are checked at compile time
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Aug 18, 2024
1 parent 4901add commit 0280c2a
Showing 1 changed file with 39 additions and 35 deletions.
74 changes: 39 additions & 35 deletions dynax/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,21 +147,22 @@ def __check_init__(self):
if not hasattr(self, attr):
raise AttributeError(f"Attribute '{attr}' not initialized.")

# Check that vector_field and output returns Arrays or scalars and not PyTrees
x = self.initial_state
u = jax.ShapeDtypeStruct(dim2shape(self.n_inputs), jnp.float64)
try:
dx = eqx.filter_eval_shape(self.vector_field, x, u, t=1.0)
y = eqx.filter_eval_shape(self.output, x, u, t=1.0)
except Exception as e:
raise ValueError(
"Can not evaluate output shapes. Check your definitions!"
) from e
for val, func in zip((dx, y), ("vector_field, output")): # noqa: B905
if not isinstance(val, jax.ShapeDtypeStruct):
with jax.ensure_compile_time_eval():
# Check that vector_field and output returns Arrays or scalars - not PyTrees
x = jax.ShapeDtypeStruct(self.initial_state.shape, jnp.float64)
u = jax.ShapeDtypeStruct(dim2shape(self.n_inputs), jnp.float64)
try:
dx = eqx.filter_eval_shape(self.vector_field, x, u, t=1.0)
y = eqx.filter_eval_shape(self.output, x, u, t=1.0)
except Exception as e:
raise ValueError(
f"{func} must return arrays or scalars, not {type(val)}"
)
"Can not evaluate output shapes. Check your definitions!"
) from e
for val, func in zip((dx, y), ("vector_field, output")): # noqa: B905
if not isinstance(val, jax.ShapeDtypeStruct):
raise ValueError(
f"{func} must return arrays or scalars, not {type(val)}"
)

@abstractmethod
def vector_field(
Expand Down Expand Up @@ -199,10 +200,12 @@ def output(
@property
def n_outputs(self) -> int | Literal["scalar"]:
"""The size of the output vector."""
x = self.initial_state
u = jax.ShapeDtypeStruct(dim2shape(self.n_inputs), jnp.float64)
y = eqx.filter_eval_shape(self.output, x, u, t=1.0)
return "scalar" if y.ndim == 0 else y.shape[0]
with jax.ensure_compile_time_eval():
x = jax.ShapeDtypeStruct(self.initial_state.shape, jnp.float64)
u = jax.ShapeDtypeStruct(dim2shape(self.n_inputs), jnp.float64)
y = eqx.filter_eval_shape(self.output, x, u, t=1.0)
n_out = "scalar" if y.ndim == 0 else y.shape[0]
return n_out

def linearize(
self,
Expand Down Expand Up @@ -309,24 +312,25 @@ class LinearSystem(AbstractControlAffine):
"""Feedthrough matrix."""

def __post_init__(self):
self.initial_state = (
jnp.array(0) if self.A.ndim == 0 else jnp.zeros(self.A.shape[0])
)

if self.initial_state.ndim == 0:
if self.B.ndim == 0:
self.n_inputs = "scalar"
elif self.B.ndim == 1:
self.n_inputs = self.B.size
else:
raise ValueError("Dimension mismatch.")
else:
if self.B.ndim == 1:
self.n_inputs = "scalar"
elif self.B.ndim == 2:
self.n_inputs = self.B.shape[1]
# Without this context manager, `initial_state` will leak later
with jax.ensure_compile_time_eval():
self.initial_state = (
jnp.array(0) if self.A.ndim == 0 else jnp.zeros(self.A.shape[0])
)
if self.initial_state.ndim == 0:
if self.B.ndim == 0:
self.n_inputs = "scalar"
elif self.B.ndim == 1:
self.n_inputs = self.B.size
else:
raise ValueError("Dimension mismatch.")
else:
raise ValueError("Dimension mismatch.")
if self.B.ndim == 1:
self.n_inputs = "scalar"
elif self.B.ndim == 2:
self.n_inputs = self.B.shape[1]
else:
raise ValueError("Dimension mismatch.")

def f(self, x: Array) -> Array:
return self.A.dot(x)
Expand Down

0 comments on commit 0280c2a

Please sign in to comment.