Skip to content

Commit

Permalink
fix typing import
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Apr 15, 2024
1 parent 7c625b5 commit fa585ef
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions dynax/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import typing
from typing import Callable, TypeAlias

import jax.typing
import jaxtyping


Expand All @@ -12,8 +11,8 @@
from jax import Array as Array
from jax.typing import ArrayLike as ArrayLike

Scalar: TypeAlias = jax.Array
ScalarLike: TypeAlias = jax.typing.ArrayLike
Scalar: TypeAlias = Array
ScalarLike: TypeAlias = ArrayLike
elif generating_docs:
# In the docs.
class Scalar:
Expand All @@ -33,10 +32,11 @@ class ArrayLike:
cls.__qualname__ = cls.__name__
else:
# At runtime.
Array = jax.typing.Array
ArrayLike = jax.typing.ArrayLike
Scalar = jaxtyping.Shaped[jax.typing.Array, ""]
ScalarLike = jaxtyping.Shaped[jax.typing.ArrayLike, ""]
from jax import Array
from jax.typing import ArrayLike as ArrayLike

Scalar = jaxtyping.Shaped[Array, ""]
ScalarLike = jaxtyping.Shaped[ArrayLike, ""]


VectorFunc: TypeAlias = Callable[[Array], Array]
Expand Down
4 changes: 2 additions & 2 deletions dynax/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .util import broadcast_right, mse, nmse, nrmse, value_and_jacfwd


def _get_bounds(module: eqx.Module) -> tuple[Array, Array]:
def _get_bounds(module: eqx.Module) -> tuple[list[float], list[float]]:
"""Build flattened arrays of lower and upper parameter bounds."""
lower_bounds = []
upper_bounds = []
Expand All @@ -52,7 +52,7 @@ def _get_bounds(module: eqx.Module) -> tuple[Array, Array]:
size = np.asarray(value).size
lower_bounds.extend([-np.inf] * size)
upper_bounds.extend([np.inf] * size)
return jnp.asarray(lower_bounds), jnp.asarray(upper_bounds)
return list(lower_bounds), list(upper_bounds)


def _key_paths(tree: Any, root: str = "tree") -> list[str]:
Expand Down

0 comments on commit fa585ef

Please sign in to comment.