Skip to content

Commit

Permalink
REFACTOR
Browse files Browse the repository at this point in the history
  • Loading branch information
mazzasaverio committed Apr 1, 2024
1 parent c90be84 commit 3828f24
Show file tree
Hide file tree
Showing 16 changed files with 411 additions and 117 deletions.
Empty file added backend/app/api/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions backend/app/api/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from fastapi import APIRouter

from app.api.routes import qa

api_router = APIRouter()

api_router.include_router(
qa.router,
prefix="/qa",
tags=["qa"],
)
Empty file.
113 changes: 113 additions & 0 deletions backend/app/api/routes/qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from app.core.db import SessionLocal
import os
import yaml
from fastapi import APIRouter


from app.core.config import logger


from operator import itemgetter

from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
from langchain_core.prompts import format_document
from langchain_core.runnables import RunnableParallel
from langchain_community.vectorstores.pgvector import PGVector
from langchain.memory import ConversationBufferMemory
from app.core.config import settings
from langchain.prompts.prompt import PromptTemplate
from pydantic import BaseModel

router = APIRouter()

config_path = os.path.join(os.path.dirname(__file__), "..", "..", "config.yml")
with open(config_path, "r") as config_file:
config = yaml.load(config_file, Loader=yaml.FullLoader)

chat_config = config.get("CHAT_CONFIG", None)


class ChatRequest(BaseModel):
message: str


@router.post("/chat")
async def chat_action(request: ChatRequest):

embeddings = OpenAIEmbeddings()

store = PGVector(
collection_name="docs",
connection_string=settings.SYNC_DATABASE_URI,
embedding_function=embeddings,
)

retriever = store.as_retriever()

# Load prompts from configuration
_template_condense = chat_config["PROMPTS"]["CONDENSE_QUESTION"]
_template_answer = chat_config["PROMPTS"]["ANSWER_QUESTION"]
_template_default_document = chat_config["PROMPTS"]["DEFAULT_DOCUMENT"]

# Your existing logic here, replace hardcoded prompt templates with loaded ones
# Example of using loaded prompts:
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template_condense)
ANSWER_PROMPT = ChatPromptTemplate.from_template(_template_answer)
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(_template_default_document)

def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
doc_strings = [format_document(doc, document_prompt) for doc in docs]

return document_separator.join(doc_strings)

memory = ConversationBufferMemory(
return_messages=True, output_key="answer", input_key="question"
)

# First we add a step to load memory
# This adds a "memory" key to the input object
loaded_memory = RunnablePassthrough.assign(
chat_history=RunnableLambda(memory.load_memory_variables)
| itemgetter("history"),
)
# Now we calculate the standalone question
standalone_question = {
"standalone_question": {
"question": lambda x: x["question"],
"chat_history": lambda x: get_buffer_string(x["chat_history"]),
}
| CONDENSE_QUESTION_PROMPT
| ChatOpenAI(temperature=0)
| StrOutputParser(),
}
# Now we retrieve the documents
retrieved_documents = {
"docs": itemgetter("standalone_question") | retriever,
"question": lambda x: x["standalone_question"],
}
# Now we construct the inputs for the final prompt
final_inputs = {
"context": lambda x: _combine_documents(x["docs"]),
"question": itemgetter("question"),
}

logger.info(f"Final inputs: {final_inputs}")
# And finally, we do the part that returns the answers
answer = {
"answer": final_inputs | ANSWER_PROMPT | ChatOpenAI(),
"docs": itemgetter("docs"),
}

final_chain = loaded_memory | standalone_question | retrieved_documents | answer

inputs = {"question": request.message}
result = final_chain.invoke(inputs)

return result["answer"].content
17 changes: 17 additions & 0 deletions backend/app/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
CHAT_CONFIG:
LLM_MODEL: "gpt-3.5-turbo"
MAX_TOKEN_LIMIT: 1000
PROMPTS:
CONDENSE_QUESTION: >
Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:
ANSWER_QUESTION: >
Answer the question based only on the following context:
{context}
Question: {question}
DEFAULT_DOCUMENT: "{page_content}"
17 changes: 15 additions & 2 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
from pydantic_settings import BaseSettings
from typing import List
from backend.logging_config import logger
from loguru import logger


class Settings(BaseSettings):

API_VERSION: str = "v1"
API_V1_STR: str = f"/api/{API_VERSION}"
PROJECT_NAME: str

DB_HOST: str
DB_PORT: str
DB_NAME: str
DB_PASS: str
DB_USER: str

OPENAI_API_KEY: str
OPENAI_ORGANIZATION: str

REDIS_HOST: str
REDIS_PORT: str

TAVILY_API_KEY: str

@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
def ASYNC_DATABASE_URI(self) -> str:
return f"postgresql+asyncpg://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"

@property
def SYNC_DATABASE_URI(self) -> str:
return f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"

