Skip to content

Commit

Permalink
Update evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
mdciri committed Mar 3, 2025
1 parent bbe0089 commit 101990e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
1 change: 1 addition & 0 deletions apps/chatbot/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ CHB_ENGINE_SIMILARITY_TOPK=...
CHB_ENGINE_USE_ASYNC=True
CHB_ENGINE_USE_STREAMING=...
CHB_GOOGLE_API_KEY=...
CHB_GOOGLE_PROJECT_ID=...
CHB_LANGFUSE_HOST=http://localhost:3000
CHB_LANGFUSE_PUBLIC_KEY=/nonexistent/ssmpath
CHB_LANGFUSE_SECRET_KEY=/nonexistent/ssmpath
Expand Down
56 changes: 55 additions & 1 deletion apps/chatbot/src/modules/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from dotenv import load_dotenv
from typing import List

import google.auth

from llama_index.core.async_utils import asyncio_run

from langchain_core.outputs import LLMResult, ChatGeneration
from langchain_aws import ChatBedrockConverse
from langchain_aws import BedrockEmbeddings
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings

from ragas import SingleTurnSample
from ragas.llms import LangchainLLMWrapper, BaseRagasLLM
Expand All @@ -27,6 +31,7 @@
assert PROVIDER in ["aws", "google"]

GOOGLE_API_KEY = get_ssm_parameter(name=os.getenv("CHB_GOOGLE_API_KEY"))
GOOGLE_PROJECT_ID = os.getenv("CHB_GOOGLE_PROJECT_ID")
AWS_ACCESS_KEY_ID = os.getenv("CHB_AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("CHB_AWS_SECRET_ACCESS_KEY")
AWS_BEDROCK_LLM_REGION = os.getenv("CHB_AWS_BEDROCK_LLM_REGION")
Expand Down Expand Up @@ -54,7 +59,56 @@
)
logger.info("Loaded evaluation model successfully!")
else:
raise NotImplementedError()

def gemini_is_finished_parser(response: LLMResult) -> bool:
is_finished_list = []
for g in response.flatten():
resp = g.generations[0][0]

# Check generation_info first
if resp.generation_info is not None:
finish_reason = resp.generation_info.get("finish_reason")
if finish_reason is not None:
is_finished_list.append(
finish_reason in ["STOP", "MAX_TOKENS"]
)
continue

# Check response_metadata as fallback
if isinstance(resp, ChatGeneration) and resp.message is not None:
metadata = resp.message.response_metadata
if metadata.get("finish_reason"):
is_finished_list.append(
metadata["finish_reason"] in ["STOP", "MAX_TOKENS"]
)
elif metadata.get("stop_reason"):
is_finished_list.append(
metadata["stop_reason"] in ["STOP", "MAX_TOKENS"]
)

# If no finish reason found, default to True
if not is_finished_list:
is_finished_list.append(True)

return all(is_finished_list)

creds, _ = google.auth.default(quota_project_id=GOOGLE_PROJECT_ID)

LLM = LangchainLLMWrapper(
ChatVertexAI(
credentials=creds,
model_name=MODEL_ID,
temperature=float(MODEL_TEMPERATURE),
max_tokens=int(MODEL_MAXTOKENS)
),
is_finished_parser=gemini_is_finished_parser
)
EMBEDDER = LangchainEmbeddingsWrapper(
VertexAIEmbeddings(
credentials=creds,
model_name=EMBED_MODEL_ID
)
)


class Evaluator():
Expand Down

0 comments on commit 101990e

Please sign in to comment.