diff --git a/numpyro/optim.py b/numpyro/optim.py index d58f7c739..8d1717b16 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -9,7 +9,7 @@ from collections import namedtuple from collections.abc import Callable -from typing import Any, Generic, TypeVar +from typing import Any import jax from jax import jacfwd, lax, value_and_grad @@ -32,8 +32,8 @@ "SM3", ] -_Params = TypeVar("_Params") -_OptState = TypeVar("_OptState") +_Params = Any +_OptState = Any _IterOptState = tuple[ArrayLike, _OptState] @@ -50,7 +50,7 @@ def _wrapper(x): return value_and_grad(f, has_aux=True)(x) -class _NumPyroOptim(Generic[_Params, _OptState]): +class _NumPyroOptim(object): def __init__(self, optim_fn: Callable, *args, **kwargs) -> None: self.init_fn: Callable[[_Params], _IterOptState] self.update_fn: Callable[[ArrayLike, _Params, _OptState], _OptState] @@ -258,7 +258,7 @@ def update_fn( # we don't use update_fn in Minimize, so let it do nothing return opt_state - def get_params(opt_state: _MinimizeState) -> _Params: # type: ignore[type-var] + def get_params(opt_state: _MinimizeState) -> _Params: flat_params, unravel_fn = opt_state return unravel_fn(flat_params)