Skip to content

Commit

Permalink
TypeVar -> Any
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Jan 3, 2025
1 parent de835c6 commit e24a7a2
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions numpyro/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from collections import namedtuple
from collections.abc import Callable
from typing import Any, Generic, TypeVar
from typing import Any

import jax
from jax import jacfwd, lax, value_and_grad
Expand All @@ -32,8 +32,8 @@
"SM3",
]

_Params = TypeVar("_Params")
_OptState = TypeVar("_OptState")
_Params = Any
_OptState = Any
_IterOptState = tuple[ArrayLike, _OptState]


Expand All @@ -50,7 +50,7 @@ def _wrapper(x):
return value_and_grad(f, has_aux=True)(x)


class _NumPyroOptim(Generic[_Params, _OptState]):
class _NumPyroOptim(object):
def __init__(self, optim_fn: Callable, *args, **kwargs) -> None:
self.init_fn: Callable[[_Params], _IterOptState]
self.update_fn: Callable[[ArrayLike, _Params, _OptState], _OptState]
Expand Down Expand Up @@ -258,7 +258,7 @@ def update_fn(
# we don't use update_fn in Minimize, so let it do nothing
return opt_state

def get_params(opt_state: _MinimizeState) -> _Params: # type: ignore[type-var]
def get_params(opt_state: _MinimizeState) -> _Params:
flat_params, unravel_fn = opt_state
return unravel_fn(flat_params)

Expand Down

0 comments on commit e24a7a2

Please sign in to comment.