diff --git a/src/fmeval/eval_algorithms/util.py b/src/fmeval/eval_algorithms/util.py index f7bde99e..0a0089d2 100644 --- a/src/fmeval/eval_algorithms/util.py +++ b/src/fmeval/eval_algorithms/util.py @@ -140,6 +140,23 @@ def validate_dataset(dataset: Dataset, column_names: List[str]): ) +def validate_prompt_template(prompt_template: str, placeholders: List[str]): + """ + Util function to validate that prompt_template contains the keywords. + + :param prompt_template: A template used to compose prompts. Ex: '{"Question":$question, "Answer": $answer}' + :param placeholders: Placeholder keyword list. This keyword appears + in `prompt_template` with a $ sign prepended. In the above example, + the placeholders are ["question", "answer"]. + :raises: EvalAlgorithmClientError for an invalid prompt_template + """ + for placeholder in placeholders: + util.require( + f"${placeholder}" in prompt_template, + f"Unable to find placeholder ${placeholder} in prompt_template.", + ) + + def aggregate_evaluation_scores( dataset: Dataset, score_column_names: List[str], agg_method: str ) -> Tuple[List[EvalScore], Optional[List[CategoryScore]]]: diff --git a/test/unit/eval_algorithms/test_util.py b/test/unit/eval_algorithms/test_util.py index 85bb7b04..df8d8ba4 100644 --- a/test/unit/eval_algorithms/test_util.py +++ b/test/unit/eval_algorithms/test_util.py @@ -1,6 +1,7 @@ import json import multiprocessing as mp import os +import re import numpy as np import pandas as pd @@ -51,8 +52,9 @@ get_dataset_configs, aggregate_evaluation_scores, create_model_invocation_pipeline, + validate_prompt_template, ) -from fmeval.exceptions import EvalAlgorithmInternalError +from fmeval.exceptions import EvalAlgorithmInternalError, EvalAlgorithmClientError from fmeval.transforms.common import GeneratePrompt, GetModelOutputs from fmeval.util import camel_to_snake, get_num_actors @@ -287,6 +289,29 @@ def test_generate_prompt_column_for_dataset(test_case): assert sorted(returned_dataset.take(test_case.num_rows), key=lambda x: x["id"]) == test_case.expected_dataset +def test_validate_prompt_template_success(): + """ + GIVEN a prompt_template and required placeholder keywords + WHEN validate_prompt_template is called + THEN no exception is raised + """ + validate_prompt_template( + prompt_template='{"Question":$question, "Answer": $answer}', placeholders=["question", "answer"] + ) + + +def test_validate_prompt_template_raise_error(): + """ + GIVEN placeholder keywords and a prompt_template doesn't contain required placeholder + WHEN validate_prompt_template is called + THEN raise EvalAlgorithmClientError with correct error message + """ + with pytest.raises(EvalAlgorithmClientError, match=re.escape("Unable to find placeholder")): + validate_prompt_template( + prompt_template='{"Question":$question, "Answer": $answer}', placeholders=["model_input"] + ) + + @patch("ray.data.ActorPoolStrategy") @patch("ray.data.Dataset") def test_num_actors_in_generate_prompt_column_for_dataset(dataset, actor_pool_strategy):