Skip to content

Commit

Permalink
feat(rag): parameterize chat similarity token limits (#141)
Browse files Browse the repository at this point in the history
Parameterize similarity chat token limits. The result of this is to
increase the context size and thus improve accuracy of results.
  • Loading branch information
mawandm authored Jul 27, 2024
1 parent 2817f0e commit 69430ef
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 1 deletion.
12 changes: 11 additions & 1 deletion nesis/rag/core/server/chat/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.types import TokenGen
from pydantic import BaseModel

Expand All @@ -23,6 +24,7 @@
)
from nesis.rag.core.open_ai.extensions.context_filter import ContextFilter
from nesis.rag.core.server.chunks.chunks_service import Chunk
from nesis.rag.core.settings.settings import Settings


class Completion(BaseModel):
Expand Down Expand Up @@ -78,8 +80,10 @@ def __init__(
vector_store_component: VectorStoreComponent,
embedding_component: EmbeddingComponent,
node_store_component: NodeStoreComponent,
settings: Settings,
) -> None:
self.llm_service = llm_component
self.settings = settings
self.vector_store_component = vector_store_component
self.storage_context = StorageContext.from_defaults(
vector_store=vector_store_component.vector_store,
Expand All @@ -104,11 +108,17 @@ def _chat_engine(
) -> BaseChatEngine:
if use_context:
vector_index_retriever = self.vector_store_component.get_retriever(
index=self.index, context_filter=context_filter
index=self.index,
context_filter=context_filter,
similarity_top_k=self.settings.vectorstore.similarity_top_k,
)
memory = ChatMemoryBuffer.from_defaults(
token_limit=self.settings.llm.token_limit
)
return ContextChatEngine.from_defaults(
system_prompt=system_prompt,
retriever=vector_index_retriever,
memory=memory,
service_context=self.service_context,
node_postprocessors=[
MetadataReplacementPostProcessor(target_metadata_key="window"),
Expand Down
5 changes: 5 additions & 0 deletions nesis/rag/core/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,15 @@ class LLMSettings(BaseModel):
"like `HuggingFaceH4/zephyr-7b-beta`. If not set, will load a tokenizer matching "
"gpt-3.5-turbo LLM.",
)
token_limit: int = Field(
9439,
description="The maximum number of chat memory tokens.",
)


class VectorstoreSettings(BaseModel):
database: Literal["chroma", "qdrant", "pgvector"]
similarity_top_k: int


class LocalSettings(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions nesis/rag/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ llm:
# Should be matching the selected model
max_new_tokens: 512
context_window: 3900
token_limit: ${NESIS_RAG_LLM_TOKEN_LIMIT:9439}
tokenizer: mistralai/Mistral-7B-Instruct-v0.2

embedding:
Expand All @@ -21,6 +22,7 @@ embedding:

vectorstore:
database: pgvector
similarity_top_k: ${NESIS_RAG_VECTORSTORE_SIMILARITY_TOP_K:5}

pgvector:
url: ${NESIS_RAG_PGVECTOR_URL:postgresql://postgres:password@localhost:65432/nesis}
Expand Down
Empty file.
62 changes: 62 additions & 0 deletions nesis/rag/tests/rag/core/server/chat/test_chat_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pathlib

import pytest
from injector import Injector
from llama_index.core.base.llms.types import ChatMessage

from nesis.rag.core.server.chat.chat_service import ChatService, Completion
from nesis.rag.core.server.ingest.ingest_service import IngestService
from nesis.rag.core.settings.settings import Settings
from nesis.rag import tests
from nesis.rag.core.server.ingest.model import IngestedDoc


@pytest.fixture
def injector(settings) -> Injector:
from nesis.rag.core.di import create_application_injector

return create_application_injector(settings=settings)


@pytest.fixture
def settings() -> Settings:
from nesis.rag.core.settings.settings import settings

return settings(
overrides={
"llm": {"mode": "mock", "token_limit": 100000},
"vectorstore": {"similarity_top_k": "20"},
}
)


def test_chat_service_similarity_top_k(injector):
"""
Test to ensure similarity_top_k setting takes effect.
"""
file_path: pathlib.Path = (
pathlib.Path(tests.__file__).parent.absolute() / "resources" / "rfc791.txt"
)

ingest_service = injector.get(IngestService)

ingested_list: list[IngestedDoc] = ingest_service.ingest_file(
file_name=file_path.name,
file_data=file_path,
metadata={
"file_name": str(file_path.absolute()),
"datasource": "rfc-documents",
},
)

chat_service = injector.get(ChatService)
completion: Completion = chat_service.chat(
use_context=True,
messages=[
ChatMessage.from_str(
content="describe the internet protocol from the darpa internet program"
)
],
)

assert len(completion.sources) == 20

0 comments on commit 69430ef

Please sign in to comment.