Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor v1.3 to make it configurable and evaluation pipeline #66

Open
wants to merge 30 commits into
base: feat/v1.3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e623ddb
copy ayulockin/feat/v1.3
ayulockin Mar 5, 2024
e8710b1
update readme
ayulockin Mar 5, 2024
f379765
pin simsimd
ayulockin Mar 6, 2024
b5ec6ed
centralize ingestion pipeline with config
ayulockin Mar 7, 2024
3e1ba27
add litellm embedding
ayulockin Mar 7, 2024
1928f4d
vector store for OpenAI and Cohere
ayulockin Mar 7, 2024
96c608c
add voyage embedding + refactor
ayulockin Mar 8, 2024
bf1d846
update app to take config
ayulockin Mar 8, 2024
d5f385a
log configs
ayulockin Mar 11, 2024
d2094c6
build simple retrieval, separate reranking and build simple synthesis
ayulockin Mar 12, 2024
c09b7b9
with or without reranker: update app
ayulockin Mar 12, 2024
9904859
rag pipeline working with different configs within chat
ayulockin Mar 12, 2024
14afc6f
control synthesizer with config
ayulockin Mar 12, 2024
ab0d463
control query enhancer with config
ayulockin Mar 12, 2024
d9e5992
control web search
ayulockin Mar 12, 2024
4f528ec
update misc
ayulockin Mar 12, 2024
4f37bc0
update ingestion pipeline to take config + update configs
ayulockin Mar 12, 2024
a0d0404
add anthropic config + litellm token + cost counter
ayulockin Mar 14, 2024
0e49d29
update chat.py
ayulockin Mar 14, 2024
6728a39
optimize usage logging
ayulockin Mar 14, 2024
e28f6dc
remove configs to eval
ayulockin Mar 15, 2024
5d70f1b
correct eval
ayulockin Mar 21, 2024
28a9ca5
fix max token issue
ayulockin Mar 21, 2024
8caed76
decouple query enhancement from rerank fusion
ayulockin Mar 21, 2024
8aef4b7
async eval pipeline
ayulockin Mar 22, 2024
04a58be
merge changes from feat/v1.3
ayulockin Mar 24, 2024
7031afa
fix: sub query and temperature issues
parambharat Mar 28, 2024
22d88dd
make eval go brr
ayulockin Apr 1, 2024
6111b2e
read old system prompt
ayulockin Apr 3, 2024
8573e4d
update eval for more speedup and better logging
ayulockin Apr 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ WANDBOT_API_URL="http://localhost:8000"
WANDB_TRACING_ENABLED="true"
WANDB_PROJECT="wandbot-dev"
WANDB_ENTITY="wandbot"
WANDB_REPORT_API_ENABLE_V2="true"
WANDB_REPORT_API_DISABLE_MESSAGE="true"
```

The best practive would be to create a `.env` file in the root of this repo and execute this command in your terminal:

```
set -o allexport; source .env; set +o allexport
```

Once these environment variables are set, you can start the Q&A bot application using the following commands:
Expand Down
3,340 changes: 2,044 additions & 1,296 deletions poetry.lock

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ include = ["src/**/*", "LICENSE", "README.md"]
[tool.poetry.dependencies]
python = ">=3.10.0,<3.12"
numpy = "^1.26.1"
wandb = "<=0.16.1"
tiktoken = "^0.5.1"
wandb = "<=0.16.3"
tiktoken = "^0.6.0"
pandas = "^2.1.2"
unstructured = "^0.12.3"
pydantic-settings = "^2.0.3"
Expand All @@ -34,19 +34,21 @@ zenpy = "^2.0.46"
openai = "^1.3.2"
weave = "^0.31.0"
colorlog = "^6.8.0"
litellm = "^1.15.1"
litellm = "^1.31.6"
google-cloud-bigquery = "^3.14.1"
db-dtypes = "^1.2.0"
python-frontmatter = "^1.1.0"
pymdown-extensions = "^10.5"
langchain = "^0.1.5"
langchain-openai = "^0.0.5"
langchain-openai = "^0.0.8"
chromadb = "^0.4.22"
langchain-experimental = "^0.0.50"
simsimd = "^3.7.4"
simsimd = "3.7.7"

[tool.poetry.dev-dependencies]
#fasttext = {git = "https://github.com/cfculhane/fastText"} # FastText doesn't come with pybind11 and we need to use this workaround.
llama_index = "^0.10.20"
ragas = "^0.1.4"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
31 changes: 24 additions & 7 deletions src/wandbot/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
It uses logger from the utils module for logging purposes.
"""

import os
os.environ["CONFIG_PATH"] = "config.yaml"

import asyncio
from contextlib import asynccontextmanager
from datetime import datetime, timezone
Expand All @@ -38,13 +41,21 @@
from wandbot.api.routers import chat as chat_router
from wandbot.api.routers import database as database_router
from wandbot.api.routers import retrieve as retrieve_router
from wandbot.ingestion.config import VectorStoreConfig
from wandbot.ingestion.utils import get_embedding_config
from wandbot.retriever import VectorStore
from wandbot.utils import get_logger
from wandbot.utils import get_logger, load_config

logger = get_logger(__name__)
last_backup = datetime.now().astimezone(timezone.utc)

config = load_config(os.environ["CONFIG_PATH"])
logger.info(
f"Loaded configuration from {os.environ['CONFIG_PATH']}: {config}"
)
vectorstore_config = get_embedding_config(
config.embeddings.type, config.embeddings.config
)


@asynccontextmanager
async def lifespan(app: FastAPI):
Expand All @@ -56,12 +67,18 @@ async def lifespan(app: FastAPI):
Returns:
None
"""
vector_store = VectorStore.from_config(VectorStoreConfig())
chat_router.chat = chat_router.Chat(vector_store=vector_store)
vector_store = VectorStore.from_config(vectorstore_config)
chat_router.chat = chat_router.Chat(vector_store=vector_store, config=config)
database_router.db_client = database_router.DatabaseClient()
retrieve_router.retriever = retrieve_router.SimpleRetrievalEngine(
vector_store=vector_store
)
if not config.retrieval_re_ranker.enabled:
retrieve_router.retriever = retrieve_router.SimpleRetrievalEngine(
vector_store=vector_store
)
else:
retrieve_router.retriever = retrieve_router.SimpleRetrievalEngineWithRerank(
vector_store=vector_store
)


async def backup_db():
"""Periodically backs up the database to a table.
Expand Down
8 changes: 6 additions & 2 deletions src/wandbot/api/routers/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from pydantic import BaseModel
from starlette import status

from wandbot.retriever.base import SimpleRetrievalEngine
from wandbot.retriever.base import (
SimpleRetrievalEngine,
SimpleRetrievalEngineWithRerank,
)

router = APIRouter(
prefix="/retrieve",
tags=["retrievers"],
)

retriever: SimpleRetrievalEngine | None = None
retriever: SimpleRetrievalEngineWithRerank | None = None


class APIRetrievalResult(BaseModel):
Expand Down Expand Up @@ -53,6 +56,7 @@ async def retrieve(request: APIRetrievalRequest) -> APIRetrievalResponse:
sources=request.sources,
)

# TODO: Fix this
return APIRetrievalResponse(
query=request.query,
top_k=[
Expand Down
26 changes: 12 additions & 14 deletions src/wandbot/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
from wandbot.chat.schemas import ChatRequest, ChatResponse
from wandbot.database.schemas import QuestionAnswer
from wandbot.retriever import VectorStore
from wandbot.utils import Timer, get_logger
from weave.monitoring import StreamTable
from wandbot.utils import Timer, get_logger, RAGPipelineConfig

logger = get_logger(__name__)

Expand All @@ -49,29 +48,28 @@ class Chat:
config: ChatConfig = ChatConfig()

def __init__(
self, vector_store: VectorStore, config: ChatConfig | None = None
self,
vector_store: VectorStore,
config: RAGPipelineConfig,
chat_config: ChatConfig | None = None
):
"""Initializes the Chat instance.

Args:
config: An instance of ChatConfig containing configuration settings.
"""
self.vector_store = vector_store
if config is not None:
self.config = config
self.config = config
if chat_config is not None:
self.chat_config = chat_config
self.run = wandb.init(
project=self.config.wandb_project,
entity=self.config.wandb_entity,
project=self.config.project,
entity=self.config.entity,
job_type="chat",
)
self.run._label(repo="wandbot")
self.chat_table = StreamTable(
table_name="chat_logs",
project_name=self.config.wandb_project,
entity_name=self.config.wandb_entity,
)

self.rag_pipeline = RAGPipeline(vector_store=vector_store)
self.rag_pipeline = RAGPipeline(vector_store=vector_store, config=config)

def _get_answer(
self, question: str, chat_history: List[QuestionAnswer]
Expand Down Expand Up @@ -105,10 +103,10 @@ def __call__(self, chat_request: ChatRequest) -> ChatResponse:
"total_tokens": result.total_tokens,
"prompt_tokens": result.prompt_tokens,
"completion_tokens": result.completion_tokens,
"completion_cost": result.completion_cost,
}
result_dict.update({"application": chat_request.application})
self.run.log(usage_stats)
self.chat_table.log(result_dict)
return ChatResponse(**result_dict)
except Exception as e:
with Timer() as timer:
Expand Down
2 changes: 0 additions & 2 deletions src/wandbot/chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,3 @@ class ChatConfig(BaseSettings):
env="WANDB_INDEX_ARTIFACT",
validation_alias="wandb_index_artifact",
)
wandb_project: str | None = Field("wandbot_public", env="WANDB_PROJECT")
wandb_entity: str | None = Field("wandbot", env="WANDB_ENTITY")
139 changes: 87 additions & 52 deletions src/wandbot/chat/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,22 @@

from langchain_community.callbacks import get_openai_callback
from pydantic import BaseModel
from wandbot.rag.query_handler import QueryEnhancer
from wandbot.rag.response_synthesis import ResponseSynthesizer
from wandbot.rag.retrieval import FusionRetrieval
from wandbot.retriever import VectorStore
from wandbot.utils import Timer, get_logger

logger = get_logger(__name__)
import litellm

from wandbot.retriever import VectorStore
from wandbot.utils import Timer, get_logger, RAGPipelineConfig
from wandbot.rag.query_handler import QueryEnhancer
from wandbot.rag.retrieval import FusionRetrieval, SimpleRetrievalWithoutFusion
from wandbot.rag.response_synthesis import ResponseSynthesizer, SimpleResponseSynthesizer
from wandbot.retriever.base import SimpleRetrievalEngine, SimpleRetrievalEngineWithRerank
from wandbot.rag.utils import LiteLLMTokenCostLogger

def get_stats_dict_from_token_callback(token_callback):
return {
"total_tokens": token_callback.total_tokens,
"prompt_tokens": token_callback.prompt_tokens,
"completion_tokens": token_callback.completion_tokens,
"successful_requests": token_callback.successful_requests,
}

logger = get_logger(__name__)

def get_stats_dict_from_timer(timer):
return {
"start_time": timer.start,
"end_time": timer.stop,
"time_taken": timer.elapsed,
}
litellm_token_and_cost_counter = LiteLLMTokenCostLogger()
litellm.callbacks = [litellm_token_and_cost_counter]


class RAGPipelineOutput(BaseModel):
Expand All @@ -39,6 +31,7 @@ class RAGPipelineOutput(BaseModel):
total_tokens: int
prompt_tokens: int
completion_tokens: int
completion_cost: float
time_taken: float
start_time: datetime.datetime
end_time: datetime.datetime
Expand All @@ -48,56 +41,98 @@ class RAGPipeline:
def __init__(
self,
vector_store: VectorStore,
config: RAGPipelineConfig,
top_k: int = 5,
search_type: str = "mmr",
):
self.vector_store = vector_store
self.query_enhancer = QueryEnhancer()
self.retrieval = FusionRetrieval(
vector_store=vector_store, top_k=top_k, search_type=search_type
)
self.response_synthesizer = ResponseSynthesizer()
self.config = config

if config.query_enhancer.enabled:
self.query_enhancer = QueryEnhancer(config=config)
if config.rerank_fusion.enabled:
self.retrieval = FusionRetrieval(
vector_store=vector_store, top_k=top_k, search_type=search_type
)
else:
self.retrieval = SimpleRetrievalWithoutFusion(
vector_store=vector_store, top_k=top_k, search_type=search_type
)
self.response_synthesizer = ResponseSynthesizer(config=config)
else:
if config.retrieval_re_ranker.enabled:
self.retrieval = SimpleRetrievalEngineWithRerank(
vector_store=vector_store, top_k=top_k, search_type=search_type
)
else:
self.retrieval = SimpleRetrievalEngine(
vector_store=vector_store, top_k=top_k, search_type=search_type
)
self.response_synthesizer = SimpleResponseSynthesizer(config=config)

def __call__(
self, question: str, chat_history: List[Tuple[str, str]] | None = None
):
) -> RAGPipelineOutput:
if chat_history is None:
chat_history = []

with get_openai_callback() as query_enhancer_cb, Timer() as query_enhancer_tb:
enhanced_query = self.query_enhancer.chain.invoke(
{"query": question, "chat_history": chat_history}
)

with Timer() as retrieval_tb:
retrieval_results = self.retrieval.chain.invoke(enhanced_query)

with get_openai_callback() as response_cb, Timer() as response_tb:
if self.config.query_enhancer.enabled:
with Timer() as query_enhancer_tb:
enhanced_query = self.query_enhancer.chain.invoke(
{"query": question, "chat_history": chat_history}
)
with Timer() as retrieval_tb:
retrieval_results = self.retrieval.chain.invoke(enhanced_query)
else:
with Timer() as retrieval_tb:
retrieval_results = self.retrieval(
question=question, language="en", top_k=self.retrieval.top_k, search_type=self.retrieval.search_type
)

with Timer() as response_tb:
response = self.response_synthesizer.chain.invoke(retrieval_results)

question = question if not self.config.query_enhancer.enabled else enhanced_query["standalone_query"]
usage_stats = litellm_token_and_cost_counter.get_totals()
time_taken = (
query_enhancer_tb.elapsed
+ retrieval_tb.elapsed
+ response_tb.elapsed
if self.config.query_enhancer.enabled
else retrieval_tb.elapsed + response_tb.elapsed
)
start_time = (
query_enhancer_tb.start
if self.config.query_enhancer.enabled
else retrieval_tb.start
)
end_time = response_tb.stop

if isinstance(retrieval_results, dict):
context = retrieval_results["context"]
else:
context = retrieval_results.context
sources = "\n".join(
[
item.metadata.get("source")
for item in context
]
)

output = RAGPipelineOutput(
question=enhanced_query["standalone_query"],
question=question,
answer=response["response"],
sources="\n".join(
[
item.metadata["source"]
for item in retrieval_results["context"]
]
),
sources=sources,
source_documents=response["context_str"],
system_prompt=response["response_prompt"],
model=response["response_model"],
total_tokens=query_enhancer_cb.total_tokens
+ response_cb.total_tokens,
prompt_tokens=query_enhancer_cb.prompt_tokens
+ response_cb.prompt_tokens,
completion_tokens=query_enhancer_cb.completion_tokens
+ response_cb.completion_tokens,
time_taken=query_enhancer_tb.elapsed
+ retrieval_tb.elapsed
+ response_tb.elapsed,
start_time=query_enhancer_tb.start,
end_time=response_tb.stop,
total_tokens=usage_stats.total_tokens,
prompt_tokens=usage_stats.prompt_tokens,
completion_tokens=usage_stats.completion_tokens,
completion_cost=usage_stats.total_cost,
time_taken=time_taken,
start_time=start_time,
end_time=end_time,
)

return output
Loading