Skip to content

Commit

Permalink
Parameterize token limit and similarity top k
Browse files Browse the repository at this point in the history
  • Loading branch information
mawandm committed Jul 27, 2024
1 parent 2817f0e commit 964e9ce
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 1 deletion.
12 changes: 11 additions & 1 deletion nesis/rag/core/server/chat/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.types import TokenGen
from pydantic import BaseModel

Expand All @@ -23,6 +24,7 @@
)
from nesis.rag.core.open_ai.extensions.context_filter import ContextFilter
from nesis.rag.core.server.chunks.chunks_service import Chunk
from nesis.rag.core.settings.settings import Settings


class Completion(BaseModel):
Expand Down Expand Up @@ -78,8 +80,10 @@ def __init__(
vector_store_component: VectorStoreComponent,
embedding_component: EmbeddingComponent,
node_store_component: NodeStoreComponent,
settings: Settings,
) -> None:
self.llm_service = llm_component
self.settings = settings
self.vector_store_component = vector_store_component
self.storage_context = StorageContext.from_defaults(
vector_store=vector_store_component.vector_store,
Expand All @@ -104,11 +108,17 @@ def _chat_engine(
) -> BaseChatEngine:
if use_context:
vector_index_retriever = self.vector_store_component.get_retriever(
index=self.index, context_filter=context_filter
index=self.index,
context_filter=context_filter,
similarity_top_k=self.settings.vectorstore.similarity_top_k,
)
memory = ChatMemoryBuffer.from_defaults(
token_limit=self.settings.llm.token_limit
)
return ContextChatEngine.from_defaults(
system_prompt=system_prompt,
retriever=vector_index_retriever,
memory=memory,
service_context=self.service_context,
node_postprocessors=[
MetadataReplacementPostProcessor(target_metadata_key="window"),
Expand Down
5 changes: 5 additions & 0 deletions nesis/rag/core/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,15 @@ class LLMSettings(BaseModel):
"like `HuggingFaceH4/zephyr-7b-beta`. If not set, will load a tokenizer matching "
"gpt-3.5-turbo LLM.",
)
token_limit: int = Field(
9439,
description="The maximum number of chat memory tokens.",
)


class VectorstoreSettings(BaseModel):
database: Literal["chroma", "qdrant", "pgvector"]
similarity_top_k: int


class LocalSettings(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions nesis/rag/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ llm:
# Should be matching the selected model
max_new_tokens: 512
context_window: 3900
token_limit: ${NESIS_RAG_LLM_TOKEN_LIMIT:9439}
tokenizer: mistralai/Mistral-7B-Instruct-v0.2

embedding:
Expand All @@ -21,6 +22,7 @@ embedding:

vectorstore:
database: pgvector
similarity_top_k: ${NESIS_RAG_VECTORSTORE_SIMILARITY_TOP_K:5}

pgvector:
url: ${NESIS_RAG_PGVECTOR_URL:postgresql://postgres:password@localhost:65432/nesis}
Expand Down
Empty file.
62 changes: 62 additions & 0 deletions nesis/rag/tests/rag/core/server/chat/test_chat_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pathlib

import pytest
from injector import Injector
from llama_index.core.base.llms.types import ChatMessage

from nesis.rag.core.server.chat.chat_service import ChatService, Completion
from nesis.rag.core.server.ingest.ingest_service import IngestService
from nesis.rag.core.settings.settings import Settings
from nesis.rag import tests
from nesis.rag.core.server.ingest.model import IngestedDoc


@pytest.fixture
def injector(settings) -> Injector:
from nesis.rag.core.di import create_application_injector

return create_application_injector(settings=settings)


@pytest.fixture
def settings() -> Settings:
from nesis.rag.core.settings.settings import settings

return settings(
overrides={
"llm": {"mode": "mock", "token_limit": 100000},
"vectorstore": {"similarity_top_k": "20"},
}
)


def test_chat_service_similarity_top_k(injector):
"""
Test to ensure similarity_top_k setting takes effect.
"""
file_path: pathlib.Path = (
pathlib.Path(tests.__file__).parent.absolute() / "resources" / "rfc791.txt"
)

ingest_service = injector.get(IngestService)

ingested_list: list[IngestedDoc] = ingest_service.ingest_file(
file_name=file_path.name,
file_data=file_path,
metadata={
"file_name": str(file_path.absolute()),
"datasource": "rfc-documents",
},
)

chat_service = injector.get(ChatService)
completion: Completion = chat_service.chat(
use_context=True,
messages=[
ChatMessage.from_str(
content="describe the internet protocol from the darpa internet program"
)
],
)

assert len(completion.sources) == 20

0 comments on commit 964e9ce

Please sign in to comment.