Skip to content

Commit

Permalink
Merge pull request #48 from ericmjl/bm25-querybot
Browse files Browse the repository at this point in the history
Bm25 querybot
  • Loading branch information
ericmjl authored Mar 16, 2024
2 parents 8f03794 + 5d28f7e commit c8d0517
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 688 deletions.
674 changes: 6 additions & 668 deletions docs/examples/simplebot.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/releases/v0.2.5.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ There are no new features in this release.

There are no deprecations in this release.

Note: The commit `48bb8c4` is related to version bump and does not introduce any new features or bug fixes. The commit `faa971d` is related to adding release notes and does not introduce any new features or bug fixes. Therefore, they are not included in the release notes.
Note: The commit `48bb8c4` is related to version bump and does not introduce any new features or bug fixes. The commit `faa971d` is related to adding release notes and does not introduce any new features or bug fixes. Therefore, they are not included in the release notes.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,4 @@ dependencies:
- litellm
- pydantic>=2.0
- pdfminer.six
- rank-bm25
21 changes: 14 additions & 7 deletions llamabot/bot/querybot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from llamabot.bot.simplebot import SimpleBot
from llamabot.components.messages import AIMessage, HumanMessage
from llamabot.components.docstore import DocumentStore
from llamabot.components.api import APIMixin
from llamabot.components.chatui import ChatUIMixin
from llamabot.components.messages import (
RetrievedMessage,
retrieve_messages_up_to_budget,
Expand All @@ -24,32 +24,36 @@
prompt_recorder_var = contextvars.ContextVar("prompt_recorder")


class QueryBot(SimpleBot, DocumentStore, APIMixin):
"""QueryBot is a bot that uses simple RAG to answer questions about a document."""
class QueryBot(SimpleBot, DocumentStore, ChatUIMixin):
"""QueryBot is a bot that uses the DocumentStore to answer questions about a document."""

def __init__(
self,
system_prompt: str,
collection_name: str,
initial_message: Optional[str] = None,
document_paths: Optional[Path | list[Path]] = None,
temperature: float = 0.0,
model_name: str = default_language_model(),
stream_target: str = "stdout",
**kwargs,
):
SimpleBot.__init__(
self,
system_prompt=system_prompt,
temperature=temperature,
model_name=model_name,
stream_target="stdout",
stream_target=stream_target,
**kwargs,
)
DocumentStore.__init__(self, collection_name=slugify(collection_name))
if document_paths:
self.add_documents(document_paths=document_paths)
self.response_budget = 2_000

def __call__(self, query: str, n_results: int = 20) -> AIMessage:
ChatUIMixin.__init__(self, initial_message)

def __call__(self, query: str, n_results: int = 10) -> AIMessage:
"""Query documents within QueryBot's document store.
We use RAG to query out documents.
Expand All @@ -70,5 +74,8 @@ def __call__(self, query: str, n_results: int = 20) -> AIMessage:
)
messages.extend(retrieved)
messages.append(HumanMessage(content=query))
response: AIMessage = self.stream_stdout(messages)
return response
if self.stream_target == "stdout":
response: AIMessage = self.stream_stdout(messages)
return response
elif self.stream_target == "panel":
return self.stream_panel(messages)
21 changes: 20 additions & 1 deletion llamabot/cli/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,35 @@

@app.command()
def chat(
model_name: str = typer.Option(
"mistral/mistral-medium", help="Name of the LLM to use."
),
initial_message: str = typer.Option(..., help="Initial message for the bot."),
panel: bool = typer.Option(True, help="Whether to use Panel or not."),
doc_path: Path = typer.Argument(
"", help="Path to the document you wish to chat with."
)
),
):
"""Chat with your document.
:param model_name: Name of the LLM to use.
:param panel: Whether to use Panel or not. If not, we default to using CLI chat.
:param initial_message: The initial message to send to the user.
:param doc_path: Path to the document you wish to chat with.
"""
stream_target = "stdout"
if panel:
stream_target = "panel"

bot = QueryBot(
system_prompt=(
"You are a bot that can answer questions about a document provided to you."
),
collection_name=doc_path.stem.lower().replace(" ", "-"),
document_paths=[doc_path],
model_name=model_name,
initial_message=initial_message,
stream_target=stream_target,
)
typer.echo(
(
Expand All @@ -35,6 +50,10 @@ def chat(
)
)

if panel:
print("Serving your document in a panel...")
bot.serve()

while True:
query = uniform_prompt()
exit_if_asked(query)
Expand Down
15 changes: 14 additions & 1 deletion llamabot/components/chatui.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,31 @@
class ChatUIMixin:
"""A mixin for a chat user interface."""

def __init__(self, callback_function: Optional[Callable] = None):
def __init__(
self,
initial_message: Optional[str] = None,
callback_function: Optional[Callable] = None,
):
self.callback_function = callback_function
if callback_function is None:
self.callback_function = lambda ai_message, user, instance: self(ai_message)

self.chat_interface = pn.chat.ChatInterface(
callback=self.callback_function, callback_exception="verbose"
)
if initial_message is not None:
self.chat_interface.send(initial_message, user="System", respond=False)

def servable(self):
"""Return the chat interface as a Panel servable object.
:returns: The chat interface as a Panel servable object.
"""
return self.chat_interface.servable()

def serve(self):
"""Serve the chat interface.
:returns: None
"""
self.chat_interface.show()
18 changes: 17 additions & 1 deletion llamabot/components/docstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
ChromaDB is a great default choice because of its simplicity and FOSS nature.
Hence we use it by default.
"""

from pathlib import Path
import chromadb
from hashlib import sha256
from chromadb import QueryResult
from llamabot.doc_processor import magic_load_doc, split_document
from tqdm.auto import tqdm
from rank_bm25 import BM25Okapi


class DocumentStore:
Expand Down Expand Up @@ -63,10 +65,24 @@ def retrieve(self, query: str, n_results: int = 10) -> list[str]:
:param query: The query to use to retrieve documents.
"""
# Use BM25 to get documents.
self.existing_records = self.collection.get()
tokenized_documents = [
doc.split() for doc in self.existing_records["documents"]
]
search_engine = BM25Okapi(tokenized_documents)
bm25_documents: list[str] = search_engine.get_top_n(
query.split(), self.existing_records["documents"], n=n_results
)
# Use Vectordb to get documents.
results: QueryResult = self.collection.query(
query_texts=query, n_results=n_results
)
return results["documents"][0]
vectordb_documents: list[str] = results["documents"][0]

# Return the union of the retrieved documents
union = set(vectordb_documents).union(bm25_documents)
return list(union)

def reset(self):
"""Reset the document store."""
Expand Down
2 changes: 1 addition & 1 deletion llamabot/prompt_library/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def commitbot():
"""
return SimpleBot(
"You are an expert user of Git.",
model_name="mistral/mistral-medium",
model_name="gpt-4-0125-preview",
stream_target="stdout",
)

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ dependencies = [
"chromadb",
"python-slugify",
"pydantic>=2.0",
"pdfminer.six"
"pdfminer.six",
"rank-bm25",
]
requires-python = ">3.10"
description = "A Pythonic interface to LLMs."
Expand Down
16 changes: 10 additions & 6 deletions tests/bot/test_querybot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,25 @@
collection_name=st.text(
alphabet="abcdefghijklmnopqrstuvwxyz0123456789", min_size=4, max_size=63
),
dummy_text=st.text(),
mock_response=st.text(),
human_message=st.text(),
dummy_text=st.text(min_size=400),
mock_response=st.text(min_size=4),
stream_target=st.one_of(st.just("panel"), st.just("stdout")),
)
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
def test_querybot_init(
tmp_path, system_prompt, collection_name, dummy_text, mock_response, human_message
def test_querybot(
tmp_path, system_prompt, collection_name, dummy_text, mock_response, stream_target
):
"""Test initialization of QueryBot."""
tempfile = tmp_path / "test.txt"
tempfile.write_text(dummy_text)

QueryBot(
bot = QueryBot(
system_prompt=system_prompt,
collection_name=collection_name,
document_paths=tempfile,
mock_response=mock_response,
stream_target=stream_target,
)

bot("How are you doing?")
bot.reset()
4 changes: 3 additions & 1 deletion tests/components/test_docstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def test_add_documents(tmp_path: Path):
retrieved_documents = docstore.retrieve("query", n_results=2)

# Assert that the retrieved documents match the added documents
assert retrieved_documents == ["content of document1", "content of document2"]
assert set(retrieved_documents) == set(
["content of document1", "content of document2"]
)

# Clean up the temporary collection
docstore.client.delete_collection(collection_name)

0 comments on commit c8d0517

Please sign in to comment.