diff --git a/tatsu_api/client.py b/tatsu_api/client.py index e39d151..c3bfe26 100644 --- a/tatsu_api/client.py +++ b/tatsu_api/client.py @@ -2,9 +2,10 @@ import asyncio import itertools -from types import TracebackType from typing import TYPE_CHECKING, Literal +import aiohttp + from .enums import ActionType from .http import HTTPClient from .models import ( @@ -24,9 +25,11 @@ if TYPE_CHECKING: + from types import TracebackType + from typing_extensions import Self else: - Self = object + TracebackType = Self = object __all__ = ("Client",) @@ -39,10 +42,13 @@ class Client: ---------- token: :class:`str` The Tatsu API key that will be used to authorize all requests to it. + session: :class:`aiohttp.ClientSession`, optional + A web client session to use for connecting to the API. If provided, the library is not responsible for closing + it. If not provided, the client will create one. """ - def __init__(self, token: str) -> None: - self.http = HTTPClient(token) + def __init__(self, token: str, *, session: aiohttp.ClientSession | None = None) -> None: + self.http = HTTPClient(token, session=session) async def __aenter__(self) -> Self: return self @@ -216,7 +222,7 @@ async def get_guild_rankings( rankings_list = [GUILD_RANKINGS_DECODER.decode(result) for result in results] truncated_rankings = tuple( ranking - for ranking in itertools.chain(*[item.rankings for item in rankings_list]) + for ranking in itertools.chain.from_iterable(item.rankings for item in rankings_list) if ranking.rank in range(start + 1, end + 2) ) return GuildRankings(str(guild_id), truncated_rankings)