class Config:
env_file = "../.env"

Expand Down
89 changes: 29 additions & 60 deletions backend/app/core/db.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,36 @@
from sqlmodel import SQLModel
from sqlalchemy.ext.asyncio import create_async_engine
from app.core.config import settings
import asyncpg
import psycopg2
from loguru import logger

engine = create_async_engine(str(settings.SQLALCHEMY_DATABASE_URI), echo=True)

async def create_extension():
conn: asyncpg.Connection = await asyncpg.connect(
user=settings.DB_USER,
password=settings.DB_PASS,
database=settings.DB_NAME,
host=settings.DB_HOST,
)
try:
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
logger.info("pgvector extension created or already exists.")
except Exception as e:
logger.error(f"Error creating pgvector extension: {e}")
finally:
await conn.close()


def create_database(database_name, user, password, host, port):
try:
# Connect to the default database
conn = psycopg2.connect(
dbname=database_name, user=user, password=password, host=host, port=port
)
conn.autocommit = True
cur = conn.cursor()

# Check if database exists
cur.execute(
f"SELECT 1 FROM pg_catalog.pg_database WHERE datname = '{database_name}'"
from sqlalchemy.orm import sessionmaker
from sqlmodel.ext.asyncio.session import AsyncSession

DB_POOL_SIZE = 83
WEB_CONCURRENCY = 9
POOL_SIZE = max(
DB_POOL_SIZE // WEB_CONCURRENCY,
5,
)


def _get_local_session() -> sessionmaker:
engine = (
create_async_engine(
url=settings.ASYNC_DATABASE_URI,
future=True,
pool_size=POOL_SIZE,
max_overflow=64,
)
exists = cur.fetchone()
if not exists:

cur.execute(f"CREATE DATABASE {database_name}")
logger.info(f"Database '{database_name}' created.")
else:
logger.info(f"Database '{database_name}' already exists.")

cur.close()
conn.close()
except Exception as e:
logger.error(f"Error creating database: {e}")

async def init_db() -> None:
create_database(
settings.DB_NAME,
settings.DB_USER,
settings.DB_PASS,
settings.DB_HOST,
settings.DB_PORT,
if settings.ASYNC_DATABASE_URI is not None
else None
)
return sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with engine.begin() as conn:
# Use run_sync to execute the create_all method in an asynchronous context
await conn.run_sync(SQLModel.metadata.create_all)

# Your existing database initialization logic here
# For example, creating extensions or setting up initial data
await create_extension()
logger.info("Database initialized and all tables created if they didn't exist.")

SessionLocal = _get_local_session()
25 changes: 13 additions & 12 deletions backend/app/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@

from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from contextlib import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware
from app.core.db import init_db

from app.api.main import api_router
from app.core.config import settings

@asynccontextmanager
async def app_lifespan(app: FastAPI):
await init_db()
yield
from typing import Dict

app = FastAPI(
openapi_url=f"{settings.API_V1_STR}/openapi.json",
docs_url=f"{settings.API_V1_STR}/docs",
)

app = FastAPI(lifespan=app_lifespan)

app.add_middleware(
CORSMiddleware,
Expand All @@ -22,13 +20,16 @@ async def app_lifespan(app: FastAPI):
allow_headers=["*"],
)


@app.get("/metrics")
def metrics():
return {"message": "Metrics endpoint"}

@app.get("/")
async def home():
return {"data": "Hello"}

@app.get("/")
async def root() -> Dict[str, str]:
"""An example "Hello world" FastAPI route."""
return {"message": "FastAPI backend"}


app.include_router(api_router, prefix=settings.API_V1_STR)
5 changes: 5 additions & 0 deletions backend/app/schemas/chat_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pydantic import BaseModel


class ChatBody(BaseModel):
message: str
43 changes: 43 additions & 0 deletions backend/ingestion/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from pydantic_settings import BaseSettings
from typing import List
from loguru import logger
import sys


class Settings(BaseSettings):

API_VERSION: str = "v1"
API_V1_STR: str = f"/api/{API_VERSION}"

DB_HOST: str
DB_PORT: str
DB_NAME: str
DB_PASS: str
DB_USER: str

OPENAI_API_KEY: str

@property
def ASYNC_DATABASE_URI(self) -> str:
return f"postgresql+asyncpg://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"

class Config:
env_file = "../.env"


class LogConfig:
LOGGING_LEVEL = "DEBUG"
LOGGING_FORMAT = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level}</level> | <level>{message}</level>"

@staticmethod
def configure_logging():
logger.remove()

logger.add(
sys.stderr, format=LogConfig.LOGGING_FORMAT, level=LogConfig.LOGGING_LEVEL
)


LogConfig.configure_logging()

settings = Settings()
Loading

0 comments on commit 3828f24

Please sign in to comment.