Skip to content

Commit

Permalink
add: zendesk workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
chloedia committed Feb 3, 2025
1 parent c26d3ce commit ffb8b37
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 3 deletions.
30 changes: 27 additions & 3 deletions core/quivr_core/rag/prompts.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import datetime
from pydantic import ConfigDict, create_model

from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_core.prompts.base import BasePromptTemplate
from pydantic import ConfigDict, create_model


class CustomPromptsDict(dict):
Expand Down Expand Up @@ -258,6 +258,30 @@ def _define_custom_prompts() -> CustomPromptsDict:

custom_prompts["TOOL_ROUTING_PROMPT"] = TOOL_ROUTING_PROMPT

system_message_zendesk_template = """
- You are a Zendesk Agent.
- You are answering a client query.
- Based on the following similar client tickets, provide a response to the client query in the same format.
------ Zendesk Similar Tickets ------
{similar_tickets}
-------------------------------------
------ Client Query ------
{client_query}
--------------------------
Agent :
"""

ZENDESK_TEMPLATE_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_zendesk_template),
]
)
custom_prompts["ZENDESK_TEMPLATE_PROMPT"] = ZENDESK_TEMPLATE_PROMPT

return custom_prompts


Expand Down
66 changes: 66 additions & 0 deletions core/quivr_core/rag/quivr_rag_langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TypedDict,
)
from uuid import UUID, uuid4

import openai
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
Expand All @@ -27,6 +28,7 @@
from langgraph.graph.message import add_messages
from langgraph.types import Send
from pydantic import BaseModel, Field
from quivr_api.modules.vector.service.vector_service import VectorService

from quivr_core.llm import LLMEndpoint
from quivr_core.llm_tools.llm_tools import LLMToolFactory
Expand Down Expand Up @@ -248,6 +250,7 @@ def __init__(
retrieval_config: RetrievalConfig,
llm: LLMEndpoint,
vector_store: VectorStore | None = None,
vector_service: VectorService | None = None,
):
"""
Construct a QuivrQARAGLangGraph object.
Expand All @@ -261,6 +264,7 @@ def __init__(
self.retrieval_config = retrieval_config
self.vector_store = vector_store
self.llm_endpoint = llm
self.vector_service = vector_service

self.graph = None

Expand Down Expand Up @@ -752,6 +756,49 @@ def _sort_docs_by_relevance(self, docs: List[Document]) -> List[Document]:
reverse=True,
)

def retrieve_full_documents_context(self, state: AgentState) -> AgentState:
task = state["tasks"]
docs = task.docs if task else []

relevant_knowledge = {}
for doc in docs:
knowledge_id = doc.metadata["knowledge_id"]
similarity_score = doc.metadata.get("similarity", 0)
if knowledge_id in relevant_knowledge:
relevant_knowledge[knowledge_id]["count"] += 1
relevant_knowledge[knowledge_id]["max_similarity_score"] = max(
relevant_knowledge[knowledge_id]["max_similarity"], similarity_score
)
else:
relevant_knowledge[knowledge_id] = {
"count": 1,
"max_similarity_score": similarity_score,
"index": doc.metadata["index"],
}

# FIXME: Tweak this to return the most relevant knowledges
top_knowledge_ids = sorted(
relevant_knowledge.keys(),
key=lambda x: (
relevant_knowledge[x]["max_similarity_score"],
relevant_knowledge[x]["count"],
),
reverse=True,
)[:3]

_docs = []

for knowledge_id in top_knowledge_ids:
_docs.append(
self.vector_service.get_vectors_by_knowledge_id(
knowledge_id, end_index=top_knowledge_ids[knowledge_id]["index"]
)
)

task.set_docs(id=uuid4(), docs=_docs) # FIXME what is supposed to be id ?

return {**state, "tasks": task}

def get_rag_context_length(self, state: AgentState, docs: List[Document]) -> int:
final_inputs = self._build_rag_prompt_inputs(state, docs)
msg = custom_prompts.RAG_ANSWER_PROMPT.format(**final_inputs)
Expand Down Expand Up @@ -836,6 +883,25 @@ def bind_tools_to_llm(self, node_name: str):
return self.llm_endpoint._llm.bind_tools(tools, tool_choice="any")
return self.llm_endpoint._llm

def generate_zendesk_rag(self, state: AgentState) -> AgentState:
tasks = state["tasks"]
docs = tasks.docs if tasks else []
messages = state["messages"]
user_task = messages[0].content
inputs = {
"similar_tickets": docs,
"client_query": user_task,
}
# state, inputs = self.reduce_rag_context(
# state, inputs, custom_prompts.RAG_ANSWER_PROMPT
# )

msg = custom_prompts.ZENDESK_TEMPLATE_PROMPT.format(**inputs)
llm = self.bind_tools_to_llm(self.generate_zendesk_rag.__name__)
response = llm.invoke(msg)

return {**state, "messages": [response]}

def generate_rag(self, state: AgentState) -> AgentState:
tasks = state["tasks"]
docs = tasks.docs if tasks else []
Expand Down
39 changes: 39 additions & 0 deletions core/tests/zendesk_rag_config_workflow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"max_files": 20,
"llm_config": { "temperature": 0.3, "max_context_tokens": 20000 },
"max_history": 10,
"reranker_config":
{ "model": "rerank-v3.5", "top_n": 5, "supplier": "cohere" },
"workflow_config":
{
"name": "Standard RAG",
"nodes":
[
{
"name": "START",
"edges": ["filter_history"],
"description": "Starting workflow",
},
{
"name": "filter_history",
"edges": ["retrieve"],
"description": "Filtering history",
},
{
"name": "retrieve",
"edges": ["retrieve_full_documents_context"],
"description": "Retrieving relevant information",
},
{
"name": "retrieve_full_documents_context",
"edges": ["generate_zendesk_rag"],
"description": "Retrieving full tickets context",
},
{
"name": "generate_zendesk_rag",
"edges": ["END"],
"description": "Generating answer",
},
],
},
}

0 comments on commit ffb8b37

Please sign in to comment.