Skip to content

Commit

Permalink
Merge pull request #80 from uriyyo/feature/caching
Browse files Browse the repository at this point in the history
🎉 Add cached decorator to prevent multiple calls to same endpoints
  • Loading branch information
uriyyo authored Oct 9, 2020
2 parents 699516d + 522f55b commit 4d59e29
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 35 deletions.
60 changes: 58 additions & 2 deletions instapi/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from collections import deque
from contextvars import ContextVar
from dataclasses import dataclass
from functools import partial, wraps
from hashlib import md5
from itertools import chain
from pathlib import Path
from typing import Optional
from time import time
from typing import Any, Callable, Deque, Dict, Optional, Tuple, TypeVar

from instagram_private_api.http import ClientCookieJar

Expand Down Expand Up @@ -39,4 +44,55 @@ def write_to_cache(credentials: Credentials, cookie: ClientCookieJar) -> None:
cache.write_bytes(cookie.dump())


__all__ = ["get_from_cache", "write_to_cache"]
CACHED_TIME = ContextVar("CACHED_TIME", default=60)
CacheKey = Tuple[Tuple, Tuple]

T = TypeVar("T")


@dataclass
class _CacheInfo:
cache: Dict[CacheKey, Any]
keys: Deque[Tuple[CacheKey, float]]


def cached(func: Callable[..., T]) -> Callable[..., T]:
cache: Dict[CacheKey, Any] = {}
keys: Deque[Tuple[CacheKey, float]] = deque()

def _delete_expired_keys() -> None: # pragma: no cover
while keys:
key, expired = keys[0]

if expired > time():
break

keys.popleft()
del cache[key]

def _add_key(key: CacheKey) -> None:
keys.append((key, time() + CACHED_TIME.get()))

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
_delete_expired_keys()

key: CacheKey = (args, tuple(kwargs.items()))

if key not in cache:
cache[key] = func(*args, **kwargs)
_add_key(key)

return cache[key]

wrapper.info: Callable[..., _CacheInfo] = partial(_CacheInfo, cache, keys) # type: ignore

return wrapper


__all__ = [
"CACHED_TIME",
"cached",
"get_from_cache",
"write_to_cache",
]
21 changes: 6 additions & 15 deletions instapi/models/direct.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
from typing import Any, Iterable, List, Optional, Tuple, Type

from ..cache import cached
from ..client import client
from ..types import StrDict
from ..utils import process_many, to_list
Expand All @@ -19,18 +20,9 @@ class Message(BaseModel):
story_share: StrDict = field(default_factory=dict)

@classmethod
def create(cls: Type[ModelT_co], data: StrDict, cache: Dict[int, User] = None) -> ModelT_co:
def create(cls: Type[ModelT_co], data: StrDict) -> ModelT_co:
user_id = data.pop("user_id")

if cache is not None and user_id in cache:
user = cache[user_id]
else:
user = User.get(user_id)

if cache is not None:
cache[user_id] = user

return super().create({"user": user, **data}) # type: ignore
return super().create({"user": User.get(user_id), **data}) # type: ignore

def as_dict(self) -> StrDict:
data = super().as_dict()
Expand Down Expand Up @@ -68,6 +60,7 @@ def directs(cls, limit: Optional[int] = None) -> List["Direct"]:
return to_list(cls.iter_directs(), limit)

@classmethod
@cached
def with_user(cls: Type[ModelT_co], user: User) -> ModelT_co:
result = client.direct_v2_get_by_participants(user)

Expand All @@ -77,15 +70,13 @@ def with_user(cls: Type[ModelT_co], user: User) -> ModelT_co:
return cls(user.username, "private", False, (user,)) # type: ignore

def iter_message(self) -> Iterable["Message"]:
cache: Dict[int, User] = {}

for response in process_many(
client.direct_v2_thread,
self.thread_id,
key="cursor",
key_path="thread.oldest_cursor",
):
yield from (Message.create(item, cache) for item in response["thread"]["items"])
yield from map(Message.create, response["thread"]["items"])

def messages(self, limit: Optional[int] = None) -> List["Message"]:
return to_list(self.iter_message(), limit)
Expand Down
2 changes: 2 additions & 0 deletions instapi/models/feed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Iterable, List, Optional

from ..cache import cached
from ..client import client
from ..types import StrDict
from ..utils import process_many, to_list
Expand Down Expand Up @@ -55,6 +56,7 @@ def _resources(self) -> Iterable["StrDict"]:
media_info = self._media_info()
return media_info.get("carousel_media", [media_info])

@cached
def user_tags(self) -> List["User"]:
"""
Generate list of Users from Feed usertags
Expand Down
5 changes: 5 additions & 0 deletions instapi/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from itertools import chain
from typing import TYPE_CHECKING, Counter, Iterable, List, Optional, cast

from ..cache import cached
from ..client import client
from ..types import StrDict
from ..utils import process_many, to_list
Expand All @@ -22,6 +23,7 @@ class User(Entity):
is_verified: bool

@classmethod
@cached
def get(cls, pk: int) -> "User":
"""
Create User object from unique user's identifier
Expand All @@ -32,6 +34,7 @@ def get(cls, pk: int) -> "User":
return cls.create(client.user_info(pk)["user"])

@classmethod
@cached
def from_username(cls, username: str) -> "User":
"""
Create User object from username
Expand All @@ -58,6 +61,7 @@ def match_username(cls, username: str, limit: Optional[int] = None) -> List["Use
return [cls.create(user) for user in response["users"]]

@classmethod
@cached
def self(cls) -> "User":
"""
Create User object from current user
Expand Down Expand Up @@ -105,6 +109,7 @@ def following_count(self) -> int:
def user_detail(self) -> StrDict:
return cast(StrDict, self.full_info()["user_detail"]["user"])

@cached
def full_info(self) -> StrDict:
return cast(StrDict, client.user_detail_info(self.pk))

Expand Down
16 changes: 0 additions & 16 deletions tests/unit_tests/models/test_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,6 @@ def test_direct_with_limit(self, mock_thread, direct, messages):
assert direct.messages(limit) == messages[:limit]


def test_messages_cache(mocker, message):
mocker.patch("instapi.models.user.User.get", return_value=message.user)

data = message.as_dict()
cache = {}

m1 = Message.create({**data}, cache)

assert data["user_id"] in cache

m2 = Message.create({**data}, cache)

# Second call should use user from a cache
assert m1.user is m2.user


class TestWithUser:
def test_with_user_thread_exists(self, mocker, direct, user):
mocker.patch(
Expand Down
3 changes: 1 addition & 2 deletions tests/unit_tests/models/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def test_user_details(user, mocker):
assert user.user_detail() == user_details
assert user.full_info() == full_info

details_mock.assert_called_with(user.pk)
assert details_mock.call_count == 6
details_mock.assert_called_once_with(user.pk)


def test_follow(user, mocker):
Expand Down

0 comments on commit 4d59e29

Please sign in to comment.