diff --git a/core/quivr_core/rag/prompts.py b/core/quivr_core/rag/prompts.py index e24ca6bd26f7..dcb2b3ce146b 100644 --- a/core/quivr_core/rag/prompts.py +++ b/core/quivr_core/rag/prompts.py @@ -1,14 +1,14 @@ import datetime -from pydantic import ConfigDict, create_model -from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, + MessagesPlaceholder, PromptTemplate, SystemMessagePromptTemplate, - MessagesPlaceholder, ) +from langchain_core.prompts.base import BasePromptTemplate +from pydantic import ConfigDict, create_model class CustomPromptsDict(dict): @@ -258,6 +258,30 @@ def _define_custom_prompts() -> CustomPromptsDict: custom_prompts["TOOL_ROUTING_PROMPT"] = TOOL_ROUTING_PROMPT + system_message_zendesk_template = """ + + - You are a Zendesk Agent. + - You are answering a client query. + - Based on the following similar client tickets, provide a response to the client query in the same format. + + ------ Zendesk Similar Tickets ------ + {similar_tickets} + ------------------------------------- + + ------ Client Query ------ + {client_query} + -------------------------- + + Agent : + """ + + ZENDESK_TEMPLATE_PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_zendesk_template), + ] + ) + custom_prompts["ZENDESK_TEMPLATE_PROMPT"] = ZENDESK_TEMPLATE_PROMPT + return custom_prompts diff --git a/core/quivr_core/rag/quivr_rag_langgraph.py b/core/quivr_core/rag/quivr_rag_langgraph.py index 13f3cdbcf662..2b9e252a21e4 100644 --- a/core/quivr_core/rag/quivr_rag_langgraph.py +++ b/core/quivr_core/rag/quivr_rag_langgraph.py @@ -13,6 +13,7 @@ TypedDict, ) from uuid import UUID, uuid4 + import openai from langchain.retrievers import ContextualCompressionRetriever from langchain_cohere import CohereRerank @@ -27,6 +28,7 @@ from langgraph.graph.message import add_messages from langgraph.types import Send from pydantic import BaseModel, Field +from quivr_api.modules.vector.service.vector_service import VectorService from quivr_core.llm import LLMEndpoint from quivr_core.llm_tools.llm_tools import LLMToolFactory @@ -248,6 +250,7 @@ def __init__( retrieval_config: RetrievalConfig, llm: LLMEndpoint, vector_store: VectorStore | None = None, + vector_service: VectorService | None = None, ): """ Construct a QuivrQARAGLangGraph object. @@ -261,6 +264,7 @@ def __init__( self.retrieval_config = retrieval_config self.vector_store = vector_store self.llm_endpoint = llm + self.vector_service = vector_service self.graph = None @@ -752,6 +756,49 @@ def _sort_docs_by_relevance(self, docs: List[Document]) -> List[Document]: reverse=True, ) + def retrieve_full_documents_context(self, state: AgentState) -> AgentState: + task = state["tasks"] + docs = task.docs if task else [] + + relevant_knowledge = {} + for doc in docs: + knowledge_id = doc.metadata["knowledge_id"] + similarity_score = doc.metadata.get("similarity", 0) + if knowledge_id in relevant_knowledge: + relevant_knowledge[knowledge_id]["count"] += 1 + relevant_knowledge[knowledge_id]["max_similarity_score"] = max( + relevant_knowledge[knowledge_id]["max_similarity"], similarity_score + ) + else: + relevant_knowledge[knowledge_id] = { + "count": 1, + "max_similarity_score": similarity_score, + "index": doc.metadata["index"], + } + + # FIXME: Tweak this to return the most relevant knowledges + top_knowledge_ids = sorted( + relevant_knowledge.keys(), + key=lambda x: ( + relevant_knowledge[x]["max_similarity_score"], + relevant_knowledge[x]["count"], + ), + reverse=True, + )[:3] + + _docs = [] + + for knowledge_id in top_knowledge_ids: + _docs.append( + self.vector_service.get_vectors_by_knowledge_id( + knowledge_id, end_index=top_knowledge_ids[knowledge_id]["index"] + ) + ) + + task.set_docs(id=uuid4(), docs=_docs) # FIXME what is supposed to be id ? + + return {**state, "tasks": task} + def get_rag_context_length(self, state: AgentState, docs: List[Document]) -> int: final_inputs = self._build_rag_prompt_inputs(state, docs) msg = custom_prompts.RAG_ANSWER_PROMPT.format(**final_inputs) @@ -836,6 +883,25 @@ def bind_tools_to_llm(self, node_name: str): return self.llm_endpoint._llm.bind_tools(tools, tool_choice="any") return self.llm_endpoint._llm + def generate_zendesk_rag(self, state: AgentState) -> AgentState: + tasks = state["tasks"] + docs = tasks.docs if tasks else [] + messages = state["messages"] + user_task = messages[0].content + inputs = { + "similar_tickets": docs, + "client_query": user_task, + } + # state, inputs = self.reduce_rag_context( + # state, inputs, custom_prompts.RAG_ANSWER_PROMPT + # ) + + msg = custom_prompts.ZENDESK_TEMPLATE_PROMPT.format(**inputs) + llm = self.bind_tools_to_llm(self.generate_zendesk_rag.__name__) + response = llm.invoke(msg) + + return {**state, "messages": [response]} + def generate_rag(self, state: AgentState) -> AgentState: tasks = state["tasks"] docs = tasks.docs if tasks else [] diff --git a/core/tests/zendesk_rag_config_workflow.yaml b/core/tests/zendesk_rag_config_workflow.yaml new file mode 100644 index 000000000000..bbd58d34cc01 --- /dev/null +++ b/core/tests/zendesk_rag_config_workflow.yaml @@ -0,0 +1,39 @@ +{ + "max_files": 20, + "llm_config": { "temperature": 0.3, "max_context_tokens": 20000 }, + "max_history": 10, + "reranker_config": + { "model": "rerank-v3.5", "top_n": 5, "supplier": "cohere" }, + "workflow_config": + { + "name": "Standard RAG", + "nodes": + [ + { + "name": "START", + "edges": ["filter_history"], + "description": "Starting workflow", + }, + { + "name": "filter_history", + "edges": ["retrieve"], + "description": "Filtering history", + }, + { + "name": "retrieve", + "edges": ["retrieve_full_documents_context"], + "description": "Retrieving relevant information", + }, + { + "name": "retrieve_full_documents_context", + "edges": ["generate_zendesk_rag"], + "description": "Retrieving full tickets context", + }, + { + "name": "generate_zendesk_rag", + "edges": ["END"], + "description": "Generating answer", + }, + ], + }, +}