Skip to content

Commit

Permalink
Update chatbot to evaluate the rag
Browse files Browse the repository at this point in the history
  • Loading branch information
mdciri committed Mar 3, 2025
1 parent 9705dbb commit bbe0089
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 10 deletions.
61 changes: 54 additions & 7 deletions apps/chatbot/src/modules/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from src.modules.engine import get_automerging_engine
from src.modules.handlers import EventHandler
from src.modules.presidio import PresidioPII
from src.modules.evaluator import Evaluator
from src.modules.utils import get_ssm_parameter

from dotenv import load_dotenv
Expand All @@ -46,6 +47,14 @@
"Your role is to provide accurate, professional, and helpful responses to users' queries regarding "
"the PagoPA DevPortal documentation available at: https://dev.developer.pagopa.it"
)
CONDENSE_PROMPT = (
"Given the following conversation between a user and an AI assistant and a follow up question from user, "
"rephrase the follow up question to be a standalone question.\n\n"
"Chat History:\n"
"{chat_history}\n"
"Follow Up Input: {query_str}\n"
"Standalone question:"
)
LANGFUSE_PUBLIC_KEY = get_ssm_parameter(os.getenv("CHB_LANGFUSE_PUBLIC_KEY"), os.getenv("LANGFUSE_INIT_PROJECT_PUBLIC_KEY"))
LANGFUSE_SECRET_KEY = get_ssm_parameter(os.getenv("CHB_LANGFUSE_SECRET_KEY"), os.getenv("LANGFUSE_INIT_PROJECT_SECRET_KEY"))
LANGFUSE_HOST = os.getenv("CHB_LANGFUSE_HOST")
Expand All @@ -72,6 +81,7 @@ def __init__(
self.pii = PresidioPII(config=params["config_presidio"])

self.model = get_llm()
self.judge = Evaluator()
self.embed_model = get_embed_model()
self.index = load_automerging_index_redis(
self.model,
Expand Down Expand Up @@ -282,11 +292,17 @@ def add_langfuse_score(
self,
trace_id: str,
name: str,
value: float,
value: float,
session_id: str | None = None,
user_id: str | None = None,
data_type: Literal['NUMERIC', 'BOOLEAN'] | None = None
) -> None:

with self.instrumentor.observe(trace_id=trace_id) as trace:
with self.instrumentor.observe(
trace_id = trace_id,
session_id = session_id,
user_id = user_id
) as trace:
trace_info = self.get_trace(trace_id, as_dict=False)
flag = True
for score in trace_info.scores:
Expand Down Expand Up @@ -374,21 +390,52 @@ def chat_generate(
engine_response = self.engine.chat(query_str, chat_history)
response_str = self._get_response_str(engine_response)

context = ""
retrieved_contexts = []
for node in engine_response.source_nodes:
url = REDIS_KVSTORE.get(
collection=f"hash_table_{INDEX_ID}",
key=node.metadata["filename"]
)
context += f"URL: {url}\n\n{node.text}\n\n------------------\n\n"
retrieved_contexts.append(f"URL: {url}\n\n{node.text}")

except Exception as e:
response_str = "Scusa, non posso elaborare la tua richiesta.\nProva a formulare una nuova domanda."
context = ""
retrieved_contexts = [""]
logger.error(f"Exception: {e}")

trace.update(output=self.mask_pii(response_str), metadata={"context": context})
trace.update(output=self.mask_pii(response_str), metadata={"context": retrieved_contexts})
trace.score(name="user-feedback", value=0, data_type="NUMERIC")
self.instrumentor.flush()

return response_str
return response_str, retrieved_contexts

def evaluate(
self,
query_str: str,
response_str: str,
retrieved_contexts: List[str],
trace_id: str,
session_id: str | None = None,
user_id: str | None = None,
messages: Optional[List[Dict[str, str]]] | None = None,
) -> None:

chat_history = self._messages_to_chathistory(messages)
condense_prompt = CONDENSE_PROMPT.format(chat_history=chat_history, query_str=query_str)
condense_query_response = asyncio_run(self.model.acomplete(condense_prompt))

scores = self.judge.evaluate(
query_str = condense_query_response.text,
response_str = response_str,
retrieved_contexts = retrieved_contexts
)

for key, value in scores.items():
self.add_langfuse_score(
trace_id = trace_id,
session_id = session_id,
user_id = user_id,
name = key,
value = value,
data_type = "NUMERIC"
)
85 changes: 85 additions & 0 deletions apps/chatbot/src/modules/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
from logging import getLogger
from dotenv import load_dotenv
from typing import List

from llama_index.core.async_utils import asyncio_run

from langchain_aws import ChatBedrockConverse
from langchain_aws import BedrockEmbeddings

from ragas import SingleTurnSample
from ragas.llms import LangchainLLMWrapper, BaseRagasLLM
from ragas.embeddings.base import LangchainEmbeddingsWrapper, BaseRagasEmbeddings
from ragas.metrics import (
Faithfulness,
ResponseRelevancy,
LLMContextPrecisionWithoutReference
)

from src.modules.utils import get_ssm_parameter


load_dotenv()
logger = getLogger(__name__)

PROVIDER = os.getenv("CHB_PROVIDER", "google")
assert PROVIDER in ["aws", "google"]

GOOGLE_API_KEY = get_ssm_parameter(name=os.getenv("CHB_GOOGLE_API_KEY"))
AWS_ACCESS_KEY_ID = os.getenv("CHB_AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("CHB_AWS_SECRET_ACCESS_KEY")
AWS_BEDROCK_LLM_REGION = os.getenv("CHB_AWS_BEDROCK_LLM_REGION")
AWS_BEDROCK_EMBED_REGION = os.getenv("CHB_AWS_BEDROCK_EMBED_REGION")
AWS_GUARDRAIL_ID = os.getenv("CHB_AWS_GUARDRAIL_ID")
AWS_GUARDRAIL_VERSION = os.getenv("CHB_AWS_GUARDRAIL_VERSION")

MODEL_ID = os.getenv("CHB_MODEL_ID")
MODEL_TEMPERATURE = os.getenv("CHB_MODEL_TEMPERATURE", "0.3")
MODEL_MAXTOKENS = os.getenv("CHB_MODEL_MAXTOKENS", "768")
EMBED_MODEL_ID = os.getenv("CHB_EMBED_MODEL_ID")

if PROVIDER == "aws":
LLM = LangchainLLMWrapper(
ChatBedrockConverse(
model=MODEL_ID,
temperature=float(MODEL_TEMPERATURE),
max_tokens=int(MODEL_MAXTOKENS)
)
)
EMBEDDER = LangchainEmbeddingsWrapper(
BedrockEmbeddings(
model_id = EMBED_MODEL_ID
)
)
logger.info("Loaded evaluation model successfully!")
else:
raise NotImplementedError()


class Evaluator():


def __init__(self, llm: BaseRagasLLM | None = None, embedder: BaseRagasEmbeddings | None = None):

self.llm = llm if llm else LLM
self.embedder = embedder if embedder else EMBEDDER

self.response_relevancy = ResponseRelevancy(llm=self.llm, embeddings=self.embedder)
self.context_precision = LLMContextPrecisionWithoutReference(llm=self.llm)
self.faithfulness = Faithfulness(llm=self.llm)


def evaluate(self, query_str: str, response_str: str, retrieved_contexts: List[str]) -> dict:

sample = SingleTurnSample(
user_input = query_str,
response = response_str,
retrieved_contexts = retrieved_contexts
)

return {
"response_relevancy": asyncio_run(self.response_relevancy.single_turn_ascore(sample)),
"context_precision": asyncio_run(self.context_precision.single_turn_ascore(sample)),
"faithfulness": asyncio_run(self.faithfulness.single_turn_ascore(sample)),
}
7 changes: 4 additions & 3 deletions apps/chatbot/src/modules/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from logging import getLogger

from llama_index.core.llms.llm import LLM
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.llms.bedrock_converse import BedrockConverse
from llama_index.embeddings.bedrock import BedrockEmbedding

from llama_index.llms.gemini import Gemini
from llama_index.embeddings.gemini import GeminiEmbedding
from google.generativeai.types import HarmCategory, HarmBlockThreshold
Expand Down Expand Up @@ -32,7 +33,7 @@
EMBED_MODEL_ID = os.getenv("CHB_EMBED_MODEL_ID")


def get_llm():
def get_llm() -> LLM:

if PROVIDER == "aws":

Expand Down Expand Up @@ -65,7 +66,7 @@ def get_llm():
return llm


def get_embed_model():
def get_embed_model() -> BaseEmbedding:

if PROVIDER == "aws":
embed_model = BedrockEmbedding(
Expand Down

0 comments on commit bbe0089

Please sign in to comment.