Skip to content

Commit

Permalink
Merge pull request #2105 from Agenta-AI/mmabrouk/fix/AGE-1016-json-di…
Browse files Browse the repository at this point in the history
…ff-evaluator-fix

Fixes to JSON evaluators
  • Loading branch information
mmabrouk authored Oct 8, 2024
2 parents cc1e07d + cf33c49 commit 385cb90
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 53 deletions.
96 changes: 47 additions & 49 deletions agenta-backend/agenta_backend/services/evaluators_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,47 +53,6 @@ def validate_string_output(
return output


def validate_json_output(
evaluator_key: str, output: Union[str, Dict[str, Any]]
) -> Union[str, dict]:
"""Checks and validate the output to be of type JSON string or dictionary.
Args:
evaluator_key (str): the key of the evaluator
output (Union[str, Dict[str, Any]]): the llm response
Raises:
Exception: requires output to be a JSON string
Returns:
str, dict: output
"""

output = output.get("data", "") if isinstance(output, dict) else output
if isinstance(output, dict):
output = json.dumps(output)
elif isinstance(output, str):
try:
json.loads(output)
except json.JSONDecodeError:
raise Exception(
f"Evaluator {evaluator_key} requires the output to be a JSON string or object."
)

if not isinstance(
output,
(
str,
dict,
),
):
raise Exception(
f"Evaluator {evaluator_key} requires the output to be either a JSON string or object, but received {type(output).__name__} instead."
)

return output


async def map(
mapping_input: EvaluatorMappingInputInterface,
) -> EvaluatorMappingOutputInterface:
Expand Down Expand Up @@ -684,7 +643,16 @@ async def auto_contains_json(
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
try:
output = validate_json_output("contains_json", output)
# parsing llm app output format if v2
output = output.get("data", "") if isinstance(output, dict) else output
if isinstance(output, dict):
output = json.dumps(
output
) # contains_json expects inputs.prediction to be a string
elif not isinstance(output, (str, dict)):
raise Exception(
f"Evaluator contains_json requires the app output to be either a JSON string or object, but received {type(output).__name__} instead."
)
response = await contains_json(
input=EvaluatorInputInterface(**{"inputs": {"prediction": output}})
)
Expand All @@ -707,7 +675,7 @@ async def contains_json(input: EvaluatorInputInterface) -> EvaluatorOutputInterf
potential_json = str(input.inputs["prediction"])[start_index:end_index]
json.loads(potential_json)
contains_json = True
except (ValueError, json.JSONDecodeError):
except (ValueError, json.JSONDecodeError) as e:
contains_json = False

return {"outputs": {"success": contains_json}}
Expand Down Expand Up @@ -825,8 +793,9 @@ async def auto_json_diff(
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
try:
output = validate_json_output("json_diff", output)
# 2. extract ground truth from data point
correct_answer = get_correct_answer(data_point, settings_values)

response = await json_diff(
input=EvaluatorInputInterface(
**{
Expand All @@ -836,7 +805,16 @@ async def auto_json_diff(
)
)
return Result(type="number", value=response["outputs"]["score"])
except (ValueError, json.JSONDecodeError, Exception):
except json.JSONDecodeError:
return Result(
type="error",
value=None,
error=Error(
message="Expected answer is not a valid JSON",
stacktrace=traceback.format_exc(),
),
)
except (ValueError, Exception):
return Result(
type="error",
value=None,
Expand All @@ -848,12 +826,32 @@ async def auto_json_diff(


async def json_diff(input: EvaluatorInputInterface) -> EvaluatorOutputInterface:
average_score = compare_jsons(
ground_truth=input.inputs["ground_truth"],
app_output=json.loads(input.inputs["prediction"]),
ground_truth = input.inputs["ground_truth"]
if isinstance(ground_truth, str):
ground_truth = json.loads(ground_truth) # if this fails we will return an error

# 1. extract llm app output if app output format is v2+
app_output = input.inputs["prediction"]
assert isinstance(
app_output, (str, dict)
), "App output is expected to be a string or a JSON object"
app_output = (
app_output.get("data", "") if isinstance(app_output, dict) else app_output
)
if isinstance(app_output, str):
try:
app_output = json.loads(app_output)
except json.JSONDecodeError:
app_output = (
{}
) # we will return 0 score for json diff in case we cannot parse the output as json

score = compare_jsons(
ground_truth=ground_truth,
app_output=app_output,
settings_values=input.settings,
)
return {"outputs": {"score": average_score}}
return {"outputs": {"score": score}}


async def measure_rag_consistency(
Expand Down
8 changes: 4 additions & 4 deletions agenta-backend/agenta_backend/tests/unit/test_evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,13 @@ async def test_auto_contains_all(output, substrings, case_sensitive, expected):
@pytest.mark.parametrize(
"output, expected",
[
('Some random text {"key": "value"} more text', None),
("No JSON here!", None),
("{Malformed JSON, nope!}", None),
('Some random text {"key": "value"} more text', True),
("No JSON here!", False),
("{Malformed JSON, nope!}", False),
('{"valid": "json", "number": 123}', True),
({"data": {"message": "The capital of Azerbaijan is Baku."}}, True),
({"data": '{"message": "The capital of Azerbaijan is Baku."}'}, True),
({"data": "The capital of Azerbaijan is Baku."}, None),
({"data": "The capital of Azerbaijan is Baku."}, False),
],
)
@pytest.mark.asyncio
Expand Down

0 comments on commit 385cb90

Please sign in to comment.