diff --git a/src/fmeval/constants.py b/src/fmeval/constants.py index f31422be..debf22ec 100644 --- a/src/fmeval/constants.py +++ b/src/fmeval/constants.py @@ -68,7 +68,7 @@ class DatasetColumns(Enum): MODEL_LOG_PROBABILITY = Column(name="model_log_probability") TARGET_OUTPUT = Column(name="target_output", should_cast=True) CATEGORY = Column(name="category", should_cast=True) - TARGET_CONTEXT = Column(name="target_context", should_cast=True) + CONTEXT = Column(name="context", should_cast=True) SENT_MORE_INPUT = Column(name="sent_more_input", should_cast=True) SENT_LESS_INPUT = Column(name="sent_less_input", should_cast=True) SENT_MORE_PROMPT = Column(name="sent_more_prompt") @@ -79,6 +79,7 @@ class DatasetColumns(Enum): DATASET_COLUMNS = OrderedDict((col.value.name, col) for col in DatasetColumns) +COLUMNS_WITH_LISTS = [DatasetColumns.CONTEXT.value.name] # This suffix must be included at the end of all # DataConfig attribute names where the attribute diff --git a/src/fmeval/data_loaders/data_config.py b/src/fmeval/data_loaders/data_config.py index 8fe69d6b..313553df 100644 --- a/src/fmeval/data_loaders/data_config.py +++ b/src/fmeval/data_loaders/data_config.py @@ -37,7 +37,7 @@ class DataConfig: input log probability (used by the Prompt Stereotyping evaluation algorithm) :param sent_less_log_prob_location: the location for the "sent less" input log probability (used by the Prompt Stereotyping evaluation algorithm). - :param target_context_location: the location of the target context. + :param context_location: the location of the context for RAG evaluations. """ dataset_name: str @@ -51,7 +51,7 @@ class DataConfig: sent_less_input_location: Optional[str] = None sent_more_log_prob_location: Optional[str] = None sent_less_log_prob_location: Optional[str] = None - target_context_location: Optional[str] = None + context_location: Optional[str] = None def __post_init__(self): require( diff --git a/src/fmeval/data_loaders/json_parser.py b/src/fmeval/data_loaders/json_parser.py index b49cc77a..9f42fecb 100644 --- a/src/fmeval/data_loaders/json_parser.py +++ b/src/fmeval/data_loaders/json_parser.py @@ -9,6 +9,7 @@ from fmeval.constants import ( DatasetColumns, DATASET_COLUMNS, + COLUMNS_WITH_LISTS, DATA_CONFIG_LOCATION_SUFFIX, MIME_TYPE_JSON, MIME_TYPE_JSONLINES, @@ -153,12 +154,14 @@ def _parse_column(args: ColumnParseArguments) -> Optional[Union[Any, List[Any]]] return result @staticmethod - def _validate_jmespath_result(result: Union[Any, List[Any]], args: ColumnParseArguments) -> None: + def _validate_jmespath_result(result: Union[Any, List[Any], List[List[Any]]], args: ColumnParseArguments) -> None: """Validates that the JMESPath result is as expected. - If `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is expected - to be a 1D array (list). If MIME_TYPE_JSON_LINES, then `result` is expected - to be a single scalar value. + For dataset columns in COLUMNS_WITH_LISTS, if `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is + expected to be a 2D array. If MIME_TYPE_JSON_LINES, then `result` is expected to be a 1D array (list). + + For all other dataset columns, if `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is expected to be + a 1D array (list). If MIME_TYPE_JSON_LINES, then `result` is expected to be a single scalar value. :param result: JMESPath query result to be validated. :param args: See ColumnParseArguments docstring. @@ -175,24 +178,38 @@ def _validate_jmespath_result(result: Union[Any, List[Any]], args: ColumnParseAr f"the {args.column.value.name} column of dataset `{args.dataset_name}`, but found at least " "one value that is None.", ) - require( - all(not isinstance(x, list) for x in result), - f"Expected a 1D array using JMESPath '{args.jmespath_parser.expression}' on dataset " - f"`{args.dataset_name}`, where each element of the array is a sample's {args.column.value.name}, " - f"but found at least one nested array.", - ) + if args.column.value.name in COLUMNS_WITH_LISTS: + require( + all(isinstance(x, list) for x in result), + f"Expected a 2D array using JMESPath '{args.jmespath_parser.expression}' on dataset " + f"`{args.dataset_name}` but found at least one non-list object.", + ) + else: + require( + all(not isinstance(x, list) for x in result), + f"Expected a 1D array using JMESPath '{args.jmespath_parser.expression}' on dataset " + f"`{args.dataset_name}`, where each element of the array is a sample's {args.column.value.name}, " + f"but found at least one nested array.", + ) elif args.dataset_mime_type == MIME_TYPE_JSONLINES: require( result is not None, f"Found no values using {args.column.value.name} JMESPath '{args.jmespath_parser.expression}' " f"on dataset `{args.dataset_name}`.", ) - require( - not isinstance(result, list), - f"Expected to find a single value using {args.column.value.name} JMESPath " - f"'{args.jmespath_parser.expression}' on a dataset line in " - f"dataset `{args.dataset_name}`, but found a list instead.", - ) + if args.column.value.name in COLUMNS_WITH_LISTS: + require( + isinstance(result, list), + f"Expected to find a List using JMESPath '{args.jmespath_parser.expression}' on a dataset line in " + f"`{args.dataset_name}`, but found a non-list object instead.", + ) + else: + require( + not isinstance(result, list), + f"Expected to find a single value using {args.column.value.name} JMESPath " + f"'{args.jmespath_parser.expression}' on a dataset line in " + f"dataset `{args.dataset_name}`, but found a list instead.", + ) else: # pragma: no cover raise EvalAlgorithmInternalError( f"args.dataset_mime_type is {args.dataset_mime_type}, but only JSON " "and JSON Lines are supported." @@ -217,16 +234,19 @@ def _validate_parsed_columns_lengths(parsed_columns_dict: Dict[str, List[Any]]): ) @staticmethod - def _cast_to_string(result: Union[Any, List[Any]], args: ColumnParseArguments) -> Union[str, List[str]]: + def _cast_to_string( + result: Union[Any, List[Any], List[List[Any]]], args: ColumnParseArguments + ) -> Union[str, List[str], List[List[str]]]: """ Casts the contents of `result` to string(s), raising an error if casting fails. It is extremely unlikely that the str() operation should fail; this basically only happens if the object has explicitly overwritten the __str__ method to raise an exception. - If `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is expected - to be a 1D array (list) of objects. If MIME_TYPE_JSON_LINES, then `result` - is expected to be a single object. + For dataset columns in COLUMNS_WITH_LISTS, if `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is + expected to be a 2D array. If MIME_TYPE_JSON_LINES, then `result` is expected to be a 1D array (list). + For all other dataset columns, if `args.dataset_mime_type` is MIME_TYPE_JSON, then `result` is expected to be + a 1D array (list). If MIME_TYPE_JSON_LINES, then `result` is expected to be a single scalar value. :param result: JMESPath query result to be casted. :param args: See ColumnParseArguments docstring. @@ -234,9 +254,12 @@ def _cast_to_string(result: Union[Any, List[Any]], args: ColumnParseArguments) - """ try: if args.dataset_mime_type == MIME_TYPE_JSON: - return [str(x) for x in result] + if args.column.value.name in COLUMNS_WITH_LISTS: + return [[str(x) for x in sample] for sample in result] + else: + return [str(x) for x in result] elif args.dataset_mime_type == MIME_TYPE_JSONLINES: - return str(result) + return [str(x) for x in result] if args.column.value.name in COLUMNS_WITH_LISTS else str(result) else: raise EvalAlgorithmInternalError( # pragma: no cover f"args.dataset_mime_type is {args.dataset_mime_type}, but only JSON and JSON Lines are supported." diff --git a/test/unit/data_loaders/test_json_data_loader.py b/test/unit/data_loaders/test_json_data_loader.py index 7533e326..2928ba5c 100644 --- a/test/unit/data_loaders/test_json_data_loader.py +++ b/test/unit/data_loaders/test_json_data_loader.py @@ -11,7 +11,7 @@ CustomJSONDatasource, ) from fmeval.data_loaders.util import DataConfig -from typing import Any, Dict, List, NamedTuple, Optional +from typing import Any, Dict, List, NamedTuple, Optional, Union from fmeval.constants import ( DatasetColumns, MIME_TYPE_JSON, @@ -45,14 +45,14 @@ def create_temp_jsonlines_data_file_from_input_dataset(path: pathlib.Path, input class TestJsonDataLoader: class TestCaseReadDataset(NamedTuple): - input_dataset: Dict[str, Any] + input_dataset: Union[Dict[str, Any], List[Dict[str, Any]]] expected_dataset: List[Dict[str, Any]] dataset_mime_type: str model_input_jmespath: Optional[str] = None model_output_jmespath: Optional[str] = None target_output_jmespath: Optional[str] = None category_jmespath: Optional[str] = None - target_context_jmespath: Optional[str] = None + context_jmespath: Optional[str] = None @pytest.mark.parametrize( "test_case", @@ -72,9 +72,9 @@ class TestCaseReadDataset(NamedTuple): # containing heterogeneous lists. TestCaseReadDataset( input_dataset={ - "row_1": ["a", True, False, 0, "context_a"], - "row_2": ["b", False, False, 1, "context_b"], - "row_3": ["c", False, True, 2, "context_c"], + "row_1": ["a", True, False, 0, ["context_a"]], + "row_2": ["b", False, False, 1, ["context_b"]], + "row_3": ["c", False, True, 2, ["context_c"]], }, expected_dataset=[ { @@ -82,21 +82,21 @@ class TestCaseReadDataset(NamedTuple): DatasetColumns.MODEL_OUTPUT.value.name: "True", DatasetColumns.TARGET_OUTPUT.value.name: "False", DatasetColumns.CATEGORY.value.name: "0", - DatasetColumns.TARGET_CONTEXT.value.name: "context_a", + DatasetColumns.CONTEXT.value.name: ["context_a"], }, { DatasetColumns.MODEL_INPUT.value.name: "b", DatasetColumns.MODEL_OUTPUT.value.name: "False", DatasetColumns.TARGET_OUTPUT.value.name: "False", DatasetColumns.CATEGORY.value.name: "1", - DatasetColumns.TARGET_CONTEXT.value.name: "context_b", + DatasetColumns.CONTEXT.value.name: ["context_b"], }, { DatasetColumns.MODEL_INPUT.value.name: "c", DatasetColumns.MODEL_OUTPUT.value.name: "False", DatasetColumns.TARGET_OUTPUT.value.name: "True", DatasetColumns.CATEGORY.value.name: "2", - DatasetColumns.TARGET_CONTEXT.value.name: "context_c", + DatasetColumns.CONTEXT.value.name: ["context_c"], }, ], dataset_mime_type=MIME_TYPE_JSON, @@ -104,35 +104,73 @@ class TestCaseReadDataset(NamedTuple): model_output_jmespath="[row_1[1], row_2[1], row_3[1]]", target_output_jmespath="[row_1[2], row_2[2], row_3[2]]", category_jmespath="[row_1[3], row_2[3], row_3[3]]", - target_context_jmespath="[row_1[4], row_2[4], row_3[4]]", + context_jmespath="[row_1[4], row_2[4], row_3[4]]", + ), + TestCaseReadDataset( + input_dataset={ + "model_input_col": ["a", "b", "c"], + "context": [["a", "b"], ["c", "d"], ["e", "f"]], + }, + expected_dataset=[ + {DatasetColumns.MODEL_INPUT.value.name: "a", DatasetColumns.CONTEXT.value.name: ["a", "b"]}, + {DatasetColumns.MODEL_INPUT.value.name: "b", DatasetColumns.CONTEXT.value.name: ["c", "d"]}, + {DatasetColumns.MODEL_INPUT.value.name: "c", DatasetColumns.CONTEXT.value.name: ["e", "f"]}, + ], + dataset_mime_type=MIME_TYPE_JSON, + model_input_jmespath="model_input_col", + context_jmespath="context", ), TestCaseReadDataset( input_dataset=[ - {"input": "a", "output": 3.14, "context": "1"}, - {"input": "c", "output": 2.718, "context": "2"}, - {"input": "e", "output": 1.00, "context": "3"}, + {"input": "a", "output": 3.14, "context": ["1"]}, + {"input": "c", "output": 2.718, "context": ["2"]}, + {"input": "e", "output": 1.00, "context": ["3"]}, ], expected_dataset=[ { DatasetColumns.MODEL_INPUT.value.name: "a", DatasetColumns.MODEL_OUTPUT.value.name: "3.14", - DatasetColumns.TARGET_CONTEXT.value.name: "1", + DatasetColumns.CONTEXT.value.name: ["1"], }, { DatasetColumns.MODEL_INPUT.value.name: "c", DatasetColumns.MODEL_OUTPUT.value.name: "2.718", - DatasetColumns.TARGET_CONTEXT.value.name: "2", + DatasetColumns.CONTEXT.value.name: ["2"], }, { DatasetColumns.MODEL_INPUT.value.name: "e", DatasetColumns.MODEL_OUTPUT.value.name: "1.0", - DatasetColumns.TARGET_CONTEXT.value.name: "3", + DatasetColumns.CONTEXT.value.name: ["3"], }, ], dataset_mime_type=MIME_TYPE_JSONLINES, model_input_jmespath="input", model_output_jmespath="output", - target_context_jmespath="context", + context_jmespath="context", + ), + TestCaseReadDataset( + input_dataset=[ + {"input": "a", "context": ["context 1", "context 2"]}, + {"input": "c", "context": ["context 3"]}, + {"input": "e", "context": ["context 4"]}, + ], + expected_dataset=[ + { + DatasetColumns.MODEL_INPUT.value.name: "a", + DatasetColumns.CONTEXT.value.name: ["context 1", "context 2"], + }, + { + DatasetColumns.MODEL_INPUT.value.name: "c", + DatasetColumns.CONTEXT.value.name: ["context 3"], + }, + { + DatasetColumns.MODEL_INPUT.value.name: "e", + DatasetColumns.CONTEXT.value.name: ["context 4"], + }, + ], + dataset_mime_type=MIME_TYPE_JSONLINES, + model_input_jmespath="input", + context_jmespath="context", ), ], ) @@ -157,7 +195,7 @@ def test_load_dataset(self, tmp_path, test_case): model_output_location=test_case.model_output_jmespath, target_output_location=test_case.target_output_jmespath, category_location=test_case.category_jmespath, - target_context_location=test_case.target_context_jmespath, + context_location=test_case.context_jmespath, ) ) config = JsonDataLoaderConfig( diff --git a/test/unit/data_loaders/test_json_parser.py b/test/unit/data_loaders/test_json_parser.py index c6e612bd..5b9ac087 100644 --- a/test/unit/data_loaders/test_json_parser.py +++ b/test/unit/data_loaders/test_json_parser.py @@ -90,25 +90,30 @@ def test_init_failure(self): class TestCaseParseColumnFailure(NamedTuple): result: List[Any] error_message: str + column: DatasetColumns @pytest.mark.parametrize( - "result, error_message", + "result, error_message, column", [ TestCaseParseColumnFailure( result="not a list", error_message="Expected to find a non-empty list of samples", + column=DatasetColumns.MODEL_INPUT, ), TestCaseParseColumnFailure( result=[1, 2, None], error_message="Expected an array of non-null values", + column=DatasetColumns.MODEL_INPUT, + ), + TestCaseParseColumnFailure( + result=[1, 2, [3], 4], error_message="Expected a 1D array", column=DatasetColumns.MODEL_INPUT ), TestCaseParseColumnFailure( - result=[1, 2, [3], 4], - error_message="Expected a 1D array", + result=[[1], 2], error_message="Expected a 2D array", column=DatasetColumns.CONTEXT ), ], ) - def test_validation_failure_json(self, result, error_message): + def test_validation_failure_json(self, result, error_message, column): """ GIVEN a malformed `result` argument (obtained from a JSON dataset) WHEN _validate_jmespath_result is called @@ -118,7 +123,7 @@ def test_validation_failure_json(self, result, error_message): with pytest.raises(EvalAlgorithmClientError, match=error_message): args = ColumnParseArguments( jmespath_parser=Mock(), - column=Mock(), + column=column, dataset={}, dataset_mime_type=MIME_TYPE_JSON, dataset_name="dataset", @@ -126,19 +131,20 @@ def test_validation_failure_json(self, result, error_message): JsonParser._validate_jmespath_result(result, args) @pytest.mark.parametrize( - "result, error_message", + "result, error_message, column", [ TestCaseParseColumnFailure( - result=None, - error_message="Found no values using", + result=None, error_message="Found no values using", column=DatasetColumns.MODEL_INPUT + ), + TestCaseParseColumnFailure( + result=[1, 2, 3], error_message="Expected to find a single value", column=DatasetColumns.MODEL_INPUT ), TestCaseParseColumnFailure( - result=[1, 2, 3], - error_message="Expected to find a single value", + result="Not a list", error_message="Expected to find a List", column=DatasetColumns.CONTEXT ), ], ) - def test_validation_failure_jsonlines(self, result, error_message): + def test_validation_failure_jsonlines(self, result, error_message, column): """ GIVEN a malformed `result` argument (obtained from a JSON Lines dataset line) WHEN _validate_jmespath_result is called @@ -148,7 +154,7 @@ def test_validation_failure_jsonlines(self, result, error_message): with pytest.raises(EvalAlgorithmClientError, match=error_message): args = ColumnParseArguments( jmespath_parser=Mock(), - column=Mock(), + column=column, dataset={}, dataset_mime_type=MIME_TYPE_JSONLINES, dataset_name="dataset", @@ -172,7 +178,7 @@ class TestCaseJsonParseDatasetColumns(NamedTuple): model_output_location="model_output.*", target_output_location="targets_outer.targets_inner[*].sentiment", category_location="category", - target_context_location="target_context", + context_location="context", # this JMESPath query will fail to find any results, and should effectively get ignored sent_more_input_location="invalid_jmespath_query", ), @@ -185,7 +191,7 @@ class TestCaseJsonParseDatasetColumns(NamedTuple): ], }, "model_output": {"sample_1": "positive", "sample_2": "negative"}, - "target_context": ["a", "b"], + "context": [["a", "b"], ["c"]], "category": ["category_0", "category_1"], }, ), @@ -201,7 +207,7 @@ class TestCaseJsonParseDatasetColumns(NamedTuple): category_location="[*].category_col", # this JMESPath query will fail to find any results, and should effectively get ignored sent_more_input_location="invalid_jmespath_query", - target_context_location="[*].target_context", + context_location="[*].context", ), dataset=[ { @@ -209,14 +215,14 @@ class TestCaseJsonParseDatasetColumns(NamedTuple): "model_output_col": "positive", "target_output_col": "negative", "category_col": "category_0", - "target_context": "a", + "context": ["a", "b"], }, { "model_input_col": "B", "model_output_col": "negative", "target_output_col": "positive", "category_col": "category_1", - "target_context": "b", + "context": ["c"], }, ], ), @@ -235,7 +241,7 @@ def test_json_parse_dataset_columns_success_json(self, mock_logger, config, data expected_model_outputs = ["positive", "negative"] expected_target_outputs = ["negative", "positive"] expected_categories = ["category_0", "category_1"] - expected_target_context = ["a", "b"] + expected_context = [["a", "b"], ["c"]] parser = JsonParser(config) cols = parser.parse_dataset_columns(dataset=dataset, dataset_mime_type=MIME_TYPE_JSON, dataset_name="dataset") @@ -244,7 +250,7 @@ def test_json_parse_dataset_columns_success_json(self, mock_logger, config, data assert cols[DatasetColumns.MODEL_OUTPUT.value.name] == expected_model_outputs assert cols[DatasetColumns.TARGET_OUTPUT.value.name] == expected_target_outputs assert cols[DatasetColumns.CATEGORY.value.name] == expected_categories - assert cols[DatasetColumns.TARGET_CONTEXT.value.name] == expected_target_context + assert cols[DatasetColumns.CONTEXT.value.name] == expected_context # ensure that ColumnNames.SENT_MORE_INPUT_COLUMN.value.name does not show up in `cols` assert set(cols.keys()) == { @@ -252,7 +258,7 @@ def test_json_parse_dataset_columns_success_json(self, mock_logger, config, data DatasetColumns.MODEL_OUTPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name, DatasetColumns.CATEGORY.value.name, - DatasetColumns.TARGET_CONTEXT.value.name, + DatasetColumns.CONTEXT.value.name, } # ensure that logger generated a warning when search_jmespath @@ -279,7 +285,7 @@ def test_parse_dataset_columns_success_jsonlines(self, mock_logger): model_output_location="output", target_output_location="target", category_location="category", - target_context_location="target_context", + context_location="context", # this JMESPath query will fail to find any results, and should effectively get ignored sent_more_input_location="invalid_jmespath_query", ) @@ -288,14 +294,14 @@ def test_parse_dataset_columns_success_jsonlines(self, mock_logger): expected_model_output = "positive" expected_target_output = "negative" expected_category = "Red" - expected_target_context = "context" + expected_context = ["context"] dataset_line = { "input": "A", "output": "positive", "target": "negative", "category": "Red", - "target_context": "context", + "context": ["context"], } cols = parser.parse_dataset_columns( dataset=dataset_line, dataset_mime_type=MIME_TYPE_JSONLINES, dataset_name="dataset_line" @@ -304,7 +310,7 @@ def test_parse_dataset_columns_success_jsonlines(self, mock_logger): assert cols[DatasetColumns.MODEL_OUTPUT.value.name] == expected_model_output assert cols[DatasetColumns.TARGET_OUTPUT.value.name] == expected_target_output assert cols[DatasetColumns.CATEGORY.value.name] == expected_category - assert cols[DatasetColumns.TARGET_CONTEXT.value.name] == expected_target_context + assert cols[DatasetColumns.CONTEXT.value.name] == expected_context # ensure that ColumnNames.SENT_MORE_INPUT_COLUMN.value.name does not show up in `cols` assert set(cols.keys()) == { @@ -312,7 +318,7 @@ def test_parse_dataset_columns_success_jsonlines(self, mock_logger): DatasetColumns.MODEL_OUTPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name, DatasetColumns.CATEGORY.value.name, - DatasetColumns.TARGET_CONTEXT.value.name, + DatasetColumns.CONTEXT.value.name, } # ensure that logger generated a warning when search_jmespath