-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c90be84
commit 3828f24
Showing
16 changed files
with
411 additions
and
117 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from fastapi import APIRouter | ||
|
||
from app.api.routes import qa | ||
|
||
api_router = APIRouter() | ||
|
||
api_router.include_router( | ||
qa.router, | ||
prefix="/qa", | ||
tags=["qa"], | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
from app.core.db import SessionLocal | ||
import os | ||
import yaml | ||
from fastapi import APIRouter | ||
|
||
|
||
from app.core.config import logger | ||
|
||
|
||
from operator import itemgetter | ||
|
||
from langchain_community.vectorstores import FAISS | ||
from langchain_core.output_parsers import StrOutputParser | ||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain_core.runnables import RunnableLambda, RunnablePassthrough | ||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | ||
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string | ||
from langchain_core.prompts import format_document | ||
from langchain_core.runnables import RunnableParallel | ||
from langchain_community.vectorstores.pgvector import PGVector | ||
from langchain.memory import ConversationBufferMemory | ||
from app.core.config import settings | ||
from langchain.prompts.prompt import PromptTemplate | ||
from pydantic import BaseModel | ||
|
||
router = APIRouter() | ||
|
||
config_path = os.path.join(os.path.dirname(__file__), "..", "..", "config.yml") | ||
with open(config_path, "r") as config_file: | ||
config = yaml.load(config_file, Loader=yaml.FullLoader) | ||
|
||
chat_config = config.get("CHAT_CONFIG", None) | ||
|
||
|
||
class ChatRequest(BaseModel): | ||
message: str | ||
|
||
|
||
@router.post("/chat") | ||
async def chat_action(request: ChatRequest): | ||
|
||
embeddings = OpenAIEmbeddings() | ||
|
||
store = PGVector( | ||
collection_name="docs", | ||
connection_string=settings.SYNC_DATABASE_URI, | ||
embedding_function=embeddings, | ||
) | ||
|
||
retriever = store.as_retriever() | ||
|
||
# Load prompts from configuration | ||
_template_condense = chat_config["PROMPTS"]["CONDENSE_QUESTION"] | ||
_template_answer = chat_config["PROMPTS"]["ANSWER_QUESTION"] | ||
_template_default_document = chat_config["PROMPTS"]["DEFAULT_DOCUMENT"] | ||
|
||
# Your existing logic here, replace hardcoded prompt templates with loaded ones | ||
# Example of using loaded prompts: | ||
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template_condense) | ||
ANSWER_PROMPT = ChatPromptTemplate.from_template(_template_answer) | ||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(_template_default_document) | ||
|
||
def _combine_documents( | ||
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" | ||
): | ||
doc_strings = [format_document(doc, document_prompt) for doc in docs] | ||
|
||
return document_separator.join(doc_strings) | ||
|
||
memory = ConversationBufferMemory( | ||
return_messages=True, output_key="answer", input_key="question" | ||
) | ||
|
||
# First we add a step to load memory | ||
# This adds a "memory" key to the input object | ||
loaded_memory = RunnablePassthrough.assign( | ||
chat_history=RunnableLambda(memory.load_memory_variables) | ||
| itemgetter("history"), | ||
) | ||
# Now we calculate the standalone question | ||
standalone_question = { | ||
"standalone_question": { | ||
"question": lambda x: x["question"], | ||
"chat_history": lambda x: get_buffer_string(x["chat_history"]), | ||
} | ||
| CONDENSE_QUESTION_PROMPT | ||
| ChatOpenAI(temperature=0) | ||
| StrOutputParser(), | ||
} | ||
# Now we retrieve the documents | ||
retrieved_documents = { | ||
"docs": itemgetter("standalone_question") | retriever, | ||
"question": lambda x: x["standalone_question"], | ||
} | ||
# Now we construct the inputs for the final prompt | ||
final_inputs = { | ||
"context": lambda x: _combine_documents(x["docs"]), | ||
"question": itemgetter("question"), | ||
} | ||
|
||
logger.info(f"Final inputs: {final_inputs}") | ||
# And finally, we do the part that returns the answers | ||
answer = { | ||
"answer": final_inputs | ANSWER_PROMPT | ChatOpenAI(), | ||
"docs": itemgetter("docs"), | ||
} | ||
|
||
final_chain = loaded_memory | standalone_question | retrieved_documents | answer | ||
|
||
inputs = {"question": request.message} | ||
result = final_chain.invoke(inputs) | ||
|
||
return result["answer"].content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
CHAT_CONFIG: | ||
LLM_MODEL: "gpt-3.5-turbo" | ||
MAX_TOKEN_LIMIT: 1000 | ||
PROMPTS: | ||
CONDENSE_QUESTION: > | ||
Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. | ||
Chat History: | ||
{chat_history} | ||
Follow Up Input: {question} | ||
Standalone question: | ||
ANSWER_QUESTION: > | ||
Answer the question based only on the following context: | ||
{context} | ||
Question: {question} | ||
DEFAULT_DOCUMENT: "{page_content}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,36 @@ | ||
from sqlmodel import SQLModel | ||
from sqlalchemy.ext.asyncio import create_async_engine | ||
from app.core.config import settings | ||
import asyncpg | ||
import psycopg2 | ||
from loguru import logger | ||
|
||
engine = create_async_engine(str(settings.SQLALCHEMY_DATABASE_URI), echo=True) | ||
|
||
async def create_extension(): | ||
conn: asyncpg.Connection = await asyncpg.connect( | ||
user=settings.DB_USER, | ||
password=settings.DB_PASS, | ||
database=settings.DB_NAME, | ||
host=settings.DB_HOST, | ||
) | ||
try: | ||
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") | ||
logger.info("pgvector extension created or already exists.") | ||
except Exception as e: | ||
logger.error(f"Error creating pgvector extension: {e}") | ||
finally: | ||
await conn.close() | ||
|
||
|
||
def create_database(database_name, user, password, host, port): | ||
try: | ||
# Connect to the default database | ||
conn = psycopg2.connect( | ||
dbname=database_name, user=user, password=password, host=host, port=port | ||
) | ||
conn.autocommit = True | ||
cur = conn.cursor() | ||
|
||
# Check if database exists | ||
cur.execute( | ||
f"SELECT 1 FROM pg_catalog.pg_database WHERE datname = '{database_name}'" | ||
from sqlalchemy.orm import sessionmaker | ||
from sqlmodel.ext.asyncio.session import AsyncSession | ||
|
||
DB_POOL_SIZE = 83 | ||
WEB_CONCURRENCY = 9 | ||
POOL_SIZE = max( | ||
DB_POOL_SIZE // WEB_CONCURRENCY, | ||
5, | ||
) | ||
|
||
|
||
def _get_local_session() -> sessionmaker: | ||
engine = ( | ||
create_async_engine( | ||
url=settings.ASYNC_DATABASE_URI, | ||
future=True, | ||
pool_size=POOL_SIZE, | ||
max_overflow=64, | ||
) | ||
exists = cur.fetchone() | ||
if not exists: | ||
|
||
cur.execute(f"CREATE DATABASE {database_name}") | ||
logger.info(f"Database '{database_name}' created.") | ||
else: | ||
logger.info(f"Database '{database_name}' already exists.") | ||
|
||
cur.close() | ||
conn.close() | ||
except Exception as e: | ||
logger.error(f"Error creating database: {e}") | ||
|
||
async def init_db() -> None: | ||
create_database( | ||
settings.DB_NAME, | ||
settings.DB_USER, | ||
settings.DB_PASS, | ||
settings.DB_HOST, | ||
settings.DB_PORT, | ||
if settings.ASYNC_DATABASE_URI is not None | ||
else None | ||
) | ||
return sessionmaker( | ||
autocommit=False, | ||
autoflush=False, | ||
bind=engine, | ||
class_=AsyncSession, | ||
expire_on_commit=False, | ||
) | ||
async with engine.begin() as conn: | ||
# Use run_sync to execute the create_all method in an asynchronous context | ||
await conn.run_sync(SQLModel.metadata.create_all) | ||
|
||
# Your existing database initialization logic here | ||
# For example, creating extensions or setting up initial data | ||
await create_extension() | ||
logger.info("Database initialized and all tables created if they didn't exist.") | ||
|
||
SessionLocal = _get_local_session() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class ChatBody(BaseModel): | ||
message: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from pydantic_settings import BaseSettings | ||
from typing import List | ||
from loguru import logger | ||
import sys | ||
|
||
|
||
class Settings(BaseSettings): | ||
|
||
API_VERSION: str = "v1" | ||
API_V1_STR: str = f"/api/{API_VERSION}" | ||
|
||
DB_HOST: str | ||
DB_PORT: str | ||
DB_NAME: str | ||
DB_PASS: str | ||
DB_USER: str | ||
|
||
OPENAI_API_KEY: str | ||
|
||
@property | ||
def ASYNC_DATABASE_URI(self) -> str: | ||
return f"postgresql+asyncpg://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}" | ||
|
||
class Config: | ||
env_file = "../.env" | ||
|
||
|
||
class LogConfig: | ||
LOGGING_LEVEL = "DEBUG" | ||
LOGGING_FORMAT = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level}</level> | <level>{message}</level>" | ||
|
||
@staticmethod | ||
def configure_logging(): | ||
logger.remove() | ||
|
||
logger.add( | ||
sys.stderr, format=LogConfig.LOGGING_FORMAT, level=LogConfig.LOGGING_LEVEL | ||
) | ||
|
||
|
||
LogConfig.configure_logging() | ||
|
||
settings = Settings() |
Oops, something went wrong.