Skip to content

Commit

Permalink
feat: update context to use lists (#305)
Browse files Browse the repository at this point in the history
  • Loading branch information
oyangz authored Jul 12, 2024
1 parent 688a81b commit e65c10e
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 68 deletions.
3 changes: 2 additions & 1 deletion src/fmeval/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class DatasetColumns(Enum):
MODEL_LOG_PROBABILITY = Column(name="model_log_probability")
TARGET_OUTPUT = Column(name="target_output", should_cast=True)
CATEGORY = Column(name="category", should_cast=True)
TARGET_CONTEXT = Column(name="target_context", should_cast=True)
CONTEXT = Column(name="context", should_cast=True)
SENT_MORE_INPUT = Column(name="sent_more_input", should_cast=True)
SENT_LESS_INPUT = Column(name="sent_less_input", should_cast=True)
SENT_MORE_PROMPT = Column(name="sent_more_prompt")
Expand All @@ -79,6 +79,7 @@ class DatasetColumns(Enum):


DATASET_COLUMNS = OrderedDict((col.value.name, col) for col in DatasetColumns)
COLUMNS_WITH_LISTS = [DatasetColumns.CONTEXT.value.name]

# This suffix must be included at the end of all
# DataConfig attribute names where the attribute
Expand Down
4 changes: 2 additions & 2 deletions src/fmeval/data_loaders/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DataConfig:
input log probability (used by the Prompt Stereotyping evaluation algorithm)
:param sent_less_log_prob_location: the location for the "sent less"
input log probability (used by the Prompt Stereotyping evaluation algorithm).
:param target_context_location: the location of the target context.
:param context_location: the location of the context for RAG evaluations.
"""

dataset_name: str
Expand All @@ -51,7 +51,7 @@ class DataConfig:
sent_less_input_location: Optional[str] = None
sent_more_log_prob_location: Optional[str] = None
sent_less_log_prob_location: Optional[str] = None
target_context_location: Optional[str] = None
context_location: Optional[str] = None

def __post_init__(self):
require(
Expand Down
67 changes: 45 additions & 22 deletions src/fmeval/data_loaders/json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fmeval.constants import (
DatasetColumns,
DATASET_COLUMNS,
COLUMNS_WITH_LISTS,
DATA_CONFIG_LOCATION_SUFFIX,
MIME_TYPE_JSON,
MIME_TYPE_JSONLINES,
Expand Down Expand Up @@ -153,12 +154,14 @@ def _parse_column(args: ColumnParseArguments) -> Optional[Union[Any, List[Any]]]
return result

@staticmethod
def _validate_jmespath_result(result: Union[Any, List[Any]], args: ColumnParseArguments) -> None:
def _validate_jmespath_result(result: Union[Any, List[Any], List[List[Any]]], args: ColumnParseArguments) -> None:
"""Validates that the JMESPath result is as expected.
If `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is expected
to be a 1D array (list). If MIME_TYPE_JSON_LINES, then `result` is expected
to be a single scalar value.
For dataset columns in COLUMNS_WITH_LISTS, if `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is
expected to be a 2D array. If MIME_TYPE_JSON_LINES, then `result` is expected to be a 1D array (list).
For all other dataset columns, if `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is expected to be
a 1D array (list). If MIME_TYPE_JSON_LINES, then `result` is expected to be a single scalar value.
:param result: JMESPath query result to be validated.
:param args: See ColumnParseArguments docstring.
Expand All @@ -175,24 +178,38 @@ def _validate_jmespath_result(result: Union[Any, List[Any]], args: ColumnParseAr
f"the {args.column.value.name} column of dataset `{args.dataset_name}`, but found at least "
"one value that is None.",
)
require(
all(not isinstance(x, list) for x in result),
f"Expected a 1D array using JMESPath '{args.jmespath_parser.expression}' on dataset "
f"`{args.dataset_name}`, where each element of the array is a sample's {args.column.value.name}, "
f"but found at least one nested array.",
)
if args.column.value.name in COLUMNS_WITH_LISTS:
require(
all(isinstance(x, list) for x in result),
f"Expected a 2D array using JMESPath '{args.jmespath_parser.expression}' on dataset "
f"`{args.dataset_name}` but found at least one non-list object.",
)
else:
require(
all(not isinstance(x, list) for x in result),
f"Expected a 1D array using JMESPath '{args.jmespath_parser.expression}' on dataset "
f"`{args.dataset_name}`, where each element of the array is a sample's {args.column.value.name}, "
f"but found at least one nested array.",
)
elif args.dataset_mime_type == MIME_TYPE_JSONLINES:
require(
result is not None,
f"Found no values using {args.column.value.name} JMESPath '{args.jmespath_parser.expression}' "
f"on dataset `{args.dataset_name}`.",
)
require(
not isinstance(result, list),
f"Expected to find a single value using {args.column.value.name} JMESPath "
f"'{args.jmespath_parser.expression}' on a dataset line in "
f"dataset `{args.dataset_name}`, but found a list instead.",
)
if args.column.value.name in COLUMNS_WITH_LISTS:
require(
isinstance(result, list),
f"Expected to find a List using JMESPath '{args.jmespath_parser.expression}' on a dataset line in "
f"`{args.dataset_name}`, but found a non-list object instead.",
)
else:
require(
not isinstance(result, list),
f"Expected to find a single value using {args.column.value.name} JMESPath "
f"'{args.jmespath_parser.expression}' on a dataset line in "
f"dataset `{args.dataset_name}`, but found a list instead.",
)
else: # pragma: no cover
raise EvalAlgorithmInternalError(
f"args.dataset_mime_type is {args.dataset_mime_type}, but only JSON " "and JSON Lines are supported."
Expand All @@ -217,26 +234,32 @@ def _validate_parsed_columns_lengths(parsed_columns_dict: Dict[str, List[Any]]):
)

@staticmethod
def _cast_to_string(result: Union[Any, List[Any]], args: ColumnParseArguments) -> Union[str, List[str]]:
def _cast_to_string(
result: Union[Any, List[Any], List[List[Any]]], args: ColumnParseArguments
) -> Union[str, List[str], List[List[str]]]:
"""
Casts the contents of `result` to string(s), raising an error if casting fails.
It is extremely unlikely that the str() operation should fail; this basically
only happens if the object has explicitly overwritten the __str__ method to raise
an exception.
If `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is expected
to be a 1D array (list) of objects. If MIME_TYPE_JSON_LINES, then `result`
is expected to be a single object.
For dataset columns in COLUMNS_WITH_LISTS, if `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is
expected to be a 2D array. If MIME_TYPE_JSON_LINES, then `result` is expected to be a 1D array (list).
For all other dataset columns, if `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is expected to be
a 1D array (list). If MIME_TYPE_JSON_LINES, then `result` is expected to be a single scalar value.
:param result: JMESPath query result to be casted.
:param args: See ColumnParseArguments docstring.
:returns: `result` casted to a string or list of strings.
"""
try:
if args.dataset_mime_type == MIME_TYPE_JSON:
return [str(x) for x in result]
if args.column.value.name in COLUMNS_WITH_LISTS:
return [[str(x) for x in sample] for sample in result]
else:
return [str(x) for x in result]
elif args.dataset_mime_type == MIME_TYPE_JSONLINES:
return str(result)
return [str(x) for x in result] if args.column.value.name in COLUMNS_WITH_LISTS else str(result)
else:
raise EvalAlgorithmInternalError( # pragma: no cover
f"args.dataset_mime_type is {args.dataset_mime_type}, but only JSON and JSON Lines are supported."
Expand Down
74 changes: 56 additions & 18 deletions test/unit/data_loaders/test_json_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
CustomJSONDatasource,
)
from fmeval.data_loaders.util import DataConfig
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, Dict, List, NamedTuple, Optional, Union
from fmeval.constants import (
DatasetColumns,
MIME_TYPE_JSON,
Expand Down Expand Up @@ -45,14 +45,14 @@ def create_temp_jsonlines_data_file_from_input_dataset(path: pathlib.Path, input

class TestJsonDataLoader:
class TestCaseReadDataset(NamedTuple):
input_dataset: Dict[str, Any]
input_dataset: Union[Dict[str, Any], List[Dict[str, Any]]]
expected_dataset: List[Dict[str, Any]]
dataset_mime_type: str
model_input_jmespath: Optional[str] = None
model_output_jmespath: Optional[str] = None
target_output_jmespath: Optional[str] = None
category_jmespath: Optional[str] = None
target_context_jmespath: Optional[str] = None
context_jmespath: Optional[str] = None

@pytest.mark.parametrize(
"test_case",
Expand All @@ -72,67 +72,105 @@ class TestCaseReadDataset(NamedTuple):
# containing heterogeneous lists.
TestCaseReadDataset(
input_dataset={
"row_1": ["a", True, False, 0, "context_a"],
"row_2": ["b", False, False, 1, "context_b"],
"row_3": ["c", False, True, 2, "context_c"],
"row_1": ["a", True, False, 0, ["context_a"]],
"row_2": ["b", False, False, 1, ["context_b"]],
"row_3": ["c", False, True, 2, ["context_c"]],
},
expected_dataset=[
{
DatasetColumns.MODEL_INPUT.value.name: "a",
DatasetColumns.MODEL_OUTPUT.value.name: "True",
DatasetColumns.TARGET_OUTPUT.value.name: "False",
DatasetColumns.CATEGORY.value.name: "0",
DatasetColumns.TARGET_CONTEXT.value.name: "context_a",
DatasetColumns.CONTEXT.value.name: ["context_a"],
},
{
DatasetColumns.MODEL_INPUT.value.name: "b",
DatasetColumns.MODEL_OUTPUT.value.name: "False",
DatasetColumns.TARGET_OUTPUT.value.name: "False",
DatasetColumns.CATEGORY.value.name: "1",
DatasetColumns.TARGET_CONTEXT.value.name: "context_b",
DatasetColumns.CONTEXT.value.name: ["context_b"],
},
{
DatasetColumns.MODEL_INPUT.value.name: "c",
DatasetColumns.MODEL_OUTPUT.value.name: "False",
DatasetColumns.TARGET_OUTPUT.value.name: "True",
DatasetColumns.CATEGORY.value.name: "2",
DatasetColumns.TARGET_CONTEXT.value.name: "context_c",
DatasetColumns.CONTEXT.value.name: ["context_c"],
},
],
dataset_mime_type=MIME_TYPE_JSON,
model_input_jmespath="[row_1[0], row_2[0], row_3[0]]",
model_output_jmespath="[row_1[1], row_2[1], row_3[1]]",
target_output_jmespath="[row_1[2], row_2[2], row_3[2]]",
category_jmespath="[row_1[3], row_2[3], row_3[3]]",
target_context_jmespath="[row_1[4], row_2[4], row_3[4]]",
context_jmespath="[row_1[4], row_2[4], row_3[4]]",
),
TestCaseReadDataset(
input_dataset={
"model_input_col": ["a", "b", "c"],
"context": [["a", "b"], ["c", "d"], ["e", "f"]],
},
expected_dataset=[
{DatasetColumns.MODEL_INPUT.value.name: "a", DatasetColumns.CONTEXT.value.name: ["a", "b"]},
{DatasetColumns.MODEL_INPUT.value.name: "b", DatasetColumns.CONTEXT.value.name: ["c", "d"]},
{DatasetColumns.MODEL_INPUT.value.name: "c", DatasetColumns.CONTEXT.value.name: ["e", "f"]},
],
dataset_mime_type=MIME_TYPE_JSON,
model_input_jmespath="model_input_col",
context_jmespath="context",
),
TestCaseReadDataset(
input_dataset=[
{"input": "a", "output": 3.14, "context": "1"},
{"input": "c", "output": 2.718, "context": "2"},
{"input": "e", "output": 1.00, "context": "3"},
{"input": "a", "output": 3.14, "context": ["1"]},
{"input": "c", "output": 2.718, "context": ["2"]},
{"input": "e", "output": 1.00, "context": ["3"]},
],
expected_dataset=[
{
DatasetColumns.MODEL_INPUT.value.name: "a",
DatasetColumns.MODEL_OUTPUT.value.name: "3.14",
DatasetColumns.TARGET_CONTEXT.value.name: "1",
DatasetColumns.CONTEXT.value.name: ["1"],
},
{
DatasetColumns.MODEL_INPUT.value.name: "c",
DatasetColumns.MODEL_OUTPUT.value.name: "2.718",
DatasetColumns.TARGET_CONTEXT.value.name: "2",
DatasetColumns.CONTEXT.value.name: ["2"],
},
{
DatasetColumns.MODEL_INPUT.value.name: "e",
DatasetColumns.MODEL_OUTPUT.value.name: "1.0",
DatasetColumns.TARGET_CONTEXT.value.name: "3",
DatasetColumns.CONTEXT.value.name: ["3"],
},
],
dataset_mime_type=MIME_TYPE_JSONLINES,
model_input_jmespath="input",
model_output_jmespath="output",
target_context_jmespath="context",
context_jmespath="context",
),
TestCaseReadDataset(
input_dataset=[
{"input": "a", "context": ["context 1", "context 2"]},
{"input": "c", "context": ["context 3"]},
{"input": "e", "context": ["context 4"]},
],
expected_dataset=[
{
DatasetColumns.MODEL_INPUT.value.name: "a",
DatasetColumns.CONTEXT.value.name: ["context 1", "context 2"],
},
{
DatasetColumns.MODEL_INPUT.value.name: "c",
DatasetColumns.CONTEXT.value.name: ["context 3"],
},
{
DatasetColumns.MODEL_INPUT.value.name: "e",
DatasetColumns.CONTEXT.value.name: ["context 4"],
},
],
dataset_mime_type=MIME_TYPE_JSONLINES,
model_input_jmespath="input",
context_jmespath="context",
),
],
)
Expand All @@ -157,7 +195,7 @@ def test_load_dataset(self, tmp_path, test_case):
model_output_location=test_case.model_output_jmespath,
target_output_location=test_case.target_output_jmespath,
category_location=test_case.category_jmespath,
target_context_location=test_case.target_context_jmespath,
context_location=test_case.context_jmespath,
)
)
config = JsonDataLoaderConfig(
Expand Down
Loading

0 comments on commit e65c10e

Please sign in to comment.