diff --git a/dynax/system.py b/dynax/system.py index a932f1d..c1f11d1 100644 --- a/dynax/system.py +++ b/dynax/system.py @@ -147,21 +147,22 @@ def __check_init__(self): if not hasattr(self, attr): raise AttributeError(f"Attribute '{attr}' not initialized.") - # Check that vector_field and output returns Arrays or scalars and not PyTrees - x = self.initial_state - u = jax.ShapeDtypeStruct(dim2shape(self.n_inputs), jnp.float64) - try: - dx = eqx.filter_eval_shape(self.vector_field, x, u, t=1.0) - y = eqx.filter_eval_shape(self.output, x, u, t=1.0) - except Exception as e: - raise ValueError( - "Can not evaluate output shapes. Check your definitions!" - ) from e - for val, func in zip((dx, y), ("vector_field, output")): # noqa: B905 - if not isinstance(val, jax.ShapeDtypeStruct): + with jax.ensure_compile_time_eval(): + # Check that vector_field and output returns Arrays or scalars - not PyTrees + x = jax.ShapeDtypeStruct(self.initial_state.shape, jnp.float64) + u = jax.ShapeDtypeStruct(dim2shape(self.n_inputs), jnp.float64) + try: + dx = eqx.filter_eval_shape(self.vector_field, x, u, t=1.0) + y = eqx.filter_eval_shape(self.output, x, u, t=1.0) + except Exception as e: raise ValueError( - f"{func} must return arrays or scalars, not {type(val)}" - ) + "Can not evaluate output shapes. Check your definitions!" + ) from e + for val, func in zip((dx, y), ("vector_field, output")): # noqa: B905 + if not isinstance(val, jax.ShapeDtypeStruct): + raise ValueError( + f"{func} must return arrays or scalars, not {type(val)}" + ) @abstractmethod def vector_field( @@ -199,10 +200,12 @@ def output( @property def n_outputs(self) -> int | Literal["scalar"]: """The size of the output vector.""" - x = self.initial_state - u = jax.ShapeDtypeStruct(dim2shape(self.n_inputs), jnp.float64) - y = eqx.filter_eval_shape(self.output, x, u, t=1.0) - return "scalar" if y.ndim == 0 else y.shape[0] + with jax.ensure_compile_time_eval(): + x = jax.ShapeDtypeStruct(self.initial_state.shape, jnp.float64) + u = jax.ShapeDtypeStruct(dim2shape(self.n_inputs), jnp.float64) + y = eqx.filter_eval_shape(self.output, x, u, t=1.0) + n_out = "scalar" if y.ndim == 0 else y.shape[0] + return n_out def linearize( self, @@ -309,24 +312,25 @@ class LinearSystem(AbstractControlAffine): """Feedthrough matrix.""" def __post_init__(self): - self.initial_state = ( - jnp.array(0) if self.A.ndim == 0 else jnp.zeros(self.A.shape[0]) - ) - - if self.initial_state.ndim == 0: - if self.B.ndim == 0: - self.n_inputs = "scalar" - elif self.B.ndim == 1: - self.n_inputs = self.B.size - else: - raise ValueError("Dimension mismatch.") - else: - if self.B.ndim == 1: - self.n_inputs = "scalar" - elif self.B.ndim == 2: - self.n_inputs = self.B.shape[1] + # Without this context manager, `initial_state` will leak later + with jax.ensure_compile_time_eval(): + self.initial_state = ( + jnp.array(0) if self.A.ndim == 0 else jnp.zeros(self.A.shape[0]) + ) + if self.initial_state.ndim == 0: + if self.B.ndim == 0: + self.n_inputs = "scalar" + elif self.B.ndim == 1: + self.n_inputs = self.B.size + else: + raise ValueError("Dimension mismatch.") else: - raise ValueError("Dimension mismatch.") + if self.B.ndim == 1: + self.n_inputs = "scalar" + elif self.B.ndim == 2: + self.n_inputs = self.B.shape[1] + else: + raise ValueError("Dimension mismatch.") def f(self, x: Array) -> Array: return self.A.dot(x)