diff --git a/aiogram_bot_template/bot.py b/aiogram_bot_template/bot.py index e3efd8f6..4a4760b8 100644 --- a/aiogram_bot_template/bot.py +++ b/aiogram_bot_template/bot.py @@ -12,112 +12,18 @@ from aiogram.fsm.storage.redis import DefaultKeyBuilder, RedisStorage from aiohttp import web from redis.asyncio import Redis -from tenacity import _utils from aiogram_bot_template import handlers, utils, web_handlers from aiogram_bot_template.data import config from aiogram_bot_template.middlewares import StructLoggingMiddleware -TIMEOUT_BETWEEN_ATTEMPTS = 2 -MAX_TIMEOUT = 30 - - -def before_log(retry_state: tenacity.RetryCallState) -> None: - if retry_state.outcome is None: - return - if retry_state.outcome.failed: - verb, value = "raised", retry_state.outcome.exception() - else: - verb, value = "returned", retry_state.outcome.result() - logger = retry_state.kwargs["logger"] - logger.info( - "Retrying {callback} in {sleep} seconds as it {verb} {value}".format( - callback=_utils.get_callback_name(retry_state.fn), # type: ignore - sleep=retry_state.next_action.sleep, # type: ignore - verb=verb, - value=value, - ), - callback=_utils.get_callback_name(retry_state.fn), # type: ignore - sleep=retry_state.next_action.sleep, # type: ignore - verb=verb, - value=value, - ) - - -def after_log(retry_state: tenacity.RetryCallState) -> None: - logger = retry_state.kwargs["logger"] - logger.info( - "Finished call to {callback!r} after {time:.2f}, this was the {attempt} time calling it.".format( - callback=_utils.get_callback_name(retry_state.fn), # type: ignore - time=retry_state.seconds_since_start, - attempt=_utils.to_ordinal(retry_state.attempt_number), - ), - callback=_utils.get_callback_name(retry_state.fn), # type: ignore - time=retry_state.seconds_since_start, - attempt=_utils.to_ordinal(retry_state.attempt_number), - ) - - -@tenacity.retry( - wait=tenacity.wait_fixed(TIMEOUT_BETWEEN_ATTEMPTS), - stop=tenacity.stop_after_delay(MAX_TIMEOUT), - before_sleep=before_log, - after=after_log, -) -async def wait_postgres( - logger: structlog.typing.FilteringBoundLogger, - host: str, - port: int, - user: str, - password: str, - database: str, -) -> asyncpg.Pool: - db_pool: asyncpg.Pool = await asyncpg.create_pool( - host=host, - port=port, - user=user, - password=password, - database=database, - min_size=1, - max_size=3, - ) - version = await db_pool.fetchrow("SELECT version() as ver;") - logger.debug("Connected to PostgreSQL.", version=version["ver"]) - return db_pool - - -@tenacity.retry( - wait=tenacity.wait_fixed(TIMEOUT_BETWEEN_ATTEMPTS), - stop=tenacity.stop_after_delay(MAX_TIMEOUT), - before_sleep=before_log, - after=after_log, -) -async def wait_redis_pool( - logger: structlog.typing.FilteringBoundLogger, - host: str, - port: int, - password: str, - database: int, -) -> redis.asyncio.Redis: # type: ignore[type-arg] - redis_pool: redis.asyncio.Redis = redis.asyncio.Redis( # type: ignore[type-arg] - connection_pool=redis.asyncio.ConnectionPool( - host=host, - port=port, - password=password, - db=database, - ) - ) - version = await redis_pool.info("server") - logger.debug("Connected to Redis.", version=version["redis_version"]) - return redis_pool - async def create_db_connections(dp: Dispatcher) -> None: logger: structlog.typing.FilteringBoundLogger = dp["business_logger"] logger.debug("Connecting to PostgreSQL", db="main") try: - db_pool = await wait_postgres( + db_pool = await utils.connect_to_services.wait_postgres( logger=dp["db_logger"], host=config.PG_HOST, port=config.PG_PORT, @@ -132,31 +38,36 @@ async def create_db_connections(dp: Dispatcher) -> None: logger.debug("Succesfully connected to PostgreSQL", db="main") dp["db_pool"] = db_pool - logger.debug("Connecting to Redis") - try: - redis_pool = await wait_redis_pool( - logger=dp["cache_logger"], - host=config.CACHE_HOST, - password=config.CACHE_PASSWORD, - port=config.CACHE_PORT, - database=0, - ) - except tenacity.RetryError: - logger.error("Failed to connect to Redis") - exit(1) - else: - logger.debug("Succesfully connected to Redis") - dp["cache_pool"] = redis_pool - - dp["temp_bot_cloud_session"] = AiohttpSession(json_loads=orjson.loads) + if config.USE_CACHE: + logger.debug("Connecting to Redis") + try: + redis_pool = await utils.connect_to_services.wait_redis_pool( + logger=dp["cache_logger"], + host=config.CACHE_HOST, + password=config.CACHE_PASSWORD, + port=config.CACHE_PORT, + database=0, + ) + except tenacity.RetryError: + logger.error("Failed to connect to Redis") + exit(1) + else: + logger.debug("Succesfully connected to Redis") + dp["cache_pool"] = redis_pool + + dp["temp_bot_cloud_session"] = utils.smart_session.SmartAiogramAiohttpSession( + json_loads=orjson.loads, + logger=dp["aiogram_session_logger"], + ) if config.USE_CUSTOM_API_SERVER: - dp["temp_bot_local_session"] = AiohttpSession( + dp["temp_bot_local_session"] = utils.smart_session.SmartAiogramAiohttpSession( api=TelegramAPIServer( base=config.CUSTOM_API_SERVER_BASE, file=config.CUSTOM_API_SERVER_FILE, is_local=config.CUSTOM_API_SERVER_IS_LOCAL, ), json_loads=orjson.loads, + logger=dp["aiogram_session_logger"], ) @@ -283,17 +194,23 @@ async def setup_aiohttp_app(bot: Bot, dp: Dispatcher) -> web.Application: def main() -> None: + aiogram_session_logger = utils.logging.setup_logger().bind(type="aiogram_session") + if config.USE_CUSTOM_API_SERVER: - session = AiohttpSession( + session = utils.smart_session.SmartAiogramAiohttpSession( api=TelegramAPIServer( base=config.CUSTOM_API_SERVER_BASE, file=config.CUSTOM_API_SERVER_FILE, is_local=config.CUSTOM_API_SERVER_IS_LOCAL, ), json_loads=orjson.loads, + logger=aiogram_session_logger, ) else: - session = AiohttpSession(json_loads=orjson.loads) + session = utils.smart_session.SmartAiogramAiohttpSession( + json_loads=orjson.loads, + logger=aiogram_session_logger, + ) bot = Bot(config.BOT_TOKEN, parse_mode="HTML", session=session) dp = Dispatcher( @@ -307,6 +224,7 @@ def main() -> None: key_builder=DefaultKeyBuilder(with_bot_id=True), ) ) + dp["aiogram_session_logger"] = aiogram_session_logger if config.USE_WEBHOOK: dp.startup.register(aiogram_on_startup_webhook) diff --git a/aiogram_bot_template/data/config.py b/aiogram_bot_template/data/config.py index b0dafff9..bfb27bab 100644 --- a/aiogram_bot_template/data/config.py +++ b/aiogram_bot_template/data/config.py @@ -21,9 +21,12 @@ FSM_PORT: int = env.int("FSM_PORT") FSM_PASSWORD: str = env.str("FSM_PASSWORD") -CACHE_HOST: str = env.str("CACHE_HOST") -CACHE_PORT: int = env.int("CACHE_PORT") -CACHE_PASSWORD: str = env.str("CACHE_PASSWORD") +USE_CACHE: bool = env.bool("USE_CACHE") + +if USE_CACHE: + CACHE_HOST: str = env.str("CACHE_HOST") + CACHE_PORT: int = env.int("CACHE_PORT") + CACHE_PASSWORD: str = env.str("CACHE_PASSWORD") USE_WEBHOOK: bool = env.bool("USE_WEBHOOK", False) diff --git a/aiogram_bot_template/db/db_api/storages/basestorage/storage.py b/aiogram_bot_template/db/db_api/storages/basestorage/storage.py index 3581db2f..295ef54e 100644 --- a/aiogram_bot_template/db/db_api/storages/basestorage/storage.py +++ b/aiogram_bot_template/db/db_api/storages/basestorage/storage.py @@ -1,26 +1,54 @@ -from typing import Any, Optional, Type, TypeVar +import typing +from typing import Any, Optional, TypeVar T = TypeVar("T") +class SingleQueryResult: + def __init__(self, result: Optional[typing.Mapping[str, Any]]): + self._data = {**result} if result else None + + @property + def data(self) -> Optional[dict[str, Any]]: + return self._data + + def convert(self, model: type[T]) -> Optional[T]: + return model(**self.data) if self._data else None + + +class MultipleQueryResults: + def __init__(self, results: list[typing.Mapping[str, Any]]): + self._data: list[dict[str, Any]] = [{**i} for i in results] + + @property + def data(self) -> list[dict[str, Any]]: + return self._data + + def convert(self, model: type[T]) -> list[T]: + return [model(**i) for i in self._data] + + class RawConnection: async def _fetch( self, sql: str, - params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]], - model_type: Type[T], - ) -> Optional[list[T]]: + params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]] = None, + con: Optional[Any] = None, + ) -> MultipleQueryResults: raise NotImplementedError async def _fetchrow( self, sql: str, - params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]], - model_type: Type[T], - ) -> Optional[T]: + params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]] = None, + con: Optional[Any] = None, + ) -> SingleQueryResult: raise NotImplementedError async def _execute( - self, sql: str, params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]] + self, + sql: str, + params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]] = None, + con: Optional[Any] = None, ) -> None: raise NotImplementedError diff --git a/aiogram_bot_template/db/db_api/storages/postgres/storage.py b/aiogram_bot_template/db/db_api/storages/postgres/storage.py index 0990cb2e..48a12c76 100644 --- a/aiogram_bot_template/db/db_api/storages/postgres/storage.py +++ b/aiogram_bot_template/db/db_api/storages/postgres/storage.py @@ -1,9 +1,10 @@ -from typing import Any, Optional, Type, TypeVar +import time +from typing import Any, Optional, TypeVar import asyncpg import structlog -from ..basestorage.storage import RawConnection +from ..basestorage.storage import MultipleQueryResults, RawConnection, SingleQueryResult T = TypeVar("T") @@ -21,61 +22,94 @@ def __init__( async def _fetch( self, sql: str, - params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]], - model_type: Type[T], - ) -> list[T]: + params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]] = None, + con: Optional[asyncpg.Connection] = None, + ) -> MultipleQueryResults: + st = time.monotonic() request_logger = self._logger.bind(sql=sql, params=params) request_logger.debug("Making query to DB") - con: asyncpg.Connection - async with self._pool.acquire() as con: - try: - if params is not None: - raw = await con.fetch(sql, *params) - else: - raw = await con.fetch(sql) - except Exception as e: - # change to appropriate error handling - request_logger = request_logger.bind(error=e) - request_logger.error(f"{e}") + try: + if con is None: + async with self._pool.acquire() as con: + if params is not None: + raw_result = await con.fetch(sql, *params) + else: + raw_result = await con.fetch(sql) else: - if raw: - return [_convert_to_model(i, model_type) for i in raw] + if params is not None: + raw_result = await con.fetch(sql, *params) else: - return [] - return [] + raw_result = await con.fetch(sql) + except Exception as e: + # change to appropriate error handling + request_logger = request_logger.bind(error=e) + request_logger.error(f"Error while making query: {e}") + raise e + else: + results = [i for i in raw_result] + finally: + request_logger.debug( + "Finished query to DB", spent_time_ms=(time.monotonic() - st) * 1000 + ) + + return MultipleQueryResults(results) async def _fetchrow( self, sql: str, - params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]], - model_type: Type[T], - ) -> Optional[T]: + params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]] = None, + con: Optional[asyncpg.Connection] = None, + ) -> SingleQueryResult: + st = time.monotonic() request_logger = self._logger.bind(sql=sql, params=params) request_logger.debug("Making query to DB") - con: asyncpg.Connection - async with self._pool.acquire() as con: - try: + + try: + if con is None: + async with self._pool.acquire() as con: + if params is not None: + raw = await con.fetchrow(sql, *params) + else: + raw = await con.fetchrow(sql) + else: if params is not None: raw = await con.fetchrow(sql, *params) else: raw = await con.fetchrow(sql) - except Exception as e: - # change to appropriate error handling - request_logger = self._logger.bind(error=e) - request_logger.error(f"{e}") - else: - if raw is not None: - return _convert_to_model(raw, model_type) - return None + except Exception as e: + # change to appropriate error handling + request_logger = request_logger.bind(error=e) + request_logger.error(f"Error while making query: {e}") + raise e + else: + result = raw + finally: + request_logger.debug( + "Finished query to DB", spent_time_ms=(time.monotonic() - st) * 1000 + ) + + return SingleQueryResult(result) async def _execute( - self, sql: str, params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]] + self, + sql: str, + params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]] = None, + con: Optional[asyncpg.Connection] = None, ) -> None: + st = time.monotonic() request_logger = self._logger.bind(sql=sql, params=params) request_logger.debug("Making query to DB") - con: asyncpg.Connection - async with self._pool.acquire() as con: - try: + try: + if con is None: + async with self._pool.acquire() as con: + if params is not None: + if isinstance(params, list): + await con.executemany(sql, params) + else: + await con.execute(sql, *params) + else: + await con.execute(sql) + else: if params is not None: if isinstance(params, list): await con.executemany(sql, params) @@ -83,11 +117,12 @@ async def _execute( await con.execute(sql, *params) else: await con.execute(sql) - except Exception as e: - # change to appropriate error handling - request_logger = request_logger.bind(error=e) - request_logger.error(f"{e}") - - -def _convert_to_model(data: asyncpg.Record, model: Type[T]) -> T: - return model(**data) + except Exception as e: + # change to appropriate error handling + request_logger = self._logger.bind(error=e) + request_logger.error(f"Error while making query: {e}") + raise e + finally: + request_logger.debug( + "Finished query to DB", spent_time_ms=(time.monotonic() - st) * 1000 + ) diff --git a/aiogram_bot_template/utils/__init__.py b/aiogram_bot_template/utils/__init__.py index ee972c5f..4f86f5db 100644 --- a/aiogram_bot_template/utils/__init__.py +++ b/aiogram_bot_template/utils/__init__.py @@ -1,2 +1,4 @@ from . import chunks as chunks +from . import connect_to_services as connect_to_services from . import logging as logging +from . import smart_session as smart_session diff --git a/aiogram_bot_template/utils/connect_to_services.py b/aiogram_bot_template/utils/connect_to_services.py new file mode 100644 index 00000000..399b49e5 --- /dev/null +++ b/aiogram_bot_template/utils/connect_to_services.py @@ -0,0 +1,99 @@ +import asyncpg +import redis +import structlog +import tenacity +from redis.asyncio import ConnectionPool, Redis +from tenacity import _utils + +TIMEOUT_BETWEEN_ATTEMPTS = 2 +MAX_TIMEOUT = 30 + + +def before_log(retry_state: tenacity.RetryCallState) -> None: + if retry_state.outcome is None: + return + if retry_state.outcome.failed: + verb, value = "raised", retry_state.outcome.exception() + else: + verb, value = "returned", retry_state.outcome.result() + logger = retry_state.kwargs["logger"] + logger.info( + "Retrying {callback} in {sleep} seconds as it {verb} {value}".format( + callback=_utils.get_callback_name(retry_state.fn), # type: ignore[arg-type] + sleep=retry_state.next_action.sleep, # type: ignore[union-attr] + verb=verb, + value=value, + ), + callback=_utils.get_callback_name(retry_state.fn), # type: ignore[arg-type] + sleep=retry_state.next_action.sleep, # type: ignore[union-attr] + verb=verb, + value=value, + ) + + +def after_log(retry_state: tenacity.RetryCallState) -> None: + logger = retry_state.kwargs["logger"] + logger.info( + "Finished call to {callback!r} after {time:.2f}, this was the {attempt} time calling it.".format( # type: ignore[str-format] + callback=_utils.get_callback_name(retry_state.fn), # type: ignore[arg-type] + time=retry_state.seconds_since_start, + attempt=_utils.to_ordinal(retry_state.attempt_number), + ), + callback=_utils.get_callback_name(retry_state.fn), # type: ignore[arg-type] + time=retry_state.seconds_since_start, + attempt=_utils.to_ordinal(retry_state.attempt_number), + ) + + +@tenacity.retry( + wait=tenacity.wait_fixed(TIMEOUT_BETWEEN_ATTEMPTS), + stop=tenacity.stop_after_delay(MAX_TIMEOUT), + before_sleep=before_log, + after=after_log, +) +async def wait_postgres( + logger: structlog.typing.FilteringBoundLogger, + host: str, + port: int, + user: str, + password: str, + database: str, +) -> asyncpg.Pool: + db_pool = await asyncpg.create_pool( + host=host, + port=port, + user=user, + password=password, + database=database, + min_size=1, + max_size=3, + ) + version = await db_pool.fetchrow("SELECT version() as ver;") + logger.debug("Connected to PostgreSQL.", version=version["ver"]) + return db_pool + + +@tenacity.retry( + wait=tenacity.wait_fixed(TIMEOUT_BETWEEN_ATTEMPTS), + stop=tenacity.stop_after_delay(MAX_TIMEOUT), + before_sleep=before_log, + after=after_log, +) +async def wait_redis_pool( + logger: structlog.typing.FilteringBoundLogger, + host: str, + port: int, + password: str, + database: int, +) -> redis.asyncio.Redis: # type: ignore[type-arg] + redis_pool: redis.asyncio.Redis = Redis( # type: ignore[type-arg] + connection_pool=ConnectionPool( + host=host, + port=port, + password=password, + db=database, + ) + ) + version = await redis_pool.info("server") + logger.debug("Connected to Redis.", version=version["redis_version"]) + return redis_pool diff --git a/aiogram_bot_template/utils/smart_session.py b/aiogram_bot_template/utils/smart_session.py new file mode 100644 index 00000000..f0a913bd --- /dev/null +++ b/aiogram_bot_template/utils/smart_session.py @@ -0,0 +1,80 @@ +import asyncio +import time +from typing import Any, Optional + +import structlog.typing +from aiogram import Bot +from aiogram.client.session.aiohttp import AiohttpSession +from aiogram.exceptions import ( + RestartingTelegram, + TelegramRetryAfter, + TelegramServerError, +) +from aiogram.methods.base import TelegramMethod, TelegramType + + +class StructLogAiogramAiohttpSessions(AiohttpSession): + def __init__(self, logger: structlog.typing.FilteringBoundLogger, **kwargs: Any): + super().__init__(**kwargs) + self._logger = logger + + async def make_request( + self, + bot: Bot, + method: TelegramMethod[TelegramType], + timeout: Optional[int] = None, + ) -> TelegramType: + req_logger = self._logger.bind( + bot=bot.token, + method=method.model_dump(exclude_none=True, exclude_unset=True), + timeout=timeout, + api=self.api, + url=self.api.api_url(bot.token, method.__api_method__), + ) + st = time.monotonic() + req_logger.debug("Making request to API") + try: + res = await super().make_request(bot, method, timeout) + except Exception as e: + req_logger.error( + "API error", + error=e, + time_spent_ms=(time.monotonic() - st) * 1000, + ) + raise e + req_logger.debug( + "API response", + response=( + res.model_dump(exclude_none=True, exclude_unset=True) + if hasattr(res, "model_dump") + else res + ), + time_spent_ms=(time.monotonic() - st) * 1000, + ) + return res + + +class SmartAiogramAiohttpSession(StructLogAiogramAiohttpSessions): + async def make_request( + self, + bot: Bot, + method: TelegramMethod[TelegramType], + timeout: Optional[int] = None, + ) -> TelegramType: + attempt = 0 + while True: + attempt += 1 + try: + res = await super().make_request(bot, method, timeout) + except TelegramRetryAfter as e: + await asyncio.sleep(e.retry_after) + except (RestartingTelegram, TelegramServerError): + if attempt > 6: + sleepy_time = 64 + else: + sleepy_time = 2**attempt + await asyncio.sleep(sleepy_time) + except Exception as e: + raise e + else: + return res diff --git a/poetry.lock b/poetry.lock index 0c1a891e..7864358c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1121,6 +1121,17 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "pytz" +version = "2023.3.post1" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +files = [ + {file = "pytz-2023.3.post1-py2.py3-none-any.whl", hash = "sha256:ce42d816b81b68506614c11e8937d3aa9e41007ceb50bfdcb0749b921bf646c7"}, + {file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"}, +] + [[package]] name = "redis" version = "5.0.0" @@ -1221,6 +1232,17 @@ files = [ [package.dependencies] cryptography = ">=35.0.0" +[[package]] +name = "types-pytz" +version = "2023.3.0.1" +description = "Typing stubs for pytz" +optional = false +python-versions = "*" +files = [ + {file = "types-pytz-2023.3.0.1.tar.gz", hash = "sha256:1a7b8d4aac70981cfa24478a41eadfcd96a087c986d6f150d77e3ceb3c2bdfab"}, + {file = "types_pytz-2023.3.0.1-py3-none-any.whl", hash = "sha256:65152e872137926bb67a8fe6cc9cfd794365df86650c5d5fdc7b167b0f38892e"}, +] + [[package]] name = "types-redis" version = "4.6.0.6" @@ -1337,4 +1359,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "3732d81a70a3c3418679051fb1029a459983f1aa43796a914f6e78620b1cf9ca" +content-hash = "ec7da3e029a569ebe899bd5e808f882e13008fcf7bd5ae276a22aa6523be56e0" diff --git a/pyproject.toml b/pyproject.toml index 9a39c485..bc2ec9ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ redis = "5.0.0" orjson = "3.9.7" structlog = "^23.1.0" tenacity = "8.2.3" +pytz = "^2023.3.post1" [tool.poetry.group.dev] optional = true @@ -76,6 +77,7 @@ black = { extras = ["d"], version = "^23.3.0" } ruff = "^0.0.267" types-redis = "^4.5.5.2" isort= "5.12.0" +types-pytz = "^2023.3.0.1"