From eeffb204251005e47c91a6ceed9ea3f18f6df41e Mon Sep 17 00:00:00 2001 From: Vincent Menger Date: Fri, 1 Dec 2023 20:44:56 +0100 Subject: [PATCH] Cleanup annotator code --- docdeid/process/annotator.py | 45 +++++++++++++++++----------- tests/unit/process/test_annotator.py | 34 +++++++++++++++++++++ 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/docdeid/process/annotator.py b/docdeid/process/annotator.py index c579113..7e301d9 100644 --- a/docdeid/process/annotator.py +++ b/docdeid/process/annotator.py @@ -133,44 +133,52 @@ def __init__( ) -> None: self.overlapping = overlapping - self.trie = LookupTrie(matching_pipeline=matching_pipeline) - self.matching_pipeline = matching_pipeline or [] + self._matching_pipeline = matching_pipeline or [] - self.start_tokens = set() + self._trie = LookupTrie(matching_pipeline=matching_pipeline) + self._start_texts = set() + + self._init_lookup_structures(lookup_values, tokenizer) + + super().__init__(*args, **kwargs) + + def _init_lookup_structures(self, lookup_values: Iterable[str], tokenizer: Tokenizer): for val in lookup_values: + texts = [token.text for token in tokenizer.tokenize(val)] if len(texts) > 0: - self.trie.add_item(texts) + self._trie.add_item(texts) start_token = texts[0] - for string_modifier in self.matching_pipeline: + for string_modifier in self._matching_pipeline: start_token = string_modifier.process(start_token) - self.start_tokens.add(start_token) - - super().__init__(*args, **kwargs) + self._start_texts.add(start_token) def annotate(self, doc: Document) -> list[Annotation]: tokens = doc.get_tokens() - start_positions = sorted( - tokens.token_lookup(self.start_tokens, matching_pipeline=self.matching_pipeline), + + start_tokens = sorted( + tokens.token_lookup(self._start_texts, matching_pipeline=self._matching_pipeline), key=lambda token: token.start_char, ) - start_positions = [tokens.token_index(token) for token in start_positions] + + start_indices = [tokens.token_index(token) for token in start_tokens] + tokens_text = [token.text for token in tokens] annotations = [] min_i = 0 - for i in start_positions: + for i in start_indices: if i < min_i: continue - longest_matching_prefix = self.trie.longest_matching_prefix(tokens_text, start_i=i) + longest_matching_prefix = self._trie.longest_matching_prefix(tokens_text, start_i=i) if longest_matching_prefix is None: continue @@ -203,9 +211,10 @@ class RegexpAnnotator(Annotator): Args: tag: The tag to use in the annotations. - regexp_pattern: A compiled ``re.Pattern``, that will be used for matching. + regexp_pattern: A pattern, either as a `str` or a ``re.Pattern``, that will be used for matching. capturing_group: The capturing group of the pattern that should be used to produce the annotation. By default, the entire match is used. + pre_match_tokens: A list of tokens, of which at least one must be present for the annotator to start matching the regexp at all. """ def __init__( @@ -213,7 +222,7 @@ def __init__( regexp_pattern: Union[re.Pattern, str], *args, capturing_group: int = 0, - pre_tokens: Optional[list[str]] = None, + pre_match_tokens: Optional[list[str]] = None, **kwargs, ) -> None: @@ -223,10 +232,10 @@ def __init__( self.regexp_pattern = regexp_pattern self.capturing_group = capturing_group - self.pre_tokens = pre_tokens + self.pre_tokens = pre_match_tokens - if pre_tokens is not None: - self.pre_tokens = set(pre_tokens) + if pre_match_tokens is not None: + self.pre_tokens = set(pre_match_tokens) self.matching_pipeline = [docdeid.str.LowercaseString()] super().__init__(*args, **kwargs) diff --git a/tests/unit/process/test_annotator.py b/tests/unit/process/test_annotator.py index d826e38..6f397ec 100644 --- a/tests/unit/process/test_annotator.py +++ b/tests/unit/process/test_annotator.py @@ -78,7 +78,41 @@ def test_multi_token_with_matching_pipeline(self, long_text, long_tokenlist): ] with patch.object(doc, "get_tokens", return_value=long_tokenlist): + annotations = annotator.annotate(doc) + + assert annotations == expected_annotations + + def test_multi_token_lookup_with_overlap(self, long_text, long_tokenlist): + + doc = Document(long_text) + + annotator = MultiTokenLookupAnnotator( + lookup_values=["dr. John", "John Smith"], tokenizer=WordBoundaryTokenizer(), tag="prefix", overlapping=True + ) + expected_annotations = [ + Annotation(text="dr. John", start_char=11, end_char=19, tag="prefix"), + Annotation(text="John Smith", start_char=15, end_char=25, tag="prefix"), + ] + + with patch.object(doc, "get_tokens", return_value=long_tokenlist): + annotations = annotator.annotate(doc) + + assert annotations == expected_annotations + + def test_multi_token_lookup_no_overlap(self, long_text, long_tokenlist): + + doc = Document(long_text) + + annotator = MultiTokenLookupAnnotator( + lookup_values=["dr. John", "John Smith"], tokenizer=WordBoundaryTokenizer(), tag="prefix", overlapping=False + ) + + expected_annotations = [ + Annotation(text="dr. John", start_char=11, end_char=19, tag="prefix"), + ] + + with patch.object(doc, "get_tokens", return_value=long_tokenlist): annotations = annotator.annotate(doc) assert annotations == expected_annotations