Skip to content

Commit

Permalink
feat: add validate_prompt_template util (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyi-cheng authored Jul 1, 2024
1 parent 36a5de5 commit 00033e9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
17 changes: 17 additions & 0 deletions src/fmeval/eval_algorithms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,23 @@ def validate_dataset(dataset: Dataset, column_names: List[str]):
)


def validate_prompt_template(prompt_template: str, placeholders: List[str]):
"""
Util function to validate that prompt_template contains the keywords.
:param prompt_template: A template used to compose prompts. Ex: '{"Question":$question, "Answer": $answer}'
:param placeholders: Placeholder keyword list. This keyword appears
in `prompt_template` with a $ sign prepended. In the above example,
the placeholders are ["question", "answer"].
:raises: EvalAlgorithmClientError for an invalid prompt_template
"""
for placeholder in placeholders:
util.require(
f"${placeholder}" in prompt_template,
f"Unable to find placeholder ${placeholder} in prompt_template.",
)


def aggregate_evaluation_scores(
dataset: Dataset, score_column_names: List[str], agg_method: str
) -> Tuple[List[EvalScore], Optional[List[CategoryScore]]]:
Expand Down
27 changes: 26 additions & 1 deletion test/unit/eval_algorithms/test_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import multiprocessing as mp
import os
import re

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -51,8 +52,9 @@
get_dataset_configs,
aggregate_evaluation_scores,
create_model_invocation_pipeline,
validate_prompt_template,
)
from fmeval.exceptions import EvalAlgorithmInternalError
from fmeval.exceptions import EvalAlgorithmInternalError, EvalAlgorithmClientError
from fmeval.transforms.common import GeneratePrompt, GetModelOutputs
from fmeval.util import camel_to_snake, get_num_actors

Expand Down Expand Up @@ -287,6 +289,29 @@ def test_generate_prompt_column_for_dataset(test_case):
assert sorted(returned_dataset.take(test_case.num_rows), key=lambda x: x["id"]) == test_case.expected_dataset


def test_validate_prompt_template_success():
"""
GIVEN a prompt_template and required placeholder keywords
WHEN validate_prompt_template is called
THEN no exception is raised
"""
validate_prompt_template(
prompt_template='{"Question":$question, "Answer": $answer}', placeholders=["question", "answer"]
)


def test_validate_prompt_template_raise_error():
"""
GIVEN placeholder keywords and a prompt_template doesn't contain required placeholder
WHEN validate_prompt_template is called
THEN raise EvalAlgorithmClientError with correct error message
"""
with pytest.raises(EvalAlgorithmClientError, match=re.escape("Unable to find placeholder")):
validate_prompt_template(
prompt_template='{"Question":$question, "Answer": $answer}', placeholders=["model_input"]
)


@patch("ray.data.ActorPoolStrategy")
@patch("ray.data.Dataset")
def test_num_actors_in_generate_prompt_column_for_dataset(dataset, actor_pool_strategy):
Expand Down

0 comments on commit 00033e9

Please sign in to comment.