From 7ff22970732429f968ea01de9bbca9c90ea74cd3 Mon Sep 17 00:00:00 2001 From: Akrem Abayed Date: Thu, 4 Jan 2024 07:34:09 +0100 Subject: [PATCH] comparison logic --- .../routers/evaluation_router.py | 30 +++++++++ .../services/evaluation_service.py | 62 ++++++++++++++++++- 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/agenta-backend/agenta_backend/routers/evaluation_router.py b/agenta-backend/agenta_backend/routers/evaluation_router.py index c28c0aad2..57d7c228c 100644 --- a/agenta-backend/agenta_backend/routers/evaluation_router.py +++ b/agenta-backend/agenta_backend/routers/evaluation_router.py @@ -242,3 +242,33 @@ async def webhook_example_fake(): random_generator = secrets.SystemRandom() random_number = random_generator.random() return {"score": random_number} + + +@router.get( + "/evaluation_scenarios/comparison-results/", + response_model=List, +) +async def fetch_evaluation_scenarios( + evaluations_ids: str, + testset_id: str, + app_variant_id: str, + request: Request, +): + """Fetches evaluation scenarios for a given evaluation ID. + + Arguments: + evaluation_id (str): The ID of the evaluation for which to fetch scenarios. + + Raises: + HTTPException: If the evaluation is not found or access is denied. + + Returns: + List[EvaluationScenario]: A list of evaluation scenarios. + """ + evaluations_ids_list = evaluations_ids.split(',') + user_org_data: dict = await get_user_and_org_id(request.state.user_id) + eval_scenarios = await evaluation_service.compare_evaluations_scenarios( + evaluations_ids_list, testset_id, app_variant_id, **user_org_data + ) + + return eval_scenarios \ No newline at end of file diff --git a/agenta-backend/agenta_backend/services/evaluation_service.py b/agenta-backend/agenta_backend/services/evaluation_service.py index 91e69f42e..2055d7f23 100644 --- a/agenta-backend/agenta_backend/services/evaluation_service.py +++ b/agenta-backend/agenta_backend/services/evaluation_service.py @@ -25,7 +25,7 @@ ) from agenta_backend.models import converters from agenta_backend.services import db_manager -from agenta_backend.services.db_manager import query, get_user +from agenta_backend.services.db_manager import fetch_app_variant_by_id, query, get_user from agenta_backend.utils.common import engine, check_access_to_app from agenta_backend.services.security.sandbox import execute_code_safely from agenta_backend.models.db_models import ( @@ -1042,3 +1042,63 @@ async def retrieve_evaluation_results( detail=f"You do not have access to this app: {str(evaluation.app.id)}", ) return await converters.aggregated_result_to_pydantic(evaluation.aggregated_results) + + + +async def compare_evaluations_scenarios(evaluations_ids: List[str], testset_id: str, app_variant_id: str, **user_org_data: dict): + all_scenarios = [] + grouped_scenarios = {} + for evaluation_id in evaluations_ids: + eval_scenarios = await fetch_evaluation_scenarios_for_evaluation( + evaluation_id, **user_org_data + ) + all_scenarios.append(eval_scenarios) + + app_variant_db = await fetch_app_variant_by_id(app_variant_id) + testset = await db_manager.fetch_testset_by_id(testset_id=testset_id) + + inputs = app_variant_db.parameters.get("inputs", []) + # inputs: [{'name': 'country'}] + formatted_inputs = extract_inputs_values_from_tesetset(inputs, testset.csvdata) + # formatted_inputs: [{'input_name': 'country', 'input_values': ['Nauru', 'Tuvalu'...]}] + + print(formatted_inputs) + groupped_scenarios_by_inputs=find_scenarios_by_input(formatted_inputs, all_scenarios) + print(groupped_scenarios_by_inputs) + return groupped_scenarios_by_inputs + + +def extract_inputs_values_from_tesetset(inputs, testset): + extracted_values = [] + + for input_item in inputs: + key_name = input_item['name'] + values = [entry[key_name] for entry in testset if key_name in entry] + + # Create a dictionary for each input with its values + input_dict = {'input_name': key_name, 'input_values': values} + extracted_values.append(input_dict) + + return extracted_values + + +def find_scenarios_by_input(formatted_inputs, all_scenarios): + results = [] + flattened_scenarios = [scenario for sublist in all_scenarios for scenario in sublist] + + for formatted_input in formatted_inputs: + input_name = formatted_input['input_name'] + for input_value in formatted_input['input_values']: + matching_scenarios = [ + scenario for scenario in flattened_scenarios + if any(input_item.name == input_name and input_item.value == input_value + for input_item in scenario.inputs) + ] + + results.append({ + 'input_name': input_name, + 'input_value': input_value, + 'scenarios': matching_scenarios + }) + + return results \ No newline at end of file