Skip to content

Commit

Permalink
fix: incorrect type hints on memoize key #958 (#959)
Browse files Browse the repository at this point in the history
* Fixes incorrect type hints on memoize `key` #958

* hashable now comes from collections.abc
  • Loading branch information
ntjess authored Jan 29, 2025
1 parent 5af26ab commit 26ebb83
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions solara/cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Hashable
import hashlib
import inspect
import logging
Expand Down Expand Up @@ -50,6 +51,7 @@
class Memory(cachetools.LRUCache):
def __init__(self, max_items=solara.settings.cache.memory_max_items):
super().__init__(maxsize=max_items)

else:

class Memory(dict): # type: ignore
Expand All @@ -64,7 +66,7 @@ def _default_key(*args, **kwargs):


class MemoizedFunction(Generic[P, R]):
def __init__(self, function: Callable[P, R], key: Callable[P, R], storage: Optional[Storage], allow_nonlocals=False):
def __init__(self, function: Callable[P, R], key: Callable[P, Hashable], storage: Optional[Storage], allow_nonlocals=False):
self.function = function
f: Callable = self.function
if not allow_nonlocals:
Expand Down Expand Up @@ -170,7 +172,7 @@ def memoize(
@overload
def memoize(
function: None = None,
key: Callable[P, R] = ...,
key: Callable[P, Hashable] = ...,
storage: Optional[Storage] = None,
allow_nonlocals=False,
) -> Callable[[Callable[P, R]], MemoizedFunction[P, R]]: ...
Expand All @@ -187,7 +189,7 @@ def memoize(

def memoize(
function: Union[None, Callable[P, R]] = None,
key: Union[None, Callable[P, R]] = None,
key: Union[None, Callable[P, Hashable]] = None,
storage: Optional[Storage] = None,
allow_nonlocals: bool = False,
) -> Union[Callable[[Callable[P, R]], MemoizedFunction[P, R]], MemoizedFunction[P, R]]:
Expand Down Expand Up @@ -249,7 +251,7 @@ def mean(df, column):
def wrapper(func: Callable[P, R]) -> MemoizedFunction[P, R]:
return MemoizedFunction[P, R](
func,
cast(Callable[P, R], key or _default_key),
cast(Callable[P, Hashable], key or _default_key),
storage,
allow_nonlocals,
)
Expand Down

0 comments on commit 26ebb83

Please sign in to comment.