diff --git a/apps/chatbot/.env.example b/apps/chatbot/.env.example index 2ca70aad6..45e549b07 100644 --- a/apps/chatbot/.env.example +++ b/apps/chatbot/.env.example @@ -19,6 +19,7 @@ CHB_ENGINE_SIMILARITY_TOPK=... CHB_ENGINE_USE_ASYNC=True CHB_ENGINE_USE_STREAMING=... CHB_GOOGLE_API_KEY=... +CHB_GOOGLE_PROJECT_ID=... CHB_LANGFUSE_HOST=http://localhost:3000 CHB_LANGFUSE_PUBLIC_KEY=/nonexistent/ssmpath CHB_LANGFUSE_SECRET_KEY=/nonexistent/ssmpath diff --git a/apps/chatbot/src/modules/evaluator.py b/apps/chatbot/src/modules/evaluator.py index 8786e4702..2607a8dda 100644 --- a/apps/chatbot/src/modules/evaluator.py +++ b/apps/chatbot/src/modules/evaluator.py @@ -3,10 +3,14 @@ from dotenv import load_dotenv from typing import List +import google.auth + from llama_index.core.async_utils import asyncio_run +from langchain_core.outputs import LLMResult, ChatGeneration from langchain_aws import ChatBedrockConverse from langchain_aws import BedrockEmbeddings +from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings from ragas import SingleTurnSample from ragas.llms import LangchainLLMWrapper, BaseRagasLLM @@ -27,6 +31,7 @@ assert PROVIDER in ["aws", "google"] GOOGLE_API_KEY = get_ssm_parameter(name=os.getenv("CHB_GOOGLE_API_KEY")) +GOOGLE_PROJECT_ID = os.getenv("CHB_GOOGLE_PROJECT_ID") AWS_ACCESS_KEY_ID = os.getenv("CHB_AWS_ACCESS_KEY_ID") AWS_SECRET_ACCESS_KEY = os.getenv("CHB_AWS_SECRET_ACCESS_KEY") AWS_BEDROCK_LLM_REGION = os.getenv("CHB_AWS_BEDROCK_LLM_REGION") @@ -54,7 +59,56 @@ ) logger.info("Loaded evaluation model successfully!") else: - raise NotImplementedError() + + def gemini_is_finished_parser(response: LLMResult) -> bool: + is_finished_list = [] + for g in response.flatten(): + resp = g.generations[0][0] + + # Check generation_info first + if resp.generation_info is not None: + finish_reason = resp.generation_info.get("finish_reason") + if finish_reason is not None: + is_finished_list.append( + finish_reason in ["STOP", "MAX_TOKENS"] + ) + continue + + # Check response_metadata as fallback + if isinstance(resp, ChatGeneration) and resp.message is not None: + metadata = resp.message.response_metadata + if metadata.get("finish_reason"): + is_finished_list.append( + metadata["finish_reason"] in ["STOP", "MAX_TOKENS"] + ) + elif metadata.get("stop_reason"): + is_finished_list.append( + metadata["stop_reason"] in ["STOP", "MAX_TOKENS"] + ) + + # If no finish reason found, default to True + if not is_finished_list: + is_finished_list.append(True) + + return all(is_finished_list) + + creds, _ = google.auth.default(quota_project_id=GOOGLE_PROJECT_ID) + + LLM = LangchainLLMWrapper( + ChatVertexAI( + credentials=creds, + model_name=MODEL_ID, + temperature=float(MODEL_TEMPERATURE), + max_tokens=int(MODEL_MAXTOKENS) + ), + is_finished_parser=gemini_is_finished_parser + ) + EMBEDDER = LangchainEmbeddingsWrapper( + VertexAIEmbeddings( + credentials=creds, + model_name=EMBED_MODEL_ID + ) + ) class Evaluator():