Skip to content

Commit

Permalink
fix lints, redirect mock_endpoint to agent-retriever
Browse files Browse the repository at this point in the history
Signed-off-by: Jack Luar <[email protected]>
  • Loading branch information
luarss committed Feb 4, 2025
1 parent 2647d1c commit 333c34e
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 66 deletions.
2 changes: 1 addition & 1 deletion backend/src/api/models/response_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class UserInput(BaseModel):


class ContextSource(BaseModel):
sources: str = ""
source: str = ""
context: str = ""


Expand Down
60 changes: 29 additions & 31 deletions backend/src/api/routers/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,30 +114,29 @@ async def get_hybrid_response(user_input: UserInput) -> ChatResponse:
context_sources = []
for i in result["context"]:
if "url" in i.metadata:
context_sources.append(ContextSource(
context=i.page_content,
source=i.metadata["url"]
))
context_sources.append(
ContextSource(context=i.page_content, source=i.metadata["url"])
)
elif "source" in i.metadata:
context_sources.append(ContextSource(
context=i.page_content,
source=i.metadata["source"]
))
context_sources.append(
ContextSource(context=i.page_content, source=i.metadata["source"])
)

if user_input.list_sources and user_input.list_context:
response = {
"response": result["answer"],
"context_sources": context_sources
}
response = {"response": result["answer"], "context_sources": context_sources}
elif user_input.list_sources:
response = {
"response": result["answer"],
"context_sources": [ContextSource(context="", source=cs.source) for cs in context_sources]
"response": result["answer"],
"context_sources": [
ContextSource(context="", source=cs.source) for cs in context_sources
],
}
elif user_input.list_context:
response = {
"response": result["answer"],
"context_sources": [ContextSource(context=cs.context, source="") for cs in context_sources]
"context_sources": [
ContextSource(context=cs.context, source="") for cs in context_sources
],
}
else:
response = {"response": result["answer"], "context_sources": []}
Expand Down Expand Up @@ -169,30 +168,29 @@ async def get_sim_response(user_input: UserInput) -> ChatResponse:
context_sources = []
for i in result["context"]:
if "url" in i.metadata:
context_sources.append(ContextSource(
context=i.page_content,
source=i.metadata["url"]
))
context_sources.append(
ContextSource(context=i.page_content, source=i.metadata["url"])
)
elif "source" in i.metadata:
context_sources.append(ContextSource(
context=i.page_content,
source=i.metadata["source"]
))
context_sources.append(
ContextSource(context=i.page_content, source=i.metadata["source"])
)

if user_input.list_sources and user_input.list_context:
response = {
"response": result["answer"],
"context_sources": context_sources
}
response = {"response": result["answer"], "context_sources": context_sources}
elif user_input.list_sources:
response = {
"response": result["answer"],
"context_sources": [ContextSource(context="", source=cs.source) for cs in context_sources]
"response": result["answer"],
"context_sources": [
ContextSource(context="", source=cs.source) for cs in context_sources
],
}
elif user_input.list_context:
response = {
"response": result["answer"],
"context_sources": [ContextSource(context=cs.context, source="") for cs in context_sources]
"response": result["answer"],
"context_sources": [
ContextSource(context=cs.context, source="") for cs in context_sources
],
}
else:
response = {"response": result["answer"], "context_sources": []}
Expand Down
37 changes: 26 additions & 11 deletions backend/src/api/routers/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ async def get_agent_response(user_input: UserInput) -> ChatResponse:
raise ValueError("RetrieverGraph not initialized.")
urls: list[str] = []
context: list[str] = []
context_sources: list[ContextSource] = []

if (
isinstance(output, list)
Expand All @@ -121,14 +122,12 @@ async def get_agent_response(user_input: UserInput) -> ChatResponse:

context_sources = []
tool_index = 1
for tool in tools:
urls.extend(list(output[tool_index].values())[0]["urls"])
context.append(list(output[tool_index].values())[0]["context"])
tool_index += 1

for url, context in zip(urls, [context]):
context_sources.append(ContextSource(context=context, source=url))
tool_index += 1
for tool_index, tool in enumerate(tools):
urls = list(output[tool_index].values())[0]["urls"]
context = list(output[tool_index].values())[0]["context"]

for _url, _context in zip(urls, context):
context_sources.append(ContextSource(context=_context, source=_url))
else:
llm_response = "LLM response extraction failed"
logging.error("LLM response extraction failed")
Expand All @@ -140,11 +139,27 @@ async def get_agent_response(user_input: UserInput) -> ChatResponse:
"tool": tools,
}
elif user_input.list_sources:
response = {"response": llm_response, "context_sources": [ContextSource(context="", source=cs.source) for cs in context_sources], "tool": tools}
response = {
"response": llm_response,
"context_sources": [
ContextSource(context="", source=cs.source) for cs in context_sources
],
"tool": tools,
}
elif user_input.list_context:
response = {"response": llm_response, "context_sources": [ContextSource(context=cs.context, source="") for cs in context_sources], "tool": tools}
response = {
"response": llm_response,
"context_sources": [
ContextSource(context=cs.context, source="") for cs in context_sources
],
"tool": tools,
}
else:
response = {"response": llm_response, "context_sources": [ContextSource(context="", source="")], "tool": tools}
response = {
"response": llm_response,
"context_sources": [ContextSource(context="", source="")],
"tool": tools,
}

