Skip to content

Commit

Permalink
fix: newly created rag db not storing system_message (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
quitrk authored Mar 4, 2025
1 parent f4250c3 commit 3329ee7
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
2 changes: 0 additions & 2 deletions skynet/modules/ttt/assistant/v1/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class RagConfig(RagPayload):


class AssistantDocumentPayload(DocumentPayload):
hint: HintType = HintType.CONVERSATION
use_only_rag_data: bool = False

model_config = {
Expand All @@ -64,7 +63,6 @@ class AssistantDocumentPayload(DocumentPayload):
'text': 'User provided context here (will be appended to the RAG one)',
'prompt': 'User prompt here',
'max_completion_tokens': None,
'hint': 'conversation',
'use_only_rag_data': False, # If True and a vector store is available, only the RAG data will be used for assistance
}
]
Expand Down
20 changes: 10 additions & 10 deletions skynet/modules/ttt/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,25 +64,25 @@ def format_docs(docs: list[Document]) -> str:

async def assist(model: BaseChatModel, payload: DocumentPayload, customer_id: Optional[str] = None) -> str:
store = await get_vector_store()
vector_store = await store.get(customer_id)
config = await store.get_config(customer_id)
customer_store = await store.get(customer_id)
retriever = None
system_message = None
question = payload.prompt
is_generated_question = False

base_retriever = vector_store.as_retriever(search_kwargs={'k': 3}) if vector_store else None
retriever = (
ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever)
if base_retriever
else None
)
if customer_store:
config = await store.get_config(customer_id)
system_message = config.system_message
base_retriever = customer_store.as_retriever(search_kwargs={'k': 3})
retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever)

if retriever and payload.text:
question_payload = DocumentPayload(**(payload.model_dump() | {'prompt': assistant_rag_question_extractor}))
question = await process_text(model, question_payload)
is_generated_question = True

log.info(
f'Using {"generated" if is_generated_question else ""} question: {question} and system message: {config.system_message or "default"}'
f'Using {"generated " if is_generated_question else ""}question: {question}. System message: {system_message or "default"}'
)

template = ChatPromptTemplate(
Expand All @@ -91,7 +91,7 @@ async def assist(model: BaseChatModel, payload: DocumentPayload, customer_id: Op
use_only_rag_data=payload.use_only_rag_data,
text=payload.text,
prompt=payload.prompt,
system_message=config.system_message,
system_message=system_message,
)
)

Expand Down
2 changes: 1 addition & 1 deletion skynet/modules/ttt/rag/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def update_from_urls(self, payload: RagPayload, store_id: str) -> Optional
return await self.update_config(store_id, system_message=payload.system_message)

await db.rpush(RUNNING_RAG_KEY, store_id)
config = RagConfig(urls=payload.urls, max_depth=payload.max_depth)
config = RagConfig(**payload.model_dump())
await db.set(store_id, RagConfig.model_dump_json(config))

task = asyncio.create_task(self.workflow(payload, store_id))
Expand Down

0 comments on commit 3329ee7

Please sign in to comment.