Skip to content

Commit

Permalink
Revert "conditional edge that checks for hallucinations (#401)" (#402)
Browse files Browse the repository at this point in the history
This reverts commit afb0f35.
  • Loading branch information
davidx33 authored Dec 5, 2024
1 parent afb0f35 commit 8b86463
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 50 deletions.
50 changes: 4 additions & 46 deletions backend/retrieval_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, StateGraph
from pydantic import BaseModel, Field

from backend.retrieval_graph.configuration import AgentConfiguration
from backend.retrieval_graph.researcher_graph.graph import graph as researcher_graph
Expand Down Expand Up @@ -150,7 +149,6 @@ class Plan(TypedDict):
"steps": response["steps"],
"documents": "delete",
"query": state.messages[-1].content,
"num_response_attempts": 0,
}


Expand Down Expand Up @@ -209,57 +207,18 @@ async def respond(
"""
configuration = AgentConfiguration.from_runnable_config(config)
model = load_chat_model(configuration.response_model)
num_response_attempts = state.num_response_attempts
# TODO: add a re-ranker here
top_k = 20
context = format_docs(state.documents[:top_k])
prompt = configuration.response_system_prompt.format(context=context)
messages = [{"role": "system", "content": prompt}] + state.messages
response = await model.ainvoke(messages)
return {
"messages": [response],
"answer": response.content,
"num_response_attempts": num_response_attempts + 1,
}


def check_hallucination(state: AgentState) -> Literal["respond", "end"]:
"""Check if the answer is hallucinated."""
model = load_chat_model("openai/gpt-4o-mini")
top_k = 20
answer = state.answer
num_response_attempts = state.num_response_attempts
context = format_docs(state.documents[:top_k])
return {"messages": [response], "answer": response.content}

class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""

binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)

grade_hallucinations_llm = model.with_structured_output(GradeHallucinations)
grade_hallucinations_system_prompt = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
grade_hallucinations_prompt = (
"Set of facts: \n\n {context} \n\n LLM generation: {answer}"
)
grade_hallucinations_prompt_formatted = grade_hallucinations_prompt.format(
context=context, answer=answer
)
result = grade_hallucinations_llm.invoke(
[
{"role": "system", "content": grade_hallucinations_system_prompt},
{"role": "human", "content": grade_hallucinations_prompt_formatted},
]
)
if result.binary_score == "yes" or num_response_attempts >= 2:
return "end"
else:
return "respond"
# Define the graph


# Define the graph
builder = StateGraph(AgentState, input=InputState, config_schema=AgentConfiguration)
builder.add_node(create_research_plan)
builder.add_node(conduct_research)
Expand All @@ -268,9 +227,8 @@ class GradeHallucinations(BaseModel):
builder.add_edge(START, "create_research_plan")
builder.add_edge("create_research_plan", "conduct_research")
builder.add_conditional_edges("conduct_research", check_finished)
builder.add_conditional_edges(
"respond", check_hallucination, {"end": END, "respond": "respond"}
)
builder.add_edge("respond", END)

# Compile into a graph object that you can invoke and deploy.
graph = builder.compile()
graph.name = "RetrievalGraph"
5 changes: 1 addition & 4 deletions backend/retrieval_graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,5 @@ class AgentState(InputState):
documents: Annotated[list[Document], reduce_docs] = field(default_factory=list)
"""Populated by the retriever. This is a list of documents that the agent can reference."""
answer: str = field(default="")
"""Final answer. Useful for evaluations."""
"""Final answer. Useful for evaluations"""
query: str = field(default="")
"""The user's query."""
num_response_attempts: int = field(default=0)
"""The number of times the agent has tried to respond."""

0 comments on commit 8b86463

Please sign in to comment.