Skip to content

Commit

Permalink
update dbs to support using transactions
Browse files Browse the repository at this point in the history
added pytz
  • Loading branch information
Forden committed Sep 15, 2023
1 parent 07bcb87 commit 4cba60a
Show file tree
Hide file tree
Showing 9 changed files with 362 additions and 173 deletions.
148 changes: 33 additions & 115 deletions aiogram_bot_template/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,112 +12,18 @@
from aiogram.fsm.storage.redis import DefaultKeyBuilder, RedisStorage
from aiohttp import web
from redis.asyncio import Redis
from tenacity import _utils

from aiogram_bot_template import handlers, utils, web_handlers
from aiogram_bot_template.data import config
from aiogram_bot_template.middlewares import StructLoggingMiddleware

TIMEOUT_BETWEEN_ATTEMPTS = 2
MAX_TIMEOUT = 30


def before_log(retry_state: tenacity.RetryCallState) -> None:
if retry_state.outcome is None:
return
if retry_state.outcome.failed:
verb, value = "raised", retry_state.outcome.exception()
else:
verb, value = "returned", retry_state.outcome.result()
logger = retry_state.kwargs["logger"]
logger.info(
"Retrying {callback} in {sleep} seconds as it {verb} {value}".format(
callback=_utils.get_callback_name(retry_state.fn), # type: ignore
sleep=retry_state.next_action.sleep, # type: ignore
verb=verb,
value=value,
),
callback=_utils.get_callback_name(retry_state.fn), # type: ignore
sleep=retry_state.next_action.sleep, # type: ignore
verb=verb,
value=value,
)


def after_log(retry_state: tenacity.RetryCallState) -> None:
logger = retry_state.kwargs["logger"]
logger.info(
"Finished call to {callback!r} after {time:.2f}, this was the {attempt} time calling it.".format(
callback=_utils.get_callback_name(retry_state.fn), # type: ignore
time=retry_state.seconds_since_start,
attempt=_utils.to_ordinal(retry_state.attempt_number),
),
callback=_utils.get_callback_name(retry_state.fn), # type: ignore
time=retry_state.seconds_since_start,
attempt=_utils.to_ordinal(retry_state.attempt_number),
)


@tenacity.retry(
wait=tenacity.wait_fixed(TIMEOUT_BETWEEN_ATTEMPTS),
stop=tenacity.stop_after_delay(MAX_TIMEOUT),
before_sleep=before_log,
after=after_log,
)
async def wait_postgres(
logger: structlog.typing.FilteringBoundLogger,
host: str,
port: int,
user: str,
password: str,
database: str,
) -> asyncpg.Pool:
db_pool: asyncpg.Pool = await asyncpg.create_pool(
host=host,
port=port,
user=user,
password=password,
database=database,
min_size=1,
max_size=3,
)
version = await db_pool.fetchrow("SELECT version() as ver;")
logger.debug("Connected to PostgreSQL.", version=version["ver"])
return db_pool


@tenacity.retry(
wait=tenacity.wait_fixed(TIMEOUT_BETWEEN_ATTEMPTS),
stop=tenacity.stop_after_delay(MAX_TIMEOUT),
before_sleep=before_log,
after=after_log,
)
async def wait_redis_pool(
logger: structlog.typing.FilteringBoundLogger,
host: str,
port: int,
password: str,
database: int,
) -> redis.asyncio.Redis: # type: ignore[type-arg]
redis_pool: redis.asyncio.Redis = redis.asyncio.Redis( # type: ignore[type-arg]
connection_pool=redis.asyncio.ConnectionPool(
host=host,
port=port,
password=password,
db=database,
)
)
version = await redis_pool.info("server")
logger.debug("Connected to Redis.", version=version["redis_version"])
return redis_pool


async def create_db_connections(dp: Dispatcher) -> None:
logger: structlog.typing.FilteringBoundLogger = dp["business_logger"]

