-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9d5c42f
commit f898c2b
Showing
9 changed files
with
623 additions
and
4 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from pydantic_settings import BaseSettings, SettingsConfigDict | ||
from typing import List | ||
from loguru import logger | ||
import sys | ||
|
||
|
||
class Settings(BaseSettings): | ||
|
||
API_VERSION: str = "v1" | ||
API_V1_STR: str = f"/api/{API_VERSION}" | ||
|
||
DB_HOST: str | ||
DB_PORT: str | ||
DB_NAME: str | ||
DB_PASS: str | ||
DB_USER: str | ||
|
||
@property | ||
def SQLALCHEMY_DATABASE_URI(self) -> str: | ||
return f"postgresql+asyncpg://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}" | ||
|
||
class Config: | ||
env_file = "../.env" | ||
|
||
class LogConfig: | ||
LOGGING_LEVEL = "DEBUG" | ||
LOGGING_FORMAT = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level}</level> | <level>{message}</level>" | ||
|
||
@staticmethod | ||
def configure_logging(): | ||
logger.remove() | ||
|
||
logger.add( | ||
sys.stderr, format=LogConfig.LOGGING_FORMAT, level=LogConfig.LOGGING_LEVEL | ||
) | ||
|
||
|
||
LogConfig.configure_logging() | ||
|
||
settings = Settings() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from sqlmodel import SQLModel | ||
from sqlalchemy.ext.asyncio import create_async_engine | ||
from app.core.config import settings | ||
import asyncpg | ||
import psycopg2 | ||
from loguru import logger | ||
|
||
engine = create_async_engine(str(settings.SQLALCHEMY_DATABASE_URI), echo=True) | ||
|
||
async def create_extension(): | ||
conn: asyncpg.Connection = await asyncpg.connect( | ||
user=settings.DB_USER, | ||
password=settings.DB_PASS, | ||
database=settings.DB_NAME, | ||
host=settings.DB_HOST, | ||
) | ||
try: | ||
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") | ||
logger.info("pgvector extension created or already exists.") | ||
except Exception as e: | ||
logger.error(f"Error creating pgvector extension: {e}") | ||
finally: | ||
await conn.close() | ||
|
||
|
||
def create_database(database_name, user, password, host, port): | ||
try: | ||
# Connect to the default database | ||
conn = psycopg2.connect( | ||
dbname=database_name, user=user, password=password, host=host, port=port | ||
) | ||
conn.autocommit = True | ||
cur = conn.cursor() | ||
|
||
# Check if database exists | ||
cur.execute( | ||
f"SELECT 1 FROM pg_catalog.pg_database WHERE datname = '{database_name}'" | ||
) | ||
exists = cur.fetchone() | ||
if not exists: | ||
|
||
cur.execute(f"CREATE DATABASE {database_name}") | ||
logger.info(f"Database '{database_name}' created.") | ||
else: | ||
logger.info(f"Database '{database_name}' already exists.") | ||
|
||
cur.close() | ||
conn.close() | ||
except Exception as e: | ||
logger.error(f"Error creating database: {e}") | ||
|
||
async def init_db() -> None: | ||
create_database( | ||
settings.DB_NAME, | ||
settings.DB_USER, | ||
settings.DB_PASS, | ||
settings.DB_HOST, | ||
settings.DB_PORT, | ||
) | ||
async with engine.begin() as conn: | ||
# Use run_sync to execute the create_all method in an asynchronous context | ||
await conn.run_sync(SQLModel.metadata.create_all) | ||
|
||
# Your existing database initialization logic here | ||
# For example, creating extensions or setting up initial data | ||
await create_extension() | ||
logger.info("Database initialized and all tables created if they didn't exist.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
from sqlmodel import Field, Relationship, SQLModel | ||
|
||
|
||
# Shared properties | ||
# TODO replace email str with EmailStr when sqlmodel supports it | ||
class UserBase(SQLModel): | ||
email: str = Field(unique=True, index=True) | ||
is_active: bool = True | ||
is_superuser: bool = False | ||
full_name: str | None = None | ||
|
||
|
||
# Properties to receive via API on creation | ||
class UserCreate(UserBase): | ||
password: str | ||
|
||
|
||
# TODO replace email str with EmailStr when sqlmodel supports it | ||
class UserCreateOpen(SQLModel): | ||
email: str | ||
password: str | ||
full_name: str | None = None | ||
|
||
|
||
# Properties to receive via API on update, all are optional | ||
# TODO replace email str with EmailStr when sqlmodel supports it | ||
class UserUpdate(UserBase): | ||
email: str | None = None # type: ignore | ||
password: str | None = None | ||
|
||
|
||
# TODO replace email str with EmailStr when sqlmodel supports it | ||
class UserUpdateMe(SQLModel): | ||
full_name: str | None = None | ||
email: str | None = None | ||
|
||
|
||
class UpdatePassword(SQLModel): | ||
current_password: str | ||
new_password: str | ||
|
||
|
||
# Database model, database table inferred from class name | ||
class User(UserBase, table=True): | ||
id: int | None = Field(default=None, primary_key=True) | ||
hashed_password: str | ||
items: list["Item"] = Relationship(back_populates="owner") | ||
|
||
|
||
# Properties to return via API, id is always required | ||
class UserOut(UserBase): | ||
id: int | ||
|
||
|
||
class UsersOut(SQLModel): | ||
data: list[UserOut] | ||
count: int | ||
|
||
|
||
# Shared properties | ||
class ItemBase(SQLModel): | ||
title: str | ||
description: str | None = None | ||
|
||
|
||
# Properties to receive on item creation | ||
class ItemCreate(ItemBase): | ||
title: str | ||
|
||
|
||
# Properties to receive on item update | ||
class ItemUpdate(ItemBase): | ||
title: str | None = None # type: ignore | ||
|
||
|
||
# Database model, database table inferred from class name | ||
class Item(ItemBase, table=True): | ||
id: int | None = Field(default=None, primary_key=True) | ||
title: str | ||
owner_id: int | None = Field(default=None, foreign_key="user.id", nullable=False) | ||
owner: User | None = Relationship(back_populates="items") | ||
|
||
|
||
# Properties to return via API, id is always required | ||
class ItemOut(ItemBase): | ||
id: int | ||
owner_id: int | ||
|
||
|
||
class ItemsOut(SQLModel): | ||
data: list[ItemOut] | ||
count: int | ||
|
||
|
||
# Generic message | ||
class Message(SQLModel): | ||
message: str | ||
|
||
|
||
# JSON payload containing access token | ||
class Token(SQLModel): | ||
access_token: str | ||
token_type: str = "bearer" | ||
|
||
|
||
# Contents of JWT token | ||
class TokenPayload(SQLModel): | ||
sub: int | None = None | ||
|
||
|
||
class NewPassword(SQLModel): | ||
token: str | ||
new_password: str |
Oops, something went wrong.