diff --git a/instapi/cache.py b/instapi/cache.py index e5a10d3..a72e383 100644 --- a/instapi/cache.py +++ b/instapi/cache.py @@ -1,7 +1,12 @@ +from collections import deque +from contextvars import ContextVar +from dataclasses import dataclass +from functools import partial, wraps from hashlib import md5 from itertools import chain from pathlib import Path -from typing import Optional +from time import time +from typing import Any, Callable, Deque, Dict, Optional, Tuple, TypeVar from instagram_private_api.http import ClientCookieJar @@ -39,4 +44,55 @@ def write_to_cache(credentials: Credentials, cookie: ClientCookieJar) -> None: cache.write_bytes(cookie.dump()) -__all__ = ["get_from_cache", "write_to_cache"] +CACHED_TIME = ContextVar("CACHED_TIME", default=60) +CacheKey = Tuple[Tuple, Tuple] + +T = TypeVar("T") + + +@dataclass +class _CacheInfo: + cache: Dict[CacheKey, Any] + keys: Deque[Tuple[CacheKey, float]] + + +def cached(func: Callable[..., T]) -> Callable[..., T]: + cache: Dict[CacheKey, Any] = {} + keys: Deque[Tuple[CacheKey, float]] = deque() + + def _delete_expired_keys() -> None: # pragma: no cover + while keys: + key, expired = keys[0] + + if expired > time(): + break + + keys.popleft() + del cache[key] + + def _add_key(key: CacheKey) -> None: + keys.append((key, time() + CACHED_TIME.get())) + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + _delete_expired_keys() + + key: CacheKey = (args, tuple(kwargs.items())) + + if key not in cache: + cache[key] = func(*args, **kwargs) + _add_key(key) + + return cache[key] + + wrapper.info: Callable[..., _CacheInfo] = partial(_CacheInfo, cache, keys) # type: ignore + + return wrapper + + +__all__ = [ + "CACHED_TIME", + "cached", + "get_from_cache", + "write_to_cache", +] diff --git a/instapi/models/direct.py b/instapi/models/direct.py index b9d48f8..7858026 100644 --- a/instapi/models/direct.py +++ b/instapi/models/direct.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type +from typing import Any, Iterable, List, Optional, Tuple, Type +from ..cache import cached from ..client import client from ..types import StrDict from ..utils import process_many, to_list @@ -19,18 +20,9 @@ class Message(BaseModel): story_share: StrDict = field(default_factory=dict) @classmethod - def create(cls: Type[ModelT_co], data: StrDict, cache: Dict[int, User] = None) -> ModelT_co: + def create(cls: Type[ModelT_co], data: StrDict) -> ModelT_co: user_id = data.pop("user_id") - - if cache is not None and user_id in cache: - user = cache[user_id] - else: - user = User.get(user_id) - - if cache is not None: - cache[user_id] = user - - return super().create({"user": user, **data}) # type: ignore + return super().create({"user": User.get(user_id), **data}) # type: ignore def as_dict(self) -> StrDict: data = super().as_dict() @@ -68,6 +60,7 @@ def directs(cls, limit: Optional[int] = None) -> List["Direct"]: return to_list(cls.iter_directs(), limit) @classmethod + @cached def with_user(cls: Type[ModelT_co], user: User) -> ModelT_co: result = client.direct_v2_get_by_participants(user) @@ -77,15 +70,13 @@ def with_user(cls: Type[ModelT_co], user: User) -> ModelT_co: return cls(user.username, "private", False, (user,)) # type: ignore def iter_message(self) -> Iterable["Message"]: - cache: Dict[int, User] = {} - for response in process_many( client.direct_v2_thread, self.thread_id, key="cursor", key_path="thread.oldest_cursor", ): - yield from (Message.create(item, cache) for item in response["thread"]["items"]) + yield from map(Message.create, response["thread"]["items"]) def messages(self, limit: Optional[int] = None) -> List["Message"]: return to_list(self.iter_message(), limit) diff --git a/instapi/models/feed.py b/instapi/models/feed.py index 9cf04a3..300bb30 100644 --- a/instapi/models/feed.py +++ b/instapi/models/feed.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import Iterable, List, Optional +from ..cache import cached from ..client import client from ..types import StrDict from ..utils import process_many, to_list @@ -55,6 +56,7 @@ def _resources(self) -> Iterable["StrDict"]: media_info = self._media_info() return media_info.get("carousel_media", [media_info]) + @cached def user_tags(self) -> List["User"]: """ Generate list of Users from Feed usertags diff --git a/instapi/models/user.py b/instapi/models/user.py index 8f1a750..8a41fc7 100644 --- a/instapi/models/user.py +++ b/instapi/models/user.py @@ -3,6 +3,7 @@ from itertools import chain from typing import TYPE_CHECKING, Counter, Iterable, List, Optional, cast +from ..cache import cached from ..client import client from ..types import StrDict from ..utils import process_many, to_list @@ -22,6 +23,7 @@ class User(Entity): is_verified: bool @classmethod + @cached def get(cls, pk: int) -> "User": """ Create User object from unique user's identifier @@ -32,6 +34,7 @@ def get(cls, pk: int) -> "User": return cls.create(client.user_info(pk)["user"]) @classmethod + @cached def from_username(cls, username: str) -> "User": """ Create User object from username @@ -58,6 +61,7 @@ def match_username(cls, username: str, limit: Optional[int] = None) -> List["Use return [cls.create(user) for user in response["users"]] @classmethod + @cached def self(cls) -> "User": """ Create User object from current user @@ -105,6 +109,7 @@ def following_count(self) -> int: def user_detail(self) -> StrDict: return cast(StrDict, self.full_info()["user_detail"]["user"]) + @cached def full_info(self) -> StrDict: return cast(StrDict, client.user_detail_info(self.pk)) diff --git a/tests/unit_tests/models/test_direct.py b/tests/unit_tests/models/test_direct.py index 85335b9..3f9660e 100644 --- a/tests/unit_tests/models/test_direct.py +++ b/tests/unit_tests/models/test_direct.py @@ -77,22 +77,6 @@ def test_direct_with_limit(self, mock_thread, direct, messages): assert direct.messages(limit) == messages[:limit] -def test_messages_cache(mocker, message): - mocker.patch("instapi.models.user.User.get", return_value=message.user) - - data = message.as_dict() - cache = {} - - m1 = Message.create({**data}, cache) - - assert data["user_id"] in cache - - m2 = Message.create({**data}, cache) - - # Second call should use user from a cache - assert m1.user is m2.user - - class TestWithUser: def test_with_user_thread_exists(self, mocker, direct, user): mocker.patch( diff --git a/tests/unit_tests/models/test_user.py b/tests/unit_tests/models/test_user.py index d0f0e39..4a5efef 100644 --- a/tests/unit_tests/models/test_user.py +++ b/tests/unit_tests/models/test_user.py @@ -127,8 +127,7 @@ def test_user_details(user, mocker): assert user.user_detail() == user_details assert user.full_info() == full_info - details_mock.assert_called_with(user.pk) - assert details_mock.call_count == 6 + details_mock.assert_called_once_with(user.pk) def test_follow(user, mocker):