Skip to content

Commit

Permalink
Merge pull request #1987 from Agenta-AI/feature/age-573-evaluators-fa…
Browse files Browse the repository at this point in the history
…il-gracefully-when-we-send-a-dict-to-a-str-only

[Enhancement]: Handle non-string outputs gracefully in auto_contains_json evaluator
  • Loading branch information
jp-agenta authored Aug 23, 2024
2 parents b224f10 + 2402f94 commit 532a4bb
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 179 deletions.
4 changes: 2 additions & 2 deletions agenta-backend/agenta_backend/routers/evaluators_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ async def evaluator_run(
)
return result
except Exception as e:
logger.error(f"Error while running evaluator: {str(e)}")
logger.error(f"Error while running {evaluator_key} evaluator: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"message": "Error while running evaluator",
"message": f"Error while running {evaluator_key} evaluator",
"stacktrace": traceback.format_exc(),
},
)
Expand Down
130 changes: 83 additions & 47 deletions agenta-backend/agenta_backend/services/evaluators_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,71 @@
logger.setLevel(logging.DEBUG)


def validate_string_output(
evaluator_key: str, output: Union[str, Dict[str, Any]]
) -> str:
"""Checks and validate the output to be of type string.
Args:
evaluator_key (str): the key of the evaluator
output (Union[str, Dict[str, Any]]): the llm response
Raises:
Exception: requires output to be a string
Returns:
str: output
"""

output = output.get("data", "") if isinstance(output, dict) else output
if not isinstance(output, str):
raise Exception(
f"Evaluator {evaluator_key} requires the output to be a string, but received {type(output).__name__} instead. "
)
return output


def validate_json_output(
evaluator_key: str, output: Union[str, Dict[str, Any]]
) -> Union[str, dict]:
"""Checks and validate the output to be of type JSON string or dictionary.
Args:
evaluator_key (str): the key of the evaluator
output (Union[str, Dict[str, Any]]): the llm response
Raises:
Exception: requires output to be a JSON string
Returns:
str, dict: output
"""

output = output.get("data", "") if isinstance(output, dict) else output
if isinstance(output, dict):
output = json.dumps(output)
elif isinstance(output, str):
try:
json.loads(output)
except json.JSONDecodeError:
raise Exception(
f"Evaluator {evaluator_key} requires the output to be a JSON string or object."
)

if not isinstance(
output,
(
str,
dict,
),
):
raise Exception(
f"Evaluator {evaluator_key} requires the output to be either a JSON string or object, but received {type(output).__name__} instead."
)

return output


