diff --git a/gilda/api.py b/gilda/api.py index e46e257..b5aaa1e 100644 --- a/gilda/api.py +++ b/gilda/api.py @@ -112,7 +112,6 @@ def annotate( sent_split_fun=None, organisms=None, namespaces=None, - return_first: bool = True, context_text: str = None, ): """Annotate a given text with Gilda (i.e., do named entity recognition). @@ -132,18 +131,16 @@ def annotate( namespaces : list[str], optional A list of namespaces to pass to the grounder to restrict the matches to. By default, no restriction is applied. - return_first : - If true, only returns the first result. Otherwise, returns all results. context_text : A longer span of text that serves as additional context for the text being annotated for disambiguation purposes. Returns ------- - list[tuple[str, ScoredMatch, int, int]] - A list of tuples of start and end character offsets of the text - corresponding to the entity, the entity text, and the ScoredMatch - object corresponding to the entity. + list[tuple[str, list[ScoredMatch], int, int]] + A list of matches where each match is a tuple consisting of + the matches text span, the list of ScoredMatches, and the + start and end character offsets of the text span. """ from .ner import annotate as _annotate @@ -153,7 +150,6 @@ def annotate( sent_split_fun=sent_split_fun, organisms=organisms, namespaces=namespaces, - return_first=return_first, context_text=context_text, ) diff --git a/gilda/app/templates/ner_matches.html b/gilda/app/templates/ner_matches.html index a9ff447..94614a4 100644 --- a/gilda/app/templates/ner_matches.html +++ b/gilda/app/templates/ner_matches.html @@ -25,9 +25,9 @@

NER Results

