Skip to content

Commit

Permalink
Add database
Browse files Browse the repository at this point in the history
  • Loading branch information
mezgoodle committed Sep 26, 2023
1 parent b4eb778 commit e159a59
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 1 deletion.
13 changes: 12 additions & 1 deletion bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

from loader import bot, dp
from tgbot.config import Settings, config
from tgbot.middlewares.database import DatabaseMiddleware
from tgbot.middlewares.settings import ConfigMiddleware
from tgbot.middlewares.throttling import ThrottlingMiddleware
from tgbot.misc.database import Database
from tgbot.models.models import close_db, init
from tgbot.services.admins_notify import on_startup_notify
from tgbot.services.setting_commands import set_default_commands

Expand All @@ -27,7 +30,7 @@ def register_global_middlewares(dp: Dispatcher, config: Settings):
middlewares = [
ConfigMiddleware(config),
ThrottlingMiddleware(),
# DatabaseMiddleware(session_pool),
DatabaseMiddleware(Database()),
]

for middleware in middlewares:
Expand All @@ -40,9 +43,15 @@ def register_global_middlewares(dp: Dispatcher, config: Settings):
logging.info("Middlewares registered.")


async def init_database():
await init()
logging.info("Database was inited")


async def on_startup(bot: Bot, dispatcher: Dispatcher) -> None:
register_all_handlers()
register_global_middlewares(dispatcher, config)
await init_database()
await register_all_commands(bot)
await on_startup_notify(bot)
logging.info("Bot started.")
Expand All @@ -51,6 +60,8 @@ async def on_startup(bot: Bot, dispatcher: Dispatcher) -> None:
async def on_shutdown(dispatcher: Dispatcher) -> None:
await dispatcher.storage.close()
logging.info("Storage closed.")
await close_db()
logging.info("Database was closed.")
logging.info("Bot stopped.")


Expand Down
20 changes: 20 additions & 0 deletions tgbot/middlewares/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any, Awaitable, Callable, Dict

from aiogram import BaseMiddleware
from aiogram.types import Message

from tgbot.misc.database import Database


class DatabaseMiddleware(BaseMiddleware):
def __init__(self, db_instance: Database) -> None:
self.db = db_instance

async def __call__(
self,
handler: Callable[[Message, Dict[str, Any]], Awaitable[Any]],
event: Message,
data: Dict[str, Any],
) -> Any:
data["db"] = self.db
return await handler(event, data)
29 changes: 29 additions & 0 deletions tgbot/misc/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from tgbot.models.models import Event, Team, Tournament


class Database:
def __init__(self):
self.tournament = Tournament
self.event = Event
self.team = Team

async def create_tournament(self, name: str) -> Tournament:
return await self.tournament.create(name=name)

async def create_team(self, name: str) -> Team:
return await self.team.create(name=name)

async def create_event(
self, name: str, tournament_name: str, participants: list[int]
) -> Event:
if (
tournament := await self.tournament.get_or_create(
name=tournament_name
)
) and (teams := await self.team.filter(id__in=participants)):
return await self.event.create(
name=name, tournament=tournament, participants=teams
)

async def get_teams(self) -> list[Team]:
return await self.team.all()
15 changes: 15 additions & 0 deletions tgbot/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from tortoise.fields import DatetimeField
from tortoise.models import Model


class BaseModel(Model):
class Meta:
abstract = True


class TimedBaseModel(BaseModel):
class Meta:
abstract = True

created_at = DatetimeField(auto_now_add=True)
updated_at = DatetimeField(auto_now=True)
106 changes: 106 additions & 0 deletions tgbot/models/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from tortoise import Tortoise, fields

from tgbot.models.base import TimedBaseModel

db = Tortoise()


class User(TimedBaseModel):
name = fields.CharField(max_length=255, null=True, description="User name")
user_id = fields.IntField(
unique=True,
description="Telegram user id",
index=True,
)
username = fields.CharField(
max_length=255,
null=True,
description="Telegram username",
unique=True,
)
subjects: fields.ManyToManyRelation["Subject"]

class Meta:
abstract = True


class Teacher(User):
pass


class Student(User):
completed_tasks: fields.ReverseRelation["CompletedTask"]
uncompleted_tasks: fields.ReverseRelation["UncompletedTask"]


class Subject(TimedBaseModel):
name = fields.CharField(max_length=255, description="Subject name")
description = fields.TextField(null=True)
teacher: fields.ForeignKeyRelation[Teacher] = fields.ForeignKeyField(
"models.Teacher",
related_name="subjects",
description="Subject teacher",
)
students: fields.ManyToManyRelation[Student] = fields.ManyToManyField(
"models.Student",
related_name="subjects",
description="Subject students",
)
drive_link = fields.CharField(
max_length=255,
null=True,
description="Drive link",
)
tasks: fields.ReverseRelation["Task"]

def __str__(self):
return self.name


class Task(TimedBaseModel):
name = fields.CharField(max_length=255)
description = fields.TextField()
subject: fields.ForeignKeyRelation[Subject] = fields.ForeignKeyField(
"models.Subject", related_name="tasks", description="Task subject"
)
due_date = fields.DatetimeField(
null=True,
)

def __str__(self):
return self.name

class Meta:
abstract = True


class UncompletedTask(Task):
student: fields.ForeignKeyRelation[Student] = fields.ForeignKeyField(
"models.Student",
related_name="uncompleted_tasks",
description="Task student",
)


class CompletedTask(Task):
student: fields.ForeignKeyRelation[Student] = fields.ForeignKeyField(
"models.Student",
related_name="completed_tasks",
description="Task student",
)


async def init():
# Here we create a SQLite DB using file "db.sqlite3"
# also specify the app name of "models"
# which contain models from "tgbot.models.models"
await db.init(
db_url="sqlite://db.sqlite3",
modules={"models": ["tgbot.models.models"]},
)
# Generate the schema
await Tortoise.generate_schemas()


async def close_db():
await db.close_connections()

0 comments on commit e159a59

Please sign in to comment.