logger.debug("Connecting to PostgreSQL", db="main")
try:
db_pool = await wait_postgres(
db_pool = await utils.connect_to_services.wait_postgres(
logger=dp["db_logger"],
host=config.PG_HOST,
port=config.PG_PORT,
Expand All @@ -132,31 +38,36 @@ async def create_db_connections(dp: Dispatcher) -> None:
logger.debug("Succesfully connected to PostgreSQL", db="main")
dp["db_pool"] = db_pool

logger.debug("Connecting to Redis")
try:
redis_pool = await wait_redis_pool(
logger=dp["cache_logger"],
host=config.CACHE_HOST,
password=config.CACHE_PASSWORD,
port=config.CACHE_PORT,
database=0,
)
except tenacity.RetryError:
logger.error("Failed to connect to Redis")
exit(1)
else:
logger.debug("Succesfully connected to Redis")
dp["cache_pool"] = redis_pool

dp["temp_bot_cloud_session"] = AiohttpSession(json_loads=orjson.loads)
if config.USE_CACHE:
logger.debug("Connecting to Redis")
try:
redis_pool = await utils.connect_to_services.wait_redis_pool(
logger=dp["cache_logger"],
host=config.CACHE_HOST,
password=config.CACHE_PASSWORD,
port=config.CACHE_PORT,
database=0,
)
except tenacity.RetryError:
logger.error("Failed to connect to Redis")
exit(1)
else:
logger.debug("Succesfully connected to Redis")
dp["cache_pool"] = redis_pool

dp["temp_bot_cloud_session"] = utils.smart_session.SmartAiogramAiohttpSession(
json_loads=orjson.loads,
logger=dp["aiogram_session_logger"],
)
if config.USE_CUSTOM_API_SERVER:
dp["temp_bot_local_session"] = AiohttpSession(
dp["temp_bot_local_session"] = utils.smart_session.SmartAiogramAiohttpSession(
api=TelegramAPIServer(
base=config.CUSTOM_API_SERVER_BASE,
file=config.CUSTOM_API_SERVER_FILE,
is_local=config.CUSTOM_API_SERVER_IS_LOCAL,
),
json_loads=orjson.loads,
logger=dp["aiogram_session_logger"],
)


Expand Down Expand Up @@ -283,17 +194,23 @@ async def setup_aiohttp_app(bot: Bot, dp: Dispatcher) -> web.Application:


def main() -> None:
aiogram_session_logger = utils.logging.setup_logger().bind(type="aiogram_session")

if config.USE_CUSTOM_API_SERVER:
session = AiohttpSession(
session = utils.smart_session.SmartAiogramAiohttpSession(
api=TelegramAPIServer(
base=config.CUSTOM_API_SERVER_BASE,
file=config.CUSTOM_API_SERVER_FILE,
is_local=config.CUSTOM_API_SERVER_IS_LOCAL,
),
json_loads=orjson.loads,
logger=aiogram_session_logger,
)
else:
session = AiohttpSession(json_loads=orjson.loads)
session = utils.smart_session.SmartAiogramAiohttpSession(
json_loads=orjson.loads,
logger=aiogram_session_logger,
)
bot = Bot(config.BOT_TOKEN, parse_mode="HTML", session=session)

dp = Dispatcher(
Expand All @@ -307,6 +224,7 @@ def main() -> None:
key_builder=DefaultKeyBuilder(with_bot_id=True),
)
)
dp["aiogram_session_logger"] = aiogram_session_logger

if config.USE_WEBHOOK:
dp.startup.register(aiogram_on_startup_webhook)
Expand Down
9 changes: 6 additions & 3 deletions aiogram_bot_template/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
FSM_PORT: int = env.int("FSM_PORT")
FSM_PASSWORD: str = env.str("FSM_PASSWORD")

CACHE_HOST: str = env.str("CACHE_HOST")
CACHE_PORT: int = env.int("CACHE_PORT")
CACHE_PASSWORD: str = env.str("CACHE_PASSWORD")
USE_CACHE: bool = env.bool("USE_CACHE")

if USE_CACHE:
CACHE_HOST: str = env.str("CACHE_HOST")
CACHE_PORT: int = env.int("CACHE_PORT")
CACHE_PASSWORD: str = env.str("CACHE_PASSWORD")

USE_WEBHOOK: bool = env.bool("USE_WEBHOOK", False)

Expand Down
44 changes: 36 additions & 8 deletions aiogram_bot_template/db/db_api/storages/basestorage/storage.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,54 @@
from typing import Any, Optional, Type, TypeVar
import typing
from typing import Any, Optional, TypeVar

T = TypeVar("T")


class SingleQueryResult:
def __init__(self, result: Optional[typing.Mapping[str, Any]]):
self._data = {**result} if result else None

@property
def data(self) -> Optional[dict[str, Any]]:
return self._data

def convert(self, model: type[T]) -> Optional[T]:
return model(**self.data) if self._data else None


class MultipleQueryResults:
def __init__(self, results: list[typing.Mapping[str, Any]]):
self._data: list[dict[str, Any]] = [{**i} for i in results]

@property
def data(self) -> list[dict[str, Any]]:
return self._data

def convert(self, model: type[T]) -> list[T]:
return [model(**i) for i in self._data]


class RawConnection:
async def _fetch(
self,
sql: str,
params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]],
model_type: Type[T],
) -> Optional[list[T]]:
params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]] = None,
con: Optional[Any] = None,
) -> MultipleQueryResults:
raise NotImplementedError

async def _fetchrow(
self,
sql: str,
params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]],
model_type: Type[T],
) -> Optional[T]:
params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]] = None,
con: Optional[Any] = None,
) -> SingleQueryResult:
raise NotImplementedError

async def _execute(
self, sql: str, params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]]
self,
sql: str,
params: Optional[tuple[Any, ...] | list[tuple[Any, ...]]] = None,
con: Optional[Any] = None,
) -> None:
raise NotImplementedError
Loading

0 comments on commit 4cba60a

Please sign in to comment.