- {% for text, match, start, end in annotations %} + {% for text, matches, start, end in annotations %} - {% set match_curie = match.term.get_curie() %} + {% set match_curie = matches[0].term.get_curie() %} {{start}}-{{end}} diff --git a/gilda/ner.py b/gilda/ner.py index 2340367..75fbe59 100644 --- a/gilda/ner.py +++ b/gilda/ner.py @@ -9,7 +9,8 @@ The results are a list of 4-tuples containing: - the text string matched -- a :class:`gilda.grounder.ScoredMatch` instance containing the _best_ match +- a list of :class:`gilda.grounder.ScoredMatch` instances containing a sorted list of matches + for the given text span (first one is the best match) - the position in the text string where the entity starts - the position in the text string where the entity ends @@ -69,7 +70,6 @@ def annotate( sent_split_fun=None, organisms=None, namespaces=None, - return_first: bool = True, context_text: str = None, ) -> List[Annotation]: """Annotate a given text with Gilda. @@ -91,18 +91,16 @@ def annotate( namespaces : list[str], optional A list of namespaces to pass to the grounder to restrict the matches to. By default, no restriction is applied. - return_first : - If true, only returns the first result. Otherwise, returns all results. context_text : A longer span of text that serves as additional context for the text being annotated for disambiguation purposes. Returns ------- - list[tuple[str, ScoredMatch, int, int]] - A list of tuples of start and end character offsets of the text - corresponding to the entity, the entity text, and the ScoredMatch - object corresponding to the entity. + list[tuple[str, list[ScoredMatch], int, int]] + A list of matches where each match is a tuple consisting of + the matches text span, the list of ScoredMatches, and the + start and end character offsets of the text span. """ if grounder is None: grounder = get_grounder() @@ -136,22 +134,19 @@ def annotate( # Find the largest matching span for span in sorted(applicable_spans, reverse=True): txt_span = ' '.join(raw_words[idx:idx+span]) - matches = grounder.ground( - txt_span, context=text if context_text is None else context_text, - organisms=organisms, namespaces=namespaces, - ) + context = text if context_text is None else context_text + matches = grounder.ground(txt_span, + context=context, + organisms=organisms, + namespaces=namespaces) if matches: start_coord = word_coords[idx] end_coord = word_coords[idx+span-1] + \ len(raw_words[idx+span-1]) raw_span = ' '.join(raw_words[idx:idx+span]) - - if return_first: - matches = [matches[0]] - for match in matches: - entities.append( - (raw_span, match, start_coord, end_coord) - ) + entities.append(( + raw_span, matches, start_coord, end_coord + )) skip_until = idx + span break @@ -163,7 +158,7 @@ def get_brat(entities, entity_type="Entity", ix_offset=1, include_text=True): Parameters ---------- - entities : list[tuple[str, str | ScoredMatch, int, int]] + entities : list[tuple[str, str | list[str] | list[ScoredMatch], int, int]] A list of tuples of entity text, grounded curie, start and end character offsets in the text corresponding to an entity. entity_type : str, optional @@ -184,9 +179,14 @@ def get_brat(entities, entity_type="Entity", ix_offset=1, include_text=True): """ brat = [] ix_offset = max(1, ix_offset) - for idx, (raw_span, curie, start, end) in enumerate(entities, ix_offset): - if isinstance(curie, ScoredMatch): - curie = curie.term.get_curie() + for idx, (raw_span, curies, start, end) in enumerate(entities, ix_offset): + if isinstance(curies, str): + curie = curies + # Note that here we always that the best match and ignore the rest + else: + curie = curies[0] + if isinstance(curie, ScoredMatch): + curie = curie.term.get_curie() if entity_type != "Entity": curie += f"; Reading system: {entity_type}" row = f'T{idx}\t{entity_type} {start} {end}' + ( diff --git a/gilda/tests/test_ner.py b/gilda/tests/test_ner.py index 13ae67b..55060ac 100644 --- a/gilda/tests/test_ner.py +++ b/gilda/tests/test_ner.py @@ -28,14 +28,14 @@ def test_annotate(): assert annotations[6][2:4] == (56, 63) # protein # Check that the curies are correct - assert isinstance(annotations[0][1], gilda.ScoredMatch) - assert annotations[0][1].term.get_curie() == "CHEBI:36080" - assert annotations[1][1].term.get_curie() == "hgnc:1097" - assert annotations[2][1].term.get_curie() == "mesh:D010770" - assert annotations[3][1].term.get_curie() == "hgnc:1097" - assert annotations[4][1].term.get_curie() == "mesh:D005796" - assert annotations[5][1].term.get_curie() == "hgnc:1097" - assert annotations[6][1].term.get_curie() == "CHEBI:36080" + assert isinstance(annotations[0][1][0], gilda.ScoredMatch) + assert annotations[0][1][0].term.get_curie() == "CHEBI:36080" + assert annotations[1][1][0].term.get_curie() == "hgnc:1097" + assert annotations[2][1][0].term.get_curie() == "mesh:D010770" + assert annotations[3][1][0].term.get_curie() == "hgnc:1097" + assert annotations[4][1][0].term.get_curie() == "mesh:D005796" + assert annotations[5][1][0].term.get_curie() == "hgnc:1097" + assert annotations[6][1][0].term.get_curie() == "CHEBI:36080" def test_get_brat(): @@ -66,12 +66,12 @@ def test_get_brat(): def test_get_all(): full_text = "This is about ER." - results = gilda.annotate(full_text, return_first=False) - assert len(results) > 1 - curies = { - scored_match.term.get_curie() - for _, scored_match, _, _ in results - } + results = gilda.annotate(full_text) + assert len(results) == 1 + curies = set() + for _, scored_matches, _, _ in results: + for scored_match in scored_matches: + curies.add(scored_match.term.get_curie()) assert "hgnc:3467" in curies # ESR1 assert "fplx:ESR" in curies assert "GO:0005783" in curies # endoplasmic reticulum @@ -82,13 +82,13 @@ def test_context_test(): context_text = "Estrogen receptor (ER) is a protein family." results = gilda.annotate(text, context_text=context_text) assert len(results) == 1 - assert results[0][1].term.get_curie() == "fplx:ESR" + assert results[0][1][0].term.get_curie() == "fplx:ESR" assert results[0][0] == "ER" assert results[0][2:4] == (14, 16) context_text = "Calcium is released from the ER." results = gilda.annotate(text, context_text=context_text) assert len(results) == 1 - assert results[0][1].term.get_curie() == "GO:0005783" + assert results[0][1][0].term.get_curie() == "GO:0005783" assert results[0][0] == "ER" assert results[0][2:4] == (14, 16)