return ChatResponse(**response)

Expand Down
14 changes: 9 additions & 5 deletions frontend/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def translate_chat_history_to_api(chat_history, max_pairs=4):


def display_sources_context(context_sources: list[dict[str, str]]):

with st.expander("Sources and Context"):
try:
if context_sources:
Expand Down Expand Up @@ -91,7 +90,6 @@ def response_generator(user_input: str) -> tuple[str, str] | tuple[None, None]:
context_sources = data.get("context_sources", [])
st.session_state.metadata[user_input] = {
"context_sources": context_sources,

}
return data.get("response", ""), context_sources
except requests.exceptions.RequestException as e:
Expand Down Expand Up @@ -144,7 +142,9 @@ def main() -> None:
user_message = st.session_state.chat_history[idx - 1]
if user_message["role"] == "user":
user_input = user_message["content"]
context_sources = st.session_state.metadata.get(user_input, {}).get("context_sources", [])
context_sources = st.session_state.metadata.get(user_input, {}).get(
"context_sources", []
)
display_sources_context(context_sources)

user_input = st.chat_input("Enter your queries ...")
Expand Down Expand Up @@ -223,8 +223,12 @@ def update_state() -> None:
# Handle thumbs up and thumbs down reactions
if thumbs_up or thumbs_down:
try:
selected_question = st.session_state.chat_history[-2]["content"] # Last user question
gen_ans = st.session_state.chat_history[-1]["content"] # Last AI response
selected_question = st.session_state.chat_history[-2][
"content"
] # Last user question
gen_ans = st.session_state.chat_history[-1][
"content"
] # Last AI response
metadata = st.session_state.metadata.get(selected_question, {})
context_sources = metadata.get("context_sources", [])

Expand Down
12 changes: 6 additions & 6 deletions frontend/utils/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@


import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from common.mongoClient import submit_feedback
from common.mongoClient import submit_feedback # type: ignore

load_dotenv()

Expand Down Expand Up @@ -152,8 +153,8 @@ def submit_feedback_to_google_sheet(
# Format sources and context as combined pairs
formatted_pairs = []
for cs in context_sources:
source = cs.get('source', '')
context = cs.get('context', '')
source = cs.get("source", "")
context = cs.get("context", "")
if source or context:
formatted_pairs.append(f"Source: {source}\nContext: {context}")

Expand All @@ -169,7 +170,6 @@ def submit_feedback_to_google_sheet(
reaction or "",
]

#
headers = sheet.row_values(1)
required_headers = [
"Question",
Expand All @@ -193,15 +193,15 @@ def submit_feedback_to_google_sheet(

def show_feedback_form(
questions: dict[str, int],
metadata: dict[str, dict[str, str]],
metadata: dict[str, dict[str, list]],
interactions: list[dict[str, str]],
) -> None:
"""
Display feedback form in the sidebar.
Args:
- questions (dict[str, int]): Dictionary of questions and indices.
- metadata (dict[str, dict[str, str]]): Metadata contains sources and context for each question.
- metadata (dict[str, dict[str, list]]): Metadata contains sources and context for each question.
- interactions (list[dict[str, str]]): List of chat interactions from st.session_state.chat_history
Returns:
Expand Down
31 changes: 19 additions & 12 deletions frontend/utils/mock_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def list_all_chains() -> Response:
return jsonify(["/chains/mock"])


@app.route("/chains/mock", methods=["POST"])
@app.route("/graphs/agent-retriever", methods=["POST"])
def chat_app() -> Response:
"""
Endpoint to handle chat requests.
Expand All @@ -25,19 +25,26 @@ def chat_app() -> Response:
"""
data: dict[str, Any] = request.get_json()
user_query: str = data.get("query", "")
list_sources: bool = data.get("list_sources", False)
list_context: bool = data.get("list_context", False)

list_sources: bool = data.get("list_sources", True)
list_context: bool = data.get("list_context", True)

dummy_context_sources = [
{"source": "https://mocksource1.com", "context": "This is Mock Context 1"},
{"source": "https://mocksource2.com", "context": "This is Mock Context 2"},
]
if not list_sources:
# drop the source keys
dummy_context_sources = [
{"source": "", "context": cs["context"]} for cs in dummy_context_sources
]
if not list_context:
# drop the context keys
dummy_context_sources = [
{"source": cs["source"], "context": ""} for cs in dummy_context_sources
]
response = {
"response": f"This is a mock response to your query: '{user_query}'",
"sources": [
"https://mocksource1.com",
"https://mocksource2.com",
"https://mocksource3.com",
]
if list_sources
else [],
"context": ["This is Mock Context"] if list_context else [],
"context_sources": dummy_context_sources,
}

return jsonify(response)
Expand Down

0 comments on commit 333c34e

Please sign in to comment.