From 634568b123a5fe056d62c211cbf933160a76f561 Mon Sep 17 00:00:00 2001 From: Jonathan Bown Date: Mon, 14 Oct 2024 09:42:55 -0600 Subject: [PATCH] Add ItemMatch + ItemNoMatch Descriptors (#1338) --- docs/book/reference/all-metrics.md | 2 + src/evidently/descriptors/__init__.py | 4 + src/evidently/descriptors/_registry.py | 6 + .../descriptors/text_contains_descriptor.py | 34 ++++++ src/evidently/features/_registry.py | 6 + .../features/text_contains_feature.py | 106 ++++++++++++++++++ tests/features/test_text_contains_feature.py | 82 ++++++++++++++ 7 files changed, 240 insertions(+) diff --git a/docs/book/reference/all-metrics.md b/docs/book/reference/all-metrics.md index d697ba09d1..8bfed4b35a 100644 --- a/docs/book/reference/all-metrics.md +++ b/docs/book/reference/all-metrics.md @@ -272,6 +272,8 @@ Check for regular expression matches. | **DoesNotContain()** Example use:
`DoesNotContain(items=["as a large language model"]` | **Required:**
`items: List[str]`

**Optional:** | | **IncludesWords()** Example use:
`IncludesWords(words_list=['booking', 'hotel', 'flight']` | **Required:**
`words_list: List[str]`

**Optional:** | | **ExcludesWords()** Example use:
`ExcludesWords(words_list=['buy', 'sell', 'bet']`| **Required:**
`words_list: List[str]`

**Optional:** | +| **ItemMatch()** Example use:
`ItemMatch(with_column="expected")`| **Required:**
`with_column: str`

**Optional:** | +| **ItemNoMatch()** Example use:
`ItemMatch(with_column="forbidden")`| **Required:**
`with_column: str`

**Optional:** | ## Descriptors: Text stats diff --git a/src/evidently/descriptors/__init__.py b/src/evidently/descriptors/__init__.py index 2520abdbe3..b06f662f2a 100644 --- a/src/evidently/descriptors/__init__.py +++ b/src/evidently/descriptors/__init__.py @@ -19,6 +19,8 @@ from .sentiment_descriptor import Sentiment from .text_contains_descriptor import Contains from .text_contains_descriptor import DoesNotContain +from .text_contains_descriptor import ItemMatch +from .text_contains_descriptor import ItemNoMatch from .text_length_descriptor import TextLength from .text_part_descriptor import BeginsWith from .text_part_descriptor import EndsWith @@ -47,6 +49,8 @@ "EndsWith", "DoesNotContain", "IncludesWords", + "ItemMatch", + "ItemNoMatch", "ExcludesWords", "TextLength", "TriggerWordsPresence", diff --git a/src/evidently/descriptors/_registry.py b/src/evidently/descriptors/_registry.py index 0f912a86fe..6463f56d41 100644 --- a/src/evidently/descriptors/_registry.py +++ b/src/evidently/descriptors/_registry.py @@ -72,6 +72,12 @@ "evidently.descriptors.text_contains_descriptor.DoesNotContain", "evidently:descriptor:DoesNotContain", ) +register_type_alias( + FeatureDescriptor, "evidently.descriptors.text_contains_descriptor.ItemMatch", "evidently:descriptor:ItemMatch" +) +register_type_alias( + FeatureDescriptor, "evidently.descriptors.text_contains_descriptor.ItemNoMatch", "evidently:descriptor:ItemNoMatch" +) register_type_alias( FeatureDescriptor, "evidently.descriptors.text_length_descriptor.TextLength", "evidently:descriptor:TextLength" ) diff --git a/src/evidently/descriptors/text_contains_descriptor.py b/src/evidently/descriptors/text_contains_descriptor.py index 7e069970d5..4795c4f77f 100644 --- a/src/evidently/descriptors/text_contains_descriptor.py +++ b/src/evidently/descriptors/text_contains_descriptor.py @@ -39,3 +39,37 @@ def feature(self, column_name: str) -> GeneratedFeature: self.mode, self.display_name, ) + + +class ItemMatch(FeatureDescriptor): + class Config: + type_alias = "evidently:descriptor:ItemMatch" + + with_column: str + mode: str = "any" + case_sensitive: bool = True + + def feature(self, column_name: str) -> GeneratedFeature: + return text_contains_feature.ItemMatch( + columns=[column_name, self.with_column], + case_sensitive=self.case_sensitive, + mode=self.mode, + display_name=self.display_name, + ) + + +class ItemNoMatch(FeatureDescriptor): + class Config: + type_alias = "evidently:descriptor:ItemNoMatch" + + with_column: str + mode: str = "any" + case_sensitive: bool = True + + def feature(self, column_name: str) -> GeneratedFeature: + return text_contains_feature.ItemNoMatch( + columns=[column_name, self.with_column], + case_sensitive=self.case_sensitive, + mode=self.mode, + display_name=self.display_name, + ) diff --git a/src/evidently/features/_registry.py b/src/evidently/features/_registry.py index ba2e101f5f..3dacdcddea 100644 --- a/src/evidently/features/_registry.py +++ b/src/evidently/features/_registry.py @@ -52,6 +52,12 @@ register_type_alias( GeneratedFeatures, "evidently.features.text_contains_feature.DoesNotContain", "evidently:feature:DoesNotContain" ) +register_type_alias( + GeneratedFeatures, "evidently.features.text_contains_feature.ItemMatch", "evidently:feature:ItemMatch" +) +register_type_alias( + GeneratedFeatures, "evidently.features.text_contains_feature.ItemNoMatch", "evidently:feature:ItemNoMatch" +) register_type_alias( GeneratedFeatures, "evidently.features.text_length_feature.TextLength", "evidently:feature:TextLength" ) diff --git a/src/evidently/features/text_contains_feature.py b/src/evidently/features/text_contains_feature.py index 31bd3a0975..6b909b95c0 100644 --- a/src/evidently/features/text_contains_feature.py +++ b/src/evidently/features/text_contains_feature.py @@ -112,3 +112,109 @@ def comparison(self, item: str, string: str): if self.case_sensitive: return item in string return item.casefold() in string.casefold() + + +class ItemMatch(GeneratedFeature): + class Config: + type_alias = "evidently:feature:ItemMatch" + + __feature_type__: ClassVar = ColumnType.Categorical + columns: List[str] + case_sensitive: bool + mode: str + + def __init__( + self, + columns: List[str], + case_sensitive: bool = True, + mode: str = "any", + display_name: Optional[str] = None, + ): + if len(columns) != 2: + raise ValueError("two columns must be provided") + self.columns = columns + self.display_name = display_name + self.case_sensitive = case_sensitive + if mode not in ["any", "all"]: + raise ValueError("mode must be either 'any' or 'all'") + self.mode = mode + super().__init__() + + def _feature_column_name(self) -> str: + return f"{self.columns[0]}_{self.columns[1]}" + "_item_match_" + str(self.case_sensitive) + "_" + self.mode + + def generate_feature(self, data: pd.DataFrame, data_definition: DataDefinition) -> pd.DataFrame: + if self.mode == "any": + calculated = data.apply( + lambda row: any(self.comparison(word, row[self.columns[0]]) for word in row[self.columns[1]]), + axis=1, + ) + else: + calculated = data.apply( + lambda row: all(self.comparison(word, row[self.columns[0]]) for word in row[self.columns[1]]), + axis=1, + ) + return pd.DataFrame({self._feature_column_name(): calculated}) + + def _as_column(self) -> ColumnName: + return self._create_column( + self._feature_column_name(), + default_display_name=f"Text contains {self.mode} of defined items", + ) + + def comparison(self, item: str, string: str): + if self.case_sensitive: + return item in string + return item.casefold() in string.casefold() + + +class ItemNoMatch(GeneratedFeature): + class Config: + type_alias = "evidently:feature:ItemNoMatch" + + __feature_type__: ClassVar = ColumnType.Categorical + columns: List[str] + case_sensitive: bool + mode: str + + def __init__( + self, + columns: List[str], + case_sensitive: bool = True, + mode: str = "any", + display_name: Optional[str] = None, + ): + self.columns = columns + self.display_name = display_name + self.case_sensitive = case_sensitive + if mode not in ["any", "all"]: + raise ValueError("mode must be either 'any' or 'all'") + self.mode = mode + super().__init__() + + def _feature_column_name(self) -> str: + return f"{self.columns[0]}_{self.columns[1]}" + "_item_no_match_" + str(self.case_sensitive) + "_" + self.mode + + def generate_feature(self, data: pd.DataFrame, data_definition: DataDefinition) -> pd.DataFrame: + if self.mode == "any": + calculated = data.apply( + lambda row: not any(self.comparison(word, row[self.columns[0]]) for word in row[self.columns[1]]), + axis=1, + ) + else: + calculated = data.apply( + lambda row: not all(self.comparison(word, row[self.columns[0]]) for word in row[self.columns[1]]), + axis=1, + ) + return pd.DataFrame({self._feature_column_name(): calculated}) + + def _as_column(self) -> ColumnName: + return self._create_column( + self._feature_column_name(), + default_display_name=f"Text does not contain {self.mode} of defined items", + ) + + def comparison(self, item: str, string: str): + if self.case_sensitive: + return item in string + return item.casefold() in string.casefold() diff --git a/tests/features/test_text_contains_feature.py b/tests/features/test_text_contains_feature.py index 51dceeb680..3b590c9de4 100644 --- a/tests/features/test_text_contains_feature.py +++ b/tests/features/test_text_contains_feature.py @@ -5,6 +5,8 @@ from evidently.features.text_contains_feature import Contains from evidently.features.text_contains_feature import DoesNotContain +from evidently.features.text_contains_feature import ItemMatch +from evidently.features.text_contains_feature import ItemNoMatch from evidently.pipeline.column_mapping import ColumnMapping from evidently.utils.data_preprocessing import create_data_definition @@ -61,3 +63,83 @@ def test_text_not_contains_feature(items: List[str], case: bool, mode: str, expe column_expected = feature_generator._feature_column_name() expected_df = pd.DataFrame({column_expected: expected}) assert result.equals(expected_df) + + +@pytest.mark.parametrize( + ("case", "mode", "expected"), + [ + (True, "any", [False, True, False, True, False]), + (True, "all", [False, True, False, False, False]), + (False, "any", [True, True, True, True, False]), + (False, "all", [False, True, True, False, False]), + ], +) +def test_item_match(case: bool, mode: str, expected: List[bool]): + data = { + "generated": [ + "You should consider purchasing Nike or Adidas shoes.", + "I eat apples, grapes, and oranges", + "grapes, oranges, apples.", + "Oranges are more sour than grapes.", + "This test doesn't have the words.", + ], + "expected": [ + ["nike", "adidas", "puma"], + ["grapes", "apples", "oranges"], + ["Apples", "Oranges", "Grapes"], + ["orange", "sweet", "grape"], + ["none", "of", "these"], + ], + } + df = pd.DataFrame(data) + df["expected"] = df["expected"].apply(tuple) + feature_generator = ItemMatch(columns=["generated", "expected"], case_sensitive=case, mode=mode) + result = feature_generator.generate_feature( + data=df, + data_definition=create_data_definition(None, df, ColumnMapping()), + ) + column_expected = feature_generator._feature_column_name() + column_name_obj = feature_generator._as_column() + expected_df = pd.DataFrame({column_expected: expected}) + assert result.equals(expected_df) + assert column_name_obj.display_name == f"Text contains {mode} of defined items" + + +@pytest.mark.parametrize( + ("case", "mode", "expected"), + [ + (True, "any", [True, False, True, False, True]), + (True, "all", [True, False, True, True, True]), + (False, "any", [False, False, False, False, True]), + (False, "all", [True, False, False, True, True]), + ], +) +def test_item_no_match(case: bool, mode: str, expected: List[bool]): + data = { + "generated": [ + "You should consider purchasing Nike or Adidas shoes.", + "I eat apples, grapes, and oranges", + "grapes, oranges, apples.", + "Oranges are more sour than grapes.", + "This test doesn't have the words.", + ], + "forbidden": [ + ["nike", "adidas", "puma"], + ["grapes", "apples", "oranges"], + ["Apples", "Oranges", "Grapes"], + ["orange", "sweet", "grape"], + ["none", "of", "these"], + ], + } + feature_generator = ItemNoMatch(columns=["generated", "forbidden"], case_sensitive=case, mode=mode) + df = pd.DataFrame(data) + df["forbidden"] = df["forbidden"].apply(tuple) + result = feature_generator.generate_feature( + data=df, + data_definition=create_data_definition(None, df, ColumnMapping()), + ) + column_expected = feature_generator._feature_column_name() + column_name_obj = feature_generator._as_column() + expected_df = pd.DataFrame({column_expected: expected}) + assert result.equals(expected_df) + assert column_name_obj.display_name == f"Text does not contain {mode} of defined items"