async def map(
mapping_input: EvaluatorMappingInputInterface,
) -> EvaluatorMappingOutputInterface:
Expand Down Expand Up @@ -94,9 +159,9 @@ async def auto_exact_match(
Returns:
Result: A Result object containing the evaluation result.
"""
if not isinstance(output, str):
output = output.get("data", "")

try:
output = validate_string_output("exact_match", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {"ground_truth": correct_answer, "prediction": output}
response = exact_match(input=EvaluatorInputInterface(**{"inputs": inputs}))
Expand Down Expand Up @@ -136,9 +201,8 @@ async def auto_regex_test(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("regex_test", output)
inputs = {"ground_truth": data_point, "prediction": output}
response = await regex_test(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -174,9 +238,8 @@ async def auto_field_match_test(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("field_match_test", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {"ground_truth": correct_answer, "prediction": output}
response = await field_match_test(
Expand Down Expand Up @@ -210,9 +273,8 @@ async def auto_webhook_test(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("webhook_test", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {"prediction": output, "ground_truth": correct_answer}
response = await webhook_test(
Expand Down Expand Up @@ -272,9 +334,8 @@ async def auto_custom_code_run(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("custom_code_run", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {
"app_config": app_params,
Expand Down Expand Up @@ -332,9 +393,9 @@ async def auto_ai_critique(
Returns:
Result: Evaluation result.
"""
if not isinstance(output, str):
output = output.get("data", "")

try:
output = validate_string_output("ai_critique", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {
"prompt_user": app_params.get("prompt_user", ""),
Expand Down Expand Up @@ -395,9 +456,8 @@ async def auto_starts_with(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("starts_with", output)
inputs = {"prediction": output}
response = await starts_with(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -437,9 +497,8 @@ async def auto_ends_with(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("ends_with", output)
inputs = {"prediction": output}
response = await ends_with(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -480,9 +539,8 @@ async def auto_contains(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("contains", output)
inputs = {"prediction": output}
response = await contains(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -523,9 +581,8 @@ async def auto_contains_any(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("contains_any", output)
inputs = {"prediction": output}
response = await contains_any(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -568,9 +625,8 @@ async def auto_contains_all(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("contains_all", output)
response = await contains_all(
input=EvaluatorInputInterface(
**{"inputs": {"prediction": output}, "settings": settings_values}
Expand Down Expand Up @@ -611,9 +667,8 @@ async def auto_contains_json(
settings_values: Dict[str, Any], # pylint: disable=unused-argument
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_json_output("contains_json", output)
response = await contains_json(
input=EvaluatorInputInterface(**{"inputs": {"prediction": output}})
)
Expand Down Expand Up @@ -754,22 +809,7 @@ async def auto_json_diff(
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
try:
output = output.get("data", "") if isinstance(output, dict) else output

if isinstance(output, dict):
output = json.dumps(output)
elif isinstance(output, str):
try:
json.loads(output)
except:
raise Exception(
f"Evaluator 'auto_json_diff' requires string outputs to be JSON strings."
)
else:
raise Exception(
f"Evaluator 'auto_json_diff' requires the output to be either a JSON string or a JSON object, but received {type(output).__name__} instead."
)

output = validate_json_output("json_diff", output)
correct_answer = get_correct_answer(data_point, settings_values)
response = await json_diff(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -1039,9 +1079,8 @@ async def auto_levenshtein_distance(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("levenshtein_distance", output)
correct_answer = get_correct_answer(data_point, settings_values)
response = await levenshtein_distance(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -1082,9 +1121,8 @@ async def auto_similarity_match(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any],
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("similarity_match", output)
correct_answer = get_correct_answer(data_point, settings_values)
response = await similarity_match(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -1164,10 +1202,8 @@ async def auto_semantic_similarity(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any],
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")

try:
output = validate_string_output("semantic_similarity", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {"prediction": output, "ground_truth": correct_answer}
response = await semantic_similarity(
Expand Down
46 changes: 41 additions & 5 deletions agenta-backend/agenta_backend/tests/unit/test_evaluators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
import pytest

from test_traces import simple_rag_trace

from agenta_backend.tests.unit.test_traces import simple_rag_trace
from agenta_backend.services.evaluators_service import (
auto_levenshtein_distance,
auto_starts_with,
Expand Down Expand Up @@ -175,10 +174,13 @@ async def test_auto_contains_all(output, substrings, case_sensitive, expected):
@pytest.mark.parametrize(
"output, expected",
[
('Some random text {"key": "value"} more text', True),
("No JSON here!", False),
("{Malformed JSON, nope!}", False),
('Some random text {"key": "value"} more text', None),
("No JSON here!", None),
("{Malformed JSON, nope!}", None),
('{"valid": "json", "number": 123}', True),
({"data": {"message": "The capital of Azerbaijan is Baku."}}, True),
({"data": '{"message": "The capital of Azerbaijan is Baku."}'}, True),
({"data": "The capital of Azerbaijan is Baku."}, None),
],
)
@pytest.mark.asyncio
Expand Down Expand Up @@ -232,6 +234,40 @@ async def test_auto_contains_json(output, expected):
0.0,
1.0,
),
(
{
"correct_answer": '{"user": {"name": "John", "details": {"age": 30, "location": "New York"}}}'
},
{
"data": '{"USER": {"NAME": "John", "DETAILS": {"AGE": 30, "LOCATION": "New York"}}}'
},
{
"predict_keys": True,
"compare_schema_only": False,
"case_insensitive_keys": True,
"correct_answer_key": "correct_answer",
},
0.0,
1.0,
),
(
{
"correct_answer": '{"user": {"name": "John", "details": {"age": 30, "location": "New York"}}}'
},
{
"data": {
"output": '{"USER": {"NAME": "John", "DETAILS": {"AGE": 30, "LOCATION": "New York"}}}'
}
},
{
"predict_keys": True,
"compare_schema_only": False,
"case_insensitive_keys": True,
"correct_answer_key": "correct_answer",
},
0.0,
1.0,
),
],
)
@pytest.mark.asyncio
Expand Down
21 changes: 13 additions & 8 deletions agenta-cli/agenta/sdk/decorators/llm_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@ async def wrapper(*args, **kwargs) -> Any:
{
"func": func.__name__,
"endpoint": route,
"params": {**config_params, **func_signature.parameters}
if not config
else func_signature.parameters,
"params": (
{**config_params, **func_signature.parameters}
if not config
else func_signature.parameters
),
"config": config,
}
)
Expand All @@ -229,9 +231,11 @@ async def wrapper(*args, **kwargs) -> Any:
{
"func": func.__name__,
"endpoint": route,
"params": {**config_params, **func_signature.parameters}
if not config
else func_signature.parameters,
"params": (
{**config_params, **func_signature.parameters}
if not config
else func_signature.parameters
),
"config": config,
}
)
Expand Down Expand Up @@ -402,15 +406,16 @@ async def execute_function(

# PATCH : if result is not a dict, make it a dict
if not isinstance(result, dict):
data = result
data = str(result)
else:
# PATCH : if result is a legacy dict, clean it up
if (
"message" in result.keys()
and "cost" in result.keys()
and "usage" in result.keys()
):
data = result["message"]
data = str(result["message"])

# END OF PATH

if data is None:
Expand Down
Loading

0 comments on commit 532a4bb

Please sign in to comment.